##################################
#                                #
# Last modified 04/02/2013       # 
#                                #
# Georgi Marinov                 #
#                                # 
##################################

import sys
import math
import random
import string
from sets import Set

def getReverseComplement(preliminarysequence):
    
    DNA = {'A':'T','T':'A','G':'C','C':'G','N':'N','a':'t','t':'a','g':'c','c':'g','n':'n'}
    sequence=''
    for i in range(len(preliminarysequence)):
        sequence=sequence+DNA[preliminarysequence[len(preliminarysequence)-i-1]]
    return sequence

def run():

    if len(sys.argv) < 4:
        print 'usage: python %s fasta gtf curcular.reads outfilename' % sys.argv[0]
        print '\tNote: The script will only consider read pairs aligning to circular junctions'
        print '\tNote: Use - to indicate standard input if the reads file is compressed and you want to stream it from bzip2 or gzip'
        sys.exit(1)

    fasta = sys.argv[1]
    gtf=sys.argv[2]
    reads=sys.argv[3]
    outputfilename = sys.argv[4]

    doStdInput = False
    if reads == '-':
        doStdInput = True

    outfile = open(outputfilename, 'w')

    j=0
    lineslist = open(gtf)
    TranscriptDict={}
    for line in lineslist:
        j+=1
        if j % 100000 == 0:
            print j, 'lines processed'
        if line.startswith('#'):
            continue
        fields=line.strip().split('\t')
        if fields[2]!='exon':
            continue
        chr=fields[0]
        left=int(fields[3])
        right=int(fields[4])
        strand=fields[6]
        if 'gene_name "' in fields[8]:
            geneName=fields[8].split('gene_name "')[1].split('";')[0]
        else:
            geneName=fields[8].split('gene_id "')[1].split('";')[0]
        if 'transcript_name "' in fields[8]:
            TranscriptName=fields[8].split('transcript_name "')[1].split('";')[0]
        else:
            TranscriptName=fields[8].split('transcript_id "')[1].split('";')[0]
        if TranscriptDict.has_key(geneName):
            pass
        else:
            TranscriptDict[geneName]={}
        if TranscriptDict[geneName].has_key(TranscriptName):
            pass
        else:
            TranscriptDict[geneName][TranscriptName]=[]
        TranscriptDict[geneName][TranscriptName].append((chr,left,right,strand))

    GenomeToTranscriptomeTranscriptDict={}

    for gene in TranscriptDict.keys():
        GenomeToTranscriptomeTranscriptDict[gene]={}
        for transcript in TranscriptDict[gene].keys():
            GenomeToTranscriptomeTranscriptDict[gene][transcript]={}
            strand = TranscriptDict[gene][transcript][0][3]
            TranscriptDict[gene][transcript].sort()
            TranscriptPos = 0
            if strand == '-':
                TranscriptDict[gene][transcript].reverse()
                for (chr,left,right,strand) in TranscriptDict[gene][transcript]:
                    GenomeToTranscriptomeTranscriptDict[gene][transcript][TranscriptPos]=right
                    TranscriptPos += (right - left)
                    GenomeToTranscriptomeTranscriptDict[gene][transcript][TranscriptPos]=left
                    TranscriptPos += 1
                GenomeToTranscriptomeTranscriptDict[gene][transcript][TranscriptPos]=left
            if strand == '+':
                for (chr,left,right,strand) in TranscriptDict[gene][transcript]:
                    GenomeToTranscriptomeTranscriptDict[gene][transcript][TranscriptPos]=left
                    TranscriptPos += (right - left)
                    GenomeToTranscriptomeTranscriptDict[gene][transcript][TranscriptPos]=right
                    TranscriptPos += 1
                GenomeToTranscriptomeTranscriptDict[gene][transcript][TranscriptPos]=right
             
    BackSpliceDict = {}

    z=0
 
    if doStdInput:
        linelist = sys.stdin
    else:
        linelist = open(reads)
    currentRead = ''
    currentReadAlignments = []
    for line in linelist:
        if line.startswith('#'):
            continue
        fields = line.strip().split('\t')
        if fields[6] != 'circular':
            continue
        gene = fields[1]
        transcript = fields[2]
        readID = fields[0]
        read1Pos = fields[3]
        read2Pos = fields[4]
        if 'CircJunc' in read1Pos:
            Splice1 = int(read1Pos.split('CircJunc:')[1].split(':')[0])
            BackSplice1 = int(read1Pos.split('CircJunc:')[1].split(':')[1])
        else:
            Splice1 = ''
            BackSplice1 = ''
        if 'CircJunc' in read2Pos:
            Splice2 = int(read2Pos.split('CircJunc:')[1].split(':')[0])
            BackSplice2 = int(read2Pos.split('CircJunc:')[1].split(':')[1])
        else:
            Splice2 = ''
            BackSplice2 = ''
        if currentRead == '':
            currentRead = readID
        if readID == currentRead:
            currentReadAlignments.append((gene,transcript,readID,read1Pos,read2Pos,Splice1,BackSplice1,Splice2,BackSplice2))
        else:
            BackSpliceGenomicPositions=[]
            for (geneName,transcriptName,oldreadID,oldread1Pos,oldread2Pos,oldSplice1,oldBackSplice1,oldSplice2,oldBackSplice2) in currentReadAlignments:
                if oldSplice1 == oldSplice2 and oldBackSplice1 == oldBackSplice2:
#                    print geneName,transcriptName,oldreadID,oldread1Pos,oldread2Pos,oldSplice1,oldBackSplice1,oldSplice2,oldBackSplice2
#                    print GenomeToTranscriptomeTranscriptDict[geneName][transcriptName]
                    Splice1Genomic = GenomeToTranscriptomeTranscriptDict[geneName][transcriptName][oldSplice1-1]
                    BackSplice1Genomic = GenomeToTranscriptomeTranscriptDict[geneName][transcriptName][oldBackSplice1]
                    BackSpliceGenomicPositions.append((Splice1Genomic,BackSplice1Genomic))
                else:
                    if oldSplice1 != '':
                        Splice1Genomic = GenomeToTranscriptomeTranscriptDict[geneName][transcriptName][oldSplice1-1]
                        BackSplice1Genomic = GenomeToTranscriptomeTranscriptDict[geneName][transcriptName][oldBackSplice1]
                        BackSpliceGenomicPositions.append((Splice1Genomic,BackSplice1Genomic))
                    if oldSplice2 != '':
                        Splice2Genomic = GenomeToTranscriptomeTranscriptDict[geneName][transcriptName][oldSplice2-1]
                        BackSplice2Genomic = GenomeToTranscriptomeTranscriptDict[geneName][transcriptName][oldBackSplice2]
                        BackSpliceGenomicPositions.append((Splice2Genomic,BackSplice2Genomic))
            BackSpliceGenomicPositions = list(Set(BackSpliceGenomicPositions))
            if len(BackSpliceGenomicPositions) > 1:
#                print '%%%%%%%%%%%%%%%%%%%%%%%'
#                print oldreadID, BackSpliceGenomicPositions
#                print currentReadAlignments
#                print GenomeToTranscriptomeTranscriptDict[geneName]
#                print '================'
                z+=1
                pass
            else:
                if BackSpliceDict.has_key(geneName):
                    pass
                else:
                    BackSpliceDict[geneName]={}
                if BackSpliceDict[geneName].has_key(BackSpliceGenomicPositions[0]):
                    pass
                else:
                    BackSpliceDict[geneName][BackSpliceGenomicPositions[0]]=0
                BackSpliceDict[geneName][BackSpliceGenomicPositions[0]]+=1
            currentRead = readID
            currentReadAlignments = []
            currentReadAlignments.append((gene,transcript,readID,read1Pos,read2Pos,Splice1,BackSplice1,Splice2,BackSplice2))

    print 'ignored ', z, 'due to mapping to multiple genes or locations in the genome'
    print 'found', len(BackSpliceDict.keys()), 'back splices'

    GenomeDict={}

    sequence=''
    inputdatafile = open(fasta)
    for line in inputdatafile:
        if line[0]=='>':
            if sequence != '':
                GenomeDict[chr] = ''.join(sequence)
            chr = line.strip().split('>')[1]
            print chr
            sequence=[]
            Keep=False
            continue
        else:
            sequence.append(line.strip())
    GenomeDict[chr] = ''.join(sequence)

    outfile = open(outputfilename,'w')
    outfile.write('#GeneName\tchr\tSplice\tbackSplice\tstrand\tReads\tShortest_Splce_exon_sequence\tShortest_backSplce_exon_sequence\n')
    for gene in BackSpliceDict.keys():
        for (splice,backsplice) in BackSpliceDict[gene].keys():
            spliceExons = []
            backspliceExons = []
            for transcript in TranscriptDict[gene].keys():
                strand = TranscriptDict[gene][transcript][0][3]
                if strand == '-':
                    for (chr,left,right,strand) in TranscriptDict[gene][transcript]:
                        if splice == left:
                            spliceExons.append((left,right))
                        if backsplice == right:
                            backspliceExons.append((left,right))
                if strand == '+':
                    for (chr,left,right,strand) in TranscriptDict[gene][transcript]:
                        if splice == right:
                            spliceExons.append((left,right))
                        if backsplice == left:
                            backspliceExons.append((left,right))
            shortestSpliceExon = ('','')
            shortestSpliceExonLength = 10000000000000
            for (left,right) in spliceExons:
                if (right - left) < shortestSpliceExonLength:
                    shortestSpliceExonLength = (right - left)
                    shortestSpliceExon = (left,right)
            shortestBackSpliceExon = ('','')
            shortestBackSpliceExonLength = 10000000000000
            for (left,right) in backspliceExons:
                if (right - left) < shortestBackSpliceExonLength:
                    shortestBackSpliceExonLength = (right - left)
                    shortestBackSpliceExon = (left,right)
            outline = gene + '\t' + chr + '\t' + str(splice) + '\t' + str(backsplice) + '\t' + strand + '\t' + str(BackSpliceDict[gene][(splice,backsplice)])
            if strand == '+':
                sequence = GenomeDict[chr][shortestSpliceExon[0]-1:shortestSpliceExon[1]] + '\t' + GenomeDict[chr][shortestBackSpliceExon[0]-1:shortestBackSpliceExon[1]]
            if strand == '-':
                sequence = getReverseComplement(GenomeDict[chr][shortestSpliceExon[0]-1:shortestSpliceExon[1]]) + '\t' + getReverseComplement(GenomeDict[chr][shortestBackSpliceExon[0]-1:shortestBackSpliceExon[1]])
            outline = outline + '\t' + sequence
            outfile.write(outline + '\n')

    outfile.close()

run()

