##################################
#                                #
# Last modified 02/15/2011       # 
#                                #
# Georgi Marinov                 #
#                                # 
##################################

import sys
import string
import math
from sets import Set
from cistematic.core.geneinfo import geneinfoDB
from cistematic.genomes import Genome

def getReverse(sequence):

    DNA = {'A':'T','T':'A','G':'C','C':'G','N':'N'}
    sequence=sequence[::-1]
    newsequence=''
    for i in range(len(sequence)):
        newsequence=newsequence+DNA[sequence[i]]
    return newsequence

def run():

    if len(sys.argv) < 3:
        print 'usage: python %s genome junctions outfilename [-withcounts]' % sys.argv[0]
        print '	junctions file format: chr <tab> left <tab> right'
        print '	if -withcount: chr <tab> left <tab> right <tab> total <tab> staggered'
        sys.exit(1)

    genome = sys.argv[1]
    datafilename = sys.argv[2]
    outfilename = sys.argv[3]

    doCounts=False
    if '-withcounts':
        doCounts=True

    DNA = {'A':'T','T':'A','G':'C','C':'G','N':'N'}

    outfile = open(outfilename, 'w')
    hg = Genome(genome)

    outfile.write('#chr\tleft\tright\tstrand\ttotal\tstaggered\t5-exon\t5-intron|3-intron\t3-exon\n')

    NCCount=0
    lineslist  = open(datafilename)
    k=0
    for line in lineslist:
        k+=1
        if k % 10000 == 0:
            print k
        if line[0]=='#':
            continue
        fields = line.strip().split('\t')
        chr=fields[0]
        left=int(fields[1])-1
        right=int(fields[2])+1
        chromosome = chr[3:len(chr)]
        try:
            sequenceleft = string.upper(hg.sequence(chromosome,left-2,4))
            sequenceright = string.upper(hg.sequence(chromosome,right-2,4))
            print string.upper(hg.sequence(chromosome,left-4,8)), string.upper(hg.sequence(chromosome,right-4,8)), getReverse(hg.sequence(chromosome,left-4,8)), getReverse(hg.sequence(chromosome,right-4,8))
        except:
            print 'failed to retrieve sequence, skipping', chromosome,left,right, 
            continue
        print sequenceleft,sequenceright
        if sequenceleft[2:4] == 'GT' and sequenceright[0:2] == 'AG':
            sequence5 = sequenceleft
            sequence3 = sequenceright
            strand='+'
        elif getReverse(sequenceright)[2:4] == 'GT' and getReverse(sequenceleft)[0:2] == 'AG':
            sequence5 = getReverse(sequenceleft)
            sequence3 = getReverse(sequenceleft)
            strand='-'
        elif sequenceleft[2:4] == 'GC' and sequenceright[0:2] == 'AG':
            sequence5 = sequenceleft
            sequence3 = sequenceright
            strand='+'
        elif getReverse(sequenceright)[2:4] == 'GC' and getReverse(sequenceleft)[0:2] == 'AG':
            sequence5 = getReverse(sequenceleft)
            sequence3 = getReverse(sequenceleft)
            strand='-'
        elif sequenceleft[2:4] == 'AT' and sequenceright[0:2] == 'AC':
            sequence5 = sequenceleft
            sequence3 = sequenceright
            strand='+'
        elif getReverse(sequenceright)[2:4] == 'AT' and getReverse(sequenceleft)[0:2] == 'AC':
            sequence5 = getReverse(sequenceleft)
            sequence3 = getReverse(sequenceleft)
            strand='-'
        else:
            sequence5 = 'NC'
            sequence3 = 'NC'
            strand = '.'
            NCCount+=1
            print NCCount, 
        if doCounts:
            total=fields[3]
            staggered=fields[4]
            outline=chr+'\t'+str(left)+'\t'+str(right)+'\t'+strand+'\t'+total+'\t'+staggered+'\t'+sequence5[0:2]+'\t'+sequence5[2:4]+'|'+sequence3[0:2]+'\t'+sequence3[2:4]
        else:
            outline=line.strip()+'\t'+sequence5[0:2]+'\t'+sequence5[2:4]+'|'+sequence3[0:2]+'\t'+sequence3[2:4]
        print outline
        outfile.write(outline+'\n')

    outfile.close()
        
run()

