import sys
import string
import math

try:
	import psyco
	psyco.full()
except:
	pass

def run():

    if len(sys.argv) < 4:
        print 'usage: python %s smallRNA_fasta_filename    known_miRNA_fasta_filename      minimal_matches   outfilename [-anymatch]' % sys.argv[0]
        sys.exit(1)

    DNA = {'A':'T','T':'A','G':'C','C':'G','N':'N'}

    inputfilename = sys.argv[1]
    miRNAilename = sys.argv[2]
    MinMatch = int(sys.argv[3])
    outputfilename = sys.argv[4]
    doAny=False
    if '-anymatch' in sys.argv:
        doAny=True
        print 'doAny=True'


    listoflines = open(miRNAilename)
    miRDict = {}
    lineslist = listoflines.readlines()
    for line in lineslist:
        if line[0]=='>':
            fields=line.split('>')
            miRName=fields[1].split(' ')[0].split('\n')[0]
            miRDict[miRName]={}
            sequence = lineslist[lineslist.index(line)+1].split('\n')[0]
            miRDict[miRName]['sequence']=sequence
            revsequence=''
            for i in range(len(sequence)):
                revsequence=revsequence+DNA[sequence[len(sequence)-i-1]]
            miRDict[miRName]['revsequence']=revsequence
            miRDict[miRName]['counts']=0.0
            miRDict[miRName]['revcounts']=0.0
            miRDict[miRName]['family']={}
            miRDict[miRName]['family']['list']=[]
            miRDict[miRName]['family']['lengths']=[]

    for line in lineslist:
        if line[0]=='>':
            fields=line.split('>')
            miRNA=fields[1].split(' ')[0].split('\n')[0]
        else:
            read=line.split('\n')[0]
            for miR in miRDict.keys():
                b = min(MinMatch,len(miRDict[miR]['sequence']))
                if line[0:b]==miRDict[miR]['sequence'][0:b]: 
                    miRDict[miR]['family']['list'].append(miRNA) 
                    miRDict[miR]['family']['lengths'].append(len(read))

    outfile = open(outputfilename, 'w')
    outfile.write('miRNA'+'\t'+'sequence'+'\t'+'length'+'\t'+'Counts'+'\t'+'TPM'+'\n')

    listoflines = open(inputfilename)
    lineslist = listoflines.readlines()
    ReadNumber=float(len(lineslist))/2
    lIndex=0
    if doAny:
        for line in lineslist:
            if lIndex % 10000 == 0:
                print len(lineslist)-lIndex 
            lIndex+=1
            if line[0]=='>':
                continue
            else:
                read=line.split('\n')[0]
                for miR in miRDict.keys():
                    b = min(MinMatch,len(miRDict[miR]['sequence']))
                    if len(miRDict[miR]['family']['list'])==1:
                        if miRDict[miR]['sequence'].count(line[0:b])>0:
                           miRDict[miR]['counts']+=1
                           break
                    else: 
                        for i in range(MinMatch,max(miRDict[miR]['family']['lengths'])+1):
                            matches=[]
                            for M in miRDict[miR]['family']['list']:
                                if miRDict[M]['sequence'].count(line[0:i])>0:
                                    matches.append(M)
                            if len(matches)==1:
                                miRDict[matches[0]]['counts']+=1.0/len(miRDict[miR]['family']['list'])
                                break
                            if (len(matches)>1 and i==(max(miRDict[miR]['family']['lengths']))):
                                for K in matches:
                                    miRDict[K]['counts']+=1.0/(len(matches)*len(miRDict[miR]['family']['list']))
                                break
                            if len(matches)==0:
                                i=i-1
                                for M in miRDict[miR]['family']['list']:
                                    if miRDict[M]['sequence'].count(line[0:i])>0:
                                        matches.append(M)
                                for K in matches:
                                    miRDict[K]['counts']+=1.0/len(matches)
                                break
    else:
        for line in lineslist:
            if lIndex % 10000 == 0:
                print len(lineslist)-lIndex 
            lIndex+=1
            if line[0]=='>':
                continue
            else:
                read=line.split('\n')[0]
                for miR in miRDict.keys():
                    if len(miRDict[miR]['family']['list'])==1:
                        b = min(MinMatch,len(miRDict[miR]['sequence']))
                        if line[0:b]==miRDict[miR]['sequence'][0:b]: 
                            miRDict[miR]['counts']+=1
                            break
                    else:  
                        for i in range(MinMatch,max(miRDict[miR]['family']['lengths'])+1):
                            matches=[]
                            for M in miRDict[miR]['family']['list']:
                                if line[0:i]==miRDict[M]['sequence'][0:i]:
                                    matches.append(M)
                            if len(matches)==1:
                                miRDict[matches[0]]['counts']+=1.0/len(miRDict[miR]['family']['list'])
                                break
                            if (len(matches)>1 and i==(max(miRDict[miR]['family']['lengths']))):
                                for K in matches:
                                    miRDict[K]['counts']+=1.0/(len(matches)*len(miRDict[miR]['family']['list']))
                                break
                            if len(matches)==0:
                                i=i-1
                                for M in miRDict[miR]['family']['list']:
                                    if line[0:i]==miRDict[M]['sequence'][0:i]:
                                        matches.append(M)
                                for K in matches:
                                    miRDict[K]['counts']+=1.0/len(matches)
                                break


    for miR in miRDict.keys():
        outfile.write(miR+'\t'+miRDict[miR]['sequence']+'\t'+str(len(miRDict[miR]['sequence']))+'\t'+str(miRDict[miR]['counts'])+'\t'+str((miRDict[miR]['counts'])/(ReadNumber/1000000))+'\n')
        line = miR+'\t'+miRDict[miR]['sequence']+'\t'+str(miRDict[miR]['counts'])+'\t'+str((miRDict[miR]['counts'])/(ReadNumber/1000000))+'\n'
        print line

    outfile.close()

run()

