Goal

  • train a decent seq-> profile + counts model for ChIP-seq

Resources

  • washu
    • session: 7-w-sox2-oct4-chipseq
In [8]:
# Use gpus 0,1
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0, 1"

from pathlib import Path
import sys
import numpy as np
import matplotlib.pyplot as plt
sys.path.append(str(Path(os.getcwd()).absolute().parent.parent))
sys.path.append('/opt/miniconda3/envs/basepair/lib/python3.6/site-packages')
from basepair import samplers
In [9]:
import basepair
In [10]:
from basepair.cli.schemas import DataSpec, TaskSpec
from basepair.datasets import chip_exo_nexus
from basepair.preproc import AppendTotalCounts
from basepair.config import get_data_dir, create_tf_session
In [11]:
ddir = '/home/prime/data'
In [12]:
bdir = "/data/sox2-oct4-chipseq/"

ds = DataSpec(task_specs={"Sox2": TaskSpec(task="Sox2",
                                           pos_counts=f"{bdir}/Sox2/pos.bw",
                                           neg_counts=f"{bdir}/Sox2/neg.bw",
                                           peaks=f"{bdir}/Sox2/Sox2_1_rep1-pr.IDR0.05.filt.12-col.bed.gz",
                                          ),
                          "Oct4": TaskSpec(task="Oct4",
                                           pos_counts=f"{bdir}/Oct4/pos2.bw",
                                           neg_counts=f"{bdir}/Oct4/neg2.bw",
                                           peaks=f"{bdir}/Oct4/Oct4_12_ppr.IDR0.05.filt.12-col.bed.gz",
                                          )
                         },
              fasta_file="/data/mm10_no_alt_analysis_set_ENCODE.fasta"
             )
In [13]:
def ds2bws(ds):
    return {task: {"pos": task_spec.pos_counts, "neg": task_spec.neg_counts} for task, task_spec in ds.task_specs.items()}
In [14]:
# Get the training data
train, valid, test = chip_exo_nexus(ds, peak_width=1000)
2018-11-10 01:01:13,044 [INFO] extract sequence
---------------------------------------------------------------------------
KeyboardInterrupt                         Traceback (most recent call last)
<ipython-input-14-4a5227272298> in <module>()
      1 # Get the training data
----> 2 train, valid, test = chip_exo_nexus(ds, peak_width=1000)

/opt/miniconda3/envs/basepair/lib/python3.6/site-packages/gin/config.py in wrapper(*args, **kwargs)
   1007 
   1008       try:
-> 1009         return fn(*new_args, **new_kwargs)
   1010       except Exception as e:  # pylint: disable=broad-except
   1011         err_str = ''

~/basepair/basepair/datasets.py in chip_exo_nexus(dataspec, peak_width, shuffle, preprocessor, interval_augm, valid_chr, test_chr)
    100 
    101     logger.info("extract sequence")
--> 102     seq = FastaExtractor(dataspec.fasta_file)(intervals)
    103 
    104     logger.info("extract counts")

/opt/miniconda3/lib/python3.6/site-packages/genomelake/extractors.py in __call__(self, intervals, out, **kwargs)
     24     def __call__(self, intervals, out=None, **kwargs):
     25         data = self._check_or_create_output_array(intervals, out)
---> 26         self._extract(intervals, data, **kwargs)
     27         return data
     28 

/opt/miniconda3/lib/python3.6/site-packages/genomelake/extractors.py in _extract(self, intervals, out, **kwargs)
     91         for index, interval in enumerate(intervals):
     92             seq = self.fasta.fetch(str(interval.chrom), interval.start,
---> 93                                        interval.stop)
     94             one_hot_encode_sequence(seq, out[index, :, :])
     95 

KeyboardInterrupt: 
In [6]:
import numpy as np
print (train[0].shape)
print (train[1]['profile/Sox2'].shape)
print(train[1]['counts/Sox2'].shape)

# print (train[1]['profile/Sox2'].max())

for k in range(train[0].shape[0]):
    mask1 = train[1]['profile/Sox2'][k] > 0.1
    mask2 = train[1]['profile/Sox2'][k] < 0.9
    elems = train[1]['profile/Sox2'][k][mask1 & mask2]
    if len(elems):
        print (elems)
---------------------------------------------------------------------------
NameError                                 Traceback (most recent call last)
<ipython-input-6-14fcac34fc02> in <module>()
      1 import numpy as np
----> 2 print (train[0].shape)
      3 print (train[1]['profile/Sox2'].shape)
      4 print(train[1]['counts/Sox2'].shape)
      5 

NameError: name 'train' is not defined
In [9]:
train[1]['profile/Sox2'].shape
Out[9]:
(14727, 1000, 2)
In [2]:
import numpy as np
np.unique(train[1]['profile/Sox2'].shape)
---------------------------------------------------------------------------
NameError                                 Traceback (most recent call last)
<ipython-input-2-6dc085abce7f> in <module>()
      1 import numpy as np
----> 2 np.unique(train[1]['profile/Sox2'].shape)

NameError: name 'train' is not defined
In [113]:
 
In [141]:
class Plotter:

    def __init__(self, ys):
        #size = (N, 1000, 2)
        self.ys = np.array(ys)
        
    def plot(self, binsize=10, n=10, sigma=1.6, lows=10, fft='r', sort='random', figsize=(20, 2), fpath_template=None):
        if sort == 'random':
            idx_list = samplers.random(self.ys[0], n)
        elif sort == 'sum':
            idx_list = samplers.top_sum_count(self.ys, n)
            idx_list = idx_list[2:]
        else: # sort == 'max':
            idx_list = samplers.top_max_count(self.ys, n)
            
        
            
        for i, idx in enumerate(idx_list):
            
            bin0 = binify(self.ys[idx, :, 0], binsize=binsize)
            bin1 = binify(self.ys[idx, :, 1], binsize=binsize)
            
            gauss0 = gaussian_filter1d(bin0, sigma=sigma)
            gauss1 = gaussian_filter1d(bin1, sigma=sigma)
            
            fft_func  = np.fft.fft  if fft == 'd' else np.fft.rfft
            ifft_func = np.fft.ifft if fft == 'd' else np.fft.irfft
            
            length = len(bin0)
            
            fft_low0 = fft_func(bin0)[:10]
            fft_low1 = fft_func(bin0)[:10]
            
            ifft_low0 = np.maximum(ifft_func(fft_low0, n=length), 0)
            ifft_low1 = np.maximum(ifft_func(fft_low1, n=length), 0)
            
                
            fig = plt.figure(figsize=figsize)
            
            plt.subplot(141)
            if i == 0:
                plt.title("Binned")
            plt.plot(bin0)
            plt.plot(bin1)
            
            plt.subplot(142)
            if i == 0:
                plt.title("Gaussian σ=%.2f" % sigma)
            plt.plot(gauss0)
            plt.plot(gauss1)
            
            plt.subplot(143)
            if i == 0:
                plt.title("FFT with low-pass")
            plt.plot(fft_low0)
            plt.plot(fft_low1)
            
            plt.subplot(144)
            if i == 0:
                plt.title("IFFT with low-pass")
            plt.plot(ifft_low0)
            plt.plot(ifft_low1)
            
            if fpath_template is not None:
                plt.savefig(fpath_template.format(i) + '.png', dpi=600)
                plt.savefig(fpath_template.format(i) + '.pdf', dpi=600)
                plt.close(fig)    # close the figure
                show_figure(fig)
                plt.show()
In [142]:
plotter = Plotter(test[1]['profile/Sox2'])
plotter.plot(sort='sum', n=6)
/opt/miniconda3/lib/python3.6/site-packages/numpy/core/numeric.py:492: ComplexWarning: Casting complex values to real discards the imaginary part
  return array(a, dtype, copy=False, order=order)
In [143]:
class GaussianPlotter:

    def __init__(self, ys):
        #size = (N, 1000, 2)
        self.ys = np.array(ys)
        
    def plot(self, binsize=10, n=10, fft='r', sort='random', figsize=(20, 2), fpath_template=None):
        if sort == 'random':
            idx_list = samplers.random(self.ys[0], n)
        elif sort == 'sum':
            idx_list = samplers.top_sum_count(self.ys, n)
            idx_list = idx_list[2:]
        else: # sort == 'max':
            idx_list = samplers.top_max_count(self.ys, n)
        
        
        
        for i, idx in enumerate(idx_list):
            
            bin0 = binify(self.ys[idx, :, 0], binsize=binsize)
            bin1 = binify(self.ys[idx, :, 1], binsize=binsize)
            
            fig = plt.figure(figsize=figsize)
            
            plt.subplot(141)
            if i == 0:
                plt.title("Binned")
            plt.plot(bin0)
            plt.plot(bin1)
            
            for j, sigma in enumerate([1.6, 2.0, 3.0]):
                gauss0 = gaussian_filter1d(bin0, sigma=sigma)
                gauss1 = gaussian_filter1d(bin1, sigma=sigma)
            
                plt.subplot(142 + j)
                if i == 0:
                    plt.title("Gaussian σ=%f" % sigma)
                plt.plot(gauss0)
                plt.plot(gauss1)
            
            if fpath_template is not None:
                plt.savefig(fpath_template.format(i) + '.png', dpi=600)
                plt.savefig(fpath_template.format(i) + '.pdf', dpi=600)
                plt.close(fig)    # close the figure
                show_figure(fig)
                plt.show()
In [144]:
plotter = GaussianPlotter(test[1]['profile/Sox2'])
plotter.plot(sort='sum', n=6)
In [145]:
class IRFFTPlotter:

    def __init__(self, ys):
        #size = (N, 1000, 2)
        self.ys = np.array(ys)
        
    def plot(self, binsize=10, n=10, fft='r', sort='random', figsize=(20, 2), fpath_template=None):
        if sort == 'random':
            idx_list = samplers.random(self.ys[0], n)
        elif sort == 'sum':
            idx_list = samplers.top_sum_count(self.ys, n)
            idx_list = idx_list[2:]
        else: # sort == 'max':
            idx_list = samplers.top_max_count(self.ys, n)
        
        
        
        for i, idx in enumerate(idx_list):
            
            bin0 = binify(self.ys[idx, :, 0], binsize=binsize)
            bin1 = binify(self.ys[idx, :, 1], binsize=binsize)
            
            fft0 = np.fft.rfft(bin0)
            fft1 = np.fft.rfft(bin1)
            
            length = len(bin0)
            
            fig = plt.figure(figsize=figsize)
            
            plt.subplot(141)
            if i == 0:
                plt.title("Binned")
            plt.plot(bin0)
            plt.plot(bin1)
            
            for j, lows in enumerate([5, 10, 15]):
                fft_low0 = fft0[:lows]
                fft_low1 = fft1[:lows]

                ifft_low0 = np.maximum(np.fft.irfft(fft_low0, n=length), 0)
                ifft_low1 = np.maximum(np.fft.irfft(fft_low1, n=length), 0)
            
                plt.subplot(142 + j)
                if i == 0:
                    plt.title("Low pass IFFT with %d freqs" % lows)
                plt.plot(ifft_low0)
                plt.plot(ifft_low1)
            
            if fpath_template is not None:
                plt.savefig(fpath_template.format(i) + '.png', dpi=600)
                plt.savefig(fpath_template.format(i) + '.pdf', dpi=600)
                plt.close(fig)    # close the figure
                show_figure(fig)
                plt.show()
In [146]:
plotter = IRFFTPlotter(test[1]['profile/Sox2'])
plotter.plot(sort='sum', n=6)
In [156]:
seq = train[1]['profile/Sox2'][100, :, 0]
binned = binify(seq)[20:]
print (binned)
plt.plot(binned)
[0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 2. 1. 4. 0. 1. 3. 1. 2.
 5. 3. 4. 2. 4. 2. 3. 4. 1. 4. 2. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0.]
Out[156]:
[<matplotlib.lines.Line2D at 0x7f281f3417b8>]
In [157]:
fft = np.fft.rfft(binned)
plt.plot(fft)
/opt/miniconda3/lib/python3.6/site-packages/numpy/core/numeric.py:492: ComplexWarning: Casting complex values to real discards the imaginary part
  return array(a, dtype, copy=False, order=order)
Out[157]:
[<matplotlib.lines.Line2D at 0x7f281f1ea7f0>]
In [158]:
ifft = np.fft.irfft(fft[:15], n=len(binned))
plt.plot(ifft)
Out[158]:
[<matplotlib.lines.Line2D at 0x7f282840ec88>]
In [159]:
from scipy.ndimage.filters import gaussian_filter1d
gauss = gaussian_filter1d(binned, sigma=1.2)

plt.plot(gauss)
Out[159]:
[<matplotlib.lines.Line2D at 0x7f281eebb160>]

setup a new model having two output branches

In [37]:
from basepair.math import softmax
from basepair import samplers
from basepair.preproc import bin_counts
import numpy as np

class Seq2Sox2Oct4:

    def __init__(self, x, y, model):
        self.x = x
        self.y = y 
        self.model = model
        # Make the prediction
        self.y_pred = [softmax(y) for y in model.predict(x)]
        
    def plot(self, n=10, kind='test', sort='random', figsize=(20, 2), fpath_template=None, binsize=1):
        import matplotlib.pyplot as plt
        if sort == 'random':
            idx_list = samplers.random(self.x, n)
        elif "_" in sort:
            kind, task = sort.split("_")
            #task_id = {"Sox2": 0, "Oct4": 1}[task]
            if kind == "max":
                idx_list = samplers.top_max_count(self.y[f"profile/{task}"], n)
            elif kind == "sum":
                idx_list = samplers.top_sum_count(self.y[f"profile/{task}"], n)
            else:
                raise ValueError("")
        else:
            raise ValueError(f"sort={sort} couldn't be interpreted")
            
        # for visualization, we use bucketize
        for i, idx in enumerate(idx_list):
            fig = plt.figure(figsize=figsize)
            plt.subplot(141)
            if i == 0:
                plt.title("Predicted Sox2")
            plt.plot(bin_counts(self.y_pred[0], binsize=binsize)[idx, :, 0], label='pos,m={}'.format(np.argmax(self.y_pred[0][idx, :, 0])))
            plt.plot(bin_counts(self.y_pred[0], binsize=binsize)[idx, :, 1], label='neg,m={}'.format(np.argmax(self.y_pred[0][idx, :, 1])))
            plt.legend()
            plt.subplot(142)
            if i == 0:
                plt.title("Observed Sox2")
            plt.plot(bin_counts(self.y["profile/Sox2"], binsize=binsize)[idx, :, 0], label='pos,m={}'.format(np.argmax(self.y["profile/Sox2"][idx, :, 0])))
            plt.plot(bin_counts(self.y["profile/Sox2"], binsize=binsize)[idx, :, 1], label='neg,m={}'.format(np.argmax(self.y["profile/Sox2"][idx, :, 1])))
            plt.legend()
            plt.subplot(143)
            if i == 0:
                plt.title("Predicted Oct4")
            plt.plot(bin_counts(self.y_pred[1], binsize=binsize)[idx, :, 0], label='pos,m={}'.format(np.argmax(self.y_pred[1][idx, :, 0])))
            plt.plot(bin_counts(self.y_pred[1], binsize=binsize)[idx, :, 1], label='neg,m={}'.format(np.argmax(self.y_pred[1][idx, :, 1])))
            plt.legend()
            plt.subplot(144)
            if i == 0:
                plt.title("Observed Oct4")
            plt.plot(bin_counts(self.y["profile/Oct4"], binsize=binsize)[idx, :, 0], label='pos,m={}'.format(np.argmax(self.y["profile/Oct4"][idx, :, 0])))
            plt.plot(bin_counts(self.y["profile/Oct4"], binsize=binsize)[idx, :, 1], label='neg,m={}'.format(np.argmax(self.y["profile/Oct4"][idx, :, 1])))
            plt.legend()
            if fpath_template is not None:
                plt.savefig(fpath_template.format(i) + '.png', dpi=600)
                plt.savefig(fpath_template.format(i) + '.pdf', dpi=600)
                plt.close(fig)    # close the figure
                show_figure(fig)
                plt.show()
In [41]:
pl = Seq2Sox2Oct4(test[0], test[1], model)
In [44]:
pl.plot(n=10, sort='sum_Sox2', binsize=50)