# Copyright (c) 2013, Ian Reid, Concordia University Centre for Structural and Functional Genomics
# All rights reserved.


__author__ = 'ian'

import sys
import os
import re

this_dir = os.path.dirname(__file__)
src = os.path.dirname(this_dir)
sys.path.append(src)
import argparse
from lib.samText import SAMTextFileIterator, get_cigar_length

DESCRIPTION = "Adjust a SAM file of true mappings of simulated reads for 3' truncation of the reads"
VERSION = '0.1'


def get_args():
    argparser = argparse.ArgumentParser(description=DESCRIPTION)
    # standard options
    argparser.add_argument('--version', action='version', version='%(prog)s' + VERSION)
    argparser.add_argument('--verbose', '-v', action='count', default=0,
                           help='Omit to see only fatal error messages; -v to see warnings; -vv to see warnings and '
                                'progress messages')
    # options to customize
    argparser.add_argument('--in', '-i', dest='input', required=True, help='Path to the input file; required')
    argparser.add_argument('--out', '-o', required=True, help='Path to the output file; required')
    argparser.add_argument('--length', '-l', type=int, required=True, help='Target read length; required')
    return argparser.parse_args()


def truncate_read(read, length):
    loss = len(read) - length
    if loss < 0:
        raise ValueError('Specified truncated length %d is longer than actual length %d' % (length, len(read)))
    if loss == 0:
        return read
    if read.is_reversed():
        read.seq = read.seq[-length:]
        read.qual = read.qual[-length:]
        origin = read.pos + read.alen
        cigar = re.split('([DIMNSHP=X])', read.cigar)[:-1] # list of code, count pairs
        residue = length
        new_pos = origin
        new_cigar = []
        for i in range(len(cigar) - 2, -1, -2):
            count = int(cigar[i])
            opcode = cigar[i + 1]
            if count < residue or opcode == 'N':
                if opcode in 'MIDN':
                    new_cigar.append('%d%s' % (count, opcode))
                if opcode in 'M':
                    residue -= count
                    new_pos -= count
                elif opcode in 'N':
                    new_pos -= count
            else:
                if opcode in 'MIDN':
                    new_cigar.append('%d%s' % (residue, opcode))
                if opcode in 'M':
                    new_pos -= residue
                break
        new_cigar.reverse()
        read.cigar = ''.join(new_cigar)
        read.pos = new_pos
    else:
        read.seq = read.seq[:length]
        read.qual = read.qual[:length]
        # truncate cigar
        cigar = re.split('([DIMNSHP=X])', read.cigar)[:-1] # list of code, count pairs
        residue = length
        new_cigar = []
        for i in range(0, len(cigar), 2):
            count = int(cigar[i])
            opcode = cigar[i + 1]
            if count < residue or opcode == 'N':
                if opcode in 'MIDN':
                    new_cigar.append('%d%s' % (count, opcode))
                if opcode in 'M':
                    residue -= count
            else:
                if opcode in 'MIDN':
                    new_cigar.append('%d%s' % (residue, opcode))
                break
        read.cigar = ''.join(new_cigar)
        #update tags
    nm = 0
    new_tags = []
    for tag in read.tags:
        if 'MD:Z:' in tag:
            md = tag[5:]
            fields = re.split('([ACGTNacgtn])', md)
            new_md = []
            residue = length
            if len(fields) > 1:
                if read.is_reversed():
                    for i in range(len(fields) - 1, -1, -2):
                        count = int(fields[i])
                        if count < residue:
                            new_md.append(fields[i])
                            new_md.append(fields[i - 1])
                            nm += 1
                            residue -= count + 1
                        else:
                            new_md.append(str(residue))
                            residue = 0
                            break
                else:
                    for i in range(0, len(fields) - 1, 2):
                        count = int(fields[i])
                        if count < residue:
                            new_md.extend(fields[i:i + 2])
                            nm += 1
                            residue -= count + 1
                        else:
                            new_md.append(str(residue))
                            residue = 0
                            break
            if residue > 0:
                new_md.append(str(residue))
            if read.is_reversed():
                new_md.reverse()
            if not new_md[0].isdigit():
                new_md.insert(0, '0')
            if not new_md[-1].isdigit():
                new_md.append('0')
            new_tags.append(''.join(['MD:Z:'] + new_md))
        elif tag.startswith('NM:i:'):
            continue # will replace at end
        elif tag.startswith('XS:A'):
            if 'N' in read.cigar:
                new_tags.append(tag)
        else:
            new_tags.append(tag)
    new_tags.append('NM:i:%d' % nm)
    read.tags = new_tags
    assert len(read) == length
    assert get_cigar_length(read.cigar) == length
    return read


def do_truncateSAM(input, output, length):
    opcodes = 'MIDNSHP'
    for read in input:
        read = truncate_read(read, length)
        print >> output, read


if __name__ == '__main__':
    args = get_args()
    infile = SAMTextFileIterator(args.input)
    input = infile
    if args.out == '-':
        output = sys.stdout
    else:
        output = open(args.out, mode='w')
    if infile.header():
        print >> output, infile.header()
    else:
        print >> sys.stderr, 'Warning: input file %s has no SAM header' % args.input
    do_truncateSAM(input, output, args.length)
    output.close()
    print >> sys.stderr, sys.argv[0], 'done.'
