from basepair.imports import *
# Imports
%matplotlib inline
%config InlineBackend.figure_format = 'retina'
from basepair.imports import *
from basepair.exp.paper.config import tf_colors
from basepair.functions import mean
from basepair.cli.imp_score import ImpScoreFile
create_tf_session(0)
from basepair.seqmodel import SeqModel
from basepair.exp.chipnexus.simulate import random_seq
from concise.preprocessing import encodeDNA
seqs = encodeDNA([random_seq(1000) for i in range(512)])
mdir='output/nexus,peaks,OSNK,0,10,1,FALSE,same,0.5,64,25,0.004,9,FALSE,[1,50],TRUE,TRUE'
m = SeqModel.from_mdir(mdir)
x = m.neutral_bias_inputs(1000, 1000)
x['seq'] = seqs
preds = m.predict(x)
plt.plot(preds['Oct4/profile'].mean(axis=0)[:, 0])
plt.plot(preds['Oct4/profile'].mean(axis=0)[:, 1])
plt.plot(preds['Sox2/profile'].mean(axis=0)[:, 0])
plt.plot(preds['Sox2/profile'].mean(axis=0)[:, 1])
x = m.neutral_bias_inputs(1000, 1000)
x['seq'] = np.zeros_like(seqs)
preds = m.predict(x)
preds = m.predict_preact(np.zeros_like(seqs))
plt.plot(preds['Oct4/profile'].mean(axis=0)[:, 0])
plt.plot(preds['Oct4/profile'].mean(axis=0)[:, 1])
plt.plot(preds['Sox2/profile'].mean(axis=0)[:, 0])
plt.plot(preds['Sox2/profile'].mean(axis=0)[:, 1])
plt.plot(preds['Oct4/profile'].mean(axis=0)[:, 0])
plt.plot(preds['Oct4/profile'].mean(axis=0)[:, 1])
plt.plot(preds['Sox2/profile'].mean(axis=0)[:, 0])
plt.plot(preds['Sox2/profile'].mean(axis=0)[:, 1])
mdir='output/nexus,peaks,OSNK,0,10,1,FALSE,same,0.5,64,25,0.004,9,FALSE,[1,50],TRUE'
# Not using data augmentation
m = SeqModel.from_mdir(mdir)
x = m.neutral_bias_inputs(1000, 1000)
x['seq'] = seqs
preds = m.predict_preact(x)
plt.plot(preds['Oct4/profile'].mean(axis=0)[:, 0])
plt.plot(preds['Oct4/profile'].mean(axis=0)[:, 1])
plt.plot(preds['Sox2/profile'].mean(axis=0)[:, 0])
plt.plot(preds['Sox2/profile'].mean(axis=0)[:, 1])
x = m.neutral_bias_inputs(1000, 1000)
x['seq'] = np.zeros_like(seqs)
preds = m.predict_preact(x)
plt.plot(preds['Oct4/profile'].mean(axis=0)[:, 0])
plt.plot(preds['Oct4/profile'].mean(axis=0)[:, 1])
plt.plot(preds['Sox2/profile'].mean(axis=0)[:, 0])
plt.plot(preds['Sox2/profile'].mean(axis=0)[:, 1])