##################################
#                                #
# Last modified 02/17/2013       # 
#                                #
# Georgi Marinov                 #
#                                # 
##################################

import sys
import string
import math
from sets import Set

def run():

    if len(sys.argv) < 3:
        print 'usage: python %s genome.fa junctions outfilename [-1-based]' % sys.argv[0]
        print '\tjunctions file format: chr <tab> left <tab> right <tab> strand <other fields>'
        print '\tthe output file will insert the junctions information between the first four fields and the rest of the fields'
        sys.exit(1)

    fasta = sys.argv[1]
    datafilename = sys.argv[2]
    outfilename = sys.argv[3]
    doOneBased=False
    if '-1-based' in sys.argv:
        doOneBased=True
        print 'assuming genome is 0-based and annotation 1-based, will shift all coordinates one base-pair to the left'

    DNA = {'A':'T','T':'A','G':'C','C':'G','N':'N'}

    outfile = open(outfilename, 'w')

    GenomeSequence = {}

    inputdatafile = open(fasta)
    Keep=False
    for line in inputdatafile:
        if line[0]=='>':
            if Keep:
                sequence = ''.join(sequence)
                GenomeSequence[chr]=sequence
            chr = line.strip().split('>')[1]
            print chr
            sequence=[]
            Keep=True
        else:
            sequence.append(line.strip())   
    sequence = ''.join(sequence)
    GenomeSequence[chr]=sequence

    skipped=0    
    lineslist  = open(datafilename)
    for line in lineslist:
        if line[0]=='#':
            fields = line.strip().split('\t')
            outline = fields[0] + '\t' + fields[1] + '\t' + fields[2] + '\t' + fields[3] + '\t5-exon\t5-intron|3-intron\t3-exon'
            for ID in range(4,len(fields)):
                outline = outline + '\t' + fields[ID]
            outfile.write(outline + '\n')
            continue
        fields = line.strip().split('\t')
        chr=fields[0]
        left=int(fields[1])+1
        right=int(fields[2])
        strand=fields[3]
        if doOneBased:
            left=left-1
            right=right-1
        if strand == '+':
            try:
                sequence5 = string.upper(GenomeSequence[chr][left-2:left+2])
                sequence3 = string.upper(GenomeSequence[chr][right-2:right+2])
            except:
                print 'could not find', chr, left, right
                continue
        elif strand == '-':
            try:
                presequence3 = string.upper(GenomeSequence[chr][left-2:left+2])
                presequence5 = string.upper(GenomeSequence[chr][right-2:right+2])
                sequence5=''
                for i in range(len(presequence5)):
                    sequence5=sequence5+DNA[presequence5[len(presequence5)-i-1]]
                sequence3=''
                for i in range(len(presequence3)):
                    sequence3=sequence3+DNA[presequence3[len(presequence3)-i-1]]
            except:
                print 'could not find', chr, left, right
                continue
        else:
            print fields[0:4], 'unknown strand, skipping', skipped
            skipped+=1
            continue
        outline=chr+'\t'+str(left)+'\t'+str(right)+'\t'+strand+'\t'+sequence5[0:2]+'\t'+sequence5[2:4]+'|'+sequence3[0:2]+'\t'+sequence3[2:4]
        for ID in range(4,len(fields)):
            outline = outline + '\t' + fields[ID]
        outfile.write(outline + '\n')

    outfile.close()
        
run()

