##############################################################################
# David Morgens
##############################################################################
# Import necessary modules

import os
import sys
import argparse
import csv
from collections import defaultdict
import random
import re


##############################################################################
#

def revcompl(x):
    return ''.join([{'A': 'T', 'C': 'G', 'G': 'C', 'T': 'A'}[B] for B in x][::-1])


###############################################################################
#

def retrievePrimers():

    primers = [
        ('oMCB1685_sublib_1', ('CTTAAACCGGCCAACATACC', 'ATGCTACTCGTTCCTTTCGA')),
        ('oMCB1686_sublib_2', ('TGCTCTTTATTCGTTGCGTC', 'TCTTATCGGTGCTTCGTTCT')),
        ('oMCB1687_sublib_3', ('TGAGCCTTATGATTTCCCGT', 'GTCCGTTTTCCTGAATGAGC')),
        ('oMCB1688_sublib_4', ('CGTTCTAAACGGCTAGATGC', 'AGTCTGTCTTTCCCCTTTCC')),
        ('oMCB1689_sublib_5', ('GTATCCGAAGCGTGGAGTAT', 'CAGGTATGCGTAGGAGTCAA')),
        ('oMCB1690_sublib_6', ('CCAAAGATTCAACCGTCCTG', 'ATTAGCCATTTCAGGACGGA')),
        ('oMCB1691_sublib_7', ('TATTCATGCTTGGACGGACT', 'ACTATGTACCGCTTGTTGGA')),
        ('oMCB1692_sublib_8', ('ATCGACAATGGTATGGCTGA', 'TATGTCTCCTAGCCACTCCT')),
        ('oMCB1693_sublib_9', ('GTCCTAGTGAGGAATACCGG', 'CCGAAGAATCGCAGATCCTA')),
        ('oMCB1694_sublib_10', ('TTAGATAGGTGTGTAGGCGC', 'TAAGGTGCGTACTAGCTGAC')),
        ('oMCB1695_sublib_11', ('TTCCGTTTATGCTTTCCAGC', 'TCCTTGGAGTTTAGAGCGAG')),
        ('oMCB1696_sublib_12', ('GTATAGTTTGTGCGGTGGTC', 'ATCAATCCCCTACACCTTCG')),
        ('oMCB1697_sublib_13', ('TCAGCCTTTCATTGATTGCG', 'TTCCTTGATACCGTAGCTCG')),
        ('oMCB1698_sublib_14', ('AGGGTCGTGGTTAAAGGTAC', 'CGTTTCTTTCCGGTCGTTAG')),
        ('oMCB1699_sublib_15', ('TGCAAGTGTACAAATCCAGC', 'GAACGGTGATCCCTTTCCTA')),
        ('oMCB1700_sublib_16', ('CTTAAGGTTTGCCCATTCCC', 'TGTTATAGCTTCCACGGTGT')),
        ('oMCB1701_sublib_17', ('TGGTTCGTTAGTCGATCTCC', 'AGACGGGATTTTACTGGGTC')),
        ('oMCB1702_sublib_18', ('TATTTTGTAGAGCGTTCGCG', 'TCTTTGCTTCGCAAGTCTTG')),
        ('oMCB1703_sublib_19', ('TTCTGTAAGTTTCGTCGGGA', 'CTAAACACCGCACCTCACTA')),
        ('oMCB1704_sublib_20', ('TTGACGTACGTAGGTTCTCC', 'GAACACAACTACACTGACGC')),
        ('oMCB1705_sublib_21', ('GAGATGAGTAGACGAGTGGG', 'ATGGTCACTGACTCGCATTA')),
        ('oMCB1706_sublib_22', ('CTTTGGGCTTTCAGATGAGC', 'CAAAGATTTCTGTCGGTCGG')),
        ('oMCB1707_sublib_23', ('TGTCATATGCTAACGTCCGT', 'TGGCTACTTTCTTAGCGGAA')),
        ('oMCB1708_sublib_24', ('TTGCGACATCACAATTCTCG', 'TACTTCGAGACTTCATGCGT')),
        ('oMCB1709_sublib_25', ('TCAGTATGGCGTCTTGAAGT', 'ATGGCCCGACCTCTATTATG')),
        ('oMCB1710_sublib_26', ('TCATGTCGTGACCAGTAGAC', 'TGGGTCTAGTGAACTTCGTC')),
        ('oMCB1711_sublib_27', ('AACTAACGGATTTAAGCGCG', 'AACATATGTTGCTTCGTCCG')),
        ('oMCB1712_sublib_28', ('CATTTTCTGTTCCCCAGTGG', 'TCGAGTTAGATTGTCACCCC')),
        ('oMCB1713_sublib_29', ('ATTTGCCTAACCACTCCACT', 'TCAGAGCTTTTCGGTACAGT')),
        ('oMCB1714_sublib_30', ('TGACTTATGAACCTTTGCGC', 'GCCCAGGAGTAGTCGTTAAT')),
        ('oMCB1715_sublib_31', ('ATAGGATTAGCTGATGGGCC', 'TCTGTGTTCCGACTAAGGTC')),
        ('oMCB1716_sublib_32', ('TGAGATTCGGGACTATTCGG', 'TCTGTTGTTAGACTCCGACC')),
        ('oMCB1717_sublib_33', ('TTGGTTAGTACACGGGACTC', 'GTACGTCTGAACTTGGGACT')),
        ('oMCB1718_sublib_34', ('ATTTGTGTATCGAGGCTCGT', 'AGACACGCGATTGTTTAACC')),
        ('oMCB1719_sublib_35', ('ATCGTTCCCCATCACATTCT', 'CCGTTCGTTTTGAGCACTTA')),
        ('oMCB1720_sublib_36', ('ATTACCATGTTATCGGGCGA', 'AGGTTAGGGAACGCAAGATT')),
        ('oMCB1721_sublib_37', ('TCGGTGGATATGACGTAACC', 'CCAGACTGTGCTCGTTATCT')),
        ('oMCB1722_sublib_38', ('GGTCAGATGGTTTACATGCG', 'AGTTGTTCTCTATCCGCGAT')),
        ('oMCB1723_sublib_39', ('TCTCGTTCGAAAATCATCGC', 'GATTAAATCTCGCCGGTGAC')),
        ('oMCB1724_sublib_40', ('TGCAAATGTGAGGTAGCAAC', 'TTGTAGTTTTCGCTTGCGTT')),
        ('oMCB1725_sublib_41', ('AAAGTCAAAGTGCGTTTCGT', 'TGTGTTGCTCTCTCATAGCC')),
        ('oMCB1725_sublib_42', ('GCTTATTCGTGCCGTGTTAT', 'TACTTTTGATTGCTGTGCCC')),
        ('oMCB1725_sublib_43', ('TTTGCTTCAGTCAGATTCGC', 'GTTCAATCACTGAATCCCGG')),
        ('oMCB1725_sublib_44', ('GTCGAGTCCTATGTAACCGT', 'CAGGGGTCGTCATATCTTCA')),
        ('oMCB1725_sublib_45', ('GTAAGATGGAAGCCGGGATA', 'CACCTCATAGAGCTGTGGAA')),
        ('oMCB1725_sublib_46', ('GGTGTCGCAACATGATCTAC', 'CGGTTCCTAGTCATGTTTGC')),
        ('oMCB1725_sublib_47', ('GTGCTAAGTCACACTGTTGG', 'TTGTACTAATCTCGTCCCGG')),
        ('oMCB1725_sublib_48', ('TCTAAACAGTTAGGCCCAGG', 'TTATGTTCACAACTGGCGTG')),
        ('oMCB1725_sublib_49', ('GTCTTTATACTTGCCTGCCG', 'TGGAACTGATTTGGCCTTTG')),
        ('oMCB1725_sublib_50', ('CACCGCGATCAATACAACTT', 'TATAGTTCCTCCCATGCACC')),
        ('oMCB1725_sublib_51', ('TTCGGATAGACTCAGGAAGC', 'ACAATAGACAGACCCATGCA')),
        ('oMCB1725_sublib_52', ('CCATTGATAGATTCGCTCGC', 'GAGTCGAGCTAGCATAGGAG')),
        ('oMCB1725_sublib_53', ('TTTTCTACTTTCCGGCTTGC', 'TTGTGGGAGCTTCTTACCAT')),
        ('oMCB1725_sublib_54', ('ATGACTATTGGGGTCGTACC', 'TCGTACGGGAATGACCATAG')),
        ('oMCB1725_sublib_55', ('TCGACAATAGTTGAGCCCTT', 'AGACACAACGTAGCCGATTA')),
        ('oMCB1725_sublib_56', ('GAGCCATGTGAAATGTGTGT', 'CGGACTAAAGGATCGAGTCA')),
        ('oMCB1725_sublib_57', ('CGTATACGTAAGGGTTCCGA', 'CATCGGATAACACAAAGCGT')),
        ('oMCB1725_sublib_58', ('TTATGATGTCCGGATACCCG', 'GATGTATACTCCACCGTGGT')),
        ('oMCB1725_sublib_59', ('TCTTAGAAATCCACGGGTCC', 'TGAGATATGTACCTGGTGCC')),
        ('oMCB1725_sublib_60', ('GAAGGGTGGATCATCGTACT', 'ATTCTTGGGCCTATCGTTGT')),
        ('oMCB1725_sublib_61', ('GGCTGTTAGTTTTAGAGCCG', 'AAACCATATACAGCCGTCGT')),
        ('oMCB1725_sublib_62', ('AGTGGTGTAGTGGCTTCTAC', 'TAGCTAAATCCCACCCGATG')),
        ('oMCB1725_sublib_63', ('CTCAGAGGGAGTTCAACTGT', 'GTGCGGTTACAGTTTTGACT')),
        ('oMCB1725_sublib_64', ('TTTGGCAGATCATTAACGGC', 'GGGACTACATAGGGTGACAG')),
        ('oMCB1725_sublib_65', ('TATGATCTCCGTACACGAGC', 'CGTTGTCGTTCCAAAGAAGT')),
        ('oMCB1725_sublib_66', ('AGTGCCATGTTATCCCTGAA', 'AGTCACACATATACGGACCC')),
        ('oMCB1725_sublib_67', ('TTATACATCTGGACGCCTCC', 'AGAGAACCCCTATTATGGCG')),
        ('oMCB1725_sublib_68', ('TCCTCGATTCTCCAATCAGG', 'TCGTTAGGCTAAAACATGCG')),
        ('oMCB1725_sublib_69', ('GCTTAACGCATTTCAAGCAC', 'TGATAGGTCGTTCAGCCTAC')),
        ('oMCB1725_sublib_70', ('CTTTTATGTTCCTCGCAGGG', 'TCGGGACTTTCATAAGCACT')),
        ('oMCB1725_sublib_71', ('GTGGGCGTTAGCAAATTACA', 'ATTTTATGCGTCCAGTTCGG')),
        ('oMCB1725_sublib_72', ('AGAGATTATTAGGCGTGGGG', 'AAGGCTGGTATTTCCCTTCA')),
        ('oMCB1725_sublib_73', ('TAGGATTACTGCTCGGTGAC', 'CATACTGTTGGTTGCTAGGC')),
        ('oMCB1725_sublib_74', ('TCGCGTGAGTGGTTCATATA', 'ATATACTGGATTCCGCCGTT')),
        ('oMCB1725_sublib_75', ('CAATAGATACCCACCCGTCA', 'ACTTATGAACCCTTGGCACT')),
        ('oMCB1725_sublib_76', ('ATATATCCGCCGTTGTACGT', 'ATAGATGTATGCCGTTCGGT')),
        ('oMCB1725_sublib_77', ('CGAGAGTCTCCCACGATATC', 'TCTCTGTTTTCCGCACTTTG')),
        ('oMCB1725_sublib_78', ('ATTCAGTTGGTCTTACGGGT', 'AGTTATTCGTCTTTCCCGGT')),
        ('oMCB1725_sublib_79', ('GGATTGCAACGTCAGGAAAT', 'TACAGGAATCTCCACGAAGC')),
        ('oMCB1725_sublib_80', ('GAATGTTGCAGACTGGAAGG', 'CCTCGGGCTTGTTACTAGAT')),
        ('oMCB1725_sublib_81', ('GTCCATGAATACAACACCGG', 'ATTCTTCCGTCCAACGTACT')),
        ('oMCB1725_sublib_82', ('TCGAACAATTTGCGATACCC', 'TAATCATACGAGTGGGCCTC')),
        ('oMCB1725_sublib_83', ('AAGTGCACATTTCGTTTCGA', 'AGTTGGTAGAATTGACCGGT'))
        ]

    return primers


###############################################################################
# Function which retrieves info about genes

def retrieveInfo(ref_base='/mnt/lab_data/bassik/All_Screen_Data/GenRef', mouse=False):

    '''
    Retrieves gene info for the screen type. Location of reference
    files can be changed, defaults to nearby GenRef folder.
    '''

    # Finds info files downloaded from NCBI
    org_file_human = os.path.join(ref_base, 'Homo_sapiens.gene_info')
    org_file_mouse = os.path.join(ref_base, 'Mus_musculus.gene_info')

    # Custom Ensemble ID to gene name file
    ens_file = os.path.join(ref_base, 'ensRef.csv')

    geneID2Name = defaultdict(lambda: 'N/A')
    geneID2Info = defaultdict(lambda: 'N/A')
    geneName2ID = defaultdict(lambda: 'N/A')
    geneEns2Name = defaultdict(lambda: 'N/A')

    # Reads in Ensemble data
    try:
        with open(ens_file, 'r') as ens_open:

            ens_csv = csv.reader(ens_open, delimiter=',')

            for line in ens_csv:
                geneEns2Name[line[1]] = line[0].upper()

    except IOError:
        print('Ensembl information file not found.\n'
                + 'Use -r to change file location')

    # Reads in Mouse data
    try:
        with open(org_file_mouse, 'r') as org_open:

            org_csv = csv.reader(org_open, delimiter='\t')
            org_csv.next()  # Skips header

            for line in org_csv:
                # Entrez
                geneID2Name[line[1]] = line[2].upper()
                geneID2Info[line[1]] = line[8]
                geneName2ID[line[2].upper()] = line[1]

    except IOError:
        print('Mouse information file not found.\n'
                + 'Use -r to change file location')

    if not mouse:
        # Reads in Human data
        try:
            with open(org_file_human, 'r') as org_open:

                org_csv = csv.reader(org_open, delimiter='\t')
                org_csv.next()  # Skips header

                # For each line in file, save that gene information
                for line in org_csv:

                    geneID2Name[line[1]] = line[2].upper()
                    geneID2Info[line[1]] = line[8]
                    geneName2ID[line[2].upper()] = line[1]


        except IOError:
            print('Human information file not found.\n'
                    + 'Use -r to change file location')

    return geneID2Name, geneID2Info, geneName2ID, geneEns2Name


###############################################################################
# Converts IDs

def retrieveIDs(gene, geneID2Name, geneName2ID, geneEns2Name):

    if gene in geneID2Name:
	geneID = gene
        entrez = gene
	name = geneID2Name[geneID]

    elif gene in geneName2ID:
	name = gene
	geneID = geneName2ID[gene]
	entrez = geneName2ID[gene]

    elif gene in geneEns2Name:
        name = geneEns2Name[gene]
        geneID = gene
        entrez = geneName2ID[name]

    else:
	geneID = gene
	name = gene
        entrez = gene

    return geneID, name, entrez


###############################################################################
#

def dissectFolder(gene_file, name):

    fil_chip_genes, fil_name = os.path.split(gene_file)
    chip_loc, genes = os.path.split(fil_chip_genes)
    directory, chip_name = os.path.split(chip_loc)
    legend_file = os.path.join(chip_loc, chip_name + '_legend.csv')
    out_file = os.path.join(chip_loc, 'Align', name)

    return out_file, legend_file


###############################################################################
#

def retrieveResults(script_all, screen_types, current_version,
                        path='/mnt/lab_data/bassik/All_Screen_Data/Results'):

    record_files = []
    result_files = []

    for root, dirs, files in os.walk(path):
        for fil in files:

            if 'DM/Growth' in root or 'Results/DB' in root or 'DM/Parnas2015' in root or 'DM/Bassiketal' in root:
                continue

            if '_record.txt' in fil:
                record_files.append(os.path.join(root, fil))

    for record_file in record_files:

        with open(record_file, 'r') as record_open:              
            record_csv = csv.reader(record_open, delimiter='\t')

            try:
                script, version = record_csv.next()

            except StopIteration:
                print('Warning: Empty record file: ' + record_file)
                continue

            if not script.startswith(script_all):
                continue

            if version != current_version:
                print('Incorrect version (' + version + '): ' + record_file)
                continue

            try:
                last_time = record_csv.next()[1]
                unt_file = record_csv.next()[1]
                trt_file = record_csv.next()[1]
                time_zero = record_csv.next()[1]
                res_file = record_csv.next()[1]
                screen_type = record_csv.next()[1]
                neg_name = record_csv.next()[1]
                split_mark = record_csv.next()[1]
                exclude = record_csv.next()[1]
                thresh = int(record_csv.next()[1])
                K = float(record_csv.next()[1])
                back = record_csv.next()[1]
                I_step = float(record_csv.next()[1])
                scale = int(record_csv.next()[1])
                draw_num = int(record_csv.next()[1])


            except ValueError:
                print('Warning: Record formatting error: ' + record_file)
                continue

        if screen_type in screen_types:

            full_res_file = os.path.join(os.path.dirname(path), res_file + '.csv')

            if not os.path.exists(full_res_file):
                print('Warning: Result file not found: ' + full_res_file)
                continue

            result_files.append(full_res_file)

    return result_files


###############################################################################
#

def extractLines(genes, result_files, split_mark, element_col):

    gene2lines = defaultdict(list)

    for result_file in result_files:

        with open(result_file, 'r') as result_open:
            
            result_csv = csv.reader(result_open, delimiter=',')

            # Skips header
            next(result_csv)

            for line in result_csv:

                if not line:
                    continue

                try:
                    geneID = line[0]
                    name = line[1]
                    element = line[element_col]

                except IndexError:
                    print('Warning: Result formatting error: ' + result_file)
                    break

                element_effect, element_name = element.split(': ')
                ID_from_element = element_name.split(split_mark)[0]

                if geneID in genes:
                    gene2lines[geneID] += [(result_file, line)]
                    continue

                elif name.upper() in genes:
                    gene2lines[name] += [(result_file, line)]
                    continue

                elif ID_from_element.upper() in genes:
                    gene2lines[ID_from_element] += [(result_file, line)]
                    continue

    return gene2lines


###############################################################################
#

def extractElements(elements, geneID, effect, Cas9, shRNA, split_mark):

    effects_guides = []

    # Order guides and hairpins in descending order of effect
    for element in elements:

        element_effect, element_name = element.split(':')

        if Cas9:
            element_ID = element_name.split(split_mark)[-1].strip()

        elif shRNA:
            ID_num = str(int(filter(lambda x : x.isdigit(), element_name.split(split_mark)[-1])))
            element_ID = split_mark.join([geneID, ID_num])

        effects_guides.append((float(element_effect), element_ID))

    if effect > 0:
        effects_guides.sort(key=lambda x: x[0], reverse=True)

    elif effect < 0:
        effects_guides.sort(key=lambda x: x[0], reverse=False)

    else:
        pass

    return effects_guides


###############################################################################
#

def extractLikes(gene2lines, cols, threshes, Cas9, shRNA, split_mark):

    gene2likes = defaultdict(list)
    result_errors = []

    L_thresh, off_thresh, num_thresh = threshes
    L_stat_col, num_col, effect_col, element_col, ID_col = cols


    for gene, lines in gene2lines.items():

        for result_file, line in lines:

            try:
                geneID = line[ID_col]
                L_stat = float(line[L_stat_col])
                num = int(line[num_col])
                effect = float(line[effect_col])
                elements = line[element_col:]

            except:
                raise
                result_errors.append(('Warning: Result formatting error: ', result_file))
                break

            if L_stat < L_thresh or num < num_thresh:
                continue

            effects_elements = extractElements(elements, geneID, effect, Cas9, shRNA, split_mark)

            elements = []
            element_effects = []

            for element_effect, element_ID in effects_elements:

                if abs(element_effect) > off_thresh * abs(effect):
                    continue

                elements.append(element_ID)
                element_effects.append(element_effect)

            gene2likes[gene] += [(L_stat, elements, result_file, effect, element_effects)]

    return gene2likes, result_errors


###############################################################################
#

def pickPrimer(restriction_sites, name2primer, context_up, context_down):

    primers = retrievePrimers()

    RE_RE = []
    for site in restriction_sites:
        RE_RE.append(re.compile(site))

    broke = False

    for primer in primers:

        if primer in name2primer.values():
            continue

        up_primer, down_primer = primer[1][0], revcompl(primer[1][1])

        example_up = up_primer + context_up
        example_down = context_down + down_primer

        re_found = False

        for site in restriction_sites:
            re_re = re.compile(site)
            if re_re.search(example_up) or re_re.search(example_down):
                re_found = True

        if not re_found:
            broke = True
            break

    if not broke:
        sys.exit('Error: Free primer pair not found')

    return up_primer, down_primer, primer


###############################################################################
#

def retrieveNegatives(neg_name, gene2names, name2seq, number):

    neg_names = gene2names[neg_name]
    random.shuffle(neg_names)
    num = 0

    non_seqs = []

    for name in neg_names[:number]:

        seq = name2seq[name]
        non_seqs.append([name, seq])
        num += 1

    return non_seqs, num


#########################################################################
# Calculates the GC content of an inputed string

def getGC(guide):
    numGC = 0
    total = len(guide)
    for nuc in list(guide):
        if nuc == 'C' or nuc == 'G' or nuc == 'c' or nuc == 'g':
            numGC += 1
    return float(numGC)/total


###############################################################################
# Adds Gs

def addG(guide):

    if guide.upper()[0] == 'G':
        return guide.upper()
    else:
        return 'G' + guide.upper()


###############################################################################
