##################################
#                                #
# Last modified 2019/01/11       # 
#                                #
# Georgi Marinov                 #
#                                # 
##################################

import sys
import string
import gzip
from sets import Set

def run():

    if len(sys.argv) < 3:
        print 'usage: python %s inputfilename minGeneNumberPerChr outfilename' % sys.argv[0]
        sys.exit(1)

    gtf = sys.argv[1]
    minGNPC = int(sys.argv[2])
    outputfilename = sys.argv[3]

    outfile = open(outputfilename, 'w')

    GeneDict = {}
    
    if gtf.endswith('.gz'):
        listoflines = gzip.open(gtf)
    else:
        listoflines = open(gtf)
    for line in listoflines:
        if line.startswith('#'):
            continue
        fields = line.strip().split('\t')
        ID = fields[8].split('gene_id "')[1].split('"')[0]
        chr = fields[0]
        left = int(fields[3])
        right = int(fields[4])
        strand = fields[6]
        if GeneDict.has_key(ID):
            pass
        else:
            GeneDict[ID] = {}
            GeneDict[ID]['strand'] = strand
            GeneDict[ID]['chr'] = chr
            GeneDict[ID]['coordinates'] = []
        GeneDict[ID]['coordinates'].append(left)
        GeneDict[ID]['coordinates'].append(right)

    ChrDict = {}

    for ID in GeneDict.keys():
        strand = GeneDict[ID]['strand']
        left = min(GeneDict[ID]['coordinates'])
        right = max(GeneDict[ID]['coordinates'])
        chr = GeneDict[ID]['chr']
        if ChrDict.has_key(chr):
            pass
        else:
            ChrDict[chr] = []
        if strand == '+':
            TSS = left
        elif strand == '-':
            TSS = right
        elif strand == '.':
            print 'unstranded transcript, skipping', ID
            TSS = left
        ChrDict[chr].append((TSS,strand))

    for chr in ChrDict.keys():
        ChrDict[chr] = list(Set(ChrDict[chr]))
        ChrDict[chr].sort()

    Same = 0.0
    Switch = 0.0

    for chr in ChrDict:
        if len(ChrDict[chr]) < minGNPC:
            continue
        for i in range(len(ChrDict[chr])-1):
            (TSS1,strand1) = ChrDict[chr][i]
            (TSS2,strand2) = ChrDict[chr][i+1]
            if strand1 == '.' or strand2 == '.':
                continue
            if strand1 == strand2:
                Same += 1
            else:
                Switch += 1

    outline = 'Number_gene_pairs_considered\t' + str(Same + Switch)
    outfile.write(outline+'\n')
    outline = 'strand_switch_frequency\t' + str(Switch/(Same + Switch))
    outfile.write(outline+'\n')

    outfile.close()

run()

