bpnet-lite PyPI Downloads Important bpnet-lite is not meant to replace the full service implementations of BPNet or ChromBPNet and is still under development. Please see the official repositories for those projects for complete TensorFlow/Keras implementations of those models along with tutorials on how to use them effectively. Although bpnet-lite is capable of loading models trained using these TensorFlow/Keras repositories into PyTorch and perfectly reproducing their outputs, and can train BPNet models to similar performance as the official BPNet repository, its fitting procedure for ChromBPNet models does not yet match and can sometimes significantly underperform those trained using the official ChromBPNet repository. bpnet-lite is a lightweight version of BPNet [paper | code] and ChromBPNet [preprint | code], containing PyTorch reference implementations of both models. Additionally, it contains efficient data loaders and common operations one would do with these trained models including calculating attributions, running TF-MoDISco, and performing marginalization experiments. These operations are wrapped in command-line tools for ease-of-use and organized in a pipeline command representing the standard workflow. This package is primarily meant to be used for prototyping new ideas that involve modifying the code and for loading models trained using the official repositories into PyTorch. Installation pip install bpnet-lite Data Preprocessing Note As of v0.9.0 you can now include BAM/SAM and .tsv/.tsv.gz files in the JSONs for the bpnet-lite command-line tool and the conversion to bigWigs will be automatically performed using bam2bw. Because bam2bw is fast (around ~500k records/second) it is not always necessary to separately preprocess your data anymore. BPNet and ChromBPNet models are both trained on read ends that have been mapped at basepair resolution (hence, the name). Accordingly, the data used for training is made up of integers with one count per read in the file (or two counts per fragment). Once you have used your favorite tool to align your FASTQ of reads to your genome of interest (we recommend ChroMAP), you can either use bam2bw to convert your BAM/SAM or fragment tsv/tsv.gz files to bigWig files, or put these raw data files in the JSON and have bpnet-lite automatically do the conversion for you. If you are using stranded data, e.g., ChIP-seq: bam2bw .bam .bam ... -s .chrom.sizes/.fa -n -v This command will create two bigWig files, one for the + strand and one for the - strand, using the name provided as the suffix. If you are using unstranded data: bam2bw .bam .bam ... -s .chrom.sizes/.fa -n -v -u If you have a file of fragments, usually formatted as a .tsv or .tsv.gz and coming from ATAC-seq or scATAC-seq data, you can use the -f flag to map both the start and end (end-1, specifically) instead of just the 5' end. You will probably also want the -u flag because the underlying data is unstranded. bam2bw .tsv.gz .tsv.gz ... -s .chrom.sizes/.fa -n -v -u -f These tools require positive loci (usually peaks for the respective activity or elements like promoters) and negative loci (usually GC-matched background sequences) for training. One or more BED files of positive loci are required from the user, potentially acquired by applying a tool like MACS2 to your .BAM files. The negative loci can be calculated using a command-line tool in this package, described later, or by specifying in the JSON that find_negatives: true. BPNet image BPNet is a convolutional neural network that maps nucleotide sequences to experimental readouts, e.g. ChIP-seq, ChIP-nexus, and ChIP-exo. It is composed of one big convolution layer, a series of dilated residual layers that mix information across distances, and another big convolution layer. Importantly, BPNet makes predictions for the total (log) read count in the region and also for the basepair resolution profiles, with these profiles being a probability vector over each position. Although these models achieve high predictive accuracy, their main purpose is to estimate the influence of non-coding variants and to extract principles of the cis-regulatory code underlying the readouts being modeled. Specifically, when paired with a feature attribution algorithm like DeepLIFT/SHAP or in silico saturation mutagenesis, these models can assign to each nucleotide an importance in the model's predictions. These attributions can shed insight into how individual loci work, and when considered genome-wide, algorithms like TF-MoDISco can identify the repeated high-attribution patterns. BPNet Command Line Tools bpnet-lite comes with a command-line tool, bpnet, that supports the steps necessary for training and using BPNet models. The fastest way to go from your raw data to results is to use the bpnet pipeline-json command followed by the bpnet pipeline command. bpnet pipeline-json -s hg38.fa -p peaks.bed.gz -i input1.bam -i input2.bam -c control1.bam -c control2.bam -n test -o pipeline.json -m JASPAR_2024.meme bpnet pipeline -p pipeline.json The pipeline-json command takes in pointers to your data files and produces a properly formatted pipeline.json file. These data files usually include a reference genome, some number of input (and optionally control) BAM/SAM/tsv/tsv.gz files (the -i and -c arguments can be repeated) a BED file of positive loci, and a MEME formatted motif database used for evaluation of the model. The pipeline command takes in the JSON and (0) optionally preprocesses your BAM/SAM/tsv/tsv.gz files and identifies GC-matched negatives (you can provide your own bigWigs and/or negatives and skip the respective portions of this), (1) trains a BPNet model, (2) makes predictions on the provided loci, (3) calculates DeepLIFT/SHAP attributions on the provided loci, (4) calls seqlets and annotates them using ttl, (5) runs TF-MoDISco and generates a report, and (6) runs in silico marginalizations using the provided motif database. These commands are separated because, although the first command produces a valid JSON that the second command can immediately use (no need to copy/paste JSONs from this GitHub anymore!), one may wish to modify some of the many parameters in the JSON. These parameters include the number of filters and layers in the model, the training and validation chromosomes, and the even very technical ones like the number of shuffles to use when calculating attributions and the p-value threshold for calling seqlets. The defaults for most of these steps seem reasonable in practice but there is immense flexibility there, e.g., the ability to train the model using a reference genome and then make predictions or attributions on synthetic sequences or the reference genome from another species. In this manner, the JSON serves as documentation for the experiments that have been performed. When running the pipeline, a JSON is produced for each one of the steps (except for running TF-MoDISco and annotating the seqlets, which uses ttl). Each of these JSON can be run by themselves using the appropriate built-in command. Because some of the values in the JSONs for these steps are set programmatically when running the file pipeline, e.g., the filenames to read in and save to, being able to inspect every one of the JSONs can be handy for debugging. bpnet fit -p bpnet_fit_example.json bpnet predict -p bpnet_predict_example.json bpnet attribute -p bpnet_attribute_example.json bpnet seqlets -p bpnet_seqlet_example.json bpnet marginalize -p bpnet_marginalize_example.json For a complete description of each of the JSONs and the command-line tools, see the example_jsons folder. ChromBPNet image Warning Several users have reported that the performance of ChromBPNet models trained using bpnet-lite significantly underperforms those trained using the official ChromBPNet repo. We are currently looking into this. Until we resolve the differences, please consider using the official repository for training your ChromBPNet models and then bpnet-lite for loading them into PyTorch. ChromBPNet extends the original modeling framework to DNase-seq and ATAC-seq experiments. A separate framework is necessary because the cutting enzymes used in these experiments, particularly the hyperactive Tn5 enzyme used in ATAC-seq experiments, have soft sequences preferences that can distort the observed readouts. Hence, it becomes necessary to train a small BPNet model to explicitly capture this soft sequence (the "bias model") bias before subsequently training a second BPNet model jointly with the frozen bias model to capture the true drivers of accessibility (the "accessibiity model"). Together, these models and the manner in which their predictions are combined are referred to as ChromBPNet. Generally, one can perform the same analyses using ChromBPNet as one can using BPNet. However, an important note is that the full ChromBPNet model faithfully represents the experimental readout -- bias and all -- and so for more inspection tasks, e.g. variant effect prediction and interpretation, one should use only the accessibility model. Because the accessibiity model itself is conceptually, and also literally implemented as, a BPNet model, one can run the same procedure and use the BPNet command-line tool using it. bpnet-lite comes with a second command-line tool, chrombpnet, that supports the steps necessary for training and using ChromBPNet models. These commands are used exactly the same way as the bpnet command-line tool with only minor changes to the parameters in the JSON. Note that the predict, attribute and marginalize commands will internally run their bpnet counterparts, but are still provided for convenience. chrombpnet fit -p chrombpnet_fit_example.json chrombpnet predict -p chrombpnet_predict_example.json chrombpnet attribute -p chrombpnet_attribute_example.json chrombpnet marginalize -p chrombpnet_marginalize_example.json Similarly to bpnet, one can run the entire pipeline of commands specified above in addition to also running TF-MoDISco and generating a report on the found motifs. Unlike bpnet, this command will run each of those steps for (1) the full ChromBPNet model, (2) the accessibility model alone, and (3) the bias model. chrombpnet pipeline -p chrombpnet_pipeline_example.json Python API Warning This is no longer accurate as of v0.9.2 with the switch to the PeakNegativeSampler. I will update soon. If you'd rather train and use BPNet/ChromBPNet models programmatically, you can use the Python API. The command-line tool is made up of wrappers around these methods and functions, so please take a look if you'd like additional documentation on how to get started. The first step is loading data. Much like with the command-line tool, if you're using the built-in data loader then you need to specify where the FASTA containing sequences, a BED file containing loci and bigwig files to train on are. The signals need to be provided in a list and the index of each bigwig in the list will correspond to a model output. Optionally, you can also provide control bigwigs. See the BPNet paper for how these control bigwigs get used during training. import torch from tangermeme.io import extract_loci from bpnetlite.io import PeakGenerator from bpnetlite import BPNet peaks = 'test/CTCF.peaks.bed' # A set of loci to train on. seqs = '../../oak/common/hg38/hg38.fa' # A set of sequences to train on signals = ['test/CTCF.plus.bw', 'test/CTCF.minus.bw'] # A set of bigwigs controls = ['test/CTCF.plus.ctl.bw', 'test/CTCF.minus.ctl.bw'] # A set of bigwigs After specifying filepaths for each of these, you can create the data generator. If you have a set of chromosomes you'd like to use for training, you can pass those in as well. They must match exactly with the names of chromsomes given in the BED file. training_chroms = ['chr{}'.format(i) for i in range(1, 17)] training_data = PeakGenerator(peaks, seqs, signals, controls, chroms=training_chroms) The PeakGenerator function is a wrapper around several functions that extract data, pass them into a generator that applies shifts and shuffling, and pass that generator into a PyTorch data loader object for use during training. The end result is an object that can be directly iterated over while training a bpnet-lite model. Although wrapping all that functionality is good for the training set, the validation set should remain constant during training. Hence, one should only use the extract_loci function that is the first step when handling the training data. valid_chroms = ['chr{}'.format(i) for i in range(18, 23)] X_valid, y_valid, X_ctl_valid = extract_loci(peaks, seqs, signals, controls, chroms=valid_chroms, max_jitter=0) Note that this function can be used without control tracks and, in that case, will only return two arguments. Further, it can used with only a FASTA and will only return one argument in that case: the extracted sequences. Now, we can define the model. If you want to change the architecture, check out the documentation. model = BPNet(n_outputs=2, n_control_tracks=2, trimming=(2114 - 1000) // 2).cuda() optimizer = torch.optim.Adam(model.parameters(), lr=0.001) And, finally, we can call the fit_generator method to train the model. This function is largely just a training loop that trains the profile head using the multinomial log-likelihood loss and the count head using the mean-squared error loss, but a benefit of this built-in method is that it outputs a tsv of the training statistics that you can redirect to a log file. model.fit(training_data, optimizer, X_valid=X_valid, X_ctl_valid=X_ctl_valid, y_valid=y_valid) Because model is a PyTorch object, it can be trained using a custom training loop in the same way any base PyTorch model can be trained if you'd prefer to do that. Likewise, if you'd prefer to use a custom data generator you can write your own and pass that into the fit function. --- bpnetlite/chrombpnet.py # chrombpnet.py # Author: Jacob Schreiber import h5py import time import numpy import torch from .bpnet import BPNet from .losses import MNLLLoss from .losses import log1pMSELoss from .performance import calculate_performance_measures from .logging import Logger from tqdm import trange from tangermeme.predict import predict class _Exp(torch.nn.Module): def __init__(self): super(_Exp, self).__init__() def forward(self, X): return torch.exp(X) class _Log(torch.nn.Module): def __init__(self): super(_Log, self).__init__() def forward(self, X): return torch.log(X) class ChromBPNet(torch.nn.Module): """A ChromBPNet model. ChromBPNet is an extension of BPNet to handle chromatin accessibility data, in contrast to the protein binding data that BPNet handles. The distinction between these data types is that an enzyme used in DNase-seq and ATAC-seq experiments itself has a soft sequence preference, meaning that the strength of the signal is driven by real biology but that the exact read mapping locations are driven by the soft sequence bias of the enzyme. ChromBPNet handles this by treating the data using two models: a bias model that is initially trained on background (non-peak) regions where the bias dominates, and an accessibility model that is subsequently trained using a frozen version of the bias model. The bias model learns to remove the enzyme bias so that the accessibility model can learn real motifs. Parameters ---------- bias: torch.nn.Module This model takes in sequence and outputs the shape one would expect in ATAC-seq data due to Tn5 bias alone. This is usually a BPNet model from the bpnet-lite repo that has been trained on GC-matched non-peak regions. accessibility: torch.nn.Module This model takes in sequence and outputs the accessibility one would expect due to the components of the sequence, but also takes in a cell representation which modifies the parameters of the model, hence, "dynamic." This model is usually a DynamicBPNet model, defined below. name: str The name to prepend when saving the file. """ def __init__(self, bias, accessibility, name): super(ChromBPNet, self).__init__() for parameter in bias.parameters(): parameter.requires_grad = False self.bias = bias self.accessibility = accessibility self.name = name self.logger = None self.n_control_tracks = accessibility.n_control_tracks self.n_outputs = 1 self._log = _Log() self._exp1 = _Exp() self._exp2 = _Exp() def forward(self, X, X_ctl=None): """A forward pass through the network. This function is usually accessed through calling the model, e.g. doing `model(x)`. The method defines how inputs are transformed into the outputs through interactions with each of the layers. Parameters ---------- X: torch.tensor, shape=(-1, 4, 2114) A one-hot encoded sequence tensor. X_ctl: ignore An ignored parameter for consistency with attribution functions. Returns ------- y_profile: torch.tensor, shape=(-1, 1000) The predicted logit profile for each example. Note that this is not a normalized value. """ acc_profile, acc_counts = self.accessibility(X) bias_profile, bias_counts = self.bias(X) n0, n1 = acc_profile.shape[-1], bias_profile.shape[-1] w = (n1 - n0) // 2 y_profile = acc_profile + bias_profile[:, :, w:-w] y_counts = self._log(self._exp1(acc_counts) + self._exp2(bias_counts)) return y_profile, y_counts def fit(self, training_data, optimizer, X_valid=None, y_valid=None, max_epochs=100, batch_size=64, validation_iter=100, dtype='float32', device='cuda', early_stopping=None, verbose=True): """Fit the ChromBPNet model to data. Specifically, this function will fit the accessibility model to observed chromatin accessibility data, and assume that the bias model is frozen and pre-trained. Hence, the only parameters being trained in this function are those in the accessibility model. This function will save the best full ChromBPNet model, as well as the best accessibility model, found during training. Parameter --------- training_data: torch.utils.data.DataLoader A data set that generates one-hot encoded sequence as input and read count signal for the output. optimizer: torch.optim.Optimizer A PyTorch optimizer. X_valid: torch.Tensor or None, shape=(-1, 4, length) A tensor of one-hot encoded sequences to use as input for the validation steps. If None, do not do validation. Default is None. y_valid: torch.Tensor or None, shape=(-1, 1, length) A tensor of read counts matched with the `X_valid` input. If None, do not do validation. Default is None. max_epochs: int The maximum number of training epochs to perform before stopping. Default is 100. batch_size: int The number of examples to use in each batch. Default is 64. validation_iter: int The number of training batches to perform before doing another round of validation. Set higher to spend a higher percentage of time in the training step. dtype: str or torch.dtype The torch.dtype to use when training. Usually, this will be torch.float32 or torch.bfloat16. Default is torch.float32. device: str The device to use for training and inference. Typically, this will be 'cuda' but can be anything supported by torch. Default is 'cuda'. early_stopping: int or None Whether to stop training early. If None, continue training until max_epochs is reached. If an integer, continue training until that number of `validation_iter` ticks has been hit without improvement in performance. Default is None. verbose: bool Whether to print the log as it is being generated. A log will be returned at the end of training regardless of this option, but when False, nothing will be printed to the screen during training. Default is False """ print("Warning: BPNet and ChromBPNet models trained using bpnet-lite may underperform those trained using the official repositories. See the GitHub README for further documentation.") dtype = getattr(torch, dtype) if isinstance(dtype, str) else dtype y_bias_profile, y_bias_counts = predict(self.bias, X_valid, batch_size=batch_size, dtype=dtype, device='cuda') self.logger = Logger(["Epoch", "Iteration", "Training Time", "Validation Time", "Training MNLL", "Training Count MSE", "Validation MNLL", "Validation Profile Correlation", "Validation Count Pearson", "Validation Count MSE", "Saved?"], verbose=verbose) early_stop_count = 0 start, best_loss = time.time(), float("inf") self.logger.start() self.bias.eval() for epoch in range(max_epochs): for iteration, (X, y) in enumerate(training_data): self.accessibility.train() X = X.cuda().float() y = y.cuda() optimizer.zero_grad() with torch.autocast(device_type=device, dtype=dtype): acc_profile, acc_counts = self.accessibility(X) bias_profile, bias_counts = self.bias(X) y_profile = torch.nn.functional.log_softmax(acc_profile + bias_profile, dim=-1) y_counts = torch.logsumexp(torch.stack([acc_counts, bias_counts]), dim=0) profile_loss = MNLLLoss(y_profile, y).mean() count_loss = log1pMSELoss(y_counts, y.sum(dim=-1).reshape(-1, 1)).mean() profile_loss_ = profile_loss.item() count_loss_ = count_loss.item() loss = profile_loss + self.accessibility.alpha * count_loss loss.backward() optimizer.step() if verbose and iteration % validation_iter == 0: train_time = time.time() - start tic = time.time() with torch.no_grad(): self.accessibility.eval() y_profile, y_counts = predict(self.accessibility, X_valid, batch_size=batch_size, device=device, dtype=dtype) y_profile = torch.nn.functional.log_softmax( y_profile + y_bias_profile, dim=-1) y_counts = torch.logsumexp(torch.stack([y_counts, y_bias_counts]), dim=0) measures = calculate_performance_measures(y_profile, y_valid, y_counts, kernel_sigma=7, kernel_width=81, measures=['profile_mnll', 'profile_pearson', 'count_pearson', 'count_mse']) profile_corr = measures['profile_pearson'] count_corr = measures['count_pearson'] valid_loss = measures['profile_mnll'].mean() valid_loss += self.accessibility.alpha * measures['count_mse'].mean() valid_time = time.time() - tic self.logger.add([epoch, iteration, train_time, valid_time, profile_loss_, count_loss_, measures['profile_mnll'].mean().item(), numpy.nan_to_num(profile_corr).mean(), numpy.nan_to_num(count_corr).mean(), measures['count_mse'].mean().item(), (valid_loss < best_loss).item()]) if valid_loss < best_loss: torch.save(self, "{}.torch".format(self.name)) torch.save(self.accessibility, "{}.accessibility.torch".format(self.name)) best_loss = valid_loss early_stop_count = 0 else: early_stop_count += 1 start = time.time() if early_stopping is not None and early_stop_count >= early_stopping: break self.logger.save("{}.log".format(self.name)) if early_stopping is not None and early_stop_count >= early_stopping: break torch.save(self, "{}.final.torch".format(self.name)) torch.save(self, "{}.accessibility.final.torch".format(self.name)) @classmethod def from_chrombpnet_lite(self, bias_model, accessibility_model, name): """Load a ChromBPNet model trained in ChromBPNet-lite. Confusingly, ChromBPNet-lite is a package written by Surag Nair that reorganized the ChromBPNet library and then was reintegrated back into it. However, some ChromBPNet models are still around that were trained using this package and this is a method for loading those models, not the models trained using the ChromBPNet package and not ChromBPNet models trained using this package. This method takes in paths to a h5 file containing the weights of the bias model and the accessibility model, both trained and whose outputs are organized according to TensorFlow. The weights are loaded and shaped into a PyTorch model and can be used as such. Parameters ---------- bias model: str The filename of the bias model. accessibility_model: str The filename of the accessibility model. name: str The name to use when training the model and outputting to a file. Returns ------- model: bpnetlite.models.ChromBPNet A PyTorch ChromBPNet model compatible with the bpnet-lite package. """ bias = BPNet.from_chrombpnet_lite(bias_model) acc = BPNet.from_chrombpnet_lite(accessibility_model) return ChromBPNet(bias, acc, name) @classmethod def from_chrombpnet(self, bias_model, accessibility_model, name): """Load a ChromBPNet model trained using the official repository. This method takes in the path to a .h5 file containing the full model, i.e., the bias model AND the accessibility model. If you have two files -- one for the bias model, and one for the accessibility model -- load them up as separate BPNet models and create a ChromBPNet object afterwards. Parameters ---------- bias model: str The filename of the bias model. accessibility_model: str The filename of the accessibility model. name: str The name to use when training the model and outputting to a file. Returns ------- model: bpnetlite.models.ChromBPNet A PyTorch ChromBPNet model compatible with the bpnet-lite package. """ bias = BPNet.from_chrombpnet(bias_model) acc = BPNet.from_chrombpnet(accessibility_model) return ChromBPNet(bias, acc, name) --- bpnetlite/bpnet.py # bpnet.py # Author: Jacob Schreiber """ This module contains a reference implementation of BPNet that can be used or adapted for your own circumstances. The implementation takes in a stranded control track and makes predictions for stranded outputs. """ import h5py import time import numpy import torch from .losses import MNLLLoss from .losses import log1pMSELoss from .losses import _mixture_loss from .performance import pearson_corr from .performance import calculate_performance_measures from .logging import Logger from tqdm import tqdm from tangermeme.predict import predict torch.backends.cudnn.benchmark = True class ControlWrapper(torch.nn.Module): """This wrapper automatically creates a control track of all zeroes. This wrapper will check to see whether the model is expecting a control track (e.g., most BPNet-style models) and will create one with the expected shape. If no control track is expected then it will provide the normal output from the model. """ def __init__(self, model): super(ControlWrapper, self).__init__() self.model = model def forward(self, X, X_ctl=None): if X_ctl != None: return self.model(X, X_ctl) if self.model.n_control_tracks == 0: return self.model(X) X_ctl = torch.zeros(X.shape[0], self.model.n_control_tracks, X.shape[-1], dtype=X.dtype, device=X.device) return self.model(X, X_ctl) class _ProfileLogitScaling(torch.nn.Module): """This ugly class is necessary because of Captum. Captum internally registers classes as linear or non-linear. Because the profile wrapper performs some non-linear operations, those operations must be registered as such. However, the inputs to the wrapper are not the logits that are being modified in a non-linear manner but rather the original sequence that is subsequently run through the model. Hence, this object will contain all of the operations performed on the logits and can be registered. Parameters ---------- logits: torch.Tensor, shape=(-1, -1) The logits as they come out of a Chrom/BPNet model. """ def __init__(self): super(_ProfileLogitScaling, self).__init__() self.softmax = torch.nn.Softmax(dim=-1) def forward(self, logits): y_softmax = self.softmax(logits) return logits * y_softmax class ProfileWrapper(torch.nn.Module): """A wrapper class that returns transformed profiles. This class takes in a trained model and returns the weighted softmaxed outputs of the first dimension. Specifically, it takes the predicted "logits" and takes the dot product between them and the softmaxed versions of those logits. This is for convenience when using captum to calculate attribution scores. Parameters ---------- model: torch.nn.Module A torch model to be wrapped. """ def __init__(self, model): super(ProfileWrapper, self).__init__() self.model = model self.flatten = torch.nn.Flatten() self.scaling = _ProfileLogitScaling() def forward(self, X, X_ctl=None, **kwargs): logits = self.model(X, X_ctl, **kwargs)[0] logits = self.flatten(logits) logits = logits - torch.mean(logits, dim=-1, keepdims=True) return self.scaling(logits).sum(dim=-1, keepdims=True) class CountWrapper(torch.nn.Module): """A wrapper class that only returns the predicted counts. This class takes in a trained model and returns only the second output. For BPNet models, this means that it is only returning the count predictions. This is for convenience when using captum to calculate attribution scores. Parameters ---------- model: torch.nn.Module A torch model to be wrapped. """ def __init__(self, model): super(CountWrapper, self).__init__() self.model = model def forward(self, X, X_ctl=None, **kwargs): return self.model(X, X_ctl, **kwargs)[1] class BPNet(torch.nn.Module): """A basic BPNet model with stranded profile and total count prediction. This is a reference implementation for BPNet models. It exactly matches the architecture in the official ChromBPNet repository. It is very similar to the implementation in the official basepairmodels repository but differs in when the activation function is applied for the resifual layers. See the BasePairNet object below for an implementation that matches that repository. The model takes in one-hot encoded sequence, runs it through: (1) a single wide convolution operation THEN (2) a user-defined number of dilated residual convolutions THEN (3a) profile predictions done using a very wide convolution layer that also takes in stranded control tracks AND (3b) total count prediction done using an average pooling on the output from 2 followed by concatenation with the log1p of the sum of the stranded control tracks and then run through a dense layer. This implementation differs from the original BPNet implementation in two ways: (1) The model concatenates stranded control tracks for profile prediction as opposed to adding the two strands together and also then smoothing that track (2) The control input for the count prediction task is the log1p of the strand-wise sum of the control tracks, as opposed to the raw counts themselves. (3) A single log softmax is applied across both strands such that the logsumexp of both strands together is 0. Put another way, the two strands are concatenated together, a log softmax is applied, and the MNLL loss is calculated on the concatenation. (4) The count prediction task is predicting the total counts across both strands. The counts are then distributed across strands according to the single log softmax from 3. Parameters ---------- n_filters: int, optional The number of filters to use per convolution. Default is 64. n_layers: int, optional The number of dilated residual layers to include in the model. Default is 8. n_outputs: int, optional The number of profile outputs from the model. Generally either 1 or 2 depending on if the data is unstranded or stranded. Default is 2. n_control_tracks: int, optional The number of control tracks to feed into the model. When predicting TFs, this is usually 2. When predicting accessibility, this is usualy 0. When 0, this input is removed from the model. Default is 2. count_loss_weight: float, optional The weight to put on the count loss. profile_output_bias: bool, optional Whether to include a bias term in the final profile convolution. Removing this term can help with attribution stability and will usually not affect performance. Default is True. count_output_bias: bool, optional Whether to include a bias term in the linear layer used to predict counts. Removing this term can help with attribution stability but may affect performance. Default is True. name: str or None, optional The name to save the model to during training. trimming: int or None, optional The amount to trim from both sides of the input window to get the output window. This value is removed from both sides, so the total number of positions removed is 2*trimming. verbose: bool, optional Whether to display statistics during training. Setting this to False will still save the file at the end, but does not print anything to screen during training. Default is True. """ def __init__(self, n_filters=64, n_layers=8, n_outputs=2, n_control_tracks=2, count_loss_weight=1, profile_output_bias=True, count_output_bias=True, name=None, trimming=None, verbose=True): super(BPNet, self).__init__() self.n_filters = n_filters self.n_layers = n_layers self.n_outputs = n_outputs self.n_control_tracks = n_control_tracks self.count_loss_weight = count_loss_weight self.name = name or "bpnet.{}.{}".format(n_filters, n_layers) self.trimming = trimming or 47 + sum(2**i for i in range(1, n_layers+1)) self.iconv = torch.nn.Conv1d(4, n_filters, kernel_size=21, padding=10) self.irelu = torch.nn.ReLU() self.rconvs = torch.nn.ModuleList([ torch.nn.Conv1d(n_filters, n_filters, kernel_size=3, padding=2**i, dilation=2**i) for i in range(1, self.n_layers+1) ]) self.rrelus = torch.nn.ModuleList([ torch.nn.ReLU() for i in range(1, self.n_layers+1) ]) self.fconv = torch.nn.Conv1d(n_filters+n_control_tracks, n_outputs, kernel_size=75, padding=37, bias=profile_output_bias) n_count_control = 1 if n_control_tracks > 0 else 0 self.linear = torch.nn.Linear(n_filters+n_count_control, 1, bias=count_output_bias) self.logger = Logger(["Epoch", "Iteration", "Training Time", "Validation Time", "Training MNLL", "Training Count MSE", "Validation MNLL", "Validation Profile Pearson", "Validation Count Pearson", "Validation Count MSE", "Saved?"], verbose=verbose) def forward(self, X, X_ctl=None): """A forward pass of the model. This method takes in a nucleotide sequence X, a corresponding per-position value from a control track, and a per-locus value from the control track and makes predictions for the profile and for the counts. This per-locus value is usually the log(sum(X_ctl_profile)+1) when the control is an experimental read track but can also be the output from another model. Parameters ---------- X: torch.tensor, shape=(batch_size, 4, length) The one-hot encoded batch of sequences. X_ctl: torch.tensor or None, shape=(batch_size, n_strands, length) A value representing the signal of the control at each position in the sequence. If no controls, pass in None. Default is None. Returns ------- y_profile: torch.tensor, shape=(batch_size, n_strands, out_length) The output predictions for each strand trimmed to the output length. """ start, end = self.trimming, X.shape[2] - self.trimming X = self.irelu(self.iconv(X)) for i in range(self.n_layers): X_conv = self.rrelus[i](self.rconvs[i](X)) X = torch.add(X, X_conv) if X_ctl is None: X_w_ctl = X else: X_w_ctl = torch.cat([X, X_ctl], dim=1) y_profile = self.fconv(X_w_ctl)[:, :, start:end] # counts prediction X = torch.mean(X[:, :, start-37:end+37], dim=2) if X_ctl is not None: X_ctl = torch.sum(X_ctl[:, :, start-37:end+37], dim=(1, 2)) X_ctl = X_ctl.unsqueeze(-1) X = torch.cat([X, torch.log(X_ctl+1)], dim=-1) y_counts = self.linear(X).reshape(X.shape[0], 1) return y_profile, y_counts def fit(self, training_data, optimizer, scheduler=None, X_valid=None, X_ctl_valid=None, y_valid=None, max_epochs=100, batch_size=64, dtype='float32', device='cuda', early_stopping=None): """Fit the model to data and validate it periodically. This method controls the training of a BPNet model. It will fit the model to examples generated by the `training_data` DataLoader object and, if validation data is provided, will validate the model against it at the end of each epoch and return those values. Two versions of the model will be saved: the best model found during training according to the validation measures, and the final model at the end of training. Additionally, a log will be saved of the training and validation statistics, e.g. time and performance. Parameters ---------- training_data: torch.utils.data.DataLoader A generator that produces examples to train on. If n_control_tracks is greater than 0, must product two inputs, otherwise must produce only one input. optimizer: torch.optim.Optimizer An optimizer to control the training of the model. scheduler: torch.optim.lr_scheduler, optional An optional learning rate scheduler which changes the learning rate across batches. If None, do not use a scheduler. Default is None. X_valid: torch.tensor or None, shape=(n, 4, 2114) A block of sequences to validate on periodically. If None, do not perform validation. Default is None. X_ctl_valid: torch.tensor or None, shape=(n, n_control_tracks, 2114) A block of control sequences to validate on periodically. If n_control_tracks is None, pass in None. Default is None. y_valid: torch.tensor or None, shape=(n, n_outputs, 1000) A block of signals to validate against. Must be provided if X_valid is also provided. Default is None. max_epochs: int The maximum number of epochs to train for, as measured by the number of times that `training_data` is exhausted. Default is 100. batch_size: int, optional The number of examples to include in each batch. Default is 64. dtype: str or torch.dtype The torch.dtype to use when training. Usually, this will be torch.float32 or torch.bfloat16. Default is torch.float32. device: str The device to use for training and inference. Typically, this will be 'cuda' but can be anything supported by torch. Default is 'cuda'. early_stopping: int or None, optional Whether to stop training early. If None, continue training until max_epochs is reached. If an integer, continue training until that number of epochs has been hit without improvement in performance. Default is None. """ print("Warning: BPNet and ChromBPNet models trained using bpnet-lite may underperform those trained using the official repositories. See the GitHub README for further documentation.") if X_valid is not None: y_valid_counts = y_valid.sum(dim=2) if X_ctl_valid is not None: X_ctl_valid = (X_ctl_valid,) dtype = getattr(torch, dtype) if isinstance(dtype, str) else dtype iteration = 0 early_stop_count = 0 best_loss = float("inf") self.logger.start() for epoch in range(max_epochs): tic = time.time() for data in training_data: X, y, labels = data[0], data[-2], data[-1] X_ctl = data[1].to(device) if len(data) == 4 else None X = X.to(device).float() y = y.to(device) # Clear the optimizer and set the model to training mode optimizer.zero_grad() self.train() # Make one training step with torch.autocast(device_type=device, dtype=dtype): y_hat_logits, y_hat_logcounts = self(X, X_ctl) training_profile_loss_, training_count_loss_, loss = _mixture_loss(y, y_hat_logits, y_hat_logcounts, self.count_loss_weight, labels) loss.backward() #torch.nn.utils.clip_grad_norm_(self.parameters(), 0.5) optimizer.step() iteration += 1 train_time = time.time() - tic # Validate the model at the end of the epoch with torch.no_grad(): self.eval() tic = time.time() y_hat_logits, y_hat_logcounts = predict(self, X_valid, args=X_ctl_valid, batch_size=batch_size, dtype=dtype, device=device) valid_profile_loss, valid_count_loss, valid_loss = _mixture_loss(y_valid, y_hat_logits, y_hat_logcounts, self.count_loss_weight) measures = calculate_performance_measures(y_hat_logits, y_valid, y_hat_logcounts, kernel_sigma=7, kernel_width=81, measures=['profile_pearson', 'count_pearson']) valid_profile_corr = numpy.nan_to_num(measures['profile_pearson']) valid_count_corr = numpy.nan_to_num(measures['count_pearson']) valid_time = time.time() - tic self.logger.add([epoch, iteration, train_time, valid_time, training_profile_loss_, training_count_loss_, valid_profile_loss, valid_profile_corr.mean(), valid_count_corr.mean(), valid_count_loss, (valid_loss < best_loss).item()]) self.logger.save("{}.log".format(self.name)) if valid_loss < best_loss: torch.save(self, "{}.torch".format(self.name)) best_loss = valid_loss early_stop_count = -1 if scheduler is not None: scheduler.step(valid_loss) early_stop_count += 1 if early_stopping is not None and early_stop_count >= early_stopping: break torch.save(self, "{}.final.torch".format(self.name)) @classmethod def from_chrombpnet_lite(cls, filename): """Loads a model from ChromBPNet-lite TensorFlow format. This method will load a ChromBPNet-lite model from TensorFlow format. Note that this is not the same as ChromBPNet format. Specifically, ChromBPNet-lite was a preceeding package that had a slightly different saving format, whereas ChromBPNet is the packaged version of that code that is applied at scale. This method does not load the entire ChromBPNet model. If that is the desired behavior, see the `ChromBPNet` object and its associated loading functions. Instead, this loads a single BPNet model -- either the bias model or the accessibility model, depending on what is encoded in the stored file. Parameters ---------- filename: str The name of the h5 file that stores the trained model parameters. Returns ------- model: BPNet A BPNet model compatible with this repository in PyTorch. """ h5 = h5py.File(filename, "r") w = h5['model_weights'] if 'model_1' in w.keys(): w = w['model_1'] bias = False else: bias = True k, b = 'kernel:0', 'bias:0' name = "conv1d_{}_1" if not bias else "conv1d_{0}/conv1d_{0}" layer_names = [] for layer_name in w.keys(): try: idx = int(layer_name.split("_")[1]) layer_names.append(idx) except: pass n_filters = w[name.format(1)][k].shape[2] n_layers = max(layer_names) - 2 model = BPNet(n_layers=n_layers, n_filters=n_filters, n_outputs=1, n_control_tracks=0, trimming=None) convert_w = lambda x: torch.nn.Parameter(torch.tensor( x[:]).permute(2, 1, 0)) convert_b = lambda x: torch.nn.Parameter(torch.tensor(x[:])) model.iconv.weight = convert_w(w[name.format(1)][k]) model.iconv.bias = convert_b(w[name.format(1)][b]) model.iconv.padding = 12 for i in range(2, n_layers+2): model.rconvs[i-2].weight = convert_w(w[name.format(i)][k]) model.rconvs[i-2].bias = convert_b(w[name.format(i)][b]) model.fconv.weight = convert_w(w[name.format(n_layers+2)][k]) model.fconv.bias = convert_b(w[name.format(n_layers+2)][b]) model.fconv.padding = 12 name = "logcounts_1" if not bias else "logcounts/logcounts" model.linear.weight = torch.nn.Parameter(torch.tensor(w[name][k][:].T)) model.linear.bias = convert_b(w[name][b]) return model @classmethod def from_chrombpnet(cls, filename): """Loads a model from ChromBPNet TensorFlow format. This method will load one of the components of a ChromBPNet model from TensorFlow format. Note that a full ChromBPNet model is made up of an accessibility model and a bias model and that this will load one of the two. Use `ChromBPNet.from_chrombpnet` to end up with the entire ChromBPNet model. Parameters ---------- filename: str The name of the h5 file that stores the trained model parameters. Returns ------- model: BPNet A BPNet model compatible with this repository in PyTorch. """ h5 = h5py.File(filename, "r") w = h5['model_weights'] if 'bpnet_1conv' in w.keys(): prefix = "" else: prefix = "wo_bias_" namer = lambda prefix, suffix: '{0}{1}/{0}{1}'.format(prefix, suffix) k, b = 'kernel:0', 'bias:0' n_layers = 0 for layer_name in w.keys(): try: idx = int(layer_name.split("_")[-1].replace("conv", "")) n_layers = max(n_layers, idx) except: pass name = namer(prefix, "bpnet_1conv") n_filters = w[name][k].shape[2] model = BPNet(n_layers=n_layers, n_filters=n_filters, n_outputs=1, n_control_tracks=0) convert_w = lambda x: torch.nn.Parameter(torch.tensor( x[:]).permute(2, 1, 0)) convert_b = lambda x: torch.nn.Parameter(torch.tensor(x[:])) iname = namer(prefix, 'bpnet_1st_conv') model.iconv.weight = convert_w(w[iname][k]) model.iconv.bias = convert_b(w[iname][b]) model.iconv.padding = ((21 - 1) // 2,) for i in range(1, n_layers+1): lname = namer(prefix, 'bpnet_{}conv'.format(i)) model.rconvs[i-1].weight = convert_w(w[lname][k]) model.rconvs[i-1].bias = convert_b(w[lname][b]) prefix = prefix + "bpnet_" if prefix != "" else "" fname = namer(prefix, 'prof_out_precrop') model.fconv.weight = convert_w(w[fname][k]) model.fconv.bias = convert_b(w[fname][b]) model.fconv.padding = ((75 - 1) // 2,) name = namer(prefix, "logcount_predictions") model.linear.weight = torch.nn.Parameter(torch.tensor(w[name][k][:].T)) model.linear.bias = convert_b(w[name][b]) return model class BasePairNet(torch.nn.Module): """A BPNet implementation matching that in basepairmodels This is a BPNet implementation that matches the one in basepairmodels and can be used to load models trained from that repository, e.g., those trained as part of the atlas project. The architecture of the model is identical to `BPNet` except that output from the residual layers is added to the pre-activation outputs from the previous layer, rather than to the post-activation outputs from the previous layer. Additionally, the count prediction head takes the sum of the control track counts, adds two instead of one, and then takes the log. Neither detail dramatically changes performance of the model but is necessary to account for when loading trained models. Parameters ---------- n_filters: int, optional The number of filters to use per convolution. Default is 64. n_layers: int, optional The number of dilated residual layers to include in the model. Default is 8. n_outputs: int, optional The number of profile outputs from the model. Generally either 1 or 2 depending on if the data is unstranded or stranded. Default is 2. n_control_tracks: int, optional The number of control tracks to feed into the model. When predicting TFs, this is usually 2. When predicting accessibility, this is usualy 0. When 0, this input is removed from the model. Default is 2. count_loss_weight: float, optional The weight to put on the count loss. profile_output_bias: bool, optional Whether to include a bias term in the final profile convolution. Removing this term can help with attribution stability and will usually not affect performance. Default is True. count_output_bias: bool, optional Whether to include a bias term in the linear layer used to predict counts. Removing this term can help with attribution stability but may affect performance. Default is True. name: str or None, optional The name to save the model to during training. trimming: int or None, optional The amount to trim from both sides of the input window to get the output window. This value is removed from both sides, so the total number of positions removed is 2*trimming. verbose: bool, optional Whether to display statistics during training. Setting this to False will still save the file at the end, but does not print anything to screen during training. Default is True. """ def __init__(self, n_filters=64, n_layers=8, n_outputs=2, n_control_tracks=2, count_loss_weight=1, profile_output_bias=True, count_output_bias=True, name=None, trimming=None, verbose=True): super(BasePairNet, self).__init__() self.n_filters = n_filters self.n_layers = n_layers self.n_outputs = n_outputs self.n_control_tracks = n_control_tracks self.count_loss_weight = count_loss_weight self.name = name or "bpnet.{}.{}".format(n_filters, n_layers) self.trimming = trimming or 2 ** n_layers self.iconv = torch.nn.Conv1d(4, n_filters, kernel_size=21, padding=10) self.irelu = torch.nn.ReLU() self.rconvs = torch.nn.ModuleList([ torch.nn.Conv1d(n_filters, n_filters, kernel_size=3, padding=2**i, dilation=2**i) for i in range(1, self.n_layers+1) ]) self.rrelus = torch.nn.ModuleList([ torch.nn.ReLU() for i in range(1, self.n_layers+1) ]) self.fconv = torch.nn.Conv1d(n_filters+n_control_tracks, n_outputs, kernel_size=75, padding=37, bias=profile_output_bias) n_count_control = 1 if n_control_tracks > 0 else 0 self.linear = torch.nn.Linear(n_filters+n_count_control, 1, bias=count_output_bias) self.logger = Logger(["Epoch", "Iteration", "Training Time", "Validation Time", "Training MNLL", "Training Count MSE", "Validation MNLL", "Validation Profile Pearson", "Validation Count Pearson", "Validation Count MSE", "Saved?"], verbose=verbose) def forward(self, X, X_ctl=None): """A forward pass of the model. This method takes in a nucleotide sequence X, a corresponding per-position value from a control track, and a per-locus value from the control track and makes predictions for the profile and for the counts. This per-locus value is usually the log(sum(X_ctl_profile)+1) when the control is an experimental read track but can also be the output from another model. Parameters ---------- X: torch.tensor, shape=(batch_size, 4, length) The one-hot encoded batch of sequences. X_ctl: torch.tensor or None, shape=(batch_size, n_strands, length) A value representing the signal of the control at each position in the sequence. If no controls, pass in None. Default is None. Returns ------- y_profile: torch.tensor, shape=(batch_size, n_strands, out_length) The output predictions for each strand trimmed to the output length. """ start, end = self.trimming, X.shape[2] - self.trimming X = self.iconv(X) for i in range(self.n_layers): X_a = self.rrelus[i](X) X_conv = self.rconvs[i](X_a) X = torch.add(X, X_conv) X = self.irelu(X) if X_ctl is None: X_w_ctl = X else: X_w_ctl = torch.cat([X, X_ctl], dim=1) y_profile = self.fconv(X_w_ctl)[:, :, start:end] # counts prediction X = torch.mean(X[:, :, start-37:end+37], dim=2) if X_ctl is not None: X_ctl = torch.sum(X_ctl[:, :, start:end], dim=(1, 2)) X_ctl = X_ctl.unsqueeze(-1) X = torch.cat([X, torch.log(X_ctl+2)], dim=-1) y_counts = self.linear(X).reshape(X.shape[0], 1) return y_profile, y_counts @classmethod def from_bpnet(cls, filename): """Loads a model from BPNet TensorFlow format. This method will allow you to load a BPNet model from the basepairmodels repo that has been saved in TensorFlow format. You do not need to have TensorFlow installed to use this function. The result will be a model whose predictions and attributions are identical to those produced when using the TensorFlow code. Parameters ---------- filename: str The name of the h5 file that stores the trained model parameters. Returns ------- model: BPNet A BPNet model compatible with this repository in PyTorch. """ h5 = h5py.File(filename, "r") w, k, b = h5['model_weights'], 'kernel:0', 'bias:0' extract = lambda name, suffix: w['{0}/{0}/{1}'.format(name, suffix)][:] convert_w = lambda x: torch.nn.Parameter(torch.tensor(x).permute(2, 1, 0)) convert_b = lambda x: torch.nn.Parameter(torch.tensor(x)) n_layers, n_filters = 0, extract("main_conv_0", k).shape[2] for layer_name in w.keys(): if 'main_dil_conv' in layer_name: n_layers = max(n_layers, int(layer_name.split("_")[-1])) model = cls(n_layers=n_layers, n_filters=n_filters, n_outputs=2, n_control_tracks=2, trimming=(2114-1000)//2) model.iconv.weight = convert_w(extract("main_conv_0", k)) model.iconv.bias = convert_b(extract("main_conv_0", b)) model.iconv.padding = ((model.iconv.weight.shape[-1] - 1) // 2,) for i in range(1, n_layers+1): lname = "main_dil_conv_{}".format(i) model.rconvs[i-1].weight = convert_w(extract(lname, k)) model.rconvs[i-1].bias = convert_b(extract(lname, b)) w0 = model.fconv.weight.numpy(force=True) wph = extract("main_profile_head", k) wpp = extract("profile_predictions", k)[0, :2] conv_weight = numpy.zeros_like(w0.transpose(2, 1, 0)) conv_weight[:, :n_filters] = wph.dot(wpp) conv_weight[37, n_filters:] = extract("profile_predictions", k)[0, 2:] model.fconv.weight = convert_w(conv_weight) model.fconv.bias = (convert_b(extract("main_profile_head", b) + extract("profile_predictions", b))) model.fconv.padding = ((model.fconv.weight.shape[-1] - 1) // 2,) linear_weight = numpy.zeros_like(model.linear.weight.numpy(force=True)) linear_weight[:, :n_filters] = (extract("main_counts_head", k).T * extract("logcounts_predictions", k)[0]) linear_weight[:, -1] = extract("logcounts_predictions", k)[1] model.linear.weight = convert_b(linear_weight) model.linear.bias = (convert_b(extract("main_counts_head", b) * extract("logcounts_predictions", k)[0] + extract("logcounts_predictions", b))) return model --- bpnetlite/attribute.py # attribute.py # Author: Jacob Schreiber from bpnetlite.bpnet import _ProfileLogitScaling from bpnetlite.chrombpnet import _Log, _Exp from tangermeme.ersatz import dinucleotide_shuffle from tangermeme.deep_lift_shap import deep_lift_shap as t_deep_lift_shap from tangermeme.deep_lift_shap import _nonlinear def deep_lift_shap(model, X, args=None, target=0, batch_size=32, references=dinucleotide_shuffle, n_shuffles=20, return_references=False, hypothetical=False, warning_threshold=0.001, additional_nonlinear_ops=None, print_convergence_deltas=False, raw_outputs=False, device='cuda', random_state=None, verbose=False): """A wrapper that registers Chrom/BPNet's custom non-linearities. This function is just a wrapper for tangermeme's deep_lift_shap function except that it automatically registers the layers that are necessary for using BPNet models. Specifically, it registers a scaling that is necessary for calculating the profile attributions and also registers the logsumexp operation for counts when using the full ChromBPNet model. Other than automatically registering the non-linearities, this wrapper does not modify the tangermeme outputs or alter the inputs in any way. It is simply for convenience so you do not need to reach into bpnet-lite's internals each time you want to calculate attributions. Parameters ---------- model: torch.nn.Module A PyTorch model to use for making predictions. These models can take in any number of inputs and make any number of outputs. The additional inputs must be specified in the `args` parameter. X: torch.tensor, shape=(-1, len(alphabet), length) A set of one-hot encoded sequences to calculate attribution values for. args: tuple or None, optional An optional set of additional arguments to pass into the model. If provided, each element in the tuple or list is one input to the model and the element must be formatted to be the same batch size as `X`. If None, no additional arguments are passed into the forward function. Default is None. target: int, optional The output of the model to calculate gradients/attributions for. This will index the last dimension of the predictions. Default is 0. batch_size: int, optional The number of sequence-reference pairs to pass through DeepLiftShap at a time. Importantly, this is not the number of elements in `X` that are processed simultaneously (alongside ALL their references) but the total number of `X`-`reference` pairs that are processed. This means that if you are in a memory-limited setting where you cannot process all references for even a single sequence simultaneously that the work is broken down into doing only a few references at a time. Default is 32. references: func or torch.Tensor, optional If a function is passed in, this function is applied to each sequence with the provided random state and number of shuffles. This function should serve to transform a sequence into some form of signal-null background, such as by shuffling it. If a torch.Tensor is passed in, that tensor must have shape `(len(X), n_shuffles, *X.shape[1:])`, in that for each sequence a number of shuffles are provided. Default is the function `dinucleotide_shuffle`. n_shuffles: int, optional The number of shuffles to use if a function is given for `references`. If a torch.Tensor is provided, this number is ignored. Default is 20. return_references: bool, optional Whether to return the references that were generated during this process. Only use if `references` is not a torch.Tensor. Default is False. hypothetical: bool, optional Whether to return attributions for all possible characters at each position or only for the character that is actually at the sequence. Practically, whether to return the returned attributions from captum with the one-hot encoded sequence. Default is False. warning_threshold: float, optional A threshold on the convergence delta that will always raise a warning if the delta is larger than it. Normal deltas are in the range of 1e-6 to 1e-8. Note that convergence deltas are calculated on the gradients prior to the aggr_func being applied to them. Default is 0.001. additional_nonlinear_ops: dict or None, optional If additional nonlinear ops need to be added to the dictionary of operations that can be handled by DeepLIFT/SHAP, pass a dictionary here where the keys are class types and the values are the name of the function that handle that sort of class. Make sure that the signature matches those of `_nonlinear` and `_maxpool` above. This can also be used to overwrite the hard-coded operations by passing in a dictionary with overlapping key names. If None, do not add any additional operations. Default is None. print_convergence_deltas: bool, optional Whether to print the convergence deltas for each example when using DeepLiftShap. Default is False. raw_outputs: bool, optional Whether to return the raw outputs from the method -- in this case, the multipliers for each example-reference pair -- or the processed attribution values. Default is False. device: str or torch.device, optional The device to move the model and batches to when making predictions. If set to 'cuda' without a GPU, this function will crash and must be set to 'cpu'. Default is 'cuda'. random_state: int or None or numpy.random.RandomState, optional The random seed to use to ensure determinism. If None, the process is not deterministic. Default is None. verbose: bool, optional Whether to display a progress bar. Default is False. Returns ------- attributions: torch.tensor If `raw_outputs=False` (default), the attribution values with shape equal to `X`. If `raw_outputs=True`, the multipliers for each example- reference pair with shape equal to `(X.shape[0], n_shuffles, X.shape[1], X.shape[2])`. references: torch.tensor, optional The references used for each input sequence, with the shape (n_input_sequences, n_shuffles, 4, length). Only returned if `return_references = True`. """ return t_deep_lift_shap(model=model, X=X, args=args, target=target, batch_size=batch_size, references=references, n_shuffles=n_shuffles, return_references=return_references, hypothetical=hypothetical, warning_threshold=warning_threshold, additional_nonlinear_ops={ _ProfileLogitScaling: _nonlinear, _Log: _nonlinear, _Exp: _nonlinear }, print_convergence_deltas=print_convergence_deltas, raw_outputs=raw_outputs, device=device, random_state=random_state, verbose=verbose)