Goal

  • implement fast sub-array extraction

Conclusion

  • using numba, we only get a speedup of 5x
In [172]:
import numpy as np
In [165]:
from basepair.modisco.results import Seqlet
In [162]:
N = 100000
N_seqlets = 100000  # 100k seqlets
seqlen = 1000
profile_width = 70
In [163]:
# Generate all the data
profiles = np.random.randn(N, seqlen, 2)
starts = np.random.randint(0 , seqlen - profile_width, N_seqlets)
ends = starts + profile_width
strand = np.random.randint(0, 1, len(ends))
example_idx = np.random.randint(0, N, N_seqlets)
strand_str = np.where(strand, "-", "+")
In [164]:
def extract_signal(x, seqlets, rc_fn=lambda x: x[::-1, ::-1]):
    def optional_rc(x, is_rc):
        if is_rc:
            return rc_fn(x)
        else:
            return x
    return np.stack([optional_rc(x[s['example'], s['start']:s['end']], s['rc'])
                     for s in seqlets])
In [166]:
seqlets = [Seqlet(seqname=example_idx[i], 
                start=starts[i],
                end=ends[i],
                name=None,
                strand=strand_str[i]
               )
           for i in range(len(starts))
]
In [167]:
%timeit out = extract_signal(profiles, seqlets)
451 ms ± 22.6 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
In [168]:
from numba import jit
In [176]:
@jit
def extract_range(A, idx, starts, rc, length):
    """Fast extract_range using numba
    
    Args:
      A: array from which to extract the signal
      idx: list of indices (row-idx)
      starts: where the range starts
      rc: whether to reverse complement the output -> applies a [::-1, ::-1] operation
      length: what length to extract
    
    Returns:
      np.array of shape: len(idx, length) + A.shape[2:]
    """
    out = np.empty((len(idx), length) + A.shape[2:])
    for i in range(len(idx)):
        if rc[i]:
            out[i, ...] = A[idx[i], slice(starts[i], starts[i] + length, -1), ::-1]
        else:
            out[i, ...] = A[idx[i], slice(starts[i], starts[i] + length)]
    return out
In [177]:
%timeit out2 = extract_range(profiles, example_idx, starts, rc=strand, length=70)
177 ms ± 14.7 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
In [171]:
assert np.all(out == out2)