##################################
#                                #
# Last modified 2018/04/30       #
#                                #
# Georgi Marinov                 #
#                                # 
##################################

import sys
import string
import math
import random
from sets import Set
import time

def run():

    if len(sys.argv) < 10:
        print 'usage: python %s input.tsv safes.tsv IDFieldID sequenceFieldID N1(safe_pairs) N2(safes_per_guide) left_adapter_sequence middle_sequence right_adapter_sequence outfile [-addG] [-randomizeOrder] [-minSafeLen bp]' % sys.argv[0]
        print 'assumed input format:'
        print '\t\t#ID\tRegion1_chr\tRegion1_start\tRegion1_end\tRegion2_chr\tRegion2_start\tRegion2_end\tsgRNA_1\tsgRNA_2\tcutting_efficiency_score_sgRNA_1\tcutting_efficiency_score_sgRNA_2\tcutting_specificity_score_sgRNA_1\tcutting_specificity_score_sgRNA_2\tstrand_sgRNA_1\tstrand_sgRNA_2\tofftargets_sum_sgRNA_1\tofftargets_sum_sgRNA_2\tofftargets_summary_sgRNA_1\tofftargets_summary_sgRNA_2'
        print '\tBy default the script will pair a random N2 safes to each guide plus a random set of N1 pairs between safes'
        print '\tUse nan if you do not want adapters'
        sys.exit(1)

    input = sys.argv[1]
    safes = sys.argv[2]
    IDfieldID = int(sys.argv[3])
    seqfieldID = int(sys.argv[4])
    N1 = int(sys.argv[5])
    N2 = int(sys.argv[6])
    AdL = sys.argv[7]
    if AdL == 'nan':
        AdL = ''
    AdM = sys.argv[8]
    if AdM == 'nan':
        AdM = ''
    AdR = sys.argv[9]
    if AdR == 'nan':
        AdR = ''
    outfilename = sys.argv[10]

    SafesList = []

    doAddG = False
    if '-addG' in sys.argv:
        doAddG = True

    doRO = False
    if '-randomizeOrder' in sys.argv:
        doRO = True
        print 'will randomize guide order'

    minSafeLen = 0
    if '-minSafeLen' in sys.argv:
        minSafeLen = int(sys.argv[sys.argv.index('-minSafeLen') + 1])
        print 'will only include safes that are at least', minSafeLen, 'bp long'
        

    linelist = open(safes)
    for line in linelist:
        if line.startswith('#') or line.strip() == '':
            continue
        fields = line.strip().split('\t')
        ID = fields[IDfieldID]
        seq = fields[seqfieldID]
        if len(seq) < minSafeLen:
            continue
        SafesList.append((ID,seq))

    TargetGuides = []

    outfile = open(outfilename, 'w')

    linelist = open(input)
    for line in linelist:
        if line.startswith('#'):
            continue
        fields = line.strip().split('\t')
        chr1 = fields[1]
        start1 = fields[2]
        end1 = fields[3]
        chr2 = fields[4]
        start2 = fields[5]
        end2 = fields[6]
        seq1 = fields[7]
        seq2 = fields[8]
        strand1 = fields[13]
        strand2 = fields[14]
        OT1 = fields[17]
        OT2 = fields[18]
        GID1 = chr1 + ':' + start1 + '-' + end1 + '|' + strand1 + '|' + OT1
        GID2 = chr2 + ':' + start2 + '-' + end2 + '|' + strand2 + '|' + OT2
        outline = GID1 + '__' + GID2
        if doRO:
            RO = random.random()
            if RO < 0.5:
                if doAddG:
                    outline = outline + ',' + AdL + 'G' + seq1 + AdM + 'G' + seq2 + AdR
                else:
                    outline = outline + ',' + AdL + seq1 + AdM + seq2 + AdR
            else:
                outline = GID2 + '__' + GID1
                if doAddG:
                    outline = outline + ',' + AdL + 'G' + seq2 + AdM + 'G' + seq1 + AdR
                else:
                    outline = outline + ',' + AdL + seq2 + AdM + seq1 + AdR
        else:
            if doAddG:
                outline = outline + ',' + AdL + 'G' + seq1 + AdM + 'G' + seq2 + AdR
            else:
                outline = outline + ',' + AdL + seq1 + AdM + seq2 + AdR
        outfile.write(outline + '\n')
        TargetGuides.append((GID1,seq1))
        TargetGuides.append((GID2,seq2))

    print 'finished outputting targetting guide pairs'

    TargetGuides = list(Set(TargetGuides))
    TargetGuides.sort()

    for (GID1,seq1) in TargetGuides:
        safes = random.sample(SafesList,N2)
        for (IDSafe,seqSafe) in safes:
            if doRO:
                RO = random.random()
                if RO < 0.5:
                    if doAddG:
                        outline = IDSafe + '__' + GID1 + ',' + AdL + 'G' + seqSafe + AdM + 'G' + seq1 + AdR
                    else:
                        outline = IDSafe + '__' + GID1 + ',' + AdL + seqSafe + AdM + seq1 + AdR
                else:
                    if doAddG:
                        outline = GID1 + '__' + IDSafe + ',' + AdL + 'G' + seq1 + AdM + 'G' + seqSafe + AdR
                    else:
                        outline = GID1 + '__' + IDSafe + ',' + AdL + seq1 + AdM + seqSafe + AdR
            else:
                if doAddG:
                    outline = GID1 + '__' + IDSafe + ',' + AdL + 'G' + seq1 + AdM + 'G' + seqSafe + AdR
                else:
                    outline = GID1 + '__' + IDSafe + ',' + AdL + seq1 + AdM + seqSafe + AdR
        outfile.write(outline + '\n')

    print 'finished outputting targetting guide and safe pairs'

    
    SafePairsList = []
    for (IDSafe1,seqSafe1) in SafesList:
        SSS = random.sample(SafesList,min(2,int((N1+0.0)/len(SafesList))+1))
        for (IDSafe2,seqSafe2) in SSS:
            if seqSafe2 != seqSafe1:
                SafePairsList.append((IDSafe1,seqSafe1,IDSafe2,seqSafe2))

    SafePairs = random.sample(SafePairsList,N1)

    for (IDSafe1,seqSafe1,IDSafe2,seqSafe2) in SafePairs:
        if doRO:
            RO = random.random()
            if RO < 0.5:
                if doAddG:
                    outline = IDSafe2 + '__' + IDSafe1 + ',' + AdL + 'G' + seqSafe2 + AdM + 'G' + seqSafe1 + AdR
                else:
                    outline = IDSafe2 + '__' + IDSafe1 + ',' + AdL + seqSafe2 + AdM + seqSafe1 + AdR
            else:
                if doAddG:
                    outline = IDSafe1 + '__' + IDSafe2 + ',' + AdL + 'G' + seqSafe1 + AdM + 'G' + seqSafe2 + AdR
                else:
                    outline = IDSafe1 + '__' + IDSafe2 + ',' + AdL + seqSafe1 + AdM + seqSafe2 + AdR
        else:
            if doAddG:
                outline = IDSafe1 + '__' + IDSafe2 + ',' + AdL + 'G' + seqSafe1 + AdM + 'G' + seqSafe2 + AdR
            else:
                outline = IDSafe1 + '__' + IDSafe2 + ',' + AdL + seqSafe1 + AdM + seqSafe2 + AdR
        outfile.write(outline + '\n')

    print 'finished outputting safe pairs'

    outfile.close()

run()
