##################################
#                                #
# Last modified 11/07/2011       # 
#                                #
# Georgi Marinov                 #
#                                # 
##################################

import sys
import string
import math
import copy
from sets import Set

def reverseComplement(sequence,DNA):
    
    reversesequence=''
    for i in range(len(sequence)):
        reversesequence=reversesequence+DNA[sequence[len(sequence)-i-1]]
    
    return reversesequence

def fix_separator(line):

    while '  ' in line:
        line=line.replace('  ',' ')

    return line

def findPosInAlignmentWithDashes(sequence,position):

    pos=0.0
    pos += position
    dashes=sequence[0:int(pos)].count('-')
    while pos-dashes < position:
       pos = position + dashes
       dashes=sequence[0:int(pos)].count('-')

    return int(pos)

def run():

    if len(sys.argv) < 5:
        print 'usage: python %s maf_directory input chrFieldID genome_symbol outfile_prefix [-GTF] [-spliced] [-class_code symbol] [-singleFile]' % sys.argv[0]
        print '       Note: this script will output alignments in separate fasta file, one for each region'
        print '       if you use the -GTF option, the input file will be assumed to be GTF and sequence of transcripts will be outputted, at the plus or minus strand'
        print '       the -spliced and -class_code option are to be used with GTF files only'
        print '       Run on single chromosomes if a lot of regions or a full genome annotation GTF file are used'
        print '       -singleFile output option works for bed only'
        sys.exit(1)
    
    doGTF=False
    if '-GTF' in sys.argv:
        doGTF=True

    doSingleFile=False
    if '-singleFile' in sys.argv:
        doSingleFile=True

    doClassCode=False
    doSpliced = False

    MAF = sys.argv[1]
    input = sys.argv[2]
    chrFieldID = int(sys.argv[3])
    genome = sys.argv[4]
    outfileprefix = sys.argv[5]

    DNA = {'A':'T','T':'A','G':'C','C':'G','N':'N','a':'t','t':'a','g':'c','c':'g','n':'n','-':'-'}

    RegionDict={}
    linelist=open(input)
    if doGTF:
        doSpliced=False
        if '-spliced' in sys.argv:
            doSpliced=True
            print 'will only look at transciprs with more than one exon'
        if '-class_code' in sys.argv:
            doClassCode=True
            class_code=sys.argv[sys.argv.index('-class_code')+1]
            print 'will only look at transciprs if class code', class_code
        TranscriptDict={}
        for line in linelist:
            if line.startswith('#'):
                continue
            fields=line.split('\t')
            if fields[2]!='exon':
                continue
            if doClassCode:
                if 'class_code "' in fields[8]:
                    cc = fields[8].split('class_code "')[1].split('";')[0]
                    if cc != class_code:
                        continue
                else:
                    continue
            chr=fields[0]
            if TranscriptDict.has_key(chr):
                pass
            else:
                TranscriptDict[chr]={}
            left=int(fields[3])
            right=int(fields[4])
            strand=fields[6]
            TranscriptID=fields[8].split('transcript_id "')[1].split('";')[0]
            if TranscriptDict[chr].has_key(TranscriptID):
                pass
            else:
                TranscriptDict[chr][TranscriptID]=[]
            TranscriptDict[chr][TranscriptID].append((left,right,strand))
        for chr in TranscriptDict.keys():
            if RegionDict.has_key(chr):
                pass
            else:
                RegionDict[chr]={}
            for TranscriptID in TranscriptDict[chr].keys():
                TranscriptDict[chr][TranscriptID] = list(Set(TranscriptDict[chr][TranscriptID]))
                if doSpliced:
                    if len(TranscriptDict[chr][TranscriptID]) == 1:
                        del TranscriptDict[chr][TranscriptID]
                        continue
                for (left,right,strand) in TranscriptDict[chr][TranscriptID]:
                    RegionDict[chr][(left,right,strand)]=''
    else:
        for line in linelist:
            if line.startswith('#'):
                continue
            fields=line.split('\t')
            chr=fields[chrFieldID]
            left=int(fields[chrFieldID+1])
            right=int(fields[chrFieldID+2])
            strand='+'
            if RegionDict.has_key(chr):
                pass
            else:
                RegionDict[chr]={}
            RegionDict[chr][(left,right,strand)]=''

    for chr in RegionDict.keys():
        linelist=open(MAF+'/'+chr+'.maf')
        SpeciesDict={}
        BlockDict={}
        RegionToBlockDict={}
        for (left,right,strand) in RegionDict[chr].keys():
            RegionToBlockDict[(left,right,strand)]=[]
        c=0
        KeepBlock=False
        for line in linelist:
            if line.startswith('#'):
                continue
            if line==('\n'):
                continue
            if line.startswith('a score'):
                KeepBlock=False
                continue
            line=fix_separator(line)
            fields=line.strip().split(' ')
            if line.startswith('s '+genome):
                start=int(fields[2])
                end=start+int(fields[3])
                c+=1
                if c % 10000 == 0:
                    print chr, start,end
                for (left,right,strand) in RegionDict[chr].keys():
                    if (left >= start and left < end) or (right >= start and right < end) or (start > left and start <= right) or (end > left and end <= right):
                        KeepBlock=True
                        RegionToBlockDict[(left,right,strand)].append((start,end))
                if KeepBlock:
                    BlockDict[(start,end)]={}
                    BlockDict[(start,end)][genome]=fields
                continue
            else:
                if line.startswith('s '):
                    species=fields[1].split('.')[0]
                    SpeciesDict[species]=''
                    if KeepBlock:
                        BlockDict[(start,end)][species]=fields
        speciesList=SpeciesDict.keys()
        speciesList.append(genome)
        speciesList.sort()
        if doGTF:
            for TranscriptID in TranscriptDict[chr].keys():
                TranscriptDict[chr][TranscriptID].sort()
                if TranscriptDict[chr][TranscriptID][0][2]=='-':
                    TranscriptDict[chr][TranscriptID].reverse()
                outfile=open(outfileprefix+'_'+TranscriptID,'w')
                outfile.write('>'+genome+'|'+TranscriptID+'\n')
                FinalSequenceDict={}
                for species in speciesList:
                    FinalSequenceDict[species]=''
                for (left,right,strand) in TranscriptDict[chr][TranscriptID]:
                    N=len(RegionToBlockDict[(left,right,strand)])
                    i=0
                    for (start,end) in RegionToBlockDict[(left,right,strand)]:
                        i+=1
                        sequence=BlockDict[(start,end)][genome][6]
                        if i==1:
                            if N==1:
                                leftPos=findPosInAlignmentWithDashes(sequence,left-start)
                                rightPos=findPosInAlignmentWithDashes(sequence,right-start)
                            else:
                                leftPos=findPosInAlignmentWithDashes(sequence,left-start)
                                rightPos=len(sequence)
                        if i==N and N!=1:
                            rightPos=findPosInAlignmentWithDashes(sequence,right-start)
                            leftPos=0
                        if i!=1 and i!=N and N>2:
                            leftPos=0
                            rightPos=len(sequence)
                        for species in speciesList:
                            if BlockDict[(start,end)].has_key(species):
                                FinalSequenceDict[species]+=BlockDict[(start,end)][species][6][leftPos:rightPos]
                            else:
                                for j in range(leftPos,rightPos):
                                    FinalSequenceDict[species]+='-'
                if strand == '-':
                    outfile.write(reverseComplement(FinalSequenceDict[genome],DNA)+'\n')
                else:
                    outfile.write(FinalSequenceDict[genome]+'\n')
                for species in speciesList:
                    if species!=genome:
                        outfile.write('>'+ species + '\n')
                        if strand == '-':
                            outfile.write(reverseComplement(FinalSequenceDict[species],DNA)+'\n')
                        else:
                            outfile.write(FinalSequenceDict[species]+'\n')
                outfile.close()        
        else:
            if doSingleFile:
                outfile=open(outfileprefix + '.MAF', 'w')
            for (left,right,strand) in RegionDict[chr].keys():
                N=len(RegionToBlockDict[(left,right,strand)])
                i=0
                if doSingleFile:
                    pass
                else:
                    outfile=open(outfileprefix+'_'+chr+':'+str(left)+'-'+str(right),'w')
                outfile.write('>'+genome+'|'+chr+':'+str(left)+'-'+str(right)+'|'+strand+'\n')
                FinalSequenceDict={}
                for species in speciesList:
                    FinalSequenceDict[species]=''
                for (start,end) in RegionToBlockDict[(left,right,strand)]:
                    i+=1
                    sequence=BlockDict[(start,end)][genome][6]
                    if i==1:
                        if N==1:
                            leftPos=findPosInAlignmentWithDashes(sequence,left-start)
                            rightPos=findPosInAlignmentWithDashes(sequence,right-start)
                        else:
                            leftPos=findPosInAlignmentWithDashes(sequence,left-start)
                            rightPos=len(sequence)
                    if i==N and N!=1:
                        rightPos=findPosInAlignmentWithDashes(sequence,right-start)
                        leftPos=0
                    if i!=1 and i!=N and N>2:
                        leftPos=0
                        rightPos=len(sequence)
                    for species in speciesList:
                        if BlockDict[(start,end)].has_key(species):
                            FinalSequenceDict[species]+=BlockDict[(start,end)][species][6][leftPos:rightPos]
                        else:
                            for j in range(leftPos,rightPos):
                               FinalSequenceDict[species]+='-'
                if strand == '-':
                    outfile.write(reverseComplement(FinalSequenceDict[genome])+'\n')
                else:
                    outfile.write(FinalSequenceDict[species]+'\n')
                for species in speciesList:
                    if species!=genome:
                        outfile.write('>'+ species + '\n')
                        if strand == '-':
                            outfile.write(reverseComplement(FinalSequenceDict[species])+'\n')
                        else:
                            outfile.write(FinalSequenceDict[species]+'\n')
                if doSingleFile:
                    pass
                else:
                    outfile.close()
                if doSingleFile:
                    outfile.write('###########################################################################\n')
            if doSingleFile:
                outfile.close()
   
run()
