In [1]:
import os

os.environ["CUDA_VISIBLE_DEVICES"] = "3"
In [2]:
from basepair.imports import *
Using TensorFlow backend.
In [3]:
modisco_dir = f"/srv/scratch/avsec/workspace/chipnexus/data/processed/chipnexus/exp/models/oct-sox-nanog-klf/models/n_dil_layers=9/modisco/all/profile/"
imp_scores =  f"/srv/scratch/avsec/workspace/chipnexus/data/processed/chipnexus/exp/models/oct-sox-nanog-klf/models/n_dil_layers=9/grad.all.h5"
In [4]:
from basepair.modisco.table import ModiscoData

md = ModiscoData.load(modisco_dir, imp_scores)
100%|██████████| 142/142 [32:22<00:00, 13.68s/it]
In [5]:
# get NRLB PSAMs
import csv
import numpy as np

oct4_selex_psams = []
consensus_seqs = []

with open('ESRRB-NRLBConfig.csv', 'r') as csv_file:
    csv_reader = csv.reader(csv_file)
    next(csv_reader, None)  # skip the headers
    for row in csv_reader:
        consensus_seqs.append(row[24])
        count = 26
        this_selex_psam = []
        while(row[count] != 'NSB>'):
            this_selex_psam.append(float(row[count]))
            count+=1
        this_selex_psam = np.array(this_selex_psam)
        this_selex_psam = this_selex_psam.reshape(int((count-26)/4), 4)
        oct4_selex_psams.append(this_selex_psam)
        
oct4_selex_psams = np.array(oct4_selex_psams)
In [6]:
# one hot encode full peak
import os

def onehot_to_seq(onehot):
    seq = ""
    for pos in range(len(onehot)):
        char_idx = [i for i, e in enumerate(onehot[pos]) if e != 0][0]
        if char_idx == 0:
            char = 'A'
        elif char_idx == 1:
            char = "C"
        elif char_idx == 2:
            char = "G"
        elif char_idx == 3:
            char = "T"
        seq +=  char
    return seq

def one_hot_encode_along_channel_axis(sequence):
    to_return = np.zeros((len(sequence),4), dtype=np.int8)
    seq_to_one_hot_fill_in_array(zeros_array=to_return,
                                 sequence=sequence, one_hot_axis=1)
    return to_return

def seq_to_one_hot_fill_in_array(zeros_array, sequence, one_hot_axis):
    assert one_hot_axis==0 or one_hot_axis==1
    if (one_hot_axis==0):
        assert zeros_array.shape[1] == len(sequence)
    elif (one_hot_axis==1): 
        assert zeros_array.shape[0] == len(sequence)
    #will mutate zeros_array
    for (i,char) in enumerate(sequence):
        if (char=="A" or char=="a"):
            char_idx = 0
        elif (char=="C" or char=="c"):
            char_idx = 1
        elif (char=="G" or char=="g"):
            char_idx = 2
        elif (char=="T" or char=="t"):
            char_idx = 3
        elif (char=="N" or char=="n"):
            continue #leave that pos as all 0's
        else:
            raise RuntimeError("Unsupported character: "+str(char))
        if (one_hot_axis==0):
            zeros_array[char_idx,i] = 1
        elif (one_hot_axis==1):
            zeros_array[i,char_idx] = 1
            
full_peaks = []
with open("nanog_peaks.fa.out") as peaks_file:
    for line in peaks_file:
        if ">" in line:
            continue
        full_peaks.append(one_hot_encode_along_channel_axis(line.strip()))
In [7]:
mr = ModiscoResult("../modisco.h5")
mr.open()
tasks = [x.split("/")[0] for x in mr.tasks()]

output_dir = Path("/srv/www/kundaje/avsec/chipnexus/oct-sox-nanog-klf/models/n_dil_layers=9/modisco/all/profile")
all_patterns = read_pkl(output_dir / 'patterns.pkl')

patterns = [x for x in all_patterns if x.attrs['pattern_group'] == 'nte']
In [8]:
from collections import OrderedDict

modisco_motif = OrderedDict()
for p in patterns:
    metacluster, pattern = p.name.split('/')
    modisco_motif[metacluster] = OrderedDict()
In [9]:
for p in patterns:
    metacluster, pattern = p.name.split('/')
    modisco_motif[metacluster][pattern] = p.hyp_contrib['Nanog'] #p.seq
In [10]:
# # one hot encode seqlets
# seqlets_dir = "modisco_seqlets/"
# metacluster_ids = list(modisco_motif.keys())
# pattern_ids = OrderedDict()
# for metacluster_id in metacluster_ids:
#     pattern_ids[metacluster_id] = list(modisco_motif[metacluster_id].keys())

# patterns_to_seqlets = OrderedDict()

# for metacluster_id in metacluster_ids:
#     patterns_to_seqlets[metacluster_id] = OrderedDict()
#     for pattern_id in pattern_ids[metacluster_id]:
#         patterns_to_seqlets[metacluster_id][pattern_id] = []
#         with open(seqlets_dir+metacluster_id+"/"+pattern_id+".bed.out") as seq_f:
#             for line in seq_f:
#                 if ">" in line:
#                     continue
#                 patterns_to_seqlets[metacluster_id][pattern_id].append(one_hot_encode_along_channel_axis(line.strip()))
In [11]:
metacluster_ids = list(modisco_motif.keys())
pattern_ids = OrderedDict()
for metacluster_id in metacluster_ids:
    pattern_ids[metacluster_id] = list(modisco_motif[metacluster_id].keys())

patterns_to_kmers = OrderedDict()
patterns_to_seqlets = OrderedDict()

for metacluster_id in metacluster_ids:
    patterns_to_kmers[metacluster_id] = OrderedDict()
    patterns_to_seqlets[metacluster_id] = OrderedDict()
    for pattern_id in pattern_ids[metacluster_id]:
        i,j = md.get_trim_idx(metacluster_id+'/'+pattern_id) 
        full_kmers = md.get_seq(metacluster_id+'/'+pattern_id)
        full_seqlets = md.get_imp(metacluster_id+'/'+pattern_id, task='Nanog', which='profile')
        patterns_to_kmers[metacluster_id][pattern_id] = full_kmers[:, i:j]
        patterns_to_seqlets[metacluster_id][pattern_id] = full_seqlets[:, i:j]
In [12]:
# they see me rollin
def rolling_window(a, window):
    shape = a.shape[:-1] + (a.shape[-1] - window + 1, window)
    strides = a.strides + (a.strides[-1],)
    return np.lib.stride_tricks.as_strided(a, shape=shape, strides=strides)
In [13]:
# getting best matches
from scipy import stats
import seaborn as sns
import matplotlib.pyplot as plt
from concise.utils.plot import seqlogo_fig, seqlogo
from modisco.visualization import viz_sequence

%matplotlib inline

psam_id_to_best_metacluster_id = OrderedDict()
psam_id_to_best_pattern_id = OrderedDict()
best_sub_patterns = OrderedDict()
for idx, psam in enumerate(oct4_selex_psams):
    print("PSAM for Consensus: "+consensus_seqs[idx])
    viz_sequence.plot_weights(psam, figsize=(10,3))
    
    best_match_score = float("-inf")
    for metacluster_id in metacluster_ids:
        for pattern_id in pattern_ids[metacluster_id]:
            motif = modisco_motif[metacluster_id][pattern_id]
            per_position_scores_fwd = np.sum(np.multiply(rolling_window(motif.T, len(psam)), psam.T[:,None,:]), axis=(0,2))
            per_position_scores_rev = np.sum(np.multiply(rolling_window(motif.T, len(psam)), psam[::-1,::-1].T[:,None,:]), axis=(0,2))
            per_position_scores = np.max(np.array([per_position_scores_fwd, per_position_scores_rev]), axis=0)
            per_position_scores_fwd_or_rev = np.argmax(np.array([per_position_scores_fwd, per_position_scores_rev]), axis=0)
            argmax_pos = np.argmax(per_position_scores)
            
            if per_position_scores[argmax_pos] >= best_match_score:    # perhaps use correlation (stats.pearsonr) instead of conv?
                best_metacluster_id = metacluster_id
                best_pattern_id = pattern_id
                best_sub_pattern = modisco_motif[best_metacluster_id][best_pattern_id][argmax_pos:argmax_pos+len(psam)]
                best_match_score = per_position_scores[argmax_pos]
                motif_orientation = per_position_scores_fwd_or_rev[argmax_pos]
                if motif_orientation == 1:
                    best_sub_pattern = best_sub_pattern[::-1,::-1]

    psam_id_to_best_metacluster_id[idx] = best_metacluster_id
    psam_id_to_best_pattern_id[idx] = best_pattern_id
    best_sub_patterns[idx] = best_sub_pattern
    print("Best match is: "+best_metacluster_id+"/"+best_pattern_id+" with orientation: "+str(motif_orientation))
    viz_sequence.plot_weights(best_sub_pattern, figsize=(10, 3))
PSAM for Consensus: CAACGTCC
Best match is: metacluster_0/pattern_1 with orientation: 0
PSAM for Consensus: CCAACGTC
Best match is: metacluster_0/pattern_1 with orientation: 0
PSAM for Consensus: AACGTCAC
Best match is: metacluster_0/pattern_16 with orientation: 0
PSAM for Consensus: ACCTTCCCC
Best match is: metacluster_0/pattern_3 with orientation: 0
PSAM for Consensus: GACCTTCCC
Best match is: metacluster_0/pattern_3 with orientation: 0
PSAM for Consensus: TGACCTTGA
Best match is: metacluster_0/pattern_1 with orientation: 0
PSAM for Consensus: AAGGTCATGG
Best match is: metacluster_0/pattern_1 with orientation: 1
PSAM for Consensus: CAAGGTCATG
Best match is: metacluster_0/pattern_1 with orientation: 1
PSAM for Consensus: TCAAGGTCAT
Best match is: metacluster_0/pattern_1 with orientation: 1
In [14]:
data = OrderedDict() # keys are subseqlets (as strings)

def create_dict_val_for_subseqlet(key):
    data[key] = OrderedDict()
    
    data[key]['counts'] = 0        # how many times this subseqlet or its revcomp showed up
    data[key]['kmers'] = []        # list of kmers where subseqlet showed up
    data[key]['seqlets'] = []      # list of seqlets where subseqlet showed up
    data[key]['revcomps'] = []     # list of whether subseqlet was found as revcomp (TRUE) or not (FALSE)
    data[key]['modisco_patterns'] = []    # list of modisco patterns where subseqlet showed up
    data[key]['psams'] = []        # list of nrlb psams where subseqlet showed up
    data[key]['importances'] = []  # list of sum(abs(importances))
    data[key]['affinities'] = []   # list of conv with PSAM
    data[key]['modisco_matches'] = []     # list of conv with modisco motif (TODO: change to conv with modisco hypothetical motif)
    
    data[key]['num_peaks_containing_subseqlet'] = 0
    data[key]['freq_in_peaks'] = 0   
    key_seq = one_hot_encode_along_channel_axis(key)
    for peak in full_peaks:
        matches_fwd = np.sum(np.multiply(rolling_window(peak.T, len(key_seq)), key_seq.T[:,None,:]), axis=(0,2))
        matches_rev = np.sum(np.multiply(rolling_window(peak.T, len(key_seq)), key_seq[::-1,::-1].T[:,None,:]), axis=(0,2))
        matches = max(np.count_nonzero(matches_fwd[matches_fwd==len(key_seq)]), np.count_nonzero(matches_rev[matches_rev==len(key_seq)]))
        data[key]['freq_in_peaks'] += matches
        if matches > 0:
            data[key]['num_peaks_containing_subseqlet'] += 1
In [15]:
for idx, psam in enumerate(oct4_selex_psams):
    print("PSAM for consensus: "+consensus_seqs[idx])
    viz_sequence.plot_weights(psam, figsize=(10,3))
    peak_affinities = []
    for peak in full_peaks:
        peak_affinities.append(np.sum(np.multiply(rolling_window(peak.T, len(psam)), psam.T[:,None,:]), axis=(0,2)))
    peak_affinities = np.array(peak_affinities).flatten()
    
    metacluster_id = psam_id_to_best_metacluster_id[idx]
    pattern_id = psam_id_to_best_pattern_id[idx]
    print("best modisco match is "+metacluster_id+"/"+pattern_id+" with best matching sub-pattern: ")
    viz_sequence.plot_weights(best_sub_patterns[idx], figsize=(10,3))
    seqlet_affinities = []
    for seqlet_idx, seqlet in enumerate(patterns_to_kmers[metacluster_id][pattern_id]):
        imp_seqlet = patterns_to_seqlets[metacluster_id][pattern_id][seqlet_idx]
        per_position_affinities_fwd = np.sum(np.multiply(rolling_window(seqlet.T, len(psam)), psam.T[:,None,:]), axis=(0,2))
        per_position_affinities_rev = np.sum(np.multiply(rolling_window(seqlet.T, len(psam)), psam[::-1,::-1].T[:,None,:]), axis=(0,2))
        per_position_affinities = np.max(np.array([per_position_affinities_fwd, per_position_affinities_rev]),
                                         axis=0)
        per_position_affinities_fwd_or_rev = np.argmax(np.array([per_position_affinities_fwd, per_position_affinities_rev]),
                                                       axis=0)
        if len(per_position_affinities) == 0:
            continue
        
        argmax_pos = np.argmax(per_position_affinities)
        seqlet_affinities.append(per_position_affinities[argmax_pos])

        if seqlet_idx % 100 == 0:
            print("finished another 100...")
        
        pos_of_best_match_to_nrlb_within_seqlet = np.argmax(per_position_affinities)
        best_matching_subseqlet = seqlet[pos_of_best_match_to_nrlb_within_seqlet:
                                         (pos_of_best_match_to_nrlb_within_seqlet+len(psam))]
        if (per_position_affinities_fwd_or_rev[argmax_pos]==1):
            best_matching_subseqlet = best_matching_subseqlet[::-1,::-1]
            
        # storing 
        key = onehot_to_seq(best_matching_subseqlet)
        rev_key = onehot_to_seq(best_matching_subseqlet[::-1,::-1])
        if (key not in data) and (rev_key not in data):
            create_dict_val_for_subseqlet(key)
        if data.get(key) is None:
            create_dict_val_for_subseqlet(key)

        data[key]['counts'] += 1
        data[key]['kmers'].append(seqlet)
        data[key]['seqlets'].append(imp_seqlet)
        data[key]['revcomps'].append(per_position_affinities_fwd_or_rev[argmax_pos])
        data[key]['modisco_patterns'].append(metacluster_id+"/"+pattern_id)
        data[key]['psams'].append(psam)
        data[key]['importances'].append(np.sum(np.abs(imp_seqlet[pos_of_best_match_to_nrlb_within_seqlet:
                                                                 (pos_of_best_match_to_nrlb_within_seqlet+len(psam))])))

        masked_best_matching_subseqlet = best_matching_subseqlet*psam

        per_position_motif_match_fwd = np.sum(np.multiply(rolling_window(modisco_motif[metacluster_id][pattern_id].T, len(seqlet)), seqlet.T[:,None,:]), axis=(0,2))
        per_position_motif_match_rev = np.sum(np.multiply(rolling_window(modisco_motif[metacluster_id][pattern_id].T, len(seqlet)), seqlet[::-1,::-1].T[:,None,:]), axis=(0,2))
        per_position_motif_match = max(np.max(per_position_motif_match_fwd), np.max(per_position_motif_match_rev))
        
        data[key]['affinities'].append(seqlet_affinities[seqlet_idx])
        data[key]['modisco_matches'].append(per_position_motif_match)
        
    sns.distplot(peak_affinities, bins='auto', label='full_peak_affinities')
    sns.distplot(seqlet_affinities, bins='auto', label='seqlet_affinities')
    plt.legend(loc='upper left')
    plt.title("histogram of affinities and seqlet for pattern: "+metacluster_id+"/"+pattern_id)
    plt.show()
PSAM for consensus: CAACGTCC
best modisco match is metacluster_0/pattern_1 with best matching sub-pattern: 
finished another 100...
finished another 100...
finished another 100...
finished another 100...
finished another 100...
finished another 100...
finished another 100...
finished another 100...
finished another 100...
finished another 100...
finished another 100...
finished another 100...
finished another 100...
finished another 100...
finished another 100...
finished another 100...
finished another 100...
finished another 100...
finished another 100...
finished another 100...
finished another 100...
finished another 100...
finished another 100...
finished another 100...
finished another 100...
finished another 100...
finished another 100...
finished another 100...
finished another 100...
finished another 100...
finished another 100...
finished another 100...
PSAM for consensus: CCAACGTC
best modisco match is metacluster_0/pattern_1 with best matching sub-pattern: 
finished another 100...
finished another 100...
finished another 100...
finished another 100...
finished another 100...
finished another 100...
finished another 100...
finished another 100...
finished another 100...
finished another 100...
finished another 100...
finished another 100...
finished another 100...
finished another 100...
finished another 100...
finished another 100...
finished another 100...
finished another 100...
finished another 100...
finished another 100...
finished another 100...
finished another 100...
finished another 100...
finished another 100...
finished another 100...
finished another 100...
finished another 100...
finished another 100...
finished another 100...
finished another 100...
finished another 100...
finished another 100...
PSAM for consensus: AACGTCAC
best modisco match is metacluster_0/pattern_16 with best matching sub-pattern: 
finished another 100...
finished another 100...
PSAM for consensus: ACCTTCCCC
best modisco match is metacluster_0/pattern_3 with best matching sub-pattern: 
finished another 100...
finished another 100...
finished another 100...
finished another 100...
finished another 100...
finished another 100...
finished another 100...
finished another 100...
finished another 100...
finished another 100...
finished another 100...
PSAM for consensus: GACCTTCCC
best modisco match is metacluster_0/pattern_3 with best matching sub-pattern: 
finished another 100...
finished another 100...
finished another 100...
finished another 100...
finished another 100...
finished another 100...
finished another 100...
finished another 100...
finished another 100...
finished another 100...
finished another 100...
PSAM for consensus: TGACCTTGA
best modisco match is metacluster_0/pattern_1 with best matching sub-pattern: 
finished another 100...
finished another 100...
finished another 100...
finished another 100...
finished another 100...
finished another 100...
finished another 100...
finished another 100...
finished another 100...
finished another 100...
finished another 100...
finished another 100...
finished another 100...
finished another 100...
finished another 100...
finished another 100...
finished another 100...
finished another 100...
finished another 100...
finished another 100...
finished another 100...
finished another 100...
finished another 100...
finished another 100...
finished another 100...
finished another 100...
finished another 100...
finished another 100...
finished another 100...
finished another 100...
finished another 100...
finished another 100...
PSAM for consensus: AAGGTCATGG
best modisco match is metacluster_0/pattern_1 with best matching sub-pattern: 
PSAM for consensus: CAAGGTCATG
best modisco match is metacluster_0/pattern_1 with best matching sub-pattern: 
PSAM for consensus: TCAAGGTCAT
best modisco match is metacluster_0/pattern_1 with best matching sub-pattern: 
In [16]:
lines = 10
print('\033[1m'+"counts \t best subseqlet \t mean importance \t var importance \t mean affinity (conv with PSAM) \t modisco match (conv with motif) \t # peaks w subseqlet \t # subseqlets in peaks"+'\033[0m')
for key in data:
    if lines == 0:
        break
    print(str(data[key]['counts'])+" \t "
          +key+" \t \t "
          +str(np.mean(data[key]['importances']))+" \t \t "
          +str(np.var(data[key]['importances']))+" \t \t \t "
          +str(np.mean(data[key]['affinities']))+" \t \t \t "
          +str(np.mean(data[key]['modisco_matches']))+" \t \t \t "
          +str(data[key]['num_peaks_containing_subseqlet'])+" \t \t \t "
          +str(data[key]['freq_in_peaks']))
    lines -= 1
counts 	 best subseqlet 	 mean importance 	 var importance 	 mean affinity (conv with PSAM) 	 modisco match (conv with motif) 	 # peaks w subseqlet 	 # subseqlets in peaks
401 	 TATTGTTC 	 	 0.6486699 	 	 0.05221678 	 	 	 5.396753067168517 	 	 	 0.525682793872351 	 	 	 1057 	 	 	 1079
1 	 TAAGACAG 	 	 0.6495832 	 	 0.0 	 	 	 3.8326865831759735 	 	 	 0.048759771880730295 	 	 	 1125 	 	 	 1190
151 	 CTTTGTTT 	 	 0.49674335 	 	 0.046326328 	 	 	 4.531007507168745 	 	 	 0.43030205879901395 	 	 	 2893 	 	 	 3113
57 	 CATTGTTG 	 	 0.5652236 	 	 0.07856894 	 	 	 5.048634186644363 	 	 	 0.5260693258412787 	 	 	 825 	 	 	 845
44 	 CATTGTTA 	 	 0.5689499 	 	 0.04731375 	 	 	 5.972033543928671 	 	 	 0.4943357332864295 	 	 	 999 	 	 	 1050
179 	 CATTGTTC 	 	 0.5948906 	 	 0.049722385 	 	 	 6.010816469360159 	 	 	 0.5463014129016006 	 	 	 995 	 	 	 1014
9 	 TATTCTTC 	 	 0.8140186 	 	 0.19658348 	 	 	 4.329155558354802 	 	 	 0.3748211978513267 	 	 	 808 	 	 	 820
98 	 TATTGTCC 	 	 0.6240449 	 	 0.056015443 	 	 	 6.582329054903907 	 	 	 0.5014618846506542 	 	 	 636 	 	 	 657
6 	 GAACAACA 	 	 0.7387287 	 	 0.067915164 	 	 	 5.655591877507217 	 	 	 0.28980443029397085 	 	 	 891 	 	 	 904
51 	 TATTGTTT 	 	 0.57866824 	 	 0.085564636 	 	 	 4.783763462991267 	 	 	 0.487361972193665 	 	 	 1343 	 	 	 1382
In [17]:
for key in data:
    for idx in range(len(data[key]['kmers'])):
        data[key]['kmers'][idx] = data[key]['kmers'][idx].tolist()
    for idx in range(len(data[key]['seqlets'])):
        data[key]['seqlets'][idx] = data[key]['seqlets'][idx].tolist()
    for idx in range(len(data[key]['psams'])):
        data[key]['psams'][idx] = data[key]['psams'][idx].tolist()
In [18]:
# json dump the dict with all the data
import json

def default(o):
    if isinstance(o, np.int64):
        return int(o)

with open('data.json', 'w') as fp:
    json.dump(data, fp, default=default)
In [19]:
# generate csv with averages
import csv

with open('nanog_affinities.csv', 'w', newline='') as csvfile:
    writer = csv.writer(csvfile, dialect='excel')
    writer.writerow(["subseqlet", "counts", "mean importance", "var importance",
                    "mean affinity (conv with PSAM)", "modisco match (conv with motif)",
                    "# peaks w subseqlet", "# subseqlets in peaks"])
    for key in data:
        writer.writerow([key, str(data[key]['counts']),
          str(np.mean(data[key]['importances'])),
          str(np.var(data[key]['importances'])),
          str(np.mean(data[key]['affinities'])),
          str(np.mean(data[key]['modisco_matches'])),
          str(data[key]['num_peaks_containing_subseqlet']),
          str(data[key]['freq_in_peaks'])])