#
#  trimquery.py
#  ENRAGE
#
#  Created by Ali Mortazavi on 8/12/08.
#

import sys
import optparse
from cistematic.core import complement
from commoncode import getConfigParser, getConfigBoolOption, getConfigOption

print "trimreads: version 2.2"

def main(argv=None):
    if not argv:
        argv = sys.argv

    usage = "usage: python %prog length infile outfile [--fastq] [--fromback] [--paired] [--flip] [--filter maxN]"

    parser = getParser(usage)
    (options, args) = parser.parse_args(argv[1:])

    if len(args) < 3:
        print usage
        print "\t where paired fragments are separated by a : when given the -paired flag" 
        sys.exit(1)

    length = int(args[0])
    infile = args[1]
    outfile = args[2]

    trimreads(length, infile, outfile, options.fastq, options.fromBack, options.paired, options.flipseq, options.maxN)


def getParser(usage):
    parser = optparse.OptionParser(usage=usage)
    parser.add_option("--fastq", action="store_true", dest="fastq")
    parser.add_option("--fromback", action="store_true", dest="fromBack")
    parser.add_option("--paired", action="store_true", dest="paired")
    parser.add_option("--flip", action="store_true", dest="flipseq")
    parser.add_option("--filter", type="int", dest="maxN")

    configParser = getConfigParser()
    section = "trimreads"
    fastq = getConfigBoolOption(configParser, section, "fastq", False)
    fromBack = getConfigBoolOption(configParser, section, "fromBack", False)
    paired = getConfigBoolOption(configParser, section, "paired", False)
    flipseq = getConfigBoolOption(configParser, section, "flipseq", False)
    maxN = getConfigOption(configParser, section, "maxN", None)

    parser.set_defaults(fastq=fastq, fromBack=fromBack, paired=paired, flipseq=flipseq, maxN=maxN)

    return parser


def trimreads(length, inFileName, outFileName, fastq=False, fromBack=False, paired=False, flipseq=False, maxN=None):
    infile = open(inFileName)
    outfile = open(outFileName, "w")

    if paired:
        pairedlength = 2 * length
    index = 0

    if fromBack:
        length = -1 * length

    filtering = False
    if maxN is not None:
        filtering = True
        print "filtering out reads with more than %d Ns" % maxN
    else:
        maxN = 2

    print "trimming reads from %s to %d bp and saving them in %s" % (inFileName, length, outFileName)
    
    filtered = 0
    header = ""
    for line in infile:
        line = line.strip()
        if len(line) == 0:
            continue

        firstChar = line[0]
        if (not fastq and firstChar == ">") or (fastq and firstChar in ["@", "+"]): 
            header = line + "\n"
        else:
            if filtering:
                if line.count("N") > maxN:
                    filtered += 1
                    continue

            seq1 = line[length:]
            seq2 = line[:length]
            if flipseq:
                try:
                    tempseq1 = seq1
                    seq1 = complement(tempseq1)
                except:
                    seq1 = tempseq1

                try:
                    tempseq2 = seq2
                    seq2 = complement(tempseq2)
                except:
                    seq2 = tempseq2

            if paired:
                if len(line) < pairedlength:
                    continue

                outfile.write("%s%s:%s\n" % (header, seq1, seq2))
            else:
                if fromBack:
                    outfile.write("%s%s\n" % (header, seq1))
                else:
                    outfile.write("%s%s\n" % (header, seq2))

            index += 1
            if index % 1000000 == 0:
                print ".",

            sys.stdout.flush()

    outfile.close()
    print "returned %d reads" % index
    if filtering:
        print "%d additional reads filtered" % filtered


if __name__ == "__main__":
    main(sys.argv)