
# Copyright (c) 2013, Ian Reid, Concordia University Centre for Structural and Functional Genomics
# All rights reserved.

""" From a known transcript, simulate cleavage, size filtering of fragments, and generation of short reads from fragment ends.

Created on Oct 8, 2009
@author: ian
"""

import random, sys, os
from Bio import  Alphabet
from Bio.Seq import Seq
from Bio.SeqRecord import SeqRecord
import math
from bisect import bisect_left

candidates = 0
accepts = 0

def reverse_complement(sequence_record):
    """
    Custom version of reverse-complement that adjusts special annotations as well as sequence
    :param sequence_record:
    :type sequence_record: Bio.SeqRecord
    :return:
    :rtype: Bio.SeqRecord
    """
    rc_seq = sequence_record.seq.reverse_complement()
    rc_id = sequence_record.id.replace('+', '-') if '[+]' in sequence_record.id else sequence_record.id.replace('-', '+')
    rc = SeqRecord(rc_seq, rc_id)
    rc.annotations = sequence_record.annotations.copy()
    rc.annotations['start'] = sequence_record.annotations['end']
    rc.annotations['end'] = sequence_record.annotations['start']
    rc.annotations['strand'] = '-' if sequence_record.annotations['strand'] == '+' else '+'
    if "cigar_list" in rc.annotations:
        rc.annotations["cigar_list"] = reversed(rc.annotations["cigar_list"])
    return rc

def fuzzySizeFilter(length, minimum, lower, upper, maximum):
    """Return True or False for acceptance of a fragment with given length.

    Accepts lengths between lower and upper with probability 1. The probability of acceptance decreases
    linearly from lower toward minimum and from upper toward maximum. Lengths <= minimum or >= maximum
    are never accepted.
    """
    if minimum <= length <= maximum:
        if lower <= length  <= upper:
            return True
        if length < lower:
            fraction = float(lower - length) / (lower - minimum)
        else:
            fraction = float(maximum - length) / (maximum - upper)
        if random.random() <= fraction:
            return True
    return False


def introduceIndels(sequence_record, indel_rate):
    """Randomly double or delete nucleotides in a DNA sequence.

    sequence_record is a SeqRecord.
    indel_rate is the probability of insertion or deletion at each position.
    Returns a new SeqRecord.
    """
    if indel_rate == 0:
        sequence_record.annotations['cigar_list'] =[[ len(sequence_record), 'M']]
        return sequence_record
    new_sequence = []
    m = 0
    cigar_list = ''
    for c in str(sequence_record.seq):
        if indel_rate > 0 and random.random() <= indel_rate:
            cigar_list.append( [m, 'M'])
            m = 0
            if random.random() < 0.5:
                new_sequence.append(c)
                m += 1
                new_sequence.append(c)
                cigar_list.append([1, 'I'])
            else:
                cigar_list.append([1, 'D'])
        else:
            new_sequence.append(c)
            m += 1
    if m > 0:
        cigar_list.append( [m, 'M'])
    new_record = SeqRecord(Seq.Seq(''.join(new_sequence), Alphabet.generic_dna), sequence_record.id)
    new_record.annotations = sequence_record.annotations.copy()
    new_record.annotations['cigar_list'] = cigar_list
    return new_record

def generatePCRfrags(frag_target, transcript, size_min, size_lower, size_upper, size_max, indel_rate=0,  max_tries=2000000):
    """Return a set of random, size-filtered fragments from a transcript.

    transcript is a DNA SeqRecord including letter_annotations specifying
    the probability of cleavage at each position.
    size_min, size_lower, size_upper, size_max are the parameters of the size filter
    indel_rate is the probability of insertion or deletion at each position.
    max_tries is a positive integer to guard against infinite looping
    Returns a list of SeqRecords.
    """
    global candidates, accepts
    # Setup cumulative probability arrays
    rate_vector_plus = transcript.letter_annotations['origin_probability_plus']
    rate_vector_minus = transcript.letter_annotations['origin_probability_minus']
    cum_probs_plus = []
    cum_probs_minus = []
    cum_prob = 0
    for fp in rate_vector_plus:
        cum_prob += fp
        cum_probs_plus.append(cum_prob)
    normalizing_factor = cum_probs_plus[-1]
    cum_probs_plus = [cp / normalizing_factor for cp in cum_probs_plus]
    cum_prob = 0
    for fp in rate_vector_minus:
        cum_prob += fp
        cum_probs_minus.append(cum_prob)
    normalizing_factor = cum_probs_minus[-1]
    cum_probs_minus = [cp / normalizing_factor for cp in cum_probs_minus]

    # Generate fragments
    mrna = str(transcript.seq)
    annotation_common = {'transcript_id' : transcript.id, 'strand' : '+'}
    tries = 0
    frags = []
    while len(frags) < frag_target and tries < max_tries:
        # Choose a left end
        initiate0 = bisect_left(cum_probs_plus, random.random())
        while initiate0 >= len(transcript) - size_lower:
            initiate0 = bisect_left(cum_probs_plus, random.random())
            tries += 1
        # Now choose a right end at an acceptable distance
        min_right = cum_probs_minus[initiate0 + size_min]
        size_range = cum_probs_minus[min(initiate0 + size_max - 1, len(cum_probs_minus)-1)] - min_right
        if size_range < 0.002:
            # print >> sys.stderr, 'Target probability range = %6.4f' % size_range
            tries += 1
            continue
        initiate1 = bisect_left(cum_probs_minus,  min_right + size_range * random.random())
        tries += 1
        while not fuzzySizeFilter(initiate1 - initiate0 + 1,  size_min, size_lower, size_upper, size_max):
            initiate1 = bisect_left(cum_probs_minus,  min_right + size_range * random.random())
            tries += 1
        if initiate1 < initiate0:
            continue
        # Format the fragment as a SeqRecord
        seq_fragment = mrna[initiate0:initiate1+1]
        annotations = annotation_common.copy()
        if transcript.annotations['strand'] == '-':
            annotations['start'] = initiate0
        else:
            annotations['start'] = initiate0
        annotations['end'] = initiate1
        annotations['isize'] = len(seq_fragment)
        fragment = SeqRecord(Seq(seq_fragment, Alphabet.generic_dna), id='%s[+]:%d..%d' % (transcript.id, initiate0, initiate1), annotations=annotations)
        if indel_rate > 0:
            fragment = introduceIndels(fragment, indel_rate)
        if random.random() < 0.5:
            frags.append(fragment)
        else:
            frags.append(reverse_complement(fragment))
    print '{:n} fragments generated in {:n} tries.'.format (len(frags), tries )
    candidates += tries
    accepts += len(frags)
    return frags



def getReadPair(sequence, read_length):
    return[get5primeRead(sequence, read_length), get5primeRead(reverse_complement(sequence), read_length)]

def get5primeRead(sequence, read_length):
    read = sequence[:read_length]
    read.annotations = sequence.annotations.copy()
    # in GFF, feature length = feature end + 1 -feature start
    read.annotations['end'] = read.annotations['start'] + read_length -1 if read.annotations['strand'] == '+' else read.annotations['start'] - read_length + 1
    if 'cigar_list' in read.annotations:
        frag_cigar_list = read.annotations['cigar_list']
        read_cigar_list = []
        cursor = i = 0
        for i, segment in enumerate(frag_cigar_list):
            if cursor + segment[0] <= read_length:
                read_cigar_list.append(segment)
                cursor += segment[0]
            else:
                residual = read_length - cursor
                if residual > 0:
                    read_cigar_list.append([residual, segment[1]])
                break
        read.annotations['cigar_list'] = read_cigar_list
    return read

def generateReadSet(transcript, copy_number, read_length, size_min, size_lower, size_upper, size_max, paired=False, indel_rate=0):
    reads = []
    if len(transcript) < size_min:
        return []
    frags = generatePCRfrags(copy_number, transcript, size_min, size_lower, size_upper, size_max, indel_rate,max_tries=10*copy_number)
    if paired:
        readset = [getReadPair(frag, read_length) for frag in frags]
    else:
        readset = [get5primeRead(frag, read_length) for frag in frags]
    return readset

def generateUnpairedReadSet(transcript, copy_number, read_length, size_min, size_lower, size_upper, size_max, indel_rate=0):
    return generateReadSet(transcript, copy_number, read_length, size_min, size_lower, size_upper, size_max, False, indel_rate)

def generatePairedReadSet(transcript, copy_number, read_length, size_min, size_lower, size_upper, size_max, indel_rate=0):
    return generateReadSet(transcript, copy_number / 2, read_length, size_min, size_lower, size_upper, size_max, True, indel_rate)


#if __name__ == '__main__':
#    test_rec = SeqRecord.SeqRecord(Seq.Seq('ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz'),'Test')
#    test_rec.letter_annotations['fragmentation_probability'] = [0.2] * len(test_rec)
#    frags = cleaveSequence(test_rec)
#    for frag in frags:
#        if fuzzySizeFilter(len(frag), 2, 3, 4, 8):
#            print '+', str(frag.seq)
#        else:
#            print '-', str(frag.seq)
#    sized = [frag for frag in frags if fuzzySizeFilter(len(frag), 2, 3, 4, 8)]
#    print [str(f.seq) for f in sized]
#    reads = [getReadPair(frag, 2) for frag in sized]
#    print reads
#    print(sum(reads, []))

