##################################
#                                #
# Last modified 03/06/2011       # 
#                                #
# Georgi Marinov                 #
#                                # 
##################################

import string
import sys
import re

def countCAGSTG(sequence):
    
    count=0
    partition=sequence.split('CAGGTG')
    count=count+len(partition)-1
    partition=sequence.split('CAGCTG')
    count=count+len(partition)-1
    partition=sequence.split('CACCTG')
    count=count+len(partition)-1
    return count

def countEb5Eb(sequence):
    
    count=0
    for i in range(len(sequence)-17):
        if sequence[i:i+2]=='CA':
            if sequence[i+4:i+6]=='TG':
                if sequence[i+11:i+13]=='CA':
                    if sequence[i+15:i+17]=='TG':
                        if (sequence[i+2:i+4]=='GC' or sequence[i+2:i+4]=='GG' or sequence[i+2:i+4]=='CC') and (sequence[i+13:i+15]=='GC' or sequence[i+13:i+15]=='GG' or sequence[i+13:i+15]=='CC'):
                            count=count+1
    return count

def countCAGSTGconservation(sequence):
    
    count=0
    positions=[]

    m = re.compile('CA(GG|GC|CC)TG')

    for mo in m.finditer(sequence):
        positions.append(mo.start())

    count=len(positions)
    return (count, positions)

def run():

    if len(sys.argv) < 3:
        print 'usage: python %s genome.fa regionfile outfilename [-PhastCons PhastConsScoresDirectory FileSuffix threshold] [-conserved multiple-alignment-directory species_code] [-startField fieldID] [-skiprandom]' % sys.argv[0]
        sys.exit(1)

    genome = sys.argv[1]
    regionfilename = sys.argv[2]
    outputfilename = sys.argv[3]
    doConserved=False
    doPhastCons=False
    if '-PhastCons' in sys.argv:
        doPhastCons=True
        PCDir=sys.argv[sys.argv.index('-PhastCons') + 1]
        FileSuffix=sys.argv[sys.argv.index('-PhastCons') + 2]
        PCthreshold=float(sys.argv[sys.argv.index('-PhastCons') + 3])
    if '-conserved' in sys.argv:
        doConserved=True
        MFADir=sys.argv[sys.argv.index('-conserved') + 1]
        SpeciesCode=sys.argv[sys.argv.index('-conserved') + 2]
    startField=1
    if '-startField' in sys.argv:
        startField=int(sys.argv[sys.argv.index('-startField') + 1])
    if doPhastCons and doConserved:
        print 'choose either PhastCons or multiple alignment conservation' % sys.argv[0]
        sys.exit(1)

    doSkipRandom=False
    if '-skiprandom' in sys.argv:
        doSkipRandom=True
        print 'Skipping random chromosomes'

    GenomeDict={}

    sequence=''
    inputdatafile = open(fasta)
    for line in inputdatafile:
        if line[0]=='>':
            if sequence != '':
                GenomeDict[chr] = ''.join(sequence)
            chr = line.strip().split('>')[1]
            print chr
            sequence=[]
            Keep=False
            continue
        else:
            sequence.append(line.strip())
    GenomeDict[chr] = ''.join(sequence)

    outfile = open(outputfilename, 'w')
    
    listoflines = open(regionfilename)
    lineslist = listoflines.readlines()
    if not doConserved and not doPhastCons:
        for line in lineslist:
            if line[0]=='#':
                continue
            fields = line.split('\n')[0].split('\t')
            chromosome = fields[startField]
            if doSkipRandom and chromosome.count('random')>0:
                continue
            start = int(fields[startField+1])
            stop = int(fields[startField+2])
            sequence=GenomeDict[chromosome][start:stop]
            count=countCAGSTG(sequence)
            newline=line.split('\n')[0]+'\t'+str(count)+'\n'
            outfile.write(newline)
    print 'fninshed parsing region coordinates'
    if doPhastCons:
        conservationDict={}
        posList=[]
        PCDict={}
        i=0
        lefttogo=0
        for line in lineslist:
            if line[0]=='#':
                continue
            fields = line.split('\n')[0].split('\t')
            chromosome = fields[startField]
            if doSkipRandom and chromosome.count('random')>0:
                continue
            start = int(fields[startField+1])
            stop = int(fields[startField+2])
            if chromosome not in conservationDict.keys():
                conservationDict[chromosome]={}
            regioni='region'+str(i)
            conservationDict[chromosome][regioni]={}
            conservationDict[chromosome][regioni]['start']=start
            conservationDict[chromosome][regioni]['stop']=stop
            conservationDict[chromosome][regioni]['OUTPUTLINE']=line.split('\n')[0]
            if PCDict.has_key(chromosome):
                pass
            else:
                print chromosome
                PCDict[chromosome]={}
            for x in range(start,stop):
                PCDict[chromosome][x]=0.0
            i+=1
            lefttogo+=1
        headerline=''
        for i in range(len(fields)):
            headerline=headerline+str(i)+'\t'
        headerline=headerline+'E-boxes'+'\t'+'conserved E-boxes'+'\t'+'fraction of region conserved'+'\t'+'average region conservation'+'\n'
        outfile.write(headerline)
        for chromosome in conservationDict.keys():
            PCfile=PCDir+'/'+chromosome+'.'+FileSuffix
            PCfile=open(PCfile)
            PClines=PCfile.readlines()
            print 'len(PClines)', len(PClines)
            u=0
            print 'len(conservationDict[chromosome])', len(conservationDict[chromosome])
            i=len(PClines)
#            m = re.compile('^fixed.*start=(?p<start>\d+).*$')
            for line in PClines:
#                mo = m.search(line)
#                if mo:
#                    InitialPos = mo.group('start')
#                    u=0
#                    print line
                if line[0]=='f':
                    InitialPos=int(line.split(' ')[2].split('=')[1])
                    u=0
                else:
                    tau=InitialPos+u
                    if PCDict[chromosome].has_key(tau):
                        PCDict[chromosome][InitialPos + u]=float(line.strip())
                    u+=1 
                if i % 20000000 == 0:
                    print i
                i=i-1
            PCfile.close()
            for regioni in conservationDict[chromosome].keys():
                start=conservationDict[chromosome][regioni]['start']
                stop=conservationDict[chromosome][regioni]['stop']
                sequence=GenomeDict[chromosome][start:stop]
                (counts,positions)=countCAGSTGconservation(sequence)
                conserved=0
                for position in positions:
                    position=position+start
                    if position not in PCDict[chromosome].keys() or position+1 not in PCDict[chromosome].keys() or position+2 not in PCDict[chromosome].keys() or position+3 not in PCDict[chromosome].keys() or position+4 not in PCDict[chromosome].keys() or position+5 not in PCDict[chromosome].keys():
                        continue
                    if PCDict[chromosome][position]>=PCthreshold:
                        if PCDict[chromosome][position+1]>=PCthreshold:
                            if PCDict[chromosome][position+2]>=PCthreshold:
                                if PCDict[chromosome][position+3]>=PCthreshold:
                                    if PCDict[chromosome][position+4]>=PCthreshold:
                                        if PCDict[chromosome][position+5]>=PCthreshold:
                                            conserved+=1
                conservedfraction=0.0
                averageconserved=0.0
                for pos in range(start,stop):
                    if PCDict[chromosome].has_key(pos):
                        if PCDict[chromosome][pos]>=PCthreshold:
                            conservedfraction+=1
                        averageconserved=averageconserved+PCDict[chromosome][pos]
                conservedfraction=conservedfraction/(stop-start)
                averageconserved=averageconserved/(stop-start)
                newline=conservationDict[chromosome][regioni]['OUTPUTLINE']+'\t'+str(counts)+'\t'+str(conserved)+'\t'+str(conservedfraction)+'\t'+str(averageconserved)+'\n'
                outfile.write(newline)
                lefttogo=lefttogo-1
                if lefttogo % 100 == 0:
                    print lefttogo, ' regions left'
    if doConserved:
        conservationDict={}
        for line in lineslist:
            if line[0]=='#':
                continue
            if i % 100 == 0:
                print i
            i+=1
            fields = line.split('\n')[0].split('\t')
            chromosome = fields[startField]
            if doSkipRandom and chromosome.count('random')>0:
                continue
            start = int(fields[startField+1])
            stop = int(fields[startField+2])
            if chromosome not in conservationDict.keys():
                conservationDict[chromosome]={}
            regioni='region'+str(i)
            conservationDict[chromosome][regioni]={}
            conservationDict[chromosome][regioni]['start']=start
            conservationDict[chromosome][regioni]['stop']=stop
            conservationDict[chromosome][regioni]['alignment']={}
            conservationDict[chromosome][regioni]['OUTPUTLINE']=line.split('\n')[0]
        for chromosome in conservationDict.keys():
            maffile=MFADir+'/'+chromosome+'.maf'
            maffile=open(maffile)
            maflines=maffile.readlines()
            for regioni in conservationDict[chromosome].keys():
                print chromosome, regioni
                start=conservationDict[chromosome][regioni]['start']
                stop=conservationDict[chromosome][regioni]['stop']
                parsed=False
                InAlignment=False
                InRegion=False
                AlignmentStarted=False
                CurrentPos=0
                PreviousPos=0
                RegionContained=False
                for mafline in maflines:
                    if parsed:
                       print len(conservationDict[chromosome][regioni]['alignment'].keys())
                       break
                    if mafline[0]=='#':
                        continue
                    if mafline=='\n':
                        InAlignment=False
                        continue
                    if mafline[2:7]=='score':
                        if RegionContained:
                            parsed=True
                        InAlignment=True
                        AlignmentStarted=True
                        PreviousPos=CurrentPos
                        continue
                    if AlignmentStarted and InAlignment and not InRegion:
                        fields=mafline.split(chromosome)[1].split(' ')
                        for field in fields:
                            if field!='':
                                CurrentPos=int(field)
                                break
                        AlignmentStarted=False
                        if CurrentPos >= start:
                            InRegion=True
                            if CurrentPos >= stop:
                                RegionContained=True
                            else:
                                RegionContained=False
                            fields=mafline.split(' ')
                            species=fields[1].split('.')[0]
                            sequence=fields[len(fields)-1]
                            sequencefields=sequence.split('-')
                            sequence=''
                            for s in sequencefields:
                                sequence=sequence+s
                            if RegionContained:
                                seq=sequence[start-PreviousPos:stop-PreviousPos]
                            else:
                                seq=sequence[start-PreviousPos:CurrentPos-PreviousPos]
                            if species not in conservationDict[chromosome][regioni]['alignment'].keys():
                                conservationDict[chromosome][regioni]['alignment'][species]=seq
                            else:
                                conservationDict[chromosome][regioni]['alignment'][species]=conservationDict[chromosome][regioni]['alignment'][species]+seq 
                        continue
                    if AlignmentStarted and InAlignment and InRegion:
                        fields=mafline.split(chromosome)[1].split(' ')
                        for field in fields:
                            if field!='':
                                CurrentPos=int(field)
                                break
                        AlignmentStarted=False
                        if CurrentPos >= stop:
                            RegionContained=True
                        else:
                            RegionContained=False
                        fields=mafline.split(' ')
                        species=fields[1].split('.')[0]
                        sequence=fields[len(fields)-1]
                        sequencefields=sequence.split('-')
                        sequence=''
                        for s in sequencefields:
                            sequence=sequence+s
                        if RegionContained:
                            seq=sequence[start-PreviousPos:stop-PreviousPos]
                        else:
                            seq=sequence[start-PreviousPos:CurrentPos-PreviousPos]
                        if species not in conservationDict[chromosome][regioni]['alignment'].keys():
                            conservationDict[chromosome][regioni]['alignment'][species]=seq
                        else:
                            conservationDict[chromosome][regioni]['alignment'][species]=conservationDict[chromosome][regioni]['alignment'][species]+seq 
                        continue
                    if not AlignmentStarted and InAlignment and not InRegion:
                        continue
                    if not AlignmentStarted and InAlignment and InRegion:
                        fields=mafline.split(' ')
                        species=fields[1].split('.')[0]
                        sequence=fields[len(fields)-1]
                        sequencefields=sequence.split('-')
                        sequence=''
                        for s in sequencefields:
                            sequence=sequence+s
                        if RegionContained:
                            seq=sequence[start-PreviousPos:stop-PreviousPos]
                        else:
                            seq=sequence[start-PreviousPos:CurrentPos-PreviousPos]
                        if species not in conservationDict[chromosome][regioni]['alignment'].keys():
                            conservationDict[chromosome][regioni]['alignment'][species]=seq
                        else:
                            conservationDict[chromosome][regioni]['alignment'][species]=conservationDict[chromosome][regioni]['alignment'][species]+seq 
                        continue
                print conservationDict[chromosome][regioni]['alignment'].keys()
                conserved=0
                counts=countCAGSTG(conservationDict[chromosome][regioni]['alignment'][SpeciesCode])
                sequence=GenomeDict[chromosome][start:stop]
                print 'len(sequence)', stop-start
                print len(conservationDict[chromosome][regioni]['alignment'][SpeciesCode])
                (conservedcounts,positions)=countCAGSTGconservation(conservationDict[chromosome][regioni]['alignment'][SpeciesCode])
                print positions
                for species in conservationDict[chromosome][regioni]['alignment'].keys():
                    for position in positions:
                        Eboxalign=conservationDict[chromosome][regioni]['alignment'][species][position:position+6]
                        if Eboxalign[0:2]=='CA':
                            if Eboxalign[4:6]=='TG':
                                if Eboxalign[2:4]=='GG' or Eboxalign[2:4]=='GC' or Eboxalign[2:4]=='CC':
                                    conserved=conserved+1
                conservationDict[chromosome][regioni]['OUTPUTLINE']=conservationDict[chromosome][regioni]['OUTPUTLINE']+'\t'+str(counts)+'\t'+str(conserved)+'\n'
                outfile.write(conservationDict[chromosome][regioni]['OUTPUTLINE'])
                print conservationDict[chromosome][regioni]['OUTPUTLINE']

    outfile.close()

run()

