##################################
#                                #
# Last modified 2018/01/26       #
#                                #
# Georgi Marinov                 #
#                                # 
##################################

import sys
import string
import math
import os
import random	
from sets import Set
import time

def run():

    if len(sys.argv) < 3:
        print 'usage: python %s labelregions.tsv.gz N_pos N_negatives' % sys.argv[0]
        print '\tThe script will print to stdout by default'
        print '\tThe numbers of positives and negatives are used to define the ratio to which negatives will be downsampled to'
        sys.exit(1)

    input = sys.argv[1]
    Npos = int(sys.argv[2])
    Nneg = int(sys.argv[3])

    InputPos = 0
    InputNeg = 0
    NegRegions = []

    if input.endswith('.bz2'):
        cmd = 'bzip2 -cd ' + input
    elif input.endswith('.gz'):
        cmd = 'gunzip -c ' + input
    else:
        cmd = 'cat ' + input
    p = os.popen(cmd, "r")
    line = 'line'
    while line != '':
        line = p.readline()
        if line == '':
            break
        if line.startswith('#'):
            continue
        fields = line.strip().split('\t')
        label = fields[3]
        if label == '-1':
            continue
        if label == '1':
            InputPos += 1
        if label == '0':
            InputNeg += 1
            chr = fields[0]
            left = int(fields[1])
            right = int(fields[2])
#             NegRegions.append((chr,left,right))

#     NegRegions.sort()
#     NewNegRegions = []
#     currentChr = NegRegions[0][0]
#     currentLeft = NegRegions[0][1]
#     currentRight = NegRegions[0][2]
#     for (chr,left,right) in NegRegions:
#         if chr != currentChr or left > currentRight:
#             NewNegRegions.append((currentChr,currentLeft,currentRight))
#             currentChr = chr
#             currentLeft = left
#             currentRight = right
#         elif chr == currentChr and left < currentRight:
#             currentRight = right
#         else:
#             print (chr,left,right), (currentChr,currentLeft,currentRight)
#     NewNegRegions.append((currentChr,currentLeft,currentRight))

    WantedImbalance = (Nneg + 0.0)/Npos

    WantedNegs = InputPos*WantedImbalance

    pWanted = WantedNegs/InputNeg

    if input.endswith('.bz2'):
        cmd = 'bzip2 -cd ' + input
    elif input.endswith('.gz'):
        cmd = 'gunzip -c ' + input
    else:
        cmd = 'cat ' + input
    p = os.popen(cmd, "r")
    line = 'line'
    while line != '':
        line = p.readline()
        if line == '':
            break
        if line.startswith('#'):
            print line.strip()
            continue
        fields = line.strip().split('\t')
        label = fields[3]
        if label == '-1':
            print line.strip()
        if label == '1':
            print line.strip()
        if label == '0':
            Ppick = random.random()
            if Ppick <= pWanted:
                print line.strip()

run()