# Copyright (c) 2013, Ian Reid, Concordia University Centre for Structural and Functional Genomics
# All rights reserved.

__author__ = 'ian'

import sys
import os
import re

this_dir = os.path.dirname(__file__)
src = os.path.dirname(this_dir)
sys.path.append(src)
import argparse
from lib.samText import SAMTextFileIterator, get_cigar_length, get_aligned_length

DESCRIPTION = "Adjust a SAM file of true mappings of simulated reads for 3' trimming by trimmomatic"
VERSION = '0.1'


def get_args():
    argparser = argparse.ArgumentParser(description=DESCRIPTION)
    # standard options
    argparser.add_argument('--version', action='version', version='%(prog)s' + VERSION)
    argparser.add_argument('--verbose', '-v', action='count', default=0,
                           help='Omit to see only fatal error messages; -v to see warnings; -vv to see warnings and '
                                'progress messages')
    # options to customize
    argparser.add_argument('--in', '-i', dest='input', required=True, help='Path to the input file; required')
    argparser.add_argument('--out', '-o', required=True, help='Path to the output file; required')
    argparser.add_argument('--length', '-l', required=True,
                           help='Path to the trimmomatic log file showing trimmed lengths; required')
    return argparser.parse_args()


def right_trim_read(read, new_end):
    loss = len(read) - new_end
    if loss < 0:
        raise ValueError('Specified truncated end %d is longer than current end %d' % (new_end, len(read)))
    if loss == 0:
        return read
    if read.is_reversed():
        read.seq = read.seq[-new_end:]
        read.qual = read.qual[-new_end:]
        origin = read.pos + read.alen
        cigar = re.split('([DIMNSHP=X])', read.cigar)[:-1] # list of code, count pairs
        residue = new_end
        new_pos = origin
        new_cigar = []
        for i in range(len(cigar) - 2, -1, -2):
            count = int(cigar[i])
            opcode = cigar[i + 1]
            if count < residue or opcode == 'N':
                if opcode in 'MIDN':
                    new_cigar.append('%d%s' % (count, opcode))
                if opcode in 'M':
                    residue -= count
                    new_pos -= count
                elif opcode in 'N':
                    new_pos -= count
            else:
                if opcode in 'MIDN':
                    new_cigar.append('%d%s' % (residue, opcode))
                if opcode in 'M':
                    new_pos -= residue
                break
        new_cigar.reverse()
        read.cigar = ''.join(new_cigar)
        read.pos = new_pos
    else:
        read.seq = read.seq[:new_end]
        read.qual = read.qual[:new_end]
        # truncate cigar
        cigar = re.split('([DIMNSHP=X])', read.cigar)[:-1] # list of code, count pairs
        residue = new_end
        new_cigar = []
        for i in range(0, len(cigar), 2):
            count = int(cigar[i])
            opcode = cigar[i + 1]
            if count < residue or opcode == 'N':
                if opcode in 'MIDN':
                    new_cigar.append('%d%s' % (count, opcode))
                if opcode in 'M':
                    residue -= count
            else:
                if opcode in 'MIDN':
                    new_cigar.append('%d%s' % (residue, opcode))
                break
        read.cigar = ''.join(new_cigar)
    read.alen = get_aligned_length(read.cigar)
    #update tags
    nm = 0
    new_tags = []
    for tag in read.tags:
        if 'MD:Z:' in tag:
            md = tag[5:]
            fields = re.split('([ACGTacgt])', md)
            new_md = []
            residue = new_end
            if len(fields) > 1:
                if read.is_reversed():
                    for i in range(len(fields) - 1, -1, -2):
                        count = int(fields[i])
                        if count < residue:
                            new_md.append(fields[i])
                            new_md.append(fields[i - 1])
                            nm += 1
                            residue -= count + 1
                        else:
                            new_md.append(str(residue))
                            residue = 0
                            break
                else:
                    for i in range(0, len(fields) - 1, 2):
                        count = int(fields[i])
                        base = fields[i + 1]
                        if count < residue:
                            new_md.extend(fields[i:i + 2])
                            nm += 1
                            residue -= count + 1
                        else:
                            new_md.append(str(residue))
                            residue = 0
                            break
            if residue > 0:
                new_md.append(str(residue))
            if read.is_reversed():
                new_md.reverse()
            if not new_md[0].isdigit():
                new_md.insert(0, '0')
            if not new_md[-1].isdigit():
                new_md.append('0')
            new_tags.append(''.join(['MD:Z:'] + new_md))
        elif tag.startswith('NM:i:'):
            continue # will replace at end
        elif tag.startswith('XS:A'):
            if 'N' in read.cigar:
                new_tags.append(tag)
        else:
            new_tags.append(tag)
    new_tags.append('NM:i:%d' % nm)
    read.tags = new_tags
    if len(read) != new_end:
        raise ValueError('Length of read %s after right trimming, %d, is not equal to requested length %d' % (
        read.qname, len(read), new_end))
        # assert get_cigar_length(read.cigar) == new_end
    if get_cigar_length(read.cigar) != new_end:
        raise ValueError('Length %d of trimmed cigar %s in read %s is not equal to requested length %d' % (
        get_cigar_length(read.cigar), read.cigar, read.qname, new_end))
    return read


def left_trim_read(read, new_start):
    if new_start == 0:
        return read
    if new_start < 0:
        raise ValueError('Specified start %d is negative' % (new_start))
    if new_start >= len(read):
        raise ValueError('Specified start %d is beyond read end %d' % (new_start, len(read)))
    original_len = len(read)
    if read.is_reversed():
        read.seq = read.seq[:-new_start]
        read.qual = read.qual[:-new_start]
        cigar = re.split('([DIMNSHP=X])', read.cigar)[:-1] # list of code, count pairs
        residue = original_len - new_start
        new_cigar = []
        for i in range(0, len(cigar) - 1, 2):
            if residue <= 0:
                break
            count = int(cigar[i])
            opcode = cigar[i + 1]
            if count > residue or opcode == 'N':
                if opcode in 'N':
                    new_cigar.append('%d%s' % (count, opcode))
                if opcode in 'M':
                    new_cigar.append('%d%s' % (residue, opcode))
                    residue = 0
            else:
                if opcode in 'MIDN':
                    new_cigar.append('%d%s' % (count, opcode))
                if opcode in 'MN':
                    residue -= count
        read.cigar = ''.join(new_cigar)
    else:
        read.seq = read.seq[new_start:]
        read.qual = read.qual[new_start:]
        # truncate cigar
        cigar = re.split('([DIMNSHP=X])', read.cigar)[:-1] # list of code, count pairs
        residue = original_len - new_start
        new_cigar = []
        new_pos = read.pos + read.alen
        for i in range(len(cigar) - 2, -1, -2):
            count = int(cigar[i])
            opcode = cigar[i + 1]
            if count < residue or opcode == 'N':
                if opcode in 'MIDN':
                    new_cigar.insert(0, '%d%s' % (count, opcode))
                if opcode in 'M':
                    residue -= count
                if opcode in 'MN':
                    new_pos -= count
            else:
                if opcode in 'MIDN':
                    new_cigar.insert(0, '%d%s' % (residue, opcode))
                if opcode in 'M':
                    new_pos -= residue
                break
        read.cigar = ''.join(new_cigar)
        read.pos = new_pos
    read.alen = get_aligned_length(read.cigar)
    #update tags
    nm = 0
    new_tags = []
    for tag in read.tags:
        if 'MD:Z:' in tag:
            md = tag[5:]
            fields = re.split('([ACGTacgt])', md)
            new_md = []
            residue = original_len - new_start
            if len(fields) > 1:
                if read.is_reversed():
                    for i in range(0, len(fields) - 1, 2):
                        count = int(fields[i])
                        if count < residue:
                            new_md.append(fields[i])
                            new_md.append(fields[i + 1])
                            nm += 1
                            residue -= count + 1
                        else:
                            new_md.append(str(residue))
                            residue = 0
                            break
                else:
                    for i in range(len(fields) - 2, -1, -2):
                        count = int(fields[i + 1])
                        base = fields[i]
                        if count < residue:
                            new_md.insert(0, fields[i + 1])
                            new_md.insert(0, fields[i])
                            nm += 1
                            residue -= count + 1
                        else:
                            new_md.insert(0, str(residue))
                            residue = 0
                            break
            if residue > 0:
                new_md.append(str(residue))
            if not new_md[0].isdigit():
                new_md.insert(0, '0')
            if not new_md[-1].isdigit():
                new_md.append('0')
            new_tags.append(''.join(['MD:Z:'] + new_md))
        elif tag.startswith('NM:i:'):
            continue # will replace at end
        elif tag.startswith('XS:A'):
            if 'N' in read.cigar:
                new_tags.append(tag)
        else:
            new_tags.append(tag)
    new_tags.append('NM:i:%d' % nm)
    read.tags = new_tags
    return read


def do_trimSAMreads(input, output, lengths):
    for read in input:
        try:
            length_line = lengths.readline()
            fields = length_line.strip().split()
            name = fields[0]
            if name == read.qname:
                start = int(fields[2])
                end = int(fields[3])
                if start > 0:
                    read = left_trim_read(read, start)
                if end > 0:
                    read = right_trim_read(read, end - start)
                    print >> output, read
            else:
                print >> sys.stderr, 'Read %s is out of sync with trim log %s' % (read.qname, name)
        except ValueError, ve:
            print >> sys.stderr, ve


if __name__ == '__main__':
    args = get_args()
    infile = SAMTextFileIterator(args.input)
    input = infile
    lengths = open(args.length)
    output = open(args.out, mode='w')
    if infile.header():
        print >> output, infile.header()
    else:
        print >> sys.stderr, 'Warning: input file %s has no SAM header' % args.input
    do_trimSAMreads(input, output, lengths)
    output.close()
    print >> sys.stderr, sys.argv[0], 'done.'
