##################################
#                                #
# Last modified 11/14/2013       # 
#                                #
# Georgi Marinov                 #
#                                # 
##################################

import sys
import os
import string

def run():

    if len(sys.argv) < 4:
        print 'usage: python %s fasta first_N_bp weblogo_location oufile_prefix [-TtoU] [-minReadLen bp] [-maxReadLen bp] [-w #] [-h #]' % sys.argv[0]
        sys.exit(1)
    
    input = sys.argv[1]
    length = int(sys.argv[2])
    logoPath = sys.argv[3]
    outfile = open(sys.argv[4]+'.PWM','w')

    PositionDict={}

    doTtoU=False
    if '-TtoU' in sys.argv:
        doTtoU=True

    doMinReadLen = False
    if '-minReadLen' in sys.argv:
        doMinReadLen = True
        MinReadLen = int(sys.argv[sys.argv.index('-minReadLen') + 1])

    doMaxReadLen = False
    if '-maxReadLen' in sys.argv:
        doMaxReadLen = True
        MaxReadLen = int(sys.argv[sys.argv.index('-maxReadLen') + 1])

    doW = False
    if '-w' in sys.argv:
        doW = True
        w = int(sys.argv[sys.argv.index('-w')+1])

    doH = False
    if '-h' in sys.argv:
        doH = True
        h = int(sys.argv[sys.argv.index('-h')+1])

    tempfile=sys.argv[4]+'.temp'

    BaseDict={}

    lineslist = open(input)
    seqfile = open(tempfile,'w')
    for line in lineslist:
        if line.startswith('>'):
            ID = line
            continue
        else:
            sequence=line.strip()[0:length]
            if doMinReadLen and len(line.strip()) < MinReadLen:
                continue
            if doMaxReadLen and len(line.strip()) > MaxReadLen:
                continue
            if doTtoU:
                sequence=sequence.replace('T','U')
                sequence=sequence.replace('t','u')
            seqfile.write(ID)
            seqfile.write(sequence+'\n')
            for i in range(len(sequence)):
                if PositionDict.has_key(i):
                    pass
                else:
                    PositionDict[i]=[]
                PositionDict[i].append(sequence[i].capitalize())
                BaseDict[sequence[i].capitalize()]=''

    positions=PositionDict.keys()
    bases=BaseDict.keys()
    positions.sort()
    bases.sort()

    seqfile.close()

    PWMDict={}

    for i in positions:
        PWMDict[i]={}
        for B in bases:
            PWMDict[i][B]=PositionDict[i].count(B)/(0.0 + len(PositionDict[i]))
#        print PWMDict[i]

    for B in bases:
        outline=B
        for i in positions:
            outline = outline+'\t'+str(PWMDict[i][B])
        outfile.write(outline+'\n')

    
    cmd = logoPath + ' -f ' + tempfile + ' -F PNG -c ' + '-a -Y -n -M -k 1 '
    if doW:
       cmd = cmd + ' -w ' + str(w) + ' '
    if doH:
       cmd = cmd + ' -h ' + str(h) + ' '
    cmd = cmd + ' -o ' + sys.argv[4]+'.png'
#    print cmd
    contents = os.system(cmd)
    os.remove(tempfile)

    outfile.close()
   
run()
