import os
os.environ["CUDA_VISIBLE_DEVICES"] = "3, 5"
from basepair.imports import *
modisco_dir = f"/srv/scratch/avsec/workspace/chipnexus/data/processed/chipnexus/exp/models/oct-sox-nanog-klf/models/n_dil_layers=9/modisco/all/profile/"
imp_scores = f"/srv/scratch/avsec/workspace/chipnexus/data/processed/chipnexus/exp/models/oct-sox-nanog-klf/models/n_dil_layers=9/grad.all.h5"
from basepair.modisco.table import ModiscoData
md = ModiscoData.load(modisco_dir, imp_scores)
# get NRLB PSAMs
import csv
import numpy as np
oct4_selex_psams = []
consensus_seqs = []
with open('POU5F1-NRLBConfig.csv', 'r') as csv_file:
csv_reader = csv.reader(csv_file)
next(csv_reader, None) # skip the headers
for row in csv_reader:
consensus_seqs.append(row[24])
count = 26
this_selex_psam = []
while(row[count] != 'NSB>'):
this_selex_psam.append(float(row[count]))
count+=1
this_selex_psam = np.array(this_selex_psam)
this_selex_psam = this_selex_psam.reshape(int((count-26)/4), 4)
oct4_selex_psams.append(this_selex_psam)
oct4_selex_psams = np.array(oct4_selex_psams)
# one hot encode full peak
import os
def onehot_to_seq(onehot):
seq = ""
for pos in range(len(onehot)):
char_idx = [i for i, e in enumerate(onehot[pos]) if e != 0][0]
if char_idx == 0:
char = 'A'
elif char_idx == 1:
char = "C"
elif char_idx == 2:
char = "G"
elif char_idx == 3:
char = "T"
seq += char
return seq
def one_hot_encode_along_channel_axis(sequence):
to_return = np.zeros((len(sequence),4), dtype=np.int8)
seq_to_one_hot_fill_in_array(zeros_array=to_return,
sequence=sequence, one_hot_axis=1)
return to_return
def seq_to_one_hot_fill_in_array(zeros_array, sequence, one_hot_axis):
assert one_hot_axis==0 or one_hot_axis==1
if (one_hot_axis==0):
assert zeros_array.shape[1] == len(sequence)
elif (one_hot_axis==1):
assert zeros_array.shape[0] == len(sequence)
#will mutate zeros_array
for (i,char) in enumerate(sequence):
if (char=="A" or char=="a"):
char_idx = 0
elif (char=="C" or char=="c"):
char_idx = 1
elif (char=="G" or char=="g"):
char_idx = 2
elif (char=="T" or char=="t"):
char_idx = 3
elif (char=="N" or char=="n"):
continue #leave that pos as all 0's
else:
raise RuntimeError("Unsupported character: "+str(char))
if (one_hot_axis==0):
zeros_array[char_idx,i] = 1
elif (one_hot_axis==1):
zeros_array[i,char_idx] = 1
full_peaks = []
with open("oct4_peaks.fa.out") as peaks_file:
for line in peaks_file:
if ">" in line:
continue
full_peaks.append(one_hot_encode_along_channel_axis(line.strip()))
mr = ModiscoResult("modisco.h5")
mr.open()
tasks = [x.split("/")[0] for x in mr.tasks()]
output_dir = Path("/srv/www/kundaje/avsec/chipnexus/oct-sox-nanog-klf/models/n_dil_layers=9/modisco/all/profile")
all_patterns = read_pkl(output_dir / 'patterns.pkl')
patterns = [x for x in all_patterns if x.attrs['pattern_group'] == 'nte']
from collections import OrderedDict
modisco_motif = OrderedDict()
for p in patterns:
metacluster, pattern = p.name.split('/')
modisco_motif[metacluster] = OrderedDict()
for p in patterns:
metacluster, pattern = p.name.split('/')
modisco_motif[metacluster][pattern] = p.seq # try hyp_contrib
# # one hot encode seqlets
# seqlets_dir = "modisco_seqlets/"
# metacluster_ids = list(modisco_motif.keys())
# pattern_ids = OrderedDict()
# for metacluster_id in metacluster_ids:
# pattern_ids[metacluster_id] = list(modisco_motif[metacluster_id].keys())
# patterns_to_seqlets = OrderedDict()
# for metacluster_id in metacluster_ids:
# patterns_to_seqlets[metacluster_id] = OrderedDict()
# for pattern_id in pattern_ids[metacluster_id]:
# patterns_to_seqlets[metacluster_id][pattern_id] = []
# with open(seqlets_dir+metacluster_id+"/"+pattern_id+".bed.out") as seq_f:
# for line in seq_f:
# if ">" in line:
# continue
# patterns_to_seqlets[metacluster_id][pattern_id].append(one_hot_encode_along_channel_axis(line.strip()))
metacluster_ids = list(modisco_motif.keys())
pattern_ids = OrderedDict()
for metacluster_id in metacluster_ids:
pattern_ids[metacluster_id] = list(modisco_motif[metacluster_id].keys())
patterns_to_kmers = OrderedDict()
patterns_to_seqlets = OrderedDict()
for metacluster_id in metacluster_ids:
patterns_to_kmers[metacluster_id] = OrderedDict()
patterns_to_seqlets[metacluster_id] = OrderedDict()
for pattern_id in pattern_ids[metacluster_id]:
i,j = md.get_trim_idx(metacluster_id+'/'+pattern_id)
full_kmers = md.get_seq(metacluster_id+'/'+pattern_id)
full_seqlets = md.get_imp(metacluster_id+'/'+pattern_id, task='Oct4', which='profile')
patterns_to_kmers[metacluster_id][pattern_id] = full_kmers[:, i:j]
patterns_to_seqlets[metacluster_id][pattern_id] = full_seqlets[:, i:j]
# they see me rollin
def rolling_window(a, window):
shape = a.shape[:-1] + (a.shape[-1] - window + 1, window)
strides = a.strides + (a.strides[-1],)
return np.lib.stride_tricks.as_strided(a, shape=shape, strides=strides)
# getting best matches
from scipy import stats
import seaborn as sns
import matplotlib.pyplot as plt
from concise.utils.plot import seqlogo_fig, seqlogo
from modisco.visualization import viz_sequence
%matplotlib inline
psam_id_to_best_metacluster_id = OrderedDict()
psam_id_to_best_pattern_id = OrderedDict()
best_sub_patterns = OrderedDict()
for idx, psam in enumerate(oct4_selex_psams):
print("PSAM for Consensus: "+consensus_seqs[idx])
viz_sequence.plot_weights(psam, figsize=(10,3))
best_match_score = float("-inf")
for metacluster_id in metacluster_ids:
for pattern_id in pattern_ids[metacluster_id]:
motif = modisco_motif[metacluster_id][pattern_id]
per_position_scores_fwd = np.sum(np.multiply(rolling_window(motif.T, len(psam)), psam.T[:,None,:]), axis=(0,2))
per_position_scores_rev = np.sum(np.multiply(rolling_window(motif.T, len(psam)), psam[::-1,::-1].T[:,None,:]), axis=(0,2))
per_position_scores = np.max(np.array([per_position_scores_fwd, per_position_scores_rev]), axis=0)
per_position_scores_fwd_or_rev = np.argmax(np.array([per_position_scores_fwd, per_position_scores_rev]), axis=0)
argmax_pos = np.argmax(per_position_scores)
if per_position_scores[argmax_pos] >= best_match_score: # perhaps use correlation (stats.pearsonr) instead of conv?
best_metacluster_id = metacluster_id
best_pattern_id = pattern_id
best_sub_pattern = modisco_motif[best_metacluster_id][best_pattern_id][argmax_pos:argmax_pos+len(psam)]
best_match_score = per_position_scores[argmax_pos]
motif_orientation = per_position_scores_fwd_or_rev[argmax_pos]
if motif_orientation == 1:
best_sub_pattern = best_sub_pattern[::-1,::-1]
psam_id_to_best_metacluster_id[idx] = best_metacluster_id
psam_id_to_best_pattern_id[idx] = best_pattern_id
best_sub_patterns[idx] = best_sub_pattern
print("Best match is: "+best_metacluster_id+"/"+best_pattern_id+" with orientation: "+str(motif_orientation))
viz_sequence.plot_weights(best_sub_pattern, figsize=(10, 3))
data = OrderedDict() # keys are subseqlets (as strings)
def create_dict_val_for_subseqlet(key):
data[key] = OrderedDict()
data[key]['counts'] = 0 # how many times this subseqlet or its revcomp showed up
data[key]['kmers'] = [] # list of kmers where subseqlet showed up
data[key]['seqlets'] = [] # list of seqlets where subseqlet showed up
data[key]['revcomps'] = [] # list of whether subseqlet was found as revcomp (TRUE) or not (FALSE)
data[key]['modisco_patterns'] = [] # list of modisco patterns where subseqlet showed up
data[key]['psams'] = [] # list of nrlb psams where subseqlet showed up
data[key]['importances'] = [] # list of sum(abs(importances))
data[key]['affinities'] = [] # list of conv with PSAM
data[key]['modisco_matches'] = [] # list of conv with modisco motif (TODO: change to conv with modisco hypothetical motif)
data[key]['num_peaks_containing_subseqlet'] = 0
data[key]['freq_in_peaks'] = 0
key_seq = one_hot_encode_along_channel_axis(key)
for peak in full_peaks:
matches_fwd = np.sum(np.multiply(rolling_window(peak.T, len(key_seq)), key_seq.T[:,None,:]), axis=(0,2))
matches_rev = np.sum(np.multiply(rolling_window(peak.T, len(key_seq)), key_seq[::-1,::-1].T[:,None,:]), axis=(0,2))
matches = max(np.count_nonzero(matches_fwd[matches_fwd==len(key_seq)]), np.count_nonzero(matches_rev[matches_rev==len(key_seq)]))
data[key]['freq_in_peaks'] += matches
if matches > 0:
data[key]['num_peaks_containing_subseqlet'] += 1
for idx, psam in enumerate(oct4_selex_psams):
print("PSAM for consensus: "+consensus_seqs[idx])
viz_sequence.plot_weights(psam, figsize=(10,3))
peak_affinities = []
for peak in full_peaks:
peak_affinities.append(np.sum(np.multiply(rolling_window(peak.T, len(psam)), psam.T[:,None,:]), axis=(0,2)))
peak_affinities = np.array(peak_affinities).flatten()
metacluster_id = psam_id_to_best_metacluster_id[idx]
pattern_id = psam_id_to_best_pattern_id[idx]
print("best modisco match is "+metacluster_id+"/"+pattern_id+" with best matching sub-pattern: ")
viz_sequence.plot_weights(best_sub_patterns[idx], figsize=(10,3))
seqlet_affinities = []
seqlet_affinities_argmax = []
seqlet_affinities_orientation = []
for seqlet in patterns_to_kmers[metacluster_id][pattern_id]:
per_position_affinities_fwd = np.sum(np.multiply(rolling_window(seqlet.T, len(psam)), psam.T[:,None,:]), axis=(0,2))
per_position_affinities_rev = np.sum(np.multiply(rolling_window(seqlet.T, len(psam)), psam[::-1,::-1].T[:,None,:]), axis=(0,2))
per_position_affinities = np.max(np.array([per_position_affinities_fwd, per_position_affinities_rev]),
axis=0)
per_position_affinities_fwd_or_rev = np.argmax(np.array([per_position_affinities_fwd, per_position_affinities_rev]),
axis=0)
argmax_pos = np.argmax(per_position_affinities)
seqlet_affinities_argmax.append(argmax_pos)
seqlet_affinities_orientation.append(per_position_affinities_fwd_or_rev[argmax_pos])
seqlet_affinities.append(per_position_affinities[argmax_pos])
seqlet_affinities = np.array(seqlet_affinities).flatten()
sns.distplot(peak_affinities, bins='auto', label='full_peak_affinities')
sns.distplot(seqlet_affinities, bins='auto', label='seqlet_affinities')
plt.legend(loc='upper left')
plt.title("histogram of affinities and seqlet for pattern: "+metacluster_id+"/"+pattern_id)
plt.show()
sorted_seqlet_affinities = sorted(enumerate(seqlet_affinities), key=lambda x: x[1])
sorted_seqlet_indices = [x[0] for x in sorted_seqlet_affinities]
for seqlet_idx in sorted_seqlet_indices:
the_kmer = patterns_to_kmers[metacluster_id][pattern_id][seqlet_idx]
imp_seqlet = patterns_to_seqlets[metacluster_id][pattern_id][seqlet_idx]
pos_of_best_match_to_nrlb_within_seqlet = seqlet_affinities_argmax[seqlet_idx]
best_matching_subseqlet = the_kmer[pos_of_best_match_to_nrlb_within_seqlet:
(pos_of_best_match_to_nrlb_within_seqlet+len(psam))]
if (seqlet_affinities_orientation[seqlet_idx]==1):
best_matching_subseqlet = best_matching_subseqlet[::-1,::-1]
# storing
key = onehot_to_seq(best_matching_subseqlet)
if key not in data and onehot_to_seq(best_matching_subseqlet[::-1,::-1]) not in data:
create_dict_val_for_subseqlet(key)
data[key]['counts'] += 1
data[key]['kmers'].append(the_kmer)
data[key]['seqlets'].append(imp_seqlet)
data[key]['revcomps'].append(seqlet_affinities_orientation[seqlet_idx])
data[key]['modisco_patterns'].append(metacluster_id+"/"+pattern_id)
data[key]['psams'].append(psam)
data[key]['importances'].append(np.sum(np.abs(imp_seqlet[pos_of_best_match_to_nrlb_within_seqlet:
(pos_of_best_match_to_nrlb_within_seqlet+len(psam))])))
masked_best_matching_subseqlet = best_matching_subseqlet*psam
per_position_motif_match_fwd = np.sum(np.multiply(rolling_window(modisco_motif[metacluster_id][pattern_id].T, len(the_kmer)), the_kmer.T[:,None,:]), axis=(0,2))
per_position_motif_match_rev = np.sum(np.multiply(rolling_window(modisco_motif[metacluster_id][pattern_id].T, len(the_kmer)), the_kmer[::-1,::-1].T[:,None,:]), axis=(0,2))
per_position_motif_match = max(np.max(per_position_motif_match_fwd), np.max(per_position_motif_match_rev))
data[key]['affinities'].append(seqlet_affinities[seqlet_idx])
data[key]['modisco_matches'].append(per_position_motif_match)
print('\033[1m'+"counts \t best subseqlet \t mean importance \t var importance \t mean affinity (conv with PSAM) \t modisco match (conv with motif) \t # peaks w subseqlet \t # subseqlets in peaks"+'\033[0m')
for key in data:
print(str(data[key]['counts'])+" \t "
+key+" \t \t "
+str(np.mean(data[key]['importances']))+" \t \t "
+str(np.var(data[key]['importances']))+" \t \t "
+str(np.mean(data[key]['affinities']))+" \t \t "
+str(np.mean(data[key]['modisco_matches']))+" \t \t "
+str(data[key]['num_peaks_containing_subseqlet'])+" \t \t "
+str(data[key]['freq_in_peaks']))
# TODO: json dump the dict with all the data
# TODO: generate tab separated csv with averages