##################################
#                                #
# Last modified 2017/10/09       # 
#                                #
# Georgi Marinov                 #
#                                # 
##################################

import sys
import string
import math

def getReverseComplement(preliminarysequence):
    
    DNA = {'A':'T',
           'T':'A',
           'G':'C',
           'C':'G',
           'a':'t',
           't':'a',
           'g':'c',
           'c':'g',
           'S':'W',
           'W':'S',
           's':'w',
           'w':'s',
           'M':'K',
           'K':'M',
           'm':'k',
           'k':'m',
           'N':'N',
           'n':'n',
           'H':'D',
           'D':'H',
           'h':'d',
           'd':'h',
           'V':'B',
           'B':'V',
           'v':'b',
           'b':'v',
           'R':'Y',
           'Y':'R',
           'r':'y',
           'y':'r'}

    sequence=''
    for j in range(len(preliminarysequence)):
        sequence=sequence+DNA[preliminarysequence[len(preliminarysequence)-j-1]]
    return sequence

def PWMmatch(PWM,sequence,match):
    
    matches = False
    PWMscore = 0
    for i in range(len(sequence)):
        if sequence[i] == 'A' or sequence[i] == 'a' or sequence[i] == 'C' or sequence[i] == 'c' or sequence[i] == 'G' or sequence[i] == 'g' or sequence[i] == 'T' or sequence[i] == 't':
            PWMscore+=PWM[i][sequence[i]]
    if PWMscore/PWM['sum'] >= match:
        matches = True
#    print sequence, PWMscore, PWMscore/PWM['sum'], match, matches

    return matches

def run():

    if len(sys.argv) < 4:
        print 'usage: python %s fasta PWM match outputfilename [-fromMEME]' % sys.argv[0]
        print '\tAssumed PWM format: acgt	0.105263	0.491228	0.350877	0.052632\n\t(the script will match the order of the letter with the weights' 
        print '\tERANGE .mot files are also acceptable' 
        print '\tUse the [-fromMEME] option to take directly MEME-produced moti files; the same order of nucleotides is assumed' 
        print '\tthe match parameter referse to the percent identity to the PWM' 
        sys.exit(1)

    fasta = sys.argv[1]
    PWMfile = sys.argv[2]
    match = float(sys.argv[3])
    outfilename = sys.argv[4]

    doMEME = False
    if '-fromMEME' in sys.argv:
        doMEME = True

    SequenceDict={}

    inputdatafile = open(fasta)
    chr=''
    for line in inputdatafile:
        if line[0]=='>':
            if chr == '':
                chr = line.strip().split('>')[1]
            else:
                sequence = ''.join(sequence)
                SequenceDict[chr]=sequence.upper()
                chr = line.strip().split('>')[1]
            sequence=[]
#            print chr
        else:
            sequence.append(line.strip())   
    sequence = ''.join(sequence)
    SequenceDict[chr]=sequence.upper()

    PWM = {}
    PWMmaxSum=0
    inputdatafile = open(PWMfile)
    i=0
    if doMEME:
        InMotif = False
        for line in inputdatafile:
            if line.startswith('letter-probability matrix'):
                InMotif = True
                continue
            if InMotif:
                if line.startswith('-----'):
                    InMotif = False
                    continue
                else:
                    fields = line.strip().split('  ')
                    PWM[i]={}
                    PWM[i]['A'] = float(fields[0])
                    PWM[i]['C'] = float(fields[1])
                    PWM[i]['G'] = float(fields[2])
                    PWM[i]['T'] = float(fields[3])
                    PWMmaxSum += max(float(fields[0]),float(fields[1]),float(fields[2]),float(fields[3]))
                    i+=1
            else:
                continue
    else:
        for line in inputdatafile:
            if line.strip() == '':
                continue
            if line.startswith('tagid'):
                continue
            if line.startswith('info'):
                continue
            if line.startswith('threshold'):
                continue
            if line.startswith('motif'):
                continue
            if line.startswith('sequence'):
                continue
            fields = line.strip().split('\t')
            nucleotides = fields[0].upper()
            PWM[i]={}
            print fields
            PWM[i][nucleotides[0]] = float(fields[1])
            PWM[i][nucleotides[1]] = float(fields[2])
            PWM[i][nucleotides[2]] = float(fields[3])
            PWM[i][nucleotides[3]] = float(fields[4])
            PWMmaxSum += max(float(fields[1]),float(fields[2]),float(fields[3]),float(fields[4]))
            i+=1

    motifLength = len(PWM.keys())
    PWM['sum'] = PWMmaxSum

    print PWM
    print PWMmaxSum, motifLength

    outfile = open(outfilename, 'w')

    print len(SequenceDict.keys())

    for chr in SequenceDict.keys():
        for pos in range(len(SequenceDict[chr])-motifLength):
            if pos % 1000000 == 0:
                print chr, str(pos/1000000)+'M'
            sequence = SequenceDict[chr][pos:pos+motifLength]
            if PWMmatch(PWM,sequence,match):
                outline = chr + '\t' + str(pos) + '\t' + str(pos+motifLength) + '\t+'
                outfile.write(outline + '\n')
            if PWMmatch(PWM,getReverseComplement(sequence),match):
                outline = chr + '\t' + str(pos) + '\t' + str(pos+motifLength) + '\t-'
                outfile.write(outline + '\n')

    outfile.close()
   
run()
