##################################
#                                #
# Last modified 7/17/2009         # 
#                                #
# Georgi Marinov                 #
#                                # 
##################################

import sys
import string
import math
from cistematic.core import Genome
from cistematic.core.geneinfo import geneinfoDB
from commoncode import *

def run():

    if len(sys.argv) < 2:
        print 'usage: python %s genome outputfilename [-GENCODE GENCODE.genelist] [-AdditionalAnnotation filename starField] [RNA_name::RPKMfilename ...] [ChIP_name::ERANGEhtsfilename ...]' % sys.argv[0]
        print 'Note: the [-AdditionalAnnotation] option has to be used with a file with the following format: <name/ID> <chromosome> <start> <stop> starting with the startField'
        print 'the [-AdditionalAnnotation] option will take the middle of the regions as a TSS'
        sys.exit(1)
    
    genome = sys.argv[1]
    outfilename = sys.argv[2]
    RNA={}
    ChIP={}
    for inputfile in sys.argv:
        if '::' in inputfile:
            fields=inputfile.split('::')
            if fields[0].split('_')[0]=='RNA':
                RNA[fields[0].split('RNA_')[1]]=fields[1]
            if fields[0].split('_')[0]=='ChIP':
                ChIP[fields[0].split('ChIP_')[1]]=fields[1]

    doAddAn=False
    if '-AdditionalAnnotation' in sys.argv:
        doAddAn=True
        doAddAnFieldID=int(sys.argv[sys.argv.index('-AdditionalAnnotation') + 2])
        doAddAnFilename=sys.argv[sys.argv.index('-AdditionalAnnotation') + 1]
    doGENCODE=False
    if '-GENCODE' in sys.argv:
        doGENCODE=True
        print 'will use GENCODE annotation'
        GENCODEgenelist=sys.argv[sys.argv.index('-GENCODE') + 1]

    outfile = open(outfilename, 'w')

    genes = {}
    hg = Genome(genome)
    idb = geneinfoDB()
    geneinfoDict = idb.getallGeneInfo(genome)
    featDict = hg.getallGeneFeatures()
    geneIDs = featDict.keys()
    allDataDict={}
    TSSDict={}
    i=0
    if doGENCODE:
        inputdatafile = open(GENCODEgenelist)
        lineslist = inputdatafile.readlines()
        for line in lineslist:
            i+=1
            if i % 1000 == 0:
                print i
            fields=line.strip().split('\t')
            if 'rRNA' in fields[0] or '7SK' in fields[0]:
                continue
            chr=fields[1]
            name=fields[0]
            orientation=fields[4]
            rmin=int(fields[2])
            rmax=int(fields[3])
            ID=name
            if chr not in allDataDict.keys():
                allDataDict[chr]={}
                TSSDict[chr]={}
            allDataDict[chr][ID]={}
            if orientation == 'F' or orientation == '+':
                allDataDict[chr][ID]['TSS']=rmin
            if orientation == 'R' or orientation == '-':
                allDataDict[chr][ID]['TSS']=rmax
            allDataDict[chr][ID]['orientation']=orientation
            allDataDict[chr][ID]['GeneName']=name
            allDataDict[chr][ID]['length']=rmax-rmin
            allDataDict[chr][ID]['outline']=str(ID)+'\t'+name+'\t'+chr+'\t'+str(rmax)+'\t'+str(rmin)+'\t'+orientation+'\t'+str(rmax-rmin)
            TSSDict[chr][allDataDict[chr][ID]['TSS']]={}
            TSSDict[chr][allDataDict[chr][ID]['TSS']]['ID']=ID
            TSSDict[chr][allDataDict[chr][ID]['TSS']]['sites']=[]
    else:
        for k in featDict.keys():
            if i % 1000==0:
               print i, 'genes processed'
            i+=1
            if idb.getGeneInfo((genome,k))==[]:
                name = 'LOC'+str(k)
            else:
                name = idb.getGeneInfo((genome,k))[0]
            leftPos=[]
            rightPos=[]
            for feature in featDict[k]:
                leftPos.append(int(feature[2]))
                rightPos.append(int(feature[3]))
            ID=str(k)
            chr= 'chr'+str(featDict[k][0][1])
            if chr not in allDataDict.keys():
                allDataDict[chr]={}
                TSSDict[chr]={}
            orientation=str(featDict[k][0][4])
            rmin=min(leftPos)
            rmax=max(rightPos)
            allDataDict[chr][ID]={}
            if orientation == 'F':
                allDataDict[chr][ID]['TSS']=rmin
            if orientation == 'R':
                allDataDict[chr][ID]['TSS']=rmax
            allDataDict[chr][ID]['orientation']=orientation
            allDataDict[chr][ID]['GeneName']=name
            allDataDict[chr][ID]['length']=rmax-rmin
            allDataDict[chr][ID]['outline']=str(ID)+'\t'+name+'\t'+chr+'\t'+str(rmax)+'\t'+str(rmin)+'\t'+orientation+'\t'+str(rmax-rmin)
            TSSDict[chr][allDataDict[chr][ID]['TSS']]={}
            TSSDict[chr][allDataDict[chr][ID]['TSS']]['ID']=ID
            TSSDict[chr][allDataDict[chr][ID]['TSS']]['sites']=[]
    if doAddAn:
        print 'doAddAn'
        listoflines = open(doAddAnFilename)
        lineslist = listoflines.readlines()
        for line in lineslist:
            if line[0]=='#':
                continue
            fields=line.strip().split('\t')
            chr=fields[doAddAnFieldID+1]
            start=int(fields[doAddAnFieldID+2])
            stop=int(fields[doAddAnFieldID+3])
            IDadd=fields[doAddAnFieldID]
            if chr not in allDataDict.keys():
                allDataDict[chr]={}
                TSSDict[chr]={}
            orientation='unknown'
            rmin=start
            rmax=stop
            allDataDict[chr][IDadd]={}
            allDataDict[chr][IDadd]['TSS']=int((start+stop)/2.0)
            allDataDict[chr][IDadd]['GeneName']=IDadd
            allDataDict[chr][IDadd]['orientation']='F'
            allDataDict[chr][IDadd]['length']=start-stop
            allDataDict[chr][IDadd]['outline']=str(IDadd)+'\t'+IDadd+'\t'+chr+'\t'+str(rmax)+'\t'+str(rmin)+'\t'+'unknown'+'\t'+str(rmax-rmin)
            TSSDict[chr][allDataDict[chr][IDadd]['TSS']]={}
            TSSDict[chr][allDataDict[chr][IDadd]['TSS']]['ID']=IDadd
            TSSDict[chr][allDataDict[chr][IDadd]['TSS']]['sites']=[]

    outline='#GeneID\tGeneName\tChromosome\tStart\tEnd\tOrientation\tLength'
    for dataset in RNA:
        outline=outline+'\tRNA '+dataset
        print 'processing RNA', dataset, RNA[dataset]
        inputdatafile = open(RNA[dataset])
        linelist = inputdatafile.readlines()
        for line in linelist:
            if line[0]=='#':
                continue
            fields=line.split('\t')
            for chr in allDataDict:
                if allDataDict[chr].has_key(fields[0]):
                    allDataDict[chr][fields[0]]['outline']=allDataDict[chr][fields[0]]['outline']+'\t'+fields[3]
    
    for dataset in ChIP:
        outline=outline+'\t'+dataset+'\t\t'
        inputdatafile = open(ChIP[dataset])
        print 'processing ChIP', dataset, ChIP[dataset]
        linelist = inputdatafile.readlines()
        for line in linelist:
            if line[0]=='#':
                continue
            fields=line.split('\t')
            chr=fields[1]
            if 'random' in chr:
                continue
            peak=int(fields[9])
            peakHeight=float(fields[10])
            distance=1000000000
            AssociatedTSS=''
            for TSS in TSSDict[chr].keys():
                if math.fabs(peak-TSS)<distance:
                    distance=math.fabs(peak-TSS)
                    AssociatedTSS=TSS
            TSSDict[chr][AssociatedTSS]['sites'].append((peak,peakHeight))
        for chr in allDataDict.keys():
            for ID in allDataDict[chr].keys():
                num=len(TSSDict[chr][allDataDict[chr][ID]['TSS']]['sites'])
                if num==0:
                    allDataDict[chr][TSSDict[chr][allDataDict[chr][ID]['TSS']]['ID']]['outline']=allDataDict[chr][TSSDict[chr][allDataDict[chr][ID]['TSS']]['ID']]['outline']+'\t'+str(num)+'\t\t'
                else:
                    print ID, allDataDict[chr][ID]['TSS'], TSSDict[chr][allDataDict[chr][ID]['TSS']], TSSDict[chr][allDataDict[chr][ID]['TSS']]['sites']
                    allDataDict[chr][TSSDict[chr][allDataDict[chr][ID]['TSS']]['ID']]['outline']=allDataDict[chr][TSSDict[chr][allDataDict[chr][ID]['TSS']]['ID']]['outline']+'\t'+str(num)+'\t'
                    for (peak, height) in TSSDict[chr][allDataDict[chr][ID]['TSS']]['sites']:
                        if allDataDict[chr][ID]['orientation']=='F' or allDataDict[chr][ID]['orientation']=='+':
                            distance2peak=peak-allDataDict[chr][ID]['TSS']
                            allDataDict[chr][TSSDict[chr][allDataDict[chr][ID]['TSS']]['ID']]['outline']=allDataDict[chr][TSSDict[chr][allDataDict[chr][ID]['TSS']]['ID']]['outline']+str(distance2peak)+','
                        if allDataDict[chr][ID]['orientation']=='R' or allDataDict[chr][ID]['orientation']=='-':
                            distance2peak=allDataDict[chr][ID]['TSS']-peak
                            allDataDict[chr][TSSDict[chr][allDataDict[chr][ID]['TSS']]['ID']]['outline']=allDataDict[chr][TSSDict[chr][allDataDict[chr][ID]['TSS']]['ID']]['outline']+str(distance2peak)+','
                    allDataDict[chr][TSSDict[chr][allDataDict[chr][ID]['TSS']]['ID']]['outline']=allDataDict[chr][TSSDict[chr][allDataDict[chr][ID]['TSS']]['ID']]['outline']+'\t'
                    for (peak, height) in TSSDict[chr][allDataDict[chr][ID]['TSS']]['sites']:
                        allDataDict[chr][TSSDict[chr][allDataDict[chr][ID]['TSS']]['ID']]['outline']=allDataDict[chr][TSSDict[chr][allDataDict[chr][ID]['TSS']]['ID']]['outline']+str(height)+','
                TSSDict[chr][allDataDict[chr][ID]['TSS']]['sites']=[]

    outfile.write(outline.strip()+'\n')
    for chr in allDataDict.keys():
        for ID in allDataDict[chr].keys():
            allDataDict[chr][ID]['outline']=allDataDict[chr][ID]['outline'].strip()+'\n'
            outfile.write(allDataDict[chr][ID]['outline'])
    outfile.close()
   
run()
