##############################################################################
# David Morgens
##############################################################################
# Import necessary modules

import os
import sys
import argparse
import csv
from collections import defaultdict
import random
import re
from libraryFun import *


##############################################################################

current_version = '0.7'


##############################################################################
# Parses input using argparse module

# Initiates argument parser
parser = argparse.ArgumentParser(description='Find top guides and shRNAs for '
                                               'each query gene')

# Non-optional arguments: List of genes, screen type and name of output file
parser.add_argument('element', choices=['shRNA', 'Cas9'])

parser.add_argument('name', help='name of output files', type=str)

parser.add_argument('genes', type=str, nargs='+')

# Optional arguments:
# Specify number of guides or shRNA

parser.add_argument('-s', '--split_mark', help='Custom delimiters',
                type=str, default='_')

parser.add_argument('-l', '--L_thresh', help='Cutoff to keep data',
                type=float, default=10.0)

parser.add_argument('-e', '--num_elements', help='Cutoff to keep data',
                type=int, default=3)

parser.add_argument('-t', '--num_screens', help='Cutoff to keep data',
                type=int, default=10)

parser.add_argument('-m', '--mouse', action='store_true')


# Saves input to args object
args = parser.parse_args()


##############################################################################
# Retrieving info

geneID2Name, geneID2Info, geneName2ID, geneEns2Name = retrieveInfo(mouse=args.mouse)

if args.element == 'shRNA':

    if args.mouse:
        screen_types = ['Mouse']
    else:
        screen_types = ['shRNA', 'ricin_batch', 'sage']

    num_thresh = 20

    shRNA, Cas9 = True, False

elif args.element == 'Cas9':

    if args.mouse:
        screen_types = ['mm-Cas9-10']
    else:
        screen_types = ['Cas9-10']

    num_thresh = 7

    Cas9, shRNA = True, False


##############################################################################
# Locate all record files generate by analyzeCounts in Results

print('Finding result files')

result_files = retrieveResults('analyzeCounts.py', screen_types, current_version)

print('Result files found: ' + str(len(result_files)))


##############################################################################
# Find gene-of-interest in all result files, retrieve information required

L_stat_col = 10
effect_col = 9
element_col = 14
num_col = 6
ID_col = 0

cols = (L_stat_col, num_col, effect_col, element_col, ID_col)

threshes = (args.L_thresh, 1.5, num_thresh)

print('Extracting lines')
gene2lines = extractLines(args.genes, result_files, args.split_mark, element_col)

print('Extracting elements')
gene2likes, result_errors = extractLikes(gene2lines, cols, threshes,
                                            Cas9, shRNA, args.split_mark)

print('Number of errors in extraction: ' + str(len(result_errors)))


##############################################################################
#

gene2element_IDs = {}
missed_genes = []
output = []

for gene in args.genes:

    if not gene in gene2likes:
        print('Warning: ' + gene + ' has no hits')
        missed_genes.append(gene)
        continue

    like_elements = gene2likes[gene]

    # Sort by L statistic within each gene
    like_elements.sort(key=lambda x: x[0], reverse=True)
    top_screens = like_elements[:args.num_screens]

    for like, elements, result_file, effect, element_effects in top_screens:

        top_elements = elements[:args.num_elements]
        top_effects = element_effects[:args.num_elements]

        element_out = []

        for element, element_effect in zip(top_elements, top_effects):
            element_out.append(' : '.join([str(element_effect), element]))

        output.append([gene, result_file, effect, like] + element_out)


###############################################################################
#

with open('_'.join([args.name, args.element, 'top.csv']), 'w') as out_open:
    out_csv = csv.writer(out_open, delimiter=',', lineterminator='\n')

    out_csv.writerow(['Gene', 'Screen', 'Effect', 'casTLE Score', 'Elements'])

    for line in output:
        out_csv.writerow(line)


###############################################################################
