##################################
#                                #
# Last modified 03/26/2012       # 
#                                #
# Georgi Marinov                 #
#                                # 
##################################

import sys
import string
import math
from sets import Set

def run():

    if len(sys.argv) < 6:
        print 'usage: python %s  datafilename datafileFieldID1 datafileFieldsID2 [first field: (0),number1,number2,number3...,numberN] [second field: (0),number1,number2,number3...,numberN] outfilename [-sumFields FieldID1,...] [-splitby string] [-FAIL]' % sys.argv[0]
        print 'Note: the script will output the dataset in datafileFieldID1 on the Y and the one in datafileFieldID2 on the X axis' 
        print '      -sumFields option refers to the second vector' 
        sys.exit(1)

    datafilename = sys.argv[1]
    fieldID1 = int(sys.argv[2])
    fieldID2 = int(sys.argv[3])
    outfilename = sys.argv[6]
    bins=sys.argv[4]
    bins=bins.split(',')
    binList1=[]
    print bins
    for bin in bins:
        binList1.append(float(bin))
    binList1.append(0.0)
    binList1=list(Set(binList1))
    binList1.sort()

    bins=sys.argv[5]
    bins=bins.split(',')
    binList2=[]
    print bins
    for bin in bins:
        binList2.append(float(bin))
    binList2.append(0.0)
    binList2=list(Set(binList2))
    binList2.sort()

    HistDict={}
    for i in binList1:
        HistDict[i]={}
        for j in binList2:
            HistDict[i][j]=0

    splitBy='\t'
    if '-splitby' in sys.argv:
        splitBy=sys.argv[sys.argv.index('-splitby') + 1]

    doSum=False
    if '-sumFields' in sys.argv:
        fields=sys.argv[sys.argv.index('-sumFields') + 1].split(',')
        doSum=True
        sumFieldIDs=[]
        for f in fields:
            sumFieldIDs.append(int(f))
        print 'will sum fields', sumFieldIDs

    doFAIL=False
    if '-FAIL' in sys.argv:
        doFAIL=True

    outfile = open(outfilename, 'w')

    DataList=[]
    lineslist  = open(datafilename)
    t=0
    for line in lineslist:
        t+=1
        if t % 1000000 == 0:
            print t, 'lines processed'
        if line[0]=='#':
            continue
        fields = line.strip().split(splitBy)
        if doFAIL:
            if fields[fieldID1] == 'FAIL':
                continue
        if doSum:
            score=0
            for ID in sumFieldIDs:
                score+=float(fields[ID])
            DataList.append((float(fields[fieldID1]),score))
        else:
            DataList.append((float(fields[fieldID1]),float(fields[fieldID2])))

    AddKeys1 = {}
    AddKeys2 = {}

    for (v1,v2) in DataList:
        V1NotNumber = False
        if math.isnan(v1):
            V1NotNumber = True
        else:
            try:
                float(v1)
            except:
                V1NotNumber = True
        V2NotNumber = False
        if math.isnan(v2):
            V2NotNumber = True
        else:
            try:
                float(v2)
            except:
                V2NotNumber = True
        if V1NotNumber and V2NotNumber:
            v1 = str(v1)
            v2 = str(v2)
            AddKeys1[v1] = 0
            AddKeys2[v2] = 0
            if HistDict.has_key(v1):
                pass
            else:
                HistDict[v1]={}
            if HistDict[v1].has_key(v2):
                pass
            else:
                HistDict[v1][v2]=0
            HistDict[v1][v2]+=1
            continue
        elif V1NotNumber:
            v1 = str(v1)
            AddKeys1[v1] = 0
            if HistDict.has_key(v1):
                pass
            else:
                HistDict[v1]={}
            if v2 >= max(binList2):
                y=max(binList2)
            elif v2 < min(binList2):
                y=min(binList2)
            else:
                for i in range(len(binList2)-1):
                    if v2 >= binList2[i] and v2 < binList2[i+1]:
                        y=binList2[i]
                        break
            if HistDict[v1].has_key(y):
                pass
            else:
                HistDict[v1][y]=0
            HistDict[v1][y]+=1
            continue
        elif V2NotNumber:
            v2 = str(v2)
            AddKeys2[v2] = 0
            if v1 >= max(binList1):
                x=max(binList1)
            elif v1 < min(binList1):
                x=min(binList1)
            else:
                for i in range(len(binList1)-1):
                    if v1 >= binList1[i] and v1 < binList1[i+1]:
                        x=binList1[i]
                        break
            if HistDict[x].has_key(v2):
                pass
            else:
                HistDict[x][v2]=0
            HistDict[x][v2]+=1
            continue
        if v1 >= max(binList1):
            x=max(binList1)
        elif v1 < min(binList1):
            x=min(binList1)
        else:
            for i in range(len(binList1)-1):
                if v1 >= binList1[i] and v1 < binList1[i+1]:
                    x=binList1[i]
                    break
        if v2 >= max(binList2):
            y=max(binList2)
        elif v2 < min(binList2):
            y=min(binList2)
        else:
            for i in range(len(binList2)-1):
                if v2 >= binList2[i] and v2 < binList2[i+1]:
                    y=binList2[i]
                    break
        HistDict[x][y]+=1

    outline='#'
    binList2 = binList2 + AddKeys2.keys()
    for x in binList2:
        outline=outline+'\t'+str(x)
    outfile.write(outline+'\n')
    
    binList1 = binList1 + AddKeys1.keys()
    for x in binList1:
        outline=str(x)
        for y in binList2:
            if HistDict[x].has_key(y):
                outline = outline + '\t' + str(HistDict[x][y])
            else:
                outline = outline + '\t' + '0'
        outfile.write(outline+'\n')

    outfile.close()
        
run()

