##################################
#                                #
# Last modified 05/15/2012       # 
#                                #
# Georgi Marinov                 #
#                                # 
##################################

import sys
import string
from sets import Set

def getReverseComplement(preliminarysequence):
    
    DNA = {'A':'T','T':'A','G':'C','C':'G','N':'N','a':'t','t':'a','g':'c','c':'g','n':'n'}
    sequence=''
    for i in range(len(preliminarysequence)):
        sequence=sequence+DNA[preliminarysequence[len(preliminarysequence)-i-1]]
    return sequence

def run():

    if len(sys.argv) < 5:
        print 'usage: python %s maternal_gtf maternal_fasta paternal_gtf paternal_fasta outfilename [-polyA length]' % sys.argv[0]
        print '       GTF files must have gene names'
        print '       GTF files must have the same set of genes and transcripts'
        sys.exit(1)

    mat_gtf = sys.argv[1]
    pat_gtf = sys.argv[3]
    mat_fasta = sys.argv[2]
    pat_fasta = sys.argv[4]
    outputfilename = sys.argv[5]
    doPolyA=False
    if '-polyA' in sys.argv:
        doPolyA=True
        tailsize=int(sys.argv[sys.argv.index('-polyA')+1])
        tail=''
        for i in range(tailsize):
            tail=tail+'A'
        print 'will add a polyA tail of ', tailsize, 'nt'

    lineslist = open(mat_gtf)
    MatTranscriptDict={}
    for line in lineslist:
        if line.startswith('#'):
            continue
        fields=line.strip().split('\t')
        if fields[2]!='exon':
            continue
        GeneID=fields[8].split('gene_id "')[1].split('";')[0]
        GeneName=fields[8].split('gene_name "')[1].split('";')[0]
        TranscriptID=fields[8].split('transcript_id "')[1].split('";')[0]
        if MatTranscriptDict.has_key((GeneName,GeneID,TranscriptID)):
            pass
        else:
            MatTranscriptDict[(GeneName,GeneID,TranscriptID)]=[]
        chr=fields[0]
        left=int(fields[3])
        right=int(fields[4])
        orientation=fields[6]
        MatTranscriptDict[(GeneName,GeneID,TranscriptID)].append((chr,left,right,orientation))

    lineslist = open(pat_gtf)
    PatTranscriptDict={}
    for line in lineslist:
        if line.startswith('#'):
            continue
        fields=line.strip().split('\t')
        if fields[2]!='exon':
            continue
        GeneID=fields[8].split('gene_id "')[1].split('";')[0]
        GeneName=fields[8].split('gene_name "')[1].split('";')[0]
        TranscriptID=fields[8].split('transcript_id "')[1].split('";')[0]
        if PatTranscriptDict.has_key((GeneName,GeneID,TranscriptID)):
            pass
        else:
            PatTranscriptDict[(GeneName,GeneID,TranscriptID)]=[]
        chr=fields[0]
        left=int(fields[3])
        right=int(fields[4])
        orientation=fields[6]
        PatTranscriptDict[(GeneName,GeneID,TranscriptID)].append((chr,left,right,orientation))

    MatSequenceDict={}
    inputdatafile = open(mat_fasta)
    i=0
    for line in inputdatafile:
        if line[0]=='>':
            if i !=0:
                MatSequenceDict[chr] = ''.join(sequence)
            chr = line.strip().split('>')[1].split('_maternal')[0]
            print chr
            sequence=[]
            i+=1
            continue
        else:
            sequence.append(line.strip())   
    MatSequenceDict[chr] = ''.join(sequence)

    PatSequenceDict={}
    inputdatafile = open(pat_fasta)
    i=0
    for line in inputdatafile:
        if line[0]=='>':
            if i !=0:
                PatSequenceDict[chr] = ''.join(sequence)
            chr = line.strip().split('>')[1].split('_paternal')[0]
            print chr
            sequence=[]
            i+=1
            continue
        else:
            sequence.append(line.strip())   
    PatSequenceDict[chr] = ''.join(sequence)

    for chr in MatSequenceDict.keys():
        print chr, len(MatSequenceDict[chr]), len(PatSequenceDict[chr])

    outfile = open(outputfilename, 'w')

    keys = PatTranscriptDict.keys()
    keys.sort()
    print 'Found', len(keys)
    g=0
    for (GeneName,GeneID,TranscriptID) in keys:
        g+=1
        if g % 10000 == 0:
            print g, 'transcripts sequences processed'
        print MatTranscriptDict[(GeneName,GeneID,TranscriptID)]
        print PatTranscriptDict[(GeneName,GeneID,TranscriptID)]
        mat_sequence=''
        MatTranscriptDict[(GeneName,GeneID,TranscriptID)].sort()
        orientation = MatTranscriptDict[(GeneName,GeneID,TranscriptID)][0][3]
        if orientation=='+':
            for (chr,left,right,orientation) in MatTranscriptDict[(GeneName,GeneID,TranscriptID)]:
                mat_sequence = mat_sequence + MatSequenceDict[chr][left-1:right]
            sense='plus_strand'
        if orientation=='-':
            for (chr,left,right,orientation) in MatTranscriptDict[(GeneName,GeneID,TranscriptID)]:
                mat_sequence = mat_sequence + MatSequenceDict[chr][left-1:right]
            mat_sequence = getReverseComplement(mat_sequence)
        outfile.write('>' + GeneName + ':' + GeneID + ':' + TranscriptID + ':mat' + '\n')
        if doPolyA:
            mat_sequence = mat_sequence + tail
        for b in range(0,len(mat_sequence),50):
            outfile.write(mat_sequence[b:min(b+50, len(mat_sequence))] + '\n')

        pat_sequence=''
        PatTranscriptDict[(GeneName,GeneID,TranscriptID)].sort()
        orientation = PatTranscriptDict[(GeneName,GeneID,TranscriptID)][0][3]
        if orientation=='+':
            for (chr,left,right,orientation) in PatTranscriptDict[(GeneName,GeneID,TranscriptID)]:
                pat_sequence = pat_sequence + PatSequenceDict[chr][left-1:right]
            sense='plus_strand'
        if orientation=='-':
            for (chr,left,right,orientation) in PatTranscriptDict[(GeneName,GeneID,TranscriptID)]:
                pat_sequence = pat_sequence + PatSequenceDict[chr][left-1:right]
            pat_sequence = getReverseComplement(pat_sequence)
        outfile.write('>' + GeneName + ':' + GeneID + ':' + TranscriptID + ':pat' + '\n')
        if doPolyA:
            pat_sequence = pat_sequence + tail
        for b in range(0,len(pat_sequence),50):
            outfile.write(pat_sequence[b:min(b+50, len(pat_sequence))] + '\n')

    outfile.close()

run()

