##################################
#                                #
# Last modified 5/6/2009         # 
#                                #
# Georgi Marinov                 #
#                                # 
##################################

import sys
import string
from sets import Set

def run():

    if len(sys.argv) < 3:
        print 'usage: python %s targetpredictionsfile outfilename [-doConservedOnly] [-Disregard6mers] [-do8mersOnly]' % sys.argv[0]
        sys.exit(1)

    listofboundmiRNAsfile = sys.argv[1]
    outfilename = sys.argv[2]

    doConservedOnly=False
    doDisregard6mers=False
    do8mersOnly=False
    if '-doConservedOnly' in sys.argv:
        doConservedOnly=True
        print 'Only conserved seed matches considered'
    if '-Disregard6mers' in sys.argv:
        doDisregard6mers=True
        print '6mers not considered'
    if '-do8mersOnly' in sys.argv:
        do8mersOnly=True
        print 'Only 8mers considered'


    cachePages = 500000
    outfile = open(outfilename, 'w')

    listofboundmiRNAs = open(listofboundmiRNAsfile)
    boundmiRNAs = listofboundmiRNAs.readlines()
    targets = {}
    headerline = boundmiRNAs[0]
    miRNAs = []
    fields = headerline.split('\n')[0].split('\t')
    for i in range(len(fields)-2):
        targets[i]={}
        targets[i]['miRNA']=fields[i]
        miRNAs.append(fields[i])
    boundmiRNAs.remove(boundmiRNAs[0])

    if doConservedOnly:
        for line in boundmiRNAs:
            fields = line.split('\n')[0].split('\t')
            for i in range(0,len(fields)-2,3):
                if fields[i]=='':
                    continue
                gene = fields[i]
                if doDisregard6mers and fields[i+1]=='6mer':
                    continue
                if do8mersOnly and (fields[i+1]=='6mer' or fields[i+1]=='M8 7mer' or fields[i+1]=='A1 7mer'):
                    continue
                conservation = int(fields[i+2])
                if conservation==0:
                    continue
                if gene not in targets[i].keys():
                    targets[i][gene]=1
                else:
                    targets[i][gene]+=1
    else:
        for line in boundmiRNAs:
            fields = line.split('\n')[0].split('\t')
            for i in range(0,len(fields)-2,3):
                if fields[i]=='':
                    continue
                gene = fields[i]
                if doDisregard6mers and fields[i+1]=='6mer':
                    continue
                if do8mersOnly and (fields[i+1]=='6mer' or fields[i+1]=='M8 7mer' or fields[i+1]=='A1 7mer'):
                    continue
                if gene not in targets[i].keys():
                    targets[i][gene]=1
                else:
                    targets[i][gene]+=1

    print 'finished parsing file'    

    allgenestargeted = []

    for i in targets.keys():
        for gene in targets[i].keys():
            if gene!='miRNA':
                allgenestargeted.append(gene)
        allgenestargeted = list(Set(allgenestargeted))

    allgenestargeted = list(Set(allgenestargeted))
    genesmiRsmatrix = {}
    
    for gene in allgenestargeted:
        genesmiRsmatrix[gene]={}
        for miR in miRNAs:
            genesmiRsmatrix[gene][miR]=0

    allgenestargeted.sort()
    miRNAs.sort()

    if doConservedOnly:
        for line in boundmiRNAs:
            fields = line.split('\n')[0].split('\t')
            for i in range(0,len(fields)-2,3):
                if fields[i]=='':
                    continue
                gene = fields[i]
                if doDisregard6mers and fields[i+1]=='6mer':
                    continue
                if do8mersOnly and (fields[i+1]=='6mer' or fields[i+1]=='M8 7mer' or fields[i+1]=='A1 7mer'):
                    continue
                miR = targets[i]['miRNA']
                conservation = int(fields[i+2])
                if conservation==0:
                    continue
                if conservation==1:
                    genesmiRsmatrix[gene][miR]+=1
    else:
        for line in boundmiRNAs:
            fields = line.split('\n')[0].split('\t')
            for i in range(0,len(fields)-2,3):
                if fields[i]=='':
                    continue
                gene = fields[i]
                if doDisregard6mers and fields[i+1]=='6mer':
                    continue
                if do8mersOnly and (fields[i+1]=='6mer' or fields[i+1]=='M8 7mer' or fields[i+1]=='A1 7mer'):
                    continue
                miR = targets[i]['miRNA']
                genesmiRsmatrix[gene][miR]+=1

    print 'finished creating matrix'    

    outfile.write('Gene/miRNA\t')
    for miR in miRNAs:
        if miR=='':
            miRNAs.remove(miR)
            continue
        outfile.write(miR)
        outfile.write('\t')
    outfile.write('\n')
    for gene in allgenestargeted:
        outfile.write(gene)
        outfile.write('\t')
        for miR in miRNAs:
            outfile.write(str(genesmiRsmatrix[gene][miR]))
            outfile.write('\t')
        outfile.write('\n')

    print 'finished outputing'    

run()
