import numpy as np
from basepair.modisco.results import Seqlet
N = 100000
N_seqlets = 100000 # 100k seqlets
seqlen = 1000
profile_width = 70
# Generate all the data
profiles = np.random.randn(N, seqlen, 2)
starts = np.random.randint(0 , seqlen - profile_width, N_seqlets)
ends = starts + profile_width
strand = np.random.randint(0, 1, len(ends))
example_idx = np.random.randint(0, N, N_seqlets)
strand_str = np.where(strand, "-", "+")
def extract_signal(x, seqlets, rc_fn=lambda x: x[::-1, ::-1]):
def optional_rc(x, is_rc):
if is_rc:
return rc_fn(x)
else:
return x
return np.stack([optional_rc(x[s['example'], s['start']:s['end']], s['rc'])
for s in seqlets])
seqlets = [Seqlet(seqname=example_idx[i],
start=starts[i],
end=ends[i],
name=None,
strand=strand_str[i]
)
for i in range(len(starts))
]
%timeit out = extract_signal(profiles, seqlets)
from numba import jit
@jit
def extract_range(A, idx, starts, rc, length):
"""Fast extract_range using numba
Args:
A: array from which to extract the signal
idx: list of indices (row-idx)
starts: where the range starts
rc: whether to reverse complement the output -> applies a [::-1, ::-1] operation
length: what length to extract
Returns:
np.array of shape: len(idx, length) + A.shape[2:]
"""
out = np.empty((len(idx), length) + A.shape[2:])
for i in range(len(idx)):
if rc[i]:
out[i, ...] = A[idx[i], slice(starts[i], starts[i] + length, -1), ::-1]
else:
out[i, ...] = A[idx[i], slice(starts[i], starts[i] + length)]
return out
%timeit out2 = extract_range(profiles, example_idx, starts, rc=strand, length=70)
assert np.all(out == out2)