
# Copyright (c) 2013, Ian Reid, Concordia University Centre for Structural and Functional Genomics
# All rights reserved.

''' Represent a read alignment in SAM  textual format

Created on Sep 25, 2010
@author: ian
'''
import sys,  re
from Bio import SeqIO
from collections import defaultdict

class SAMText(object):
    def __init__(self, Qname):
        self.qname = str(Qname)
        self.flag = 0
        self.rname = '*'
        self.pos = 0
        self.mapq = 255
        self.cigar = '*'
        self.mrnm = '*'
        self.mpos = 0
        self.isize = 0
        self.seq = '*'
        self.qual = '*'
        self.tags = []
        self.alen = 0

    def __str__(self):
        str_list = [self.qname, str(self.flag), self.rname, str(self.pos), str(self.mapq), self.cigar, self.mrnm, str(self.mpos), str(self.isize), self.seq, self.qual]
        str_list.extend(self.tags)
        return '\t'.join(str_list)

    def __hash__(self):
        return (hash(self.qname) + self.flag + hash(self.rname) + hash(self.pos) + hash(self.cigar) + hash(self.mrnm) + self.mpos + self.isize + hash(self.seq) + hash(self.qual)) & sys.maxint

    def __eq__(self, other):
        try:
            return (self.qname == other.qname) and (self.flag == other.flag) and (self.rname == other.rname) and (self.pos  == other.pos) and (self.cigar == other.cigar) and (self.mrnm == other.mrnm) and (self.mpos == other.mpos) and (self.isize == other.isize) and (self.seq == other.seq) and (self.qual == other.qual)
        except:
            return False

    def set_paired_flag(self):
        self.flag |= 1

    def set_proper_pair_flag(self):
        self.flag |= 2

    def set_reverse_strand_flag(self):
        self.flag |= 2**4

    def set_forward_strand_flag(self):
        self.flag &=  ~(2**4)

    def set_mate_reverse_strand_flag(self):
        self.flag |= 2**5

    def set_mate_forward_strand_flag(self):
        self.flag &=  ~(2**5)

    def set_first_of_pair_flag(self):
        self.flag |= 2**6

    def set_second_of_pair_flag(self):
        self.flag |= 2**7

    def is_reversed(self):
        return self.flag & 2**4

    def is_paired(self):
        return self.flag & 1

    def is_properly_paired(self):
        return self.flag & 2

    def is_mate_reversed(self):
        return self.flag & 2**5

    def is_first_of_pair(self):
        return self.flag & 2**6

    def is_second_of_pair(self):
        return self.flag & 2**7

    def is_mapped(self):
        return  (self.flag & 4)^4

    def __len__(self):
        return len(self.seq)

    @staticmethod
    def fromString(line):
        fields = line.strip().split('\t')
        return SAMText.fromFields(fields)

    @staticmethod
    def fromFields(fields):
        result = SAMText(fields[0])
        result.flag = int(fields[1])
        result.rname = fields[2]
        result.pos = int(fields[3])
        result.mapq = int(fields[4])
        result.cigar = fields[5]
        result.mrnm = fields[6]
        result.mpos = int(fields[7])
        result.isize = int(fields[8])
        result.seq = fields[9]
        result.qual = fields[10]
        result.tags = fields[11:]
        result.alen = get_aligned_length(result.cigar)
        return result

def get_cigar_length(cigar):
    lengths = defaultdict(int)
    pieces = re.split('([DIMNSHP=X])',  cigar.strip())[:-1]
    for i in range(0, len(pieces),  2):
        lengths[pieces[i+1]] += int(pieces[i])
    return lengths['M'] + lengths['I'] + lengths['S'] + lengths['='] + lengths['X']

def get_aligned_length(cigar):
    lengths = defaultdict(int)
    pieces = re.split('([DIMNSHP=X])',  cigar.strip())[:-1]
    for i in range(0, len(pieces),  2):
        lengths[pieces[i+1]] += int(pieces[i])
    return lengths['M'] + lengths['I'] + lengths['S'] + lengths['='] + lengths['X'] + lengths['N']

class SAMTextFileIterator(object):
    def __init__(self, filename):
        if filename == '-':
            self.input = sys.stdin
        else:
            self.input = open(filename,  'r')
        self._header = []
        self.next_line = self.input.readline()
        while self.next_line and self.next_line.startswith('@'):
            self._header.append(self.next_line.strip())
            self.next_line = self.input.readline()

    def header(self):
        return '\n'.join(self._header)

    def __iter__(self):
        while True:
            if self.next_line:
                yield SAMText.fromString(self.next_line)
                self.next_line = self.input.readline()
            else:
                raise StopIteration



def generate_SAM_header_from_fai(fai_filename):
    header_lines = []
    fai = open(fai_filename)
    for line in fai:
        fields = line.strip().split()
        ref_name = fields[0]
        ref_length = fields[1]
        hdr_line = '@SQ\tSN:%s\tLN:%s' % (ref_name, ref_length)
        header_lines.append(hdr_line)
    fai.close()
    return '\n'.join(header_lines)

def generate_SAM_header_from_fasta(fasta_filename):
    header_lines = []
    genome_index = SeqIO.index(fasta_filename, 'fasta')
    for ref_name in sorted(genome_index.keys()):
        ref_length = len(genome_index[ref_name])
        hdr_line = '@SQ\tSN:%s\tLN:%s' % (ref_name, ref_length)
        header_lines.append(hdr_line)
    return '\n'.join(header_lines)


