
# Copyright (c) 2013, Ian Reid, Concordia University Centre for Structural and Functional Genomics
# All rights reserved.

"""Model GFF3 records generically, and genes, mRNAs, and exons specifically."""
import  re
from types import StringTypes, DictionaryType
from Bio import SeqRecord, Seq, Alphabet
from Bio.SeqFeature import SeqFeature, FeatureLocation

PAD_LENGTH = 99

def intervals_overlap(start0, end0, start1, end1):
    if end1 < start0 :
        return False
    if start1 > end0 :
        return False
    return True


class GFF3Record(object):
    """Generic GFF3 record"""
    def __init__(self, seqID, source, so_type, start, end, score, strand, phase, attributes):
        self._seqID = seqID
        self._source = source
        self._type = so_type
        self._start = int(start)
        self._end = int(end)
        self._score = score
        self._strand = strand
        self._phase = phase
        self.set_attributes(attributes)

    def get_seqID(self):
        return self._seqID

    def set_seqID(self, scaffold):
        self._seqID = str(scaffold)

    def get_score(self):
        return self._score

    def getScore(self):
        return self.get_score()

    def get_strand(self):
        return self._strand

    def getStrand(self):
        return self.get_strand()

    def set_score(self, value):
        self._score = value

    def setScore(self, value):
        self.set_score(value)

    def set_strand(self, value):
        self._strand = value

    def setStrand(self, value):
        self.set_strand(value)

    def get_attr_string(self):
        attr_list = ['%s=%s' % pair for pair in sorted(self._attributes.items())]
        return ';'.join(attr_list)

    def get_field_list(self):
        return [self._seqID, self._source, self._type, str(self._start), str(self._end), str(self._score), self._strand, str(self._phase), self.get_attr_string()]

    def __str__(self):
        return '\t'.join(self.get_field_list())

    def __repr__(self):
        return 'GFF3Record(%s)' % self.__str__()

    def __len__(self):
        return int(1 + abs(self._end - self._start))

    def chrom_num(self):
        pat = re.compile(r"\D*(\d+)")
        return int(pat.match(self._seqID).group(1))

    def get_start(self):
        return self._start

    def get_end(self):
        return self._end

    def set_start(self, start):
        self._start = int(start)

    def set_end(self, end):
        self._end = int(end)

    def get_attributes(self):
        return self._attributes

    def set_attributes(self, attributes):
        self._attributes = {}
        if isinstance(attributes, StringTypes) :
            for field in attributes.strip(';').split(';'):
                try:
                    (key,val) = field.split('=')
                    self._attributes[key] = val
                except:
                    continue
        elif isinstance(attributes,DictionaryType):
            self._attributes = attributes
        else:
            raise TypeError('Inappropriate type %s for GFF3Record attributes' % type(attributes))

    def add_attribute(self,  key,  value):
        self._attributes[key] = value

    def set_attribute(self,  key,  value):
        self.add_attribute(key,  value)

    def has_attribute(self,  key):
        return key in self._attributes

    def get_attribute(self,  key):
        return self._attributes[key]

    def remove_attribute(self,  key):
        del self._attributes[key]

    def getType(self):
        return self._type

    def get_type(self):
        return self._type

    def set_type(self, so_type):
        self._type = str(so_type)

    def get_source(self):
        return self._source

    def set_source(self, source):
        self._source = str(source)

    def get_phase(self):
        return self._phase

    def get_name(self):
        return self._attributes.get('Name')

    def set_name(self, name):
        self._attributes['Name'] = str(name)

    def get_ID(self):
        return self._attributes.get('ID')

    def getID(self):
        return self.get_ID()

    def set_ID(self, id):
        self._attributes['ID'] = str(id)

    def setID(self, id):
        self.set_ID(id)

    def get_parents(self):
        return self._attributes.get('Parent')

    def set_parents(self, parents):
        self._attributes['Parent'] = str(parents)

    def clear_parents(self):
        if 'Parent' in self._attributes:
            del self._attributes['Parent']

    def overlaps(self, other):
        if other._seqID != self._seqID :
            return False
        if other.get_end() < self.get_start() :
            return False
        if other.get_start() > self.get_end() :
            return False
        return True

    def includes(self, other):
        if other._seqID != self._seqID :
            return False
        if other.get_end() > self.get_end() :
            return False
        if other.get_start() < self.get_start() :
            return False
        return True


    def __cmp__(self, other):
        if not other:
            return 1
        try:
            if self.get_seqID() != other.get_seqID():
                return -1 if self.get_seqID() < other.get_seqID() else 1
            elif self.get_start() != other.get_start():
                return -1 if self.get_start() < other.get_start() else 1
            elif self.get_end() != other.get_end():
                return -1 if self.get_end() < other.get_end() else 1
            else:
                return 0
        except AttributeError, e:
            print e
            return 1

    def __eq__(self,  other):
        result = isinstance(other, GFF3Record)
        result = result and self.get_seqID() == other.get_seqID()
        result = result and self.getStrand() == other.getStrand()
        result = result and self.get_start() == other.get_start()
        result = result and self.get_end() == other.get_end()
        return result

    def clone(self):
        return GFF3Record(self._seqID, self._source, self._type, str(self._start), str(self._end), str(self._score), self._strand, self._phase, self.get_attr_string())

    def asSeqFeature(self):
        location = FeatureLocation(self._start, self._end)
        strand = {"+":1, "-": -1, ".":None}[self._strand]
        return SeqFeature(location=location, type=self._type, location_operator='join', strand=strand, id=self.getID(), qualifiers=self._attributes, sub_features=[], ref=None, ref_db=None)
# class GFF3Record ends here

class GFF3Gene(GFF3Record):
    """A gene normally does not have a parent
    """
    def __init__(self, seqID, source, start, end, score, strand, phase, attributes):
        super(GFF3Gene, self).__init__( seqID, source, 'gene', start, end, score, strand, phase, attributes)
        self._transcripts = {}

    def get_parent(self):
        return None

    def add_transcript(self, transcript):
        transcript.set_parents(self.getID())
        self._transcripts[transcript.getID()] = transcript
        self._start = min(self._start, transcript.get_start())
        self._end = max(self._end, transcript.get_end())

    def get_transcripts(self):
        return self._transcripts.values()

    def set_transcripts(self,  transcript_list):
        self._transcripts = {}
        for transcript in transcript_list:
            self.add_transcript(transcript)

    def set_strand(self, value):
        self._strand = value
        for transcript in self.get_transcripts():
            transcript.set_strand(value)

    def __str__(self):
        result = [GFF3Record.__str__(self)]
        for transcript in self.get_transcripts():
            result.append(str(transcript))
        result.append('###')
        return '\n'.join(result)

    def toString_with_start_and_stop_codons(self):
        result = [GFF3Record.__str__(self)]
        for transcript in self.get_transcripts():
            result.append(transcript.toString_with_start_and_stop_codons())
        result.append('###')
        return '\n'.join(result)

    def __eq__(self,  other):
        result = isinstance(other,  GFF3Gene) and GFF3Record.__eq__(self,  other)
        if result:
            other_transcripts = other.get_transcripts()[:]
            for transcript in self.get_transcripts():
                match = False
                for o_transcript in other_transcripts:
                    if transcript == o_transcript:
                        match = True
                        other_transcripts.remove(o_transcript)
                        break
                result = result and match
        return result

    def get_CDS_length(self):
        if self._transcripts:
            return self.get_transcripts()[0].get_CDS_length()
        else:
            return self.__len__()

    def CDS_matches(self,  other):
        if isinstance(other,  GFF3Gene):
            o_transcript = other.get_transcripts()[0]
        elif isinstance(other,  GFF3mRNA):
            o_transcript = other
        else:
            return False
        return self.get_transcripts()[0].CDS_matches(o_transcript)

    def get_forward_sequence(self, genome_index):
        seq_start = min(self.get_start(), self.get_end()) -1
        seq_end = max(self.get_start(), self.get_end())
        unspliced = genome_index[self.get_seqID()][seq_start : seq_end]
        return unspliced

    def get_gene_sequence(self, genome_index):
        seq_rec = self.get_forward_sequence(genome_index)
        if self._strand == '-':
            seq_rec = SeqRecord.SeqRecord(seq_rec.seq.reverse_complement())
        seq_rec.id = self.get_name()
        seq_rec.description = 'Genomic sequence'
        return seq_rec


    @staticmethod
    def fromRecord(gff3record):
        fields = gff3record.get_field_list()
        del fields[2]
        return GFF3Gene(*fields)

# end of class GFF3Gene

class GFF3mRNA(GFF3Record):
    """ The Parent of an mRNA is a gene.
    """
    def __init__(self, seqID, source, start, end, score, strand, phase, attributes):
        super(GFF3mRNA, self).__init__( seqID, source, 'mRNA', start, end, score, strand, phase, attributes)
        self._exons = []
        self._CDSstart = int(end) if strand == '-' else int(start)
        self._CDSstop = int(start) if strand == '-' else int(end)

    def clone(self):
        return GFF3mRNA(self._seqID, self._source, str(self._start), str(self._end), str(self._score), self._strand, self._phase, self.get_attr_string())

    def get_CDS_start(self):
        return self._CDSstart

    def get_CDS_stop(self):
        return self._CDSstop


    def set_CDS_start(self, _start):
        self._CDSstart = int(_start)
        self._start = min(self._start,  _start)
        self._end = max(self._end,  _start)

    def set_CDS_stop(self, _stop):
        self._CDSstop = int(_stop)
        self._start = min(self._start,  _stop)
        self._end = max(self._end,  _stop)

    def get_parent(self):
        return self._attributes.get("Parent")

    def clear_exons(self):
        self._exons = []

    def add_exon(self, new_exon):
        for old_exon in self._exons:
            #check for overlap
            if intervals_overlap(old_exon.get_start(), old_exon.get_end() - 1, new_exon.get_start(), new_exon.get_end() -1 ):
                if new_exon.get_start() == old_exon.get_start() and new_exon.get_end() == old_exon.get_end():
                    return
                elif new_exon.get_start() <= old_exon.get_start() and new_exon.get_end() >= old_exon.get_end():
                    del old_exon
                    break
                elif old_exon.get_start() > new_exon.get_start() or old_exon.get_end() < new_exon.get_end():
                    raise ValueError('Overlapping but non-identical exons in transcript: %s, %s' % (str(old_exon), str(new_exon)))
                # else silently ignore duplicate exon
                else:
                    return
        self._exons.append(new_exon)
        try:
            new_exon.add_parent(self.getID())
            new_exon.setID(self.getID()+'.cds')
        except AttributeError, e:
            print e
        self._start = min(self._start, new_exon.get_start())
        self._end = max(self._end, new_exon.get_end())

    def add_start_codon(self,  codon):
        codon.set_type('start_codon')
        #self._exons.append(codon)
        new_start = codon.get_end() if self.getStrand() == '-' else codon.get_start()
        self.set_CDS_start(new_start)
        self._start = min(self._start, new_start)
        self._end = max(self._end, new_start)

    def add_stop_codon(self,  codon):
        codon.set_type('stop_codon')
#        self._exons.append(codon)
        new_stop = codon.get_start() if self.getStrand() == '-' else codon.get_end()
        self.set_CDS_stop(new_stop)
        self._start = min(self._start, new_stop)
        self._end = max(self._end, new_stop)



    def get_exons(self):
        return self._exons

    def isInGene(self, geneID):
        return self.get_parent() == geneID

    def make_start_codon(self):
        start_codon = GFF3Record.clone(self)
        start_codon.set_type('start_codon')
        start_codon_start = self._CDSstart if self.get_strand() != '-' else int(self._CDSstart) - 2
        start_codon_end = int(self._CDSstart) + 2 if self.get_strand() != '-' else self._CDSstart
        start_codon.set_start(start_codon_start)
        start_codon.set_end(start_codon_end)
        start_codon.set_parents(self.get_ID())
        start_codon.set_ID(self.get_ID() + '.start')
        return start_codon

    def make_stop_codon(self):
        stop_codon = GFF3Record.clone(self)
        stop_codon.set_type('stop_codon')
        stop_codon_start = int(self._CDSstop) - 2 if self.get_strand() != '-' else self._CDSstop
        stop_codon_end = self._CDSstop if self.get_strand() != '-' else int(self._CDSstop) + 2
        stop_codon.set_start(stop_codon_start)
        stop_codon.set_end(stop_codon_end)
        stop_codon.set_parents(self.get_ID())
        stop_codon.set_ID(self.get_ID() + '.stop')
        return stop_codon

    def toString_with_start_and_stop_codons(self):
        result = [GFF3Record.__str__(self)]
        result.append(str(self.make_start_codon()))
        for exon in self.get_exons():
            result.append(str(exon))
        result.append(str(self.make_stop_codon()))
        return '\n'.join(result)

    def set_strand(self, value):
        self._strand = value
        for exon in self.get_exons():
            exon.set_strand(value)

    def __str__(self):
        result = [GFF3Record.__str__(self)]
        if (self.get_strand() == '+' and self.get_CDS_start() != self.get_start()) or (self.get_strand() == '-' and self.get_CDS_start() != self.get_end()):
            #start_codon = GFF3Record.clone(self)
            #start_codon.set_type('start_codon')
            #start_codon_start = self._CDSstart if self.get_strand() != '-' else int(self._CDSstart) - 2
            #start_codon_end = int(self._CDSstart) + 2 if self.get_strand() != '-' else self._CDSstart
            #start_codon.set_start(start_codon_start)
            #start_codon.set_end(start_codon_end)
            #start_codon.set_parents(self.get_ID())
            #start_codon.set_ID(self.get_ID() + '.start')
            #result.append(str(start_codon))
            result.append(str(self.make_start_codon()))
        for exon in self.get_exons():
            result.append(str(exon))
        if (self.get_strand() == '+' and self.get_CDS_stop() != self.get_end()) or (self.get_strand() == '-' and self.get_CDS_stop() != self.get_start()):
            #stop_codon = GFF3Record.clone(self)
            #stop_codon.set_type('stop_codon')
            #stop_codon_start = int(self._CDSstop) - 2 if self.get_strand() != '-' else self._CDSstop
            #stop_codon_end = self._CDSstop if self.get_strand() != '-' else int(self._CDSstop) + 2
            #stop_codon.set_start(stop_codon_start)
            #stop_codon.set_end(stop_codon_end)
            #stop_codon.set_parents(self.get_ID())
            #stop_codon.set_ID(self.get_ID() + '.stop')
            #result.append(str(stop_codon))
            result.append(str(self.make_stop_codon()))
        return '\n'.join(result)

    def __eq__(self,  other):
        result = isinstance(other,  GFF3mRNA) and GFF3Record.__eq__(self,  other)
        o_exons = sorted(other.get_exons()[:])
        for s_exon,  o_exon in zip(sorted(self.get_exons()),  o_exons):
            result = result and s_exon == o_exon
        return result

    def CDS_matches(self,  other):
        result = isinstance(other,  GFF3mRNA)
        result = result and self.extract_CDS() == other.extract_CDS()
        return result

    def extract_CDS(self):
        CDS = self.clone()
        CDS.set_start(min(self._CDSstart, self._CDSstop))
        CDS.set_end(max(self._CDSstart, self._CDSstop))
        CDS.set_CDS_start(self._CDSstart)
        CDS.set_CDS_stop(self._CDSstop)
        for exon in sorted(self.get_exons()):
            if exon.get_end() < CDS.get_start() or exon.get_start() > CDS.get_end():
                continue
            cds_part=exon.clone()
            cds_part.set_type('CDS')
            if cds_part.get_start() < CDS.get_start():
                cds_part.set_start(CDS.get_start())
            if cds_part.get_end() > CDS.get_end():
                cds_part.set_end(CDS.get_end())
            CDS.add_exon(cds_part)
        cds_parts = CDS.get_exons()
        # Correct clipping of terminal exon(s)
        if cds_parts[0].get_start() > CDS.get_start():
            cds_parts[0].set_start(CDS.get_start())
        if cds_parts[-1].get_end() < CDS.get_end():
            cds_parts[-1].set_end(CDS.get_end())
        return CDS

    def get_forward_sequence(self, genome_index):
        seq_start = min(self.get_start(), self.get_end()) -1
        seq_end = max(self.get_start(), self.get_end())
        unspliced = genome_index[self.get_seqID()][seq_start : seq_end]
        exon_seqs = []
        exons = sorted(self.get_exons())
        # Correct clipping of terminal exon(s)
        if exons[0].get_start() > self.get_start():
            exons[0].set_start(self.get_start())
        if exons[-1].get_end() < self.get_end():
            exons[-1].set_end(self.get_end())
        for exon in exons:
            rel_start = exon.get_start() - seq_start -1
            if rel_start < 0:
                raise ValueError('[GFF3mRNA.get_forward_sequence] Negative relative start')
            rel_end = exon.get_end() - seq_start
            exon_seq = unspliced[rel_start:rel_end]
            exon_seqs.append(exon_seq)
        spliced = None
        if exon_seqs:
            spliced = exon_seqs[0]
            for exon_seq in exon_seqs[1:]:
                spliced += exon_seq
        return spliced

    def get_transcript_sequence(self, genome_index):
        seq_rec = self.get_forward_sequence(genome_index)
        if self._strand == '-':
            seq_rec = SeqRecord.SeqRecord(seq_rec.seq.reverse_complement())
        return seq_rec

    def get_genomic_coordinate(self,  transcript_coordinate):
        if 0 <= transcript_coordinate <= self.get_transcript_length():
            exons = sorted(self.get_exons())
            residual = transcript_coordinate
            if self.get_strand() == '-':
                for exon in reversed(exons):
                    if residual <= len(exon):
                        result = exon.get_end() - residual
                        break
                    else:
                        residual -= len(exon)
            else:
                for exon in exons:
                    if residual <= len(exon):
                        result = exon.get_start() + residual
                        break
                    else:
                        residual -= len(exon)
            return result
        else:
            raise ValueError('Coordinate out of range')

    def get_utr5_length(self):
        if self._CDSstart == self._CDSstop:
            return 0  ## work around error in setting start and stop codons
        ## The UTR may contain introns ##
        exons = sorted(self.get_exons())
        length = 0
        if self._strand == '-':
            exons.reverse()
            for x in exons:
                if x.get_end() <= self._CDSstart:
                    break
                length += 1 + x.get_end() - max(x.get_start(),  self._CDSstart + 1)
        else:
            for x in exons:
                if x.get_start() >= self._CDSstart:
                    break
                length += 1 + min(x.get_end(),  self._CDSstart - 1) - x.get_start()
        return length

    def get_utr3_length(self):
        if self._CDSstart == self._CDSstop:
            return 0  ## work around error in setting start and stop codons
        ## The UTR may contain introns ##
        exons = sorted(self.get_exons())
        length = 0
        if self._strand == '-':
            for x in exons:
                if x.get_start() >= self._CDSstop:
                    break
                length += 1 + min(x.get_end(),  self._CDSstop - 1) - x.get_start()
        else:
            exons.reverse()
            for x in exons:
                if x.get_end() <= self._CDSstop:
                    break
                length += 1 + x.get_end() - max(x.get_start(),  self._CDSstop + 1)
        return length

    def get_transcript_length(self):
        return sum([len(x) for x in self._exons])

    def get_CDS_length(self):
        return self.get_transcript_length() - self.get_utr3_length() - self.get_utr5_length()

    def get_coding_sequence(self, genome_index):
        utr5 = self.get_utr5_length()
        utr3 = self.get_utr3_length()
        if utr3 > 0:
            return self.get_transcript_sequence(genome_index)[utr5 : -utr3]
        else:
            return self.get_transcript_sequence(genome_index)[utr5 :]

    def get_padded_coding_sequence(self, genome_index, pad_length= PAD_LENGTH):
        cds = self.get_coding_sequence(genome_index)
        left = min(self._CDSstart, self._CDSstop)
        right = max(self._CDSstart, self._CDSstop)
        left_start = max(0, left - pad_length-1)
        left_pad = genome_index[self.get_seqID()][left_start  : left-1].seq
        right_pad = genome_index[self.get_seqID()][right: right + pad_length ].seq if pad_length > 0 else Seq.Seq('')
        if self._strand == '-':
            return str(right_pad.reverse_complement()) + cds + str(left_pad.reverse_complement())
        else:
            return str(left_pad) + cds + str(right_pad)


    @staticmethod
    def fromRecord(gff3record):
        fields = gff3record.get_field_list()
        del fields[2]
        return GFF3mRNA(*fields)

# end of class GFF3mRNA

class GFF3Exon(GFF3Record):
    """The Parent of an exon is a set of mRNAs, implemented here as a comma-separated list of strings
    """
    def __init__(self, seqID, source, start, end, score, strand, phase, attributes):
        super(GFF3Exon, self).__init__( seqID, source, 'exon', start, end, score, strand, phase, attributes)

    def clone(self):
        return GFF3Exon(self._seqID, self._source, str(self._start), str(self._end), str(self._score), self._strand, self._phase, self.get_attr_string())

    def get_parent(self):
        if self.get_parents():
            return self.get_parents().split(',')
        return None

    def add_parent(self, parent):
        parents = self.get_parent()
        if parents:
            if parent not in parents:
                parents.append(parent)
        else :
            parents = [parent]
        self.set_parents(','.join(parents))

    @staticmethod
    def fromRecord(gff3record):
        fields = gff3record.get_field_list()
        del fields[2]
        return GFF3Exon(*fields)

    def isInTranscript(self, transcriptID):
        parents = self.get_parent()
        if parents and transcriptID in parents:
            return True
        return False

    def isInGene(self, geneID):
        parents = self.get_parent()
        if parents:
            for parent in parents.values():
                if parent.isInGene(geneID):
                    return True
        return False

    def get_forward_sequence(self, genome_index):
        seq_start = min(self.get_start(), self.get_end()) -1
        seq_end = max(self.get_start(), self.get_end())
        return genome_index[self.get_seqID()][seq_start : seq_end].seq

# end of class GFF3Exon

