###############################################################################
# Imports modules

from __future__ import division
import csv
import time
import argparse
import sys
import random
from collections import defaultdict
import re
import os
from libraryFun import *


###############################################################################
# Parses input using argparse module

# Initiates input parser
parser = argparse.ArgumentParser(description='Extracts elements')

# Non-optional arguments: The file bases containing counts for untreated and
# treated samples; and the screen type (human or mouse)
parser.add_argument('gene_file', help='File for genes', type=str)

parser.add_argument('ref_file', help='File for alignments', type=str)

parser.add_argument('name', help='Name for output file', type=str)

# Optional arguments:
parser.add_argument('-s', '--split_mark', help='Custom delimiters',
                type=str, default='_')

parser.add_argument('-n', '--none_num', help='Number of nontargeting controls',
                type=int, default=1000)

parser.add_argument('-sn', '--safe_num', help='Number of safe controls',
                type=int, default=1000)

parser.add_argument('-sh', '--scramble_num', help='Number of scrambled controls',
                type=int, default=2000)

parser.add_argument('-e', '--element_num',
                help='Maximum number of elements', type=int, default=10)

parser.add_argument('-of', '--override_file', action='store_true')

parser.add_argument('-t', '--test', action='store_true')

# Saves all input to object args
args = parser.parse_args()

###############################################################################
#

out_file, legend_file = dissectFolder(args.gene_file, args.name)

Cas9, shRNA = False, False

if args.ref_file in ['Data/sgRNA_human_10_ref.csv', 'Data/sgRNA_mouse_10_ref.csv']:
    Cas9 = True
    print('CRISPR deletion screen detected')

elif args.ref_file in ['Data/shRNA_human_25_ref.csv', 'Data/shRNA_mouse_ref.csv']:
    shRNA = True
    print('shRNA screen detected')

elif args.ref_file in ['Data/crisprI_mouse_0701_ref.csv', 'Data/crisprA_human_0710_ref.csv', 'Data/crisprI_human_20160604_ref.csv', 'Data/crisprA_mouse_20170418_ref.csv']:
    print('Misc screen detected')

else:
    sys.exit('Error: Unknown data file')

if args.ref_file in ['Data/sgRNA_mouse_10_ref.csv', 'Data/shRNA_mouse_ref.csv', 'Data/crisprI_mouse_0701_ref.csv', 'Data/crisprA_mouse_20161020_ref.csv']:
    maus = True
    print('Mouse screen detected')
else:
    maus = False


###############################################################################
#

print('Reading files')

name2primer = {}

try:

    with open(legend_file, 'r') as legend_open:
        legend_csv = csv.reader(legend_open)

        for line in legend_csv:
            name2primer[line[0]] = eval(line[1])

except IOError:
    print('No legend file found, creating new one')

genes = set()

with open(args.gene_file, 'rU') as gene_open:
    gene_csv = csv.reader(gene_open)

    for line in gene_csv:

        if len(line) > 0:
            genes.add(line[0].upper().strip())

name2seq = {}
gene2names = defaultdict(list)

with open(args.ref_file, 'r') as ref_open:
    ref_csv = csv.reader(ref_open, delimiter=',')

    for line in ref_csv:

        try:
            name, seq = line
        except:
            sys.exit('Error: ref file syntax' + '\n' + str(line))

        if Cas9:
            ens, gene, cat, ID = name.split(args.split_mark)
            gene2names[ens] += [name]

        elif shRNA:
            gene, lib, pin_num = name.split(args.split_mark)
            ID = args.split_mark.join([gene, pin_num])

        else:
            gene = name.split(args.split_mark)[0]

        gene2names[gene.upper()] += [name]
        name2seq[name] = seq
        example_seq = seq


###############################################################################
#

print('Finding primer')

if args.name in name2primer:
    print('Warning: deleting previous primer assignment')
    del name2primer[args.name]

restriction_sites = ['CTGCAG','GAAGAC','GTCTTC','CCA......TGG','GCT.AGC']

up_primer, down_primer, primer = pickPrimer(restriction_sites, name2primer, example_seq[:5], example_seq[-5:])

name2primer[args.name] = primer


###############################################################################
#

print('Finding elements')

seqs = []
gene_num = 0

geneID2Name, geneID2Info, geneName2ID, geneEns2Name = retrieveInfo(mouse=maus)

for gene in genes:

    geneID, name, entrez = retrieveIDs(gene, geneID2Name, geneName2ID, geneEns2Name)

    if gene in gene2names:
        names = gene2names[gene]

    elif geneID in gene2names:
        names = gene2names[geneID]

    elif name in gene2names:
        names = gene2names[name]

    elif entrez in gene2names:
        names = gene2names[entrez]

    else:
        print('Warning: ' + gene + ' not found')
        continue

    gene_num += 1

    for name in names:

        seq = name2seq[name]
        seqs.append([name, seq])

seqs = dict((x[0], x) for x in seqs).values()
seqs.sort(key=lambda x: x[0])

print('Number of genes: ' + str(gene_num))
print('Total targeting elements: ' + str(len(seqs)))


###############################################################################
# Negative controls

if Cas9:

    none_seqs, none = retrieveNegatives('NONE', gene2names, name2seq, args.none_num)
    safe_seqs, safe = retrieveNegatives('SAFE', gene2names, name2seq, args.safe_num)

    seqs.extend(none_seqs)
    seqs.extend(safe_seqs)

    print('Total scrambled guides: ' + str(none))
    print('Total safe harbour guides: ' + str(safe))

elif shRNA:

    scramble_seqs, scramble = retrieveNegatives('0', gene2names, name2seq, args.scramble_num)
    seqs.extend(scramble_seqs)

    print('Total scrambled hairpins: ' + str(scramble))

else:

    none_seqs, none = retrieveNegatives('0NONE', gene2names, name2seq, args.none_num)
    safe_seqs, safe = retrieveNegatives('0SAFE', gene2names, name2seq, args.safe_num)

    seqs.extend(none_seqs)
    seqs.extend(safe_seqs)

    print('Total scrambled guides: ' + str(none))
    print('Total safe harbour guides: ' + str(safe))


###############################################################################
#

output = []

for name, seq in seqs:

    output.append([name, up_primer + seq + down_primer])


print('Total elements: ' + str(len(output)))


###############################################################################
#

if args.test:
    sys.exit('Warning: Files not written')

with open(out_file + '_align.csv', 'w') as out_open:
    out_csv = csv.writer(out_open, delimiter=',', lineterminator='\n')

    for line in output:
        out_csv.writerow(line)

legend = name2primer.items()
legend.sort(key=lambda x: x[0])

with open(legend_file, 'w') as legend_open:
    legend_csv = csv.writer(legend_open, delimiter=',', lineterminator='\n')

    for line in legend:
        legend_csv.writerow(line)



###############################################################################
