# Copyright (c) 2013, Ian Reid, Concordia University Centre for Structural and Functional Genomics
# All rights reserved.

'''Accumulate counts of reads mapping within the exons of each annotated transcript; produce simplified output for
use in simulation

Created on Sep 22, 2010
@author: ian
'''

import sys
import os
from collections import defaultdict
import pysam
from pysam import Samfile


def usage(msg=None):
    if msg:
        print >> sys.stderr, msg
    print >> sys.stderr, 'Usage: python %s genes.gff3 output_counts.txt mapped_reads.bam' % sys.argv[0]
    sys.exit(1)


def get_interval_count(scaffold, start, end, bam):
    count = bam.count(reference=scaffold, start=start, end=end)
    return count


def parse_gff_line(line):
    try:
        fields = line.strip().split('\t')
        result = [fields[0], int(fields[3]), int(fields[4])]
        attributes = fields[8].split(';')
        id = None
        attr_dict = {}
        for attribute in attributes:
            if '=' in attribute:
                key, value = attribute.split('=')
            else:
                key, value = attribute.split()
            attr_dict[key] = value
        for key in ['Name', 'Parent', 'ID', 'transcriptId', 'proteinID']:
            if key in attr_dict:
                id = attr_dict[key]
                break
        result.append(id)
    except ValueError:
        result = [None, None, None, None]
    except IndexError:
        result = [None, None, None, None]
    return result


if __name__ == '__main__':
    try:
        gff = open(sys.argv[1])
        output = open(sys.argv[2], 'w')
    except Exception, e:
        usage(e)

    parse_line = parse_gff_line
    gff_lines = gff.readlines()

    bamfiles = sys.argv[3:]

    sample_counts = {}
    bamfile = bamfiles[0]
    try:
        print bamfile
        if not os.path.isfile(bamfile + '.bai'):
            print 'Indexing', bamfile
            pysam.index(bamfile)
        bam = Samfile(bamfile, mode='rb')
        counts = defaultdict(int)
        line_count = 0
        transcript_lengths = defaultdict(int)

        for line in gff_lines:
            try:
                line_count += 1
                scaffold, start, end, id = parse_line(line)
                if scaffold and id:
                    transcript_lengths[id] += end - start + 1
                    count = get_interval_count(scaffold, start, end, bam)
                    counts[id] += count
            except Exception, le:
                print >> sys.stderr, 'Problem processing', line
                print >> sys.stderr, le
            if line_count % 1000 == 0:
                print line_count, len(counts)
        sample_counts[bamfile] = counts
    except Exception, e:
        print >> sys.stderr, 'Problem processing', bamfile
        print >> sys.stderr, e
        print

    for id in sorted(sample_counts[bamfile].keys()):
        data = [id, str(sample_counts[bamfile][id])]
        print >> output, '\t'.join(data)

    output.close()
    print sys.argv[0], 'done.'
