Affinity Distillation¶

In [1]:
import random
import glob
import h5py
import numpy as np
import pandas as pd
from pyfaidx import Fasta
from tangermeme.utils import one_hot_encode, characters
from tangermeme.ersatz import substitute, dinucleotide_shuffle
from Bio.Seq import Seq
import matplotlib.pyplot as plt
%matplotlib inline
import seaborn as sns
from adjustText import adjust_text
from sklearn.linear_model import RANSACRegressor

from tensorflow.keras.models import load_model
# Need to import to avoid an error
from bpnet.model.arch import *
/oak/stanford/groups/akundaje/shouvikm/miniconda3/envs/bpnet/lib/python3.7/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html
  from .autonotebook import tqdm as notebook_tqdm
2026-01-02 02:46:02.757727: I tensorflow/stream_executor/platform/default/dso_loader.cc:49] Successfully opened dynamic library libcudart.so.11.0
In [2]:
def affinity_distillation(motif, model, fold=0):
    model_file = f"/users/shouvikm/BPNet/models/{model}/fold_{fold}/model_split000"
    bpnet = load_model(model_file)

    # TODO: change for each fold
    test_chromosomes = ["chr1", "chr3", "chr6"]
    peak_file = glob.glob(f"/users/shouvikm/data/Cell_Lines_Ped_Leukemia/Reh/*/{model}/idr_peaks/peaks_inliers.bed")[0]
    peaks_df = pd.read_csv(peak_file, sep='\t', header=None)
    held_out_peaks_df = peaks_df[peaks_df[0].isin(test_chromosomes)]

    input_seq_len = 2114
    num_marginalization_samples = 100
    genome = Fasta("/users/shouvikm/BPNet/data/hg38.genome.fa")
    nucleotides = ['A', 'T', 'C', 'G']

    seed = 0
    random_sequences = []
    motif_inserted_sequences = []
    while(len(random_sequences) < num_marginalization_samples):
        random.seed(seed)

        random_peak_index = random.randint(0, held_out_peaks_df.shape[0]-1)
        random_peak = held_out_peaks_df.iloc[random_peak_index]
        chr, start, end = random_peak[0], random_peak[1], random_peak[2]
        peak_center = start + ((end - start) // 2)
        random_seq_start = peak_center - (input_seq_len // 2)
        random_seq_end = peak_center + (input_seq_len // 2)
        random_seq = genome[chr][random_seq_start:random_seq_end].seq

        all_valid_nucleotides = all([letter in nucleotides for letter in random_seq])
        if not all_valid_nucleotides: 
            seed += 1
            continue

        random_seq_one_hot = one_hot_encode(random_seq).unsqueeze(0)
        random_seq_shuf_one_hot = dinucleotide_shuffle(random_seq_one_hot, n=1, random_state=seed)
        random_seq_bpnet_input = random_seq_shuf_one_hot[0].numpy().transpose((0, 2, 1))\
                                    .astype("float32")
        random_sequences.append(random_seq_bpnet_input)
        
        motif_inserted_seq_one_hot = substitute(random_seq_shuf_one_hot[0], motif)
        motif_inserted_seq_bpnet_input = motif_inserted_seq_one_hot.numpy().transpose((0, 2, 1))\
                                            .astype("float32")
        motif_inserted_sequences.append(motif_inserted_seq_bpnet_input)
        
        seed += 1
    
    random_sequences = np.vstack(random_sequences)
    motif_inserted_sequences = np.vstack(motif_inserted_sequences)
        
    profile_bias = np.zeros((num_marginalization_samples, 1000, 2), dtype=np.float32)
    counts_bias = np.zeros((num_marginalization_samples, 2), dtype=np.float32)
    _, pred_logcounts_random_sequences = bpnet.predict([random_sequences, 
                                                        profile_bias, 
                                                        counts_bias])
    _, pred_logcounts_motif_inserted_sequences = bpnet.predict([motif_inserted_sequences, 
                                                                profile_bias, 
                                                                counts_bias])

    log_fold_changes = pred_logcounts_motif_inserted_sequences - pred_logcounts_random_sequences
    marginalization_score = log_fold_changes.mean()
    return marginalization_score
In [3]:
def generate_all_1bp_substitutions(core_motif):
    motif_mutations = [core_motif]
    nucleotides = ['A', 'C', 'G', 'T']
    for i, motif_nucleotide in enumerate(core_motif):
        for possible_nucleotide in nucleotides:
            if motif_nucleotide is not possible_nucleotide:
                motif_mutated = (core_motif[:i] + possible_nucleotide + core_motif[i+1:])
                motif_mutations.append(motif_mutated)
    return motif_mutations

def generate_all_2bp_substitutions(core_motif):
    len_core_motif = len(core_motif)
    motif_mutations = [core_motif]
    nucleotides = ['A', 'C', 'G', 'T']
    for i, motif_nucleotide_i in enumerate(core_motif):
        for j in range(i+1, len_core_motif):
            motif_nucleotide_j = core_motif[j]
            for possible_nucleotide_i in nucleotides:
                for possible_nucleotide_j in nucleotides:
                    if (possible_nucleotide_i is not motif_nucleotide_i and
                        possible_nucleotide_j is not motif_nucleotide_j):
                        motif_mutated = (core_motif[:i] + 
                                         possible_nucleotide_i + 
                                         core_motif[i+1:j] + 
                                         possible_nucleotide_j + 
                                         core_motif[j+1:])
                        motif_mutations.append(motif_mutated)
    return motif_mutations
In [4]:
core_motif = 'TGTGGT'
motif_mutations_1bp = generate_all_1bp_substitutions(core_motif)
motif_mutations_1bp
Out[4]:
['TGTGGT',
 'AGTGGT',
 'CGTGGT',
 'GGTGGT',
 'TATGGT',
 'TCTGGT',
 'TTTGGT',
 'TGAGGT',
 'TGCGGT',
 'TGGGGT',
 'TGTAGT',
 'TGTCGT',
 'TGTTGT',
 'TGTGAT',
 'TGTGCT',
 'TGTGTT',
 'TGTGGA',
 'TGTGGC',
 'TGTGGG']
In [5]:
core_motif = 'TGTGGT'
motif_mutations_2bp = generate_all_2bp_substitutions(core_motif)
motif_mutations_2bp
Out[5]:
['TGTGGT',
 'AATGGT',
 'ACTGGT',
 'ATTGGT',
 'CATGGT',
 'CCTGGT',
 'CTTGGT',
 'GATGGT',
 'GCTGGT',
 'GTTGGT',
 'AGAGGT',
 'AGCGGT',
 'AGGGGT',
 'CGAGGT',
 'CGCGGT',
 'CGGGGT',
 'GGAGGT',
 'GGCGGT',
 'GGGGGT',
 'AGTAGT',
 'AGTCGT',
 'AGTTGT',
 'CGTAGT',
 'CGTCGT',
 'CGTTGT',
 'GGTAGT',
 'GGTCGT',
 'GGTTGT',
 'AGTGAT',
 'AGTGCT',
 'AGTGTT',
 'CGTGAT',
 'CGTGCT',
 'CGTGTT',
 'GGTGAT',
 'GGTGCT',
 'GGTGTT',
 'AGTGGA',
 'AGTGGC',
 'AGTGGG',
 'CGTGGA',
 'CGTGGC',
 'CGTGGG',
 'GGTGGA',
 'GGTGGC',
 'GGTGGG',
 'TAAGGT',
 'TACGGT',
 'TAGGGT',
 'TCAGGT',
 'TCCGGT',
 'TCGGGT',
 'TTAGGT',
 'TTCGGT',
 'TTGGGT',
 'TATAGT',
 'TATCGT',
 'TATTGT',
 'TCTAGT',
 'TCTCGT',
 'TCTTGT',
 'TTTAGT',
 'TTTCGT',
 'TTTTGT',
 'TATGAT',
 'TATGCT',
 'TATGTT',
 'TCTGAT',
 'TCTGCT',
 'TCTGTT',
 'TTTGAT',
 'TTTGCT',
 'TTTGTT',
 'TATGGA',
 'TATGGC',
 'TATGGG',
 'TCTGGA',
 'TCTGGC',
 'TCTGGG',
 'TTTGGA',
 'TTTGGC',
 'TTTGGG',
 'TGAAGT',
 'TGACGT',
 'TGATGT',
 'TGCAGT',
 'TGCCGT',
 'TGCTGT',
 'TGGAGT',
 'TGGCGT',
 'TGGTGT',
 'TGAGAT',
 'TGAGCT',
 'TGAGTT',
 'TGCGAT',
 'TGCGCT',
 'TGCGTT',
 'TGGGAT',
 'TGGGCT',
 'TGGGTT',
 'TGAGGA',
 'TGAGGC',
 'TGAGGG',
 'TGCGGA',
 'TGCGGC',
 'TGCGGG',
 'TGGGGA',
 'TGGGGC',
 'TGGGGG',
 'TGTAAT',
 'TGTACT',
 'TGTATT',
 'TGTCAT',
 'TGTCCT',
 'TGTCTT',
 'TGTTAT',
 'TGTTCT',
 'TGTTTT',
 'TGTAGA',
 'TGTAGC',
 'TGTAGG',
 'TGTCGA',
 'TGTCGC',
 'TGTCGG',
 'TGTTGA',
 'TGTTGC',
 'TGTTGG',
 'TGTGAA',
 'TGTGAC',
 'TGTGAG',
 'TGTGCA',
 'TGTGCC',
 'TGTGCG',
 'TGTGTA',
 'TGTGTC',
 'TGTGTG']
In [6]:
len(set(motif_mutations_1bp)), len(set(motif_mutations_2bp))
Out[6]:
(19, 136)
In [7]:
set(motif_mutations_1bp) & set(motif_mutations_2bp)
Out[7]:
{'TGTGGT'}
In [8]:
motif_mutations = set(motif_mutations_1bp + motif_mutations_2bp)
len(motif_mutations)
Out[8]:
154
In [9]:
models = ['Reh_HPA', 'Reh_PAS', 'Reh_DMSO_HPA', 'Reh_DMSO_RUNX1', 
          'REH_ETV6-RUNX1_rep1', 'REH_ETV6-RUNX1_rep2', 'Reh_ETV6_Atlas',
          'Reh_RUNX1_B', 'REH_RUNX1_rep1', 'REH_RUNX1_rep2']

motif_affinities = []
for i, motif in enumerate(motif_mutations):
    if i%10==0: print(i)
    affinities = []
    for model in models:
        affinity = affinity_distillation(motif, model)
        affinities.append(affinity)
    motif_affinities.append(affinities)

motif_affinities_df = pd.DataFrame(motif_affinities)
motif_affinities_df.columns = models
motif_affinities_df.index = motif_mutations
motif_affinities_df
0
2025-12-31 07:38:30.842255: I tensorflow/compiler/jit/xla_cpu_device.cc:41] Not creating XLA devices, tf_xla_enable_xla_devices not set
2025-12-31 07:38:30.842396: I tensorflow/stream_executor/platform/default/dso_loader.cc:49] Successfully opened dynamic library libcuda.so.1
2025-12-31 07:38:30.914500: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1720] Found device 0 with properties: 
pciBusID: 0000:2a:00.0 name: NVIDIA L40S computeCapability: 8.9
coreClock: 2.52GHz coreCount: 142 deviceMemorySize: 44.53GiB deviceMemoryBandwidth: 804.75GiB/s
2025-12-31 07:38:30.915031: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1720] Found device 1 with properties: 
pciBusID: 0000:3d:00.0 name: NVIDIA L40S computeCapability: 8.9
coreClock: 2.52GHz coreCount: 142 deviceMemorySize: 44.53GiB deviceMemoryBandwidth: 804.75GiB/s
2025-12-31 07:38:30.915540: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1720] Found device 2 with properties: 
pciBusID: 0000:ab:00.0 name: NVIDIA L40S computeCapability: 8.9
coreClock: 2.52GHz coreCount: 142 deviceMemorySize: 44.53GiB deviceMemoryBandwidth: 804.75GiB/s
2025-12-31 07:38:30.915931: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1720] Found device 3 with properties: 
pciBusID: 0000:bd:00.0 name: NVIDIA L40S computeCapability: 8.9
coreClock: 2.52GHz coreCount: 142 deviceMemorySize: 44.53GiB deviceMemoryBandwidth: 804.75GiB/s
2025-12-31 07:38:30.915985: I tensorflow/stream_executor/platform/default/dso_loader.cc:49] Successfully opened dynamic library libcudart.so.11.0
2025-12-31 07:38:30.916101: I tensorflow/stream_executor/platform/default/dso_loader.cc:49] Successfully opened dynamic library libcublas.so.11
2025-12-31 07:38:30.916167: I tensorflow/stream_executor/platform/default/dso_loader.cc:49] Successfully opened dynamic library libcublasLt.so.11
2025-12-31 07:38:30.922975: I tensorflow/stream_executor/platform/default/dso_loader.cc:49] Successfully opened dynamic library libcufft.so.10
2025-12-31 07:38:30.927381: I tensorflow/stream_executor/platform/default/dso_loader.cc:49] Successfully opened dynamic library libcurand.so.10
2025-12-31 07:38:30.936735: I tensorflow/stream_executor/platform/default/dso_loader.cc:49] Successfully opened dynamic library libcusolver.so.10
2025-12-31 07:38:30.942177: I tensorflow/stream_executor/platform/default/dso_loader.cc:49] Successfully opened dynamic library libcusparse.so.11
2025-12-31 07:38:30.942307: I tensorflow/stream_executor/platform/default/dso_loader.cc:49] Successfully opened dynamic library libcudnn.so.8
2025-12-31 07:38:30.945625: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1862] Adding visible gpu devices: 0, 1, 2, 3
2025-12-31 07:38:30.946749: I tensorflow/core/platform/cpu_feature_guard.cc:142] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 AVX512F FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2025-12-31 07:38:30.954679: I tensorflow/compiler/jit/xla_gpu_device.cc:99] Not creating XLA devices, tf_xla_enable_xla_devices not set
2025-12-31 07:38:31.773727: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1720] Found device 0 with properties: 
pciBusID: 0000:2a:00.0 name: NVIDIA L40S computeCapability: 8.9
coreClock: 2.52GHz coreCount: 142 deviceMemorySize: 44.53GiB deviceMemoryBandwidth: 804.75GiB/s
2025-12-31 07:38:31.774014: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1720] Found device 1 with properties: 
pciBusID: 0000:3d:00.0 name: NVIDIA L40S computeCapability: 8.9
coreClock: 2.52GHz coreCount: 142 deviceMemorySize: 44.53GiB deviceMemoryBandwidth: 804.75GiB/s
2025-12-31 07:38:31.774267: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1720] Found device 2 with properties: 
pciBusID: 0000:ab:00.0 name: NVIDIA L40S computeCapability: 8.9
coreClock: 2.52GHz coreCount: 142 deviceMemorySize: 44.53GiB deviceMemoryBandwidth: 804.75GiB/s
2025-12-31 07:38:31.774513: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1720] Found device 3 with properties: 
pciBusID: 0000:bd:00.0 name: NVIDIA L40S computeCapability: 8.9
coreClock: 2.52GHz coreCount: 142 deviceMemorySize: 44.53GiB deviceMemoryBandwidth: 804.75GiB/s
2025-12-31 07:38:31.774578: I tensorflow/stream_executor/platform/default/dso_loader.cc:49] Successfully opened dynamic library libcudart.so.11.0
2025-12-31 07:38:31.774597: I tensorflow/stream_executor/platform/default/dso_loader.cc:49] Successfully opened dynamic library libcublas.so.11
2025-12-31 07:38:31.774614: I tensorflow/stream_executor/platform/default/dso_loader.cc:49] Successfully opened dynamic library libcublasLt.so.11
2025-12-31 07:38:31.774651: I tensorflow/stream_executor/platform/default/dso_loader.cc:49] Successfully opened dynamic library libcufft.so.10
2025-12-31 07:38:31.774673: I tensorflow/stream_executor/platform/default/dso_loader.cc:49] Successfully opened dynamic library libcurand.so.10
2025-12-31 07:38:31.774694: I tensorflow/stream_executor/platform/default/dso_loader.cc:49] Successfully opened dynamic library libcusolver.so.10
2025-12-31 07:38:31.774715: I tensorflow/stream_executor/platform/default/dso_loader.cc:49] Successfully opened dynamic library libcusparse.so.11
2025-12-31 07:38:31.774731: I tensorflow/stream_executor/platform/default/dso_loader.cc:49] Successfully opened dynamic library libcudnn.so.8
2025-12-31 07:38:31.776565: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1862] Adding visible gpu devices: 0, 1, 2, 3
2025-12-31 07:38:31.776618: I tensorflow/stream_executor/platform/default/dso_loader.cc:49] Successfully opened dynamic library libcudart.so.11.0
2025-12-31 07:38:39.365201: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1261] Device interconnect StreamExecutor with strength 1 edge matrix:
2025-12-31 07:38:39.365231: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1267]      0 1 2 3 
2025-12-31 07:38:39.365235: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1280] 0:   N Y Y Y 
2025-12-31 07:38:39.365239: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1280] 1:   Y N Y Y 
2025-12-31 07:38:39.365241: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1280] 2:   Y Y N Y 
2025-12-31 07:38:39.365243: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1280] 3:   Y Y Y N 
2025-12-31 07:38:39.366487: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1406] Created TensorFlow device (/job:localhost/replica:0/task:0/device:GPU:0 with 33956 MB memory) -> physical GPU (device: 0, name: NVIDIA L40S, pci bus id: 0000:2a:00.0, compute capability: 8.9)
2025-12-31 07:38:39.367385: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1406] Created TensorFlow device (/job:localhost/replica:0/task:0/device:GPU:1 with 3679 MB memory) -> physical GPU (device: 1, name: NVIDIA L40S, pci bus id: 0000:3d:00.0, compute capability: 8.9)
2025-12-31 07:38:39.368000: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1406] Created TensorFlow device (/job:localhost/replica:0/task:0/device:GPU:2 with 24366 MB memory) -> physical GPU (device: 2, name: NVIDIA L40S, pci bus id: 0000:ab:00.0, compute capability: 8.9)
2025-12-31 07:38:39.368618: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1406] Created TensorFlow device (/job:localhost/replica:0/task:0/device:GPU:3 with 24366 MB memory) -> physical GPU (device: 3, name: NVIDIA L40S, pci bus id: 0000:bd:00.0, compute capability: 8.9)
2025-12-31 07:38:41.050506: I tensorflow/compiler/mlir/mlir_graph_optimization_pass.cc:116] None of the MLIR optimization passes are enabled (registered 2)
2025-12-31 07:38:41.050841: I tensorflow/core/platform/profile_utils/cpu_utils.cc:112] CPU Frequency: 3600000000 Hz
10
20
30
40
50
60
70
80
90
100
110
120
130
140
150
Out[9]:
Reh_HPA Reh_PAS Reh_DMSO_HPA Reh_DMSO_RUNX1 REH_ETV6-RUNX1_rep1 REH_ETV6-RUNX1_rep2 Reh_ETV6_Atlas Reh_RUNX1_B REH_RUNX1_rep1 REH_RUNX1_rep2
CGTGGA -0.000212 -0.002838 -0.013116 -0.004291 0.001467 0.000104 -0.003589 0.006039 0.000440 -0.000896
TGTAGA -0.004476 -0.006378 -0.019743 -0.006555 -0.000499 -0.005409 -0.007663 -0.012253 -0.002754 -0.012684
ATTGGT -0.001626 -0.001803 -0.009646 -0.003104 -0.001095 -0.001675 0.001616 -0.009277 0.001535 0.005572
TCTGCT -0.004606 -0.000835 -0.009488 0.001543 -0.000567 0.003377 0.005039 -0.000926 0.000868 0.002461
AATGGT -0.000532 0.003069 -0.005507 -0.002775 -0.000613 0.004116 -0.003518 -0.007429 -0.000510 0.004722
... ... ... ... ... ... ... ... ... ... ...
TGGGGG 0.005018 0.007617 -0.000674 0.002684 0.002752 0.011724 0.016001 0.002306 0.008107 0.025045
TATGGT 0.005461 0.013459 -0.004942 0.001546 -0.000926 0.014691 0.000959 -0.003847 0.001954 0.020968
GGTAGT 0.012945 0.009405 0.011137 0.005504 0.000353 0.012853 0.011881 0.006217 0.004101 0.022718
AGTGGG 0.005152 0.008557 0.007419 0.000260 0.001790 0.014683 0.007167 0.004002 0.003688 0.013900
TATGTT -0.003757 -0.002668 -0.017644 -0.004485 -0.002120 -0.004898 -0.001480 -0.013669 -0.001131 -0.001744

154 rows × 10 columns

In [10]:
motif_affinities_df.to_csv('affinity_distillation/motif_affinities_TGTGGT.csv')
In [9]:
models = ['Reh_HPA', 'Reh_PAS', 'Reh_DMSO_HPA', 'REH_ETV6-RUNX1_rep2', 
          'Reh_ETV6_Atlas','Reh_DMSO_RUNX1', 'Reh_RUNX1_B', 'REH_RUNX1_rep2']

fusion_ER_models = ['Reh_HPA', 'Reh_PAS', 'Reh_DMSO_HPA','REH_ETV6-RUNX1_rep2', 
                    'Reh_ETV6_Atlas']
native_ER_models = ['Reh_DMSO_RUNX1', 'Reh_RUNX1_B', 'REH_RUNX1_rep2']
In [12]:
motif_affinities_df = pd.read_csv('affinity_distillation/motif_affinities_TGTGGT.csv', 
                                  index_col='Unnamed: 0')
motif_affinities_df = motif_affinities_df[models]
# multi_index_columns = pd.MultiIndex.from_tuples(
#     [('Fusion ETV6-RUNX1 models', model) for model in fusion_ER_models] + 
#     [('Native RUNX1 models', model) for model in native_ER_models]
# )
# motif_affinities_df.columns = multi_index_columns
motif_affinities_df
Out[12]:
Reh_HPA Reh_PAS Reh_DMSO_HPA REH_ETV6-RUNX1_rep2 Reh_ETV6_Atlas Reh_DMSO_RUNX1 Reh_RUNX1_B REH_RUNX1_rep2
CGTGGA -0.000212 -0.002838 -0.013116 0.000104 -0.003589 -0.004291 0.006039 -0.000896
TGTAGA -0.004476 -0.006378 -0.019743 -0.005409 -0.007663 -0.006555 -0.012253 -0.012684
ATTGGT -0.001626 -0.001803 -0.009646 -0.001675 0.001616 -0.003104 -0.009277 0.005572
TCTGCT -0.004606 -0.000835 -0.009488 0.003377 0.005039 0.001543 -0.000926 0.002461
AATGGT -0.000532 0.003069 -0.005507 0.004116 -0.003518 -0.002775 -0.007429 0.004722
... ... ... ... ... ... ... ... ...
TGGGGG 0.005018 0.007617 -0.000674 0.011724 0.016001 0.002684 0.002306 0.025045
TATGGT 0.005461 0.013459 -0.004942 0.014691 0.000959 0.001546 -0.003847 0.020968
GGTAGT 0.012945 0.009405 0.011137 0.012853 0.011881 0.005504 0.006217 0.022718
AGTGGG 0.005152 0.008557 0.007419 0.014683 0.007167 0.000260 0.004002 0.013900
TATGTT -0.003757 -0.002668 -0.017644 -0.004898 -0.001480 -0.004485 -0.013669 -0.001744

154 rows × 8 columns

In [6]:
def create_pairwise_affinities_scatter_plot(model_1, model_2, ax, size_points_by_residuals=True, 
                                            point_size_scale_residual=10000,
                                            label_points_residual_percentile=95):
    ransac_regression = RANSACRegressor(random_state=0)
    ransac_regression.fit(motif_affinities_df[model_1].values.reshape(-1, 1), 
                          motif_affinities_df[model_2].values)
    y_pred = ransac_regression.predict(motif_affinities_df[model_1].values.reshape(-1, 1))
    residuals = (motif_affinities_df[model_2] - y_pred) ** 2
    ax.scatter(motif_affinities_df[model_1], motif_affinities_df[model_2], 
                s=point_size_scale_residual*residuals if size_points_by_residuals else 10)
    residuals_nth_percentile = np.percentile(residuals, label_points_residual_percentile)
    texts = []
    for i, motif in enumerate(motif_affinities_df.index):
        if residuals[i] >= residuals_nth_percentile:
            texts.append(ax.text(motif_affinities_df[model_1][i], 
                                 motif_affinities_df[model_2][i], 
                                 motif, ha='center', va='center'))
    adjust_text(texts, ax=ax)
    ax.axline((0, ransac_regression.estimator_.intercept_), slope=ransac_regression.estimator_.coef_[0], 
               color='tab:blue')
    ax.axline((0, 0), slope=1, color='gray', linestyle='--')
    ax.set_xlabel(model_1)
    ax.set_ylabel(model_2)
In [7]:
model_1 = 'Reh_HPA'
model_2 = 'Reh_PAS'
fig, ax = plt.subplots(figsize=(5, 5))
create_pairwise_affinities_scatter_plot(model_1, model_2, ax)
No description has been provided for this image
In [8]:
model_1 = 'Reh_HPA'
model_2 = 'REH_RUNX1_rep2'
fig, ax = plt.subplots(figsize=(5, 5))
create_pairwise_affinities_scatter_plot(model_1, model_2, ax)
No description has been provided for this image
In [9]:
model_1 = 'Reh_HPA'
model_2 = 'Reh_RUNX1_B'
fig, ax = plt.subplots(figsize=(5, 5))
create_pairwise_affinities_scatter_plot(model_1, model_2, ax)
No description has been provided for this image
In [10]:
model_1 = 'REH_RUNX1_rep2'
model_2 = 'Reh_RUNX1_B'
fig, ax = plt.subplots(figsize=(5, 5))
create_pairwise_affinities_scatter_plot(model_1, model_2, ax)
No description has been provided for this image
In [13]:
fusion_ER_models = ['Reh_HPA', 'Reh_PAS', 'Reh_DMSO_HPA','REH_ETV6-RUNX1_rep2', 
                    'Reh_ETV6_Atlas']
native_ER_models = ['Reh_DMSO_RUNX1', 'Reh_RUNX1_B', 'REH_RUNX1_rep2']


fig, axs = plt.subplots(3, 5, figsize=(20, 10))
for row, model_1 in enumerate(native_ER_models):
    for col, model_2 in enumerate(fusion_ER_models):
        create_pairwise_affinities_scatter_plot(model_2, model_1, axs[row, col])
plt.tight_layout()
No description has been provided for this image
In [13]:
def calculate_differential_affinity_motifs(model, background_models, motif_affinities_df):
    ransac_regression = RANSACRegressor(random_state=0)
    ransac_regression.fit(motif_affinities_df[background_models].values, 
                          motif_affinities_df[model].values)
    y_pred = ransac_regression.predict(motif_affinities_df[background_models].values)
    residuals = motif_affinities_df[model] - y_pred
    residuals_scaled = (residuals - residuals.mean()) / residuals.std()
    differential_affinity_motifs = residuals_scaled.sort_values(ascending=False)
    return differential_affinity_motifs
In [14]:
model = 'Reh_HPA'
background_models = native_ER_models
differential_affinity_motifs = calculate_differential_affinity_motifs(model, background_models, 
                                                                      motif_affinities_df)
differential_affinity_motifs.head(10)
Out[14]:
TGTCGT    2.333437
TGTTGT    2.064227
GGTTGT    1.949198
AGTTGT    1.650539
TGTCGG    1.641491
TGACGT    1.594752
TGTAGT    1.485391
TGTTTT    1.434448
TTTCGT    1.383797
TGTCGC    1.345671
Name: Reh_HPA, dtype: float64
In [15]:
model = 'Reh_RUNX1_B'
background_models = fusion_ER_models
differential_affinity_motifs = calculate_differential_affinity_motifs(model, background_models, 
                                                                      motif_affinities_df)
differential_affinity_motifs.head(10)
Out[15]:
CGTGCT    1.803505
TACGGT    1.763948
TGCGAT    1.547879
TGCGCT    1.428850
TGCCGT    1.405616
TGCGTT    1.286048
TGGCGT    1.230265
TGTCGC    1.185657
CGTCGT    1.109339
TGTGCC    1.106918
Name: Reh_RUNX1_B, dtype: float64
In [16]:
fusion_ER_models = ['Reh_HPA', 'Reh_PAS', 'Reh_DMSO_HPA','REH_ETV6-RUNX1_rep2', 
                    'Reh_ETV6_Atlas']
native_ER_models = ['Reh_DMSO_RUNX1', 'Reh_RUNX1_B', 'REH_RUNX1_rep2']

model_diff_affinity_motifs_df = pd.DataFrame()
for model in fusion_ER_models:
    differential_affinity_motifs = calculate_differential_affinity_motifs(model, native_ER_models, 
                                                                          motif_affinities_df)
    differential_affinity_motifs = (differential_affinity_motifs.index + ' (' + 
                                    differential_affinity_motifs.round(2).values.astype(str) + ')')
    model_diff_affinity_motifs_df[model] = differential_affinity_motifs[:10]
    
for model in native_ER_models:
    differential_affinity_motifs = calculate_differential_affinity_motifs(model, fusion_ER_models, 
                                                                          motif_affinities_df)
    differential_affinity_motifs = (differential_affinity_motifs.index + ' (' + 
                                    differential_affinity_motifs.round(2).values.astype(str) + ')')
    model_diff_affinity_motifs_df[model] = differential_affinity_motifs[:10]

multi_index_columns = pd.MultiIndex.from_tuples(
    [('Fusion ETV6-RUNX1 models', model + ' (score)') for model in fusion_ER_models] + 
    [('Native RUNX1 models', model + ' (score)') for model in native_ER_models]
)
model_diff_affinity_motifs_df.columns = multi_index_columns
model_diff_affinity_motifs_df
Out[16]:
Fusion ETV6-RUNX1 models Native RUNX1 models
Reh_HPA (score) Reh_PAS (score) Reh_DMSO_HPA (score) REH_ETV6-RUNX1_rep2 (score) Reh_ETV6_Atlas (score) Reh_DMSO_RUNX1 (score) Reh_RUNX1_B (score) REH_RUNX1_rep2 (score)
0 TGTCGT (2.33) TGTCGT (2.49) AGTGGT (3.88) TGCGGT (7.81) GTTGGT (1.42) TGCGGT (9.77) CGTGCT (1.8) CGCGGT (5.82)
1 TGTTGT (2.06) TGCCGT (2.0) TGTGGG (2.45) TGTGGT (7.16) TGTTGG (1.23) TGTGGT (5.16) TACGGT (1.76) TGCGGT (3.15)
2 GGTTGT (1.95) TGTCGC (1.45) AGCGGT (2.24) CGCGGT (2.06) TGTCGG (1.19) CGCGGT (3.05) TGCGAT (1.55) CGTGGT (2.75)
3 AGTTGT (1.65) TGTCGG (1.31) GGTGGT (1.9) AGCGGT (1.82) TTTGGG (1.1) TGCGGC (2.08) TGCGCT (1.43) TGCGGG (1.87)
4 TGTCGG (1.64) TACGGT (1.14) TGAGGT (1.71) CGTGGT (1.61) TTTGTT (1.05) AGCGGT (1.41) TGCCGT (1.41) TGAGGT (1.86)
5 TGACGT (1.59) TGTTGT (1.04) AGTGGG (1.71) AGTGGT (1.55) TGTTTT (0.93) TGCGGG (1.3) TGCGTT (1.29) TTAGGT (1.53)
6 TGTAGT (1.49) AGTCGT (1.02) TGTGGC (1.59) TGCGGC (1.55) GGGGGT (0.91) GGCGGT (1.06) TGGCGT (1.23) TACGGT (1.51)
7 TGTTTT (1.43) TATGGT (1.01) AGTGGC (1.52) TGTGGC (1.44) TTTGGC (0.9) TCCGGT (0.78) TGTCGC (1.19) TGAGGC (1.44)
8 TTTCGT (1.38) AGTGGC (0.92) AGAGGT (1.22) TGTGGG (1.36) TCTGTT (0.89) TGCCGT (0.65) CGTCGT (1.11) TAAGGT (1.41)
9 TGTCGC (1.35) TGTGAT (0.89) GGTAGT (1.22) GGCGGT (1.33) TGTTGT (0.77) TGCGCT (0.57) TGTGCC (1.11) CGAGGT (1.4)
In [12]:
core_motif = 'TGTGGTTT'
motif_mutations_1bp = generate_all_1bp_substitutions(core_motif)
motif_mutations_1bp
Out[12]:
['TGTGGTTT',
 'AGTGGTTT',
 'CGTGGTTT',
 'GGTGGTTT',
 'TATGGTTT',
 'TCTGGTTT',
 'TTTGGTTT',
 'TGAGGTTT',
 'TGCGGTTT',
 'TGGGGTTT',
 'TGTAGTTT',
 'TGTCGTTT',
 'TGTTGTTT',
 'TGTGATTT',
 'TGTGCTTT',
 'TGTGTTTT',
 'TGTGGATT',
 'TGTGGCTT',
 'TGTGGGTT',
 'TGTGGTAT',
 'TGTGGTCT',
 'TGTGGTGT',
 'TGTGGTTA',
 'TGTGGTTC',
 'TGTGGTTG']
In [13]:
core_motif = 'TGTGGTTT'
motif_mutations_2bp = generate_all_2bp_substitutions(core_motif)
motif_mutations_2bp
Out[13]:
['TGTGGTTT',
 'AATGGTTT',
 'ACTGGTTT',
 'ATTGGTTT',
 'CATGGTTT',
 'CCTGGTTT',
 'CTTGGTTT',
 'GATGGTTT',
 'GCTGGTTT',
 'GTTGGTTT',
 'AGAGGTTT',
 'AGCGGTTT',
 'AGGGGTTT',
 'CGAGGTTT',
 'CGCGGTTT',
 'CGGGGTTT',
 'GGAGGTTT',
 'GGCGGTTT',
 'GGGGGTTT',
 'AGTAGTTT',
 'AGTCGTTT',
 'AGTTGTTT',
 'CGTAGTTT',
 'CGTCGTTT',
 'CGTTGTTT',
 'GGTAGTTT',
 'GGTCGTTT',
 'GGTTGTTT',
 'AGTGATTT',
 'AGTGCTTT',
 'AGTGTTTT',
 'CGTGATTT',
 'CGTGCTTT',
 'CGTGTTTT',
 'GGTGATTT',
 'GGTGCTTT',
 'GGTGTTTT',
 'AGTGGATT',
 'AGTGGCTT',
 'AGTGGGTT',
 'CGTGGATT',
 'CGTGGCTT',
 'CGTGGGTT',
 'GGTGGATT',
 'GGTGGCTT',
 'GGTGGGTT',
 'AGTGGTAT',
 'AGTGGTCT',
 'AGTGGTGT',
 'CGTGGTAT',
 'CGTGGTCT',
 'CGTGGTGT',
 'GGTGGTAT',
 'GGTGGTCT',
 'GGTGGTGT',
 'AGTGGTTA',
 'AGTGGTTC',
 'AGTGGTTG',
 'CGTGGTTA',
 'CGTGGTTC',
 'CGTGGTTG',
 'GGTGGTTA',
 'GGTGGTTC',
 'GGTGGTTG',
 'TAAGGTTT',
 'TACGGTTT',
 'TAGGGTTT',
 'TCAGGTTT',
 'TCCGGTTT',
 'TCGGGTTT',
 'TTAGGTTT',
 'TTCGGTTT',
 'TTGGGTTT',
 'TATAGTTT',
 'TATCGTTT',
 'TATTGTTT',
 'TCTAGTTT',
 'TCTCGTTT',
 'TCTTGTTT',
 'TTTAGTTT',
 'TTTCGTTT',
 'TTTTGTTT',
 'TATGATTT',
 'TATGCTTT',
 'TATGTTTT',
 'TCTGATTT',
 'TCTGCTTT',
 'TCTGTTTT',
 'TTTGATTT',
 'TTTGCTTT',
 'TTTGTTTT',
 'TATGGATT',
 'TATGGCTT',
 'TATGGGTT',
 'TCTGGATT',
 'TCTGGCTT',
 'TCTGGGTT',
 'TTTGGATT',
 'TTTGGCTT',
 'TTTGGGTT',
 'TATGGTAT',
 'TATGGTCT',
 'TATGGTGT',
 'TCTGGTAT',
 'TCTGGTCT',
 'TCTGGTGT',
 'TTTGGTAT',
 'TTTGGTCT',
 'TTTGGTGT',
 'TATGGTTA',
 'TATGGTTC',
 'TATGGTTG',
 'TCTGGTTA',
 'TCTGGTTC',
 'TCTGGTTG',
 'TTTGGTTA',
 'TTTGGTTC',
 'TTTGGTTG',
 'TGAAGTTT',
 'TGACGTTT',
 'TGATGTTT',
 'TGCAGTTT',
 'TGCCGTTT',
 'TGCTGTTT',
 'TGGAGTTT',
 'TGGCGTTT',
 'TGGTGTTT',
 'TGAGATTT',
 'TGAGCTTT',
 'TGAGTTTT',
 'TGCGATTT',
 'TGCGCTTT',
 'TGCGTTTT',
 'TGGGATTT',
 'TGGGCTTT',
 'TGGGTTTT',
 'TGAGGATT',
 'TGAGGCTT',
 'TGAGGGTT',
 'TGCGGATT',
 'TGCGGCTT',
 'TGCGGGTT',
 'TGGGGATT',
 'TGGGGCTT',
 'TGGGGGTT',
 'TGAGGTAT',
 'TGAGGTCT',
 'TGAGGTGT',
 'TGCGGTAT',
 'TGCGGTCT',
 'TGCGGTGT',
 'TGGGGTAT',
 'TGGGGTCT',
 'TGGGGTGT',
 'TGAGGTTA',
 'TGAGGTTC',
 'TGAGGTTG',
 'TGCGGTTA',
 'TGCGGTTC',
 'TGCGGTTG',
 'TGGGGTTA',
 'TGGGGTTC',
 'TGGGGTTG',
 'TGTAATTT',
 'TGTACTTT',
 'TGTATTTT',
 'TGTCATTT',
 'TGTCCTTT',
 'TGTCTTTT',
 'TGTTATTT',
 'TGTTCTTT',
 'TGTTTTTT',
 'TGTAGATT',
 'TGTAGCTT',
 'TGTAGGTT',
 'TGTCGATT',
 'TGTCGCTT',
 'TGTCGGTT',
 'TGTTGATT',
 'TGTTGCTT',
 'TGTTGGTT',
 'TGTAGTAT',
 'TGTAGTCT',
 'TGTAGTGT',
 'TGTCGTAT',
 'TGTCGTCT',
 'TGTCGTGT',
 'TGTTGTAT',
 'TGTTGTCT',
 'TGTTGTGT',
 'TGTAGTTA',
 'TGTAGTTC',
 'TGTAGTTG',
 'TGTCGTTA',
 'TGTCGTTC',
 'TGTCGTTG',
 'TGTTGTTA',
 'TGTTGTTC',
 'TGTTGTTG',
 'TGTGAATT',
 'TGTGACTT',
 'TGTGAGTT',
 'TGTGCATT',
 'TGTGCCTT',
 'TGTGCGTT',
 'TGTGTATT',
 'TGTGTCTT',
 'TGTGTGTT',
 'TGTGATAT',
 'TGTGATCT',
 'TGTGATGT',
 'TGTGCTAT',
 'TGTGCTCT',
 'TGTGCTGT',
 'TGTGTTAT',
 'TGTGTTCT',
 'TGTGTTGT',
 'TGTGATTA',
 'TGTGATTC',
 'TGTGATTG',
 'TGTGCTTA',
 'TGTGCTTC',
 'TGTGCTTG',
 'TGTGTTTA',
 'TGTGTTTC',
 'TGTGTTTG',
 'TGTGGAAT',
 'TGTGGACT',
 'TGTGGAGT',
 'TGTGGCAT',
 'TGTGGCCT',
 'TGTGGCGT',
 'TGTGGGAT',
 'TGTGGGCT',
 'TGTGGGGT',
 'TGTGGATA',
 'TGTGGATC',
 'TGTGGATG',
 'TGTGGCTA',
 'TGTGGCTC',
 'TGTGGCTG',
 'TGTGGGTA',
 'TGTGGGTC',
 'TGTGGGTG',
 'TGTGGTAA',
 'TGTGGTAC',
 'TGTGGTAG',
 'TGTGGTCA',
 'TGTGGTCC',
 'TGTGGTCG',
 'TGTGGTGA',
 'TGTGGTGC',
 'TGTGGTGG']
In [14]:
len(set(motif_mutations_1bp)), len(set(motif_mutations_2bp))
Out[14]:
(25, 253)
In [15]:
set(motif_mutations_1bp) & set(motif_mutations_2bp)
Out[15]:
{'TGTGGTTT'}
In [16]:
motif_mutations = set(motif_mutations_1bp + motif_mutations_2bp)
len(motif_mutations)
Out[16]:
277
In [17]:
models = ['Reh_HPA', 'Reh_PAS', 'Reh_DMSO_HPA', 'Reh_DMSO_RUNX1', 
          'REH_ETV6-RUNX1_rep1', 'REH_ETV6-RUNX1_rep2', 'Reh_ETV6_Atlas',
          'Reh_RUNX1_B', 'REH_RUNX1_rep1', 'REH_RUNX1_rep2']

motif_affinities = []
for i, motif in enumerate(motif_mutations):
    if i%10==0: print(i)
    affinities = []
    for model in models:
        affinity = affinity_distillation(motif, model)
        affinities.append(affinity)
    motif_affinities.append(affinities)

motif_affinities_df = pd.DataFrame(motif_affinities)
motif_affinities_df.columns = models
motif_affinities_df.index = motif_mutations
motif_affinities_df
0
10
20
30
40
50
60
70
80
90
100
110
120
130
140
150
160
170
180
190
200
210
220
230
240
250
260
270
Out[17]:
Reh_HPA Reh_PAS Reh_DMSO_HPA Reh_DMSO_RUNX1 REH_ETV6-RUNX1_rep1 REH_ETV6-RUNX1_rep2 Reh_ETV6_Atlas Reh_RUNX1_B REH_RUNX1_rep1 REH_RUNX1_rep2
TTCGGTTT 0.026151 0.031984 0.087783 0.041492 -0.001437 0.094285 0.033748 0.018817 0.013151 0.094560
TTTGGCTT 0.008133 0.009144 0.022339 0.004334 -0.001524 0.018462 0.022594 -0.005896 0.007657 0.029627
TATGGATT -0.003708 -0.003552 -0.017911 -0.004375 -0.002005 -0.002423 -0.007832 -0.017566 -0.003377 -0.010938
TGCGTTTT 0.017434 0.023183 0.039789 0.025254 -0.001360 0.043630 0.024720 0.014077 0.011221 0.025533
TGCTGTTT 0.016018 0.022195 0.020867 0.028237 -0.001536 0.022974 0.029029 0.006019 0.010374 0.020505
... ... ... ... ... ... ... ... ... ... ...
TGTGGTGA 0.024082 0.039744 0.174464 0.132503 0.000720 0.163614 0.020788 0.091373 0.008147 0.191925
GGGGGTTT 0.015742 0.021883 0.023532 0.008895 0.001181 0.047843 0.043184 -0.001390 0.016098 0.049890
TATGATTT 0.000564 -0.001104 -0.018721 -0.003216 -0.003076 -0.005605 -0.006579 -0.020833 -0.004913 -0.011075
TGTAGGTT 0.011757 -0.000436 -0.006819 -0.001689 -0.000696 0.004141 0.015032 -0.009839 0.005587 0.008547
CGTGGTAT 0.010704 0.009325 0.026374 0.000786 0.000110 0.041922 0.001014 0.018507 0.002891 0.058079

277 rows × 10 columns

In [18]:
motif_affinities_df.to_csv('affinity_distillation/motif_affinities_TGTGGTTT.csv')
In [17]:
motif_affinities_df = pd.read_csv('affinity_distillation/motif_affinities_TGTGGTTT.csv', 
                                  index_col='Unnamed: 0')
motif_affinities_df = motif_affinities_df[models]
motif_affinities_df
Out[17]:
Reh_HPA Reh_PAS Reh_DMSO_HPA REH_ETV6-RUNX1_rep2 Reh_ETV6_Atlas Reh_DMSO_RUNX1 Reh_RUNX1_B REH_RUNX1_rep2
TTCGGTTT 0.026151 0.031984 0.087783 0.094285 0.033748 0.041492 0.018817 0.094560
TTTGGCTT 0.008133 0.009144 0.022339 0.018462 0.022594 0.004334 -0.005896 0.029627
TATGGATT -0.003708 -0.003552 -0.017911 -0.002423 -0.007832 -0.004375 -0.017566 -0.010938
TGCGTTTT 0.017434 0.023183 0.039789 0.043630 0.024720 0.025254 0.014077 0.025533
TGCTGTTT 0.016018 0.022195 0.020867 0.022974 0.029029 0.028237 0.006019 0.020505
... ... ... ... ... ... ... ... ...
TGTGGTGA 0.024082 0.039744 0.174464 0.163614 0.020788 0.132503 0.091373 0.191925
GGGGGTTT 0.015742 0.021883 0.023532 0.047843 0.043184 0.008895 -0.001390 0.049890
TATGATTT 0.000564 -0.001104 -0.018721 -0.005605 -0.006579 -0.003216 -0.020833 -0.011075
TGTAGGTT 0.011757 -0.000436 -0.006819 0.004141 0.015032 -0.001689 -0.009839 0.008547
CGTGGTAT 0.010704 0.009325 0.026374 0.041922 0.001014 0.000786 0.018507 0.058079

277 rows × 8 columns

In [19]:
model_1 = 'Reh_HPA'
model_2 = 'Reh_RUNX1_B'
fig, ax = plt.subplots(figsize=(5, 5))
create_pairwise_affinities_scatter_plot(model_1, model_2, ax, 
                                        point_size_scale_residual=1000)
No description has been provided for this image
In [21]:
fusion_ER_models = ['Reh_HPA', 'Reh_PAS', 'Reh_DMSO_HPA','REH_ETV6-RUNX1_rep2', 
                    'Reh_ETV6_Atlas']
native_ER_models = ['Reh_DMSO_RUNX1', 'Reh_RUNX1_B', 'REH_RUNX1_rep2']


fig, axs = plt.subplots(3, 5, figsize=(20, 10))
for row, model_1 in enumerate(native_ER_models):
    for col, model_2 in enumerate(fusion_ER_models):
        create_pairwise_affinities_scatter_plot(model_2, model_1, axs[row, col],
                                                point_size_scale_residual=1000)
plt.tight_layout()
No description has been provided for this image
In [18]:
model = 'Reh_HPA'
background_models = native_ER_models
differential_affinity_motifs = calculate_differential_affinity_motifs(model, background_models, 
                                                                      motif_affinities_df)
differential_affinity_motifs.head(10)
Out[18]:
TGTCGTTT    2.396878
TGTCGGTT    1.953531
TGCCGTTT    1.618807
TTTCGTTT    1.511114
TGTTGTTT    1.413762
TGTCGTTA    1.406167
TTTGGTTT    1.375332
TGTAGTTT    1.320371
TGACGTTT    1.273735
TGTTGGTT    1.182333
Name: Reh_HPA, dtype: float64
In [19]:
model = 'Reh_RUNX1_B'
background_models = fusion_ER_models
differential_affinity_motifs = calculate_differential_affinity_motifs(model, background_models, 
                                                                      motif_affinities_df)
differential_affinity_motifs.head(10)
Out[19]:
TGCGGTTA    6.924182
TGCGGTTG    6.398754
TGCGGTTT    4.996113
TGCGGTCT    4.936024
TGTGGTCG    3.882444
TGTGGTCA    3.263083
TGTGGTTA    3.092863
TGTGGTTG    3.013317
TGCGGTTC    3.001735
TGCGGTGT    2.543941
Name: Reh_RUNX1_B, dtype: float64
In [20]:
fusion_ER_models = ['Reh_HPA', 'Reh_PAS', 'Reh_DMSO_HPA','REH_ETV6-RUNX1_rep2', 
                    'Reh_ETV6_Atlas']
native_ER_models = ['Reh_DMSO_RUNX1', 'Reh_RUNX1_B', 'REH_RUNX1_rep2']

model_diff_affinity_motifs_df = pd.DataFrame()
for model in fusion_ER_models:
    differential_affinity_motifs = calculate_differential_affinity_motifs(model, native_ER_models, 
                                                                          motif_affinities_df)
    differential_affinity_motifs = (differential_affinity_motifs.index + ' (' + 
                                    differential_affinity_motifs.round(2).values.astype(str) + ')')
    model_diff_affinity_motifs_df[model] = differential_affinity_motifs[:10]
    
for model in native_ER_models:
    differential_affinity_motifs = calculate_differential_affinity_motifs(model, fusion_ER_models, 
                                                                          motif_affinities_df)
    differential_affinity_motifs = (differential_affinity_motifs.index + ' (' + 
                                    differential_affinity_motifs.round(2).values.astype(str) + ')')
    model_diff_affinity_motifs_df[model] = differential_affinity_motifs[:10]

multi_index_columns = pd.MultiIndex.from_tuples(
    [('Fusion ETV6-RUNX1 models', model + ' (score)') for model in fusion_ER_models] + 
    [('Native RUNX1 models', model + ' (score)') for model in native_ER_models]
)
model_diff_affinity_motifs_df.columns = multi_index_columns
model_diff_affinity_motifs_df
Out[20]:
Fusion ETV6-RUNX1 models Native RUNX1 models
Reh_HPA (score) Reh_PAS (score) Reh_DMSO_HPA (score) REH_ETV6-RUNX1_rep2 (score) Reh_ETV6_Atlas (score) Reh_DMSO_RUNX1 (score) Reh_RUNX1_B (score) REH_RUNX1_rep2 (score)
0 TGTCGTTT (2.4) TGTCGTTT (0.84) AGTGGTTT (4.98) TGTGGTTT (6.93) TGTTGGTT (0.77) TGCGGTTA (7.02) TGCGGTTA (6.92) TGCGGTTA (6.17)
1 TGTCGGTT (1.95) TGTCGTTA (0.69) AGCGGTTT (3.7) TGCGGTTT (5.09) TGTTGTTG (0.7) TGCGGTTT (6.46) TGCGGTTG (6.4) CGTGGTTA (3.58)
2 TGCCGTTT (1.62) TGTCGTAT (0.68) GGTGGTTT (3.4) TGTGGTTA (4.18) TGTCGGTT (0.7) TGCGGTTG (5.74) TGCGGTTT (5.0) TGCGGTTT (3.16)
3 TTTCGTTT (1.51) TGTCGTCT (0.65) TGAGGTTT (3.35) TGAGGTTT (3.86) TGTTGTCT (0.69) TGTGGTTA (5.51) TGCGGTCT (4.94) CGTGGTTG (3.08)
4 TGTTGTTT (1.41) TGTCGTTC (0.65) TGCGGCTT (2.94) TGCGGTTG (3.75) TGGGGGTT (0.68) TGTGGTTT (4.77) TGTGGTCG (3.88) TGCGGTGT (2.93)
5 TGTCGTTA (1.41) TGTCGTTG (0.65) TGTGGCTT (2.52) TGCGGTTC (3.44) TGTCGTCT (0.65) TGCGGTCT (4.04) TGTGGTCA (3.26) CGTGGTTC (2.86)
6 TTTGGTTT (1.38) TGTCGCTT (0.63) CGTGGTTT (2.2) GGCGGTTT (2.88) TGTCGTTG (0.63) TGTGGTTG (3.55) TGTGGTTA (3.09) TGCGGGTT (2.61)
7 TGTAGTTT (1.32) TATCGTTT (0.62) AGTGGTTG (2.07) TGTGGTTG (2.76) CGTTGTTT (0.63) TGTGGTCA (3.1) TGTGGTTG (3.01) TGCGGTAT (2.61)
8 TGACGTTT (1.27) TGTCGTGT (0.56) AGTGGCTT (2.03) CGCGGTTT (2.74) TGGCGTTT (0.62) TGCGGTTC (3.01) TGCGGTTC (3.0) CGTGGTGT (2.54)
9 TGTTGGTT (1.18) TATGGTAT (0.55) TCTGGTTT (2.02) CGTGGTTT (2.65) TGTTGTGT (0.62) TGTGGTCG (2.09) TGCGGTGT (2.54) CGCGGTTT (2.48)
In [23]:
def generate_random_kmers(motif_string, n_kmers):
    random_kmers = set()
    nucleotides = ['A', 'C', 'G', 'T']
    seed = 0
    while len(random_kmers) < n_kmers:
        random_kmer_nucleotides = []
        for letter in motif_string:
            if letter in nucleotides:
                random_kmer_nucleotides.append(letter)
            elif letter == 'X':
                random.seed(seed)
                random_nucleotide = random.choice(nucleotides)
                random_kmer_nucleotides.append(random_nucleotide)
                seed += 1
        random_kmer = ''.join(random_kmer_nucleotides)
        random_kmers.add(random_kmer)
    return random_kmers
In [24]:
# core_motif = 'TGTGGT'
# motif_mutations_1bp = generate_all_1bp_substitutions(core_motif)
# motif_mutations_2bp = generate_all_2bp_substitutions(core_motif)
# motif_mutations = set(motif_mutations_1bp + motif_mutations_2bp)
# len(motif_mutations)
In [25]:
# motif_mutations_with_flanks_all = []
# for motif_mutation in motif_mutations:
#     motif_mutation_template = f'XXXXX{motif_mutation}XXXXX'
#     motif_mutations_with_flanks = generate_random_kmers(motif_mutation_template, n_kmers=20)
#     motif_mutations_with_flanks_all.extend(list(motif_mutations_with_flanks))

# len(set(motif_mutations_with_flanks_all))
In [84]:
core_motif = 'TGCGGT'

motif_mutation_template = f'XXXXX{core_motif}XXXXX'
motif_mutations_with_flanks_all = generate_random_kmers(motif_mutation_template, n_kmers=5000)
motif_mutations_with_flanks_all = list(motif_mutations_with_flanks_all)
len(set(motif_mutations_with_flanks_all))
Out[84]:
5000
In [86]:
models = ['Reh_HPA', 'Reh_PAS', 'Reh_DMSO_HPA', 'Reh_DMSO_RUNX1', 
          'REH_ETV6-RUNX1_rep1', 'REH_ETV6-RUNX1_rep2', 'Reh_ETV6_Atlas',
          'Reh_RUNX1_B', 'REH_RUNX1_rep1', 'REH_RUNX1_rep2']
output_file_name = 'affinity_distillation/motif_affinities_TGCGGT_with_flanks.csv'

motif_affinities = []
for i, motif in enumerate(motif_mutations_with_flanks_all):
    if i%10==0: print(i)
    affinities = []
    for model in models:
        affinity = affinity_distillation(motif, model)
        affinities.append(affinity)
    motif_affinities.append(affinities)
    if i%100==0:
        motif_affinities_df = pd.DataFrame(motif_affinities)
        motif_affinities_df.columns = models
        motif_affinities_df.index = motif_mutations_with_flanks_all[:i+1]
        motif_affinities_df.to_csv(output_file_name)

motif_affinities_df = pd.DataFrame(motif_affinities)
motif_affinities_df.columns = models
motif_affinities_df.index = motif_mutations_with_flanks_all
motif_affinities_df.to_csv(output_file_name)
motif_affinities_df
0
10
20
30
40
50
60
70
80
90
100
110
120
130
140
150
160
170
180
190
200
210
220
230
240
250
260
270
280
290
300
310
320
330
340
350
360
370
380
390
400
410
420
430
440
450
460
470
480
490
500
510
520
530
540
550
560
570
580
590
600
610
620
630
640
650
660
670
680
690
700
710
720
730
740
750
760
770
780
790
800
810
820
830
840
850
860
---------------------------------------------------------------------------
PermissionError                           Traceback (most recent call last)
/oak/stanford/groups/akundaje/shouvikm/miniconda3/envs/bpnet/lib/python3.7/site-packages/pyfaidx/__init__.py in read_fai(self)

/oak/stanford/groups/akundaje/shouvikm/miniconda3/envs/bpnet/lib/python3.7/site-packages/pyfaidx/__init__.py in _open_fai(self, mode)

PermissionError: [Errno 13] Permission denied: '/users/shouvikm/BPNet/data/hg38.genome.fa.fai'

During handling of the above exception, another exception occurred:

IndexNotFoundError                        Traceback (most recent call last)
/tmp/ipykernel_2116228/1407708525.py in <module>
      8     affinities = []
      9     for model in models:
---> 10         affinity = affinity_distillation(motif, model)
     11         affinities.append(affinity)
     12     motif_affinities.append(affinities)

/tmp/ipykernel_2116228/796358718.py in affinity_distillation(motif, model, fold)
     11     input_seq_len = 2114
     12     num_marginalization_samples = 100
---> 13     genome = Fasta("/users/shouvikm/BPNet/data/hg38.genome.fa")
     14     nucleotides = ['A', 'T', 'C', 'G']
     15 

/oak/stanford/groups/akundaje/shouvikm/miniconda3/envs/bpnet/lib/python3.7/site-packages/pyfaidx/__init__.py in __init__(self, filename, indexname, default_seq, key_function, as_raw, strict_bounds, read_ahead, mutable, split_char, filt_function, one_based_attributes, read_long_names, duplicate_action, sequence_always_upper, rebuild, build_index, gzi_indexname)

/oak/stanford/groups/akundaje/shouvikm/miniconda3/envs/bpnet/lib/python3.7/site-packages/pyfaidx/__init__.py in __init__(self, filename, indexname, default_seq, key_function, as_raw, strict_bounds, read_ahead, mutable, split_char, duplicate_action, filt_function, one_based_attributes, read_long_names, sequence_always_upper, rebuild, build_index, gzi_indexname)

/oak/stanford/groups/akundaje/shouvikm/miniconda3/envs/bpnet/lib/python3.7/site-packages/pyfaidx/__init__.py in read_fai(self)

IndexNotFoundError: Could not read index file /users/shouvikm/BPNet/data/hg38.genome.fa.fai
In [21]:
motif_affinities_df = pd.read_csv('affinity_distillation/motif_affinities_TGCGGT_with_flanks.csv', 
                                  index_col='Unnamed: 0')
motif_affinities_df = motif_affinities_df[models]
motif_affinities_df
Out[21]:
Reh_HPA Reh_PAS Reh_DMSO_HPA REH_ETV6-RUNX1_rep2 Reh_ETV6_Atlas Reh_DMSO_RUNX1 Reh_RUNX1_B REH_RUNX1_rep2
GTACATGCGGTAGGAT 0.008628 0.016436 0.078145 0.101990 0.005310 0.045214 0.036243 0.037676
GTAGTTGCGGTTGAGT 0.083942 0.092621 0.428186 0.679391 0.048088 0.398709 0.290851 0.504155
GATGCTGCGGTTTGAT 0.063410 0.101413 0.454801 0.578341 0.044561 0.479219 0.268625 0.581043
CGTAGTGCGGTGACTT 0.029609 0.010599 0.036979 0.169667 0.019989 0.061763 0.085255 0.107780
CGAGGTGCGGTGAGAG 0.026174 0.015470 0.078958 0.194234 0.008045 0.089969 0.069520 0.182262
... ... ... ... ... ... ... ... ...
TTACATGCGGTCGCGG 0.043044 0.028721 0.270344 0.395935 0.035427 0.396537 0.277153 0.235117
CCGCGTGCGGTATGGT 0.049635 0.083087 0.136061 0.400179 0.056562 0.118320 0.152819 0.197538
GGCGCTGCGGTGGTAA 0.058613 0.055354 0.239545 0.241693 0.052754 0.372992 0.180237 0.313991
TCACATGCGGTCTTAT 0.023508 0.036430 0.132633 0.124948 0.016026 0.149592 0.045602 0.173644
CAGATTGCGGTCCAAT 0.019122 0.026719 0.116408 0.251429 0.004488 0.120020 0.062137 0.224662

5000 rows × 8 columns

In [32]:
model_1 = 'Reh_HPA'
model_2 = 'Reh_RUNX1_B'
fig, ax = plt.subplots(figsize=(5, 5))
create_pairwise_affinities_scatter_plot(model_1, model_2, ax, 
                                        point_size_scale_residual=100,
                                        label_points_residual_percentile=99.8)
No description has been provided for this image
In [33]:
fusion_ER_models = ['Reh_HPA', 'Reh_PAS', 'Reh_DMSO_HPA','REH_ETV6-RUNX1_rep2', 
                    'Reh_ETV6_Atlas']
native_ER_models = ['Reh_DMSO_RUNX1', 'Reh_RUNX1_B', 'REH_RUNX1_rep2']


fig, axs = plt.subplots(3, 5, figsize=(20, 10))
for row, model_1 in enumerate(native_ER_models):
    for col, model_2 in enumerate(fusion_ER_models):
        create_pairwise_affinities_scatter_plot(model_2, model_1, axs[row, col], 
                                                point_size_scale_residual=100,
                                                label_points_residual_percentile=99.8)
plt.tight_layout()
No description has been provided for this image
In [22]:
model = 'Reh_HPA'
background_models = native_ER_models
differential_affinity_motifs = calculate_differential_affinity_motifs(model, background_models, 
                                                                      motif_affinities_df)
differential_affinity_motifs.head(10)
Out[22]:
CCGGTTGCGGTTATAC    3.977320
GTCGTTGCGGTTACGT    3.890415
CGCGCTGCGGTGGTTT    3.872483
CGGTATGCGGTAGTTA    3.848301
GCAGTTGCGGTTTTAA    3.791177
GCTGCTGCGGTTTTAT    3.552440
GCGGCTGCGGTTTAAG    3.494940
GCACATGCGGTTGCGG    3.479971
TCGCCTGCGGTAAACC    3.475359
ACCGGTGCGGTTGCTA    3.351693
Name: Reh_HPA, dtype: float64
In [23]:
model = 'Reh_RUNX1_B'
background_models = fusion_ER_models
differential_affinity_motifs = calculate_differential_affinity_motifs(model, background_models, 
                                                                      motif_affinities_df)
differential_affinity_motifs.head(10)
Out[23]:
GCGACTGCGGTCACGC    4.709977
AACCTTGCGGTCACGG    4.643236
GTCACTGCGGTCAGGC    4.454007
CACCTTGCGGTCACGG    4.273786
GAGGTTGCGGTCGGGC    4.244653
CGACTTGCGGTCGCGC    4.119065
ATGTTTGCGGTCGGCG    4.030583
AACGCTGCGGTCACGG    4.010167
CGCTTTGCGGTCGAGG    3.990771
AGGCTTGCGGTCACCG    3.958911
Name: Reh_RUNX1_B, dtype: float64
In [24]:
fusion_ER_models = ['Reh_HPA', 'Reh_PAS', 'Reh_DMSO_HPA','REH_ETV6-RUNX1_rep2', 
                    'Reh_ETV6_Atlas']
native_ER_models = ['Reh_DMSO_RUNX1', 'Reh_RUNX1_B', 'REH_RUNX1_rep2']

model_diff_affinity_motifs_df = pd.DataFrame()
for model in fusion_ER_models:
    differential_affinity_motifs = calculate_differential_affinity_motifs(model, native_ER_models, 
                                                                          motif_affinities_df)
    differential_affinity_motifs = (differential_affinity_motifs.index + ' (' + 
                                    differential_affinity_motifs.round(2).values.astype(str) + ')')
    model_diff_affinity_motifs_df[model] = differential_affinity_motifs[:10]
    
for model in native_ER_models:
    differential_affinity_motifs = calculate_differential_affinity_motifs(model, fusion_ER_models, 
                                                                          motif_affinities_df)
    differential_affinity_motifs = (differential_affinity_motifs.index + ' (' + 
                                    differential_affinity_motifs.round(2).values.astype(str) + ')')
    model_diff_affinity_motifs_df[model] = differential_affinity_motifs[:10]

multi_index_columns = pd.MultiIndex.from_tuples(
    [('Fusion ETV6-RUNX1 models', model + ' (score)') for model in fusion_ER_models] + 
    [('Native RUNX1 models', model + ' (score)') for model in native_ER_models]
)
model_diff_affinity_motifs_df.columns = multi_index_columns
model_diff_affinity_motifs_df
Out[24]:
Fusion ETV6-RUNX1 models Native RUNX1 models
Reh_HPA (score) Reh_PAS (score) Reh_DMSO_HPA (score) REH_ETV6-RUNX1_rep2 (score) Reh_ETV6_Atlas (score) Reh_DMSO_RUNX1 (score) Reh_RUNX1_B (score) REH_RUNX1_rep2 (score)
0 CCGGTTGCGGTTATAC (3.98) CGGTTTGCGGTTTGTG (5.59) TCAAATGCGGTAAACC (4.83) GGAAGTGCGGTTAGCC (5.14) TCCGCTGCGGTGGTTG (3.98) AGCTCTGCGGTCGGCC (4.28) GCGACTGCGGTCACGC (4.71) AACGCTGCGGTCACGG (4.09)
1 GTCGTTGCGGTTACGT (3.89) CGGTATGCGGTTCCAC (5.0) ACTAATGCGGTTTACC (3.71) CAGGATGCGGTTTTTA (4.98) ACACCTGCGGTGGGTA (3.69) CGACTTGCGGTCGCGC (4.23) AACCTTGCGGTCACGG (4.64) GATTTTGCGGTCAAAC (4.07)
2 CGCGCTGCGGTGGTTT (3.87) CGGTTTGCGGTTTCTG (4.92) CCACATGCGGTGGGAC (3.64) CAGGATGCGGTTTTTC (4.96) CCCCCTGCGGTTGTCG (3.5) AGCCCTGCGGTTATAC (3.8) GTCACTGCGGTCAGGC (4.45) CCATTTGCGGTCAAGG (3.93)
3 CGGTATGCGGTAGTTA (3.85) CGGTATGCGGTATCCT (4.78) AGCAATGCGGTTCACC (3.35) GCGGATGCGGTACCGC (4.76) ACCGGTGCGGTGTTGG (3.49) AGCCTTGCGGTCGGCC (3.59) CACCTTGCGGTCACGG (4.27) ATTTCTGCGGTCAGCG (3.5)
4 GCAGTTGCGGTTTTAA (3.79) CGGTTTGCGGTTACCC (4.23) TTCAATGCGGTTAATC (3.32) GAGGATGCGGTAGGAA (4.4) TCGTTTGCGGTAGCCA (3.4) CGACCTGCGGTTACAG (3.52) GAGGTTGCGGTCGGGC (4.24) GCTTCTGCGGTCAAAA (3.48)
5 GCTGCTGCGGTTTTAT (3.55) CCGCATGCGGTTACCA (4.13) GCACATGCGGTTGCGG (3.31) CCGGATGCGGTGTGGT (4.31) CCGGTTGCGGTTGTCT (3.25) AGGTCTGCGGTCAGCT (3.5) CGACTTGCGGTCGCGC (4.12) ATATCTGCGGTCACGG (3.43)
6 GCGGCTGCGGTTTAAG (3.49) CGGCATGCGGTTCCAC (4.06) TTATATGCGGTTTGAC (3.27) TCAAATGCGGTAAACC (4.04) CCGTTTGCGGTTCGGT (3.17) GGGCTTGCGGTCAAAT (3.49) ATGTTTGCGGTCGGCG (4.03) GTTCCTGCGGTCACAG (3.41)
7 GCACATGCGGTTGCGG (3.48) CGGCATGCGGTATATG (3.94) CAGGATGCGGTTTTTC (3.26) TGGTTTGCGGTTTGGT (3.61) CCGGATGCGGTGTGGT (3.17) AACCTTGCGGTCACGG (3.43) AACGCTGCGGTCACGG (4.01) TTTTCTGCGGTCACGG (3.34)
8 TCGCCTGCGGTAAACC (3.48) CGGCATGCGGTATATC (3.7) CCGAATGCGGTTGATC (3.22) TCGGATGCGGTTCAGT (3.58) GGGGTTGCGGTTGGTT (3.15) CGCTTTGCGGTCGAGG (3.41) CGCTTTGCGGTCGAGG (3.99) GTTTCTGCGGTAACAC (3.24)
9 ACCGGTGCGGTTGCTA (3.35) TGGTTTGCGGTTTGGT (3.66) CCACATGCGGTAGACT (3.2) CCAAGTGCGGTTTGGT (3.56) CCCGCTGCGGTGTCTT (3.13) CGCGCTGCGGTCAAAG (3.4) AGGCTTGCGGTCACCG (3.96) GGAGCTGCGGTCTCTT (3.22)
In [116]:
def affinity_distillation_profiles(motif, model, fold=0):
    model_file = f"/users/shouvikm/BPNet/models/{model}/fold_{fold}/model_split000"
    bpnet = load_model(model_file)

    # TODO: change for each fold
    test_chromosomes = ["chr1", "chr3", "chr6"]
    peak_file = glob.glob(f"/users/shouvikm/data/Cell_Lines_Ped_Leukemia/Reh/*/{model}/idr_peaks/peaks_inliers.bed")[0]
    peaks_df = pd.read_csv(peak_file, sep='\t', header=None)
    held_out_peaks_df = peaks_df[peaks_df[0].isin(test_chromosomes)]

    input_seq_len = 2114
    num_marginalization_samples = 100
    genome = Fasta("/users/shouvikm/BPNet/data/hg38.genome.fa")
    nucleotides = ['A', 'T', 'C', 'G']

    seed = 0
    random_sequences = []
    motif_inserted_sequences = []
    while(len(random_sequences) < num_marginalization_samples):
        random.seed(seed)

        random_peak_index = random.randint(0, held_out_peaks_df.shape[0]-1)
        random_peak = held_out_peaks_df.iloc[random_peak_index]
        chr, start, end = random_peak[0], random_peak[1], random_peak[2]
        peak_center = start + ((end - start) // 2)
        random_seq_start = peak_center - (input_seq_len // 2)
        random_seq_end = peak_center + (input_seq_len // 2)
        random_seq = genome[chr][random_seq_start:random_seq_end].seq

        all_valid_nucleotides = all([letter in nucleotides for letter in random_seq])
        if not all_valid_nucleotides: 
            seed += 1
            continue

        random_seq_one_hot = one_hot_encode(random_seq).unsqueeze(0)
        random_seq_shuf_one_hot = dinucleotide_shuffle(random_seq_one_hot, n=1, random_state=seed)
        random_seq_bpnet_input = random_seq_shuf_one_hot[0].numpy().transpose((0, 2, 1))\
                                    .astype("float32")
        random_sequences.append(random_seq_bpnet_input)
        
        motif_inserted_seq_one_hot = substitute(random_seq_shuf_one_hot[0], motif)
        motif_inserted_seq_bpnet_input = motif_inserted_seq_one_hot.numpy().transpose((0, 2, 1))\
                                            .astype("float32")
        motif_inserted_sequences.append(motif_inserted_seq_bpnet_input)
        
        seed += 1
    
    random_sequences = np.vstack(random_sequences)
    motif_inserted_sequences = np.vstack(motif_inserted_sequences)
        
    profile_bias = np.zeros((num_marginalization_samples, 1000, 2), dtype=np.float32)
    counts_bias = np.zeros((num_marginalization_samples, 2), dtype=np.float32)
    pred_profiles_random_sequences, _ = bpnet.predict([random_sequences, 
                                                       profile_bias, 
                                                       counts_bias])
    pred_profiles_motif_inserted_sequences, _ = bpnet.predict([motif_inserted_sequences, 
                                                               profile_bias, 
                                                               counts_bias])

    pred_profile_diffs = pred_profiles_motif_inserted_sequences - pred_profiles_random_sequences
    return pred_profile_diffs
In [117]:
model = 'Reh_HPA'
motif = 'TGCGGT'

profiles = affinity_distillation_profiles(motif, model)
profiles.shape
Out[117]:
(100, 1000, 2)
In [118]:
avg_profile = profiles.mean(axis=0)

plt.figure(figsize=(4, 2))
plt.plot(avg_profile[:, 1], color='lightblue')
plt.plot(avg_profile[:, 0], color='tab:blue')
Out[118]:
[<matplotlib.lines.Line2D at 0x7f8f154e8ed0>]
No description has been provided for this image
In [132]:
models = ['Reh_HPA', 'Reh_PAS', 'Reh_DMSO_HPA','REH_ETV6-RUNX1_rep2', 
          'Reh_ETV6_Atlas', 'Reh_DMSO_RUNX1', 'Reh_RUNX1_B', 
          'REH_RUNX1_rep2']
motifs = ['AAAAAA', 'CCCCCC', 'TGCGGT', 'TGCGGTT', 'TGCGGTTT', 
          'TGTGGTTGTGGT', 'GGAATGTGGT']

avg_profiles = np.zeros(shape=(len(motifs), len(models), 1000, 2))
for row, motif in enumerate(motifs):
    for col, model in enumerate(models):
        print(row, col)
        profiles = affinity_distillation_profiles(motif, model)
        avg_profile = profiles.mean(axis=0)
        avg_profiles[row, col] = avg_profile
0 0
0 1
0 2
0 3
0 4
0 5
0 6
0 7
1 0
1 1
1 2
1 3
1 4
1 5
1 6
1 7
2 0
2 1
2 2
2 3
2 4
2 5
2 6
2 7
3 0
3 1
3 2
3 3
3 4
3 5
3 6
3 7
4 0
4 1
4 2
4 3
4 4
4 5
4 6
4 7
5 0
5 1
5 2
5 3
5 4
5 5
5 6
5 7
6 0
6 1
6 2
6 3
6 4
6 5
6 6
6 7
In [134]:
fig, axs = plt.subplots(len(motifs), len(models), figsize=(20, 1*len(motifs)))
for row, motif in enumerate(motifs):
    for col, model in enumerate(models):
        ax = axs[row, col]
        ax.plot(avg_profiles[row, col, 0:1000, 1], color='lightblue')
        ax.plot(avg_profiles[row, col, 0:1000, 0], color='tab:blue')
        # Set y-axis min/max to compare motif profiles across a model
        y_min, y_max = avg_profiles[:, col, :, :].min(), avg_profiles[:, col, :, :].max()
        ax.set_ylim([y_min, y_max])
        if row < len(motifs)-1: ax.set_xticks([])
        if row == 0: ax.set_title(model)
        if col == 0: ax.set_ylabel(motif, rotation=0, labelpad=50, fontsize=12)
No description has been provided for this image
In [ ]: