
# Copyright (c) 2013, Ian Reid, Concordia University Centre for Structural and Functional Genomics
# All rights reserved.

''' Represent a read alignment in SAM format

Created on Sep 25, 2010
@author: ian
'''
import sys, re
import pysam
from pysam import AlignedRead

cigar_opcode_pat = re.compile(r"(M|N|D|I|S|H|P)")

class SAMHeader(dict):
    def __init__(self):
        dict.__init__(self)
        dict.__setitem__(self, 'HD', {'VN':'1.3'})
        dict.__setitem__(self, 'SQ', [])
        dict.__setitem__(self, 'RG', [])
        dict.__setitem__(self, 'PG', [])
        dict.__setitem__(self, 'CO', [])

    def __setitem__(self, key, value):
        if key in ('HD','SQ',  'RG','PG',  'CO'):
            dict.__setitem__(self, key, value)
        else:
            raise KeyError('SAMHeader.__setitem__, Illegal key: %s' % key)

    def add_refseq(self, ref_name, ref_length):
        self.get('SQ').append({'SN':ref_name, 'LN':int(ref_length) })

    def populate_refseqs(self, name_list, length_list):
        for ref_name, ref_length in zip(name_list, length_list):
            self.add_refseq(ref_name, ref_length)

    def find_tid(self, ref_name):
        tid = -1
        for i, ref in enumerate(self.get('SQ')):
            if ref['SN'] == ref_name:
                tid = i
                break
        return tid

    def get_ref_name(self, tid):
        return self.get('SQ')[tid].get('SN')

    def __str__(self):
        lines = []
        if 'HD' in self.dict:
            lines.append('@HD\t%s' % '\t'.join(['%s:%s' % (key,  str(value)) for key,  value in self.get("HD").items()]))
        for superkey in ('SQ','RG','PG'):
            if superkey in self.dict:
                for subdict in  self.get(superkey):
                    lines.append("@%s\t%s" %  (superkey,'\t'.join(['%s:%s' % (key,  str(value)) for key,  value in subdict.items()])))
        if 'CO' in self.dict:
            for item in self.get('CO'):
                lines.append('@CO\t%s' % item)
        return '\n'.join(lines)


    @staticmethod
    def from_lines(lines):
        result = SAMHeader()
        for line in lines:
            fields = line.strip().split()
            ref_name = fields[0]
            ref_length = fields[1]
            result.add_refseq(ref_name, ref_length)
        return result

    @staticmethod
    def from_fai(fai_filename):
        handle = open(fai_filename)
        result = SAMHeader.from_lines(handle)
        handle.close()
        return result

    @staticmethod
    def from_samfile_header(header):
        result = SAMHeader()
        result.update(header)
        return result

class SAMRead(object):
    def __init__(self, header, aligned_read=None):
        object.__init__(self)
        self._bam_header = header
        if aligned_read:
            self._read = aligned_read
        else:
            self._read = AlignedRead()
            self._read.flag = 4 # unpaired, unreversed, and unmapped
            self._read.rname = -1
            self._read.mrnm  = -1
            self._read.pos = -1
            self._read.mpos = -1

    def write_to_samfile(self,  samfile):
        '''The header in samfile should match the _bam_header of this read'''
        samfile.write(self._read)

    def __str__(self):
        str_list = [self.qname, str(self.flag), self.rname, str(self.pos + 1), str(self.mapq), self.cigar, self.mrnm, str(self.mpos + 1), str(self.isize), self.seq, self.qual]
        tags = self.tags
        if tags:
            str_list.extend(self.tags)
        return '\t'.join(str_list)

    def __hash__(self):
        return (hash(self.qname) + self.flag + hash(self.rname) + hash(self.pos) + hash(self.cigar) + hash(self.mrnm) + self.mpos + self.isize + hash(self.seq) + hash(self.qual)) & sys.maxint

    def __eq__(self, other):
        try:
            return (self.qname == other.qname) and (self.flag == other.flag) and (self.rname == other.rname) and (self.pos  == other.pos) and (self.cigar == other.cigar) and (self.mrnm == other.mrnm) and (self.mpos == other.mpos) and (self.isize == other.isize) and (self.seq == other.seq) and (self.qual == other.qual)
        except:
            return False

    def is_paired(self):
        return self._read.is_paired

    def set_paired_flag(self):
        self._read.is_paired = True

    def is_properly_paired(self):
        return self._read.is_proper_pair

    def set_proper_pair_flag(self):
        self._read.is_proper_pair = True

    def is_mapped(self):
        return self._read.is_unmapped == False

    def set_unmapped_flag(self):
        self._read.is_unmapped = True

    def set_mapped_flag(self):
        self._read.is_unmapped = False

    def set_reverse_strand_flag(self):
        self._read.is_reverse = True

    def is_reversed(self):
        return self._read.is_reverse

    def set_forward_strand_flag(self):
        self._read.is_reverse = False

    def is_mate_reversed(self):
        return self._read.mate_is_reverse

    def set_mate_reverse_strand_flag(self):
        self._read.mate_is_reverse = True

    def set_mate_forward_strand_flag(self):
        self._read.mate_is_reverse = False

    def is_first_of_pair(self):
        return self._read.is_read1

    def set_first_of_pair_flag(self):
        self._read.is_read1 = True
        self._read.is_read2 = False
        self._read.is_paired = True

    def is_second_of_pair(self):
        return self._read.is_read2

    def set_second_of_pair_flag(self):
        self._read.is_read1 = False
        self._read.is_read2 = True
        self._read.is_paired = True

    @staticmethod
    def fromString(line, header):
        fields = line.strip().split('\t')
        return SAMRead.fromFields(fields, header)

    @staticmethod
    def fromFields(fields, header):
        result = SAMRead(header)
        result.qname = fields[0]
        result.flag = int(fields[1])
        result.rname = fields[2]
        result.pos = int(fields[3]) - 1 # change from 1-based to 0-based
        result.mapq = int(fields[4])
        result.cigar = fields[5]
        result.mrnm = fields[6]
        result.mpos = int(fields[7]) - 1
        result.isize = int(fields[8])
        result.seq = fields[9]
        result.qual = fields[10]
        result.tags = fields[11:]
        return result

    def clone(self):
        return SAMRead.fromFields([self.qname,  self.flag,  self.rname,  self.pos + 1,  self.mapq,  self.cigar,  self.mrnm,  self.mpos + 1,  self.isize,  self.seq,  self.qual] + self.tags, self._bam_header)

    def get_end_position(self):
        return self._read.aend # position is 0-based

    def get_ref_len(self, refname):
        tid = self.find_tid(refname)
        return self._bam_header['SQ'][tid]['LN']

    def set_ref_len(self, refname, length):
        tid = self.find_tid(refname)
        self._bam_header['SQ'][tid]['LN'] = int(length)

    def get_qname(self):
        return self._read.qname

    def set_qname(self, value):
        self._read.qname = value

    def get_flag(self):
        return self._read.flag

    def get_pos(self):
        return self._read.pos

    def get_mpos(self):
        return self._read.mpos

    def get_isize(self):
        return self._read.isize

    def get_seq(self):
        if self._read.seq:
            return self._read.seq
        return '*'

    def get_qual(self):
        if self._read.qual:
            return self._read.qual
        return '*'

    def get_mapq(self):
        return self._read.mapq

    def set_flag(self, value):
        self._read.flag = value

    def set_pos(self, value, extend_ref=False):
        # pos is 0-based in AlignedRead, so 0-based here for compatibility
        val = int(value)
        if extend_ref and val > self.get_ref_len(self.rname):
            self.set_ref_len(self.rname, val + 1)
        self._read.pos = val

    def set_mpos(self, value, extend_ref=False):
        # pos is 0-based in AlignedRead, so 0-based here for compatibility
        val = int(value)
        if extend_ref and val > self.get_ref_len(self.mrnm):
            self.set_ref_len(self.mrnm, val + 1)
        self._read.mpos = val

    def set_isize(self, value):
        self._read.isize = value

    def set_seq(self, value):
        self._read.seq = value

    def set_qual(self, value):
        self._read.qual = value

    def set_mapq(self, value):
        self._read.mapq = value

    def get_tags(self):
        tagstrs = []
        read_tags = self._read.tags
        if read_tags:
            for tagname, tagvalue in read_tags:
                if isinstance(tagvalue, int):
                    tagstrs.append('%s:i:%d' % (tagname, tagvalue))
                elif isinstance(tagvalue, float):
                    tagstrs.append('%s:f:%f' % (tagname, tagvalue))
                elif tagvalue.isdigit() and int(tagvalue) <= 255 and tagname != 'MD': #A-type values get converted to string representation of character code
                    tagstrs.append('%s:A:%s' % (tagname, chr(int(tagvalue))))
                elif len(tagvalue) > 1:
                    tagstrs.append('%s:Z:%s' % (tagname, tagvalue))
                else:
                    tagstrs.append('%s:A:%s' % (tagname, tagvalue))
        return tagstrs

    def get_cigar(self):
        cigar_opcodes = 'MIDNSHP'
        result = []
        cigar = self._read.cigar
        if cigar:
            for code, count in self._read.cigar:
                result.append('%d%s' % (count, cigar_opcodes[code]))
        if result:
            return ''.join(result)
        return '*'

    def get_mrnm(self):
        if self._read.mrnm == -1:
            return '*'
        return self._bam_header.get_ref_name(self._read.mrnm)

    def set_tags(self, tag_list):
        tuples = []
        for tag in tag_list:
            tagname, tagtype, tagvalue = tag.split(':',  2)
            if tagtype == 'i':
                tuples.append((tagname, int(tagvalue)))
            elif tagtype == 'f':
                tuples.append((tagname, float(tagvalue)))
            elif tagtype == 'A':
                tuples.append((tagname, tagvalue))
            else:
                tuples.append((tagname, tagvalue))
        self._read.tags = tuples

    def set_cigar(self, value):
        cigar_opcodes = 'MIDNSHP'
        result = []
        remainder = value
        while remainder:
            if cigar_opcode_pat.search(remainder):
                width_s, opcode, remainder = cigar_opcode_pat.split(remainder, 1)
                result.append([cigar_opcodes.index(opcode), int(width_s)])
        self._read.cigar = result

    def set_mrnm(self, value, extend_refs=False):
        self._read.mrnm = self.find_tid(value, extend_refs)

    def get_rname(self):
        """Name of the reference sequence """
        if self._read.rname == -1:
            return '*'
        return self._bam_header.get_ref_name(self._read.rname)

    def find_tid(self, value, extend_refs=False):
        if value == '*':
            return -1
        return self._bam_header.find_tid(value)

    def set_rname(self, value, extend_refs=False):
        self._read.rname = self.find_tid(value, extend_refs)

    def get_alen(self):
        '''Aligned length of the read'''
        return self._read.alen

    def get_header(self):
        '''The SAMHeader instance that defines the mapping from rname to tid for this read'''
        return self._bam_header

    def set_header(self,  header):
        self._bam_header = header

    rname = property(get_rname, set_rname)
    tags = property(get_tags, set_tags)
    cigar = property(get_cigar, set_cigar)
    mrnm = property(get_mrnm, set_mrnm)
    flag = property(get_flag, set_flag)
    pos = property(get_pos, set_pos)
    mpos = property(get_mpos, set_mpos)
    isize = property(get_isize, set_isize)
    seq = property(get_seq, set_seq)
    qual = property(get_qual, set_qual)
    mapq = property(get_mapq, set_mapq)
    qname = property(get_qname, set_qname)
    alen = property(get_alen)
    header = property(get_header,  set_header)
#    ref_len = property(get_ref_len, set_ref_len)

class SAMReadIterator(object):
    def __init__(self):
        object.__init__(self)

    def iterator(self):
        '''Dummy implementation'''
        raise StopIteration

class SAMReadFileIterator(SAMReadIterator):
    def __init__(self, filename, mode=None):
        SAMReadIterator.__init__(self)
        if mode == None:
            if filename.endswith('.sam'):
                mode = 'r'
            elif filename.endswith('.bam'):
                mode = 'rb'
            else:
                raise ValueError('[Reads.samRead.SAMReadFileIterator] Format of SAM/BAM file %s is unknown' % filename)
        self.samfile = pysam.Samfile(filename, mode=mode)
        self.header = SAMHeader.from_samfile_header(self.samfile.header)

    def iterator(self):
        for read in self.samfile:
            sam_read = SAMRead(self.header, read)
            yield sam_read

    def __iter__(self):
        for read in self.samfile:
            sam_read = SAMRead(self.header, read)
            yield sam_read



def generate_SAM_header_from_fai(fai_filename):
    header_lines = []
    fai = open(fai_filename)
    for line in fai:
        fields = line.strip().split()
        ref_name = fields[0]
        ref_length = fields[1]
        hdr_line = '@SQ\tSN:%s\tLN:%s' % (ref_name, ref_length)
        header_lines.append(hdr_line)
    fai.close()
    return '\n'.join(header_lines)
