##################################
#                                #
# Last modified 04/03/2013       # 
#                                #
# Georgi Marinov                 #
#                                # 
##################################

import sys
import string

def getReverseComplement(preliminarysequence):
    
    DNA = {'A':'T','T':'A','G':'C','C':'G','N':'N','a':'t','t':'a','g':'c','c':'g','n':'n'}
    sequence=''
    for j in range(len(preliminarysequence)):
        sequence=sequence+DNA[preliminarysequence[len(preliminarysequence)-j-1]]
    return sequence


def run():

    if len(sys.argv) < 4:
        print 'usage: python %s fasta GTF span outputfilename' % sys.argv[0]
        print '\t this script will take a GTF file and output all circularized junctions, i.e. sequences joining a splice sites to all acceptors upstream of it in the transcript'
        print '\t the span parameter referse to the length of sequence on each side, i.e. if you have 75bp reads, you would want to use a span around 60 for stringency purposes'
        sys.exit(1)

    fasta = sys.argv[1]
    GTF = sys.argv[2]
    span = int(sys.argv[3])
    outfilename = sys.argv[4]

    inputdatafile = open(fasta)
    SequenceDict={}
    sequence = ''
    for line in inputdatafile:
        if line[0]=='>':
            if sequence != '':
                sequence = ''.join(sequence)
                SequenceDict[chr]=sequence
            chr = line.strip().split('>')[1]
            print chr
            sequence=[]
        else:
            sequence.append(line.strip())   
    sequence = ''.join(sequence)
    SequenceDict[chr]=sequence
   
    j=0
    lineslist = open(GTF)
    TranscriptDict={}
    for line in lineslist:
        j+=1
        if j % 100000 == 0:
            print j, 'lines processed'
        if line.startswith('#'):
            continue
        fields=line.strip().split('\t')
        if fields[2]!='exon':
            continue
        chr=fields[0]
        if SequenceDict.has_key(chr):
            pass
        else:
            continue
        left=int(fields[3])
        right=int(fields[4])
        strand=fields[6]
        if 'gene_name "' in fields[8]:
            gene=fields[8].split('gene_name "')[1].split('";')[0]
        else:
            gene=fields[8].split('gene_id "')[1].split('";')[0]
        if 'transcript_name "' in fields[8]:
            transcript=fields[8].split('transcript_name "')[1].split('";')[0]
        else:
            transcript=fields[8].split('transcript_id "')[1].split('";')[0]
        ID = (gene,transcript)
        if TranscriptDict.has_key(ID):
            pass
        else:
            TranscriptDict[ID]=[]
        TranscriptDict[ID].append((chr,left,right,strand))

    outfile=open(outfilename,'w')

    g=0 
    print 'found', len(TranscriptDict.keys()), 'transcripts'
    for (gene,transcript) in TranscriptDict.keys():
        g+=1
        if g % 10000 == 0:
            print g, 'transcripts sequences processed'
        leftEnds=[]
        rightEnds=[] 
        TranscriptDict[(gene,transcript)].sort()
        orientation = TranscriptDict[(gene,transcript)][0][3]
        JunctionPoints=[]
        sequence = ''
        if orientation=='+':
            for (chr,left,right,strand) in TranscriptDict[(gene,transcript)]:
                sequence = sequence + SequenceDict[chr][left-1:right]
                JunctionPoints.append(len(sequence))
        if orientation=='-':
            TranscriptDict[(gene,transcript)].reverse()
            for (chr,left,right,strand) in TranscriptDict[(gene,transcript)]:
                sequence = sequence + getReverseComplement(SequenceDict[chr][left-1:right])
                JunctionPoints.append(len(sequence))
        NumJunctions = len(JunctionPoints)-1
        if NumJunctions == 0:
            continue
        JunctionPoints.append(0)
        JunctionPoints.sort()
        for i in range(1,len(JunctionPoints)):
            for j in range(i):
                CircJuncSequence = sequence[JunctionPoints[i]-span:JunctionPoints[i]] + sequence[JunctionPoints[j]:JunctionPoints[j]+span]
                if len(CircJuncSequence) < 2*span:
                    continue
                outline = '>' + gene + ':' + transcript + ':CircJunc:' + str(JunctionPoints[i]) + ':' + str(JunctionPoints[j])
                outfile.write(outline + '\n')
                outfile.write(CircJuncSequence + '\n')

    outfile.close()

run()

