##################################
#                                #
# Last modified 2018/11/27       #
#                                #
# Georgi Marinov                 #
#                                # 
##################################

import sys
import string
import math
import random
from sets import Set
import time

def run():

    if len(sys.argv) < 4:
        print 'usage: python %s guidescan.csv N_guides_per_region CFD|OT outfile [-addG]' % sys.argv[0]
        print '\tIf the OT option is picked, the script will pick guides only among the set of guides with the fewest predicted off-target sites for each region'
        print '\tIf that set is smaller than N, and you still want N guides, use the [-relaxed OT] option, in which case guides will be picked from sets with up to the specified OT number of off-targets'
        print '\tIf the CFD option is picked, the script will rank guides by the CFD'
        print '\tAssumed guidescan.csv format: chr11:5280592-5281000__chr11:5280603-5280625:+,52,0.67491047,2;2:0|3:2,CTTTATGATGCCGTTTGAGG'
        sys.exit(1)

    GS = sys.argv[1]
    N = int(sys.argv[2])
    CFDOT = sys.argv[3]
    outfilename = sys.argv[4]

    doAddG = False
    if '-addG' in sys.argv:
        doAddG = True

    RegionDict = {}

    doRelaxed = False
    if '-relaxed' in sys.argv:
        doRelaxed = True
        RelOT = int(sys.argv[sys.argv.index('-relaxed') + 1])

    linelist = open(GS)
    currentTSS = ''
    L = 0
    prev_line = ''
    InGuide = False
    for line in linelist:
        fields = line.strip().split(',')
        region = fields[0].split('_')[0]
        if RegionDict.has_key(region):
            pass
        else:
            RegionDict[region] = {}
            RegionDict[region]['OT'] = {}
            RegionDict[region]['all'] = []
        chr = fields[0].split('_')[-1].split(':')[0]
        left = fields[0].split('_')[-1].split(':')[1].split('-')[0]
        right = fields[0].split('_')[-1].split(':')[1].split('-')[1]
        gRNA = fields[4]
        cutting_efficiency_score = fields[1]
        cutting_specificity_score = float(fields[2])
        strand =  fields[0].split('_')[-1].split(':')[2]
        offtargets_sum = fields[3].split(';')[0]
        offtargets_summary = fields[3].split(';')[1]
        if RegionDict[region]['OT'].has_key(int(offtargets_sum)):
            pass
        else:
            RegionDict[region]['OT'][int(offtargets_sum)] = []
        guide = (cutting_specificity_score,chr,left,right,gRNA,cutting_efficiency_score,strand,offtargets_sum,offtargets_summary)
        RegionDict[region]['OT'][int(offtargets_sum)].append(guide)
        RegionDict[region]['all'].append(guide)

    regions = RegionDict.keys()
    regions.sort()

    for region in regions:
        for K in RegionDict[region]['OT'].keys():
            RegionDict[region]['OT'][K].sort()
            RegionDict[region]['OT'][K].reverse()
        RegionDict[region]['all'].sort()
        RegionDict[region]['all'].reverse()
        if CFDOT == 'OT':
            OT = min(RegionDict[region].keys())
            RegionDict[region]['selected_guides'] = RegionDict[region][OT][0:min(N,len(RegionDict[region][OT]))]
        if CFDOT == 'CFD':
            OT = 'all'
            RegionDict[region]['selected_guides'] = RegionDict[region][OT][0:min(N,len(RegionDict[region][OT]))]
        print 'picked', len(RegionDict[region]['selected_guides']), 'guides', region

    outfile = open(outfilename, 'w')

    outline = '#ID\tRegion1_chr\tRegion1_start\tRegion1_end\tRegion2_chr\tRegion2_start\tRegion2_end\tsgRNA_1\tsgRNA_2'
    outline = outline + '\tcutting_efficiency_score_sgRNA_1\tcutting_efficiency_score_sgRNA_2\tcutting_specificity_score_sgRNA_1\tcutting_specificity_score_sgRNA_2\tstrand_sgRNA_1\tstrand_sgRNA_2'
    outline = outline + '\tofftargets_sum_sgRNA_1\tofftargets_sum_sgRNA_2\tofftargets_summary_sgRNA_1\tofftargets_summary_sgRNA_2'
    outfile.write(outline + '\n')

    pgRNA = 1
    for i in range(len(regions)-1):
        r1 = regions[i]
        for j in range(i+1,len(regions)):
            r2 = regions[j]
            for (cutting_specificity_score1,chr1,left1,right1,gRNA1,cutting_efficiency_score1,strand1,offtargets_sum1,offtargets_summary1) in RegionDict[r1]['selected_guides']:
                for (cutting_specificity_score2,chr2,left2,right2,gRNA2,cutting_efficiency_score2,strand2,offtargets_sum2,offtargets_summary2) in RegionDict[r2]['selected_guides']:
                    outline = 'pgRNA_' + str(pgRNA)
                    outline = outline + '\t' + r1.replace(':','\t').replace('-','\t') + '\t' + r2.replace(':','\t').replace('-','\t')
                    if doAddG:
                        outline = outline + '\tG' + gRNA1 + '\tG' + gRNA2
                    else:
                        outline = outline + '\t' + gRNA1 + '\t' + gRNA2
                    outline = outline + '\t' + str(cutting_efficiency_score1) + '\t' + str(cutting_efficiency_score2)
                    outline = outline + '\t' + str(cutting_specificity_score1) + '\t' + str(cutting_specificity_score2)
                    outline = outline + '\t' + strand1 + '\t' + strand2
                    outline = outline + '\t' + offtargets_sum1 + '\t' + offtargets_sum2
                    outline = outline + '\t' + offtargets_summary1 + '\t' + offtargets_summary2
                    pgRNA += 1
                    outfile.write(outline + '\n')

    outfile.close()

run()
