##################################
#                                #
# Last modified 2019/02/26       # 
#                                #
# Georgi Marinov                 #
#                                # 
##################################

import sys
import string
import math
import os
import pyBigWig
from sets import Set

def run():

    if len(sys.argv) < 6:
        print 'usage: python %s inputfilename chrFieldID posField [strandField | -noStrand] upstream downstream guides outputfilename [-narrowPeak] [-stranded] [-offset bp]' % sys.argv[0]
        print '\tguides input format:' 
        print '\tchr19:40398966-40398986__chr19:40398974-40398996:-,48,0.17726806,15;2:1|3:14,GCGGGGATAAGGTCATGGGG' 
        sys.exit(1)
    
    regionfilename = sys.argv[1]
    chrFieldID = int(sys.argv[2])
    posFieldID = int(sys.argv[3])
    noStrand=False
    if sys.argv[4]=='-noStrand':
        noStrand=True
    else:
        strandFieldID = int(sys.argv[4])
    upstream = int(sys.argv[5])
    downstream = int(sys.argv[6])
    guides = sys.argv[7]
    outfilename = sys.argv[8]

    doStranded = False
    if '-stranded' in sys.argv:
        doStranded = True

    doNP = False
    if '-narrowPeak' in sys.argv:
        doNP = True

    OS = 0
    doOffset = False
    if '-offset' in sys.argv:
        doOffset = True
        OS = int(sys.argv[sys.argv.index('-offset') + 1])
        print 'will shift cut sites regions by', OS, 'bp'

    GuideDict = {}

    if guides.endswith('.bz2'):
        cmd = 'bzip2 -cd ' + guides
    elif guides.endswith('.gz'):
        cmd = 'zcat ' + guides
    else:
        cmd = 'cat ' + guides
    p = os.popen(cmd, "r")
    LC = 0
    line = 'line'
    while line != '':
        line = p.readline()
        if line == '':
            break
        if line.startswith('#'):
            continue
        fields = line.strip().split(',')
        chr = fields[0].split(':')[0]
        left = int(fields[0].split(':')[-2].split('-')[0])
        right = int(fields[0].split(':')[-2].split('-')[1])
        strand = fields[0][-1]
        if strand == '+':
            cut = right - 3 - 3 + OS
        if strand == '-':
            cut = left + 3 + 3 - OS
        CFD = fields[2]
        if GuideDict.has_key(chr):
            pass
        else:
            GuideDict[chr] = {}
        GuideDict[chr][cut] = (CFD,strand)

    outfile = open(outfilename,'w')
    outline='#'
    for i in range(0 - upstream, 0 + downstream + 1):
        outline = outline + '\t' + str(i)
    outfile.write(outline+'\n')

    if regionfilename.endswith('.bz2'):
        cmd = 'bzip2 -cd ' + regionfilename
    elif regionfilename.endswith('.gz'):
        cmd = 'zcat ' + regionfilename
    else:
        cmd = 'cat ' + regionfilename
    p = os.popen(cmd, "r")
    LC = 0
    line = 'line'
    while line != '':
        line = p.readline()
        if line == '':
            break
        LC += 1
        if LC % 10000 == 0:
            print LC, 'lines processed'
        if line.startswith('#'):
            header=line.strip()
            continue
        fields = line.strip().split('\t')
        if len(fields) < 3:
            continue
        chr = fields[chrFieldID]
        if doNP:
            pos = int(fields[1]) + int(fields[9])
        else:
            pos = int(fields[posFieldID])
        if noStrand:
            strand = '+'
        else:
            strand = fields[strandFieldID]
        if doStranded:
            outline = chr + ':' + str(pos) + ':' + strand + ':same_strand_guides'
            if strand == '+':
                for i in range(pos - upstream, pos + downstream + 1):
                    if GuideDict.has_key(chr):
                        if GuideDict[chr].has_key(i):
                            print chr, pos, strand, GuideDict[chr][i], GuideDict[chr][i][1]
                            if GuideDict[chr][i][1] == strand:
                                outline = outline + '\t' + GuideDict[chr][i][0]
                            else:
                                outline = outline + '\t'
                        else:
                            outline = outline + '\t'
                    else:
                        outline = outline + '\t'
            if strand=='-':
                for i in range(pos - upstream, pos + downstream + 1):
                    if GuideDict.has_key(chr):
                        j = pos + downstream - (i-pos - upstream)
                        if GuideDict[chr].has_key(j):
                            if GuideDict[chr][j][1] == strand:
                                outline = outline + '\t' + GuideDict[chr][j][0]
                            else:
                                outline = outline + '\t'
                        else:
                            outline = outline + '\t'
                    else:
                        outline = outline + '\t'
            outfile.write(outline + '\n')
            outline = chr + ':' + str(pos) + ':' + strand + ':opp_strand_guides'
            if strand=='+':
                for i in range(pos - upstream, pos + downstream + 1):
                    if GuideDict.has_key(chr):
                        if GuideDict[chr].has_key(i):
                            if GuideDict[chr][i][1] != strand:
                                outline = outline + '\t' + GuideDict[chr][i][0]
                            else:
                                outline = outline + '\t'
                        else:
                            outline = outline + '\t'
                    else:
                        outline = outline + '\t'
            if strand=='-':
                for i in range(pos - upstream, pos + downstream + 1):
                    if GuideDict.has_key(chr):
                        j = pos + downstream - (i-pos - upstream)
                        if GuideDict[chr].has_key(j):
                            if GuideDict[chr][j][1] != strand:
                                outline = outline + '\t' + GuideDict[chr][j][0]
                            else:
                                outline = outline + '\t'
                        else:
                            outline = outline + '\t'
                    else:
                        outline = outline + '\t'
            outfile.write(outline + '\n')
        else:
            outline = chr + ':' + str(pos) + ':' + strand
            if strand=='+' or strand=='F':
                for i in range(pos - upstream, pos + downstream + 1):
                    if GuideDict.has_key(chr):
                        if GuideDict[chr].has_key(i):
                            outline = outline + '\t' + GuideDict[chr][i][0]
                        else:
                            outline = outline + '\t'
                    else:
                        outline = outline + '\t'
            if strand=='-' or strand=='R':
                for i in range(pos - upstream, pos + downstream + 1):
                    if GuideDict.has_key(chr):
                        j = pos + downstream - (i-pos - upstream)
                        if GuideDict[chr].has_key(j):
                            outline = outline + '\t' + GuideDict[chr][j][0]
                        else:
                            outline = outline + '\t'
                    else:
                        outline = outline + '\t'
            outfile.write(outline + '\n')

    outfile.close()
   
run()
