Goal

  • evaluate the model predictions
In [ ]:
from basepair.config import get_data_dir, create_tf_session
from keras.models import load_model
from basepair.datasets import *

from basepair.preproc import transform_data

from basepair.plots import regression_eval
In [ ]:
a=1
In [ ]:
create_tf_session(1)
In [ ]:
model = load_model("/users/avsec/workspace/basepair/basepair/../data/processed/chipnexus/exp/models/p-multi-task/seq_mutlitask_filters=32,conv1_kernel_size=21,tconv_kernel_size=25,n_dil_layers=6,use_profile=True,use_counts=True,c_task_weight=10,lr=0.004.5.h5")
In [319]:
ls -latr /users/avsec/workspace/basepair/basepair/../data/processed/chipnexus/exp/models/p-multi-task
total 2264
drwxrwxr-x 6 avsec users   4096 May 19 11:18 ../
-rw-rw-r-- 1 avsec users 383152 May 19 11:19 seq_mutlitask_filters=32,conv1_kernel_size=21,tconv_kernel_size=25,n_dil_layers=6,use_profile=True,use_counts=True,lr=0.004.1.h5
-rw-rw-r-- 1 avsec users 383312 May 19 11:23 seq_mutlitask_filters=32,conv1_kernel_size=21,tconv_kernel_size=25,n_dil_layers=6,use_profile=True,use_counts=True,c_task_weight=100,lr=0.004.1.h5
-rw-rw-r-- 1 avsec users 383312 May 19 11:37 seq_mutlitask_filters=32,conv1_kernel_size=21,tconv_kernel_size=25,n_dil_layers=6,use_profile=True,use_counts=True,c_task_weight=100,lr=0.004.2.h5
-rw-rw-r-- 1 avsec users 383320 May 19 11:38 seq_mutlitask_filters=32,conv1_kernel_size=21,tconv_kernel_size=25,n_dil_layers=6,use_profile=True,use_counts=True,c_task_weight=100,lr=0.004.3.h5
-rw-rw-r-- 1 avsec users 383320 May 19 11:40 seq_mutlitask_filters=32,conv1_kernel_size=21,tconv_kernel_size=25,n_dil_layers=6,use_profile=True,use_counts=True,c_task_weight=1,lr=0.004.4.h5
drwxrwxr-x 2 avsec users   4096 May 19 11:41 ./
-rw-rw-r-- 1 avsec users 383376 May 19 11:42 seq_mutlitask_filters=32,conv1_kernel_size=21,tconv_kernel_size=25,n_dil_layers=6,use_profile=True,use_counts=True,c_task_weight=10,lr=0.004.5.h5
In [4]:
data = sox2_oct4_peaks_sox2()
In [7]:
train, valid, test = transform_data(data, use_profile=True, use_counts=True)
In [84]:
yv_pred = model.predict(valid[0])
In [85]:
yt_pred = model.predict(test[0])
In [21]:
len(valid[1])
Out[21]:
4
In [24]:
max_counts = valid[1][0].max(axis=1)
In [29]:
plt.scatter(np.log10(1+max_counts[:,0]), np.log10(1+max_counts[:,1]))
plt.xlabel("Number of pos counts of the max peak(log10)");
plt.ylabel("Number of neg counts of the max peak(log10)");
In [33]:
plt.hist(np.log10(1+valid[1][0].max(axis=1).mean(axis=1)), bins=50);
plt.xlabel("Number of counts for the max peak")
Out[33]:
Text(0.5,0,'Number of counts for the max peak')
In [34]:
plt.hist(valid[1][0].max(axis=1).mean(axis=1), bins=50);
plt.xlabel("Number of counts for the max peak")
Out[34]:
Text(0.5,0,'Number of counts for the max peak')
In [42]:
plt.hist(valid[1][0].sum(axis=1).mean(axis=1), bins=200);
plt.xlabel("Number of counts per peak");
plt.xlim([0, 300]);
In [244]:
plt.hist(valid[1][1].sum(axis=1).mean(axis=1), bins=200);
plt.xlabel("Number of counts per peak");
plt.xlim([0, 300]);

Number of peaks per sequence

In [149]:
is_peak_pred = softmax(yv_pred[0])>0.02
In [ ]:
 
In [113]:
is_peak_pred.sum()
Out[113]:
9026
In [404]:
plt.hist(np.ravel(valid[1][0]/valid[1][0].sum(axis=1, keepdims=True)), bins=1000);
plt.xlim([0.001, 0.05])
Out[404]:
(0.001, 0.05)
In [405]:
is_peak = valid[1][0]/valid[1][0].sum(axis=1, keepdims=True)>0.05
In [245]:
is_peak_oct4 = valid[1][1]/valid[1][1].sum(axis=1, keepdims=True)>0.05
In [412]:
is_peak.sum()
Out[412]:
9474
In [411]:
is_peak.shape
Out[411]:
(1884, 201, 2)
In [224]:
is_peak.mean()
Out[224]:
0.0125091104984631
In [118]:
from concise.eval_metrics import auprc
In [ ]:
# TODO - compute the auprc in the regions with only sufficiently high counts
In [ ]:
 
In [241]:
do_eval = valid[1][0].sum(axis=1).mean(axis=1) > 54
In [204]:
do_eval.mean()
Out[204]:
0.04883227176220807
In [205]:
do_eval.sum()
Out[205]:
92
In [ ]:
 
In [145]:
a = is_peak_pred[do_eval]
In [226]:
# Number of peaks per sequence on average
pd.Series(is_peak.sum(1).sum(1)).value_counts().sort_index().plot(kind='bar')
Out[226]:
<matplotlib.axes._subplots.AxesSubplot at 0x7fc657a7d7f0>
In [233]:
np.all(is_peak.sum(1) ==1, axis=1).mean()
Out[233]:
0.0881104033970276
In [234]:
np.all(is_peak.sum(1) ==1, axis=1).sum()
Out[234]:
166
In [235]:
# Events with exectly two peaks
In [413]:
do_eval_sox2 = test[1][0].sum(axis=1).mean(axis=1) > 54
In [414]:
do_eval_ocf4 = test[1][1].sum(axis=1).mean(axis=1) > 54
In [415]:
do_eval = do_eval_sox2 & do_eval_ocf4
In [416]:
is_peak = test[1][1]/test[1][1].sum(axis=1, keepdims=True)>0.05
In [287]:
 
Out[287]:
0.42303609341825904
In [303]:
is_peak.sum()
Out[303]:
8628
In [304]:
is_peak.size
Out[304]:
784302
In [302]:
is_peak.mean()
Out[302]:
0.011000864462923721
In [422]:
fractions = test[1][tid]/test[1][tid].sum(axis=1, keepdims=True)
In [422]:
fractions = test[1][tid]/test[1][tid].sum(axis=1, keepdims=True)
In [481]:
do_eval.mean()
Out[481]:
0.42439774474628394
In [482]:
fractions = fractions[do_eval]
In [424]:
np.ravel(fractions).max()
Out[424]:
0.5
In [483]:
amb = (fractions<0.05) & (fractions>0.01)
In [484]:
unique, counts = np.unique(np.digitize(fractions, [0.01, 0.05]), return_counts=True)
In [490]:
# Total counts
counts
Out[490]:
array([271173,  61207,    476])
In [486]:
# Percentages
counts / counts.sum()
Out[486]:
array([0.81468563, 0.18388432, 0.00143005])
In [487]:
plt.bar(["0-1%", "1%-5%", "5%-"], counts)
plt.xlabel("Percentage bucket");
In [488]:
plt.hist(np.ravel(fractions[fractions<0.05]), bins=100);
In [425]:
np.ravel(fractions).min()
Out[425]:
0.0
In [271]:
is_peak.sum()
Out[271]:
6195
In [468]:
def bin_counts_max(x, binsize=2):
    """Bin the counts
    """
    assert len(x.shape) == 3
    outlen = x.shape[1] // binsize
    xout = np.zeros((x.shape[0], outlen, x.shape[2]))
    for i in range(outlen):
        xout[:,i,:] = x[:,(binsize*i):(binsize*(i+1)), :].max(1)
    return xout    

def bin_counts_amb(x, binsize=2):
    """Bin the counts
    """
    assert len(x.shape) == 3
    outlen = x.shape[1] // binsize
    xout = np.zeros((x.shape[0], outlen, x.shape[2])).astype(float)
    for i in range(outlen):
        iterval = x[:,(binsize*i):(binsize*(i+1)), :]
        has_amb = np.any(iterval==-1, axis=1)
        has_peak = np.any(iterval==1, axis=1)
        # if no peak and has_amb -> -1
        # if no peak and no has_amb -> 0
        # if peak -> 1
        xout[:,i,:] = (has_peak - (1-has_peak) * has_amb).astype(float)
    return xout    

def bin_counts_summary(x, binsize=2, fn=np.max):
    """Bin the counts
    """
    assert len(x.shape) == 3
    outlen = x.shape[1] // binsize
    xout = np.zeros((x.shape[0], outlen, x.shape[2]))
    for i in range(outlen):
        xout[:,i,:] = np.apply_along_axis(fn, 1, x[:,(binsize*i):(binsize*(i+1)), :])
    return xout    
In [496]:
fracs = test[1][tid]/test[1][tid].sum(axis=1, keepdims=True)
is_peak = fracs>=0.05
ambigous = (fracs<0.05) & (fracs>=0.01)
In [495]:
np.mean(fracs[do_eval]>=0.05)
Out[495]:
0.00564508375994424
In [498]:
np.mean(ambigous[do_eval])
Out[498]:
0.2133024491071214
In [561]:
out= []
def fn(x):
    if np.any(x==-1):
        return -1
    else:
        np.max(x)

for tid in [0, 1]:
    random = np.random.permutation(softmax(yt_pred[tid])[do_eval])
    for binsize in [1, 2, 4, 10]:
        fracs = test[1][tid]/test[1][tid].sum(axis=1, keepdims=True)
        is_peak = (fracs>=0.05).astype(float)
        ambigous = (fracs<0.05) & (fracs>=0.01)
        is_peak[ambigous] = -1
        #keep = do_eval.reshape((-1,1,1)) & ~ambigous
        #y_true = np.ravel(is_peak[do_eval])
        y_true = np.ravel(bin_counts_amb(is_peak[do_eval],binsize))
        
        imbalance = np.sum(y_true==1)/np.sum(y_true >=0)
        n_positives = np.sum(y_true==1)
        n_ambigous = np.sum(y_true==-1)
        frac_ambigous = n_ambigous / y_true.size
        print(f"tid: {tid}, binsize: {binsize}, imbalance: {imbalance}, n_positives: {n_positives}, n_ambigous: {n_ambigous}, frac_ambigous: {frac_ambigous}")
        res = auprc(y_true, 
                    np.ravel(bin_counts_max(softmax(yt_pred[tid])[do_eval], binsize)))
        print(f"auprc: {res:.2f}")
        res_random = auprc(y_true,
                           np.ravel(bin_counts_max(random, binsize)))
        out.append({"binsize": binsize, 
                    "auprc": res, 
                    "random_auprc": res_random,
                   "tid": tid,
                    "n_positives": n_positives,
                    "imbalance": imbalance
                   })

    df = pd.DataFrame.from_dict(out)
tid: 0, binsize: 1, imbalance: 0.007175672218042672, n_positives: 1879, n_ambigous: 70999, frac_ambigous: 0.2133024491071214
auprc: 0.44
tid: 0, binsize: 2, imbalance: 0.016665580455650003, n_positives: 1790, n_ambigous: 58193, frac_ambigous: 0.35140700483091786
auprc: 0.55
tid: 0, binsize: 4, imbalance: 0.043440442079410564, n_positives: 1698, n_ambigous: 43712, frac_ambigous: 0.5279227053140096
auprc: 0.67
tid: 0, binsize: 10, imbalance: 0.1855305086746135, n_positives: 1572, n_ambigous: 24647, frac_ambigous: 0.7441727053140097
auprc: 0.84
tid: 1, binsize: 1, imbalance: 0.001755084583278026, n_positives: 476, n_ambigous: 61644, frac_ambigous: 0.18519720239382798
auprc: 0.31
tid: 1, binsize: 2, imbalance: 0.00400941340538656, n_positives: 460, n_ambigous: 50870, frac_ambigous: 0.30718599033816424
auprc: 0.38
tid: 1, binsize: 4, imbalance: 0.01036531994287557, n_positives: 450, n_ambigous: 39386, frac_ambigous: 0.4756763285024155
auprc: 0.44
tid: 1, binsize: 10, imbalance: 0.046274425595877175, n_positives: 431, n_ambigous: 23806, frac_ambigous: 0.7187801932367149
auprc: 0.61
In [525]:
[x for x in df.iterrows()][0]
Out[525]:
(0, auprc              0.435988
 binsize            1.000000
 imbalance          0.007176
 n_positives     1879.000000
 random_auprc       0.010004
 tid                0.000000
 Name: 0, dtype: float64)
In [557]:
fig = plt.figure(figsize=(8,3))
fig.subplots_adjust(bottom=0.2)
plt.subplot(121)
plt.semilogx(df.binsize[df.tid==0], df.auprc[df.tid==0], "-o", label="BPNet")
plt.semilogx(df.binsize[df.tid==0], df.random_auprc[df.tid==0], "-o", label="Random")
for i, s in df[df.tid==0].iloc[:-1].iterrows():
    plt.text(s['binsize'], s['auprc']-0.1, f"n_pos: {int(s['n_positives'])} ({s['imbalance']*100:.1f}%)")
s = df[df.tid==0].iloc[-1]
plt.text(s['binsize']-7.1, s['auprc']-0.01, f"n_pos: {int(s['n_positives'])} ({s['imbalance']*100:.1f}%)")
    #plt.text(s['binsize'], s['auprc']-0.2, f"imb: ")
plt.xticks(df.binsize, df.binsize);
plt.xlabel("Binsize")
plt.ylabel("auPRC");
plt.title("Sox2")

plt.legend()
plt.subplot(122)
plt.semilogx(df.binsize[df.tid==1], df.auprc[df.tid==1], "-o", label="BPNet")
plt.semilogx(df.binsize[df.tid==1], df.random_auprc[df.tid==1], "-o", label="Random")
plt.xticks(df.binsize, df.binsize);
for i, s in df[df.tid==1].iloc[:-1].iterrows():
    plt.text(s['binsize'], s['auprc']-0.07, f"n_pos: {int(s['n_positives'])} ({s['imbalance']*100:.1f}%)")
s = df[df.tid==1].iloc[-1]
plt.text(s['binsize']-7.1, s['auprc']-0.01, f"n_pos: {int(s['n_positives'])} ({s['imbalance']*100:.1f}%)")

plt.legend()
plt.title("Oct4")
plt.tight_layout()
#plt.title("Peak prediction precision")
plt.xlabel("Binsize")
#plt.ylabel("auPRC");
#plt.savefig('fig/icml18/auprc-peak.png', dpi=600)
#plt.savefig('fig/icml18/auprc-peak.pdf', dpi=600)
#plt.close(fig)    # close the figure
fig
Out[557]:
In [306]:
fig = plt.figure(figsize=(8,3))
fig.subplots_adjust(bottom=0.2)
plt.subplot(121)
plt.semilogx(df.binsize[df.tid==0], df.auprc[df.tid==0], "-o", label="BPNet")
plt.semilogx(df.binsize[df.tid==0], df.random_auprc[df.tid==0], "-o", label="Random")
plt.xticks(df.binsize, df.binsize);
plt.xlabel("Binsize")
plt.ylabel("auPRC");
plt.title("Sox2")

plt.legend()
plt.subplot(122)
plt.semilogx(df.binsize[df.tid==1], df.auprc[df.tid==1], "-o", label="BPNet")
plt.semilogx(df.binsize[df.tid==1], df.random_auprc[df.tid==1], "-o", label="Random")
plt.xticks(df.binsize, df.binsize);
plt.legend()
plt.title("Oct4")
plt.tight_layout()
#plt.title("Peak prediction precision")
plt.xlabel("Binsize")
#plt.ylabel("auPRC");
#plt.savefig('fig/icml18/auprc-peak-sox2.png', dpi=600)
#plt.savefig('fig/icml18/auprc-peak-sox2.pdf', dpi=600)
#plt.close(fig)    # close the figure
fig
Out[306]:
In [197]:
do_eval.sum()
Out[197]:
341
In [160]:
is_peak[do_eval].mean()
Out[160]:
0.005137068127877157
In [161]:
is_peak[do_eval].sum()
Out[161]:
1935
In [162]:
is_peak[do_eval].size
Out[162]:
376674
In [131]:
auprc(np.ravel(bin_counts_max(is_peak[do_eval], 1)), np.ravel(bin_counts_max(is_peak_pred[do_eval], 1)))
Out[131]:
0.32336555228047453
In [ ]:
# TODO - what if we summarize the suff
In [116]:
np.mean(is_peak == is_peak_pred)
Out[116]:
0.9796558608232722
In [99]:
np.mean(is_peak == is_peak_pred)
Out[99]:
0.9874908895015369
In [100]:
is_peak.sum()
Out[100]:
9474
In [101]:
is_peak_pred.sum()
Out[101]:
0
In [93]:
np.mean(is_peak == is_peak_pred)
Out[93]:
0.9874908895015369
In [76]:
peaks = is_peak.sum(axis=1)
In [87]:
plt.scatter(peaks[:,0], peaks[:,1], alpha=0.2)
plt.xlabel("Number of pos peaks per sequence")
plt.ylabel("Number of neg peaks per sequence")
Out[87]:
Text(0,0.5,'Number of neg peaks per sequence')
In [88]:
pd.Series(peaks[:,1]).value_counts().sort_index().plot(kind='bar')
Out[88]:
<matplotlib.axes._subplots.AxesSubplot at 0x7fc65216b550>
In [83]:
pd.Series(peaks[:,1]).value_counts().sort_index().plot(kind='bar')
Out[83]:
<matplotlib.axes._subplots.AxesSubplot at 0x7fc65247dcc0>
In [ ]:
# Focus only on the peaks with max 2 peaks per sequence
In [57]:
plt.hist(peaks[:,0], bins=20);
plt.xlabel("Number of peaks per sequence");
In [43]:
np.median(valid[1][0].sum(axis=1).mean(axis=1))
Out[43]:
54.0
In [ ]:
 
In [28]:
plt.scatter(np.log10(1+max_counts[:,0]), np.log10(1+max_counts[:,1]))
plt.xlabel("Number of pos counts (log10)");
plt.ylabel("Number of neg counts (log10)");

Valid

TODO - the performance of Sox2 could be better

In [329]:
yt_pred = model.predict(test[0])
yv_pred = model.predict(valid[0])
In [330]:
yt_pred_reg = model_reg.predict(test[0])
yv_pred_reg = model_reg.predict(valid[0])
In [296]:
# Sox2
regression_eval(valid[1][2].sum(1), yv_pred[2].sum(1))
In [297]:
# Oct4
regression_eval(valid[1][3].sum(1), yv_pred[3].sum(1))

Test

In [303]:
is_peak.sum()
Out[303]:
8628
In [758]:
def regression_eval(y_true, y_pred, task):
    from scipy.stats import pearsonr, spearmanr
    pearson,pearson_pval =  pearsonr(y_true, y_pred)
    spearman,spearman_pval = spearmanr(y_true, y_pred)
    plt.scatter(y_pred, y_true, alpha=0.3)
    plt.xlabel("Predicted  log(count+1)")
    plt.ylabel("Observed  log(count+1)")
    plt.title(f"{task}; R_spearman={spearman:.2f}")
In [350]:
# Sox2
fig = plt.figure(figsize=(8,6))
fig.subplots_adjust(bottom=0.2)
plt.subplot(221)
regression_eval(test[1][2].sum(1), yt_pred[2].sum(1), "Sox2, profile+count")
plt.xlabel("")
plt.subplot(222)
regression_eval(test[1][3].sum(1), yt_pred[3].sum(1), "Oct4, profile+count")
plt.ylabel("")
plt.xlabel("")
plt.subplot(223)
regression_eval(test[1][2].sum(1), yt_pred_reg[0].sum(1), "Sox2, count")
plt.subplot(224)
regression_eval(test[1][3].sum(1), yt_pred_reg[1].sum(1), "Oct4, count")
plt.ylabel("")
plt.tight_layout()
plt.savefig('fig/icml18/count-pred-scatter.png', dpi=600)
plt.savefig('fig/icml18/count-pred-scatter.pdf', dpi=600)
plt.close(fig)    # close the figure
fig
Out[350]:
In [759]:
# Sox2
fig = plt.figure(figsize=(8,3))
fig.subplots_adjust(bottom=0.2)
plt.subplot(121)
regression_eval(test[1][2].sum(1), yt_pred[2].sum(1), "Sox2")
plt.subplot(122)
regression_eval(test[1][3].sum(1), yt_pred[3].sum(1), "Oct4")
plt.ylabel("")
plt.tight_layout()
plt.savefig('fig/icml18/count-pred-scatter-profile-only.png', dpi=600)
plt.savefig('fig/icml18/count-pred-scatter-profile-only.pdf', dpi=600)
plt.close(fig)    # close the figure
fig
Out[759]:
In [375]:
a=1
In [353]:
a = "{}=1"
In [562]:
from basepair.plots import *
In [368]:
p = Seq2Sox2Oct4(test[0], test[1], model)
In [563]:
p.
Out[563]:
<basepair.plots.Seq2Nexus at 0x7fc6e027f0b8>
In [568]:
idx_list = list(samplers.top_max_count(p.y[1], 2)) + list(samplers.top_max_count(p.y[0], 2))
In [569]:
idx_list
Out[569]:
[1788, 1743, 1949, 1950]
In [583]:
from matplotlib.ticker import FormatStrFormatter
In [609]:
 
Out[609]:
'chr8:70555760-70555961'
In [737]:
def plot(p, idx_list, n=10, figsize=(20,2), fpath_template=None):
    import matplotlib.pyplot as plt
    #assert len(sox2_idx) == oct4_idx
    
    
    
    #n = len(list(idx_dict.items())[0][1])
    for j, idx in enumerate(idx_list):
        fig, axes = plt.subplots(2, 2, sharex=True, figsize=figsize, 
                             gridspec_kw = {'hspace':0})
        fig.subplots_adjust(top=0.2)
        for i in range(2):
            prot = {0:"Sox2", 1: "Oct4"}
            anno = p.labels.iloc[idx].replace("-", "-\n    ")
            #print(anno)
            axes[0, i].plot(p.y[i][idx,:,0], label='pos'.format(np.argmax(p.y[i][idx,:,0])))
            axes[0, i].plot(p.y[i][idx,:,1], label='neg'.format(np.argmax(p.y[i][idx,:,1])))

            axes[0, i].yaxis.set_major_formatter(FormatStrFormatter('%.2f'))
            
            if j != 0:
                axes[0, i].xaxis.set_ticks_position('none') 

            if j ==0:
                axes[0, i].set_title(f"{prot[i]}")
                
            if i==0:
                axes[0, i].set_ylabel("Observed");
                #axes[0, i].get_yaxis().set_label_coords(-0.15,0.5)
                axes[0, i].text(-7, p.y[i][idx,:,0].max()*0.45, anno, fontsize=9);
            axes[0, i].yaxis.set_visible(False)
            if i==1 and j == 0:
                if j == 0:
                    leg = axes[0, i].legend();
                    leg.get_frame().set_linewidth(0.0)

                
                
            axes[1, i].plot(p.y_pred[i][idx,:,0], label='{}'.format(np.argmax(p.y_pred[i][idx,:,0])))
            axes[1, i].plot(p.y_pred[i][idx,:,1], label='{}'.format(np.argmax(p.y_pred[i][idx,:,1])))
            #axes[1, i].legend();
            if i == 0:
                axes[1, i].set_ylabel("Predicted");
                #axes[1, i].get_yaxis().set_label_coords(-0.15,0.5)
            if j != len(idx_list) -1 :
                axes[1, i].xaxis.set_visible(False)
            axes[1, i].yaxis.set_visible(False)
            if j == len(idx_list) - 1:
                fig.text(0.5, -0.1, 'Position', ha='center')
            #axes[1, i].legend();
            plt.tight_layout()#pad=0.4, w_pad=0.5, h_pad=1.0)

            # TODO - add a label spanning accross multiple labels

        if fpath_template is not None:
            plt.savefig(fpath_template.format(j)+ '.png', dpi=600)
            plt.savefig(fpath_template.format(j)+ '.pdf', dpi=600)
            plt.close(fig)    # close the figure
            show_figure(fig)
            plt.show()
In [734]:
lf fig/
In [738]:
plot(p, [1788, 1743, 1949, 1950], figsize=(10, 1.5), fpath_template="fig/icml18/profiles2/{}")
In [589]:
fig, (ax0, ax1) = plt.subplots(2, 1, sharex=True, figsize=(20,2))
        ax0.set_title(f"{pattern}")
        ax0.plot(signal[:,0], label='fwd')
        ax0.plot(signal[:,1], label='rev')
        ax0.legend()
        seqlogo(sequence, ax=ax1)
  File "<ipython-input-589-1f761894feec>", line 2
    ax0.set_title(f"{pattern}")
    ^
IndentationError: unexpected indent
In [369]:
p.plot(sort='max_sox2',figsize=(20,2), fpath_template="fig/icml18/profiles/max_sox.{}")
In [370]:
p.plot(sort='max_oct4',figsize=(20,2), fpath_template="fig/icml18/profiles/max_oct4.{}")
In [379]:
p = Seq2Nexus(test[0], test[1],test[2], model)
In [382]:
from basepair import samplers
In [384]:
samplers.top_max_count(p.y[1], 1)
Out[384]:
Int64Index([1788], dtype='int64')
In [385]:
i = 1788
In [ ]:
fig, ax = plt.subplots(1 + len(tracks), 1, sharex=True, figsize=figsize, gridspec_kw = {'height_ratios':[1]*len(tracks) + [seq_height]})
ax[0].set_title(f"motif {i} ({n})")
for i, (k,y) in enumerate(tracks.items()):
    signal = self.extract_signal(y, rc_fn)[pattern].mean(axis=0)
    if rc_vec is not None and rc_vec[i]:
        signal = rc_fn(signal)
    if start_vec is not None:
        start = start_vec[i]
        signal = signal[start:(start+width)]

    ax[i].plot(np.arange(1,len(signal)+1), signal[:,0], label='pos')
    ax[i].plot(np.arange(1,len(signal)+1), signal[:,1], label='neg')
    ax[i].set_ylabel(f"{k}")
    ax[i].yaxis.set_major_formatter(FormatStrFormatter('%.1f'))
    ax[i].spines['top'].set_visible(False)
    ax[i].spines['right'].set_visible(False)
    ax[i].spines['bottom'].set_visible(False)
    ax[i].xaxis.set_ticks_position('none') 
    ax[i].set_ylim(ylim)
    if legend:
        ax[i].legend()

seqlogo(sequence, ax=ax[-1])
ax[-1].yaxis.set_major_formatter(FormatStrFormatter('%.1f'))
ax[-1].set_ylabel("Inf. content")
ax[-1].spines['top'].set_visible(False)
ax[-1].spines['right'].set_visible(False)
ax[-1].spines['bottom'].set_visible(False)
ax[-1].set_xticks(list(range(0, len(sequence)+1, 5)));
In [395]:
def plot_track(track, idx, ax):
    ax.plot(track[idx,:,0], label='positive strand')
    ax.plot(track[idx,:,1], label='negative strand')
In [ ]:
def strip_borders(ax):
    ax.yaxis.set_major_formatter(FormatStrFormatter('%.1f'))
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)
    ax.spines['bottom'].set_visible(False)
    ax.xaxis.set_ticks_position('none') 
In [756]:
fig, ax = plt.subplots(9, 1, sharex=True, 
                       gridspec_kw = {'wspace':0, 'hspace':0},
                       figsize=(12, 9))
                       #gridspec_kw = {'height_ratios':[1]*len(tracks) + [seq_height]})

i=1788
# TODO - enable for all
# ax[-1].yaxis.set_major_formatter(FormatStrFormatter('%.1f'))
xlim = [53, 152]
# ----------------
# tracks (4x)
plot_track(p.y_pred[0], i, ax[0])
ax[0].set_ylabel("Sox2\npred")
ax[0].set_xlim(xlim)
plot_track(p.y[0], i, ax[1])
ax[0].text(xlim[0]+1, 0.07, "Predicted and observed ChipNexus profile")
ax[1].set_ylabel("Sox2\nobs")
ax[1].set_yticks([0, 10])
ax[1].set_xlim(xlim)
ax[1].legend();
plot_track(p.y_pred[1], i, ax[2])
ax[2].set_ylabel("Oct4\npred")
ax[2].set_yticks([0, 0.05])
ax[2].set_xlim(xlim)
plot_track(p.y[1], i, ax[3])
ax[3].set_ylabel("Oct4\nobs")
ax[3].set_yticks([0, 50])
ax[3].set_xlim(xlim)
# ----------------
gi = p.input_grad(p.x[[i]], 'pos', 0, "max") * p.x[[i]]
seqlogo(gi[0], ax=ax[4])
ax[4].set_ylabel("Sox2 pos\nmax p.")
ax[4].text(xlim[0]+1, 0.9, "Importance scores: d(output)/d(input) * input")
ax[4].set_xlim(xlim)

gi = p.input_grad(p.x[[i]], 'neg', 0, "max") * p.x[[i]]
seqlogo(gi[0], ax=ax[5])
ax[5].set_ylabel("Sox2 neg\nmax p.")
ax[5].set_xlim(xlim)

gip = p.input_grad(p.x[[i]], 'pos', 1, "max") * p.x[[i]]
seqlogo(gip[0], ax=ax[6])
ax[6].set_ylabel("Oct4 pos\nmax p.")
ax[6].set_xlim(xlim)

gin = p.input_grad(p.x[[i]], 'neg', 1, "max") * p.x[[i]]
seqlogo(gin[0], ax=ax[7])
ax[7].set_ylabel("Oct4 neg\nmax p.")
ax[7].set_xlim(xlim)

gi_pos = p.input_grad(p.x[[i]], 'neg', 0, "count") * p.x[[i]]
gi_neg = p.input_grad(p.x[[i]], 'pos', 0, "count") * p.x[[i]]
gi_pos2 = p.input_grad(p.x[[i]], 'neg', 1, "count") * p.x[[i]]
gi_neg2 = p.input_grad(p.x[[i]], 'pos', 1, "count") * p.x[[i]]

seqlogo((gi_pos[0]+gi_neg[0] + gi_pos2[0]+gi_neg2[0])/4, ax=ax[8])
ax[8].set_ylabel("Total\ncounts")
ax[8].set_yticks([0, 0.3]); 
ax[8].set_xticks(list(range(xlim[0]+2, xlim[1]+1, 5))); 
ax[8].set_xlabel("Position"); 
ax[8].set_xlim(xlim)

# seq importances
plt.savefig('fig/icml18/grad_track.1788.png', dpi=600)
plt.savefig('fig/icml18/grad_track.1788.pdf', dpi=600)
plt.close(fig)    # close the figure
fig
Out[756]:
In [581]:
fig, ax = plt.subplots(10, 1, sharex=True, 
                       gridspec_kw = {'wspace':0, 'hspace':0},
                       figsize=(17, 12))
                       #gridspec_kw = {'height_ratios':[1]*len(tracks) + [seq_height]})

# TODO - enable for all
# ax[-1].yaxis.set_major_formatter(FormatStrFormatter('%.1f'))
# ----------------
# tracks (4x)
plot_track(p.y_pred[0], i, ax[0])
ax[0].legend();
ax[0].set_ylabel("Sox2\nPredicted")
plot_track(p.y[0], i, ax[1])
ax[1].set_ylabel("Sox2\nObserved")
plot_track(p.y_pred[1], i, ax[2])
ax[2].set_ylabel("Oct4\nPredicted")
plot_track(p.y[1], i, ax[3])
ax[3].set_ylabel("Oct4\nObserved")
# ----------------
gi = p.input_grad(p.x[[i]], 'pos', 0, "max") * p.x[[i]]
seqlogo(gi[0], ax=ax[4])
ax[4].set_ylabel("Sox2 pos\ng(max) * i")

gi = p.input_grad(p.x[[i]], 'neg', 0, "max") * p.x[[i]]
seqlogo(gi[0], ax=ax[5])
ax[5].set_ylabel("Sox2 pos\ng(max) * i")

gi = p.input_grad(p.x[[i]], 'pos', 1, "max") * p.x[[i]]
seqlogo(gi[0], ax=ax[6])
ax[6].set_ylabel("Oct4 pos\ng(max) * i")

gi = p.input_grad(p.x[[i]], 'neg', 1, "max") * p.x[[i]]
seqlogo(gi[0], ax=ax[7])
ax[7].set_ylabel("Oct4 neg\ng(max) * i")

gi_pos = p.input_grad(p.x[[i]], 'neg', 0, "count") * p.x[[i]]
gi_neg = p.input_grad(p.x[[i]], 'pos', 0, "count") * p.x[[i]]

seqlogo((gi_pos[0]+gi_neg[0])/2, ax=ax[8])
ax[8].set_ylabel("Sox2 pos+neg\ng(counts) * i")

gi_pos = p.input_grad(p.x[[i]], 'neg', 1, "count") * p.x[[i]]
gi_neg = p.input_grad(p.x[[i]], 'pos', 1, "count") * p.x[[i]]

seqlogo((gi_pos[0]+gi_neg[0])/2, ax=ax[9])
ax[9].set_ylabel("Oct4 pos+neg\ng(counts) * i")
ax[9].set_xticks(list(range(0, p.seq_len, 5))); 

# seq importances
#plt.savefig('fig/icml18/grad_track.1788.png', dpi=600)
#plt.savefig('fig/icml18/grad_track.1788.pdf', dpi=600)
#plt.close(fig)    # close the figure
#fig
In [ ]:
 
In [19]:
p.plot(sort='max_sox2', seq_grad='max', figsize=(20,12))
In [ ]:
 
In [380]:
p.plot(sort='max_oct4', seq_grad='max', figsize=(20,12))