##################################
#                                #
# Last modified 2018/04/07       #
#                                #
# Georgi Marinov                 #
#                                # 
##################################

import sys
import string
import math
import random
from sets import Set
import os

def run():

    if len(sys.argv) < 10:
        print 'usage: python %s guidescan.csv motifs motifsChrFieldID motifsStrandFieldID motifsIDFieldID  peaks peakChrFieldID peakPosFieldID|middle|narrowPeak maxOffTarget outfile [-ignoreMotifs]' % sys.argv[0]
        print '\tNote: all input files can be compressed'
        sys.exit(1)

    GS = sys.argv[1]
    motifs = sys.argv[2]
    motChrFieldID = int(sys.argv[3])
    motStrandFieldID = int(sys.argv[4])
    motIDFieldID = int(sys.argv[5])
    peaks = sys.argv[6]
    peakChrFieldID = sys.argv[7]
    peakFieldID = sys.argv[8]
    maxOT = int(sys.argv[9])
    outfilename = sys.argv[10]

    doIgnoreMotifs = False
    if '-ignoreMotifs' in sys.argv:
        doIgnoreMotifs = True

    peakDict = {}

    maxRegionLen = 0

    if peaks.endswith('.bz2'):
        cmd = 'bzip2 -cd ' + peaks
    elif peaks.endswith('.gz'):
        cmd = 'gunzip -c ' + peaks
    elif peaks.endswith('.zip'):
        cmd = 'unzip -p ' + peaks
    else:
        cmd = 'cat ' + peaks
    p = os.popen(cmd, "r")
    line = 'line'
    while line != '':
        line = p.readline().strip()
        fields = line.split('\t')
        if line == '':
            break
        if line.startswith('#'):
            continue
        fields = line.strip().split('\t')
        if peakFieldID == 'narrowPeak':
            chr = fields[0]
            left = int(fields[1])
            right = int(fields[2])
            offset = int(fields[9])
            peak = left + offset
        elif peakFieldID == 'middle':
            chr = fields[peakChrFieldID]
            left = int(fields[peakChrFieldID + 1])
            right = int(fields[peakChrFieldID + 2])
            peak = (right + left)/2
        else:
            chr = fields[peakChrFieldID]
            left = int(fields[peakChrFieldID + 1])
            right = int(fields[peakChrFieldID + 2])
            peak = int(fields[peakFieldID])
        region = chr + ':' + str(left) + '-' + str(right)
        peakDict[region] = peak

    print 'finished inputting peaks'


    if doIgnoreMotifs:
        pass
    else:
        MotifDict = {}

        if motifs.endswith('.bz2'):
            cmd = 'bzip2 -cd ' + motifs
        elif motifs.endswith('.gz'):
            cmd = 'gunzip -c ' + motifs
        elif motifs.endswith('.zip'):
            cmd = 'unzip -p ' + motifs
        else:
            cmd = 'cat ' + motifs
        p = os.popen(cmd, "r")
        line = 'line'
        while line != '':
            line = p.readline().strip()
            fields = line.split('\t')
            if line == '':
                break
            if line.startswith('#'):
                continue
            fields = line.strip().split('\t')
            chr = fields[motChrFieldID]
            left = int(fields[motChrFieldID + 1])
            right = int(fields[motChrFieldID + 2])
            strand = fields[motStrandFieldID]
            motif = fields[motIDFieldID]
            if MotifDict.has_key(chr):
                pass
            else:
                MotifDict[chr] = {}
            for i in range(left,right):
                MotifDict[chr][i] = (motif,chr,left,right,strand)

        print 'finished inputting motifs'

    RegionDict = {}

    L = 0
    prev_line = ''
    InGuide = False
    if GS.endswith('.bz2'):
        cmd = 'bzip2 -cd ' + GS
    elif GS.endswith('.gz'):
        cmd = 'gunzip -c ' + GS
    elif GS.endswith('.zip'):
        cmd = 'unzip -p ' + GS
    else:
        cmd = 'cat ' + GS
    p = os.popen(cmd, "r")
    line = 'line'
    while line != '':
        line = p.readline().strip()
        fields = line.split('\t')
        if line == '':
            break
        if line.startswith('chromosome,target site'):
            if RegionDict.has_key(region):
                print 'duplicate regions detected, exiting:'
                print region
                sys.exit(1)
            RegionDict[region] = []
            continue
        if ',' not in line:
            region = line.strip()
            continue
        fields = line.strip().split(',')
        chr = fields[0]
        left = int(fields[1])
        right = int(fields[2])
        gRNA = fields[3]
        cutting_efficiency_score = fields[4]
        cutting_specificity_score = fields[5]
        strand =  fields[6]
        offtargets_sum = int(fields[7])
        offtargets_summary = fields[8]
        if offtargets_sum > maxOT:
            continue
        if doIgnoreMotifs:
            guide = (chr,left,right,gRNA,cutting_efficiency_score,cutting_specificity_score,strand,offtargets_sum,offtargets_summary,'n/a','n/a',peakDict[region])
            RegionDict[region].append(guide)
        else:
            if MotifDict.has_key(chr):
                for i in range(left,right):
                    if MotifDict[chr].has_key(i):
                        motif = MotifDict[chr][i][0]
                        motifstrand = MotifDict[chr][i][4]
                        guide = (chr,left,right,gRNA,cutting_efficiency_score,cutting_specificity_score,strand,offtargets_sum,offtargets_summary,motif,motifstrand,peakDict[region])
                        RegionDict[region].append(guide)
                        break

    print 'finished inputting guides'

    regions = RegionDict.keys()
    regions.sort()

    outfile = open(outfilename, 'w')

    outline = '#sgRNA_ID\tsgRNA\tchr\tleft\tright\tstrand\tmotif\tmotifStrand\tChIP-seq_peak_pos'
    outline = outline + '\tcutting_efficiency_score\tcutting_specificity_score'
    outline = outline + '\tofftargets_sum\tofftargets_summary'
    outfile.write(outline + '\n')

    sgRNA = 1
    for region in regions:
        RegionDict[region] = list(Set(RegionDict[region]))
        for guide in RegionDict[region]:
            (chr,left,right,gRNA,cutting_efficiency_score,cutting_specificity_score,strand,offtargets_sum,offtargets_summary,motif,motifstrand,peak) = guide
            outline = 'sgRNA_' + str(sgRNA)
            outline = outline + '\t' + gRNA 
            outline = outline + '\t' + chr 
            outline = outline + '\t' + str(left) 
            outline = outline + '\t' + str(right) 
            outline = outline + '\t' + strand 
            outline = outline + '\t' + motif
            outline = outline + '\t' + motifstrand
            outline = outline + '\t' + str(peak)
            outline = outline + '\t' + str(cutting_efficiency_score)
            outline = outline + '\t' + str(cutting_specificity_score)
            outline = outline + '\t' + str(offtargets_sum)
            outline = outline + '\t' + str(offtargets_summary)
            sgRNA += 1
            outfile.write(outline + '\n')

    outfile.close()

run()
