##################################
#                                #
# Last modified 7/24/2009         # 
#                                #
# Georgi Marinov                 #
#                                # 
##################################

import sys
import string
from sets import Set

def run():

    if len(sys.argv) < 5:
        print 'usage: python %s inputmatrixfile minupstreambp maxdownstrembp firstSiteField outputmatrixfilename ' % sys.argv[0]
        sys.exit(1)

    inputfilename = sys.argv[1]
    upstream = -int(sys.argv[2])
    downstream = int(sys.argv[3])
    firstsitefield = int(sys.argv[4])
    outfilename = sys.argv[5]

    cachePages = 500000
    outfile = open(outfilename, 'w')

    listoflines = open(inputfilename)
    lineslist = listoflines.readlines()
    i=0
    outfile.write(lineslist[0])
    fields=lineslist[0].split('\t')
#    print fields
    lenfields=len(fields)
#    print lenfields
    for line in lineslist:
       i+=1
       if line[0]=='#':
           continue
       else:
           fields = line.strip().split('\t')
           if len(fields)<lenfields:
               print 'warining: data in row', i, 'truncated; skipping row'
               print line
               continue
           outputline=''
           for j in range(0,firstsitefield):
               outputline = outputline+fields[j]+'\t'
           for j in range(firstsitefield,lenfields+1):
               if (j-firstsitefield) % 3 == 0:
                   if fields[j]=='0':
                       outputline = outputline+'0\t\t\t'
                       zero=True
                   if fields[j]!='0':
                       zero=False
                   continue
               if (j-firstsitefield) % 3 == 1:
                   if zero:
                       continue
                   distances=fields[j].split(',')
                   peaks=fields[j+1].split(',')
                   k=0
                   for y in range(0,len(distances)-1):
                       if int(distances[y]) >= upstream and int(distances[y]) <= downstream:
                           k+=1
                   outputline = outputline+str(k)+'\t'
                   for y in range(0,len(distances)-1):
                       if int(distances[y]) >= upstream and int(distances[y]) <= downstream:
                           outputline = outputline+distances[y]+','
                   outputline = outputline+'\t'
                   for y in range(0,len(distances)-1):
                       if int(distances[y]) >= upstream and int(distances[y]) <= downstream:
                           outputline = outputline+peaks[y]+','
                   outputline = outputline+'\t'
           outputline=outputline+'\n'
           outfile.write(outputline)

    outfile.close()

run()
