In [1]:
import os

os.environ["CUDA_VISIBLE_DEVICES"] = "2"
In [2]:
from basepair.imports import *
Using TensorFlow backend.
In [3]:
modisco_dir = f"/srv/scratch/avsec/workspace/chipnexus/data/processed/chipnexus/exp/models/oct-sox-nanog-klf/models/n_dil_layers=9/modisco/all/profile/"
imp_scores =  f"/srv/scratch/avsec/workspace/chipnexus/data/processed/chipnexus/exp/models/oct-sox-nanog-klf/models/n_dil_layers=9/grad.all.h5"
In [4]:
from basepair.modisco.table import ModiscoData

md = ModiscoData.load(modisco_dir, imp_scores)
100%|██████████| 142/142 [44:13<00:00, 18.69s/it]
In [5]:
# get NRLB PSAMs
import csv
import numpy as np

oct4_selex_psams = []
consensus_seqs = []

with open('SOX2-NRLBConfig.csv', 'r') as csv_file:
    csv_reader = csv.reader(csv_file)
    next(csv_reader, None)  # skip the headers
    for row in csv_reader:
        consensus_seqs.append(row[24])
        count = 26
        this_selex_psam = []
        while(row[count] != 'NSB>'):
            this_selex_psam.append(float(row[count]))
            count+=1
        this_selex_psam = np.array(this_selex_psam)
        this_selex_psam = this_selex_psam.reshape(int((count-26)/4), 4)
        oct4_selex_psams.append(this_selex_psam)
        
oct4_selex_psams = np.array(oct4_selex_psams)
In [6]:
# one hot encode full peak
import os

def onehot_to_seq(onehot):
    seq = ""
    for pos in range(len(onehot)):
        char_idx = [i for i, e in enumerate(onehot[pos]) if e != 0][0]
        if char_idx == 0:
            char = 'A'
        elif char_idx == 1:
            char = "C"
        elif char_idx == 2:
            char = "G"
        elif char_idx == 3:
            char = "T"
        seq +=  char
    return seq

def one_hot_encode_along_channel_axis(sequence):
    to_return = np.zeros((len(sequence),4), dtype=np.int8)
    seq_to_one_hot_fill_in_array(zeros_array=to_return,
                                 sequence=sequence, one_hot_axis=1)
    return to_return

def seq_to_one_hot_fill_in_array(zeros_array, sequence, one_hot_axis):
    assert one_hot_axis==0 or one_hot_axis==1
    if (one_hot_axis==0):
        assert zeros_array.shape[1] == len(sequence)
    elif (one_hot_axis==1): 
        assert zeros_array.shape[0] == len(sequence)
    #will mutate zeros_array
    for (i,char) in enumerate(sequence):
        if (char=="A" or char=="a"):
            char_idx = 0
        elif (char=="C" or char=="c"):
            char_idx = 1
        elif (char=="G" or char=="g"):
            char_idx = 2
        elif (char=="T" or char=="t"):
            char_idx = 3
        elif (char=="N" or char=="n"):
            continue #leave that pos as all 0's
        else:
            raise RuntimeError("Unsupported character: "+str(char))
        if (one_hot_axis==0):
            zeros_array[char_idx,i] = 1
        elif (one_hot_axis==1):
            zeros_array[i,char_idx] = 1
            
full_peaks = []
with open("sox2_peaks.fa.out") as peaks_file:
    for line in peaks_file:
        if ">" in line:
            continue
        full_peaks.append(one_hot_encode_along_channel_axis(line.strip()))
In [7]:
mr = ModiscoResult("../modisco.h5")
mr.open()
tasks = [x.split("/")[0] for x in mr.tasks()]

output_dir = Path("/srv/www/kundaje/avsec/chipnexus/oct-sox-nanog-klf/models/n_dil_layers=9/modisco/all/profile")
all_patterns = read_pkl(output_dir / 'patterns.pkl')

patterns = [x for x in all_patterns if x.attrs['pattern_group'] == 'nte']
In [8]:
from collections import OrderedDict

modisco_motif = OrderedDict()
for p in patterns:
    metacluster, pattern = p.name.split('/')
    modisco_motif[metacluster] = OrderedDict()
In [9]:
for p in patterns:
    metacluster, pattern = p.name.split('/')
    modisco_motif[metacluster][pattern] = p.hyp_contrib['Sox2'] #p.seq
In [10]:
# # one hot encode seqlets
# seqlets_dir = "modisco_seqlets/"
# metacluster_ids = list(modisco_motif.keys())
# pattern_ids = OrderedDict()
# for metacluster_id in metacluster_ids:
#     pattern_ids[metacluster_id] = list(modisco_motif[metacluster_id].keys())

# patterns_to_seqlets = OrderedDict()

# for metacluster_id in metacluster_ids:
#     patterns_to_seqlets[metacluster_id] = OrderedDict()
#     for pattern_id in pattern_ids[metacluster_id]:
#         patterns_to_seqlets[metacluster_id][pattern_id] = []
#         with open(seqlets_dir+metacluster_id+"/"+pattern_id+".bed.out") as seq_f:
#             for line in seq_f:
#                 if ">" in line:
#                     continue
#                 patterns_to_seqlets[metacluster_id][pattern_id].append(one_hot_encode_along_channel_axis(line.strip()))
In [11]:
metacluster_ids = list(modisco_motif.keys())
pattern_ids = OrderedDict()
for metacluster_id in metacluster_ids:
    pattern_ids[metacluster_id] = list(modisco_motif[metacluster_id].keys())

patterns_to_kmers = OrderedDict()
patterns_to_seqlets = OrderedDict()

for metacluster_id in metacluster_ids:
    patterns_to_kmers[metacluster_id] = OrderedDict()
    patterns_to_seqlets[metacluster_id] = OrderedDict()
    for pattern_id in pattern_ids[metacluster_id]:
        i,j = md.get_trim_idx(metacluster_id+'/'+pattern_id) 
        full_kmers = md.get_seq(metacluster_id+'/'+pattern_id)
        full_seqlets = md.get_imp(metacluster_id+'/'+pattern_id, task='Sox2', which='profile')
        patterns_to_kmers[metacluster_id][pattern_id] = full_kmers[:, i:j]
        patterns_to_seqlets[metacluster_id][pattern_id] = full_seqlets[:, i:j]
In [12]:
# they see me rollin
def rolling_window(a, window):
    shape = a.shape[:-1] + (a.shape[-1] - window + 1, window)
    strides = a.strides + (a.strides[-1],)
    return np.lib.stride_tricks.as_strided(a, shape=shape, strides=strides)
In [13]:
# getting best matches
from scipy import stats
import seaborn as sns
import matplotlib.pyplot as plt
from concise.utils.plot import seqlogo_fig, seqlogo
from modisco.visualization import viz_sequence

%matplotlib inline

psam_id_to_best_metacluster_id = OrderedDict()
psam_id_to_best_pattern_id = OrderedDict()
best_sub_patterns = OrderedDict()
for idx, psam in enumerate(oct4_selex_psams):
    print("PSAM for Consensus: "+consensus_seqs[idx])
    viz_sequence.plot_weights(psam, figsize=(10,3))
    
    best_match_score = float("-inf")
    for metacluster_id in metacluster_ids:
        for pattern_id in pattern_ids[metacluster_id]:
            motif = modisco_motif[metacluster_id][pattern_id]
            per_position_scores_fwd = np.sum(np.multiply(rolling_window(motif.T, len(psam)), psam.T[:,None,:]), axis=(0,2))
            per_position_scores_rev = np.sum(np.multiply(rolling_window(motif.T, len(psam)), psam[::-1,::-1].T[:,None,:]), axis=(0,2))
            per_position_scores = np.max(np.array([per_position_scores_fwd, per_position_scores_rev]), axis=0)
            per_position_scores_fwd_or_rev = np.argmax(np.array([per_position_scores_fwd, per_position_scores_rev]), axis=0)
            argmax_pos = np.argmax(per_position_scores)
            
            if per_position_scores[argmax_pos] >= best_match_score:    # perhaps use correlation (stats.pearsonr) instead of conv?
                best_metacluster_id = metacluster_id
                best_pattern_id = pattern_id
                best_sub_pattern = modisco_motif[best_metacluster_id][best_pattern_id][argmax_pos:argmax_pos+len(psam)]
                best_match_score = per_position_scores[argmax_pos]
                motif_orientation = per_position_scores_fwd_or_rev[argmax_pos]
                if motif_orientation == 1:
                    best_sub_pattern = best_sub_pattern[::-1,::-1]

    psam_id_to_best_metacluster_id[idx] = best_metacluster_id
    psam_id_to_best_pattern_id[idx] = best_pattern_id
    best_sub_patterns[idx] = best_sub_pattern
    print("Best match is: "+best_metacluster_id+"/"+best_pattern_id+" with orientation: "+str(motif_orientation))
    viz_sequence.plot_weights(best_sub_pattern, figsize=(10, 3))
PSAM for Consensus: AACGAACC
Best match is: metacluster_3/pattern_3 with orientation: 0
PSAM for Consensus: AACTACGC
Best match is: metacluster_3/pattern_3 with orientation: 0
PSAM for Consensus: AACGAACC
Best match is: metacluster_3/pattern_3 with orientation: 0
PSAM for Consensus: GGGTTCGGT
Best match is: metacluster_3/pattern_3 with orientation: 1
PSAM for Consensus: GGGTTATGT
Best match is: metacluster_0/pattern_15 with orientation: 0
PSAM for Consensus: GGGTTCGGT
Best match is: metacluster_3/pattern_3 with orientation: 1
PSAM for Consensus: AACGAACCCC
Best match is: metacluster_3/pattern_3 with orientation: 0
PSAM for Consensus: ACACGACCTC
Best match is: metacluster_3/pattern_3 with orientation: 0
PSAM for Consensus: ACCATACCTC
Best match is: metacluster_3/pattern_3 with orientation: 0
In [14]:
data = OrderedDict() # keys are subseqlets (as strings)

def create_dict_val_for_subseqlet(key):
    data[key] = OrderedDict()
    
    data[key]['counts'] = 0        # how many times this subseqlet or its revcomp showed up
    data[key]['kmers'] = []        # list of kmers where subseqlet showed up
    data[key]['seqlets'] = []      # list of seqlets where subseqlet showed up
    data[key]['revcomps'] = []     # list of whether subseqlet was found as revcomp (TRUE) or not (FALSE)
    data[key]['modisco_patterns'] = []    # list of modisco patterns where subseqlet showed up
    data[key]['psams'] = []        # list of nrlb psams where subseqlet showed up
    data[key]['importances'] = []  # list of sum(abs(importances))
    data[key]['affinities'] = []   # list of conv with PSAM
    data[key]['modisco_matches'] = []     # list of conv with modisco motif (TODO: change to conv with modisco hypothetical motif)
    
    data[key]['num_peaks_containing_subseqlet'] = 0
    data[key]['freq_in_peaks'] = 0   
    key_seq = one_hot_encode_along_channel_axis(key)
    for peak in full_peaks:
        matches_fwd = np.sum(np.multiply(rolling_window(peak.T, len(key_seq)), key_seq.T[:,None,:]), axis=(0,2))
        matches_rev = np.sum(np.multiply(rolling_window(peak.T, len(key_seq)), key_seq[::-1,::-1].T[:,None,:]), axis=(0,2))
        matches = max(np.count_nonzero(matches_fwd[matches_fwd==len(key_seq)]), np.count_nonzero(matches_rev[matches_rev==len(key_seq)]))
        data[key]['freq_in_peaks'] += matches
        if matches > 0:
            data[key]['num_peaks_containing_subseqlet'] += 1
In [15]:
for idx, psam in enumerate(oct4_selex_psams):
    print("PSAM for consensus: "+consensus_seqs[idx])
    viz_sequence.plot_weights(psam, figsize=(10,3))
    peak_affinities = []
    for peak in full_peaks:
        peak_affinities.append(np.sum(np.multiply(rolling_window(peak.T, len(psam)), psam.T[:,None,:]), axis=(0,2)))
    peak_affinities = np.array(peak_affinities).flatten()
    
    metacluster_id = psam_id_to_best_metacluster_id[idx]
    pattern_id = psam_id_to_best_pattern_id[idx]
    print("best modisco match is "+metacluster_id+"/"+pattern_id+" with best matching sub-pattern: ")
    viz_sequence.plot_weights(best_sub_patterns[idx], figsize=(10,3))
    seqlet_affinities = []
    seqlet_affinities_argmax = []
    seqlet_affinities_orientation = []
    for seqlet in patterns_to_kmers[metacluster_id][pattern_id]:
        per_position_affinities_fwd = np.sum(np.multiply(rolling_window(seqlet.T, len(psam)), psam.T[:,None,:]), axis=(0,2))
        per_position_affinities_rev = np.sum(np.multiply(rolling_window(seqlet.T, len(psam)), psam[::-1,::-1].T[:,None,:]), axis=(0,2))
        per_position_affinities = np.max(np.array([per_position_affinities_fwd, per_position_affinities_rev]),
                                         axis=0)
        per_position_affinities_fwd_or_rev = np.argmax(np.array([per_position_affinities_fwd, per_position_affinities_rev]),
                                                       axis=0)
        argmax_pos = np.argmax(per_position_affinities)
        seqlet_affinities_argmax.append(argmax_pos)
        seqlet_affinities_orientation.append(per_position_affinities_fwd_or_rev[argmax_pos])
        seqlet_affinities.append(per_position_affinities[argmax_pos])
    seqlet_affinities = np.array(seqlet_affinities).flatten()

    sns.distplot(peak_affinities, bins='auto', label='full_peak_affinities')
    sns.distplot(seqlet_affinities, bins='auto', label='seqlet_affinities')
    plt.legend(loc='upper left')
    plt.title("histogram of affinities and seqlet for pattern: "+metacluster_id+"/"+pattern_id)
    plt.show()
    
    sorted_seqlet_affinities = sorted(enumerate(seqlet_affinities), key=lambda x: x[1])
    sorted_seqlet_indices = [x[0] for x in sorted_seqlet_affinities]
    for seqlet_idx in sorted_seqlet_indices:
        if seqlet_idx % 100 == 0:
            print("finished another 100...")
        the_kmer = patterns_to_kmers[metacluster_id][pattern_id][seqlet_idx]
        imp_seqlet = patterns_to_seqlets[metacluster_id][pattern_id][seqlet_idx]
        
        pos_of_best_match_to_nrlb_within_seqlet = seqlet_affinities_argmax[seqlet_idx]
        best_matching_subseqlet = the_kmer[pos_of_best_match_to_nrlb_within_seqlet:
                                           (pos_of_best_match_to_nrlb_within_seqlet+len(psam))]
        if (seqlet_affinities_orientation[seqlet_idx]==1):
            best_matching_subseqlet = best_matching_subseqlet[::-1,::-1]
            
        # storing 
        key = onehot_to_seq(best_matching_subseqlet)
        rev_key = onehot_to_seq(best_matching_subseqlet[::-1,::-1])
        if (key not in data) and (rev_key not in data):
            create_dict_val_for_subseqlet(key)
        if data.get(key) is None:
            create_dict_val_for_subseqlet(key)

        data[key]['counts'] += 1
        data[key]['kmers'].append(the_kmer)
        data[key]['seqlets'].append(imp_seqlet)
        data[key]['revcomps'].append(seqlet_affinities_orientation[seqlet_idx])
        data[key]['modisco_patterns'].append(metacluster_id+"/"+pattern_id)
        data[key]['psams'].append(psam)
        data[key]['importances'].append(np.sum(np.abs(imp_seqlet[pos_of_best_match_to_nrlb_within_seqlet:
                                                                 (pos_of_best_match_to_nrlb_within_seqlet+len(psam))])))

        masked_best_matching_subseqlet = best_matching_subseqlet*psam

        per_position_motif_match_fwd = np.sum(np.multiply(rolling_window(modisco_motif[metacluster_id][pattern_id].T, len(the_kmer)), the_kmer.T[:,None,:]), axis=(0,2))
        per_position_motif_match_rev = np.sum(np.multiply(rolling_window(modisco_motif[metacluster_id][pattern_id].T, len(the_kmer)), the_kmer[::-1,::-1].T[:,None,:]), axis=(0,2))
        per_position_motif_match = max(np.max(per_position_motif_match_fwd), np.max(per_position_motif_match_rev))
        
        data[key]['affinities'].append(seqlet_affinities[seqlet_idx])
        data[key]['modisco_matches'].append(per_position_motif_match)
PSAM for consensus: AACGAACC
best modisco match is metacluster_3/pattern_3 with best matching sub-pattern: 
finished another 100...
finished another 100...
finished another 100...
finished another 100...
finished another 100...
finished another 100...
finished another 100...
finished another 100...
finished another 100...
PSAM for consensus: AACTACGC
best modisco match is metacluster_3/pattern_3 with best matching sub-pattern: 
finished another 100...
finished another 100...
finished another 100...
finished another 100...
finished another 100...
finished another 100...
finished another 100...
finished another 100...
finished another 100...
PSAM for consensus: AACGAACC
best modisco match is metacluster_3/pattern_3 with best matching sub-pattern: 
finished another 100...
finished another 100...
finished another 100...
finished another 100...
finished another 100...
finished another 100...
finished another 100...
finished another 100...
finished another 100...
PSAM for consensus: GGGTTCGGT
best modisco match is metacluster_3/pattern_3 with best matching sub-pattern: 
finished another 100...
finished another 100...
finished another 100...
finished another 100...
finished another 100...
finished another 100...
finished another 100...
finished another 100...
finished another 100...
PSAM for consensus: GGGTTATGT
best modisco match is metacluster_0/pattern_15 with best matching sub-pattern: 
finished another 100...
finished another 100...
finished another 100...
PSAM for consensus: GGGTTCGGT
best modisco match is metacluster_3/pattern_3 with best matching sub-pattern: 
finished another 100...
finished another 100...
finished another 100...
finished another 100...
finished another 100...
finished another 100...
finished another 100...
finished another 100...
finished another 100...
PSAM for consensus: AACGAACCCC
best modisco match is metacluster_3/pattern_3 with best matching sub-pattern: 
finished another 100...
finished another 100...
finished another 100...
finished another 100...
finished another 100...
finished another 100...
finished another 100...
finished another 100...
finished another 100...
PSAM for consensus: ACACGACCTC
best modisco match is metacluster_3/pattern_3 with best matching sub-pattern: 
finished another 100...
finished another 100...
finished another 100...
finished another 100...
finished another 100...
finished another 100...
finished another 100...
finished another 100...
finished another 100...
PSAM for consensus: ACCATACCTC
best modisco match is metacluster_3/pattern_3 with best matching sub-pattern: 
finished another 100...
finished another 100...
finished another 100...
finished another 100...
finished another 100...
finished another 100...
finished another 100...
finished another 100...
finished another 100...
In [16]:
lines = 10
print('\033[1m'+"counts \t best subseqlet \t mean importance \t var importance \t mean affinity (conv with PSAM) \t modisco match (conv with motif) \t # peaks w subseqlet \t # subseqlets in peaks"+'\033[0m')
for key in data:
    if lines == 0:
        break
    print(str(data[key]['counts'])+" \t "
          +key+" \t \t "
          +str(np.mean(data[key]['importances']))+" \t \t "
          +str(np.var(data[key]['importances']))+" \t \t \t "
          +str(np.mean(data[key]['affinities']))+" \t \t \t "
          +str(np.mean(data[key]['modisco_matches']))+" \t \t \t "
          +str(data[key]['num_peaks_containing_subseqlet'])+" \t \t \t "
          +str(data[key]['freq_in_peaks']))
    lines -= 1
counts 	 best subseqlet 	 mean importance 	 var importance 	 mean affinity (conv with PSAM) 	 modisco match (conv with motif) 	 # peaks w subseqlet 	 # subseqlets in peaks
6 	 AGGGCCCC 	 	 0.15896662 	 	 1.2960216e-05 	 	 	 1.322584532360931 	 	 	 0.14625814384622268 	 	 	 250 	 	 	 256
1 	 GCTCTCCC 	 	 0.067928694 	 	 0.0 	 	 	 2.4176685913634506 	 	 	 0.20032791079631623 	 	 	 491 	 	 	 502
4 	 AGCGCCCT 	 	 0.098678395 	 	 6.3617044e-05 	 	 	 2.3755446828207716 	 	 	 0.1545731429790857 	 	 	 93 	 	 	 94
6 	 GCCCTCTC 	 	 0.044738267 	 	 3.826668e-08 	 	 	 2.6010423716028677 	 	 	 0.17219194447965194 	 	 	 361 	 	 	 367
3 	 GCCCCCAA 	 	 0.045206208 	 	 1.3877788e-17 	 	 	 2.583512008747562 	 	 	 0.19556033287335825 	 	 	 399 	 	 	 402
3 	 TCCCCCTA 	 	 0.13128933 	 	 0.0 	 	 	 2.584666837037669 	 	 	 0.18913652668118164 	 	 	 223 	 	 	 226
12 	 GCCCCCTC 	 	 0.15680356 	 	 1.270622e-05 	 	 	 2.744585463985111 	 	 	 0.15996080705172488 	 	 	 359 	 	 	 369
6 	 CCCCCCTT 	 	 0.20427038 	 	 4.3302705e-05 	 	 	 2.7802702103541974 	 	 	 0.18341923400758564 	 	 	 438 	 	 	 447
2 	 AGCGCGCC 	 	 0.21166363 	 	 0.0 	 	 	 2.592671521045085 	 	 	 0.1757841600749192 	 	 	 30 	 	 	 34
6 	 GACCCTCC 	 	 0.16897821 	 	 4.1761777e-07 	 	 	 2.899984381824361 	 	 	 0.18650182554597985 	 	 	 255 	 	 	 260
In [17]:
for key in data:
    for idx in range(len(data[key]['kmers'])):
        data[key]['kmers'][idx] = data[key]['kmers'][idx].tolist()
    for idx in range(len(data[key]['seqlets'])):
        data[key]['seqlets'][idx] = data[key]['seqlets'][idx].tolist()
    for idx in range(len(data[key]['psams'])):
        data[key]['psams'][idx] = data[key]['psams'][idx].tolist()
In [18]:
# json dump the dict with all the data
import json

def default(o):
    if isinstance(o, np.int64):
        return int(o)

with open('data.json', 'w') as fp:
    json.dump(data, fp, default=default)
In [19]:
# generate csv with averages
import csv

with open('sox2_affinities.csv', 'w', newline='') as csvfile:
    writer = csv.writer(csvfile, dialect='excel')
    writer.writerow(["subseqlet", "counts", "mean importance", "var importance",
                    "mean affinity (conv with PSAM)", "modisco match (conv with motif)",
                    "# peaks w subseqlet", "# subseqlets in peaks"])
    for key in data:
        writer.writerow([key, str(data[key]['counts']),
          str(np.mean(data[key]['importances'])),
          str(np.var(data[key]['importances'])),
          str(np.mean(data[key]['affinities'])),
          str(np.mean(data[key]['modisco_matches'])),
          str(data[key]['num_peaks_containing_subseqlet']),
          str(data[key]['freq_in_peaks'])])