In [1]:
import basepair
Using TensorFlow backend.
In [2]:
from keras.models import Model, load_model
from basepair.losses import twochannel_multinomial_nll
In [3]:
# Use gpus 3, 5
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "3, 5"
In [4]:
model = load_model("model.h5", custom_objects={"twochannel_multinomial_nll": twochannel_multinomial_nll})
WARNING:tensorflow:From /users/amr1/miniconda3/envs/basepair/lib/python3.6/site-packages/tensorflow/python/util/deprecation.py:497: calling conv1d (from tensorflow.python.ops.nn_ops) with data_format=NHWC is deprecated and will be removed in a future version.
Instructions for updating:
`NHWC` for data_format is deprecated, use `NWC` instead
2018-09-17 14:06:19,521 [WARNING] From /users/amr1/miniconda3/envs/basepair/lib/python3.6/site-packages/tensorflow/python/util/deprecation.py:497: calling conv1d (from tensorflow.python.ops.nn_ops) with data_format=NHWC is deprecated and will be removed in a future version.
Instructions for updating:
`NHWC` for data_format is deprecated, use `NWC` instead
WARNING:tensorflow:From /users/amr1/miniconda3/envs/basepair/lib/python3.6/site-packages/tensorflow/contrib/learn/python/learn/datasets/base.py:198: retry (from tensorflow.contrib.learn.python.learn.datasets.base) is deprecated and will be removed in a future version.
Instructions for updating:
Use the retry module or similar alternatives.
2018-09-17 14:06:31,791 [WARNING] From /users/amr1/miniconda3/envs/basepair/lib/python3.6/site-packages/tensorflow/contrib/learn/python/learn/datasets/base.py:198: retry (from tensorflow.contrib.learn.python.learn.datasets.base) is deprecated and will be removed in a future version.
Instructions for updating:
Use the retry module or similar alternatives.
In [5]:
from basepair.utils import read_pkl
train,valid,test = read_pkl("/users/avsec/workspace/basepair-workflow/models/0/data.pkl")
preproc = read_pkl("/users/avsec/workspace/basepair-workflow/models/0/preprocessor.pkl")
In [6]:
from basepair.eval import evaluate
evaluate(model, valid[0], valid[1])
Out[6]:
{'loss': 953.9045610385498,
 'profile/Oct4_loss': 271.63486264911495,
 'profile/Sox2_loss': 162.5192502521077,
 'profile/Klf4_loss': 228.8282409395277,
 'profile/Nanog_loss': 265.30063137148727,
 'counts/Oct4_loss': 0.612707477322428,
 'counts/Sox2_loss': 0.5658578240037332,
 'counts/Klf4_loss': 0.8421652589296793,
 'counts/Nanog_loss': 0.5414273575545541}
In [8]:
task_names = ["profile/Oct4", "profile/Sox2", "profile/Klf4", "profile/Nanog",
              "counts/Oct4", "counts/Sox2", "counts/Klf4", "counts/Nanog"]
In [10]:
import keras.backend as K
inp = model.inputs[0]
fn_pos = {}
fn_neg = {}
for task_id, task_name in enumerate(task_names):
    if "counts" in task_name:
        fn_pos[task_name] = K.function([inp], K.gradients(model.outputs[task_id][:, 0], inp))
        fn_neg[task_name] = K.function([inp], K.gradients(model.outputs[task_id][:, 1], inp))
    else:
        fn_pos[task_name] = K.function([inp], K.gradients(K.sum(K.stop_gradient(K.softmax(
            model.outputs[task_id][:, :, 0])) * model.outputs[task_id][:, :, 0], axis=-1), inp))
        fn_neg[task_name] = K.function([inp], K.gradients(K.sum(K.stop_gradient(K.softmax(
            model.outputs[task_id][:, :, 1])) * model.outputs[task_id][:, :, 1], axis=-1), inp))
In [11]:
import numpy as np
from basepair.data import numpy_minibatch

grads_pos = {}
grads_neg = {}

for task_name in task_names:
    grads_pos[task_name] = np.concatenate([np.array(fn_pos[task_name]([batch])).squeeze()
                                           for batch in numpy_minibatch(valid[0], 512)])
    grads_neg[task_name] = np.concatenate([np.array(fn_neg[task_name]([batch])).squeeze()
                                           for batch in numpy_minibatch(valid[0], 512)])
2018-09-17 14:16:05,969 [WARNING] git-lfs not installed
2018-09-17 14:16:06,128 [WARNING] git-lfs not installed
In [12]:
# Setup different scores
hyp_scores = {}
scores = {}
for task_name in task_names:
    hyp_scores[task_name] = grads_pos[task_name] + grads_neg[task_name]
    hyp_scores[task_name] = hyp_scores[task_name] - hyp_scores[task_name].mean(-1, keepdims=True)
    scores[task_name] = hyp_scores[task_name] * valid[0]
In [13]:
onehot_data = valid[0]
task_to_scores = scores
task_to_hyp_scores = hyp_scores
In [14]:
from concise.utils.plot import seqlogo_fig, seqlogo
import matplotlib.pyplot as plt
fig, (ax0, ax1)= plt.subplots(2, 1, sharex=True, figsize=(20, 6))

ax0.set_title("scores")
seqlogo(scores["profile/Oct4"][0], ax=ax0)

ax1.set_title("hyp_scores")
seqlogo(hyp_scores["profile/Oct4"][0], ax=ax1)
In [27]:
tfs = ["Oct4", "Sox2", "Klf4", "Nanog"]
def bpnet_predict(seqs, tasks=tfs, preproc=preproc):
    from basepair.math import softmax
    from kipoi.data_utils import get_dataset_item

    preds = model.predict(seqs, batch_size=512)

    # Use softmax for the first prediction
    preds[:len(tasks)] = [softmax(p) for p in preds[:len(tasks)]]
    preds_dict = dict(profile=preds[:len(tasks)],
                      counts=preds[len(tasks):])

    # Dev
    def pred2scale_strands(preds):
        """Compute the scaling factor for the profile in order to
        obtain the absolute counts
        """
        return {task: np.exp(preproc.objects[f'profile/{task}'].steps[1][1].inverse_transform(preds['counts'][i])) - 1
                for i, task in enumerate(tasks)}

    def append_scaled_profile(preds):
        scales = pred2scale_strands(preds)
        preds['scaled_profile'] = {task: preds['profile'][i] * scales[task][np.newaxis]
                                   for i, task in enumerate(tasks)}
        return preds

    out = [dict(
        seq=get_dataset_item(seqs, i),
        preds=append_scaled_profile(get_dataset_item(preds_dict, i)),
        # scaling factor to go from relative -> absolute counts
        scale_profile=pred2scale_strands(get_dataset_item(preds_dict, i)),
    ) for i in range(len(seqs))]
    return out
In [28]:
y_preds = bpnet_predict(onehot_data)
In [66]:
def plot_preds(y_preds, num_samples=5, figsize=(20, 2), binsize=1):
    import matplotlib.pyplot as plt
    from basepair.preproc import bin_counts

    # for visualization, we use bucketize
    for idx in range(num_samples):
        fig = plt.figure(figsize=figsize)
        
        plt.subplot(141)
        if idx == 0:
            plt.title("Oct4")
        plt.plot(bin_counts(y_preds[idx]['preds']['scaled_profile']['Oct4'], binsize=binsize)[:, 0], label='pos,m={}'.format(np.argmax(y_preds[idx]['preds']['scaled_profile']['Oct4'][:, 0])))
        plt.plot(bin_counts(y_preds[idx]['preds']['scaled_profile']['Oct4'], binsize=binsize)[:, 1], label='neg,m={}'.format(np.argmax(y_preds[idx]['preds']['scaled_profile']['Oct4'][:, 1])))
        plt.legend()
        
        plt.subplot(142)
        if idx == 0:
            plt.title("Sox2")
        plt.plot(bin_counts(y_preds[idx]['preds']['scaled_profile']['Sox2'], binsize=binsize)[:, 0], label='pos,m={}'.format(np.argmax(y_preds[idx]['preds']['scaled_profile']['Sox2'][:, 0])))
        plt.plot(bin_counts(y_preds[idx]['preds']['scaled_profile']['Sox2'], binsize=binsize)[:, 1], label='neg,m={}'.format(np.argmax(y_preds[idx]['preds']['scaled_profile']['Sox2'][:, 1])))
        plt.legend()
        
        plt.subplot(143)
        if idx == 0:
            plt.title("Klf4")
        plt.plot(bin_counts(y_preds[idx]['preds']['scaled_profile']['Klf4'], binsize=binsize)[:, 0], label='pos,m={}'.format(np.argmax(y_preds[idx]['preds']['scaled_profile']['Klf4'][:, 0])))
        plt.plot(bin_counts(y_preds[idx]['preds']['scaled_profile']['Klf4'], binsize=binsize)[:, 1], label='neg,m={}'.format(np.argmax(y_preds[idx]['preds']['scaled_profile']['Klf4'][:, 1])))
        plt.legend()
        
        plt.subplot(144)
        if idx == 0:
            plt.title("Nanog")
        plt.plot(bin_counts(y_preds[idx]['preds']['scaled_profile']['Nanog'], binsize=binsize)[:, 0], label='pos,m={}'.format(np.argmax(y_preds[idx]['preds']['scaled_profile']['Nanog'][:, 0])))
        plt.plot(bin_counts(y_preds[idx]['preds']['scaled_profile']['Nanog'], binsize=binsize)[:, 1], label='neg,m={}'.format(np.argmax(y_preds[idx]['preds']['scaled_profile']['Nanog'][:, 1])))
        plt.legend()
In [67]:
plot_preds(y_preds)
In [54]:
def onehot_to_seq(onehot):
    seq = ""
    for pos in range(len(onehot)):
        char_idx = [i for i, e in enumerate(onehot[pos]) if e != 0][0]
        if char_idx == 0:
            char = 'A'
        elif char_idx == 1:
            char = "C"
        elif char_idx == 2:
            char = "G"
        elif char_idx == 3:
            char = "T"
        seq +=  char
    return seq
In [55]:
#this is set up for 1d convolutions where examples
#have dimensions (len, num_channels)
#the channel axis is the axis for one-hot encoding.
def one_hot_encode_along_channel_axis(sequence):
    to_return = np.zeros((len(sequence),4), dtype=np.int8)
    seq_to_one_hot_fill_in_array(zeros_array=to_return,
                                 sequence=sequence, one_hot_axis=1)
    return to_return

def seq_to_one_hot_fill_in_array(zeros_array, sequence, one_hot_axis):
    assert one_hot_axis==0 or one_hot_axis==1
    if (one_hot_axis==0):
        assert zeros_array.shape[1] == len(sequence)
    elif (one_hot_axis==1): 
        assert zeros_array.shape[0] == len(sequence)
    #will mutate zeros_array
    for (i,char) in enumerate(sequence):
        if (char=="A" or char=="a"):
            char_idx = 0
        elif (char=="C" or char=="c"):
            char_idx = 1
        elif (char=="G" or char=="g"):
            char_idx = 2
        elif (char=="T" or char=="t"):
            char_idx = 3
        elif (char=="N" or char=="n"):
            continue #leave that pos as all 0's
        else:
            raise RuntimeError("Unsupported character: "+str(char))
        if (one_hot_axis==0):
            zeros_array[char_idx,i] = 1
        elif (one_hot_axis==1):
            zeros_array[i,char_idx] = 1
In [56]:
from deeplift.dinuc_shuffle import dinuc_shuffle
In [91]:
num_samples = 5
num_shuffles = 50

for idx in range(num_samples):
    
    ref_pred = {}
    ref_pred[0] = {}
    ref_pred[0]['preds'] = {}
    ref_pred[0]['preds']['scaled_profile'] = {}
    ref_pred[0]['preds']['scaled_profile']['Oct4'] = np.zeros((200, 2))
    ref_pred[0]['preds']['scaled_profile']['Sox2'] = np.zeros((200, 2))
    ref_pred[0]['preds']['scaled_profile']['Klf4'] = np.zeros((200, 2))
    ref_pred[0]['preds']['scaled_profile']['Nanog'] = np.zeros((200, 2))

    for shuff in range(num_shuffles):
        reference = one_hot_encode_along_channel_axis(dinuc_shuffle(onehot_to_seq(onehot_data[idx])))
        
        pred = bpnet_predict([[reference]])
        ref_pred[0]['preds']['scaled_profile']['Oct4'] += pred[0]['preds']['scaled_profile']['Oct4']
        ref_pred[0]['preds']['scaled_profile']['Sox2'] += pred[0]['preds']['scaled_profile']['Sox2']
        ref_pred[0]['preds']['scaled_profile']['Klf4'] += pred[0]['preds']['scaled_profile']['Klf4']
        ref_pred[0]['preds']['scaled_profile']['Nanog'] += pred[0]['preds']['scaled_profile']['Nanog']

    ref_pred[0]['preds']['scaled_profile']['Oct4'] /= num_shuffles
    ref_pred[0]['preds']['scaled_profile']['Sox2'] /= num_shuffles
    ref_pred[0]['preds']['scaled_profile']['Klf4'] /= num_shuffles
    ref_pred[0]['preds']['scaled_profile']['Nanog'] /= num_shuffles

    
#     reference = np.zeros((200,4))#np.array([[0.3, 0.2, 0.2, 0.3]]*200)
#     ref_pred = np.array(bpnet_predict([[reference]]))
    
    plot_preds(ref_pred, num_samples=1)
In [68]:
def plot_truth(y, num_samples=5, figsize=(20, 2), binsize=1):
    import matplotlib.pyplot as plt
    from basepair.preproc import bin_counts

    # for visualization, we use bucketize
    for idx in range(num_samples):
        fig = plt.figure(figsize=figsize)
        
        plt.subplot(141)
        if idx == 0:
            plt.title("Oct4")
        plt.plot(bin_counts(y['profile/Oct4'], binsize=binsize)[idx, :, 0], label='pos,m={}'.format(np.argmax(y['profile/Oct4'][idx, :, 0])))
        plt.plot(bin_counts(y['profile/Oct4'], binsize=binsize)[idx, :, 1], label='neg,m={}'.format(np.argmax(y['profile/Oct4'][idx, :, 1])))
        plt.legend()
        
        plt.subplot(142)
        if idx == 0:
            plt.title("Sox2")
        plt.plot(bin_counts(y['profile/Sox2'], binsize=binsize)[idx, :, 0], label='pos,m={}'.format(np.argmax(y['profile/Sox2'][idx, :, 0])))
        plt.plot(bin_counts(y['profile/Sox2'], binsize=binsize)[idx, :, 1], label='neg,m={}'.format(np.argmax(y['profile/Sox2'][idx, :, 1])))
        plt.legend()
        
        plt.subplot(143)
        if idx == 0:
            plt.title("Klf4")
        plt.plot(bin_counts(y['profile/Klf4'], binsize=binsize)[idx, :, 0], label='pos,m={}'.format(np.argmax(y['profile/Klf4'][idx, :, 0])))
        plt.plot(bin_counts(y['profile/Klf4'], binsize=binsize)[idx, :, 1], label='neg,m={}'.format(np.argmax(y['profile/Klf4'][idx, :, 1])))
        plt.legend()
        
        plt.subplot(144)
        if idx == 0:
            plt.title("Nanog")
        plt.plot(bin_counts(y['profile/Nanog'], binsize=binsize)[idx, :, 0], label='pos,m={}'.format(np.argmax(y['profile/Nanog'][idx, :, 0])))
        plt.plot(bin_counts(y['profile/Nanog'], binsize=binsize)[idx, :, 1], label='neg,m={}'.format(np.argmax(y['profile/Nanog'][idx, :, 1])))
        plt.legend()
In [69]:
plot_truth(valid[1])