##################################
#                                #
# Last modified 2017/12/14       #
#                                #
# Georgi Marinov                 #
#                                # 
##################################

import sys
import string
import math
import random
from sets import Set
import time

def run():

    if len(sys.argv) < 3:
        print 'usage: python %s input maxGuidesPerGene outfile' % sys.argv[0]
        print '\tassumed input format:'
        print '\tsgRNA_1 ENSG00000186842.4       LINC00846       ENST00000334165.4       chr21:32573327:30M      GTAGGTCGTGTGTGTCGTCTGAGCATTTGC  GCAAATGCTCAGACGACACACACGACCTAC  0.533333333333'
        sys.exit(1)

    input = sys.argv[1]
    maxGuides = int(sys.argv[2])
    outfilename = sys.argv[3]

    GeneDict = {}

    linelist = open(input)
    for line in linelist:
        if line.startswith('#'):
            continue
        fields = line.strip().split('\t')
        sequence = fields[5]
        alignment = fields[4]
        geneID = fields[1]
        if GeneDict.has_key(geneID):
            pass
        else:
            GeneDict[geneID] = []
        GeneDict[geneID].append((sequence,alignment))

    newGeneDict = {}

    for geneID in GeneDict:
        newGeneDict[geneID] = {}
        GeneDict[geneID] = list(Set(GeneDict[geneID]))
        if len(GeneDict[geneID]) > maxGuides:
            kkk = random.sample(GeneDict[geneID],maxGuides)
            for (sequence,alignment) in kkk:
                newGeneDict[geneID][(sequence,alignment)] = 1
        else:
            for (sequence,alignment) in GeneDict[geneID]:
                newGeneDict[geneID][(sequence,alignment)] = 1

    outfile = open(outfilename, 'w')

    SeenDict = {}

    linelist = open(input)
    SG = 0
    for line in linelist:
        if line.startswith('#'):
            outifle.write(outline)
            continue
        SG += 1
        fields = line.strip().split('\t')
        sequence = fields[5]
        alignment = fields[4]
        geneID = fields[1]
        if SeenDict.has_key((sequence,alignment)):
            continue
        if newGeneDict[geneID].has_key((sequence,alignment)):
            pass
        else:
            continue
        ID = 'sgRNA_' + str(SG)
        outfile.write(line.replace(fields[0],ID))
        SeenDict[(sequence,alignment)] = 1

    print 'finished subsampling guides'

    outfile.close()

run()
