##################################
#                                #
# Last modified 02/23/2014       # 
#                                #
# Georgi Marinov                 #
#                                # 
##################################

import sys
import string
import time
import math
import random
from sets import Set

def reverseComplement(sequence,DNA):
    
    reversesequence=''
    for i in range(len(sequence)):
        reversesequence=reversesequence+DNA[sequence[len(sequence)-i-1]]
    
    return reversesequence

def run():

    if len(sys.argv) < 7:
        print 'usage: python %s input -f|-q minReadLength maxReadLength minOverlap maxOverlap outprefix [-collapseDups]' % sys.argv[0]
        print '\tNote: Use - for standard input, e.g. if you are streaming from an arhived file' 
        sys.exit(1)

    DNA = {'A':'T','T':'A','G':'C','C':'G','N':'N','a':'t','t':'a','g':'c','c':'g','n':'n','-':'-'}
    RNA = {'A':'U','U':'A','G':'C','C':'G','N':'N','a':'u','u':'a','g':'c','c':'g','n':'n','-':'-'}


    doCollapseDups = False
    if '-collapseDups' in sys.argv:
        doCollapseDups = True

    fast = sys.argv[1]
    type = sys.argv[2]
    minLength = int(sys.argv[3])
    maxLength = int(sys.argv[4])
    minK = int(sys.argv[5])
    maxK = int(sys.argv[6])
    outprefix = sys.argv[7]

    ReadList = []

    ReadLenDict = {}
    TotalReads = 0.0

    if fast == '-':
        lineslist = sys.stdin
    else:
        lineslist = open(fast)
    if type == '-f':
        for line in lineslist:
            if line[0]=='>':
                continue
            read=line.strip().replace('U','T')
            readlen = len(read)
            if ReadLenDict.has_key(readlen):
                pass
            else:
                ReadLenDict[readlen] = 0
            ReadLenDict[readlen] += 1
            TotalReads += 1
            if len(read) < minLength or len(read) > maxLength:
                continue
            ReadList.append(read)
    if type == '-q':
        i = 0 
        for line in lineslist:
            i += 1
            if i % 4 == 1 and line[0] != '@':
                print 'fastq file broken, exiting'
                sys.exit(1)
            if (i % 2 == 0) and (i % 4 == 2):
                read=line.strip().replace('U','T')
                readlen = len(read)
                if ReadLenDict.has_key(readlen):
                    pass
                else:
                    ReadLenDict[readlen] = 0
                ReadLenDict[readlen] += 1
                TotalReads += 1
                if len(read) < minLength or len(read) > maxLength:
                    continue
                ReadList.append(read)

    outfile = open(outprefix + '.readLenghtDistribution_initial','w')
    outfile.write('#Read_length\tNumber_reads\tFraction\n')
    keys = ReadLenDict.keys()
    keys.sort()
    for r in keys:
        outline = str(r) + '\t'+ str(ReadLenDict[r]) + '\t'+ str(ReadLenDict[r]/TotalReads)
        outfile.write(outline + '\n')
    outfile.close()

    if doCollapseDups:
        ReadList = list(Set(ReadList))

    outfile = open(outprefix + '.readLenghtDistribution_postFiltering','w')
    outfile.write('#Read_length\tNumber_reads\tFraction\n')
    TotalReads = len(ReadList)
    ReadLenDict = {}
    for read in ReadList:
        readlen = len(read)
        if ReadLenDict.has_key(readlen):
            pass
        else:
            ReadLenDict[readlen] = 0
        ReadLenDict[readlen] += 1
    keys = ReadLenDict.keys()
    keys.sort()
    for r in keys:
        outline = str(r) + '\t'+ str(ReadLenDict[r]) + '\t'+ str(ReadLenDict[r]/(TotalReads + 0.0))
        outfile.write(outline + '\n')
    outfile.close()

    outfile = open(outprefix + '.PingPongFraction_and_1U10A_vs_offset','w')
    outline = '#Offset\tTotalReads\tPalindromic_reads\tNumber_ping_pong_non_palindromic_reads\tFraction_ping_pong_non_palindromic_reads\tNumber_1U_ping_pong_non_palindromic_reads\tNumber_10A_ping_pong_non_palindromic_reads\tNumber_1U10A_ping_pong_non_palindromic_reads\tFraction_1U_ping_pong_non_palindromic_reads\tFraction_10A_ping_pong_non_palindromic_reads\tFraction_1U10A_ping_pong_non_palindromic_reads'
    outline = outline + '\tFraction_of_ping_pong_reads_1U\tFraction_of_ping_pong_reads_10A\tFraction_of_ping_pong_reads_1U10A\t'
    outfile.write(outline + '\n')

    FirstKmerDict = {}
    FirstRevKmerDict = {}
    for k in range(minK,maxK+1):
        FirstKmerDict[k] = {}
        FirstRevKmerDict[k] = {}
        print k, len(ReadList)
        for read in ReadList:
            kmer = read[0:k]
            revkmer = reverseComplement(read[0:k],DNA)
            if kmer == revkmer:
                continue
            FirstKmerDict[k][kmer] = 1
            FirstRevKmerDict[k][revkmer] = 1
    for k in range(minK,maxK+1):
        Palindromic = 0
        InPingPongPairs = 0.0
        PP1U = 0.0
        PP10A = 0.0
        PP1U10A = 0.0
        for read in ReadList:
            kmer = read[0:k]
            revkmer = reverseComplement(read[0:k],DNA)
            if kmer == revkmer:
                Palindromic += 1
                continue
            if FirstRevKmerDict[k].has_key(kmer):
                HigherK = False
                for kk in range(k+1,maxK+1):
                    kkmer = read[0:kk]
                    if FirstRevKmerDict[kk].has_key(kkmer):
                        HigherK = True
                        break
                if HigherK:
                    continue
                else:
                    InPingPongPairs += 1
                    if read[0] == 'T':
                        PP1U += 1
                    if read[k-1] == 'A':
                        PP10A += 1
                    if read[k-1] == 'A' and read[0] == 'T':
                        PP1U10A += 1
        outline = str(k) + '\t' + str(len(ReadList)) + '\t' + str(Palindromic) + '\t' + str(int(InPingPongPairs)) + '\t' + str(InPingPongPairs/(len(ReadList) - Palindromic))
        outline = outline + '\t' + str(int(PP1U))
        outline = outline + '\t' + str(int(PP10A))
        outline = outline + '\t' + str(int(PP1U10A))
        outline = outline + '\t' + str(PP1U/len(ReadList))
        outline = outline + '\t' + str(PP10A/len(ReadList))
        outline = outline + '\t' + str(PP1U10A/len(ReadList))
        outline = outline + '\t' + str(PP1U/InPingPongPairs)
        outline = outline + '\t' + str(PP10A/InPingPongPairs)
        outline = outline + '\t' + str(PP1U10A/InPingPongPairs)
        outfile.write(outline + '\n')

    outfile.close()
        
run()

