In [ ]:
import os

os.environ["CUDA_VISIBLE_DEVICES"] = "3, 5"
In [ ]:
from basepair.imports import *
In [ ]:
modisco_dir = f"/srv/scratch/avsec/workspace/chipnexus/data/processed/chipnexus/exp/models/oct-sox-nanog-klf/models/n_dil_layers=9/modisco/all/profile/"
imp_scores =  f"/srv/scratch/avsec/workspace/chipnexus/data/processed/chipnexus/exp/models/oct-sox-nanog-klf/models/n_dil_layers=9/grad.all.h5"
In [ ]:
from basepair.modisco.table import ModiscoData

md = ModiscoData.load(modisco_dir, imp_scores)
In [ ]:
# get NRLB PSAMs
import csv
import numpy as np

oct4_selex_psams = []
consensus_seqs = []

with open('POU5F1-NRLBConfig.csv', 'r') as csv_file:
    csv_reader = csv.reader(csv_file)
    next(csv_reader, None)  # skip the headers
    for row in csv_reader:
        consensus_seqs.append(row[24])
        count = 26
        this_selex_psam = []
        while(row[count] != 'NSB>'):
            this_selex_psam.append(float(row[count]))
            count+=1
        this_selex_psam = np.array(this_selex_psam)
        this_selex_psam = this_selex_psam.reshape(int((count-26)/4), 4)
        oct4_selex_psams.append(this_selex_psam)
        
oct4_selex_psams = np.array(oct4_selex_psams)
In [ ]:
# one hot encode full peak
import os

def onehot_to_seq(onehot):
    seq = ""
    for pos in range(len(onehot)):
        char_idx = [i for i, e in enumerate(onehot[pos]) if e != 0][0]
        if char_idx == 0:
            char = 'A'
        elif char_idx == 1:
            char = "C"
        elif char_idx == 2:
            char = "G"
        elif char_idx == 3:
            char = "T"
        seq +=  char
    return seq

def one_hot_encode_along_channel_axis(sequence):
    to_return = np.zeros((len(sequence),4), dtype=np.int8)
    seq_to_one_hot_fill_in_array(zeros_array=to_return,
                                 sequence=sequence, one_hot_axis=1)
    return to_return

def seq_to_one_hot_fill_in_array(zeros_array, sequence, one_hot_axis):
    assert one_hot_axis==0 or one_hot_axis==1
    if (one_hot_axis==0):
        assert zeros_array.shape[1] == len(sequence)
    elif (one_hot_axis==1): 
        assert zeros_array.shape[0] == len(sequence)
    #will mutate zeros_array
    for (i,char) in enumerate(sequence):
        if (char=="A" or char=="a"):
            char_idx = 0
        elif (char=="C" or char=="c"):
            char_idx = 1
        elif (char=="G" or char=="g"):
            char_idx = 2
        elif (char=="T" or char=="t"):
            char_idx = 3
        elif (char=="N" or char=="n"):
            continue #leave that pos as all 0's
        else:
            raise RuntimeError("Unsupported character: "+str(char))
        if (one_hot_axis==0):
            zeros_array[char_idx,i] = 1
        elif (one_hot_axis==1):
            zeros_array[i,char_idx] = 1
            
full_peaks = []
with open("oct4_peaks.fa.out") as peaks_file:
    for line in peaks_file:
        if ">" in line:
            continue
        full_peaks.append(one_hot_encode_along_channel_axis(line.strip()))
In [ ]:
mr = ModiscoResult("modisco.h5")
mr.open()
tasks = [x.split("/")[0] for x in mr.tasks()]

output_dir = Path("/srv/www/kundaje/avsec/chipnexus/oct-sox-nanog-klf/models/n_dil_layers=9/modisco/all/profile")
all_patterns = read_pkl(output_dir / 'patterns.pkl')

patterns = [x for x in all_patterns if x.attrs['pattern_group'] == 'nte']
In [ ]:
from collections import OrderedDict

modisco_motif = OrderedDict()
for p in patterns:
    metacluster, pattern = p.name.split('/')
    modisco_motif[metacluster] = OrderedDict()
In [ ]:
for p in patterns:
    metacluster, pattern = p.name.split('/')
    modisco_motif[metacluster][pattern] = p.seq # try hyp_contrib
In [ ]:
# # one hot encode seqlets
# seqlets_dir = "modisco_seqlets/"
# metacluster_ids = list(modisco_motif.keys())
# pattern_ids = OrderedDict()
# for metacluster_id in metacluster_ids:
#     pattern_ids[metacluster_id] = list(modisco_motif[metacluster_id].keys())

# patterns_to_seqlets = OrderedDict()

# for metacluster_id in metacluster_ids:
#     patterns_to_seqlets[metacluster_id] = OrderedDict()
#     for pattern_id in pattern_ids[metacluster_id]:
#         patterns_to_seqlets[metacluster_id][pattern_id] = []
#         with open(seqlets_dir+metacluster_id+"/"+pattern_id+".bed.out") as seq_f:
#             for line in seq_f:
#                 if ">" in line:
#                     continue
#                 patterns_to_seqlets[metacluster_id][pattern_id].append(one_hot_encode_along_channel_axis(line.strip()))
In [ ]:
metacluster_ids = list(modisco_motif.keys())
pattern_ids = OrderedDict()
for metacluster_id in metacluster_ids:
    pattern_ids[metacluster_id] = list(modisco_motif[metacluster_id].keys())

patterns_to_kmers = OrderedDict()
patterns_to_seqlets = OrderedDict()

for metacluster_id in metacluster_ids:
    patterns_to_kmers[metacluster_id] = OrderedDict()
    patterns_to_seqlets[metacluster_id] = OrderedDict()
    for pattern_id in pattern_ids[metacluster_id]:
        i,j = md.get_trim_idx(metacluster_id+'/'+pattern_id) 
        full_kmers = md.get_seq(metacluster_id+'/'+pattern_id)
        full_seqlets = md.get_imp(metacluster_id+'/'+pattern_id, task='Oct4', which='profile')
        patterns_to_kmers[metacluster_id][pattern_id] = full_kmers[:, i:j]
        patterns_to_seqlets[metacluster_id][pattern_id] = full_seqlets[:, i:j]
In [ ]:
# they see me rollin
def rolling_window(a, window):
    shape = a.shape[:-1] + (a.shape[-1] - window + 1, window)
    strides = a.strides + (a.strides[-1],)
    return np.lib.stride_tricks.as_strided(a, shape=shape, strides=strides)
In [ ]:
# getting best matches
from scipy import stats
import seaborn as sns
import matplotlib.pyplot as plt
from concise.utils.plot import seqlogo_fig, seqlogo
from modisco.visualization import viz_sequence

%matplotlib inline

psam_id_to_best_metacluster_id = OrderedDict()
psam_id_to_best_pattern_id = OrderedDict()
best_sub_patterns = OrderedDict()
for idx, psam in enumerate(oct4_selex_psams):
    print("PSAM for Consensus: "+consensus_seqs[idx])
    viz_sequence.plot_weights(psam, figsize=(10,3))
    
    best_match_score = float("-inf")
    for metacluster_id in metacluster_ids:
        for pattern_id in pattern_ids[metacluster_id]:
            motif = modisco_motif[metacluster_id][pattern_id]
            per_position_scores_fwd = np.sum(np.multiply(rolling_window(motif.T, len(psam)), psam.T[:,None,:]), axis=(0,2))
            per_position_scores_rev = np.sum(np.multiply(rolling_window(motif.T, len(psam)), psam[::-1,::-1].T[:,None,:]), axis=(0,2))
            per_position_scores = np.max(np.array([per_position_scores_fwd, per_position_scores_rev]), axis=0)
            per_position_scores_fwd_or_rev = np.argmax(np.array([per_position_scores_fwd, per_position_scores_rev]), axis=0)
            argmax_pos = np.argmax(per_position_scores)
            
            if per_position_scores[argmax_pos] >= best_match_score:    # perhaps use correlation (stats.pearsonr) instead of conv?
                best_metacluster_id = metacluster_id
                best_pattern_id = pattern_id
                best_sub_pattern = modisco_motif[best_metacluster_id][best_pattern_id][argmax_pos:argmax_pos+len(psam)]
                best_match_score = per_position_scores[argmax_pos]
                motif_orientation = per_position_scores_fwd_or_rev[argmax_pos]
                if motif_orientation == 1:
                    best_sub_pattern = best_sub_pattern[::-1,::-1]

    psam_id_to_best_metacluster_id[idx] = best_metacluster_id
    psam_id_to_best_pattern_id[idx] = best_pattern_id
    best_sub_patterns[idx] = best_sub_pattern
    print("Best match is: "+best_metacluster_id+"/"+best_pattern_id+" with orientation: "+str(motif_orientation))
    viz_sequence.plot_weights(best_sub_pattern, figsize=(10, 3))
In [ ]:
data = OrderedDict() # keys are subseqlets (as strings)

def create_dict_val_for_subseqlet(key):
    data[key] = OrderedDict()
    
    data[key]['counts'] = 0        # how many times this subseqlet or its revcomp showed up
    data[key]['kmers'] = []        # list of kmers where subseqlet showed up
    data[key]['seqlets'] = []      # list of seqlets where subseqlet showed up
    data[key]['revcomps'] = []     # list of whether subseqlet was found as revcomp (TRUE) or not (FALSE)
    data[key]['modisco_patterns'] = []    # list of modisco patterns where subseqlet showed up
    data[key]['psams'] = []        # list of nrlb psams where subseqlet showed up
    data[key]['importances'] = []  # list of sum(abs(importances))
    data[key]['affinities'] = []   # list of conv with PSAM
    data[key]['modisco_matches'] = []     # list of conv with modisco motif (TODO: change to conv with modisco hypothetical motif)
    
    data[key]['num_peaks_containing_subseqlet'] = 0
    data[key]['freq_in_peaks'] = 0   
    key_seq = one_hot_encode_along_channel_axis(key)
    for peak in full_peaks:
        matches_fwd = np.sum(np.multiply(rolling_window(peak.T, len(key_seq)), key_seq.T[:,None,:]), axis=(0,2))
        matches_rev = np.sum(np.multiply(rolling_window(peak.T, len(key_seq)), key_seq[::-1,::-1].T[:,None,:]), axis=(0,2))
        matches = max(np.count_nonzero(matches_fwd[matches_fwd==len(key_seq)]), np.count_nonzero(matches_rev[matches_rev==len(key_seq)]))
        data[key]['freq_in_peaks'] += matches
        if matches > 0:
            data[key]['num_peaks_containing_subseqlet'] += 1
In [ ]:
for idx, psam in enumerate(oct4_selex_psams):
    print("PSAM for consensus: "+consensus_seqs[idx])
    viz_sequence.plot_weights(psam, figsize=(10,3))
    peak_affinities = []
    for peak in full_peaks:
        peak_affinities.append(np.sum(np.multiply(rolling_window(peak.T, len(psam)), psam.T[:,None,:]), axis=(0,2)))
    peak_affinities = np.array(peak_affinities).flatten()
    
    metacluster_id = psam_id_to_best_metacluster_id[idx]
    pattern_id = psam_id_to_best_pattern_id[idx]
    print("best modisco match is "+metacluster_id+"/"+pattern_id+" with best matching sub-pattern: ")
    viz_sequence.plot_weights(best_sub_patterns[idx], figsize=(10,3))
    seqlet_affinities = []
    seqlet_affinities_argmax = []
    seqlet_affinities_orientation = []
    for seqlet in patterns_to_kmers[metacluster_id][pattern_id]:
        per_position_affinities_fwd = np.sum(np.multiply(rolling_window(seqlet.T, len(psam)), psam.T[:,None,:]), axis=(0,2))
        per_position_affinities_rev = np.sum(np.multiply(rolling_window(seqlet.T, len(psam)), psam[::-1,::-1].T[:,None,:]), axis=(0,2))
        per_position_affinities = np.max(np.array([per_position_affinities_fwd, per_position_affinities_rev]),
                                         axis=0)
        per_position_affinities_fwd_or_rev = np.argmax(np.array([per_position_affinities_fwd, per_position_affinities_rev]),
                                                       axis=0)
        argmax_pos = np.argmax(per_position_affinities)
        seqlet_affinities_argmax.append(argmax_pos)
        seqlet_affinities_orientation.append(per_position_affinities_fwd_or_rev[argmax_pos])
        seqlet_affinities.append(per_position_affinities[argmax_pos])
    seqlet_affinities = np.array(seqlet_affinities).flatten()

    sns.distplot(peak_affinities, bins='auto', label='full_peak_affinities')
    sns.distplot(seqlet_affinities, bins='auto', label='seqlet_affinities')
    plt.legend(loc='upper left')
    plt.title("histogram of affinities and seqlet for pattern: "+metacluster_id+"/"+pattern_id)
    plt.show()
    
    sorted_seqlet_affinities = sorted(enumerate(seqlet_affinities), key=lambda x: x[1])
    sorted_seqlet_indices = [x[0] for x in sorted_seqlet_affinities]
    for seqlet_idx in sorted_seqlet_indices:
        the_kmer = patterns_to_kmers[metacluster_id][pattern_id][seqlet_idx]
        imp_seqlet = patterns_to_seqlets[metacluster_id][pattern_id][seqlet_idx]
        
        pos_of_best_match_to_nrlb_within_seqlet = seqlet_affinities_argmax[seqlet_idx]
        best_matching_subseqlet = the_kmer[pos_of_best_match_to_nrlb_within_seqlet:
                                           (pos_of_best_match_to_nrlb_within_seqlet+len(psam))]
        if (seqlet_affinities_orientation[seqlet_idx]==1):
            best_matching_subseqlet = best_matching_subseqlet[::-1,::-1]
            
        # storing 
        key = onehot_to_seq(best_matching_subseqlet)
        if key not in data and onehot_to_seq(best_matching_subseqlet[::-1,::-1]) not in data:
            create_dict_val_for_subseqlet(key)

        data[key]['counts'] += 1
        data[key]['kmers'].append(the_kmer)
        data[key]['seqlets'].append(imp_seqlet)
        data[key]['revcomps'].append(seqlet_affinities_orientation[seqlet_idx])
        data[key]['modisco_patterns'].append(metacluster_id+"/"+pattern_id)
        data[key]['psams'].append(psam)
        data[key]['importances'].append(np.sum(np.abs(imp_seqlet[pos_of_best_match_to_nrlb_within_seqlet:
                                                                 (pos_of_best_match_to_nrlb_within_seqlet+len(psam))])))

        masked_best_matching_subseqlet = best_matching_subseqlet*psam

        per_position_motif_match_fwd = np.sum(np.multiply(rolling_window(modisco_motif[metacluster_id][pattern_id].T, len(the_kmer)), the_kmer.T[:,None,:]), axis=(0,2))
        per_position_motif_match_rev = np.sum(np.multiply(rolling_window(modisco_motif[metacluster_id][pattern_id].T, len(the_kmer)), the_kmer[::-1,::-1].T[:,None,:]), axis=(0,2))
        per_position_motif_match = max(np.max(per_position_motif_match_fwd), np.max(per_position_motif_match_rev))
        
        data[key]['affinities'].append(seqlet_affinities[seqlet_idx])
        data[key]['modisco_matches'].append(per_position_motif_match)
In [ ]:
print('\033[1m'+"counts \t best subseqlet \t mean importance \t var importance \t mean affinity (conv with PSAM) \t modisco match (conv with motif) \t # peaks w subseqlet \t # subseqlets in peaks"+'\033[0m')
for key in data:
    print(str(data[key]['counts'])+" \t "
          +key+" \t \t "
          +str(np.mean(data[key]['importances']))+" \t \t "
          +str(np.var(data[key]['importances']))+" \t \t "
          +str(np.mean(data[key]['affinities']))+" \t \t "
          +str(np.mean(data[key]['modisco_matches']))+" \t \t "
          +str(data[key]['num_peaks_containing_subseqlet'])+" \t \t "
          +str(data[key]['freq_in_peaks']))
In [ ]:
# TODO: json dump the dict with all the data
# TODO: generate tab separated csv with averages