##################################
#                                #
# Last modified 2020/05/21       # 
#                                #
# Georgi Marinov                 #
#                                # 
##################################

import sys
import string
import math
import os
from sets import Set

def run():

    if len(sys.argv) < 5:
        print 'usage: python %s modules geneFieldID moduleFieldID GMAP outfile' % sys.argv[0]
        sys.exit(1)

    modules = sys.argv[1]
    geneFieldID = int(sys.argv[2])
    moduleFieldID = int(sys.argv[3])
    GMAP = sys.argv[4]
    outprefix = sys.argv[5]

    GenePosDict = {}
    ChrPosDict = {}

    if GMAP.endswith('.bz2'):
        cmd = 'bzip2 -cd ' + GMAP
    elif GMAP.endswith('.gz') or GMAP.endswith('.bgz'):
        cmd = 'zcat ' + GMAP
    else:
        cmd = 'cat ' + GMAP
    p = os.popen(cmd, "r")
    line = 'line'
    while line != '':
        line = p.readline()
        if line == '':
            break
        fields = line.strip().split('\t')
        if line.startswith('#'):
            continue
        chr = fields[13]
        strand = fields[8]
        gene = fields[9]
        L = int(fields[15])
        R = int(fields[16])
        if strand == '-':
            TSS = R
        if strand == '+':
            TSS = L
        if ChrPosDict.has_key(chr):
            pass
        else:
            ChrPosDict[chr] = []
        ChrPosDict[chr].append(TSS)
        GenePosDict[gene] = (chr,TSS,L,R,strand)

    ModuleDict = {}

    if modules.endswith('.bz2'):
        cmd = 'bzip2 -cd ' + modules
    elif modules.endswith('.gz') or modules.endswith('.bgz'):
        cmd = 'zcat ' + modules
    else:
        cmd = 'cat ' + modules
    p = os.popen(cmd, "r")
    line = 'line'
    while line != '':
        line = p.readline()
        if line == '':
            break
        fields = line.strip().split('\t')
        if line.startswith('#'):
            continue
        module = fields[moduleFieldID]
        gene = fields[geneFieldID]
        if ModuleDict.has_key(module):
            pass
        else:
            ModuleDict[module] = []
        if GenePosDict.has_key(gene):
            pass
        else:
            print gene, 'not found in chromosome positions file, removing from module', module
            continue
        ModuleDict[module].append(gene)

    for chr in ChrPosDict.keys():
        ChrPosDict[chr] = list(Set(ChrPosDict[chr]))
        ChrPosDict[chr].sort()

    outfile = open(outprefix + '.same_module', 'w')
    outline = '#module1\tmodule2\tgene1\tgene2\tchr1\tchr2\tsame_chromosome?\tTSS1\tTSS2\tstrand1\tstrand2\tTSS_distance(bp)\tTSS_distance(gene_positions)\tsame_strand?'
    outfile.write(outline + '\n')

    SeenDict = {}

    for module in ModuleDict.keys():
        ModuleDict[module] = list(Set(ModuleDict[module]))
        for gene1 in ModuleDict[module]:
            (chr1,TSS1,L1,R1,strand1) = GenePosDict[gene1]
            for gene2 in ModuleDict[module]:
                if gene1 == gene2:
                    continue
                if SeenDict.has_key((gene1,gene2)) or SeenDict.has_key((gene2,gene1)):
                    continue
                SeenDict[(gene1,gene2)] = 1
                (chr2,TSS2,L2,R2,strand2) = GenePosDict[gene2]
                outline = module + '\t' + module + '\t' + gene1 + '\t' + gene2 + '\t' + chr1 + '\t' + chr2
                if chr1 != chr2:
                    outline = outline + '\tno\t' + '--\t--\t--\t--\t--\t--\t--'
                else:
                    outline = outline + '\tyes\t' + str(TSS1) + '\t' + str(TSS2) + '\t' + strand1 + '\t' + strand2
                    outline = outline + '\t' + str(math.fabs(TSS1 - TSS2))
                    pos1 = ChrPosDict[chr1].index(TSS1)
                    pos2 = ChrPosDict[chr2].index(TSS2)
                    outline = outline + '\t' + str(math.fabs(pos1 - pos2))
                    if strand1 == strand2:
                        outline = outline + '\t' + 'yes'
                    else:
                        outline = outline + '\t' + 'no'
                outfile.write(outline + '\n')

    outfile = open(outprefix + '.diff_module', 'w')
    outline = '#module1\tmodule2\tgene1\tgene2\tchr1\tchr2\tsame_chromosome?\tTSS1\tTSS2\tstrand1\tstrand2\tTSS_distance(bp)\tTSS_distance(gene_positions)\tsame_strand?'
    outfile.write(outline + '\n')

    SeenDict = {}

    for module1 in ModuleDict.keys():
        for gene1 in ModuleDict[module1]:
            (chr1,TSS1,L1,R1,strand1) = GenePosDict[gene1]
            for module2 in ModuleDict.keys():
                if module1 == module2:
                    continue
                for gene2 in ModuleDict[module2]:
                    if gene1 == gene2:
                        continue
                    if SeenDict.has_key((gene1,gene2)) or SeenDict.has_key((gene2,gene1)):
                        continue
                    SeenDict[(gene1,gene2)] = 1
                    (chr2,TSS2,L2,R2,strand2) = GenePosDict[gene2]
                    outline = module1 + '\t' + module2 + '\t' + gene1 + '\t' + gene2 + '\t' + chr1 + '\t' + chr2
                    if chr1 != chr2:
                        outline = outline + '\tno\t' + '--\t--\t--\t--\t--\t--\t--'
                    else:
                        outline = outline + '\tyes\t' + str(TSS1) + '\t' + str(TSS2) + '\t' + strand1 + '\t' + strand2
                        outline = outline + '\t' + str(math.fabs(TSS1 - TSS2))
                        pos1 = ChrPosDict[chr1].index(TSS1)
                        pos2 = ChrPosDict[chr2].index(TSS2)
                        outline = outline + '\t' + str(math.fabs(pos1 - pos2))
                        if strand1 == strand2:
                            outline = outline + '\t' + 'yes'
                        else:
                            outline = outline + '\t' + 'no'
                    outfile.write(outline + '\n')                   
        
run()

