Source code for squid.mutagenizer

import os
#os.environ["TQDM_DISABLE"] = "1"
import numpy as np
from tqdm import tqdm


[docs] class BaseMutagenesis: """ Base class for in silico MAVE data generation for a given sequence. """
[docs] def __call__(self, x, num_sim): """Return an in silico MAVE based on mutagenesis of 'x'. Parameters ---------- x : torch.Tensor one-hot sequence (shape: (L, A)). num_sim : int Number of sequences to mutagenize. Returns ------- torch.Tensor Batch of one-hot sequences with random augmentation applied. """ raise NotImplementedError()
[docs] class RandomMutagenesis(BaseMutagenesis): """Module for performing random mutagenesis. Parameters ---------- mut_rate : float, optional Mutation rate for random mutagenesis (defaults to 0.1). uniform : bool uniform (True), Poisson (False); sets the number of mutations per sequence. seed : int, optional Random seed for reproducibility. If None, results will not be reproducible. (defaults to None) Returns ------- numpy.ndarray Batch of one-hot sequences with random mutagenesis applied. """ def __init__(self, mut_rate, uniform=False, seed=None): self.mut_rate = mut_rate self.uniform = uniform self.seed = seed
[docs] def __call__(self, x, num_sim): if self.seed is not None: np.random.seed(self.seed) L, A = x.shape avg_num_mut = int(np.ceil(self.mut_rate*L)) # get indices of nucleotides x_index = np.argmax(x, axis=1) # sample number of mutations for each sequence if self.uniform: num_muts = int(avg_num_mut*np.ones((num_sim,), dtype=int)) else: num_muts = np.random.poisson(avg_num_mut, (num_sim, 1))[:,0] num_muts = np.clip(num_muts, 0, L) one_hot = apply_mut_by_seq_index(x_index, (num_sim,L,A), num_muts) return one_hot
[docs] class CombinatorialMutagenesis(): """Module for performing combinatorial mutagenesis. Parameters ---------- max_order : int, optional Maximum order of mutations to generate. If -1, generates all possible combinations. If 1, generates only single mutations (all SNVs). If 2, generates single and double mutations, etc. Must be less than or equal to sequence length L, or -1 for all combinations. (defaults to -1) mut_window : [int, int], optional Index of start and stop position along sequence to probe for mutations. If provided, only generates mutations within this window (inclusive on both ends). For example, mut_window=[4,6] will generate mutations at positions 4, 5, and 6. (defaults to None, which means the entire sequence is considered) batch_size : int, optional Batch size for one-hot encoding conversion. If None, converts all at once. For large sequences, using a batch size can help manage memory usage. (defaults to None) seed : int, optional Random seed for reproducibility. If None, results will not be reproducible. (defaults to None) Returns ---------- numpy.ndarray Batch of one-hot sequences with combinatorial mutagenesis applied. For max_order=-1: number of sequences is A^L For max_order=k: number of sequences is 1 + sum(n_choose_r * (A-1)^r) for r in 1..k where: - L is sequence length - A is alphabet size - n_choose_r is the binomial coefficient (L choose r) - The leading 1 accounts for the reference sequence Examples -------- For L=4, A=4: - max_order=1: 1 + C(4,1)*(3^1) = 1 + 12 = 13 sequences - max_order=2: 1 + C(4,1)*(3^1) + C(4,2)*(3^2) = 1 + 12 + 54 = 67 sequences Raises ------ ValueError If max_order is greater than sequence length L or less than -1 """ def __init__(self, max_order=-1, mut_window=None, batch_size=256, seed=None): if max_order < -1: raise ValueError("max_order must be -1 or a non-negative integer") self.max_order = max_order self.mut_window = mut_window self.batch_size = batch_size self.seed = seed
[docs] def __call__(self, x, num_sim): # 'num_sim' will be ignored if self.seed is not None: np.random.seed(self.seed) L, A = x.shape # If mut_window is provided, we'll only consider positions within that window if self.mut_window is not None: start_pos, stop_pos = self.mut_window stop_pos = stop_pos + 1 # Make stop_pos exclusive to include the last position window_length = stop_pos - start_pos if window_length <= 0: raise ValueError("mut_window stop_pos must be greater than or equal to start_pos") if start_pos < 0 or stop_pos > L: raise ValueError(f"mut_window must be within sequence bounds [0, {L}]") else: start_pos, stop_pos = 0, L window_length = L if self.max_order > window_length: raise ValueError(f"max_order ({self.max_order}) cannot exceed window length ({window_length})") x_index = np.argmax(x, axis=1) # Get reference sequence indices from itertools import combinations, product # If max_order is -1, set it to window_length for complete enumeration max_order = window_length if self.max_order == -1 else self.max_order # Pre-calculate total size and allocate array total_variants = 1 + sum( # +1 for reference sequence len(list(combinations(range(start_pos, stop_pos), order))) * (A-1)**order for order in range(1, max_order + 1) ) all_variants = np.zeros((total_variants, L), dtype=np.int8) all_variants[0] = x_index # Add reference sequence # Pre-compute alternative bases for each position alt_bases_lookup = {i: np.array([b for b in range(A) if b != base]) for i, base in enumerate(x_index[start_pos:stop_pos], start=start_pos)} current_idx = 1 # Start after reference sequence # Generate variants for each order up to max_order for order in range(1, max_order + 1): n_positions = len(list(combinations(range(start_pos, stop_pos), order))) n_variants = n_positions * (A-1)**order with tqdm(total=n_variants, desc=f"Order {order} mutations") as pbar: for pos in combinations(range(start_pos, stop_pos), order): # Get pre-computed alternative bases for these positions alt_bases_per_pos = [alt_bases_lookup[p] for p in pos] # Generate all combinations at once for this position set alt_combos = np.array(list(product(*alt_bases_per_pos))) n_combos = len(alt_combos) # Create variants for all combinations at once new_seqs = np.tile(x_index, (n_combos, 1)) new_seqs[:, pos] = alt_combos # Add to pre-allocated array all_variants[current_idx:current_idx + n_combos] = new_seqs current_idx += n_combos pbar.update(n_combos) print("Converting to one-hot encoding...") if self.batch_size is None: # Convert all at once one_hot = np.eye(A, dtype=np.int8)[np.ascontiguousarray(all_variants)] else: # Convert in batches n_sequences = len(all_variants) one_hot = np.zeros((n_sequences, L, A), dtype=np.int8) for i in tqdm(range(0, n_sequences, self.batch_size), desc="One-hot encoding"): batch_end = min(i + self.batch_size, n_sequences) one_hot[i:batch_end] = np.eye(A, dtype=np.int8)[all_variants[i:batch_end]] return one_hot
[docs] class TwoHotMutagenesis(BaseMutagenesis): """Module to perform random mutagenesis using two-hot encoding. That is, encode each individual nucleotide at a given position using a one-hot encoding scheme, then represent the unphased diploid sequence as the sum of the two one-hot encoded nucleotides at each position. The sequence "AYCR", for example, would be encoded as: [[2, 0, 0, 0], [0, 1, 0, 1], [0, 2, 0, 0], [1, 0, 1, 0]]. Parameters ---------- mut_rate : float Mutation rate for random mutagenesis. uniform : bool, optional uniform (True), Poisson (False); sets the number of mutations per sequence. (defaults to False) seed : int, optional Random seed for reproducibility. If None, results will not be reproducible. (defaults to None) Returns ---------- numpy.ndarray Batch of one-hot sequences with random mutagenesis applied, with alphabet: {A, C, G, T, R (A/G), Y (C/T), S (C/G), W (A/T), K (G/T), M (A/C)}, such that heterozygous positions are represented using the IUPAC ambiguity codes. """ def __init__(self, mut_rate, uniform=False, seed=None): self.mut_rate = mut_rate self.uniform = uniform self.seed = seed
[docs] def __call__(self, x, num_sim): if self.seed is not None: np.random.seed(self.seed) from numpy.random import choice from numpy.random import poisson def swap_elements(x, t): """Per iteration, use numpy.random.choice to randomly select elements where replacements will occur in the original list. Then zip those indices against the values used for the substitution and apply the replacements. """ new_x = x[:] for idx, value in zip(choice(range(len(x)), size=len(t), replace=False), t): new_x[idx] = value return new_x L, A = x.shape # ensure A=4 for this module alphabet_pool = ['A', 'C', 'G', 'T', 'R', 'Y', 'S', 'W', 'K', 'M'] # pool for selecting characters seq = twohot2seq(x) seq = [*seq] # set up number of mutations to sample for each sequence avg_num_mut = int(np.ceil(self.mut_rate*L)) if self.uniform: num_muts = int(avg_num_mut*np.ones((num_sim,), dtype=int)) + 1 else: num_muts = poisson(avg_num_mut, (num_sim, 1))[:,0] + 1 # mutagenize each sequence based on number of mutations; i.e., samples from alphabet pool one_hot = np.zeros(shape=(num_sim, L, A)) for i, num_mut in enumerate(tqdm(num_muts, desc="Mutagenesis")): if i == 0: one_hot[i,:,:] = seq2twohot(''.join(seq)) else: options_list = choice(alphabet_pool, size=num_mut, replace=True) # sample 'num_mut' characters from alphabet_pool with replacement mut_seq = ''.join(swap_elements(seq, options_list)) one_hot[i,:,:] = seq2twohot(mut_seq) return one_hot
""" class CustomMutagenesis(BaseMutagenesis): def __init__(self, param1, param2): self.param1 = param1 self.param2 = param2 def __call__(self, x, num_sim): # code goes here return one_hot """ ################################################################################ # useful functions ################################################################################
[docs] def apply_mut_by_seq_index(x_index, shape, num_muts): """Function to perform random mutagenesis. Parameters ---------- x_index : np.ndarray Indices of wildtype sequence. shape : list Shape of MAVE array; i.e., (num_sim,L,A). num_muts : int Number of mutations per sequence. Returns ------- torch.Tensor Batch of one-hot sequences with random mutagenesis applied. """ num_sim, L, A = shape one_hot = np.zeros((num_sim, L, A)) # loop through and generate random mutagenesis for i, num_mut in enumerate(tqdm(num_muts, desc="Mutagenesis")): if i == 0: # keep wild-type sequence one_hot[i,:,:] = np.eye(A)[x_index] else: # generate mutation index mut_index = np.random.choice(range(0, L), num_mut, replace=False) # sample alphabet mut = np.random.choice(range(1, A), (len(mut_index))) # loop through sequence and add mutation index (note: up to 3 is added which does not map to [0,3] alphabet) seq_index = np.copy(x_index) for j, m in zip(mut_index, mut): seq_index[j] += m # wrap non-sensical indices back to alphabet -- effectively makes it random mutation seq_index = np.mod(seq_index, A) # create one-hot from index one_hot[i,:,:] = np.eye(A)[seq_index] return one_hot.astype('uint8')
[docs] def twohot2seq(one_hot): """Function to convert two-hot encoding to a DNA sequence. Parameters ---------- one_hot : numpy.ndarray Input one-hot encoding of sequence (shape : (L,C)) Returns ------- seq : string Input sequence with length L. """ seq = [] for i in range(one_hot.shape[0]): if np.array_equal(one_hot[i,:], np.array([2, 0, 0, 0])): seq.append('A') elif np.array_equal(one_hot[i,:], np.array([0, 2, 0, 0])): seq.append('C') elif np.array_equal(one_hot[i,:], np.array([0, 0, 2, 0])): seq.append('G') elif np.array_equal(one_hot[i,:], np.array([0, 0, 0, 2])): seq.append('T') elif np.array_equal(one_hot[i,:], np.array([0, 0, 0, 0])): seq.append('N') elif np.array_equal(one_hot[i,:], np.array([1, 1, 0, 0])): seq.append('M') elif np.array_equal(one_hot[i,:], np.array([1, 0, 1, 0])): seq.append('R') elif np.array_equal(one_hot[i,:],np.array([1, 0, 0, 1])): seq.append('W') elif np.array_equal(one_hot[i,:], np.array([0, 1, 1, 0])): seq.append('S') elif np.array_equal(one_hot[i,:], np.array([0, 1, 0, 1])): seq.append('Y') elif np.array_equal(one_hot[i,:], np.array([0, 0, 1, 1])): seq.append('K') seq = ''.join(seq) return seq
[docs] def seq2twohot(seq): """Function to convert heterozygous DNA sequence to two-hot encoding. Parameters ---------- seq : string Input sequence with length L. Returns ------- one_hot : numpy.ndarray Input one-hot encoding of sequence (shape : (L,C)) """ seq_list = list(seq.upper()) # get sequence into an array # one hot the sequence encoding = { "A": np.array([2, 0, 0, 0]), "C": np.array([0, 2, 0, 0]), "G": np.array([0, 0, 2, 0]), "T": np.array([0, 0, 0, 2]), "N": np.array([0, 0, 0, 0]), "M": np.array([1, 1, 0, 0]), "R": np.array([1, 0, 1, 0]), "W": np.array([1, 0, 0, 1]), "S": np.array([0, 1, 1, 0]), "Y": np.array([0, 1, 0, 1]), "K": np.array([0, 0, 1, 1]), } one_hot = [encoding.get(seq, seq) for seq in seq_list] one_hot = np.array(one_hot) return one_hot
[docs] def get_alternative_bases(ref_base, A): """Get all possible alternative bases for a given reference base.""" return [b for b in range(A) if b != ref_base]
if __name__ == "__main__": if 1: print("\nTesting CombinatorialMutagenesis:")
[docs] L = 10 # Change this value to test different lengths
A = 4 # Alphabet size (A,C,G,T) # Create one-hot encoding for sequence of all A's x = np.zeros((L, A)) x[:, 0] = 1 # Set first position (A) to 1 for all positions # Test with different max_order values for max_order in [2]: mut = CombinatorialMutagenesis(max_order=max_order, mut_window=[4, 6]) result = mut(x, num_sim=None) # Convert results back to sequences for easy viewing sequences = [] nucleotides = ['A', 'C', 'G', 'T'] for seq in result: seq_indices = np.argmax(seq, axis=1) sequences.append(''.join([nucleotides[idx] for idx in seq_indices])) print(f"\nmax_order = {max_order}:") print(f"Number of sequences generated: {len(sequences)}") if len(sequences) < 50: # Only print sequences if there aren't too many print("Sequences:") for seq in sequences: print(seq) else: print("\nTesting RandomMutagenesis:") L = 20 # sequence length A = 4 # alphabet size # Create one-hot encoding for sequence of all A's x = np.zeros((L, A)) x[:, 0] = 1 # Set first position (A) to 1 for all positions # Test with Poisson mutations, 10% mutation rate mut = RandomMutagenesis(mut_rate=0.1, uniform=False, seed=42) result = mut(x, num_sim=10) # Convert results back to sequences for easy viewing sequences = [] nucleotides = ['A', 'C', 'G', 'T'] for seq in result: seq_indices = np.argmax(seq, axis=1) sequences.append(''.join([nucleotides[idx] for idx in seq_indices])) print(f"Input sequence: {'A' * L}") print("\nMutated sequences:") for seq in sequences: print(seq)