##################################
#                                #
# Last modified 2018/04/01       #
#                                #
# Georgi Marinov                 #
#                                # 
##################################

import sys
import string
import math
from sets import Set
import random

def run():

    if len(sys.argv) < 5:
        print 'usage: python %s input geneID|geneName|transcriptID number_guides_per_array number_arrays_per_target outfile [-shuffle] [-rRNA] [-prefix prefix] [-safes]' % sys.argv[0]
        print '\tassumed input format: '
        print '\t\t#sgRNA_ID\tgeneIDs\tgeneNames\ttranscripts\tmapping(s)\tgRNA_sequence\tgRNA_sequence_reverse_complement\tGC%'
        print '\t\tsgRNA_1\tENSG00000001561.6\tENPP4\tENST00000321037.4\tchr6:46129994:30M\tACGTCCCTATCTGCGCCGCTCGGGGCGCTC\tGAGCGCCCCGAGCGGCGCAGATAGGGACGT\t0.733333333333'
        print '\t\ti.e. the output of Cas13sgRNAdesign_Finalize.py'
        print '\t\tguides are assumed to be sorted'
        print '\tthe geneID|geneName|transcriptID|transcriptName paramter tells the scrit whether to aggregate guides over genes or over individual transcripts'
        print '\tby default the script will pick the middle guide(s) to be the first in the array; use the [-shuffle] option to randomize the order'
        print '\tthe [-prefix] option will append the prefix to the array IDs'
        print '\tthe [-safes] option will treat guides as safes, i.e. geneIDs will be ignored and treated as the same target'
        sys.exit(1)

    input = sys.argv[1]
    GorT = sys.argv[2]
    Nsg = int(sys.argv[3])
    Narr = int(sys.argv[4])
    outfilename = sys.argv[5]

    doShuffle = False
    if '-shuffle' in sys.argv:
        doShuffle = True
        print 'will shuffle guide order'

    doRibo = False
    if '-rRNA' in sys.argv:
        doRibo = True
        print 'will treat guides as targetting rRNA'

    doSafes = False
    if '-safes' in sys.argv:
        doSafes = True
        print 'will treat guides as safes'

    Prefix = ''
    doPrefix = False
    if '-prefix' in sys.argv:
        doPrefix = True
        Prefix = sys.argv[sys.argv.index('-prefix') + 1]
        print 'will append', Prefix, 'to array IDs'

    guideDict = {}

    linelist = open(input)
    for line in linelist:
        if line.startswith('#'):
            continue
        fields = line.strip().split('\t')
        geneID = fields[1]
        geneName = fields[2]
        transcriptID = fields[3]
        if doRibo:
            geneName = geneName.split(',')[0]
            geneID = geneID.split(',')[0]
            transcriptID = transcriptID.split(',')[0]
        if doSafes:
            target = 'Safe'
        else:
            if GorT == 'geneID':
                target = geneId
            if GorT == 'geneName':
                target = geneName
            if GorT == 'transcriptID':
                target = transcriptID
        ID = fields[0]
        pos = fields[4]
        posToSortOn = int(pos.split(':')[1])
        if doRibo and ',' in pos:
            posToSortOn = int(pos.split(transcriptID.split('-antisense')[0])[1].split(':')[1])
        sequence = fields[5]
        revsequence = fields[6]
        if guideDict.has_key(target):
            pass
        else:
            guideDict[target] = []
        guideDict[target].append((posToSortOn,ID,pos,sequence,revsequence))
        
    outfile = open(outfilename, 'w')
    outline = '#sgRNA_array_ID\tgene|transcript'
    for i in range(Nsg):
        outline = outline + '\t' + 'guide_' + str(i+1) + '_sequence'
    for i in range(Nsg):
        outline = outline + '\t' + 'guide_' + str(i+1) + '_position'
    outfile.write(outline + '\n')

    A = 0
    for target in guideDict.keys():
        if len(guideDict[target]) < Nsg:
            print 'insufficient number of guides for target', target, 'skipping'
            continue
        guideDict[target].sort()
#        for SSS in guideDict[target]:
#            print target, SSS
        K = int(len(guideDict[target]) + 0.0)/Nsg
        NotSeenDict = {}
        binList = {}
        for i in range(Nsg):
            NotSeenDict[i] = {}
            binList[i] = []
            for j in range(i*K,min((i+1)*K,len(guideDict[target]))):
                NotSeenDict[i][guideDict[target][j]] = 1
                binList[i].append(guideDict[target][j])
	        guideCombosList = []
        for i in range(Narr):
            guideCombo = []
            for j in range(Nsg):
#                print i, target, j, len(binList[j])
                if len(NotSeenDict[j]) > 0:
                    ggg = NotSeenDict[j].keys()
                    sgRNA = random.sample(ggg,1)[0]
                    del NotSeenDict[j][sgRNA]
                else:
                    sgRNA = random.sample(binList[j],1)[0]
                guideCombo.append(sgRNA)
            guideCombosList.append(guideCombo)
        for guideCombo in guideCombosList:
            A += 1
            if doShuffle:
                random.shuffle(guideCombo)
            outline = 'sgRNA_array_' + Prefix + '_' + str(A) + '\t' + target
            for (p,ID,pos,sequence,revsequence) in guideCombo:
                outline = outline + '\t' + sequence
            for (p,ID,pos,sequence,revsequence) in guideCombo:
                outline = outline + '\t' + pos
            outfile.write(outline + '\n')

    outfile.close()

run()
