
# Copyright (c) 2013, Ian Reid, Concordia University Centre for Structural and Functional Genomics
# All rights reserved.

"""Create read islands from RNA-Seq splice and coverage data

Created: 2011-03-16
Author: ian
"""
import sys, os, random,  subprocess
from pysam import Tabixfile
from collections import defaultdict, deque,  namedtuple
from gff3Record import GFF3Exon,  GFF3mRNA,  GFF3Record

MINIMUM_COVER = 3
MIN_TRANSCRIPT_LENGTH = 100
MIN_EXON_LENGTH = 50
MAX_READTHROUGH = 0.6
ALPHA=0.1

Coverblock = namedtuple('Coverblock', 'seqID start end depth')
Splice = namedtuple('Splice', 'seqID start end strand score count')

def usage(msg=None):
    if msg:
        print >> sys.stderr,  msg
    print >> sys.stderr,  'Usage: python %s genome_seq.fa  observed.juncs.gz coverage.wig.gz  output.gff3' % sys.argv[0]
    sys.exit(1)

def get_cover_blocks(exon,  coverage_tabix):
    cover = []
    try:
        wiglines = list(coverage_tabix.fetch(exon.get_seqID(), exon.get_start(), exon.get_end()-1))
        for line in wiglines:
            fields = line.strip().split('\t')
            block = Coverblock(fields[0],  max(exon.get_start(), int(fields[1]) + 1),  min(exon.get_end()-1, int(fields[2])),  int(fields[3]))
            cover.append(block)
    except ValueError,  ve:
        print >> sys.stderr, '[find_read_islands.get_cover_blocks]',   ve
    return cover


def contiguous_cover(scaffold,  start,  end, coverage_tabix,   min_cover=MINIMUM_COVER):
    islands = []
    curr_island = None
    try:
        wiglines = list(coverage_tabix.fetch(scaffold,  int(start),  int(end)))
        for line in wiglines:
            fields = line.strip().split('\t')
            if len(fields) < 4:
                continue
            block = Coverblock(fields[0],  max(int(start), int(fields[1]) + 1),  min(int(end), int(fields[2])),  int(fields[3]))
            if block.depth < min_cover:
                if curr_island:
                    if curr_island[2] > curr_island[1] :
                        islands.append(curr_island)
                    curr_island = None
            else:
                if curr_island:
                    curr_island[2] = block.end
                else:
                    curr_island = [block[0],  block[1],  block[2]]
        if curr_island and curr_island[2] > curr_island[1]:
            islands.append(curr_island)
    except Exception,  ve:
        print >> sys.stderr,  '[find_read_islands:contiguous_cover]',  ve
    return islands

def make_transcript_from_island(island,  id):
    transcript = GFF3mRNA(island[0],  'RNA-Seq',  island[1],  island[2]+1,  '.',  '.',  '.',  'ID=%s' % id)
    exon = make_exon_from_island(island)
    transcript.add_exon(exon)
    return transcript

def make_exon_from_island(island):
    if island[2] < island[1] - 1:
        raise ValueError('[find_read_islands.make_exon_from_island] Too short!' + str(island))
    return GFF3Exon(island[0],  'RNA-Seq',  island[1],  island[2]+1,  '.',  '.',  '.',  '')

def break_transcript_here(transcript,  splice):
    id = transcript.get_ID()
    upstream = transcript.clone()
    upstream.set_end(int(splice.start)+1)
    try:
        if upstream.get_strand() == '-':
            upstream.set_CDS_start(int(splice.start)+1)
        else:
            upstream.set_CDS_stop(int(splice.start)+1)
    except :
        pass
    upstream.set_ID(  id + "U")
    downstream = transcript.clone()
    downstream.set_start(int(splice.end))
    try:
        if downstream.get_strand() == '-':
            downstream.set_CDS_stop(int(splice.end))
        else:
            downstream.set_CDS_start(int(splice.end))
    except:
        pass
    downstream.set_ID(  id + "D")
    for exon in transcript.get_exons():
        if exon.get_end() <= int(splice.start):
            upstream.add_exon(exon)
        elif exon.get_start() >= int(splice.end):
            downstream.add_exon(exon)
        else:
            exon_5,  exon_3 = break_exon_here(exon,  splice)
            if exon_5:
                upstream.add_exon(exon_5)
            if exon_3:
                downstream.add_exon(exon_3)
    return upstream,  downstream

def break_exon_here(exon,  splice):
    exon_5 = None
    exon_3 = None
    if exon.get_start() < int(splice.start):
        exon_5 = exon.clone()
        exon_5.set_end(int(splice.start)+1)
    if exon.get_end() > int(splice.end):
        exon_3 = exon.clone()
        exon_3.set_start(int(splice.end))
    return exon_5,  exon_3

def filter_and_output_contigs(fragments, output,  min_length= MIN_TRANSCRIPT_LENGTH,  min_cover= MINIMUM_COVER):
    for fragment in fragments:
        try:
            f_len = sum([len(exon) for exon in fragment.get_exons()])
            for exon in fragment.get_exons():
                if not exon.has_attribute('cov'):
                    if fragment.has_attribute('cov'):
                        exon.add_attribute('cov',  fragment.get_attribute('cov'))
                    else:
                        raise ValueError('[find_read_islands.filter_and_output_contigs.l133] Fragment %s lacks a value for cov' % fragment.get_ID())
            if f_len >= min_length:
                f_cov = max([float(exon.get_attribute('cov')) for exon in fragment.get_exons()])
                if  f_cov >= min_cover:
                        lines = []
                        id = fragment.get_ID()
                        fragment.set_score(f_cov)
                        trans_attrs = 'length=%d;Name=%s;ID=%s' % (f_len,  id,  id)
                        exon_attrs = 'Name=%s;Parent=%s' % (id,  id)
                        lines.append('\t'.join([str(f) for f in [fragment._seqID, fragment._source, 'contig', fragment._start, fragment._end, fragment._score, fragment._strand, fragment._phase, trans_attrs]]))
                        exons = fragment.get_exons()
                        to_sort = [ (e.get_start(),  e) for e in exons]
                        exons = [tup[1] for tup in sorted(to_sort)]
                        offset = 0
                        for exon in exons:
                            e_len = len(exon)
                            target = ';Target=%s %d %d +' % (id,  1 + offset,  e_len + offset)
                            exon.set_score(float(exon.get_attribute('cov')))
                            lines.append('\t'.join([str(f) for f in [exon._seqID, exon._source, 'match', exon._start, exon._end, exon._score, exon._strand, exon._phase, exon_attrs + target]]))
                            offset += e_len
                        lines.append('###')
                        print >> output,  '\n'.join(lines)
        except ValueError,  ve:
            print >> sys.stderr,  ve
            raise

def make_transcripts(splice_list,  coverage_list,  min_cover=MINIMUM_COVER):
    '''Create one or more transcripts from a list of splice sites and coverage records
    Input:
        splice_list - list of Splice, already filtered for adequate coverage
        coverage_list - list of Coverblock

    Output:
        list of GFF3mRNA instances
    '''
    transcripts = []
    exons = []
    if splice_list:
        intron0 = splice_list[0]
        end0 = int(intron0.start)

        # find start of coverage
        i = 0
        for i,  block in enumerate(coverage_list):
            if int(block.end) >= end0:
                break
        else:
            raise ValueError('End of 1st exon not found in coverage list')
        while i >= 0:
            if int(coverage_list[i].depth) < min_cover:
                break
            i -= 1
        i += 1
        start0 = int(coverage_list[i].start)
        exon0 = GFF3Exon(coverage_list[i].seqID,  'RNA-Seq',  start0 + 1,  end0 + 1,  '.', '.', '.', 'ID=exon0')
        exons.append(exon0)
        exon0.set_strand(intron0.strand)

        if len(splice_list) > 1:
            for intron1 in splice_list[1:]:
                exon1 = GFF3Exon(coverage_list[0].seqID,  'RNA-Seq',  int(intron0.end) + 1,  int(intron1.start) + 1,  '.', intron1.strand, '.', 'ID=exon0')

    if exons:
        transcript = GFF3mRNA.fromRecord(exons[0])
        transcripts.append(transcript)
        transcript.set_ID('t%s' % len(transcripts))
        for exon in exons:
            transcript.add_exon(exon)

    return transcripts


def get_mean_depth(blocks):
    area = length = 0
    for block in blocks:
        block_length = 1 + int(block.end) - int(block.start)
        length += block_length
        area += block_length * int(block.depth)
    if length < 1:
        return 0
    return float(area) / length

def get_sum_of_squared_deviations(blocks,  mean):
    ss = 0.0
    for block in blocks:
        block_length = 1 + int(block.end) - int(block.start)
        ss += block_length * (int(block.depth) - mean)**2
    return ss

def find_extreme_partition(blocks, left_margin=0, right_margin=0):
    lengths = [1 + int(block.end) - int(block.start) for block in blocks]
    depths = [int(block.depth) for block in blocks]
    divider = int(blocks[0].start)
    max_i = -1
    max_diff = 0
    left_len = 0
    left_area = 0
    left_offset = max(1, left_margin)
    right_offset = max(1, right_margin)
    right_len = sum(lengths)
    right_area = sum([length * depth for length, depth in zip(lengths, depths)])
    extreme_left_mean = extreme_right_mean = right_area / right_len
    for i, (length, depth) in enumerate(zip(lengths, depths)):
        area = length * depth
        left_len += length
        right_len -= length
        left_area += area
        right_area -= area
        if left_len >= left_offset and right_len >= right_offset:
            left_mean = left_area / left_len
            right_mean = right_area / right_len
            diff = abs(right_mean - left_mean)
            if diff > max_diff:
                max_i = i
                max_diff = diff
                extreme_left_mean = left_mean
                extreme_right_mean = right_mean
    if max_i > -1:
        divider = int(blocks[max_i].start)
    return divider, max_diff, extreme_left_mean, extreme_right_mean

def extract_depth_changes(blocks):
    jumps = []
    jumps.append(Coverblock(blocks[0].seqID, 0, int(blocks[0].end) - int(blocks[0].start), 0))
    for prevblock, nextblock in zip(blocks[:-1], blocks[1:]):
        jumps.append(Coverblock(nextblock.seqID, 0, int(nextblock.end) - int(nextblock.start), int(nextblock.depth) - int(prevblock.depth)))
    return jumps

def apply_depth_changes(jumps):
    work_copy = [Coverblock(jumps[0].seqID, 0, jumps[0].end, jumps[0].depth)]
    for jump in jumps[1:]:
        work_copy.append(Coverblock(jump.seqID,work_copy[-1].end, work_copy[-1].end + jump.end, work_copy[-1].depth + jump.depth))
    return work_copy

def sample_extreme_partitions(max_iter, blocks, threshold,  critical_tail_count,  left_margin=0, right_margin=0):
    jumps = extract_depth_changes(blocks)
    diffs = []
    tail_count = 0
    for i in xrange(max_iter):
        random.shuffle(jumps)
        work_copy = apply_depth_changes(jumps)
        divider, max_diff, extreme_left_mean, extreme_right_mean = find_extreme_partition(work_copy, left_margin, right_margin)
        diffs.append(max_diff)
        if max_diff > threshold:
            tail_count += 1
            if tail_count > critical_tail_count:
                break
    diffs.sort()
    return diffs

def calculate_tail_area(x, distribution):
    for i in range(1, len(distribution)):
        if distribution[-i] < x:
            rank = i - 1
            break
    else:
        rank = len(distribution)
    p = float(rank) / len(distribution)
    return p

def find_1_jump(blocks, left_margin=MIN_EXON_LENGTH, right_margin=MIN_EXON_LENGTH):
    divider, max_diff, extreme_left_mean, extreme_right_mean = find_extreme_partition(blocks, left_margin, right_margin)
    critical_tail_count = 1000 * ALPHA
    distribution = sample_extreme_partitions(1000, blocks, max_diff,  critical_tail_count, left_margin, right_margin)
    p = calculate_tail_area(max_diff, distribution)
    return divider, p, extreme_left_mean, extreme_right_mean

def break_at_depth_jumps(transcript_deque, coverage_tabix):
    jumps = []
    fragments = deque()
    while transcript_deque:
        current_transcript = transcript_deque.popleft()
        current_transcript_len = sum([len(exon) for exon in current_transcript.get_exons()])
        if current_transcript_len < 2:
            continue
        for exon in current_transcript.get_exons():
            if len(exon) < 2:
                print >> sys.stderr,  '[find_read_islands.break_at_depth_jumps] Short exon:\n',  str(current_transcript)
            if not exon.has_attribute('cov'):
                blocks = get_cover_blocks(exon,  coverage_tabix)
                exon_mean_depth = get_mean_depth(blocks)
                exon.add_attribute('cov',  exon_mean_depth)
        current_transcript_cov = sum([len(exon) * float(exon.get_attribute('cov')) for exon in current_transcript.get_exons() ]) / current_transcript_len
        current_transcript.add_attribute('cov',  current_transcript_cov)
#        if current_transcript_len < MIN_TRANSCRIPT_LENGTH:
        fragments.append(current_transcript)
#        else:
#            for exon in current_transcript.get_exons():
#                if exon.get_start() + 2 * MIN_EXON_LENGTH >= exon.get_end():
#                    continue
#                if float(exon.get_attribute('cov')) < MINIMUM_COVER :
#                    continue
#                # check whether exon should be split because of jump in coverage depth
#                blocks = get_cover_blocks(exon,  coverage_tabix)
#                divider, p,  left_mean,  right_mean = find_1_jump(blocks, left_margin=MIN_EXON_LENGTH, right_margin=MIN_EXON_LENGTH)
#                if p < ALPHA and min([divider - exon.get_start(),  exon.get_end() - divider]) >= MIN_EXON_LENGTH:
#                    j_splice = Splice(current_transcript.get_seqID(),  divider,  divider+1,  current_transcript.get_strand(), 1,  1)
#                    upstream,  downstream = break_transcript_here(current_transcript,  j_splice)
#                    if len(downstream) > 1:
#                        downstream.add_attribute('cov',  right_mean)
#                        sorted(downstream.get_exons())[0].add_attribute('cov',  right_mean)
#                        transcript_deque.appendleft(downstream)
#                    if len(upstream) > 1:
#                        upstream.add_attribute('cov',  left_mean)
#                        sorted(upstream.get_exons())[-1].add_attribute('cov',  left_mean)
#                        transcript_deque.appendleft(upstream)
#                    break
#            else:
#                fragments.append(current_transcript)

    return fragments

def get_scaffold_islands(args):
    scaffold, scaffold_len, juncs_tabix,  cover_tabix = args
    curr_transcript = None
    search_start = 1
    island_num = 0
    island_transcripts = []
    result = []
    try:
        all_splices = [Splice._make(j.split('\t')[:6]) for j in juncs_tabix.fetch(scaffold)]
        s_splices = [s for s in all_splices if float(s.score) <= MAX_READTHROUGH and int(s.count) >= MINIMUM_COVER]
        splices = [s_splices[0]]
        for s in s_splices[1:]:
            if int(s.start) > int(splices[-1].end):
                splices.append(s)
            elif int(s.count) > int(splices[-1].count): # keep only best-supported of overlapping splices
                splices[-1] = s
    except Exception,  e:
        print >> sys.stderr,  '[find_read_islands.l372]',  e
        splices = []
    try:
        if splices:
            inter_splice0 = (scaffold,  1,  int(splices[0].start))
            try:
                islands = contiguous_cover(*inter_splice0, coverage_tabix=cover_tabix,  min_cover=MINIMUM_COVER)
                if islands and len(islands) > 1:
                    for island in islands[:-1]:
                        try:
                            island_num += 1
                            transcript = make_transcript_from_island(island, '%s.Island%06d' % (scaffold, island_num))
                            island_transcripts.append(transcript)
                        except ValueError,  ve:
                            print >> sys.stderr,  ['find_read_islands.l386'],  ve

                if islands:
                    try:
                        island_num += 1
                        curr_transcript = make_transcript_from_island(islands[-1], '%s.Island%06d' % (scaffold, island_num))
                        curr_transcript.set_strand(splices[0].strand)
                        if islands[-1][2] != int(splices[0].start):
                            island_transcripts.append(curr_transcript)
                            curr_transcript = None
                    except ValueError,  ve:
                        print >> sys.stderr,  ['find_read_islands.l397'],  ve
                        curr_transcript = None
            except Exception,  e:
                print >> sys.stderr,  '[find_read_islands.l400]',  e
                raise
            if len(splices) > 1:
                for splice1, splice2 in zip(splices[:-1], splices[1:]):
                    if int(splice1.end) >= int (splice2.start):
                        print >> sys.stderr,  'Overlapping splices!',  scaffold,  splice1.end,  splice2.start
                        continue
                    inter_splice = (scaffold,  int(splice1.end),  int(splice2.start))
                    try:
                        islands = contiguous_cover(*inter_splice, coverage_tabix=cover_tabix,  min_cover=MINIMUM_COVER)
                        if islands:
                            if curr_transcript:
                                exon = make_exon_from_island(islands[0])
                                curr_transcript.add_exon(exon)
                            else:
                                island_num += 1
                                curr_transcript = make_transcript_from_island(islands[0], '%s.Island%06d' % (scaffold, island_num))
                            curr_transcript.set_strand(splice1.strand)
                            if len(islands) > 1:
                                for island in islands[1:]:
                                    island_transcripts.append(curr_transcript)
                                    island_num += 1
                                    curr_transcript = make_transcript_from_island(island, '%s.Island%06d' % (scaffold, island_num))
                                    for transcript in island_transcripts:
                                        try:
                                            t_deque = deque([transcript])
                                            fragments = break_at_depth_jumps(t_deque, cover_tabix)
                                            filter_and_output_contigs(fragments, output,  1,  1)
                                        except Exception,  e:
                                            print >> sys.stderr,  '[find_read_islands.l429]',  e
                                            raise
                                    island_transcripts = []

                            if splice2.strand != splice1.strand:
                                if len(islands) == 1: # Start a new transcript with first exon overlapping last exon of current transcript; leave finding break between transcripts until later
                                    island_transcripts.append(curr_transcript)
                                    island_num += 1
                                    curr_transcript = make_transcript_from_island(islands[0], '%s.Island%06d' % (scaffold, island_num))
                                    curr_transcript.set_strand(splice2.strand)
                    except Exception,  e:
                        print >> sys.stderr,  '[find_read_islands.l440]',  e
                        raise
            search_start = splices[-1].end
    # now past last splice
        try:
            if int(search_start) < int(scaffold_len) :
                islands = contiguous_cover(scaffold, search_start, scaffold_len, coverage_tabix=cover_tabix,  min_cover=MINIMUM_COVER)
                #if islands and splices:
                    #assert int(splices[-1].end) == islands[0][1]
                if islands:
                    if curr_transcript:
                        try:
                            curr_transcript.add_exon(make_exon_from_island(islands[0]))
                        except ValueError, ve:
                            print >> sys.stderr,  '[find_read_islands.l455]',  ve
                        island_transcripts.append(curr_transcript)
                        curr_transcript = None
                    else:
                        try:
                            island_num += 1
                            transcript = make_transcript_from_island(islands[0], '%s.Island%06d' % (scaffold, island_num))
                            if splices:
                                transcript.set_strand(splices[-1].strand)
                            island_transcripts.append(transcript)
                        except ValueError, ve:
                            print >> sys.stderr,  '[find_read_islands.l465]',  ve

                    for island in islands[1:]:
                        try:
                            island_num += 1
                            transcript = make_transcript_from_island(island, '%s.Island%06d' % (scaffold, island_num))
                            island_transcripts.append(transcript)
                        except ValueError, ve:
                            print >> sys.stderr,  '[find_read_islands.l473]',  ve

        except Exception,  e:
            print >> sys.stderr,  '[find_read_islands.l476]',  e

        for transcript in island_transcripts:
            t_deque = deque([transcript])
            fragments = break_at_depth_jumps(t_deque, cover_tabix)
            result.append(fragments)
    except ValueError, ve:
        print >> sys.stderr,  '[find_read_islands.l483]',  ve
        raise
    return result



if __name__ == '__main__':
    try:
        genome_seq = sys.argv[1]
        fai_path = genome_seq + '.fai'
        if not os.path.isfile(fai_path):
            subprocess.call('samtools faidx %s' % genome_seq,  shell=True)
        scaffolds = open(fai_path)
        juncs_tabix = Tabixfile(sys.argv[2])
        cover_tabix = Tabixfile(sys.argv[3])
        output = open(sys.argv[4],  'w')
    except Exception,  e:
        usage(e)

    for line in scaffolds:
        scaffold, scaffold_len = line.split()[:2]
        args = (scaffold, scaffold_len, juncs_tabix,  cover_tabix)
        scaffold_islands = get_scaffold_islands(args)
        for fragments in scaffold_islands:
            filter_and_output_contigs(fragments, output,  min_length=0,  min_cover=0)
    output.close()
    print sys.argv[0],  'done.'
