# Copyright (c) 2013, Ian Reid, Concordia University Centre for Structural and Functional Genomics
# All rights reserved.


__author__ = 'ian'

import sys, os, re
from ctabix import Tabixfile
import argparse
this_dir = os.path.dirname(__file__)
src = os.path.dirname(this_dir)
sys.path.append(src)
from lib.gff3Iterator import GFF3Iterator
from calcTranscriptFragmentationProbabilities import get_probs_for_transcript

DESCRIPTION = 'For each transcript, extract origin counts from all ocl files in a list and output stranded ' \
              'probability vectors in a format suitable for R'
VERSION = '0.1'

id_pat = re.compile('(R\d{4})')


def get_args():
    argparser = argparse.ArgumentParser(description=DESCRIPTION)
    # standard options
    argparser.add_argument('--version', action='version', version='%(prog)s' + VERSION)
    argparser.add_argument('--verbose', '-v', action='count', default=0,
                           help='Omit to see only fatal error messages; -v to see warnings; -vv to see warnings and '
                                'progress messages')
    # options to customize
    argparser.add_argument('--in', '-i', dest='input', type=argparse.FileType('r'), nargs='?', default=sys.stdin,
                           help='Path to the transcript input gff3 file; if omitted or -, input is read from stdin')
    argparser.add_argument('--out_dir', '-o', help='Path to the output file directory; required')
    argparser.add_argument('--counts', '-c', nargs='+',
                           help='Space-delimited list of tabix-indexed origin count file names; required')
    return argparser.parse_args()


if __name__ == '__main__':
    args = get_args()

    genes = GFF3Iterator(args.input).genes()
    for gene in genes:
        for transcript in gene.get_transcripts():
            name = transcript.get_name()
            outfile = open(os.path.join(args.out_dir, name + '.probs.txt'), 'w')
            print >> outfile, 'Forward\tReverse\tPosition\tSample'
            for count_file in args.counts:
                file_id = count_file
                count_tabix = Tabixfile(count_file)
                match = id_pat.search(count_file)
                if match:
                    file_id = match.group(1)
                vector = get_probs_for_transcript(transcript, count_tabix)
                for i in range(len(vector)):
                    print >> outfile, '%.8f\t%.8f\t%d\t%s' % (vector.plus[i], vector.minus[i], i, file_id)
        outfile.close()

    print >> sys.stderr, sys.argv[0], 'done.'
