##################################
#                                #
# Last modified 03/27/2011       # 
#                                #
# Georgi Marinov                 #
#                                # 
##################################

import sys
import string
import math

def normalize(p):

    pnorm = []
    S = 0
    for n in p:
        S += math.log(n+1,2)
    for n in p:
        pnorm.append(math.log(n+1,2)/S)

    return pnorm

def nolognormalize(p):

    pnorm = []
    S = 0.0
    for n in p:
        S += n+1
    for n in p:
        pnorm.append((n+1)/S)

    return pnorm

def H(p):

    H = 0
    for pi in p:
        if pi != 0:
            H -= pi*math.log(pi,2)

    return H

def JSsp(p1,p2):

    p1p2=[]
    for i in range(len(p1)):
        p1p2.append((p1[i] + p2[i])/2.0)

    JS = H(p1p2) - (H(p1) + H(p2))/2.0

    if JS < 0:
        print p1
        print p2
        JSsp = 'math error'
    else:
        JSsp = 1 - math.sqrt(JS)

    return JSsp

def run():

    if len(sys.argv) < 3:
        print 'usage: python %s table fields outfilename [-nolognormalize] [-cap maxvalue] [-cufflinksStatus minNumberSamples]' % sys.argv[0]
        print '       fields comma separated or in from:to (including) format'
        sys.exit(1)

    table = sys.argv[1]
    outfilename = sys.argv[3]
    fieldIDs=[]
    fields = sys.argv[2].split(',')
    for f in fields:
        if ':' in f:
            for i in range(int(f.split(':')[0]),int(f.split(':')[1])+1):
                fieldIDs.append(i)    
        else:
            fieldIDs.append(int(f))    

    doCufflinksStatus=False
    if '-cufflinksStatus' in sys.argv:
        doCufflinksStatus=True
        minSamples = int(sys.argv[sys.argv.index('-cufflinksStatus')+1])
        print 'will ignore FAIL values and skip genes with fewer than', minSamples, 'OK values'

    doLogNormalize=True
    if '-nolognormalize' in sys.argv:
        doLogNormalize=False 

    doCap=False 
    if '-cap' in sys.argv:
        doCap=True
        cap = float(sys.argv[sys.argv.index('-cap')+1])

    fieldIDs.sort()
    print fieldIDs

    outfile = open(outfilename, 'w')

    listoflines = open(table)
    for line in listoflines:
        if line.startswith('#'):
            outline = line.strip() + '\tJSsp\n'
            outfile.write(outline)
            continue
        fields = line.strip().split('\t')
        p=[]
        if doCap:
            for ID in fieldIDs:
                if doCufflinksStatus and fields[ID]=='FAIL':
                    continue
                p.append(min(float(fields[ID]),cap))
        else:
            for ID in fieldIDs:
                if doCufflinksStatus and fields[ID]=='FAIL':
                    continue
                p.append(float(fields[ID]))
        if len(p) == 0:
            continue
        if doCufflinksStatus and len(p) < minSamples:
            continue
        if max(p) == 0:
            continue
        if doLogNormalize:
            try:
                p = normalize(p)
            except:
                print 'can not normalize vector', p
                continue
        else:
            p = nolognormalize(p)
        e = []
        v = []
        for i in range(len(p)):
            v.append(0)
        for i in range(len(p)):
            v[i]=1
            a=[]
            for aa in v:
                a.append(aa)
            e.append(a)
            v[i]=0
        JSspmax=0
        for v in e:
            JS = JSsp(p,v)
            if JS != 'math error' and JS > JSspmax:
                JSspmax = JS 
        outline = line.strip() + '\t' + str(JSspmax)
#        for i in range(len(p)):
#            outline = outline + '\t' + str(p[i])
        outfile.write(outline + '\n')

    outfile.close()

run()

