###############################################################################
# Imports modules

from __future__ import division
import csv
import time
import argparse
import sys
import random
import re
from libraryFun import *


###############################################################################
# Parses input using argparse module

# Initiates input parser
parser = argparse.ArgumentParser(description='Extracts elements')

# Non-optional arguments: The file bases containing counts for untreated and
# treated samples; and the screen type (human or mouse)
parser.add_argument('oligo_file', help='File for oligos', type=str)

parser.add_argument('name', help='Name for output file', type=str)

# Optional arguments:
parser.add_argument('-x', '--exclude', type=str)

parser.add_argument('-t', '--test', action='store_true')

parser.add_argument('-d', '--file_delimit', type=str, default='\t')

parser.add_argument('-r', '--ref_file', type=str)

parser.add_argument('-n', '--none_num', help='Number of nontargeting controls',
                type=int, default=0)

parser.add_argument('-sn', '--safe_num', help='Number of safe controls',
                type=int, default=0)

parser.add_argument('-p', '--primers', type=str)

parser.add_argument('-s', '--strip', type=int)

# Saves all input to object args
args = parser.parse_args()


###############################################################################
#

out_file, legend_file = dissectFolder(args.oligo_file, args.name)

name2primer = {}

try:

    with open(legend_file, 'r') as legend_open:
        legend_csv = csv.reader(legend_open)

        for line in legend_csv:
            name2primer[line[0]] = eval(line[1])

except IOError:
    print('No legend file found, creating new one')

oligos = []

with open(args.oligo_file, 'rU') as oligo_open:
    oligo_csv = csv.reader(oligo_open, delimiter=args.file_delimit)

    for line in oligo_csv:

        try:
            if line:
                if args.strip:
                    new_oligo = line[1].upper()[args.strip:(-1 * args.strip)]
                    oligos.append([line[0], new_oligo])
                    example_seq = line[1].upper()
                else:
                    oligos.append([line[0], line[1].upper()])
                    example_seq = line[1].upper()
        except IndexError:
            sys.exit('Error: Failure to parse oligo file.\n' + str(line))


###############################################################################
#

if args.ref_file:
    name2seq = {}
    gene2names = defaultdict(list)

    with open(args.ref_file, 'r') as ref_open:
        ref_csv = csv.reader(ref_open, delimiter=',')

        for line in ref_csv:

            name, seq = line

            ens, gene, cat, ID = name.split('_')
            gene2names[ens] += [name]

            gene2names[gene.upper()] += [name]
            name2seq[name] = seq
            example_seq = seq.upper()


    none_seqs, none = retrieveNegatives('NONE', gene2names, name2seq, args.none_num)
    safe_seqs, safe = retrieveNegatives('SAFE', gene2names, name2seq, args.safe_num)

    oligos.extend(none_seqs)
    oligos.extend(safe_seqs)

    print('Total scrambled guides: ' + str(none))
    print('Total safe harbour guides: ' + str(safe))


###############################################################################
#

if args.name in name2primer:
    print('Warning: deleting previous primer assignment')
    del name2primer[args.name]

restriction_sites = ['GAATTC', 'CTCGAG', 'AAGCTT', 'CTGCAG', 'CGTCTC', 'GAGACG',
                        'GAAGAC', 'GTCTTC','CCA......TGG','GCT.AGC']

if not args.primers:

    up_primer, down_primer, primer = pickPrimer(restriction_sites, name2primer, example_seq[:5], example_seq[-5:])

    name2primer[args.name] = primer
    
else:
    name2primer[args.name] = str((args.primers, (args.primers, args.primers)))


###############################################################################
#

output = []

for oligo in oligos:
    
    if not args.primers:

        output.append([oligo[0], up_primer + oligo[1] + down_primer])
    else:
        output.append([oligo[0], oligo[1]])

output = dict((x[0], x) for x in output).values()
output.sort(key=lambda x: x[0])

print('Total elements: ' + str(len(output)))


###############################################################################
#

if args.test:
    sys.exit('Warning: Files not written')

with open(out_file + '_align.csv', 'w') as out_open:
    out_csv = csv.writer(out_open, delimiter=',', lineterminator='\n')

    for line in output:
        out_csv.writerow(line)

legend = name2primer.items()
legend.sort(key=lambda x: x[0])

with open(legend_file, 'w') as legend_open:
    legend_csv = csv.writer(legend_open, delimiter=',', lineterminator='\n')

    for line in legend:
        legend_csv.writerow(line)


###############################################################################ust
