##################################
#                                #
# Last modified 2018/11/08       #
#                                #
# Georgi Marinov                 #
#                                # 
##################################

import sys
import string
import math
import random
from sets import Set
import time

def run():

    if len(sys.argv) < 3:
        print 'usage: python %s guidescan.csv N_guides_per_region outfile [-relaxed OT] [-addG]' % sys.argv[0]
        print '\tBy default the script will pick guides only among the set of guides with the fewest predicted off-target sites for each region'
        print '\tIf that set is smaller than N, and you still want N guides, use the [-relaxed OT] option, in which case guides will be picked from sets with up to the specified OT number of off-targets'
        sys.exit(1)

    GS = sys.argv[1]
    N = int(sys.argv[2])
    outfilename = sys.argv[3]

    doAddG = False
    if '-addG' in sys.argv:
        doAddG = True

    RegionDict = {}

    doRelaxed = False
    if '-relaxed' in sys.argv:
        doRelaxed = True
        RelOT = int(sys.argv[sys.argv.index('-relaxed') + 1])

    linelist = open(GS)
    currentTSS = ''
    L = 0
    prev_line = ''
    InGuide = False
    for line in linelist:
        if line.startswith('chromosome,target site'):
#            InGuide = True
#            region = prev_line
            if RegionDict.has_key(region):
                print 'duplicate regions detected, exiting:'
                print region
                sys.exit(1)
            RegionDict[region] = {}
            continue
        if ',' not in line:
            region = line.strip()
            continue
        fields = line.strip().split(',')
        chr = fields[0]
        left = fields[1]
        right = fields[2]
        gRNA = fields[3]
        cutting_efficiency_score = fields[4]
        cutting_specificity_score = fields[5]
        strand =  fields[6]
        offtargets_sum = fields[7]
        offtargets_summary = fields[8]
        if RegionDict[region].has_key(int(offtargets_sum)):
            pass
        else:
            RegionDict[region][int(offtargets_sum)] = []
        guide = (chr,left,right,gRNA,cutting_efficiency_score,cutting_specificity_score,strand,offtargets_sum,offtargets_summary)
        RegionDict[region][int(offtargets_sum)].append(guide)

    regions = RegionDict.keys()
    regions.sort()

    for region in regions:
        OT = min(RegionDict[region].keys())
        OTmax = max(RegionDict[region].keys())
        RegionDict[region]['selected_guides'] = []
        if len(RegionDict[region][OT]) <= N:
            if doRelaxed:
                for guide in RegionDict[region][OT]:
                    RegionDict[region]['selected_guides'].append(guide)
                OTT = OT
                while len(RegionDict[region]['selected_guides']) < N:
                    NN = N - len(RegionDict[region]['selected_guides'])
                    OTT += 1
                    if RegionDict[region].has_key(OTT):
                        RandSubSample = random.sample(RegionDict[region][OTT],min(NN,len(RegionDict[region][OTT])))
                        for guide in RandSubSample:
                            RegionDict[region]['selected_guides'].append(guide)
                    if OTT >= OTmax:
                        break
            else:
                RegionDict[region]['selected_guides'] = RegionDict[region][OT]
        else:
            RegionDict[region]['selected_guides'] = random.sample(RegionDict[region][OT],N)
        print region, OT, len(RegionDict[region][OT])
        print 'picked', len(RegionDict[region]['selected_guides']), 'guides'

    outfile = open(outfilename, 'w')

    outline = '#ID\tRegion1_chr\tRegion1_start\tRegion1_end\tRegion2_chr\tRegion2_start\tRegion2_end\tsgRNA_1\tsgRNA_2'
    outline = outline + '\tcutting_efficiency_score_sgRNA_1\tcutting_efficiency_score_sgRNA_2\tcutting_specificity_score_sgRNA_1\tcutting_specificity_score_sgRNA_2\tstrand_sgRNA_1\tstrand_sgRNA_2'
    outline = outline + '\tofftargets_sum_sgRNA_1\tofftargets_sum_sgRNA_2\tofftargets_summary_sgRNA_1\tofftargets_summary_sgRNA_2'
    outfile.write(outline + '\n')

    pgRNA = 1
    for i in range(len(regions)-1):
        r1 = regions[i]
        for j in range(i+1,len(regions)):
            r2 = regions[j]
            for (chr1,left1,right1,gRNA1,cutting_efficiency_score1,cutting_specificity_score1,strand1,offtargets_sum1,offtargets_summary1) in RegionDict[r1]['selected_guides']:
                for (chr2,left2,right2,gRNA2,cutting_efficiency_score2,cutting_specificity_score2,strand2,offtargets_sum2,offtargets_summary2) in RegionDict[r2]['selected_guides']:
                    outline = 'pgRNA_' + str(pgRNA)
                    outline = outline + '\t' + r1.replace(':','\t').replace('-','\t') + '\t' + r2.replace(':','\t').replace('-','\t')
                    if doAddG:
                        outline = outline + '\tG' + gRNA1 + '\tG' + gRNA2
                    else:
                        outline = outline + '\t' + gRNA1 + '\t' + gRNA2
                    outline = outline + '\t' + cutting_efficiency_score1 + '\t' + cutting_efficiency_score2
                    outline = outline + '\t' + cutting_specificity_score1 + '\t' + cutting_specificity_score2
                    outline = outline + '\t' + strand1 + '\t' + strand2
                    outline = outline + '\t' + offtargets_sum1 + '\t' + offtargets_sum2
                    outline = outline + '\t' + offtargets_summary1 + '\t' + offtargets_summary2
                    pgRNA += 1
                    outfile.write(outline + '\n')

    outfile.close()

run()
