
# Copyright (c) 2013, Ian Reid, Concordia University Centre for Structural and Functional Genomics
# All rights reserved.

"""Produce complete genes or contigs from a GFF3 source"""
import sys, os

from gff3Record import GFF3Record, GFF3Gene, GFF3mRNA, GFF3Exon
from collections import defaultdict

class GFF3Iterator:
    def __init__(self,  source):
        """ source is an iterator over input lines, e.g a file handle"""
        self._source = source
        self._genes = {}
        self._transcripts = {}
        self._ids = [] # remember order of genes in input
        self._orphan_transcripts = defaultdict(list)
        self._orphan_exons = defaultdict(list)
        self._orphan_CDSstarts = defaultdict(list)
        self._orphan_CDSstops = defaultdict(list)

    def genes(self):
        """Creates GFF3Gene objects from a list of .gff3 format lines and returns them one-by-one """
        for line in self._source:
            if len(line.strip()) < 3:
                continue
            if line.startswith('###'): # Signal that all forward references have been resolved
#                assert len(self._orphan_transcripts) == 0
#                assert len(self._orphan_exons) == 0
                for id in self._ids:
                    yield self._genes[id]
                self._ids = []
                self._genes.clear()
                self._transcripts.clear()
            elif line.startswith('#'):
                continue
            elif len(line.split('\t')) > 8:
                record = GFF3Record(*(line.strip().split('\t')))
                self.add_record(record)
            else:
                print >> sys.stderr, 'Strange input line:', line
        for id in self._ids:
            yield self._genes[id]
        self._ids = []

    def gene_iterator(self):
        """ Iterates over the GFF3Genes in a GFF3Iterator that has already been loaded with data (e.g. by add_record)"""
        assert len(self._orphan_transcripts) == 0
        assert len(self._orphan_exons) == 0
        for id in self._ids:
            yield self._genes[id]
        self._ids = []
        self._genes.clear()
        self._transcripts.clear()


    def contigs(self):
        for line in self._source:
            if len(line.strip()) < 3:
                continue
            if line.startswith('###'): # Signal that all forward references have been resolved
                assert len(self._orphan_transcripts) == 0
                assert len(self._orphan_exons) == 0
                for id in self._ids:
                    contig = self.gene2contig(self._genes[id])
                    yield contig
                self._ids = []
                self._genes.clear()
                self._transcripts.clear()
            elif line.startswith('#'):
                continue
            else:
                record = GFF3Record(*(line.strip().split('\t')))
                self.add_record(record)
        for id in self._ids:
            contig = self.gene2contig(self._genes[id])
            yield contig
        self._ids = []

    def gene2contig(self, gene):
        gene.set_type('contig')
        transcript = gene.get_transcripts()[0]
        transcript.set_type('contig')
        for exon in transcript.get_exons():
            exon.set_type('match')
        return gene

    def contigs2genes(self):
        for line in self._source:
            if len(line.strip()) < 3:
                continue
            if line.startswith('###'): # Signal that all forward references have been resolved
                assert len(self._orphan_transcripts) == 0
                assert len(self._orphan_exons) == 0
                for id in self._ids:
                   gene = self._genes[id]
                   for t in gene.get_transcripts():
                       for exon in t.get_exons():
                           exon.set_type('exon')
                   yield gene
                self._ids = []
                self._genes.clear()
                self._transcripts.clear()
            elif line.startswith('#'):
                continue
            else:
                record = GFF3Record(*(line.strip().split('\t')))
                self.add_record(record)
        for id in self._ids:
           gene = self._genes[id]
           for t in gene.get_transcripts():
               for exon in t.get_exons():
                   exon.set_type('exon')
           yield gene
        self._ids = []

    def add_record(self, record):
        """Add a GFF3Record instance corresponding to a single .gff3 line to the data accumulating in a GFF3Iterator"""
        if record.getType() == 'gene' :
            if record.getID() not in self._ids:
                    self._genes[record.getID()] = GFF3Gene.fromRecord(record)
                    self._ids.append(record.getID())
            if record.getID() in self._orphan_transcripts:
                for transcript in self._orphan_transcripts[record.getID()]:
                    self._genes[record.getID()].add_transcript(transcript)
                del self._orphan_transcripts[record.getID()]
        elif record.getType() in [ 'mRNA' , 'transcript', 'contig',  'cDNA_match', 'protein_match',  'expressed_sequence_match']:
            self._transcripts[record.getID()] = GFF3mRNA.fromRecord(record)
            if record.getType() in ['contig',  'cDNA_match', 'protein_match',  'expressed_sequence_match']:
                parent = record.getID()
                self._genes[record.getID()] = GFF3Gene.fromRecord(record)
                self._ids.append(record.getID())
            else:
                parent = record.get_parents()
            if parent in self._genes:
                self._genes[parent].add_transcript(self._transcripts[record.getID()] )
            else:
                self._orphan_transcripts[parent].append(self._transcripts[record.getID()] )
            if record.getID() in self._orphan_exons:
                for exon in self._orphan_exons[record.getID()]:
                    self._transcripts[record.getID()].add_exon(exon)
                del self._orphan_exons[record.getID()]
            if record.getID() in self._orphan_CDSstarts:
                for CDSstart in self._orphan_CDSstarts[record.getID()]:
                    self._transcripts[record.getID()]._CDSstart = CDSstart
                del self._orphan_CDSstarts[record.getID()]
            if record.getID() in self._orphan_CDSstops:
                for CDSstop in self._orphan_CDSstops[record.getID()]:
                    self._transcripts[record.getID()]._CDSstop = CDSstop
                del self._orphan_CDSstops[record.getID()]
        elif record.getType() in ['exon' , 'CDS', 'match',  'match_part',  'three_prime_UTR',  'five_prime_UTR']:
            new_exon = GFF3Exon.fromRecord(record)
            new_exon.set_type(record.getType())
            parents = record.get_parents().rstrip(',').split(',')
            for parent in parents:
                if parent in self._transcripts:
                    self._transcripts[parent].add_exon(new_exon)
                else:
                    self._orphan_exons[parent].append(new_exon)
        elif record.getType() == 'start_codon':
            start = record.get_end() if record.getStrand() == '-' else record.get_start()
            parents = record.get_parents().rstrip(',').split(',')
            for parent in parents:
                if parent in self._transcripts:
                    self._transcripts[parent]._CDSstart = start
                else:
                    self._orphan_CDSstarts[parent].append(start)
        elif record.getType() == 'stop_codon':
            stop = record.get_start() if record.getStrand() == '-' else record.get_end()
            parents = record.get_parents().rstrip(',').split(',')
            for parent in parents:
                if parent in self._transcripts:
                    self._transcripts[parent]._CDSstop = stop
                else:
                    self._orphan_CDSstops[parent].append(stop)


if __name__ == '__main__' :
    iterator = GFF3Iterator(open(sys.argv[1]))
    for gene in iterator.genes():
        print gene
        print

