##############################################################################
# David Morgens
##############################################################################
# Import necessary modules

import os
import sys
import argparse
import csv
from collections import defaultdict
import random
import re
from libraryFun import *


##############################################################################

current_version = '1.0'


##############################################################################
# Parses input using argparse module

# Initiates argument parser
parser = argparse.ArgumentParser(description='Find top guides and shRNAs for '
                                               'each query gene')

# Non-optional arguments: List of genes, screen type and name of output file

parser.add_argument('gene_file', type=str)

parser.add_argument('off_file', type=str)

parser.add_argument('ref_file', type=str)

parser.add_argument('name', help='name of output files', type=str)

# Optional arguments:
# Specify number of guides or shRNA
parser.add_argument('-g', '--element_num', help='number of guides or hairpins', default=3, type=int)

parser.add_argument('-s', '--split_mark', help='Custom delimiters',
                type=str, default='_')

parser.add_argument('-n', '--none_num', help='Number of scrambled controls',
                type=int, default=25)

parser.add_argument('-sn', '--safe_num', help='Number of safe controls',
                type=int, default=25)

parser.add_argument('-sh', '--scramble_num', help='Number of scrambled hairpins',
                type=int, default=50)

parser.add_argument('-st', '--screen_type', help='Screen type(s) to use',
                type=str, nargs='+', default=[])

parser.add_argument('-l', '--L_thresh', help='Cutoff to keep data',
                type=float, default=10.0)

parser.add_argument('-t', '--test', action='store_true')

parser.add_argument('-x', '--exclude', type=str, nargs='+')

parser.add_argument('-b', '--batch', action='store_true')

parser.add_argument('-o', '--only', action='store_true')

parser.add_argument('-v', '--verbose', action='store_true')

# Saves input to args object
args = parser.parse_args()


###############################################################################
#

out_file, legend_file = dissectFolder(args.gene_file, args.name)

Cas9, shRNA, maus, Int = False, False, False, False

if args.ref_file == 'Data/sgRNA_human_10_ref.csv':
    Cas9 = True
    args.screen_type += ['Cas9-10', 'CHL_L1_batch', 'EJ_legin', 'KH_batch', 'KH_drug2', 'KH_kras', 'KH_kras_DTKP', 'DM_high', 'DM_ricin_ko']
    print('CRISPR deletion screen detected')

elif args.ref_file == 'Data/sgRNA_mouse_10_ref.csv':
    Cas9 = True
    args.screen_type += ['mm-Cas9-10']
    maus = True
    print('CRISPR deletion screen detected')

elif args.ref_file == 'Data/crisprI_human_20160604_ref.csv':
    Cas9 = True
    Int = True
    args.screen_type += ['DM_ricin_int']
    print('CRISPRi screen detected')

elif args.ref_file == 'Data/shRNA_human_25_ref.csv':
    shRNA = True
    args.screen_type += ['shRNA']
    print('shRNA screen detected')

elif args.ref_file == 'Data/shRNA_mouse_ref.csv':
    shRNA = True
    args.screen_type += ['mouse']
    maus = True
    print('shRNA screen detected')

else:
    sys.exit('Error: Unknown data file')


##############################################################################
# Read in list of query genes and convert into list

geneID2Name, geneID2Info, geneName2ID, geneEns2Name = retrieveInfo(mouse=maus)

genes = []

with open(args.gene_file, 'rU') as gene_open:

    gene_csv = csv.reader(gene_open, delimiter='\t')

    for line in gene_csv:

        if not line:
            continue

        gene = line[0].upper().strip()

        if shRNA and gene in geneEns2Name:
            gene = geneEns2Name[gene]
        
        if gene not in genes:
            genes.append(gene)


##############################################################################
# Locate all record files generate by analyzeCounts in Results

result_files = retrieveResults('analyzeCounts.py', args.screen_type, current_version)

print('Result files found: ' + str(len(result_files)))


##############################################################################
# Find gene-of-interest in all result files, retrieve information required

L_stat_col = 8
effect_col = 7
element_col = 12
num_col = 6
ID_col = 0

cols = (L_stat_col, num_col, effect_col, element_col, ID_col)

if Cas9:
    num_thresh = 7
elif shRNA:
    num_thresh = 20
else:
    sys.exit('What?')

threshes = (args.L_thresh, 1.5, num_thresh)

print('Extracting lines')
gene2lines = extractLines(genes, result_files, args.split_mark, element_col)

print('Extracting elements')
gene2likes, result_errors = extractLikes(gene2lines, cols, threshes,
                                            Cas9, shRNA, args.split_mark)

print('Number of errors in extraction: ' + str(len(result_errors)))


##############################################################################
#

print('Reading off-targets')

elementID2name_seq = {}
name2seq = {}
gene2names = defaultdict(list)
guide2name = {}

with open(args.ref_file, 'r') as ref_open:
    ref_csv = csv.reader(ref_open, delimiter=',')

    for line in ref_csv:

        name, seq = line

        if Cas9:
            try:
                ens, gene, cat, ID = name.split(args.split_mark)
            except ValueError:
                continue
            gene2names[ens] += [name]

        elif shRNA:
            gene, cat, pin_num = name.split(args.split_mark)
            ID = args.split_mark.join([gene, pin_num])
        
       # elif Int:
        #    gene, cat, pin_num = name.split(args.split_mark)
        #    ID = args.split_mark.join([gene, pin_num])

        if args.exclude:
            if gene == '0' and cat not in args.exclude:
                continue

        gene2names[gene.upper()] += [name]
        name2seq[name] = seq
        guide2name[seq[11:-17]] = name
        elementID2name_seq[ID] = (name, seq)

        example_seq = seq

off_filter = []

with open(args.off_file, 'r') as off_open:

    off_csv = csv.reader(off_open, delimiter='\t', lineterminator = '\n')

    for line in off_csv:
        
        if shRNA or Int:
            break
        
        chrom, start, end, strand, guide, gene, score, effect, pval, pval, rho, \
                off_exon, off_all, bah, bah2, off_pos_all, off_trans_all = line

        name = guide2name[guide]
        guideID = name.split('_')[-1]
        
        off_all_int = map(int, off_all.split(','))
        mis0, mis1, mis2, mis3, mis4 = off_all_int

        if mis0 or mis1:
            off_filter.append(name)

print(str(len(off_filter)) + ' off-target guides found')


##############################################################################
#

print('Selecting elements')

seqs = []
missed_genes = []
filtered = 0

for gene in genes:

    if not gene in gene2likes:
        #print('Warning: ' + gene + ' has no hits')
        missed_genes.append(gene)
        continue

    like_elements = gene2likes[gene]

    # Sort by L statistic within each gene
    like_elements.sort(key=lambda x: x[0], reverse=True)
    
    for screen in like_elements:
        IDs = screen[1]
        gene_seqs = []
        
        for element_ID in IDs:
            if element_ID in elementID2name_seq:
                name, seq = elementID2name_seq[element_ID]

                if name in off_filter:
                    filtered += 1
                    continue
            
            gene_seqs.append([name, seq])
        
        if len(gene_seqs) >= args.element_num:
            break
        
    if len(gene_seqs) < args.element_num:
        print('Warning: not enough good elements for ' + gene)
        missed_genes.append(gene)
    else:
    
        seqs.extend(gene_seqs[:args.element_num])
        

print('Out of ' + str(len(genes)) + ' genes, ' +
        str(len(missed_genes)) + ' have no corresponding hits')


##############################################################################
#

name2primer = {}

try:

    with open(legend_file, 'r') as legend_open:
        legend_csv = csv.reader(legend_open)

        for line in legend_csv:
            name2primer[line[0]] = eval(line[1])

except IOError:
    print('No legend file found, creating new one')


###############################################################################
#

for gene in missed_genes:

    if args.only:
        if args.verbose:
            print('Warning: ' + str(gene) + ' excluded')
        continue

    geneID, name, entrez = retrieveIDs(gene, geneID2Name, geneName2ID, geneEns2Name)

    if gene in gene2names:
        names = gene2names[gene]

    elif geneID in gene2names:
        names = gene2names[geneID]

    elif name in gene2names:
        names = gene2names[name]

    elif entrez in gene2names:
        names = gene2names[entrez]

    else:
        print('Warning: ' + gene + ' not found')
        continue

    if len(names) != len(set(names)):
        sys.exit(names)

    if len(names) < args.element_num:
        print('Warning: ' + gene + ' has too few elements')
        if not args.batch:
            continue

    count = 0

    for name in names:

#        if name in off_filter:
#            filtered += 1
#            continue

        seq = name2seq[name]
        seqs.append([name, seq])
        count += 1

        if count >= args.element_num:
            break

    if count != args.element_num:
        print('Warning: ' + gene + ' has too few elements')

print(str(filtered) + ' guides removed for off-targets')


###############################################################################
#

seqs = dict((x[0], x) for x in seqs).values()
seqs.sort(key=lambda x: x[0])

print('Total targeting elements: ' + str(len(seqs)))


###############################################################################
# Negative controls

if Cas9:
    safe_names = gene2names['SAFE']
    none_names = gene2names['NONE']

    random.shuffle(safe_names)
    random.shuffle(none_names)

    safe = 0
    none = 0

if shRNA:
    scramble_names = gene2names['0']
    random.shuffle(scramble_names)
    scramble = 0

if Cas9:

    for name in safe_names[:args.safe_num]:

        seq = name2seq[name]
        seqs.append([name, seq])
        safe += 1

    for name in none_names[:args.none_num]:

        seq = name2seq[name]
        seqs.append([name, seq])
        none += 1

    print('Total scrambled guides: ' + str(none))
    print('Total safe harbour guides: ' + str(safe))

if shRNA:

    for name in scramble_names[:args.scramble_num]:

        seq = name2seq[name]
        seqs.append([name, seq])
        scramble += 1

    print('Total scrambled hairpins: ' + str(scramble))


###############################################################################
#

if args.name in name2primer:
    print('Warning: deleting previous primer assignment')
    del name2primer[args.name]
    if Cas9:
        del name2primer[args.name + '_hU6']

restriction_sites = ['CTGCAG','GAAGAC','GTCTTC','CCA......TGG','GCT.AGC']

up_primer, down_primer, primer = pickPrimer(restriction_sites, name2primer, example_seq[:5], example_seq[-5:])

name2primer[args.name] = primer

if Cas9:

    if args.name + '_hU6' in name2primer:
        print('Warning: deleting previous primer assignments')
        del name2primer[args.name]

    hU6_up = 'gctttatatatcttgtggaaaggacgaaacacc'
    hU6_down = 'gtttaagagctaagctggaaacagcatagcaagtttaaataag'

    up_primer_hU6, down_primer_hU6, primer_hU6 = pickPrimer(restriction_sites, name2primer, '', '')

    name2primer[args.name + '_hU6'] = primer_hU6


###############################################################################
#

output = []

if Cas9:
    output_hU6 = []

for name, seq in seqs:

    output.append([name, up_primer + seq + down_primer])

    if Cas9:
        seq_raw = seq[11:-17]
        output_hU6.append([name, up_primer_hU6 + hU6_up + seq_raw + hU6_down + down_primer_hU6])

print('Total elements: ' + str(len(output)))


###############################################################################
#

if args.test:
    sys.exit('Warning: Files not written')

with open(out_file + '_align.csv', 'w') as out_open:
    out_csv = csv.writer(out_open, delimiter=',', lineterminator='\n')

    for line in output:
        out_csv.writerow(line)

if Cas9:

    with open(out_file + '_hU6_align.csv', 'w') as out_open:
        out_csv = csv.writer(out_open, delimiter=',', lineterminator='\n')

        for line in output_hU6:
            out_csv.writerow(line)

legend = name2primer.items()
legend.sort(key=lambda x: x[0])

with open(legend_file, 'w') as legend_open:
    legend_csv = csv.writer(legend_open, delimiter=',', lineterminator='\n')

    for line in legend:
        legend_csv.writerow(line)


###############################################################################
