
# Copyright (c) 2013, Ian Reid, Concordia University Centre for Structural and Functional Genomics
# All rights reserved.

'''Process a master simulation file into separate read1, read2, and sam files with random labels

Input file should be a sequence of stanzas with the format:
readset name
read1 qual1
sam1
[ read2 qual2
sam 2]
< blank line>

If the read2, sam2 lines are omitted, the empty ..._2.fastq file produced can be deleted.

Created on Jun 1, 2010
@author: ian
'''

import sys, os, random
import logging
from Bio import SeqIO

def usage():
    print 'Usage: python %s input.txt sample_num genome.fai' % sys.argv[0]
    sys.exit(1)

def generate_SAM_header_from_fai(fai_filename):
    header_lines = []
    fai = open(fai_filename)
    for line in fai:
        fields = line.strip().split()
        ref_name = fields[0]
        ref_length = fields[1]
        hdr_line = '@SQ\tSN:%s\tLN:%s' % (ref_name, ref_length)
        header_lines.append(hdr_line)
    fai.close()
    return '\n'.join(header_lines)

def generate_SAM_header_from_fasta(fasta_filename):
    header_lines = []
    genome_index = SeqIO.index(fasta_filename, 'fasta')
    for ref_name in sorted(genome_index.keys()):
        ref_length = len(genome_index[ref_name])
        hdr_line = '@SQ\tSN:%s\tLN:%s' % (ref_name, ref_length)
        header_lines.append(hdr_line)
    return '\n'.join(header_lines)

def format_fq(line, label):
    try:
        read, qual = line.split('\t')
        return '@%s\n%s\n+\n%s' % (label, read, qual)
    except ValueError, ve:
        raise ValueError('|'.join([label, line, str(ve)]))

def relabel_sam(line, label):
    fields = line.split('\t')
    fields[0] = label
    return '\t'.join(fields)

def output_group(group, fq_list, samout, label):
    if len(group) < 3:
        return
#    logging.info(label + '\t' + group[0])
    try:
        if len(group) > 4:
            print >> fq_list[0], format_fq(group[1], label + '/1')
            print >> samout, relabel_sam(group[2], label+ '/1')
            print >> fq_list[1], format_fq(group[3], label+ '/2')
            print >> samout, relabel_sam(group[4], label+ '/2')
        else:
            print >> fq_list[0], format_fq(group[1], label)
            print >> samout, relabel_sam(group[2], label)
    except Exception, e:
        logging.error(e)

def doRepackageSimulation(input,  sample_num,  samout,  fqs,  fai_path):
    if os.path.isfile(fai_path) and os.path.getsize(fai_path) > 0:
        header = generate_SAM_header_from_fai(fai_path)
    else:
        fasta_path = fai_path[:-4]
        header = generate_SAM_header_from_fasta(fasta_path)
    print >> samout, header

    codes = [str(i) for i in xrange(1, sample_num + 1)]
    random.shuffle(codes)

    group = []
    i = 0
    for line in input:
        if line == '\n' :
            output_group(group, fqs, samout, codes[i])
            group = []
            i += 1
        else:
            group.append(line.strip())


if __name__ == '__main__':
    try:
        input = open(sys.argv[1])
        sample_num = int(sys.argv[2])
        basepath = os.path.splitext(sys.argv[1])[0]
        samout = open(basepath + '.true_mappings.sam', 'w')
        fqs = []
        for j in range(2):
            fqs.append(open(basepath + '_%d.fastq' % (j + 1), 'w'))
        fai = sys.argv[3]
    except IndexError:
        usage()


    LOG_FILENAME = basepath +  '.repackage.log'
    logging.basicConfig(filename=LOG_FILENAME,level=logging.DEBUG, filemode = 'w')
    doRepackageSimulation(input,  sample_num,  samout,  fqs,  fai)

    samout.close()
    for j in range(2):
        fqs[j].close()
    print sys.argv[0], 'done. '
