# Copyright (c) 2013, Ian Reid, Concordia University Centre for Structural and Functional Genomics
# All rights reserved.

'''Given a read alignment in SAM format with MD tags, count the frequency of each quality score by base position for
correct and erroneous bases and Ns

Created on May 18, 2010
@author: ian
'''

import sys
import os
import re

this_dir = os.path.dirname(__file__)
src = os.path.dirname(this_dir)
sys.path.append(src)
from numpy import *
from simulateRNA_Seq import MAX_ILLUMINA_QUALITY
from lib.samRead import SAMReadFileIterator, SAMHeader
#import traceback

def usage(msg):
    if msg:
        print >> sys.stderr, msg
    print >> sys.stderr, 'Usage: python %s input_withMD.sam  maximum_read_length  out_stream  maximum_errors_per_read ' \
                         'sample_size' % sys.argv[0]
    sys.exit(1)


def decode_MD(tags):
    for tag in tags:
        if tag.startswith('MD:'):
            md = tag.split(':')[-1]
            fields = re.split('\D', md)
            result = []
            for field in fields:
                result += [0] * int(field)
                result += [1]
            return result[:-1]
    raise ValueError('No MD tag found in ' + str(tags))


def check_BBB(sam_read):
    rev_quals = sam_read.qual if sam_read.is_reversed() else sam_read.qual[::-1]
    bbb = rev_quals[0] == '#'
    BBB = [0] * len(rev_quals)
    for q in rev_quals:
        if bbb and q == '#':
            BBB.append(1)
        else:
            break
    return BBB[-len(rev_quals):]


def count_correct_and_wrong_qualities_by_position(sam_file, max_read_len, max_err, sample_size):
    """ Process a SAM file and return matrices of quality score frequencies at each read position for correct and
    erroneous reads.
    
    Parameters:
        - input     A file in SAM format containing MD tags open for reading
        - max_read_len Length of the longest read, int
        - max_err  Ignore reads with more errors than this, int
    """
    right = zeros([max_read_len, MAX_ILLUMINA_QUALITY + 1])
    wrong = zeros([max_read_len, MAX_ILLUMINA_QUALITY + 1])
    Ns = zeros([max_read_len, MAX_ILLUMINA_QUALITY + 1])
    BBB_right = zeros([max_read_len, MAX_ILLUMINA_QUALITY + 1])
    BBB_wrong = zeros([max_read_len, MAX_ILLUMINA_QUALITY + 1])
    BBB_Ns = zeros([max_read_len, MAX_ILLUMINA_QUALITY + 1])
    counts = [[right, wrong, Ns], [BBB_right, BBB_wrong, BBB_Ns]]
    read_count = 0
    for sam in sam_file:
        if read_count > sample_size:
            break
        try:
            BBB = check_BBB(sam)
            subs = decode_MD(sam.tags)
            if sum(subs) > max_err:
                continue
            errs = [2 if c == 'N' else s for s, c in zip(subs, sam.seq)]
            qual = sam.qual
            if sam.is_reversed():
                errs.reverse()
                qual = sam.qual[::-1]
            for i, qual_code in enumerate(qual):
                score = ord(qual_code) - 33
                if BBB[i] and score > 2:
                    print >> sys.stderr, 'BBB-qual inconsistency!'
                counts[BBB[i]][errs[i]][i][score] += 1
            read_count += 1
        except Exception, e:
            # traceback.print_exc()
            # print >> sys.stderr, i, score, BBB[i], errs[i]
            # print >> sys.stderr, len(counts), len(counts[BBB[i]]), len(counts[BBB[i]][errs[i]])
            print >> sys.stderr, e

    return counts[0] + counts[1]


if __name__ == '__main__':
    try:
        input_file = SAMReadFileIterator(sys.argv[1])
        input = input_file.iterator()
        max_read_len = int(sys.argv[2])
        if sys.argv[3] == '-':
            out_stream = sys.stdout
            save_dir = os.path.dirname(os.path.abspath(sys.argv[1]))
        else:
            out_stream = open(sys.argv[3], 'w')
            save_dir = os.path.dirname(os.path.abspath(sys.argv[3]))
        max_err = max_read_len
        if len(sys.argv) > 4:
            max_err = int(sys.argv[4])
        sample_size = sys.maxint
        if len(sys.argv) > 5:
            sample_size = int(sys.argv[5])
    except Exception, e:
        usage(e)

    right, wrong, Ns, BBB_right, BBB_wrong, BBB_Ns = count_correct_and_wrong_qualities_by_position(input, max_read_len,
                                                                                                   max_err, sample_size)
    rows = len(right)
    print >> out_stream, 'Correct reads:'
    for r in range(rows):
        line = '\t'.join([str(i) for i in right[r, :]])
        print >> out_stream, line
    print >> out_stream, '\nIncorrect reads:'
    for r in range(rows):
        line = '\t'.join([str(i) for i in wrong[r, :]])
        print >> out_stream, line
    print >> out_stream, '\nN reads:'
    for r in range(rows):
        line = '\t'.join([str(i) for i in Ns[r, :]])
        print >> out_stream, line
    print >> out_stream, '\nBBB_Correct reads:'
    for r in range(rows):
        line = '\t'.join([str(i) for i in BBB_right[r, :]])
        print >> out_stream, line
    print >> out_stream, '\nBBB_Incorrect reads:'
    for r in range(rows):
        line = '\t'.join([str(i) for i in BBB_wrong[r, :]])
        print >> out_stream, line
    print >> out_stream, '\nBBB_N reads:'
    for r in range(rows):
        line = '\t'.join([str(i) for i in BBB_Ns[r, :]])
        print >> out_stream, line
    out_stream.close()

    # Calculate error probabilities
    normal_total = (sum(right, axis=1) + sum(wrong, axis=1) + sum(Ns, axis=1))
    err_rate = true_divide((sum(wrong, axis=1) + sum(Ns, axis=1)), normal_total)
    BBB_total = (sum(BBB_right, axis=1) + sum(BBB_wrong, axis=1) + sum(BBB_Ns, axis=1))
    BBB_err_rate = true_divide((sum(BBB_wrong, axis=1) + sum(BBB_Ns, axis=1)), BBB_total)
    BBB_rate = true_divide(BBB_total, (normal_total + BBB_total))
    BBB_init_rate = [BBB_rate[0]]
    for r in range(1, rows):
        BBB_init_rate.append((BBB_rate[r] - BBB_rate[r - 1]) / (1. - BBB_rate[r - 1]))

    out_stream = open(os.path.join(save_dir, 'error_probabilities.txt'), 'w')
    for x in err_rate:
        print >> out_stream, x
    out_stream.close()

    out_stream = open(os.path.join(save_dir, 'BBB.error_probabilities.txt'), 'w')
    for x in BBB_err_rate:
        print >> out_stream, x
    out_stream.close()

    out_stream = open(os.path.join(save_dir, 'BBB_rate.txt'), 'w')
    for x in BBB_rate:
        print >> out_stream, x
    out_stream.close()

    out_stream = open(os.path.join(save_dir, 'BBB.init_probabilities.txt'), 'w')
    for x in BBB_init_rate:
        print >> out_stream, x
    out_stream.close()

    # Calculate quality code probabilities
    correct_total = sum(right, axis=1)
    correct_total.shape = (-1, 1)
    correct_freq = true_divide(right, correct_total)
    correct_cum_freq = correct_freq.cumsum(axis=1)
    out_stream = open(os.path.join(save_dir, 'Correct_reads.quality_scores.cumulative_frequency.txt'), 'w')
    for r in range(rows):
        line = '\t'.join([str(x) for x in correct_cum_freq[r, :]])
        print >> out_stream, line
    out_stream.close()

    incorrect_total = sum(wrong, axis=1)
    incorrect_total.shape = (-1, 1)
    incorrect_freq = true_divide(wrong, incorrect_total)
    incorrect_cum_freq = incorrect_freq.cumsum(axis=1)
    out_stream = open(os.path.join(save_dir, 'Incorrect_reads.quality_scores.cumulative_frequency.txt'), 'w')
    for r in range(rows):
        line = '\t'.join([str(x) for x in incorrect_cum_freq[r, :]])
        print >> out_stream, line
    out_stream.close()

    print sys.argv[0], 'done.'

