##############################################################################
# Kimberly Tsui
# 06/29/2015

##############################################################################
# Import necessary modules

import os
import sys
import argparse
import csv
from collections import defaultdict
import random
import re


##############################################################################

current_version = '1.0'


###############################################################################
# Helper function

revcompl = lambda x: ''.join([{'A': 'T', 'C': 'G', 'G': 'C', 'T': 'A'}[B] for B in x][::-1])


##############################################################################
# Parses input using argparse module

# Initiates argument parser
parser = argparse.ArgumentParser(description='Find top guides and shRNAs for '
                                               'each query gene')

# Non-optional arguments: List of genes, screen type and name of output file

parser.add_argument('gene_file', type=str)
parser.add_argument('align_file', type=str)
parser.add_argument('off_file', type=str)
parser.add_argument('file_out', help='path and name of output files', type=str)

# Optional arguments:
# Specify number of guides or shRNA
parser.add_argument('-g', '--guide_num', help='number of guides or hairpins', default=3, type=int)

parser.add_argument('-s', '--split_mark', help='Custom delimiters',
                type=str, default='_')

parser.add_argument('-l', '--L_thresh', help='Cuttoff to keep data',
                type=float, default=10.0)

parser.add_argument('-t', '--test', action='store_true')

parser.add_argument('-sn', '--safe_num', help='Number of safe controls',
                type=int, default=2)


# Saves input to args object
args = parser.parse_args()


if args.align_file == 'Data/sgRNA_human_10_ref.csv':
    screen_type_all = 'Cas9-10'

else:
    print('Screen not recognized')
    sys.exit()


##############################################################################
# Read in list of query genes and convert into list

genes = []

with open(args.gene_file, 'r') as gene_open:

    gene_csv = csv.reader(gene_open, delimiter='\t')

    for line in gene_csv:
        genes.append(line[0].upper()) 


##############################################################################
# Locate all record files generate by analyzeCounts in Results

record_list = []
analyzeCounts_records = []

for root, dirs, files in os.walk('/mnt/lab_data/bassik/All_Screen_Data/Results'):
    for fil in files:
        if '_record.txt' in fil:
            record_list.append(os.path.join(root, fil))


##############################################################################
# From record files, retrieve result location and version of analysis script used

result_file_list = []

for record_file in record_list:

    with open(record_file, 'r') as record_open:             
        rec_csv = csv.reader(record_open, delimiter='\t')

        try:
            script, version = rec_csv.next()

        except StopIteration:
            print('Empty record file: ' + record_file)
            continue

        if not script.startswith('analyzeCounts.py'):
            continue

        if version != current_version:
            print('Record file outdated: ' + record_file)
            continue

        try:
            last_time = rec_csv.next()[1]
            unt_file = rec_csv.next()[1]
            trt_file = rec_csv.next()[1]
            zero_files = rec_csv.next()[1]
            if zero_files:
                zero_files = eval(zero_files)
            res_file = rec_csv.next()[1]
            screen_type = rec_csv.next()[1]
            neg_name = rec_csv.next()[1]
            split_mark = rec_csv.next()[1]
            exclude = rec_csv.next()[1]
            if exclude:
                exclude = eval(exclude)
            thresh = int(rec_csv.next()[1])
            K = float(rec_csv.next()[1])
            back = rec_csv.next()[1]
            I_step = float(rec_csv.next()[1])
            scale = int(rec_csv.next()[1])
            draw_num = int(rec_csv.next()[1])

        except ValueError:
            print('Warning: Record formating error')
            print(record_file)

        if screen_type == screen_type_all:

            if not os.path.exists('/mnt/lab_data/bassik/All_Screen_Data/' + res_file + '.csv'):
                print('Warning: Result file not found.')
                print('/mnt/lab_data/bassik/All_Screen_Data/' + res_file + '.csv')
                continue

            result_file_list.append('/mnt/lab_data/bassik/All_Screen_Data/' + res_file + '.csv')

print('Result files found: ' + str(len(result_file_list)))


##############################################################################
# Find gene-of-interest in all result files, retrieve information required

gene2like_guides = defaultdict(list)

for result_file in result_file_list:

    with open(result_file, 'r') as result_open:
            
            result_csv = csv.reader(result_open, delimiter=',')

            # Skips header
            next(result_csv)

            for line in result_csv:

                #geneID, name, info, comp, proc, fun, num, pMW, pKS,
                                #max_e, rat, pRat, min_I, max_I] + elements

                if not line:
                    continue

                try:
                    geneID = line[0]
                    name = line[1]
                    effect = float(line[7])
                    Lstat = float(line[8])
                    guide_col = 12
                    guide = line[guide_col]

                except:
                    print('Result formatting error: ' + result_file)
                    break

                guide_effect, guideID = guide.split(': ')

                if Lstat < args.L_thresh:
                    continue

                if geneID.upper() in genes or name.upper() in genes:

                    if geneID in genes:
                        gene = geneID

                    elif name.upper() in genes:
                        gene = name.upper()

                    effects_guides = []

                    # Order guides and hairpins in descending order of effect
                    for guide in line[guide_col:]:

                        guide_effect, guide_name = guide.split(':')
                        guideID = guide_name.split('_')[-1].strip()

                        effects_guides.append((float(guide_effect), guideID))

                    if effect > 0:
                        effects_guides.sort(key=lambda x: x[0], reverse=True)

                    elif effect < 0:
                        effects_guides.sort(key=lambda x: x[0], reverse=False)

                    else:
                        continue

                    guides = []

                    for guide_effect, guideID in effects_guides:

                        if abs(guide_effect) > 1.5 * abs(effect):
                            continue

                        guides.append(guideID)

#                        if len(guides) >= args.guide_num:
#                            break

#                    if len(guides) != args.guide_num:
 #                       continue

                    gene2like_guides[gene] += [(Lstat, guides)]


##############################################################################
#

guideID2name_seq = {}
name2seq = {}
gene2names = defaultdict(list)
ens2names = defaultdict(list)
guide2name = {}

with open(args.align_file, 'r') as align_open:
    align_csv = csv.reader(align_open, delimiter=',')

    for line in align_csv:

        name, seq = line
        ens, gene, cat, guideID = name.split(args.split_mark)

        guideID2name_seq[guideID] = (name, seq)
        gene2names[gene.upper()] += [name]
        ens2names[ens] += [name]
        name2seq[name] = seq
        guide2name[seq[11:-17]] = name

off_filter = []

with open(args.off_file, 'r') as off_open:

    off_csv = csv.reader(off_open, delimiter='\t', lineterminator = '\n')

    for line in off_csv:
        chrom, start, end, strand, guide, gene, score, effect, pval, pval, rho, \
                off_exon, off_all, bah, bah2, off_pos_all, off_trans_all = line

        name = guide2name[guide]
        guideID = name.split('_')[-1]
        
        off_all_int = map(int, off_all.split(','))
        mis0, mis1, mis2, mis3, mis4 = off_all_int

        if mis0 or mis1:
            off_filter.append(guideID)


##############################################################################
#

all_guideIDs = []
missed_genes = []

for gene in genes:

    if not gene in gene2like_guides:
        #print('Warning: ' + gene + ' has no hits')
        missed_genes.append(gene)
        continue

    like_guides = gene2like_guides[gene]

    # Sort by L statistic within each gene
    like_guides.sort(key=lambda x: x[0], reverse=True)
    guideIDs = like_guides[0][1]

    gene_count = 0

    for guideID in guideIDs:

        if guideID in off_filter:
            print('Off-target guide found')
            continue

        elif gene_count == args.guide_num:
            break

        else:
            all_guideIDs.append(guideID)
            gene_count += 1

print('Out of ' + str(len(genes)) + ' genes, ' +
        str(len(missed_genes)) + ' have no corresponding hits')


###############################################################################
#

output_fwd = []
output_rev = []

fwd_up = 'CCACCTTGTTG'
fwd_down = 'GTTTAAGAGCTAAGCTG'

up_fwd = 'ttg'
down_fwd = 'gtttaagagc'

up_rev = 'ttagctcttaaac'
down_rev = 'caacaag'


###############################################################################
#

for guideID in all_guideIDs:

    if guideID in guideID2name_seq:
        name, seq = guideID2name_seq[guideID]

    else:
        print('Warning: ' + guideID + ' not found')

    seq_raw = seq[11:-17]

    output_fwd.append([name + '_fwd', up_fwd + seq_raw + down_fwd])
    output_rev.append([name + '_rev', up_rev + revcompl(seq_raw) + down_rev])


###############################################################################
#

for gene in missed_genes:

    if gene in gene2names:
        names = gene2names[gene]

    elif gene in ens2names:
        names = ens2names[gene]

    else:
        print('Warning: ' + gene + ' not found')
        continue

    if len(names) != len(set(names)):
        print(names)
        sys.exit()

    if len(names) < args.guide_num:
        print('Warning: ' + gene + ' has too few guides')
        continue

    for name in names[:args.guide_num]:

        seq = name2seq[name]

        seq_raw = seq[11:-17]

        output_fwd.append([name + '_fwd', up_fwd + seq_raw + down_fwd])
        output_rev.append([name + '_rev', up_rev + revcompl(seq_raw) + down_rev])


###############################################################################
#

safe_names = gene2names['SAFE']

random.shuffle(safe_names)

for name in safe_names[:args.safe_num]:

    seq = name2seq[name]

    seq_raw = seq[11:-17]

    output_fwd.append([name + '_fwd', up_fwd + seq_raw + down_fwd])
    output_rev.append([name + '_rev', up_rev + revcompl(seq_raw) + down_rev])


###############################################################################
#

output = []

for i in range(len(output_fwd) / 12 + 1):
    output.extend(output_fwd[12 * i: 12 * i + 12])
    output.extend(output_rev[12 * i: 12 * i + 12])

print('Total elements: ' + str(len(output)))


###############################################################################
#

if args.test:
    sys.exit('Warning: Files not written')

with open(args.file_out, 'w') as out_open:
    out_csv = csv.writer(out_open, delimiter=',', lineterminator='\n')

    for line in output:
        out_csv.writerow(line)


###############################################################################
