Goal

  • fine-tune the ChIP-nexus model

Open questions

  • training a genome-wide model
    • how does the count distribution look like genome-wide, how in the peaks and how in the accessible regions?
      • is MSE still a good loss function
In [1]:
from basepair.imports import *
from basepair.datasets import *
from kipoi.data_utils import numpy_collate_concat
from basepair.config import valid_chr, test_chr
from basepair.utils import read_json
Using TensorFlow backend.
In [2]:
from concise.losses import binary_crossentropy_masked
from concise.metrics import accuracy
from concise.eval_metrics import auprc
from basepair.samplers import StratifiedRandomBatchSampler

from keras.models import Sequential
from keras.optimizers import Adam
from keras.callbacks import EarlyStopping, History, CSVLogger, ModelCheckpoint
In [3]:
# test StratifiedRandomBatchSampler
sampler = StratifiedRandomBatchSampler(np.array([0,0,0,0, 0, 1,1]), p_vec=[0.5, 0.5], batch_size=2)

list(sampler)
Out[3]:
[[4, 5], [1, 6], [3, 5]]
In [4]:
create_tf_session(0)
Out[4]:
<tensorflow.python.client.session.Session at 0x7fb7b1c65080>
In [5]:
model_dir = Path(f"{ddir}/processed/chipnexus/exp/models/oct-sox-nanog-klf/models/n_dil_layers=9/")
In [6]:
# Load the model
model = load_model(model_dir / "model.h5")
WARNING:tensorflow:From /users/avsec/bin/anaconda3/envs/chipnexus/lib/python3.6/site-packages/tensorflow/python/util/deprecation.py:497: calling conv1d (from tensorflow.python.ops.nn_ops) with data_format=NHWC is deprecated and will be removed in a future version.
Instructions for updating:
`NHWC` for data_format is deprecated, use `NWC` instead
2018-12-11 03:46:46,539 [WARNING] From /users/avsec/bin/anaconda3/envs/chipnexus/lib/python3.6/site-packages/tensorflow/python/util/deprecation.py:497: calling conv1d (from tensorflow.python.ops.nn_ops) with data_format=NHWC is deprecated and will be removed in a future version.
Instructions for updating:
`NHWC` for data_format is deprecated, use `NWC` instead
WARNING:tensorflow:From /users/avsec/bin/anaconda3/envs/chipnexus/lib/python3.6/site-packages/tensorflow/contrib/learn/python/learn/datasets/base.py:198: retry (from tensorflow.contrib.learn.python.learn.datasets.base) is deprecated and will be removed in a future version.
Instructions for updating:
Use the retry module or similar alternatives.
2018-12-11 03:46:55,279 [WARNING] From /users/avsec/bin/anaconda3/envs/chipnexus/lib/python3.6/site-packages/tensorflow/contrib/learn/python/learn/datasets/base.py:198: retry (from tensorflow.contrib.learn.python.learn.datasets.base) is deprecated and will be removed in a future version.
Instructions for updating:
Use the retry module or similar alternatives.

Load the data

In [7]:
batch_size = 256
num_workers = 8
In [8]:
intervalspec = read_json("/srv/scratch/avsec/workspace/chipnexus/data/processed/chipseq/labels/chipnexus/accessible/intervalspec.json")

tasks = intervalspec['task_names']

intervals_file = intervalspec['chipnexus']['intervals_file']

fasta_file = "/mnt/data/pipeline_genome_data/mm10/mm10_no_alt_analysis_set_ENCODE.fasta"
In [9]:
intervals_file
Out[9]:
'/srv/scratch/avsec/workspace/chipnexus/data/processed/chipseq/labels/chipnexus/accessible/oct4-sox2.intervals_file.DNase-accessible.tsv.gz'
In [10]:
train = SeqClassification(intervals_file, fasta_file, excl_chromosomes=valid_chr+test_chr)
In [11]:
valid = SeqClassification(intervals_file, fasta_file, incl_chromosomes=valid_chr)
In [12]:
len(train)
Out[12]:
3813248
In [13]:
len(valid)
Out[13]:
1143141

Frozen model

Add a model on top

In [14]:
transfer_to = "add_9"

inp = kl.Input((model.inputs[0].shape[1].value, 4), name='seq')
# Transferred part
tmodel = Model(model.inputs,
               model.get_layer(transfer_to).output)

# Freeze all the layers up to (including) the freeze_to layer
for l in tmodel.layers:
    l.trainable = False

# define the top model
top_model = Sequential([
    kl.MaxPool1D(50, input_shape=tmodel.output_shape[1:]),
    kl.Flatten(),
    kl.Dense(64),
    kl.BatchNormalization(),
    kl.Activation('relu'),
    kl.Dropout(0.2),
    #kl.Dense(128),
    #kl.BatchNormalization(),
    #kl.Activation('relu'),
    #kl.Dropout(0.5),
    kl.Dense(2, activation='sigmoid')
])

final_model = Model([inp], top_model(tmodel(inp)))
final_model.compile(Adam(lr=0.003), binary_crossentropy_masked, metrics=[accuracy])
In [21]:
final_model.summary()
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
seq (InputLayer)             (None, 1000, 4)           0         
_________________________________________________________________
model_1 (Model)              (None, 1000, 64)          117632    
_________________________________________________________________
sequential_1 (Sequential)    (None, 2)                 82370     
=================================================================
Total params: 200,002
Trainable params: 82,242
Non-trainable params: 117,760
_________________________________________________________________

Train the model

In [20]:
mdir = f"{ddir}/processed/chipseq/exp/models/finetune-bpnet/fine-tune-conv"

ckp_file = f"{mdir}/model.h5"
history_path = f"{mdir}/history.csv"
In [21]:
!mkdir -p {mdir}
In [22]:
from basepair.samplers import StratifiedRandomBatchSampler

train_it = train.batch_train_iter(shuffle=False,
                                  batch_size=1,
                                  drop_last=None,
                                  batch_sampler=StratifiedRandomBatchSampler(train.get_targets().max(axis=1),
                                                                             batch_size=batch_size,
                                                                             p_vec=[0.95, 0.05],
                                                                             verbose=True),
                                  num_workers=num_workers)
next(train_it)
valid_it = valid.batch_train_iter(batch_size=batch_size,
                                  shuffle=True,
                                  num_workers=num_workers)
next(valid_it);
Using batch sizes:
[243  13]
In [ ]:
# -------------------------------------------
final_model.fit_generator(train_it,
                    epochs=100,
                    steps_per_epoch=int(len(train) / batch_size * 0.1),
                    validation_data=valid_it,
                    validation_steps=int(len(valid) / batch_size * 0.2),
                    # sample_weight=sample_weight,
                    callbacks=[EarlyStopping(patience=4),
                               CSVLogger(history_path),
                               ModelCheckpoint(ckp_file, save_best_only=True)]
                    )
final_model = load_model(ckp_file)
Epoch 1/100
1489/1489 [==============================] - 135s 91ms/step - loss: 0.0946 - accuracy: 0.9698 - val_loss: 0.0472 - val_accuracy: 0.9909
Epoch 2/100
1488/1489 [============================>.] - ETA: 0s - loss: 0.0800 - accuracy: 0.9747
In [26]:
a=1

Evaluate

In [28]:
valid_it = valid.batch_predict_iter(batch_size=batch_size,
                                  shuffle=False,
                                  num_workers=num_workers*2)
preds = final_model.predict_generator(valid_it, verbose=1, steps=len(valid) // batch_size)
n_points = len(valid) // batch_size
4465/4465 [==============================] - 231s 52ms/step
In [29]:
print("Oct4:")
auprc(valid.get_targets()[:(n_points*batch_size),0], preds[:,0])
Oct4:
Out[29]:
0.301518766323592
In [30]:
print("Sox2:")
auprc(valid.get_targets()[:(n_points*batch_size),1], preds[:,1])
Sox2:
Out[30]:
0.14600623931197712

Further fine-tuned model accuracy

In [31]:
for l in final_model.layers[1].layers:
    l.trainable = True
final_model.compile(Adam(lr=0.003), binary_crossentropy_masked, metrics=[accuracy])
In [32]:
# -------------------------------------------
final_model.fit_generator(train_it,
                    epochs=100,
                    steps_per_epoch=int(len(train) / batch_size * 0.1),
                    validation_data=valid_it,
                    validation_steps=int(len(valid) / batch_size * 0.2),
                    # sample_weight=sample_weight,
                    callbacks=[EarlyStopping(patience=4),
                               CSVLogger(history_path),
                               ModelCheckpoint(ckp_file, save_best_only=True)]
                    )
Epoch 1/100
1488/1489 [============================>.] - ETA: 0s - loss: 0.0787 - accuracy: 0.9748
---------------------------------------------------------------------------
StopIteration                             Traceback (most recent call last)
<ipython-input-32-26e48c871750> in <module>
      8                     callbacks=[EarlyStopping(patience=4),
      9                                CSVLogger(history_path),
---> 10                                ModelCheckpoint(ckp_file, save_best_only=True)]
     11                     )

~/bin/anaconda3/envs/chipnexus/lib/python3.6/site-packages/comet_ml/monkey_patching.py in wrapper(*args, **kwargs)
    241                     )
    242 
--> 243         return_value = original(*args, **kwargs)
    244 
    245         # Call after callbacks once we have the return value

~/bin/anaconda3/envs/chipnexus/lib/python3.6/site-packages/keras/legacy/interfaces.py in wrapper(*args, **kwargs)
     89                 warnings.warn('Update your `' + object_name +
     90                               '` call to the Keras 2 API: ' + signature, stacklevel=2)
---> 91             return func(*args, **kwargs)
     92         wrapper._original_function = func
     93         return wrapper

~/bin/anaconda3/envs/chipnexus/lib/python3.6/site-packages/keras/engine/training.py in fit_generator(self, generator, steps_per_epoch, epochs, verbose, callbacks, validation_data, validation_steps, class_weight, max_queue_size, workers, use_multiprocessing, shuffle, initial_epoch)
   2242                                 workers=workers,
   2243                                 use_multiprocessing=use_multiprocessing,
-> 2244                                 max_queue_size=max_queue_size)
   2245                         else:
   2246                             # No need for try/except because

~/bin/anaconda3/envs/chipnexus/lib/python3.6/site-packages/keras/legacy/interfaces.py in wrapper(*args, **kwargs)
     89                 warnings.warn('Update your `' + object_name +
     90                               '` call to the Keras 2 API: ' + signature, stacklevel=2)
---> 91             return func(*args, **kwargs)
     92         wrapper._original_function = func
     93         return wrapper

~/bin/anaconda3/envs/chipnexus/lib/python3.6/site-packages/keras/engine/training.py in evaluate_generator(self, generator, steps, max_queue_size, workers, use_multiprocessing)
   2360 
   2361             while steps_done < steps:
-> 2362                 generator_output = next(output_generator)
   2363                 if not hasattr(generator_output, '__len__'):
   2364                     raise ValueError('Output of generator should be a tuple '

~/bin/anaconda3/envs/chipnexus/lib/python3.6/site-packages/keras/utils/data_utils.py in get(self)
    783                 all_finished = all([not thread.is_alive() for thread in self._threads])
    784                 if all_finished and self.queue.empty():
--> 785                     raise StopIteration()
    786                 else:
    787                     time.sleep(self.wait_time)

StopIteration: 
In [ ]:
valid_it = valid.batch_predict_iter(batch_size=batch_size,
                                  shuffle=False,
                                  num_workers=num_workers*2)
In [ ]:
preds = final_model.predict_generator(valid_it, verbose=1, steps=len(valid) // batch_size)
In [ ]:
n_points = len(valid) // batch_size
In [ ]:
print("Oct4:")
auprc(valid.get_targets()[:(n_points*batch_size),0], preds[:,0])
In [ ]:
print("Sox2:")
auprc(valid.get_targets()[:(n_points*batch_size),1], preds[:,1])
In [72]:
print("Oct4:")
auprc(valid.get_targets()[:(n_points*batch_size),0], preds[:,0])
Oct4:
Out[72]:
0.3844448817381261
In [73]:
print("Sox2:")
auprc(valid.get_targets()[:(n_points*batch_size),1], preds[:,1])
Sox2:
Out[73]:
0.22240325400318217

Open question

  • what accuracy do we achieve if we directly fit the model?
In [85]:
class BPNetSequenceClassifier:

    def __init__(self, shapes, num_tasks,
                 filters=64,
                 conv1_kernel_size=25,
                 n_dil_layers=9,
                 pool_size=50,
                 fc_units=[64],
                 dropout=0.2):
        import keras.layers as kl
        assert len(num_filters) == len(conv_width)

        # configure inputs
        keras_inputs = self.get_keras_inputs(shapes)
        inputs = self.reshape_keras_inputs(keras_inputs)

        # convolve sequence
        inp = inputs["data/genome_data_dir"]
        first_conv = kl.Conv1D(filters,
                           kernel_size=conv1_kernel_size,
                           padding='same',
                           activation='relu')(inp)
        prev_layers = [first_conv]
        for i in range(1, n_dil_layers + 1):
            if i == 1:
                prev_sum = first_conv
            else:
                prev_sum = kl.add(prev_layers)
            conv_output = kl.Conv1D(filters, kernel_size=3, padding='same', activation='relu', dilation_rate=2**i)(prev_sum)
            prev_layers.append(conv_output)
        combined_conv = kl.add(prev_layers, name='final_conv')

        seq_preds = kl.MaxPooling1D(pool_size)(combined_conv)
        seq_preds = kl.Flatten()(seq_preds)
        for units in fc_units:
            seq_preds = kl.Dense(units)(seq_preds)
            seq_preds = kl.BatchNormalization()(seq_preds)
            seq_preds = kl.Activation('relu')(seq_preds)
            seq_preds = kl.Dropout(dropout)(seq_preds)
        seq_preds = kl.Dense(2, activation='sigmoid')(seq_preds)
        
        self.model = Model(input=keras_inputs.values(), output=seq_preds)

Training a model from scratch

In [ ]:
from basepair.models import binary_seq_multitask
In [ ]:
mdir = f"{ddir}/processed/chipseq/exp/models/finetune-bpnet/from-scratch3"
!mkdir -p {mdir}
ckp_file = f"{mdir}/model.h5"
history_path = f"{mdir}/history.csv"
In [ ]:
train_it = train.batch_train_iter(shuffle=False,
                                  batch_size=1,
                                  drop_last=None,
                                  batch_sampler=StratifiedRandomBatchSampler(train.get_targets().max(axis=1),
                                                                             batch_size=batch_size,
                                                                             p_vec=[0.95, 0.05],
                                                                             verbose=True),
                                  num_workers=num_workers)
next(train_it)
valid_it = valid.batch_train_iter(batch_size=batch_size,
                                  shuffle=True,
                                  num_workers=num_workers)
next(valid_it);
In [ ]:
final_model = binary_seq_multitask()
In [ ]:
# -------------------------------------------
final_model.fit_generator(train_it,
                    epochs=100,
                    steps_per_epoch=int(len(train) / batch_size * 0.1),
                    validation_data=valid_it,
                    validation_steps=int(len(valid) / batch_size * 0.2),
                    callbacks=[EarlyStopping(patience=4),
                               CSVLogger(history_path),
                               ModelCheckpoint(ckp_file, save_best_only=True)]
                    )
In [ ]:
final_model = load_model(ckp_file)
In [ ]:
# get predictions
valid_it = valid.batch_predict_iter(batch_size=batch_size, shuffle=False, num_workers=num_workers*2)
preds = final_model.predict_generator(valid_it, verbose=1, steps=len(valid) // batch_size)
n_points = len(valid) // batch_size
In [ ]:
print("Oct4:")
auprc(valid.get_targets()[:(n_points*batch_size),0], preds[:,0])
In [ ]:
print("Sox2:")
auprc(valid.get_targets()[:(n_points*batch_size),1], preds[:,1])

Test the trainer

In [19]:
from basepair.trainers import KerasTrainer
In [20]:
final_model = binary_seq_multitask()
In [21]:
trainer = KerasTrainer(final_model, train, valid, 
                       f"{ddir}/processed/chipseq/exp/models/finetune-bpnet/from-scratch2")
In [ ]:
trainer.train(train_epoch_frac=0.1, valid_epoch_frac=0.2,
               train_batch_sampler=StratifiedRandomBatchSampler(trainer.train_dataset.get_targets().max(axis=1),
                                                                                  batch_size=batch_size,
                                                                                  p_vec=[0.95, 0.05],
                                                                                  verbose=True)
             )
Using batch sizes:
[243  13]
Epoch 1/100
1489/1489 [==============================] - 348s 233ms/step - loss: 0.1359 - accuracy: 0.9673 - val_loss: 0.0527 - val_accuracy: 0.9911
Epoch 2/100
  63/1489 [>.............................] - ETA: 4:48 - loss: 0.1256 - accuracy: 0.9697
In [35]:
from basepair.metrics import MetricsMultiTask, MetricsConcise
In [36]:
trainer.evaluate(MetricsMultiTask(MetricsConcise(['auprc']), ["Oct4", "Sox2"]))
4466it [04:17, 17.31it/s]                          
Out[36]:
OrderedDict([('Oct4', OrderedDict([('auprc', 0.34719715685409297)])),
             ('Sox2', OrderedDict([('auprc', 0.18396284759858067)]))])
In [38]:
!ls {trainer.output_dir}
evaluation.valid.json				   history.csv
events.out.tfevents.1537743409.surya.stanford.edu  model.h5
In [40]:
pd.read_csv(f"{trainer.output_dir}/history.csv")
Out[40]:
epoch accuracy loss val_accuracy val_loss
0 0 0.967336 0.135875 0.991116 0.052740
1 1 0.969678 0.125083 0.991285 0.065987
2 2 0.972112 0.099555 0.975883 0.111574
3 3 0.974504 0.081547 0.991572 0.038957
4 4 0.975163 0.077799 0.991634 0.033537
5 5 0.975634 0.075258 0.991216 0.047692
6 6 0.976106 0.072712 0.978279 0.081050
7 7 0.976420 0.070774 0.991768 0.036325
8 8 0.977098 0.068409 0.989134 0.039493
In [ ]:
!cat {trainer.output_dir}/evaluation.valid.json