##################################
#                                #
# Last modified 2017/06/01       # 
#                                #
# Georgi Marinov                 #
#                                # 
##################################

import sys
import os
from sets import Set
import string

def run():

    if len(sys.argv) < 6:
        print 'usage: python %s regionFile chrFieldID leftFieldID rightFieldID config outfilename' % sys.argv[0]
        print '\tconfig file format: label\tfile\tchrFieldID\tleftFieldID\trightFieldID'
        print '\tindividual files can be zipped'
        sys.exit(1)

    regions = sys.argv[1]
    chrFieldID = int(sys.argv[2])
    leftFieldID = int(sys.argv[3])
    rightFieldID = int(sys.argv[4])
    config = sys.argv[5]
    outfilename = sys.argv[6]

    RegionCoverageDict = {}

    k = 0
    Labels = []
    LabelsDictRev = {}
    LabelsDict = {}
    linelist = open(config)
    for line in linelist:
        if line.startswith('#'):
            continue
        if line.strip() == '':
            continue
        k+=1
        fields = line.strip().split('\t')
        chrID = int(fields[2])
        leftID = int(fields[3])
        rightID = int(fields[4])
        label = fields[0]
        file = fields[1]
        LabelsDict[k] = (label,file,chrID,leftID,rightID)
        if LabelsDictRev.has_key(label):
            pass
        else:
            LabelsDictRev[label] = {}
        LabelsDictRev[label][k] = 1
        Labels.append(label)

    Labels = list(set(Labels))
    Labels.sort()

    print 'parsing regions file'

    outfile = open(outfilename, 'w')

    linelist = open(regions)
    for line in linelist:
        if line.startswith('#'):
            outline = line.strip()
            for label in Labels:
                outline = outline + '\t' + label
            outfile.write(outline + '\n')
            continue
        fields = line.strip().split('\t')
        chr = fields[chrFieldID]
        left = int(fields[leftFieldID])
        right = int(fields[rightFieldID])
        if RegionCoverageDict.has_key(chr):
            pass
        else:
            RegionCoverageDict[chr] = {}
        for i in range(left,right):
            RegionCoverageDict[chr][i] = {}

    print 'finished parsing regions file'

    Ks = LabelsDict.keys()
    Ks.sort()

#    print Ks
#    sys.exit(1)

    for k in Ks:
        (label,file,chrID,leftID,rightID) = LabelsDict[k]
        print k, label, file
        if file.endswith('.bz2'):
            cmd = 'bzip2 -cd ' + file
        elif file.endswith('.gz'):
            cmd = 'gunzip -c ' + file
        elif file.endswith('.zip'):
            cmd = 'unzip -p ' + file
        else:
            cmd = 'cat ' + file
        p = os.popen(cmd, "r")
        line = 'line'
        while line != '':
            line = p.readline().strip()
            fields = line.split('\t')
            if line == '':
                break
            if line.startswith('#'):
                continue
            fields = line.strip().split('\t')
            chr = fields[chrID]
            if RegionCoverageDict.has_key(chr):
                pass
            else:
                continue
            left = int(fields[leftID])
            right = int(fields[rightID])
            for i in range(left,right):
                if RegionCoverageDict[chr].has_key(i):
                    RegionCoverageDict[chr][i][k] = 1

    print 'finished parsing data'

    linelist = open(regions)
    for line in linelist:
        if line.startswith('#'):
            continue
        fields = line.strip().split('\t')
        chr = fields[chrFieldID]
        left = int(fields[leftFieldID])
        right = int(fields[rightFieldID])
        outline = line.strip()
        for label in Labels:
            IsPresent = 0
            for k in LabelsDictRev[label].keys():
                if IsPresent == 1:
                    continue
                for i in range(left,right):
                    if RegionCoverageDict[chr][i].has_key(k):
                        IsPresent = 1
                        break
            outline = outline + '\t' + str(IsPresent)
        outfile.write(outline + '\n')

    outfile.close()

run()

