##################################
#                                #
# Last modified 2018/11/18       # 
#                                #
# Georgi Marinov                 #
#                                # 
##################################

import sys
import string
import gzip
from sets import Set

def run():

    if len(sys.argv) < 7:
        print 'usage: python %s inputfilename radius(genes) up-genes geneIDfieldID down-genes geneIDfieldID outfilename' % sys.argv[0]
        sys.exit(1)

    gtf = sys.argv[1]
    R = int(sys.argv[2])
    upgenes = sys.argv[3]
    upID = int(sys.argv[4])
    downgenes = sys.argv[5]
    downID = int(sys.argv[6])
    outputfilename = sys.argv[7]


    UpGeneDict = {}
    DownGeneDict = {}

    if upgenes.endswith('.gz'):
        listoflines = gzip.open(upgenes)
    else:
        listoflines = open(upgenes)
    for line in listoflines:
        if line.startswith('#'):
            continue
        fields = line.strip().split('\t')
        geneID = fields[upID]
        UpGeneDict[geneID] = 1

    if downgenes.endswith('.gz'):
        listoflines = gzip.open(downgenes)
    else:
        listoflines = open(downgenes)
    for line in listoflines:
        if line.startswith('#'):
            continue
        fields = line.strip().split('\t')
        geneID = fields[upID]
        DownGeneDict[geneID] = 1

    GeneDict = {}
    
    if gtf.endswith('.gz'):
        listoflines = gzip.open(gtf)
    else:
        listoflines = open(gtf)
    for line in listoflines:
        if line.startswith('#'):
            continue
        fields = line.strip().split('\t')
        ID = fields[8].split('gene_id "')[1].split('"')[0]
        chr = fields[0]
        left = int(fields[3])
        right = int(fields[4])
        strand = fields[6]
        if GeneDict.has_key(ID):
            pass
        else:
            GeneDict[ID] = {}
            GeneDict[ID]['strand'] = strand
            GeneDict[ID]['chr'] = chr
            GeneDict[ID]['coordinates'] = []
        GeneDict[ID]['coordinates'].append(left)
        GeneDict[ID]['coordinates'].append(right)

    ChrDict = {}

    for ID in GeneDict.keys():
        strand = GeneDict[ID]['strand']
        left = min(GeneDict[ID]['coordinates'])
        right = max(GeneDict[ID]['coordinates'])
        chr = GeneDict[ID]['chr']
        if ChrDict.has_key(chr):
            pass
        else:
            ChrDict[chr] = []
        if strand == '+':
            TSS = left
        if strand == '-':
            TSS = right
        ChrDict[chr].append((TSS,strand,ID))

    Same = 0.0
    Switch = 0.0

    PSameDict = {}
    PDiffDict = {}
    PSameDict['ss'] = {}
    PSameDict['ds'] = {}
    for i in range(1,R+1):
        PSameDict['ss'][i] = {}
        PSameDict['ss'][i]['same'] = 0
        PSameDict['ss'][i]['ns'] = 0
        PSameDict['ss'][i]['diff'] = 0
        PSameDict['ds'][i] = {}
        PSameDict['ds'][i]['same'] = 0
        PSameDict['ds'][i]['ns'] = 0
        PSameDict['ds'][i]['diff'] = 0

    for chr in ChrDict:
        if len(ChrDict[chr]) < R:
            continue
        for i in range(len(ChrDict[chr]) - R - 1):
            (TSS1,strand1,ID1) = ChrDict[chr][i]
            if UpGeneDict.has_key(ID1) or DownGeneDict.has_key(ID1):
                pass
            else:
                continue
            if UpGeneDict.has_key(ID1) and DownGeneDict.has_key(ID1):
                print 'gene found in both the up and down lists, exiting'
                print ID1
                sys.exit(1)
            if UpGeneDict.has_key(ID1):
                sign1 = 'U'
            if DownGeneDict.has_key(ID1):
                sign1 = 'D'
            for j in range(1,R+1):
                (TSS2,strand2,ID2) = ChrDict[chr][i+j]
                if UpGeneDict.has_key(ID2):
                    sign2 = 'U'
                elif DownGeneDict.has_key(ID2):
                    sign2 = 'D'
                else:
                    sign2 = 'ns'
                if strand1 == strand2:
                    if (sign1 == 'U' and sign2 == 'U') or (sign1 == 'D' and sign2 == 'D'):
                        PSameDict['ss'][j]['same'] += 1
                    if (sign1 == 'U' and sign2 == 'ns') or (sign1 == 'D' and sign2 == 'ns'):
                        PSameDict['ss'][j]['ns'] += 1
                    if (sign1 == 'U' and sign2 == 'D') or (sign1 == 'D' and sign2 == 'U'):
                        PSameDict['ss'][j]['diff'] += 1
                if strand1 != strand2:
                    if (sign1 == 'U' and sign2 == 'U') or (sign1 == 'D' and sign2 == 'D'):
                        PSameDict['ds'][j]['same'] += 1
                    if (sign1 == 'U' and sign2 == 'ns') or (sign1 == 'D' and sign2 == 'ns'):
                        PSameDict['ds'][j]['ns'] += 1
                    if (sign1 == 'U' and sign2 == 'D') or (sign1 == 'D' and sign2 == 'U'):
                        PSameDict['ds'][j]['diff'] += 1

    outfile = open(outputfilename, 'w')

    outline = '#Distance\tSameStrand_Same\tSameStrand_ns\tSameStrand_Diff\tDiffStrand_Same\tDiffStrand_ns\tDiffStrand_Diff\tSameStrand_Same\tSameStrand_ns\tSameStrand_Diff\tDiffStrand_Same\tDiffStrand_ns\tDiffStrand_Diff'
    outfile.write(outline + '\n')

    for i in range(1,R+1):
        outline = str(i)
        outline = outline + '\t' + str(PSameDict['ss'][i]['same'])
        outline = outline + '\t' + str(PSameDict['ss'][i]['ns'])
        outline = outline + '\t' + str(PSameDict['ss'][i]['diff'])
        outline = outline + '\t' + str(PSameDict['ds'][i]['same'])
        outline = outline + '\t' + str(PSameDict['ds'][i]['ns'])
        outline = outline + '\t' + str(PSameDict['ds'][i]['diff'])
        TotalSameStrand = PSameDict['ss'][i]['same'] + PSameDict['ss'][i]['ns'] + PSameDict['ss'][i]['diff'] + 0.0
        TotalDiffStrand = PSameDict['ds'][i]['same'] + PSameDict['ds'][i]['ns'] + PSameDict['ds'][i]['diff'] + 0.0
        outline = outline + '\t' + str(PSameDict['ss'][i]['same']/TotalSameStrand)
        outline = outline + '\t' + str(PSameDict['ss'][i]['ns']/TotalSameStrand)
        outline = outline + '\t' + str(PSameDict['ss'][i]['diff']/TotalSameStrand)
        outline = outline + '\t' + str(PSameDict['ds'][i]['same']/TotalDiffStrand)
        outline = outline + '\t' + str(PSameDict['ds'][i]['ns']/TotalDiffStrand)
        outline = outline + '\t' + str(PSameDict['ds'][i]['diff']/TotalDiffStrand)
        outfile.write(outline + '\n')

    outfile.close()

run()

