
# Copyright (c) 2013, Ian Reid, Concordia University Centre for Structural and Functional Genomics
# All rights reserved.

'''
Handle conversion from transcript-relative (spliced) to genome-relative (unspliced) coordinates
Created on Jan 12, 2010
@author: ian
'''

import cPickle, sys,os
this_dir = os.path.dirname(__file__)
src = os.path.dirname(this_dir)
sys.path.append(src)
from gff3Iterator import GFF3Iterator
from bisect import *

class scaffold_coords(object):
    def __init__(self, seqID, strand='+'):
        object.__init__(self)
        self.seq_id = str(seqID)
        self.strand = str(strand)
        self.spliced_coords = []
        self.unspliced_coords = []
        self.unspliced_length = 0
        self.spliced_length = 0
    
    def add(self, unspliced, spliced):
        if self.unspliced_coords and unspliced >= self.unspliced_coords[-1]:
            self.unspliced_coords.append(unspliced)
            assert spliced >= self.spliced_coords[-1]
            self.spliced_coords.append(spliced)
        else:
            index = bisect_right(self.unspliced_coords, unspliced)
            self.unspliced_coords.insert(index, unspliced)
            self.spliced_coords.insert(index, spliced)
        self.spliced_length = max(self.spliced_length, spliced - self.spliced_coords[0] +1 )
        self.unspliced_length = max(self.unspliced_length, unspliced - self.unspliced_coords[0] + 1)
    
    def get_spliced_coord(self, unspliced_coord):
        forward_coord = unspliced_coord
#        index = bisect_left(self.unspliced_coords, forward_coord)
        index = max(0, min(bisect_left(self.unspliced_coords, forward_coord) - 1, len(self.unspliced_coords) - 1))
        result = self.spliced_coords[index] + forward_coord - self.unspliced_coords[index]
        if index < len(self.unspliced_coords) - 1:
            result = min(result, self.spliced_coords[index + 1] - 1) # unspliced_coord inside an intron gives spliced_coord of base before intron
            if forward_coord == self.unspliced_coords[index + 1]:
                result = self.spliced_coords[index + 1]
        if self.strand == '-':
            result = self.spliced_length - result - 1
        return result
    
    def get_unspliced_coord(self, spliced_coord):
        forward_coord = self.spliced_length -1 - spliced_coord if self.strand == '-' else spliced_coord
        index = max(0, min(bisect_left(self.spliced_coords, forward_coord)-1, len(self.spliced_coords) - 1))
        if  index < len(self.spliced_coords) - 1 and forward_coord == self.spliced_coords[index + 1]:
            return self.unspliced_coords[index + 1]
        return self.unspliced_coords[index] + forward_coord - self.spliced_coords[index]


class transcript_coords(object):
    
    def __init__(self):
        self._dict = {}
        
    def save(self, filename):
        output = open(filename, 'w')
        cPickle.dump(self._dict, output)
        output.close()
        
    def load(self, filename):
        self._filename = filename
        input = open(filename)
        self._dict = cPickle.load(input)
        input.close()
    
    def get_genome_strand(self, transcript_id):
        return self._dict[transcript_id].strand
    
    def get_genome_coord(self, transcript_id, transcript_coord):
        coords = self._dict[transcript_id]
        relative_coord = int(transcript_coord)
        genome_coord = coords.get_unspliced_coord(relative_coord)
#        if coords.strand == '-':
#            relative_coord = coords.spliced_length - relative_coord
#        genome_coord = coords[-1][1] + relative_coord - coords[-1][0] # to handle transcript_coord beyond end of transcript 
#        for i in range(3, len(coords)):
#            if relative_coord <= coords[i][0]:
#                genome_coord = coords[i][1] - (coords[i][0] - relative_coord)
#                break
#        return coords[0], genome_coord
        return coords.seq_id, genome_coord
    
#    def get_donor_acceptor_pair(self, transcript_id, start, end):
#        coords = self._dict[transcript_id]
#        relative_start = int(start)
#        relative_end = int(end)
#        if coords[1] == '-':
#            relative_start = coords[-1][0] - relative_start
#            relative_end = coords[-1][0] - relative_end
#        if relative_end < relative_start:
#            relative_start, relative_end = relative_end, relative_start
##        cigar = '%dM' % (relative_end - relative_start)
#        left_coord = right_coord = None
##        if relative_start == 0:
##            left_coord = coords[3][0]
##            i = 3
##        else:
#        for i in range(3, len(coords)):
#            if relative_start < coords[i][0]: 
#                left_coord = coords[i][1]
#                break
#        if left_coord == None:
#            print >> sys.stderr, 'No left coordinate found in get_donor_acceptor_pair(',transcript_id, start, end,')' 
#        for j in range(i , len(coords)):
#            if relative_end < coords[j][0]:
#                right_coord = coords[j][1] - coords[j][0] + coords[j - 1][0]
#                break
#        if right_coord == None and relative_end == coords[-1][0]:
#            right_coord = coords[-1][1] - coords[-1][0] + coords[-2][0]
#        if right_coord == None:
#            print >> sys.stderr, 'No right coordinate found in get_donor_acceptor_pair(',transcript_id, start, end,')' 
#        pair = (left_coord,  right_coord) if coords[1] == '+' else (right_coord,  left_coord)
#        return pair
#    
    def get_donor_acceptor_pairs(self, transcript_id, start, end):
        coords = self._dict[transcript_id]
        introns = []
        relative_start = int(start)
        relative_end = int(end)
        if coords.strand == '-':
            relative_start = coords.spliced_length - relative_start - 1
            relative_end = coords.spliced_length - relative_end - 1
        if relative_end < relative_start:
            relative_start, relative_end = relative_end, relative_start
        left_coord = right_coord = None
        for i in range(1, len(coords.spliced_coords)):
            if relative_start < coords.spliced_coords[i]: 
                if relative_end >= coords.spliced_coords[i]:
                    left_coord = coords.unspliced_coords[i-1] + coords.spliced_coords[i] -coords.spliced_coords[i-1] - 1
                    right_coord = coords.unspliced_coords[i]
                    if coords.strand == '+':
                        introns.append((left_coord,  right_coord))
                    else:
                        introns.append((right_coord,  left_coord))
            if relative_end <= coords.spliced_coords[i]: 
                return introns
        if left_coord == None:
            print >> sys.stderr, 'No left coordinate found in get_donor_acceptor_pair(',transcript_id, start, end,')' 
        #for j in range(i + 1 , len(coords.spliced_coords)):
            #right_coord = coords.unspliced_coords[j] - coords.spliced_coords[j] + coords.spliced_coords[j - 1]
            #pair = (left_coord,  right_coord) if coords.strand == '+' else (right_coord,  left_coord)
            #introns.append(pair)
            #if relative_end <= coords.unspliced_coords[j]:
                #break
            #else:
                #left_coord = coords.unspliced_coords[j]
                #right_coord = None
        #if right_coord == None and relative_end == coords.spliced_coords[-1]:
            #right_coord = coords.unspliced_coords[-1] - coords.spliced_coords[-1] + coords.spliced_coords[-2]
        if right_coord == None:
            print >> sys.stderr, 'No right coordinate found in get_donor_acceptor_pair(',transcript_id, start, end,')' 
        return introns
    
    
    def add(self, transcript):
        seq_len = 0
        coords_lst = scaffold_coords(transcript.get_seqID(), transcript.getStrand())
        seq_end = transcript.get_start() - 1
        for exon in sorted(transcript.get_exons()):
            coords_lst.add(exon.get_start(), seq_len)
            seq_len += min(len(exon), exon.get_end() - seq_end) # in case of overlapping or abutting exons
            seq_end = exon.get_end()
#            coords_lst.append((seq_len, exon.get_end()))
        coords_lst.add(exon.get_end(), seq_len - 1)
        self._dict[transcript.getID()] = coords_lst
    
    def get_scaffold_coords(self,  id):
        return self._dict[id]
        
if __name__ == '__main__':
    genes = GFF3Iterator(open(sys.argv[1])).genes()
    coords = transcript_coords()
    save_filename = sys.argv[1] + '.coords'
    for gene in genes:
        for transcript in gene.get_transcripts():
            coords.add(transcript)
    coords.save(save_filename)
