
# Copyright (c) 2013, Ian Reid, Concordia University Centre for Structural and Functional Genomics
# All rights reserved.

""" From a known set of transcripts, generate short reads starting at all positions for testing read assembly tools.

Created on Jun 7, 2010
@author: ian
"""

import random
from Bio import Alphabet
from Bio.Seq import Seq
from Bio.SeqRecord import SeqRecord

candidates = 0
accepts = 0


def reverse_complement(sequence_record):
    rc_seq = sequence_record.seq.reverse_complement()
    rc_id = sequence_record.id.replace('+', '-') if '[+]' in sequence_record.id else sequence_record.id.replace('-',
                                                                                                                '+')
    rc = SeqRecord(rc_seq, rc_id)
    rc.annotations = sequence_record.annotations.copy()
    rc.annotations['start'] = sequence_record.annotations['end']
    rc.annotations['end'] = sequence_record.annotations['start']
    rc.annotations['strand'] = '-' if sequence_record.annotations['strand'] == '+' else '+'
    if "cigar_list" in rc.annotations:
        rc.annotations["cigar_list"] = reversed(rc.annotations["cigar_list"])
    return rc


def cleaveSequence(transcript, frag_len):
    ''' Return a list of random non-overlapping subsequences of a transcript.
    
    transcript is a DNA SeqRecord including letter_annotations specifying
    the probability of cleavage at each position.
    '''
    fragments = []
    mrna = str(transcript.seq)
    annotation_common = {'transcript_id': transcript.id, 'strand': '+'}
    for i in range(len(transcript) - frag_len):
        seq_fragment = mrna[i: i + frag_len + 1]
        annotations = annotation_common.copy()
        annotations['start'] = i
        annotations['end'] = i + frag_len
        fragment = SeqRecord(Seq(seq_fragment, Alphabet.generic_dna),
                             id='%s[+]:%d..%d' % (transcript.id, i, i + frag_len), annotations=annotations)
        fragments.append(fragment)
    return fragments


def introduceIndels(sequence_record, indel_rate):
    """Randomly double or delete nucleotides in a DNA sequence.

    sequence_record is a SeqRecord.
    indel_rate is the probability of insertion or deletion at each position.
    Returns a new SeqRecord.
    """
    if indel_rate == 0:
        sequence_record.annotations['cigar_list'] = [[len(sequence_record), 'M']]
        return sequence_record
    new_sequence = []
    m = 0
    cigar_list = []
    for c in str(sequence_record.seq):
        if indel_rate > 0 and random.random() <= indel_rate:
            cigar_list.append([m, 'M'])
            m = 0
            if random.random() < 0.5:
                new_sequence.append(c)
                m += 1
                new_sequence.append(c)
                cigar_list.append([1, 'I'])
            else:
                cigar_list.append([1, 'D'])
        else:
            new_sequence.append(c)
            m += 1
    if m > 0:
        cigar_list.append([m, 'M'])
    new_record = SeqRecord(Seq.Seq(''.join(new_sequence), Alphabet.generic_dna), sequence_record.id)
    new_record.annotations = sequence_record.annotations.copy()
    new_record.annotations['cigar_list'] = cigar_list
    return new_record


def generateFrags(transcript, size_min, size_lower=200, size_upper=250, size_max=300, indel_rate=0):
    """Return a complete set of  subsequences of length size_min from a transcript.
    
    transcript is a DNA SeqRecord including letter_annotations specifying
    the probability of cleavage at each position.
    size_min, size_lower, size_upper, size_max are the parameters of the size filter
    indel_rate is the probability of insertion or deletion at each position.
    Returns a list of SeqRecords.
    """
    global candidates, accepts
    target_fragment_size = size_min
    frags = cleaveSequence(transcript, target_fragment_size)
    candidates += len(frags)
    sized = frags
    accepts += len(sized)
    for frag in sized:
        frag.annotations['isize'] = len(frag)
    rcs = [reverse_complement(frag) for frag in sized]
    sized.extend(rcs)
    mutated = [introduceIndels(frag, 0) for frag in sized]
    return mutated


def getReadPair(sequence, read_length):
    return [get5primeRead(sequence, read_length), get5primeRead(reverse_complement(sequence), read_length)]


def get5primeRead(sequence, read_length):
    read = sequence[:read_length]
    read.annotations = sequence.annotations.copy()
    read.annotations['end'] = read.annotations['start'] + read_length - 1 if read.annotations['strand'] == '+' else read.annotations['start'] - read_length + 1
    if 'cigar_list' in read.annotations:
        frag_cigar_list = read.annotations['cigar_list']
        read_cigar_list = []
        cursor = i = 0
        for i, segment in enumerate(frag_cigar_list):
            if cursor + segment[0] <= read_length:
                read_cigar_list.append(segment)
                cursor += segment[0]
            else:
                residual = read_length - cursor
                if residual > 0:
                    read_cigar_list.append([residual, segment[1]])
                break
        read.annotations['cigar_list'] = read_cigar_list
    return read


def generateReadSet(transcript, copy_number, read_length, size_min, size_lower, size_upper, size_max, paired=False,
                    indel_rate=0):
    reads = []
    for i in range(copy_number):
        frags = generateFrags(transcript, size_min, size_lower, size_upper, size_max, indel_rate)
        if paired:
            readset = [getReadPair(frag, read_length) for frag in frags]
        else:
            readset = [get5primeRead(frag, read_length) for frag in frags]
        reads.extend(readset)
    return reads


def generateUnpairedReadSet(transcript, copy_number, read_length, size_min, size_lower, size_upper, size_max,
                            indel_rate=0):
    return generateReadSet(transcript, copy_number, read_length, size_min, size_lower, size_upper, size_max, False,
                           indel_rate)


def generatePairedReadSet(transcript, copy_number, read_length, size_min, size_lower, size_upper, size_max,
                          indel_rate=0):
    return generateReadSet(transcript, copy_number, read_length, size_min, size_lower, size_upper, size_max, True,
                           indel_rate)


