# Copyright (c) 2013, Ian Reid, Concordia University Centre for Structural and Functional Genomics
# All rights reserved.

'''Output a set of simulated RNA-Seq short reads from a set of transcripts.

Positional arguments:
gene_id-copy_numbers.txt
transcripts.fasta
frag_prob_dict.shelve
transcript.coords
output_filename.fastq
probability_files_dir

Created on Oct 8, 2009
@author: ian
Modified to use Stranded_Read_Initiation_Probability_Vectors 2012-12-07

The relative probability of a substitution error at each position in a read was estimated empirically
from counts of correct and incorrect bases in a set of reads mapped to a genome, and the relative probability
of each quality code for correct and incorrect bases at each position was estimated from the same data.
'''
import sys
import logging
import os
import traceback
import tempfile
from gzip import GzipFile
from collections import defaultdict
import random
from optparse import OptionParser
import shelve
from bisect import bisect

from Bio import SeqIO, SeqRecord, Seq, Alphabet
from Bio.Seq import reverse_complement


this_dir = os.path.dirname(__file__)
src = os.path.dirname(this_dir)
sys.path.append(src)
from lib.transcript_coords import transcript_coords
from lib.samText import SAMText as SAMRead
from lib.gff3Iterator import GFF3Iterator
from lib.repackageSimulation import doRepackageSimulation
from lib import generateShortReads
from calcTranscriptFragmentationProbabilities import Stranded_Read_Initiation_Probability_Vector

MAX_ILLUMINA_QUALITY = 62
illumina_quality_code = [chr(q + 64) for q in range(MAX_ILLUMINA_QUALITY + 1)]
phred_quality_code = [chr(q + 33) for q in range(MAX_ILLUMINA_QUALITY + 1)]
illumina2phred_quality_code = dict(zip(illumina_quality_code, phred_quality_code))


def load_probability_tables(probability_files_dir):
    error_prob_file = open(os.path.join(probability_files_dir, 'error_probabilities.txt'))
    BBB_error_prob_file = open(os.path.join(probability_files_dir, 'BBB.error_probabilities.txt'))
    BBB_init_prob_file = open(os.path.join(probability_files_dir, 'BBB.init_probabilities.txt'))
    correct_qual_freq_file = open(
        os.path.join(probability_files_dir, 'Correct_reads.quality_scores.cumulative_frequency.txt'))
    incorrect_qual_freq_file = open(
        os.path.join(probability_files_dir, 'Incorrect_reads.quality_scores.cumulative_frequency.txt'))
    error_prob = []
    for line in error_prob_file:
        error_prob.append(float(line.strip().split()[-1]))
    error_prob_file.close()

    BBB_error_prob = []
    for line in BBB_error_prob_file:
        BBB_error_prob.append(float(line.strip().split()[-1]))
    BBB_error_prob_file.close()

    BBB_init_prob = []
    for line in BBB_init_prob_file:
        BBB_init_prob.append(float(line.strip().split()[-1]))
    BBB_init_prob_file.close()

    correct_qual_freq = []
    for line in correct_qual_freq_file:
        freqs = [float(f) for f in line.strip().split()]
        correct_qual_freq.append(freqs[:])
    correct_qual_freq_file.close()

    incorrect_qual_freq = []
    for line in incorrect_qual_freq_file:
        freqs = [float(f) for f in line.strip().split()]
        incorrect_qual_freq.append(freqs[:])
    incorrect_qual_freq_file.close()

    return error_prob, BBB_error_prob, BBB_init_prob, correct_qual_freq, incorrect_qual_freq


def gene_copy_iterator(filename):
    lines = open(filename)
    for line in lines:
        fields = line.strip().split('\t')
        try:
            gene_id = fields[0]
            copies = int(fields[1])
            yield (gene_id, copies)
        except Exception:
            continue


def introduceSubstitutions(sequence_record, error_prob, BBB_error_prob, BBB_init_prob, correct_qual_freq,
                           incorrect_qual_freq, N_rate):
    new_sequence = list(str(sequence_record.seq))
    quals = ''
    bbb = False #  qual == 2 is a special Read Segment Quality Control Indicator and all subsequent quals in this
    # read are set to 2 unless the read is N; looks like BBB... in Illumina quality scale
    for i in range(len(new_sequence)):
        bbb = bbb or random.random() <= BBB_init_prob[i]
        substitution_probability = BBB_error_prob[i] if bbb else error_prob[i]
        if random.random() <= substitution_probability:
            if random.random() <= N_rate:
                new_sequence[i] = 'N'
                qual = 2 # Ns are always assigned quality 2, but this does not force subsequent positions to take
                # quality 2
            else:
                new_base = random.choice('ACGT')
                while new_base == new_sequence[i]:
                    new_base = random.choice('ACGT')
                new_sequence[i] = new_base
                qual = 2 if bbb else bisect(incorrect_qual_freq[i], random.random())
        else:
            qual = 2 if bbb else bisect(correct_qual_freq[i], random.random())
        quals += illumina_quality_code[qual]
    new_record = SeqRecord.SeqRecord(Seq.Seq(''.join(new_sequence), Alphabet.generic_dna), sequence_record.id)
    new_record.letter_annotations['phred_quality'] = quals
    new_record.annotations = sequence_record.annotations
    return new_record


def estimate_substitution_error_rate(read_length, error_prob_vector, BBB_error_prob_vector, BBB_init_prob_vector):
    err_rate = 0
    bbb_rate_i = 0
    for i in range(read_length):
        bbb_rate_i += BBB_init_prob_vector[i] * (1 - bbb_rate_i)
        err_rate_i = (1 - bbb_rate_i) * error_prob_vector[i] + bbb_rate_i * BBB_error_prob_vector[i]
        err_rate += err_rate_i
    err_rate /= options.read_length
    return err_rate


def adjust_error_rate(target_rate, read_length, error_prob_vector, BBB_error_prob_vector, BBB_init_prob_vector):
    # Adjust substitution error rate
    err_rate = estimate_substitution_error_rate(read_length, error_prob_vector, BBB_error_prob_vector,
                                                BBB_init_prob_vector)
    adjustment_factor = target_rate / err_rate
    logging.info('Error rate adjustment factor = %.3f' % adjustment_factor)
    error_prob_vector = [e * adjustment_factor for e in error_prob_vector]
    BBB_error_prob_vector = [e * adjustment_factor for e in BBB_error_prob_vector]
    adjusted_err_rate = estimate_substitution_error_rate(read_length, error_prob_vector, BBB_error_prob_vector,
                                                         BBB_init_prob_vector)
    logging.info('Adjusted error rate  = %.3f' % adjusted_err_rate)
    return error_prob_vector, BBB_error_prob_vector, BBB_init_prob_vector


def generateSam(read, coords):
    cigar_list = read.annotations['cigar_list'][:]
    sam = SAMRead(read.id)
    start = int(read.annotations['start'])
    end = int(read.annotations['end'])
    gene_id = read.annotations['transcript_id']
    if len(gene_id.split('|')) > 2:
        transcript_id = 'mRNA' + gene_id.split('|')[2]
    else:
        transcript_id = gene_id
    genome_start = coords.get_genome_coord(transcript_id, start)
    genome_end = coords.get_genome_coord(transcript_id, end)
    strand = '-' if genome_end[1] < genome_start[1] else '+'
    sam.rname = genome_start[0]
    sam.pos = min(genome_start[1], genome_end[1])
    sam.seq = str(read.seq)
    sam.qual = ''.join([illumina2phred_quality_code[q] for q in read.letter_annotations['phred_quality']])
    trans_coords = coords.get_scaffold_coords(transcript_id)
    transcript_strand = trans_coords.strand
    if strand == '-':
        sam.seq = reverse_complement(sam.seq)
        sam.qual = sam.qual[::-1]
        sam.set_reverse_strand_flag()
    if abs(genome_start[1] - genome_end[1]) >= len(read):
        introns = coords.get_donor_acceptor_pairs(transcript_id, start, end)
        if introns:
            for donor, acceptor in introns:
                ex1 = abs(min(trans_coords.get_spliced_coord(donor),
                              trans_coords.get_spliced_coord(acceptor)) - trans_coords.get_spliced_coord(sam.pos)) + 1
                if transcript_strand == '-':
                    ex1 -= 1
                gap = abs(acceptor - donor) - 1
                ex2 = len(read) - ex1
                cursor = 0
                for i in range(len(cigar_list)):
                    if cigar_list[i][1] in 'MI':
                        if cursor + cigar_list[i][0] > ex1:
                            break
                        cursor += cigar_list[i][0]
                new_m = ex1 - cursor # adjustment for different coordinate systems
                new_m2 = cursor + cigar_list[i][0] - ex1
                if new_m < 0 or new_m2 < 0:
                    raise ValueError('Negative segment length in SAM cigar: ' + str(read) + ' ' + str(coords._dict))
                letter = cigar_list[i][1]
                cigar_list[i][0] = new_m
                if i + 1 < len(cigar_list):
                    cigar_list.insert(i + 1, [new_m2, letter])
                    cigar_list.insert(i + 1, [gap, 'N'])
                else:
                    cigar_list.append([gap, 'N'])
                    cigar_list.append([new_m2, letter])
            intron_dir = '+' if acceptor > donor else '-'
            sam.tags.append('XS:A:%s' % intron_dir)
            filtered_cigar = [s for s in cigar_list if s[0] > 0] # remove zero-length segments
            i = 0
            while i < len(filtered_cigar) - 1:
                if filtered_cigar[i][1] == filtered_cigar[i + 1][1]:
                    filtered_cigar[i][0] += filtered_cigar[i + 1][0] # collapse adjacent segments with same opcode
                    del filtered_cigar[i + 1]
                else:
                    i += 1
            i = 0
            while i < len(filtered_cigar) - 1:
                if filtered_cigar[i][1] != 'M' and (i == 0 or i > len(filtered_cigar) - 2):
                    sam.pos += filtered_cigar[i][0]
                    del filtered_cigar[i]
                else:
                    i += 1
            read.annotations['cigar_list'] = filtered_cigar
    sam.cigar = ''.join(['%d%s' % (length, letter) for length, letter in read.annotations['cigar_list']])
    return sam


def get_transcript_record(transcript, genome_index, options, origin_prob_dict):
    """ Create a Bio.SeqRecord corresponding to a transcript isoform and annotate it with info
    needed for simulated read generation

    :param transcript:
    :type transcript: GFF3mRNA
    :param genome_index:
    :type genome_index: Bio.SeqIO.dict
    :param options:
    :type options: Namespace
    :param origin_prob_dict:
    :type origin_prob_dict: shelve
    :return:
    :rtype: SeqRecord or None
    """
    transcript_rec = transcript.get_transcript_sequence(genome_index)
    transcript_rec.seq = transcript_rec.seq.upper()
    transcript_rec.id = transcript.get_ID()
    transcript_rec.name = transcript.get_name()
    transcript_rec.annotations['strand'] = transcript.get_strand()
    seq_len = len(transcript_rec)

    if seq_len < options.size_upper:
        print >> sys.stderr, '%s is too short: %d < %d' % (transcript_rec.id, seq_len, options.size_upper)
        logging.warn('%s is too short: %d < %d' % (transcript_rec.id, seq_len, options.size_upper))
        return None

    if transcript_rec.id in origin_prob_dict:
        origin_rate_vector = origin_prob_dict[transcript_rec.id]
    else:
        print >> sys.stderr, 'No origin probability vector found for', transcript_rec.id
        print >> sys.stderr, transcript
        logging.error('No origin probability vector found for ' + transcript_rec.id)
        return None

    if sum(origin_rate_vector.plus) < 0.05:
        logging.warn('Zero origin probability for ' + transcript_rec.id)
        return None

    if len(origin_rate_vector) < seq_len:
        origin_rate_vector.plus.extend([0] * (seq_len - len(origin_rate_vector.plus)))
        origin_rate_vector.minus.extend([0] * (seq_len - len(origin_rate_vector.minus)))
    try:
        if transcript.get_strand() == '-':
            transcript_rec.letter_annotations['origin_probability_plus'] = origin_rate_vector.minus[::-1][:seq_len]
            transcript_rec.letter_annotations['origin_probability_minus'] = origin_rate_vector.plus[::-1][:seq_len]
        else:
            transcript_rec.letter_annotations['origin_probability_plus'] = origin_rate_vector.plus[:seq_len]
            transcript_rec.letter_annotations['origin_probability_minus'] = origin_rate_vector.minus[:seq_len]
    except TypeError, te:
        print >> sys.stderr, te
        print >> sys.stderr, transcript_rec.id, 'transcript length:', seq_len, 'origin_rate_vector length:', len(
            origin_rate_vector)
        logging.error('origin_rate_vector length error for %s' % transcript_rec.id)
        return None
    return transcript_rec


def load_transcripts(gene_filename):
    """Make lists of isoforms for each gene

    :param gene_filename: Name of a .gff3 file describing genes, transcripts, and exons
    :type gene_filename: String
    :return: transcript_map
    :rtype: defaultdict(list)
    """
    transcript_map = defaultdict(list)
    for gene in GFF3Iterator(open(gene_filename)).genes():
        gene_name = gene.get_name()
        for transcript in gene.get_transcripts():
            transcript_map[gene_name].append(transcript)
    return transcript_map


def add_sam_pair_info(sam_pair, insert_size):
    if sam_pair[0].is_reversed():
        sam_pair[1].set_mate_reverse_strand_flag()
    else:
        sam_pair[0].set_mate_reverse_strand_flag()
    sam_pair[0].mrnm = sam_pair[1].rname
    sam_pair[1].mrnm = sam_pair[0].rname
    sam_pair[0].mpos = sam_pair[1].pos
    sam_pair[1].mpos = sam_pair[0].pos
    sam_pair[0].set_first_of_pair_flag()
    sam_pair[1].set_second_of_pair_flag()
    if sam_pair[0].pos < sam_pair[1].pos:
        sam_pair[0].isize = insert_size
        sam_pair[1].isize = -insert_size
    else:
        sam_pair[0].isize = -insert_size
        sam_pair[1].isize = insert_size
    return sam_pair


def generate_SAM_header_from_fai(fai_filename):
    header_lines = []
    fai = open(fai_filename)
    for line in fai:
        fields = line.strip().split()
        ref_name = fields[0]
        ref_length = fields[1]
        hdr_line = '@SQ\tSN:%s\tLN:%s' % (ref_name, ref_length)
        header_lines.append(hdr_line)
    fai.close()
    return '\n'.join(header_lines)


def generate_SAM_header_from_fasta(fasta_filename):
    header_lines = []
    genome_index = SeqIO.index(fasta_filename, 'fasta')
    for ref_name in sorted(genome_index.keys()):
        ref_length = len(genome_index[ref_name])
        hdr_line = '@SQ\tSN:%s\tLN:%s' % (ref_name, ref_length)
        header_lines.append(hdr_line)
    return '\n'.join(header_lines)


def get_command_line():
    usage = "Usage: %prog  [options] gene_models.gff3 gene_id-copy_numbers.txt origin_prob_dict.shelf " \
            "output_filename_base probability_files_directory genome.fa"
    parser = OptionParser(usage=usage)
    parser.add_option("-l", "--length", dest="read_length", type='int', default=38,
                      help="length of a read; default: %default")
    parser.add_option("-1", "--single", dest="pairs", action="store_false", default=True,
                      help="single or paired reads; default: paired")
    parser.add_option("-2", "--pair", dest="pairs", action="store_true")
    parser.add_option("--minsize", dest="size_min", type='int', default=150,
                      help="lower bound for size filter; default: %default")
    parser.add_option("--lowsize", dest="size_lower", type='int', default=175,
                      help="lower end of pass range for size filter; default: %default")
    parser.add_option("--highsize", dest="size_upper", type='int', default=250,
                      help="upper end of pass range for size filter; default: %default")
    parser.add_option("--maxsize", dest="size_max", type='int', default=300,
                      help="upper bound for size filter; default: %default")
    parser.add_option("-i", "--indelrate", dest="indel_rate", type='float', default=0,
                      help="probability of an indel error at any position; default: %default")
    parser.add_option("-s", "--subrate", dest="subst_rate", type='float', default=0,
                      help="probability of a substitution error at any position; default: %default")
    parser.add_option("-N", "--Nrate", dest="N_rate", type='float', default=0,
                      help="probability that a substitution will introduce an N; default: %default")
    (options, args) = parser.parse_args()
    if len(args) != 6:
        parser.error("Wrong number of arguments")
        parser.print_help()
        raise RuntimeError("Wrong number of arguments")
    return (options, args)


if __name__ == '__main__':
    try:
        options, args = get_command_line()
        transcript_map = load_transcripts(args[0])
        gene_copies = gene_copy_iterator(args[1])
        total_copies = sum([copy_number for gene_id, copy_number in gene_copy_iterator(args[1])])
        if options.pairs:
            total_copies /= 2
        gene_copies = gene_copy_iterator(args[1])
        origin_prob_dict = shelve.open(args[2], 'r')

        output_filename_base = os.path.abspath(args[3])
        working_dir = os.path.dirname(output_filename_base)
        basename = os.path.basename(output_filename_base)
        output_file_name = output_filename_base + '.true_mappings.sam'
        output = open(output_file_name, 'w')
        probability_files_dir = os.path.abspath(args[4])
        error_prob, BBB_error_prob, BBB_init_prob, correct_qual_freq, incorrect_qual_freq = load_probability_tables(
            probability_files_dir)

        genome_seq = args[5]
        genome_index = SeqIO.to_dict(SeqIO.parse(open(genome_seq), 'fasta'))

        if options.read_length > len(error_prob):
            raise ValueError('Error probability vector must be at least as long as reads')
        if options.size_min <= options.read_length:
            raise ValueError('Read length must be less than minimum fragment length')
        if options.pairs and options.size_lower < 2 * options.read_length:
            print >> sys.stderr, '''
            WARNING: Read length is too close to target fragment length!
             Some read pairs will have negative insert size.
             '''
    except Exception, e:
        print >> sys.stderr, e
        sys.exit(1)

    LOG_FILENAME = output_filename_base + '_simulateRNA-Seq.log'
    logging.basicConfig(filename=LOG_FILENAME, level=logging.DEBUG, filemode='w')

    logging.info('read_length = %d' % options.read_length)
    logging.info('pairs = %s' % options.pairs)
    logging.info('size_min = %d' % options.size_min)
    logging.info('size_lower = %d' % options.size_lower)
    logging.info('size_upper = %d' % options.size_upper)
    logging.info('size_max = %d' % options.size_max)
    logging.info('indel_rate = %f' % options.indel_rate)
    logging.info('subst_rate = %f' % options.subst_rate)
    logging.info('N_rate = %f' % options.N_rate)
    logging.info('transcripts file = %s' % os.path.abspath(args[0]))
    logging.info('transcript copy numbers file = %s' % os.path.abspath(args[1]))
    logging.info('predicted total number of %s = %d' % ('read pairs' if options.pairs else 'reads', total_copies))
    logging.info('origin probability file = %s' % os.path.abspath(args[2]))
    logging.info('output file = %s' % output_file_name)
    logging.info('error probability file = %s' % os.path.join(probability_files_dir, 'error_probabilities.txt'))
    logging.info('BBB error probability file = %s' % os.path.join(probability_files_dir, 'BBB.error_probabilities.txt'))
    # Probability of converting to BBB state if the current quality is 2 and the current base is not N
    logging.info('BBB conditional transition probability file = %s' % os.path.join(probability_files_dir,
                                                                                   'BBB.init_probabilities.txt'))
    logging.info('Correct qualities file = %s' % os.path.join(probability_files_dir,
                                                              'Correct_reads.quality_scores.cumulative_frequency.txt'))
    logging.info('Incorrect qualities file = %s' % os.path.join(probability_files_dir,
                                                                'incorrect_reads.quality_scores.cumulative_frequency'
                                                                '.txt'))
    logging.info('\n')

    logging.info('\tReads\tID\tLength')

    error_prob, BBB_error_prob, BBB_init_prob = adjust_error_rate(options.subst_rate, options.read_length, error_prob,
                                                                  BBB_error_prob, BBB_init_prob)

    fai_path = genome_seq + '.fai'
    if os.path.isfile(fai_path) and os.path.getsize(fai_path) > 0:
        header = generate_SAM_header_from_fai(fai_path)
    else:
        header = generate_SAM_header_from_fasta(genome_seq)
    print >> output, header

    codes = [str(i) for i in xrange(1, total_copies + 1)]
    random.shuffle(codes)

    coords = transcript_coords()
    total = 0
    code_index = 0
    for gene_id, copy_number in gene_copies:
        try:
            print gene_id, copy_number
            if gene_id in transcript_map:
                isoform_list = transcript_map[gene_id]
            else:
                print >> sys.stderr, 'No transcript found for', gene_id
                logging.error('No transcript found for ' + gene_id)
                continue
            if len(isoform_list) > 1 and isoform_list[0].has_attribute('depth'):
                total_height = sum([float(isoform.get_attribute('depth')) for isoform in isoform_list])
                isoform_copies = [int(0.5 + copy_number * float(isoform.get_attribute('depth')) / total_height) for
                                  isoform in isoform_list]
            else:
                isoform_copies = [copy_number]
            for transcript, copy_number in zip(isoform_list, isoform_copies):
                transcript_rec = get_transcript_record(transcript, genome_index, options, origin_prob_dict)
                if transcript_rec is None:
                    continue
                coords.add(transcript)
                if copy_number > 0:
                    if options.pairs:
                        readpairset = generateShortReads.generatePairedReadSet(transcript_rec, copy_number,
                                                                               options.read_length, options.size_min,
                                                                               options.size_lower, options.size_upper,
                                                                               options.size_max, options.indel_rate)
                        i = 0
                        logging.info('\t%d\t%s\t%d\t%d' % (
                        len(readpairset) * 2, transcript_rec.id, len(transcript_rec), copy_number))
                        for readpair in readpairset:
                            try:
                                for read in readpair:
                                    if 'cigar_list' not in read.annotations:
                                        read.annotations['cigar_list'] = [[options.read_length, 'M']]
                                    cigar_list = read.annotations['cigar_list']
                                #                                    pass
                                read_id = '%(transcript_id)s[%(strand)s]:%(start)d..%(end)d' % readpair[0].annotations
                                sam = []
                                for j in range(2):
                                    read = readpair[j]
                                    if options.subst_rate > 0:
                                        read = introduceSubstitutions(read, error_prob, BBB_error_prob, BBB_init_prob,
                                                                      correct_qual_freq, incorrect_qual_freq,
                                                                      options.N_rate)
                                    else:
                                        read.letter_annotations['phred_quality'] = illumina_quality_code[38] * len(read)
                                    sam.append(generateSam(read, coords))
                                sam = add_sam_pair_info(sam, readpair[0].annotations['isize'])

                                for j in range(2):
                                    sam[j].set_paired_flag()
                                    sam[j].set_proper_pair_flag()
                                    sam[j].qname = codes[code_index]
                                    print >> output, sam[j]
                                code_index += 1
                                i += 1
                            except ValueError, ve:
                                logging.error(ve)
                    else:
                        readset = generateShortReads.generateUnpairedReadSet(transcript_rec, copy_number,
                                                                             options.read_length, options.size_min,
                                                                             options.size_lower, options.size_upper,
                                                                             options.size_max, options.indel_rate)
                        i = 0
                        logging.info('\t%d\t%s\t%d' % (len(readset), transcript_rec.id, len(transcript_rec)))
                        for read in readset:
                            try:
                                if options.subst_rate > 0:
                                    read = introduceSubstitutions(read, error_prob, BBB_error_prob, BBB_init_prob,
                                                                  correct_qual_freq, incorrect_qual_freq,
                                                                  options.N_rate)
                                else:
                                    read.letter_annotations['phred_quality'] = illumina_quality_code[38] * len(read)
                                read_id = '%(transcript_id)s[%(strand)s]:%(start)d..%(end)d' % read.annotations
                                sam = generateSam(read, coords)
                                sam.qname = codes[code_index]
                                print >> output, sam
                                i += 1
                                code_index += 1
                            except ValueError, ve:
                                logging.error(ve)
                    total += i
                    print  'Total: {:n}'.format(total)
        except Exception, e:
            logging.error(e)
            logging.error('in transcript {}'.format(gene_id))
            traceback.print_exc()
    if options.pairs:
        summary = "Generated {0:n} pairs of reads of length {1:d}".format(total, options.read_length)
    else:
        summary = "Generated {0:n} reads of length {1:d}".format(total, options.read_length)
    logging.info(summary)
    print
    print summary
    if generateShortReads.candidates:
        accept_rate = float(generateShortReads.accepts) / generateShortReads.candidates
    else:
        accept_rate = 0
    accept_pct = "Size filter accept rate was {:.2%}".format(accept_rate)
    logging.info(accept_pct)
    print accept_pct

    output.close()
    print >> sys.stderr, sys.argv[0], 'done. '

