#!/usr/bin/python

# Copyright (c) 2013, Ian Reid, Concordia University Centre for Structural and Functional Genomics
# All rights reserved.

"""Compare two mappings of the same read set onto the same genome

Created on May 6, 2010
@author: ian
"""

import sys, os
import itertools
from collections import defaultdict
this_dir = os.path.dirname(__file__)
src = os.path.dirname(this_dir)
sys.path.append(src)
from lib.samRead import SAMRead,  SAMReadFileIterator
from pysam import Samfile

class MapComparison(object):
    def __init__(self, map1, map2):
        object.__init__(self)
        self.identical_count = 0
        self.different_count = 0
        self.only_map1_count = 0
        self.only_map2_count = 0
        print 'Reading map1'
        self.map1_dict = self.make_read_dict(map1)
        print 'Reading map2'
        self.map2_dict = self.make_read_dict(map2)
        print 'Both maps have been read.'
        self.differently_mapped_reads = set()
        self.reads_only_in_map1 = set()
        self.reads_only_in_map2 = set()

    def make_read_dict(self, mapping):
        result = defaultdict(list)
        for read in mapping:
            if read.startswith('@'):
                continue
            sam = SAMRead.fromString(read)
            if sam.rname == '*':
                continue
            if sam.is_paired():
                if sam.is_first_of_pair():
                    sam.qname += '/1'
                elif sam.is_second_of_pair():
                    sam.qname += '/2'
            result[sam.qname].append((sam.rname, sam.pos, sam.flag, sam.cigar))
#            fields = read.strip().split('\t')
#            if fields[2] == '*':
#                continue
#            result[fields[0]].append((fields[2], fields[3], fields[1], fields[5]))
        return result

    def _calc_counts(self):
        for read1 in self.map1_dict:
            if read1 in self.map2_dict:
                mappings1 = self.map1_dict[read1]
                mappings2 = self.map2_dict[read1]
                for mapping in mappings1:
                    if mapping in mappings2:
                        self.identical_count  += 1
                    else:
                        self.different_count += 1
            else:
                self.only_map1_count += 1
        for read2 in self.map2_dict:
            if read2 in self.map1_dict:
                mappings1 = self.map1_dict[read2]
                mappings2 = self.map2_dict[read2]
                for mapping in mappings2:
                    if mapping not in mappings1:
                        self.different_count += 1
            else:
                self.only_map2_count += 1

    def _calc_subsets(self):
        for read1 in self.map1_dict:
            if read1 in self.map2_dict:
                mappings1 = self.map1_dict[read1]
                mappings2 = self.map2_dict[read1]
                for mapping in mappings1:
                    if mapping not in mappings2:
                        self.differently_mapped_reads.add(read1)
            else:
                self.reads_only_in_map1.add(read1)
        for read2 in self.map2_dict:
            if read2 in self.map1_dict:
                mappings1 = self.map1_dict[read2]
                mappings2 = self.map2_dict[read2]
                for mapping in mappings2:
                    if mapping not in mappings1:
                        self.differently_mapped_reads.add(read2)
            else:
                self.reads_only_in_map2.add(read2)

    def get_counts(self):
        """Returns a tuple of ints: identical_count, different_count, only_map1_count, only_map2_count
        """
        if self.identical_count == 0:
            self._calc_counts()
        return self.identical_count, self.different_count, self.only_map1_count, self.only_map2_count

    def get_differently_mapped_reads(self):
        """Returns a set of read qnames
        """
        if not self.differently_mapped_reads:
            self._calc_subsets()
        return self.differently_mapped_reads

    def get_reads_only_in_map1(self):
        """Returns a set of read qnames
        """
        if not self.reads_only_in_map1:
            self._calc_subsets()
        return self.reads_only_in_map1

    def get_reads_only_in_map2(self):
        """Returns a set of read qnames
        """
        if not self.reads_only_in_map2:
            self._calc_subsets()
        return self.reads_only_in_map2

def compare(map1, map2):
    """Compare 2 mappings of the same read set onto the same genome.

    map1 and map2 are iterators over the two mappings, in SAM format
    Returns a tuple of ints: identical_count, different_count, only_map1_count, only_map2_count
    """
    result = MapComparison(map1, map2)
    return result.get_counts()

def simplify_id(id):
    if id.startswith('seq.'):
        id = id[4:]
#    return id
    return id.split('/')[0]

def get_id(sam_line):
#    return sam_line.strip().split('\t')[0]
    id = sam_line.get_qname()
    return simplify_id(id)

def do_nothing(m):
    pass

class SequentialMapComparison(object):
    '''Assumes that maps are sorted by read id
    '''
    def __init__(self, map1, map2, difference_handler=do_nothing, only1_handler=do_nothing, only2_handler=do_nothing, out_stream=sys.stdout,   discard_suffix=False, include_secondary=False):
        object.__init__(self)
        self.iters = []
        for mapn in (map1, map2):
            self.iters.append(itertools.groupby(mapn,  get_id))
        self.handle_differently_mapped_read = difference_handler
        self.handle_read_only_in_map1 = only1_handler
        self.handle_read_only_in_map2 = only2_handler
        self.outstream = out_stream
        self.discard_id_suffix = discard_suffix
        self.include_secondary_alignments = include_secondary
        self._calc_subsets()

    def read_iter(self, itr):
            try:
                read_id, grouper = itr.next()
                while read_id.startswith('@'):
                    read_id, grouper = itr.next()
                group = [read for read in grouper if read.is_mapped() and (read.is_primary_alignment() or self.include_secondary_alignments)] # some SAM/BAM files include unmapped reads
                status = True
            except StopIteration:
                status = False
                read_id = 'zz'
                group = []
            return read_id, group, status

    def read_groups(self):
        # initialize
        more = []
        ids = []
        groups = []
        for itr in self.iters:
            read_id, group, status = self.read_iter(itr)
            more.append(status)
            ids.append(read_id)
            groups.append(group)

        # main loop
        while True in more:
            if all([read_id.isdigit() for read_id in ids]): # numerical order
                curr_id = str(min([int(read_id) for read_id in ids]))
            else: # alphabetical order
                curr_id = min(ids)
            result = []
            for i, read_id in enumerate(ids):
                if read_id == curr_id:
                    result.append(groups[i])
                    ids[i], groups[i], more[i] = self.read_iter(self.iters[i])
                else:
                    result.append([])
            yield curr_id, result

    def _calc_subsets(self):
        self.read_count = 0
        self.common_count = 0
        self.difference_count_1 = self.difference_count_2 = 0
        self.only1 = 0
        self.only2 = 0
        for read_id, mappings in self.read_groups():
            map_groups = []
            for map_list in mappings:
                map_group = []
                for a_map in map_list:
                    if self.discard_id_suffix:
                        fields = [simplify_id(a_map.qname),  a_map.flag,  a_map.rname,  a_map.pos,  a_map.cigar]
                    else:
                        fields = [a_map.qname,  a_map.flag,  a_map.rname,  a_map.pos,  a_map.cigar]
                    fields[1] = str(int(fields[1]) & 16)
                    map_group.append(fields)
                map_groups.append(map_group)
            self.read_count += 1
            for i, mapping in enumerate(map_groups[0]):
                read = mappings[0][i]
                if len(map_groups[1]) == 0:
                    self.handle_read_only_in_map1(read)
                    self.only1 += 1
                elif mapping in map_groups[1]:
                    self.common_count += 1
                else:
                    found = False
                    for m2 in map_groups[1]:
                        if m2[1] == mapping[1]:
                            found = True
                            break
                    if found:
                        tags = read.get_tags()
                        tags.append('XM:i:1')
                        read.set_tags(tags)
                        self.handle_differently_mapped_read(read)
                        self.difference_count_1 += 1
                    else:
                        self.handle_read_only_in_map1(read)
                        self.only1 += 1
            for i,  mapping in enumerate(map_groups[1]):
                read = mappings[1][i]
                if len(map_groups[0]) == 0:
                    self.handle_read_only_in_map2(read)
                    self.only2 += 1
                elif mapping not in map_groups[0]:
                    found = False
                    for m2 in map_groups[0]:
                        if m2[1] == mapping[1]:
                            found = True
                            break
                    if found:
                        tags = read.get_tags()
                        tags.append('XM:i:2')
                        read.set_tags(tags)
                        self.handle_differently_mapped_read(read)
                        self.difference_count_2 += 1
                    else:
                        self.handle_read_only_in_map2(read)
                        self.only2 += 1
        print >> self.outstream, 'Number of reads =', self.read_count
        print >> self.outstream, self.common_count, 'mappings were common to both maps.'
        print >> self.outstream, self.difference_count_1, 'mappings were different in the first map.'
        print >> self.outstream, self.difference_count_2, 'mappings were different in the second map.'
        print >> self.outstream, self.only1, 'mappings were only found in the first map.'
        print >> self.outstream, self.only2, 'mappings were only found in the second map.'

def make_save_handler(out_handle):
    def handler(mapping):
        print >> out_handle, '\t'.join([str(f) for f in mapping])
    return handler

def make_save_SAM_handler(out_filename,  sam_header):
    SAM_out  = Samfile(out_filename, mode ='wb',  header=sam_header)
    def handler(read):
        read.write_to_samfile(SAM_out)
    return SAM_out,  handler

class Counter(object):
    def __init__(self, initial_value=0):
        object.__init__(self)
        self._value = int(initial_value)

    def increment(self):
        self._value += 1

    def get_count(self):
        return self._value

def make_count_handler(initial_value=0):
    counter = Counter(initial_value)
    def handler(mapping):
        counter.increment()
    return counter, handler

def make_accumulate_handler(accumulator):
    def handler(mapping):
        accumulator.append(mapping)
    return handler

def usage():
    print 'Usage: python %s map1.sam map2.sam output_stem/ [discard ID suffix (True/False) [include secondary alignments (True/False) ]]' % sys.argv[0]
    sys.exit(1)

if __name__ == '__main__':
    try:
        map1 = SAMReadFileIterator(sys.argv[1])
        map2 = SAMReadFileIterator(sys.argv[2])
        output_stem = sys.argv[3]
        diff_out = output_stem + '/differently_mapped_reads.bam'
        only1_out = output_stem + '/reads_only_in_map1.bam'
        only2_out = output_stem + '/reads_only_in_map2.bam'

        discard_id_suffix = False # Some programs discard the /1,/2 suffix from paired reads; discard_id_suffix must be set to True for the output of these programs
        if len(sys.argv) > 4 and sys.argv[4] == 'True':
            discard_id_suffix = True
        include_secondary_alignments = False
        if len(sys.argv) > 5 and sys.argv[5] == 'True':
            include_secondary_alignments = True

        diff_out_file, save_differently_mapped_reads = make_save_SAM_handler(diff_out, map1.header)
        only1_out_file, save_map1_only_reads = make_save_SAM_handler(only1_out, map1.header)
        only2_out_file, save_map2_only_reads = make_save_SAM_handler(only2_out, map1.header)
        compare = SequentialMapComparison(map1, map2, save_differently_mapped_reads, save_map1_only_reads,
                                          save_map2_only_reads, out_stream=sys.stdout,   discard_suffix=discard_id_suffix, include_secondary=include_secondary_alignments)
        diff_out_file.close()
        only1_out_file.close()
        only2_out_file.close()
        print >> sys.stderr,  sys.argv[0], 'done.'
    except IndexError:
        usage()
