##################################
#                                #
# Last modified 07/30/2013       # 
#                                #
# Georgi Marinov                 #
#                                # 
##################################

import sys
import string
import re
from sets import Set

def getReverseComplement(preliminarysequence):
    
    DNA = {'A':'T','T':'A','G':'C','C':'G','N':'N','a':'t','t':'a','g':'c','c':'g','n':'n'}
    sequence=''
    for j in range(len(preliminarysequence)):
        sequence=sequence+DNA[preliminarysequence[len(preliminarysequence)-j-1]]
    return sequence

def run():

    if len(sys.argv) < 4:
        print 'usage: python %s fasta1,(fasta2,...,fastaN) sequence k outfile_prefix' % sys.argv[0]
        sys.exit(1)

    input_files = sys.argv[1].split(',')
    sequence = sys.argv[2]
    minK = int(sys.argv[3])
    outfile_prefix = sys.argv[4]

    genome={}

    for inputfilename in input_files:
        linestring = open(inputfilename).read()
        chromosomes = linestring.split('>')
        for chr_seq in chromosomes:
            if len(chr_seq) == 0:
                continue
            chr = chr_seq[0:chr_seq.find('\n')]
            print chr
            genome[chr] = chr_seq[len(chr):len(chr_seq)].replace('\n','')

    chrList = genome.keys()
    chrList.sort()

    for k in range(minK,len(sequence)+1):
        print k
        outfile = open(outfile_prefix + '.k' + str(k), 'w')
        for i in range(len(sequence)-k+1):
            print sequence[i:i+k]
            kmer = re.compile(sequence[i:i+k])
            revkmer = re.compile(getReverseComplement(sequence[i:i+k]))
            for chr in chrList:
                outputList = []
                for m in re.finditer(kmer,genome[chr]):
                    outline = sequence[i:i+k] + '\t' + chr + '\t' + str(m.start()) + '\t' + str(m.end())
                    outputList.append(outline)
                for m in re.finditer(revkmer,genome[chr]):
                    outline = sequence[i:i+k] + '\t' + chr + '\t' + str(m.start()) + '\t' + str(m.end())
                    outputList.append(outline)
                outputList = list(Set(outputList))
                outputList.sort()
                for outline in outputList:
                    outfile.write(outline + '\n')

        outfile.close() 
run()

