##################################
#                                #
# Last modified 03/18/2012       # 
#                                #
# Georgi Marinov                 #
#                                # 
##################################

import sys
import numpy
import math
from sets import Set
import string
from cistematic.core.geneinfo import geneinfoDB
from cistematic.genomes import Genome

from cistematic.core.geneinfo import geneinfoDB
from cistematic.genomes import Genome

def getSequence(hg,chromosome,start,stop):
    
    DNA = {'A':'T','T':'A','G':'C','C':'G','N':'N','a':'t','t':'a','g':'c','c':'g','n':'n'}
    chromosome = chromosome[3:len(chromosome)]
    sequence = hg.sequence(chromosome,start,stop-start)
    
    return sequence

def getReverseComplement(DNA,sequence):
    
    revsequence = ''
    for i in range(len(sequence)):
        revsequence =revsequence + DNA[sequence[len(sequence)-i-1]]
    
    return revsequence

def HammingDistance(string1,string2):

    ham = 0

    for i in range(min(len(string1),len(string2))):
        if string1[i] != string2[i]:
            ham+=1

    ham += math.fabs(len(string1) - len(string2))

    return ham

def run():

    if len(sys.argv) < 5:
        print 'usage: python %s junctions GTF <number base pairs into exons to be considered> genome outfilename' % sys.argv[0]
        print '   junctions format:'
        print '   chrY	9175621	9196544	+	known exon to known exon, different genes	55.0	TSPY4	TSPY8	novel	GT|AG'
        sys.exit(1)

    junctions = sys.argv[1]
    GTF = sys.argv[2]
    ExonBP = int(sys.argv[3])
    genome = sys.argv[4]
    outfilename = sys.argv[5]

    DNA = {'A':'T','T':'A','G':'C','C':'G','N':'N','a':'t','t':'a','g':'c','c':'g','n':'n'}

    hg = Genome(genome)
    idb = geneinfoDB()

#    exec('genome = worldbase.Bio.Seq.Genome.' + genome_name + '.' + genome_version + '()')

    GeneDict={}

    lineslist = open(junctions)
    for line in lineslist:
        if line.startswith('#'):
            continue
        fields=line.strip().split('\t')
        gene1 = fields[6]
        gene2 = fields[7]
        GeneDict[gene1]={}
        GeneDict[gene2]={}

    print 'finished inputting junctions'

    lineslist = open(GTF)
    for line in lineslist:
        if line.startswith('#'):
            continue
        fields=line.strip().split('\t')
        if fields[2] != 'exon':
            continue
        geneName = fields[8].split('gene_name "')[1].split('"')[0]
        if GeneDict.has_key(geneName):
            pass
        else:
            continue 
        transcriptID = fields[8].split('transcript_id "')[1].split('"')[0]
        if GeneDict[geneName].has_key(transcriptID):
            pass
        else:
            GeneDict[geneName][transcriptID]=[]
        chr=fields[0]
        strand=fields[6]
        left=int(fields[3])
        right=int(fields[4])
        GeneDict[geneName][transcriptID].append((chr,left,right,strand))

    print 'finished inputting GTF'

    for geneName in GeneDict.keys():
        if len(GeneDict[geneName].keys())==0:
            continue
        strand = GeneDict[geneName][GeneDict[geneName].keys()[0]][0][3]
        GeneDict[geneName]['exons']={}
        for transcriptID in GeneDict[geneName].keys():
            if transcriptID == 'exons':
                continue
            GeneDict[geneName][transcriptID].sort()
            if strand == '-':
                GeneDict[geneName][transcriptID].reverse()
            i=0
            for (chr,left,right,strand) in GeneDict[geneName][transcriptID]:
                i+=1
                if GeneDict[geneName]['exons'].has_key((chr,left,right,strand)):
                    GeneDict[geneName]['exons'][(chr,left,right,strand)][i]=1
                else:
                    GeneDict[geneName]['exons'][(chr,left,right,strand)]={}
                    GeneDict[geneName]['exons'][(chr,left,right,strand)][i]=1

    print 'finished sorting exons'

    outfile = open(outfilename, 'w')

    lineslist = open(junctions)
    for line in lineslist:
        if line.startswith('#'):
            continue
        fields=line.strip().split('\t')
        chr=fields[0]
        left=int(fields[1])
        right=int(fields[2])
        strand=fields[3]
        gene1 = fields[6]
        gene2 = fields[7]
        if len(GeneDict[gene1].keys())==0 or len(GeneDict[gene2].keys())==0:
            continue
        strand1 = GeneDict[gene1]['exons'].keys()[0][3]
        strand2 = GeneDict[gene2]['exons'].keys()[0][3]
        if strand1 != strand2:
            outline = line.strip() + '\tdifferent strands\tdifferent strands\tdifferent strands\n'
            outfile.write(outline)
            continue
        if strand1 == '+':
            gene1HAM = ExonBP
            for (chr,left1,right1,strand) in GeneDict[gene1]['exons'].keys():
                if left1 > left:
#                    right1String = str(genome[chr][left1:left1+ExonBP])
#                    right2String = str(genome[chr][right:right+ExonBP])
                    right1String=getSequence(hg,chr,left1-1,left1+ExonBP-1)
                    right2String=getSequence(hg,chr,right,right+ExonBP)
                    HD = HammingDistance(right1String,right2String)
                    if HD < gene1HAM:
                        gene1HAM = HD
#                    print 'gene1', left1, left1+ExonBP, right, right+ExonBP, right1String, right2String, HD, gene1HAM
            gene2HAM = ExonBP
            for (chr,left2,right2,strand) in GeneDict[gene2]['exons'].keys():
                if right2 < right:
#                    left1String = str(genome[chr][right2-ExonBP:right2])
#                    left2String = str(genome[chr][left-ExonBP:left])
                    left1String=getSequence(hg,chr,right2-ExonBP,right2)
                    left2String=getSequence(hg,chr,left-ExonBP+1,left+1)
                    HD = HammingDistance(left1String,left2String)
                    if HD < gene2HAM:
                        gene2HAM = HD
#                    print 'gene2', right2-ExonBP, right2, left-ExonBP, left, left1String, left2String, HD, gene2HAM
        if strand1 == '-':
            gene2HAM = ExonBP
            for (chr,left2,right2,strand) in GeneDict[gene2]['exons'].keys():
                if right2 < right:
#                    left1String = str(genome[chr][right2-ExonBP:right2])
#                    left2String = str(genome[chr][left-ExonBP:left])
                    left1String=getSequence(hg,chr,right2-ExonBP-1,right2-1)
                    left2String=getSequence(hg,chr,left-ExonBP,left)
                    HD = HammingDistance(left1String,left2String)
                    if HD < gene2HAM:
#                        print gene2HAM, HD
                        gene2HAM = HD
#                        print gene2HAM, HD
#                    print 'gene2', left2-ExonBP, left2, left-ExonBP, left, left1String, left2String, HD, gene2HAM
            gene1HAM = ExonBP
            for (chr,left1,right1,strand) in GeneDict[gene1]['exons'].keys():
                if left < left1:
#                    right1String = str(genome[chr][left1:left1+ExonBP])
#                    right2String = str(genome[chr][right:right+ExonBP])
                    right1String=getSequence(hg,chr,left1-1,left1+ExonBP-1)
                    right2String=getSequence(hg,chr,right,right+ExonBP)
                    HD = HammingDistance(right1String,right2String)
                    if HD < gene1HAM:
#                        print gene1HAM, HD
                        gene1HAM = HD
#                        print gene1HAM, HD
#                    print 'gene1', left1, left1+ExonBP, right, right+ExonBP, right1String, right2String, HD, gene1HAM
#        print gene1HAM, gene2HAM, min(gene1HAM,gene2HAM)
        outline = line.strip() + '\t' + str(gene1HAM) + '\t' + str(gene2HAM) + '\t' + str(min(gene1HAM,gene2HAM)) + '\n'
        if gene1 == 'RLN1' or gene1 == 'RLN2':
            print outline.strip()
        outfile.write(outline)

    outfile.close()

run()
