# Copyright (c) 2013, Ian Reid, Concordia University Centre for Structural and Functional Genomics
# All rights reserved.

''' Convert alignments in a SAM output file to a list of fragment origins
Created on Sept. 13, 2010

@author: ian
'''
import sys, re, heapq

aln_re = re.compile('(\S+)\s+(\d+)\s+(\w+)\s+(\d+)\s+\S+\s+([0-9MN]+).+\s+NH:i:(\d+)')
#                   read_name, flag,  refseq,_offset,        cigar,       num_matches
aln_re2 = re.compile('(\S+)\s+(\d+)\s+(\w+)\s+(\d+)\s+\S+\s+([0-9MN]+)\s+')


class SortedOutput(object):
    def __init__(self, output):
        object.__init__(self)
        self.output = output
        self.group = ''
        self.heap = []
        self.groups_seen = set()
        self.cursor = None

    def flush(self):
        while self.__len__() > 0:
            position, left, right = self.pop_next_position()
            print >> self.output, '%s\t%d\t%d\t%d\t%d' % (self.group, position, left, right, left + right)
        self.cursor = None

    def add(self, key, value, group='default_group'):
        if group != self.group:
            self.flush()
            self.cursor = None
            if group in self.groups_seen:
                raise ValueError('[SortedOutput.add] Input Sorting error; Group %s seen more than once' % group)
            if self.cursor and self.cursor > key:
                raise ValueError(
                    '[SortedOutput.add] Input Sorting error; Key %s is less than previous output' % str(key))
            self.group = group
            self.groups_seen.add(group)
        heapq.heappush(self.heap, (key, value))

    def lowest_key(self):
        if self.heap:
            return (self.heap[0][0])
        return None

    def __len__(self):
        return len(self.heap)


    def output_up_to(self, limit_key):
        if self.cursor and self.cursor > limit_key:
            raise ValueError(
                '[SortedOutput.output_up_to] Input Sorting error; Key %s is less than previous output' % str(limit_key))
        while self.__len__() > 0 and self.lowest_key() <= limit_key:
            position, left, right = self.pop_next_position()
            print >> self.output, '%s\t%d\t%d\t%d\t%d' % (self.group, position, left, right, left + right)

    def pop_next_position(self):
        left = right = 0
        self.cursor = self.lowest_key()
        while self.lowest_key() == self.cursor:
            key, value = heapq.heappop(self.heap)
            if value == '-':
                left += 1
            else:
                right += 1
        return self.cursor, left, right

# end class SortedOutput

def parse_alignment(read_name, flag, refseq, _offset, cigar, num_matches):
    offset = int(_offset)
    chrom = refseq
    position = offset
    strand = '+'
    if int(flag) & 16:
        strand = '-'
        match_length = 0
        segments = cigar.split("N")
        for segment in segments:
            if segment:
                subsegments = segment.split("M")
                for length in subsegments:
                    if length:
                        match_length += int(length)
        position += match_length - 1 # SAM coordinates are inclusive
    score = 1 / (float(num_matches) )
    return read_name, chrom, position, strand, score


def usage():
    print 'Usage: SAM2origin.py  input.sam  output.txt'
    exit(1)


def process_alns(alns, output_sorter):
    curr_position = output_sorter.lowest_key()
    for aln in alns:
        name, chrom, position, orientation, score = aln
        if orientation == '+' and position > curr_position:
            output_sorter.output_up_to(curr_position)
            curr_position = position
        output_sorter.add(position, orientation, chrom)


def do_SAM2origin(SAM_lines, origin_counts):
    output_sorter = SortedOutput(origin_counts)
    currname = ''
    alns = []
    for line in SAM_lines:
        if line.startswith('@'):
            continue
        m = aln_re.search(line)
        m2 = aln_re2.search(line)
        if m:
            n_aln = parse_alignment(*m.groups())
        elif m2:
            n_aln = parse_alignment(*m2.groups(), num_matches=1)
        if m or m2:
            n_name = n_aln[0]
            if n_name != currname:
                process_alns(alns, output_sorter)
                currname = n_name
                alns = []
            alns.append(n_aln)
        else:
            print >> sys.stderr, 'Reject:', line
    if len(alns) > 0:
        process_alns(alns, output_sorter)
    output_sorter.flush()


if __name__ == '__main__':
    if len(sys.argv) != 3:
        usage()
    if sys.argv[1] == '-':
        input = sys.stdin
    else:
        input = open(sys.argv[1])
    if sys.argv[2] == '-':
        output = sys.stdout
    else:
        output = open(sys.argv[2], 'w')

    do_SAM2origin(input, output)

    output.close()
