##################################
#                                #
# Last modified 04/06/2016       # 
#                                #
# Georgi Marinov                 #
#                                # 
##################################

import sys
import string
from sets import Set
from collections import Counter

def run():

    if len(sys.argv) < 8:
        print 'usage: python %s GTF1 GTF2 prefix1 prefix2 OrthologousGroups.txt windows_size fraction_match outfilename' % sys.argv[0]
        print '\tExample of a gene name in the OrthologousGroups.txt file:'
        print '\t\tParamecium_biaurelia_V1-4:PBIGNG33362|PBIGNG33362|PBIGNT33362|PBIGNT33362|scaffold_0398|2934,25827'
        print '\tin this case the prefix name would be Paramecium_biaurelia_V1-4'
        sys.exit(1)

    GTF1 = sys.argv[1]
    GTF2 = sys.argv[2]
    prefix1 = sys.argv[3]
    prefix2 = sys.argv[4]
    OG = sys.argv[5]
    window = int(sys.argv[6])
    fraction = float(sys.argv[7])
    outfilename = sys.argv[8]

    GeneDict1 = {}
    linelist=open(GTF1)
    for line in linelist:
        if line.startswith('#'):
            continue
        fields=line.strip().split('\t')
        if fields[2] != 'exon':
            continue
        chr = fields[0]
        left = int(fields[3])
        right = int(fields[4])
        strand = fields[6]
        geneID = fields[8].split('gene_id "')[1].split('";')[0]
        transcriptID=fields[8].split('transcript_id "')[1].split('";')[0]
        if GeneDict1.has_key(geneID):
            pass
        else:
            GeneDict1[geneID]={}
        if GeneDict1[geneID].has_key(transcriptID):
            pass
        else:
            GeneDict1[geneID][transcriptID]=[]
        GeneDict1[geneID][transcriptID].append((chr,left,right,strand))

    print 'finished inputting', GTF1

    GeneDict2 = {}
    linelist=open(GTF2)
    for line in linelist:
        if line.startswith('#'):
            continue
        fields=line.strip().split('\t')
        if fields[2] != 'exon':
            continue
        chr = fields[0]
        left = int(fields[3])
        right = int(fields[4])
        strand = fields[6]
        geneID = fields[8].split('gene_id "')[1].split('";')[0]
        transcriptID=fields[8].split('transcript_id "')[1].split('";')[0]
        if GeneDict2.has_key(geneID):
            pass
        else:
            GeneDict2[geneID]={}
        if GeneDict2[geneID].has_key(transcriptID):
            pass
        else:
            GeneDict2[geneID][transcriptID]=[]
        GeneDict2[geneID][transcriptID].append((chr,left,right,strand))

    print 'finished inputting', GTF2

    chrGenesDict1 = {}
    chrGenesDict2 = {}

    for geneID in GeneDict1.keys():
        chr = GeneDict1[geneID][GeneDict1[geneID].keys()[0]][0][0]
        strand = GeneDict1[geneID][GeneDict1[geneID].keys()[0]][0][-1]
        coordinates = []
        for transcriptID in GeneDict1[geneID].keys():
            for (chr,left,right,strand) in GeneDict1[geneID][transcriptID]:
                coordinates.append(left)
                coordinates.append(right)
        left = min(coordinates)
        right = max(coordinates)
        if chrGenesDict1.has_key(chr):
            pass
        else:
            chrGenesDict1[chr]=[]
        chrGenesDict1[chr].append((left,right,geneID,strand))

    for chr in chrGenesDict1.keys():
        chrGenesDict1[chr].sort()

    for geneID in GeneDict2.keys():
        chr = GeneDict2[geneID][GeneDict2[geneID].keys()[0]][0][0]
        strand = GeneDict2[geneID][GeneDict2[geneID].keys()[0]][0][-1]
        coordinates = []
        for transcriptID in GeneDict2[geneID].keys():
            for (chr,left,right,strand) in GeneDict2[geneID][transcriptID]:
                coordinates.append(left)
                coordinates.append(right)
        left = min(coordinates)
        right = max(coordinates)
        if chrGenesDict2.has_key(chr):
            pass
        else:
            chrGenesDict2[chr]=[]
        chrGenesDict2[chr].append((left,right,geneID,strand))

    for chr in chrGenesDict2.keys():
        chrGenesDict2[chr].sort()

    Ortholog12Dict = {}
    Ortholog21Dict = {}

    linelist = open(OG)
    for line in linelist:
        if line.startswith('#'):
            continue
        fields = line.strip().split('\t')
        OGDict = {}
        for i in range(1,len(fields)):
            species = fields[i].split(':')[0]
            geneID = fields[i].split(':')[1].split('|')[0]
            OGDict[species] = geneID
        if OGDict.has_key(prefix1) and OGDict.has_key(prefix2):
            Ortholog12Dict[OGDict[prefix1]] = OGDict[prefix2]
            Ortholog21Dict[OGDict[prefix2]] = OGDict[prefix1]

    OGchrDict1 = {}
    FinalOGchrDict1 = {}
    for chr in chrGenesDict1.keys():
        OGchrDict1[chr] = {}
        FinalOGchrDict1[chr] = {}
        for i in range(len(chrGenesDict1[chr]) - window):
            OGscaffolds = []
            for j in range(i,i+window):
                geneID = chrGenesDict1[chr][j][2]
                if Ortholog12Dict.has_key(geneID):
                    ortholog = Ortholog12Dict[geneID]
#                    print j, geneID, ortholog, GeneDict2[ortholog][GeneDict2[ortholog].keys()[0]][0][0]
                    OGscaffolds.append(GeneDict2[ortholog][GeneDict2[ortholog].keys()[0]][0][0])
                else:
                    OGscaffolds.append('')
            counts = Counter(OGscaffolds).most_common(window)
#            print chr, i, counts
            OGchr = ''
            if counts[0][0] == '':
                if len(counts) > 1:
                    F = counts[1][1]
                    if (F + 0.0)/window >= fraction:
                        OGchr = counts[1][0]
            else:
                F = counts[0][1]
                if (F + 0.0)/window >= fraction:
                    OGchr = counts[0][0]
#            print '......', OGchr
            if OGchr != '':
                OGchrindexes = []
                for j in range(i,i+window):
                    geneID = chrGenesDict1[chr][j][2]
                    if Ortholog12Dict.has_key(geneID):
                        ortholog = Ortholog12Dict[geneID]
#                        print OGchr, j, geneID, ortholog, GeneDict2[ortholog][GeneDict2[ortholog].keys()[0]][0][0]
                        if GeneDict2[ortholog][GeneDict2[ortholog].keys()[0]][0][0] == OGchr:
                            coordinates = []
                            for transcriptID in GeneDict2[ortholog].keys():
                                for (c,l,r,s) in GeneDict2[ortholog][transcriptID]:
                                    coordinates.append(l)
                                    coordinates.append(r)
                            L = min(coordinates)
                            R = max(coordinates)
                            ind = chrGenesDict2[OGchr].index((L,R,ortholog,GeneDict2[ortholog][GeneDict2[ortholog].keys()[0]][0][-1]))
#                            print j, geneID, ortholog, ind
                            OGchrindexes.append(ind)
#                print OGchrindexes
                Total = max(OGchrindexes) - min(OGchrindexes)
#                if chr == 'scaffold_0058':
#                    print OGchr, Total, max(OGchrindexes), min(OGchrindexes), (F + 0.0)/Total
                if (F + 0.0)/Total > fraction:
#                    if chr == 'scaffold_0058':
#                        print OGchr, Total, min(OGchrindexes), max(OGchrindexes), (F + 0.0)/Total, i,i+window
                    OGchrDict1[chr][(i,i+window)] = (OGchr,min(OGchrindexes),max(OGchrindexes))
#                    print chr, i,i+window, OGchr,min(OGchrindexes),max(OGchrindexes)
                    InOrtBlock = True

# merge windows

        iws = OGchrDict1[chr].keys()
        iws.sort()
#        print chr, len(iws)
        if len(iws) == 0:
            continue
        (i_initial,iw_initial) = iws[0]
        (i_current,iw_current) = iws[0]
        for i in range(1,len(iws)):
#            print iws[i][0],iws[i][1], i_current+1,iw_current+1
            if (iws[i][0],iws[i][1]) == (i_current+1,iw_current+1):
                (i_current,iw_current) = (iws[i][0],iws[i][1])
            else:
#                print i_initial,iw_current
                FinalOGchrDict1[chr][(i_initial,iw_current)] = (min(OGchrDict1[chr][(i_initial,iw_initial)][1],OGchrDict1[chr][(i_current,iw_current)][1]),max(OGchrDict1[chr][(i_initial,iw_initial)][2],OGchrDict1[chr][(i_current,iw_current)][2]))
                (i_initial,iw_initial) = iws[i]
                (i_current,iw_current) = iws[i]
        FinalOGchrDict1[chr][(i_initial,iw_current)] = (min(OGchrDict1[chr][(i_initial,iw_initial)][1],OGchrDict1[chr][(i_current,iw_current)][1]),max(OGchrDict1[chr][(i_initial,iw_initial)][2],OGchrDict1[chr][(i_current,iw_current)][2]))
                            
# output

    outfile = open(outfilename,'w')   

    outline = '#OrthologousBlock\tchr_1\tleft_1\tright_1\tstrand_1\tGeneID_1\t\tchr_2\tleft_2\tright_2\tstrand_2\tGeneID_2'
    outfile.write(outline + '\n')

    chromosomes = FinalOGchrDict1.keys()
    chromosomes.sort()
    OB = 1
    for chr in chromosomes:
        blocks = FinalOGchrDict1[chr].keys()
        blocks.sort()
        for b in blocks:
#            print chr, b, FinalOGchrDict1[chr][b]
            chroms = []
            for i in range(b[0],b[1]):
                geneID1 = chrGenesDict1[chr][i][2]
                if Ortholog12Dict.has_key(geneID1):
                    ortholog = Ortholog12Dict[geneID1]
                    chroms.append(GeneDict2[ortholog][GeneDict2[ortholog].keys()[0]][0][0])
            counts = Counter(chroms).most_common(b[1]-b[0])
            OrthChr = counts[0][0]
            NotOrth = True
            newb = b[1]
            i=b[1]
            while NotOrth:
                geneID = chrGenesDict1[chr][i][2]
                i = i - 1
                if Ortholog12Dict.has_key(geneID):
                    geneID2 = Ortholog12Dict[geneID]
                    if OrthChr == GeneDict2[geneID2][GeneDict2[geneID2].keys()[0]][0][0]:
                        NotOrth = False
                    else:
                        newb = newb - 1
                else:
                    newb = newb - 1
#            print chr, b, FinalOGchrDict1[chr][b], newb
            for i in range(b[0],newb+1):
                outline = 'OB' + str(OB) + '\t' + chr
                left1 = chrGenesDict1[chr][i][0]
                right1 = chrGenesDict1[chr][i][1]
                strand1 = chrGenesDict1[chr][i][3]
                geneID1 = chrGenesDict1[chr][i][2]
                outline = outline + '\t' + str(left1) + '\t' + str(right1) + '\t' + str(strand1) + '\t' + str(geneID1)
                if Ortholog12Dict.has_key(geneID1):
                    ortholog = Ortholog12Dict[geneID1]
                    chr2 = GeneDict2[ortholog][GeneDict2[ortholog].keys()[0]][0][0]
                    strand2 = GeneDict2[ortholog][GeneDict2[ortholog].keys()[0]][0][-1]
                    coordinates = []
                    for transcriptID in GeneDict2[ortholog].keys():
                        for (c,l,r,s) in GeneDict2[ortholog][transcriptID]:
                            coordinates.append(l)
                            coordinates.append(r)
                    left2 = min(coordinates)
                    right2 = max(coordinates)
                    geneID2 = ortholog
                else:
                    chr2 = '-'
                    left2 = '-'
                    right2 = '-'
                    strand2 = '-'
                    geneID2 = '-'
                outline = outline + '\t' + chr2 + '\t' + str(left2) + '\t' + str(right2) + '\t' + str(strand2) + '\t' + str(geneID2)
                outfile.write(outline + '\n')
            OB+=1

    outfile.close()

run()
