from basepair.config import get_data_dir, create_tf_session
from keras.models import load_model
from basepair.datasets import *
from basepair.preproc import transform_data
from basepair.plots import regression_eval
a=1
create_tf_session(1)
model = load_model("/users/avsec/workspace/basepair/basepair/../data/processed/chipnexus/exp/models/p-multi-task/seq_mutlitask_filters=32,conv1_kernel_size=21,tconv_kernel_size=25,n_dil_layers=6,use_profile=True,use_counts=True,c_task_weight=10,lr=0.004.5.h5")
ls -latr /users/avsec/workspace/basepair/basepair/../data/processed/chipnexus/exp/models/p-multi-task
data = sox2_oct4_peaks_sox2()
train, valid, test = transform_data(data, use_profile=True, use_counts=True)
yv_pred = model.predict(valid[0])
yt_pred = model.predict(test[0])
len(valid[1])
max_counts = valid[1][0].max(axis=1)
plt.scatter(np.log10(1+max_counts[:,0]), np.log10(1+max_counts[:,1]))
plt.xlabel("Number of pos counts of the max peak(log10)");
plt.ylabel("Number of neg counts of the max peak(log10)");
plt.hist(np.log10(1+valid[1][0].max(axis=1).mean(axis=1)), bins=50);
plt.xlabel("Number of counts for the max peak")
plt.hist(valid[1][0].max(axis=1).mean(axis=1), bins=50);
plt.xlabel("Number of counts for the max peak")
plt.hist(valid[1][0].sum(axis=1).mean(axis=1), bins=200);
plt.xlabel("Number of counts per peak");
plt.xlim([0, 300]);
plt.hist(valid[1][1].sum(axis=1).mean(axis=1), bins=200);
plt.xlabel("Number of counts per peak");
plt.xlim([0, 300]);
is_peak_pred = softmax(yv_pred[0])>0.02
is_peak_pred.sum()
plt.hist(np.ravel(valid[1][0]/valid[1][0].sum(axis=1, keepdims=True)), bins=1000);
plt.xlim([0.001, 0.05])
is_peak = valid[1][0]/valid[1][0].sum(axis=1, keepdims=True)>0.05
is_peak_oct4 = valid[1][1]/valid[1][1].sum(axis=1, keepdims=True)>0.05
is_peak.sum()
is_peak.shape
is_peak.mean()
from concise.eval_metrics import auprc
# TODO - compute the auprc in the regions with only sufficiently high counts
do_eval = valid[1][0].sum(axis=1).mean(axis=1) > 54
do_eval.mean()
do_eval.sum()
a = is_peak_pred[do_eval]
# Number of peaks per sequence on average
pd.Series(is_peak.sum(1).sum(1)).value_counts().sort_index().plot(kind='bar')
np.all(is_peak.sum(1) ==1, axis=1).mean()
np.all(is_peak.sum(1) ==1, axis=1).sum()
# Events with exectly two peaks
do_eval_sox2 = test[1][0].sum(axis=1).mean(axis=1) > 54
do_eval_ocf4 = test[1][1].sum(axis=1).mean(axis=1) > 54
do_eval = do_eval_sox2 & do_eval_ocf4
is_peak = test[1][1]/test[1][1].sum(axis=1, keepdims=True)>0.05
is_peak.sum()
is_peak.size
is_peak.mean()
fractions = test[1][tid]/test[1][tid].sum(axis=1, keepdims=True)
fractions = test[1][tid]/test[1][tid].sum(axis=1, keepdims=True)
do_eval.mean()
fractions = fractions[do_eval]
np.ravel(fractions).max()
amb = (fractions<0.05) & (fractions>0.01)
unique, counts = np.unique(np.digitize(fractions, [0.01, 0.05]), return_counts=True)
# Total counts
counts
# Percentages
counts / counts.sum()
plt.bar(["0-1%", "1%-5%", "5%-"], counts)
plt.xlabel("Percentage bucket");
plt.hist(np.ravel(fractions[fractions<0.05]), bins=100);
np.ravel(fractions).min()
is_peak.sum()
def bin_counts_max(x, binsize=2):
"""Bin the counts
"""
assert len(x.shape) == 3
outlen = x.shape[1] // binsize
xout = np.zeros((x.shape[0], outlen, x.shape[2]))
for i in range(outlen):
xout[:,i,:] = x[:,(binsize*i):(binsize*(i+1)), :].max(1)
return xout
def bin_counts_amb(x, binsize=2):
"""Bin the counts
"""
assert len(x.shape) == 3
outlen = x.shape[1] // binsize
xout = np.zeros((x.shape[0], outlen, x.shape[2])).astype(float)
for i in range(outlen):
iterval = x[:,(binsize*i):(binsize*(i+1)), :]
has_amb = np.any(iterval==-1, axis=1)
has_peak = np.any(iterval==1, axis=1)
# if no peak and has_amb -> -1
# if no peak and no has_amb -> 0
# if peak -> 1
xout[:,i,:] = (has_peak - (1-has_peak) * has_amb).astype(float)
return xout
def bin_counts_summary(x, binsize=2, fn=np.max):
"""Bin the counts
"""
assert len(x.shape) == 3
outlen = x.shape[1] // binsize
xout = np.zeros((x.shape[0], outlen, x.shape[2]))
for i in range(outlen):
xout[:,i,:] = np.apply_along_axis(fn, 1, x[:,(binsize*i):(binsize*(i+1)), :])
return xout
fracs = test[1][tid]/test[1][tid].sum(axis=1, keepdims=True)
is_peak = fracs>=0.05
ambigous = (fracs<0.05) & (fracs>=0.01)
np.mean(fracs[do_eval]>=0.05)
np.mean(ambigous[do_eval])
out= []
def fn(x):
if np.any(x==-1):
return -1
else:
np.max(x)
for tid in [0, 1]:
random = np.random.permutation(softmax(yt_pred[tid])[do_eval])
for binsize in [1, 2, 4, 10]:
fracs = test[1][tid]/test[1][tid].sum(axis=1, keepdims=True)
is_peak = (fracs>=0.05).astype(float)
ambigous = (fracs<0.05) & (fracs>=0.01)
is_peak[ambigous] = -1
#keep = do_eval.reshape((-1,1,1)) & ~ambigous
#y_true = np.ravel(is_peak[do_eval])
y_true = np.ravel(bin_counts_amb(is_peak[do_eval],binsize))
imbalance = np.sum(y_true==1)/np.sum(y_true >=0)
n_positives = np.sum(y_true==1)
n_ambigous = np.sum(y_true==-1)
frac_ambigous = n_ambigous / y_true.size
print(f"tid: {tid}, binsize: {binsize}, imbalance: {imbalance}, n_positives: {n_positives}, n_ambigous: {n_ambigous}, frac_ambigous: {frac_ambigous}")
res = auprc(y_true,
np.ravel(bin_counts_max(softmax(yt_pred[tid])[do_eval], binsize)))
print(f"auprc: {res:.2f}")
res_random = auprc(y_true,
np.ravel(bin_counts_max(random, binsize)))
out.append({"binsize": binsize,
"auprc": res,
"random_auprc": res_random,
"tid": tid,
"n_positives": n_positives,
"imbalance": imbalance
})
df = pd.DataFrame.from_dict(out)
[x for x in df.iterrows()][0]
fig = plt.figure(figsize=(8,3))
fig.subplots_adjust(bottom=0.2)
plt.subplot(121)
plt.semilogx(df.binsize[df.tid==0], df.auprc[df.tid==0], "-o", label="BPNet")
plt.semilogx(df.binsize[df.tid==0], df.random_auprc[df.tid==0], "-o", label="Random")
for i, s in df[df.tid==0].iloc[:-1].iterrows():
plt.text(s['binsize'], s['auprc']-0.1, f"n_pos: {int(s['n_positives'])} ({s['imbalance']*100:.1f}%)")
s = df[df.tid==0].iloc[-1]
plt.text(s['binsize']-7.1, s['auprc']-0.01, f"n_pos: {int(s['n_positives'])} ({s['imbalance']*100:.1f}%)")
#plt.text(s['binsize'], s['auprc']-0.2, f"imb: ")
plt.xticks(df.binsize, df.binsize);
plt.xlabel("Binsize")
plt.ylabel("auPRC");
plt.title("Sox2")
plt.legend()
plt.subplot(122)
plt.semilogx(df.binsize[df.tid==1], df.auprc[df.tid==1], "-o", label="BPNet")
plt.semilogx(df.binsize[df.tid==1], df.random_auprc[df.tid==1], "-o", label="Random")
plt.xticks(df.binsize, df.binsize);
for i, s in df[df.tid==1].iloc[:-1].iterrows():
plt.text(s['binsize'], s['auprc']-0.07, f"n_pos: {int(s['n_positives'])} ({s['imbalance']*100:.1f}%)")
s = df[df.tid==1].iloc[-1]
plt.text(s['binsize']-7.1, s['auprc']-0.01, f"n_pos: {int(s['n_positives'])} ({s['imbalance']*100:.1f}%)")
plt.legend()
plt.title("Oct4")
plt.tight_layout()
#plt.title("Peak prediction precision")
plt.xlabel("Binsize")
#plt.ylabel("auPRC");
#plt.savefig('fig/icml18/auprc-peak.png', dpi=600)
#plt.savefig('fig/icml18/auprc-peak.pdf', dpi=600)
#plt.close(fig) # close the figure
fig
fig = plt.figure(figsize=(8,3))
fig.subplots_adjust(bottom=0.2)
plt.subplot(121)
plt.semilogx(df.binsize[df.tid==0], df.auprc[df.tid==0], "-o", label="BPNet")
plt.semilogx(df.binsize[df.tid==0], df.random_auprc[df.tid==0], "-o", label="Random")
plt.xticks(df.binsize, df.binsize);
plt.xlabel("Binsize")
plt.ylabel("auPRC");
plt.title("Sox2")
plt.legend()
plt.subplot(122)
plt.semilogx(df.binsize[df.tid==1], df.auprc[df.tid==1], "-o", label="BPNet")
plt.semilogx(df.binsize[df.tid==1], df.random_auprc[df.tid==1], "-o", label="Random")
plt.xticks(df.binsize, df.binsize);
plt.legend()
plt.title("Oct4")
plt.tight_layout()
#plt.title("Peak prediction precision")
plt.xlabel("Binsize")
#plt.ylabel("auPRC");
#plt.savefig('fig/icml18/auprc-peak-sox2.png', dpi=600)
#plt.savefig('fig/icml18/auprc-peak-sox2.pdf', dpi=600)
#plt.close(fig) # close the figure
fig
do_eval.sum()
is_peak[do_eval].mean()
is_peak[do_eval].sum()
is_peak[do_eval].size
auprc(np.ravel(bin_counts_max(is_peak[do_eval], 1)), np.ravel(bin_counts_max(is_peak_pred[do_eval], 1)))
# TODO - what if we summarize the suff
np.mean(is_peak == is_peak_pred)
np.mean(is_peak == is_peak_pred)
is_peak.sum()
is_peak_pred.sum()
np.mean(is_peak == is_peak_pred)
peaks = is_peak.sum(axis=1)
plt.scatter(peaks[:,0], peaks[:,1], alpha=0.2)
plt.xlabel("Number of pos peaks per sequence")
plt.ylabel("Number of neg peaks per sequence")
pd.Series(peaks[:,1]).value_counts().sort_index().plot(kind='bar')
pd.Series(peaks[:,1]).value_counts().sort_index().plot(kind='bar')
# Focus only on the peaks with max 2 peaks per sequence
plt.hist(peaks[:,0], bins=20);
plt.xlabel("Number of peaks per sequence");
np.median(valid[1][0].sum(axis=1).mean(axis=1))
plt.scatter(np.log10(1+max_counts[:,0]), np.log10(1+max_counts[:,1]))
plt.xlabel("Number of pos counts (log10)");
plt.ylabel("Number of neg counts (log10)");
yt_pred = model.predict(test[0])
yv_pred = model.predict(valid[0])
yt_pred_reg = model_reg.predict(test[0])
yv_pred_reg = model_reg.predict(valid[0])
# Sox2
regression_eval(valid[1][2].sum(1), yv_pred[2].sum(1))
# Oct4
regression_eval(valid[1][3].sum(1), yv_pred[3].sum(1))
is_peak.sum()
def regression_eval(y_true, y_pred, task):
from scipy.stats import pearsonr, spearmanr
pearson,pearson_pval = pearsonr(y_true, y_pred)
spearman,spearman_pval = spearmanr(y_true, y_pred)
plt.scatter(y_pred, y_true, alpha=0.3)
plt.xlabel("Predicted log(count+1)")
plt.ylabel("Observed log(count+1)")
plt.title(f"{task}; R_spearman={spearman:.2f}")
# Sox2
fig = plt.figure(figsize=(8,6))
fig.subplots_adjust(bottom=0.2)
plt.subplot(221)
regression_eval(test[1][2].sum(1), yt_pred[2].sum(1), "Sox2, profile+count")
plt.xlabel("")
plt.subplot(222)
regression_eval(test[1][3].sum(1), yt_pred[3].sum(1), "Oct4, profile+count")
plt.ylabel("")
plt.xlabel("")
plt.subplot(223)
regression_eval(test[1][2].sum(1), yt_pred_reg[0].sum(1), "Sox2, count")
plt.subplot(224)
regression_eval(test[1][3].sum(1), yt_pred_reg[1].sum(1), "Oct4, count")
plt.ylabel("")
plt.tight_layout()
plt.savefig('fig/icml18/count-pred-scatter.png', dpi=600)
plt.savefig('fig/icml18/count-pred-scatter.pdf', dpi=600)
plt.close(fig) # close the figure
fig
# Sox2
fig = plt.figure(figsize=(8,3))
fig.subplots_adjust(bottom=0.2)
plt.subplot(121)
regression_eval(test[1][2].sum(1), yt_pred[2].sum(1), "Sox2")
plt.subplot(122)
regression_eval(test[1][3].sum(1), yt_pred[3].sum(1), "Oct4")
plt.ylabel("")
plt.tight_layout()
plt.savefig('fig/icml18/count-pred-scatter-profile-only.png', dpi=600)
plt.savefig('fig/icml18/count-pred-scatter-profile-only.pdf', dpi=600)
plt.close(fig) # close the figure
fig
a=1
a = "{}=1"
from basepair.plots import *
p = Seq2Sox2Oct4(test[0], test[1], model)
p.
idx_list = list(samplers.top_max_count(p.y[1], 2)) + list(samplers.top_max_count(p.y[0], 2))
idx_list
from matplotlib.ticker import FormatStrFormatter
def plot(p, idx_list, n=10, figsize=(20,2), fpath_template=None):
import matplotlib.pyplot as plt
#assert len(sox2_idx) == oct4_idx
#n = len(list(idx_dict.items())[0][1])
for j, idx in enumerate(idx_list):
fig, axes = plt.subplots(2, 2, sharex=True, figsize=figsize,
gridspec_kw = {'hspace':0})
fig.subplots_adjust(top=0.2)
for i in range(2):
prot = {0:"Sox2", 1: "Oct4"}
anno = p.labels.iloc[idx].replace("-", "-\n ")
#print(anno)
axes[0, i].plot(p.y[i][idx,:,0], label='pos'.format(np.argmax(p.y[i][idx,:,0])))
axes[0, i].plot(p.y[i][idx,:,1], label='neg'.format(np.argmax(p.y[i][idx,:,1])))
axes[0, i].yaxis.set_major_formatter(FormatStrFormatter('%.2f'))
if j != 0:
axes[0, i].xaxis.set_ticks_position('none')
if j ==0:
axes[0, i].set_title(f"{prot[i]}")
if i==0:
axes[0, i].set_ylabel("Observed");
#axes[0, i].get_yaxis().set_label_coords(-0.15,0.5)
axes[0, i].text(-7, p.y[i][idx,:,0].max()*0.45, anno, fontsize=9);
axes[0, i].yaxis.set_visible(False)
if i==1 and j == 0:
if j == 0:
leg = axes[0, i].legend();
leg.get_frame().set_linewidth(0.0)
axes[1, i].plot(p.y_pred[i][idx,:,0], label='{}'.format(np.argmax(p.y_pred[i][idx,:,0])))
axes[1, i].plot(p.y_pred[i][idx,:,1], label='{}'.format(np.argmax(p.y_pred[i][idx,:,1])))
#axes[1, i].legend();
if i == 0:
axes[1, i].set_ylabel("Predicted");
#axes[1, i].get_yaxis().set_label_coords(-0.15,0.5)
if j != len(idx_list) -1 :
axes[1, i].xaxis.set_visible(False)
axes[1, i].yaxis.set_visible(False)
if j == len(idx_list) - 1:
fig.text(0.5, -0.1, 'Position', ha='center')
#axes[1, i].legend();
plt.tight_layout()#pad=0.4, w_pad=0.5, h_pad=1.0)
# TODO - add a label spanning accross multiple labels
if fpath_template is not None:
plt.savefig(fpath_template.format(j)+ '.png', dpi=600)
plt.savefig(fpath_template.format(j)+ '.pdf', dpi=600)
plt.close(fig) # close the figure
show_figure(fig)
plt.show()
lf fig/
plot(p, [1788, 1743, 1949, 1950], figsize=(10, 1.5), fpath_template="fig/icml18/profiles2/{}")
fig, (ax0, ax1) = plt.subplots(2, 1, sharex=True, figsize=(20,2))
ax0.set_title(f"{pattern}")
ax0.plot(signal[:,0], label='fwd')
ax0.plot(signal[:,1], label='rev')
ax0.legend()
seqlogo(sequence, ax=ax1)
p.plot(sort='max_sox2',figsize=(20,2), fpath_template="fig/icml18/profiles/max_sox.{}")
p.plot(sort='max_oct4',figsize=(20,2), fpath_template="fig/icml18/profiles/max_oct4.{}")
p = Seq2Nexus(test[0], test[1],test[2], model)
from basepair import samplers
samplers.top_max_count(p.y[1], 1)
i = 1788
fig, ax = plt.subplots(1 + len(tracks), 1, sharex=True, figsize=figsize, gridspec_kw = {'height_ratios':[1]*len(tracks) + [seq_height]})
ax[0].set_title(f"motif {i} ({n})")
for i, (k,y) in enumerate(tracks.items()):
signal = self.extract_signal(y, rc_fn)[pattern].mean(axis=0)
if rc_vec is not None and rc_vec[i]:
signal = rc_fn(signal)
if start_vec is not None:
start = start_vec[i]
signal = signal[start:(start+width)]
ax[i].plot(np.arange(1,len(signal)+1), signal[:,0], label='pos')
ax[i].plot(np.arange(1,len(signal)+1), signal[:,1], label='neg')
ax[i].set_ylabel(f"{k}")
ax[i].yaxis.set_major_formatter(FormatStrFormatter('%.1f'))
ax[i].spines['top'].set_visible(False)
ax[i].spines['right'].set_visible(False)
ax[i].spines['bottom'].set_visible(False)
ax[i].xaxis.set_ticks_position('none')
ax[i].set_ylim(ylim)
if legend:
ax[i].legend()
seqlogo(sequence, ax=ax[-1])
ax[-1].yaxis.set_major_formatter(FormatStrFormatter('%.1f'))
ax[-1].set_ylabel("Inf. content")
ax[-1].spines['top'].set_visible(False)
ax[-1].spines['right'].set_visible(False)
ax[-1].spines['bottom'].set_visible(False)
ax[-1].set_xticks(list(range(0, len(sequence)+1, 5)));
def plot_track(track, idx, ax):
ax.plot(track[idx,:,0], label='positive strand')
ax.plot(track[idx,:,1], label='negative strand')
def strip_borders(ax):
ax.yaxis.set_major_formatter(FormatStrFormatter('%.1f'))
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
ax.spines['bottom'].set_visible(False)
ax.xaxis.set_ticks_position('none')
fig, ax = plt.subplots(9, 1, sharex=True,
gridspec_kw = {'wspace':0, 'hspace':0},
figsize=(12, 9))
#gridspec_kw = {'height_ratios':[1]*len(tracks) + [seq_height]})
i=1788
# TODO - enable for all
# ax[-1].yaxis.set_major_formatter(FormatStrFormatter('%.1f'))
xlim = [53, 152]
# ----------------
# tracks (4x)
plot_track(p.y_pred[0], i, ax[0])
ax[0].set_ylabel("Sox2\npred")
ax[0].set_xlim(xlim)
plot_track(p.y[0], i, ax[1])
ax[0].text(xlim[0]+1, 0.07, "Predicted and observed ChipNexus profile")
ax[1].set_ylabel("Sox2\nobs")
ax[1].set_yticks([0, 10])
ax[1].set_xlim(xlim)
ax[1].legend();
plot_track(p.y_pred[1], i, ax[2])
ax[2].set_ylabel("Oct4\npred")
ax[2].set_yticks([0, 0.05])
ax[2].set_xlim(xlim)
plot_track(p.y[1], i, ax[3])
ax[3].set_ylabel("Oct4\nobs")
ax[3].set_yticks([0, 50])
ax[3].set_xlim(xlim)
# ----------------
gi = p.input_grad(p.x[[i]], 'pos', 0, "max") * p.x[[i]]
seqlogo(gi[0], ax=ax[4])
ax[4].set_ylabel("Sox2 pos\nmax p.")
ax[4].text(xlim[0]+1, 0.9, "Importance scores: d(output)/d(input) * input")
ax[4].set_xlim(xlim)
gi = p.input_grad(p.x[[i]], 'neg', 0, "max") * p.x[[i]]
seqlogo(gi[0], ax=ax[5])
ax[5].set_ylabel("Sox2 neg\nmax p.")
ax[5].set_xlim(xlim)
gip = p.input_grad(p.x[[i]], 'pos', 1, "max") * p.x[[i]]
seqlogo(gip[0], ax=ax[6])
ax[6].set_ylabel("Oct4 pos\nmax p.")
ax[6].set_xlim(xlim)
gin = p.input_grad(p.x[[i]], 'neg', 1, "max") * p.x[[i]]
seqlogo(gin[0], ax=ax[7])
ax[7].set_ylabel("Oct4 neg\nmax p.")
ax[7].set_xlim(xlim)
gi_pos = p.input_grad(p.x[[i]], 'neg', 0, "count") * p.x[[i]]
gi_neg = p.input_grad(p.x[[i]], 'pos', 0, "count") * p.x[[i]]
gi_pos2 = p.input_grad(p.x[[i]], 'neg', 1, "count") * p.x[[i]]
gi_neg2 = p.input_grad(p.x[[i]], 'pos', 1, "count") * p.x[[i]]
seqlogo((gi_pos[0]+gi_neg[0] + gi_pos2[0]+gi_neg2[0])/4, ax=ax[8])
ax[8].set_ylabel("Total\ncounts")
ax[8].set_yticks([0, 0.3]);
ax[8].set_xticks(list(range(xlim[0]+2, xlim[1]+1, 5)));
ax[8].set_xlabel("Position");
ax[8].set_xlim(xlim)
# seq importances
plt.savefig('fig/icml18/grad_track.1788.png', dpi=600)
plt.savefig('fig/icml18/grad_track.1788.pdf', dpi=600)
plt.close(fig) # close the figure
fig
fig, ax = plt.subplots(10, 1, sharex=True,
gridspec_kw = {'wspace':0, 'hspace':0},
figsize=(17, 12))
#gridspec_kw = {'height_ratios':[1]*len(tracks) + [seq_height]})
# TODO - enable for all
# ax[-1].yaxis.set_major_formatter(FormatStrFormatter('%.1f'))
# ----------------
# tracks (4x)
plot_track(p.y_pred[0], i, ax[0])
ax[0].legend();
ax[0].set_ylabel("Sox2\nPredicted")
plot_track(p.y[0], i, ax[1])
ax[1].set_ylabel("Sox2\nObserved")
plot_track(p.y_pred[1], i, ax[2])
ax[2].set_ylabel("Oct4\nPredicted")
plot_track(p.y[1], i, ax[3])
ax[3].set_ylabel("Oct4\nObserved")
# ----------------
gi = p.input_grad(p.x[[i]], 'pos', 0, "max") * p.x[[i]]
seqlogo(gi[0], ax=ax[4])
ax[4].set_ylabel("Sox2 pos\ng(max) * i")
gi = p.input_grad(p.x[[i]], 'neg', 0, "max") * p.x[[i]]
seqlogo(gi[0], ax=ax[5])
ax[5].set_ylabel("Sox2 pos\ng(max) * i")
gi = p.input_grad(p.x[[i]], 'pos', 1, "max") * p.x[[i]]
seqlogo(gi[0], ax=ax[6])
ax[6].set_ylabel("Oct4 pos\ng(max) * i")
gi = p.input_grad(p.x[[i]], 'neg', 1, "max") * p.x[[i]]
seqlogo(gi[0], ax=ax[7])
ax[7].set_ylabel("Oct4 neg\ng(max) * i")
gi_pos = p.input_grad(p.x[[i]], 'neg', 0, "count") * p.x[[i]]
gi_neg = p.input_grad(p.x[[i]], 'pos', 0, "count") * p.x[[i]]
seqlogo((gi_pos[0]+gi_neg[0])/2, ax=ax[8])
ax[8].set_ylabel("Sox2 pos+neg\ng(counts) * i")
gi_pos = p.input_grad(p.x[[i]], 'neg', 1, "count") * p.x[[i]]
gi_neg = p.input_grad(p.x[[i]], 'pos', 1, "count") * p.x[[i]]
seqlogo((gi_pos[0]+gi_neg[0])/2, ax=ax[9])
ax[9].set_ylabel("Oct4 pos+neg\ng(counts) * i")
ax[9].set_xticks(list(range(0, p.seq_len, 5)));
# seq importances
#plt.savefig('fig/icml18/grad_track.1788.png', dpi=600)
#plt.savefig('fig/icml18/grad_track.1788.pdf', dpi=600)
#plt.close(fig) # close the figure
#fig
p.plot(sort='max_sox2', seq_grad='max', figsize=(20,12))
p.plot(sort='max_oct4', seq_grad='max', figsize=(20,12))