finemo.hitcaller
Hit caller module implementing the Fi-NeMo motif instance calling algorithm.
This module provides the core functionality for identifying transcription factor binding motif instances in neural network contribution scores using a competitive optimization approach based on proximal gradient descent.
The main algorithm fits a sparse linear model where contribution scores are reconstructed as a weighted combination of motif contribution weight matrices (CWMs) at specific genomic positions. The sparsity constraint ensures that only the most significant motif instances are called.
1"""Hit caller module implementing the Fi-NeMo motif instance calling algorithm. 2 3This module provides the core functionality for identifying transcription factor 4binding motif instances in neural network contribution scores using a competitive 5optimization approach based on proximal gradient descent. 6 7The main algorithm fits a sparse linear model where contribution scores are 8reconstructed as a weighted combination of motif contribution weight matrices (CWMs) 9at specific genomic positions. The sparsity constraint ensures that only the most 10significant motif instances are called. 11""" 12 13import warnings 14from typing import Tuple, Union, Optional, Dict, List 15from abc import ABC, abstractmethod 16 17import numpy as np 18from numpy import ndarray 19import torch 20import torch.nn.functional as F 21from torch import Tensor 22import polars as pl 23from jaxtyping import Float, Int, Bool 24 25from tqdm import tqdm 26 27 28def prox_grad_step( 29 coefficients: Float[Tensor, "B M P"], 30 importance_scale: Float[Tensor, "B 1 P"], 31 cwms: Float[Tensor, "M 4 W"], 32 contribs: Float[Tensor, "B 4 L"], 33 sequences: Union[Int[Tensor, "B 4 L"], int], 34 lambdas: Float[Tensor, "1 M 1"], 35 step_sizes: Float[Tensor, "B 1 1"], 36) -> Tuple[Float[Tensor, "B M P"], Float[Tensor, " B"], Float[Tensor, " B"]]: 37 """Perform a proximal gradient descent optimization step for non-negative lasso. 38 39 This function implements a single optimization step of the Fi-NeMo algorithm, 40 which uses proximal gradient descent to solve a sparse reconstruction problem. 41 The goal is to represent contribution scores as a sparse linear combination 42 of motif contribution weight matrices (CWMs). 43 44 Dimension notation: 45 - B = batch size (number of regions processed simultaneously) 46 - M = number of motifs 47 - L = sequence length 48 - W = motif width (length of each motif) 49 - P = L - W + 1 (number of valid motif positions) 50 51 Parameters 52 ---------- 53 coefficients : Float[Tensor, "B M P"] 54 Current coefficient matrix representing motif instance strengths. 55 importance_scale : Float[Tensor, "B 1 P"] 56 Scaling factors for importance-weighted reconstruction. 57 cwms : Float[Tensor, "M 4 W"] 58 Motif contribution weight matrices for all motifs. 59 4 represents the DNA bases (A, C, G, T). 60 contribs : Float[Tensor, "B 4 L"] 61 Target contribution scores to reconstruct. 62 sequences : Float[Tensor, "B 4 L"] | int 63 One-hot encoded DNA sequences. Can be a scalar (1) for hypothetical mode. 64 lambdas : Float[Tensor, "1 M 1"] 65 L1 regularization weights for each motif. 66 step_sizes : Float[Tensor, "B 1 1"] 67 Optimization step sizes for each batch element. 68 69 Returns 70 ------- 71 c_next : Float[Tensor, "B M P"] 72 Updated coefficient matrix after the optimization step (shape: batch_size × motifs × positions). 73 dual_gap : Float[Tensor, " B"] 74 Duality gap for convergence assessment (shape: batch_size). 75 nll : Float[Tensor, " B"] 76 Negative log likelihood (proportional to MSE, shape: batch_size). 77 78 Notes 79 ----- 80 The algorithm uses proximal gradient descent to solve: 81 82 minimize_c: ||contribs - conv_transpose(c * importance_scale, cwms) * sequences||²₂ + λ||c||₁ 83 84 subject to: c ≥ 0 85 86 References 87 ---------- 88 - Proximal gradient descent: https://yuxinchen2020.github.io/ele520_math_data/lectures/lasso_algorithm_extension.pdf, slide 22 89 - Duality gap computation: https://stanford.edu/~boyd/papers/pdf/l1_ls.pdf, Section III 90 """ 91 # Forward pass: convolution operations require specific tensor layouts 92 coef_adj = coefficients * importance_scale 93 pred_unmasked = F.conv_transpose1d(coef_adj, cwms) # (B, 4, L) 94 pred = ( 95 pred_unmasked * sequences 96 ) # (B, 4, L), element-wise masking for projected mode 97 98 # Compute gradient * -1 99 residuals = contribs - pred # (B, 4, L) 100 ngrad = F.conv1d(residuals, cwms) * importance_scale # (B, M, P) 101 102 # Negative log likelihood (proportional to MSE) 103 nll = (residuals**2).sum(dim=(1, 2)) # (B) 104 105 # Compute duality gap for convergence assessment 106 dual_norm = (ngrad / lambdas).amax(dim=(1, 2)) # (B) 107 dual_scale = (torch.clamp(1 / dual_norm, max=1.0) ** 2 + 1) / 2 # (B) 108 nll_scaled = nll * dual_scale # (B) 109 110 dual_diff = (residuals * contribs).sum(dim=(1, 2)) # (B) 111 l1_term = (torch.abs(coefficients).sum(dim=2, keepdim=True) * lambdas).sum( 112 dim=(1, 2) 113 ) # (B) 114 dual_gap = (nll_scaled - dual_diff + l1_term).abs() # (B) 115 116 # Compute proximal gradient descent step 117 c_next = coefficients + step_sizes * (ngrad - lambdas) # (B, M, P) 118 c_next = F.relu(c_next) # Ensure non-negativity constraint 119 120 return c_next, dual_gap, nll 121 122 123def optimizer_step( 124 cwms: Float[Tensor, "M 4 W"], 125 contribs: Float[Tensor, "B 4 L"], 126 importance_scale: Float[Tensor, "B 1 P"], 127 sequences: Union[Int[Tensor, "B 4 L"], int], 128 coef_inter: Float[Tensor, "B M P"], 129 coef: Float[Tensor, "B M P"], 130 i: Float[Tensor, "B 1 1"], 131 step_sizes: Float[Tensor, "B 1 1"], 132 L: int, 133 lambdas: Float[Tensor, "1 M 1"], 134) -> Tuple[ 135 Float[Tensor, "B M P"], 136 Float[Tensor, "B M P"], 137 Float[Tensor, " B"], 138 Float[Tensor, " B"], 139]: 140 """Perform a non-negative lasso optimizer step with Nesterov momentum. 141 142 This function combines proximal gradient descent with momentum acceleration 143 to improve convergence speed while maintaining the non-negative constraint 144 on coefficients. 145 146 Dimension notation: 147 - B = batch size (number of regions processed simultaneously) 148 - M = number of motifs 149 - L = sequence length 150 - W = motif width (length of each motif) 151 - P = L - W + 1 (number of valid motif positions) 152 153 Parameters 154 ---------- 155 cwms : Float[Tensor, "M 4 W"] 156 Motif contribution weight matrices. 157 contribs : Float[Tensor, "B 4 L"] 158 Target contribution scores. 159 importance_scale : Float[Tensor, "B 1 P"] 160 Importance scaling factors. 161 sequences : Union[Int[Tensor, "B 4 L"], int] 162 One-hot encoded sequences or scalar for hypothetical mode. 163 coef_inter : Float[Tensor, "B M P"] 164 Intermediate coefficient matrix (with momentum). 165 coef : Float[Tensor, "B M P"] 166 Current coefficient matrix. 167 i : Float[Tensor, "B 1 1"] 168 Iteration counter for each batch element. 169 step_sizes : Float[Tensor, "B 1 1"] 170 Step sizes for optimization. 171 L : int 172 Sequence length for normalization. 173 lambdas : Float[Tensor, "1 M 1"] 174 Regularization parameters. 175 176 Returns 177 ------- 178 coef_inter : Float[Tensor, "B M P"] 179 Updated intermediate coefficients with momentum (shape: batch_size × motifs × positions). 180 coef : Float[Tensor, "B M P"] 181 Updated coefficient matrix (shape: batch_size × motifs × positions). 182 gap : Float[Tensor, " B"] 183 Normalized duality gap (shape: batch_size). 184 nll : Float[Tensor, " B"] 185 Normalized negative log likelihood (shape: batch_size). 186 187 Notes 188 ----- 189 Uses Nesterov momentum with momentum coefficient i/(i+3) for improved 190 convergence properties. The duality gap and NLL are normalized by 191 sequence length for scale-invariant convergence assessment. 192 193 References 194 ---------- 195 https://yuxinchen2020.github.io/ele520_math_data/lectures/lasso_algorithm_extension.pdf, slides 22, 27 196 """ 197 coef_prev = coef 198 199 # Proximal gradient descent step 200 coef, gap, nll = prox_grad_step( 201 coef_inter, importance_scale, cwms, contribs, sequences, lambdas, step_sizes 202 ) 203 gap = gap / L 204 nll = nll / (2 * L) 205 206 # Compute updated coefficients with Nesterov momentum 207 mom_term = i / (i + 3.0) 208 coef_inter = (1 + mom_term) * coef - mom_term * coef_prev 209 210 return coef_inter, coef, gap, nll 211 212 213def _to_channel_last_layout(tensor: Tensor, **kwargs) -> torch.Tensor: 214 """Convert tensor to channel-last memory layout for optimized convolution operations. 215 216 Parameters 217 ---------- 218 tensor : torch.Tensor 219 Input tensor to convert. 220 **kwargs 221 Additional keyword arguments passed to tensor.to(). 222 223 Returns 224 ------- 225 torch.Tensor 226 Tensor with channel-last memory layout. 227 """ 228 return ( 229 tensor[:, :, :, None].to(memory_format=torch.channels_last, **kwargs).squeeze(3) 230 ) 231 232 233def _signed_sqrt(x: torch.Tensor) -> torch.Tensor: 234 """Apply signed square root transformation to input tensor. 235 236 This transformation preserves the sign while applying square root to the 237 absolute value, which can help with numerical stability and gradient flow. 238 239 Parameters 240 ---------- 241 x : torch.Tensor 242 Input tensor. 243 244 Returns 245 ------- 246 torch.Tensor 247 Transformed tensor with same shape as input. 248 """ 249 return torch.sign(x) * torch.sqrt(torch.abs(x)) 250 251 252class BatchLoaderBase(ABC): 253 """Base class for loading batches of contribution scores and sequences. 254 255 This class provides common functionality for different input formats 256 including batch indexing and padding for consistent batch sizes. 257 258 Dimension notation: 259 - N = number of sequences/regions in dataset 260 - L = sequence length 261 - B = batch size (number of regions processed simultaneously) 262 263 Parameters 264 ---------- 265 contribs : Union[Float[Tensor, "N 4 L"], Float[Tensor, "N L"]] 266 Contribution scores array. 267 sequences : Int[Tensor, "N 4 L"] 268 One-hot encoded sequences array. 269 L : int 270 Sequence length. 271 device : torch.device 272 Target device for tensor operations. 273 """ 274 275 def __init__( 276 self, 277 contribs: Union[Float[Tensor, "N 4 L"], Float[Tensor, "N L"]], 278 sequences: Int[Tensor, "N 4 L"], 279 L: int, 280 device: torch.device, 281 ) -> None: 282 self.contribs = contribs 283 self.sequences = sequences 284 self.L = L 285 self.device = device 286 287 def _get_inds_and_pad_lens( 288 self, start: int, end: int 289 ) -> Tuple[Int[Tensor, " Z"], Tuple[int, ...]]: 290 """Get indices and padding lengths for batch loading. 291 292 Parameters 293 ---------- 294 start : int 295 Start index for batch. 296 end : int 297 End index for batch. 298 299 Returns 300 ------- 301 inds : Int[Tensor, " Z"] 302 Padded indices tensor with -1 for padding positions (shape: padded_batch_size). 303 pad_lens : tuple 304 Padding specification for F.pad (left, right, top, bottom, front, back). 305 """ 306 N = end - start 307 end = min(end, self.contribs.shape[0]) 308 overhang = N - (end - start) 309 pad_lens = (0, 0, 0, 0, 0, overhang) 310 311 inds = F.pad( 312 torch.arange(start, end, dtype=torch.int), (0, overhang), value=-1 313 ).to(device=self.device) 314 315 return inds, pad_lens 316 317 @abstractmethod 318 def load_batch( 319 self, start: int, end: int 320 ) -> Tuple[ 321 Float[Tensor, "B 4 L"], Union[Int[Tensor, "B 4 L"], int], Int[Tensor, " B"] 322 ]: 323 """Load a batch of data. 324 325 Dimension notation: 326 - B = batch size (number of regions in this batch) 327 - L = sequence length 328 329 Parameters 330 ---------- 331 start : int 332 Start index (used by subclasses). 333 end : int 334 End index (used by subclasses). 335 336 Returns 337 ------- 338 contribs_batch : Float[Tensor, "B 4 L"] 339 Batch of contribution scores (shape: batch_size × 4_bases × L). 340 sequences_batch : Union[Int[Tensor, "B 4 L"], int] 341 Batch of one-hot encoded sequences (shape: batch_size × 4_bases × L) or scalar 1 for hypothetical mode. 342 inds_batch : Int[Tensor, " B"] 343 Batch indices mapping to original sequence indices (shape: batch_size). 344 345 Notes 346 ----- 347 This is an abstract method that must be implemented by subclasses. 348 Parameters are intentionally unused in the base implementation. 349 """ 350 pass 351 352 353class BatchLoaderCompactFmt(BatchLoaderBase): 354 """Batch loader for compact format contribution scores. 355 356 Handles contribution scores in shape (N, L) representing projected 357 scores that need to be broadcasted to (N, 4, L) format. 358 """ 359 360 def load_batch( 361 self, start: int, end: int 362 ) -> Tuple[Float[Tensor, "B 4 L"], Int[Tensor, "B 4 L"], Int[Tensor, " B"]]: 363 inds, pad_lens = self._get_inds_and_pad_lens(start, end) 364 365 contribs_compact = F.pad(self.contribs[start:end, None, :], pad_lens) 366 contribs_batch = _to_channel_last_layout( 367 contribs_compact, device=self.device, dtype=torch.float32 368 ) 369 sequences_batch = F.pad(self.sequences[start:end, :, :], pad_lens) # (B, 4, L) 370 sequences_batch = _to_channel_last_layout( 371 sequences_batch, device=self.device, dtype=torch.int8 372 ) 373 374 contribs_batch = contribs_batch * sequences_batch # (B, 4, L) 375 376 return contribs_batch, sequences_batch, inds 377 378 379class BatchLoaderProj(BatchLoaderBase): 380 """Batch loader for projected contribution scores. 381 382 Handles contribution scores in shape (N, 4, L) where scores are 383 element-wise multiplied by one-hot sequences to get projected contributions. 384 """ 385 386 def load_batch( 387 self, start: int, end: int 388 ) -> Tuple[Float[Tensor, "B 4 L"], Int[Tensor, "B 4 L"], Int[Tensor, " B"]]: 389 inds, pad_lens = self._get_inds_and_pad_lens(start, end) 390 391 contribs_hyp = F.pad(self.contribs[start:end, :, :], pad_lens) 392 contribs_hyp = _to_channel_last_layout( 393 contribs_hyp, device=self.device, dtype=torch.float32 394 ) 395 sequences_batch = F.pad(self.sequences[start:end, :, :], pad_lens) # (B, 4, L) 396 sequences_batch = _to_channel_last_layout( 397 sequences_batch, device=self.device, dtype=torch.int8 398 ) 399 contribs_batch = contribs_hyp * sequences_batch 400 401 return contribs_batch, sequences_batch, inds 402 403 404class BatchLoaderHyp(BatchLoaderBase): 405 """Batch loader for hypothetical contribution scores. 406 407 Handles hypothetical contribution scores in shape (N, 4, L) where 408 scores represent counterfactual effects of base substitutions. 409 """ 410 411 def load_batch( 412 self, start: int, end: int 413 ) -> Tuple[Float[Tensor, "B 4 L"], int, Int[Tensor, " B"]]: 414 inds, pad_lens = self._get_inds_and_pad_lens(start, end) 415 416 contribs_batch = F.pad(self.contribs[start:end, :, :], pad_lens) 417 contribs_batch = _to_channel_last_layout( 418 contribs_batch, device=self.device, dtype=torch.float32 419 ) 420 421 return contribs_batch, 1, inds 422 423 424def fit_contribs( 425 cwms: Float[ndarray, "M 4 W"], 426 contribs: Union[Float[ndarray, "N 4 L"], Float[ndarray, "N L"]], 427 sequences: Int[ndarray, "N 4 L"], 428 cwm_trim_mask: Float[ndarray, "M W"], 429 use_hypothetical: bool, 430 lambdas: Float[ndarray, " M"], 431 step_size_max: float = 3.0, 432 step_size_min: float = 0.08, 433 sqrt_transform: bool = False, 434 convergence_tol: float = 0.0005, 435 max_steps: int = 10000, 436 batch_size: int = 2000, 437 step_adjust: float = 0.7, 438 post_filter: bool = True, 439 device: Optional[torch.device] = None, 440 compile_optimizer: bool = False, 441 eps: float = 1.0, 442) -> Tuple[pl.DataFrame, pl.DataFrame]: 443 """Call motif hits by fitting sparse linear model to contribution scores. 444 445 This is the main function implementing the Fi-NeMo algorithm. It identifies 446 motif instances by solving a sparse reconstruction problem where contribution 447 scores are approximated as a linear combination of motif CWMs at specific 448 positions. The optimization uses proximal gradient descent with momentum. 449 450 Parameters 451 ---------- 452 cwms : Float[ndarray, "M 4 W"] 453 Motif contribution weight matrices where: 454 - M = number of motifs (transcription factor binding patterns) 455 - 4 = DNA bases (A, C, G, T dimensions) 456 - W = motif width (length of each motif pattern) 457 contribs : Float[ndarray, "N 4 L"] | Float[ndarray, "N L"] 458 Neural network contribution scores where: 459 - N = number of regions in dataset 460 - L = sequence length 461 Can be hypothetical (N, 4, L) or projected (N, L) format. 462 sequences : Int[ndarray, "N 4 L"] 463 One-hot encoded DNA sequences (shape: num_regions × 4_bases × L). 464 cwm_trim_mask : Float[ndarray, "M W"] 465 Binary mask indicating which positions of each CWM to use (shape: num_motifs × motif_width). 466 use_hypothetical : bool 467 Whether to use hypothetical contribution scores (True) or 468 projected scores (False). 469 lambdas : Float[ndarray, " M"] 470 L1 regularization weights for each motif (shape: num_motifs). 471 step_size_max : float, default 3.0 472 Maximum optimization step size. 473 step_size_min : float, default 0.08 474 Minimum optimization step size (for convergence failure detection). 475 sqrt_transform : bool, default False 476 Whether to apply signed square root transformation to inputs. 477 convergence_tol : float, default 0.0005 478 Convergence tolerance based on duality gap. 479 max_steps : int, default 10000 480 Maximum number of optimization steps. 481 batch_size : int, default 2000 482 Number of regions to process simultaneously. 483 step_adjust : float, default 0.7 484 Factor to reduce step size when optimization diverges. 485 post_filter : bool, default True 486 Whether to filter hits based on similarity threshold. 487 device : torch.device, optional 488 Target device for computation. Auto-detected if None. 489 compile_optimizer : bool, default False 490 Whether to JIT compile the optimizer for speed. 491 eps : float, default 1.0 492 Small constant for numerical stability. 493 494 Returns 495 ------- 496 hits_df : pl.DataFrame 497 DataFrame containing called motif hits with columns: 498 - peak_id: Region index 499 - motif_id: Motif index 500 - hit_start: Start position of hit 501 - hit_coefficient: Hit strength coefficient 502 - hit_similarity: Cosine similarity with motif 503 - hit_importance: Total contribution score in hit region 504 - hit_importance_sq: Sum of squared contributions (for normalization) 505 qc_df : pl.DataFrame 506 DataFrame containing quality control metrics with columns: 507 - peak_id: Region index 508 - nll: Final negative log likelihood 509 - dual_gap: Final duality gap 510 - num_steps: Number of optimization steps 511 - step_size: Final step size 512 - global_scale: Region-level scaling factor 513 514 Notes 515 ----- 516 The algorithm solves the optimization problem: 517 518 minimize_c: ||contribs - Σⱼ convolve(c * scale, cwms[j]) * sequences||²₂ + Σⱼ λⱼ||c[:,j]||₁ 519 520 subject to: c ≥ 0 521 522 where c[i,j] represents the strength of motif j at position i. 523 524 The importance scaling balances reconstruction across different 525 motifs and positions based on the local contribution magnitude. 526 527 Examples 528 -------- 529 >>> hits_df, qc_df = fit_contribs( 530 ... cwms=motif_cwms, 531 ... contribs=contrib_scores, 532 ... sequences=onehot_seqs, 533 ... cwm_trim_mask=trim_masks, 534 ... use_hypothetical=False, 535 ... lambdas=np.array([0.7, 0.8]), 536 ... step_size_max=3.0, 537 ... step_size_min=0.08, 538 ... sqrt_transform=False, 539 ... convergence_tol=0.0005, 540 ... max_steps=10000, 541 ... batch_size=1000, 542 ... step_adjust=0.7, 543 ... post_filter=True, 544 ... device=None, 545 ... compile_optimizer=False 546 ... ) 547 """ 548 M, _, W = cwms.shape 549 N, _, L = sequences.shape 550 551 B = batch_size # Using uppercase for consistency with dimension notation 552 553 if device is None: 554 if torch.cuda.is_available(): 555 device = torch.device("cuda") 556 else: 557 device = torch.device("cpu") 558 warnings.warn("No GPU available. Running on CPU.", RuntimeWarning) 559 560 # Compile optimizer if requested 561 global optimizer_step 562 if compile_optimizer: 563 optimizer_step = torch.compile(optimizer_step, fullgraph=True) 564 565 # Convert inputs to PyTorch tensors with proper device placement 566 cwms_tensor: torch.Tensor = torch.from_numpy(cwms) 567 contribs_tensor: torch.Tensor = torch.from_numpy(contribs) 568 sequences_tensor: torch.Tensor = torch.from_numpy(sequences) 569 cwm_trim_mask_tensor = torch.from_numpy(cwm_trim_mask)[:, None, :].repeat(1, 4, 1) 570 lambdas_tensor: torch.Tensor = torch.from_numpy(lambdas)[None, :, None].to( 571 device=device, dtype=torch.float32 572 ) 573 574 # Convert to channel-last layout for optimized convolution operations 575 cwms_tensor = _to_channel_last_layout( 576 cwms_tensor, device=device, dtype=torch.float32 577 ) 578 cwm_trim_mask_tensor = _to_channel_last_layout( 579 cwm_trim_mask_tensor, device=device, dtype=torch.float32 580 ) 581 cwms_tensor = cwms_tensor * cwm_trim_mask_tensor # Apply trimming mask 582 583 if sqrt_transform: 584 cwms_tensor = _signed_sqrt(cwms_tensor) 585 cwm_norm = (cwms_tensor**2).sum(dim=(1, 2)).sqrt() 586 cwms_tensor = cwms_tensor / cwm_norm[:, None, None] 587 588 # Initialize batch loader 589 if len(contribs_tensor.shape) == 3: 590 if use_hypothetical: 591 batch_loader = BatchLoaderHyp(contribs_tensor, sequences_tensor, L, device) 592 else: 593 batch_loader = BatchLoaderProj(contribs_tensor, sequences_tensor, L, device) 594 elif len(contribs_tensor.shape) == 2: 595 if use_hypothetical: 596 raise ValueError( 597 "Input regions do not contain hypothetical contribution scores" 598 ) 599 else: 600 batch_loader = BatchLoaderCompactFmt( 601 contribs_tensor, sequences_tensor, L, device 602 ) 603 else: 604 raise ValueError( 605 f"Input contributions array is of incorrect shape {contribs_tensor.shape}" 606 ) 607 608 # Initialize output container objects 609 hit_idxs_lst: List[ndarray] = [] 610 coefficients_lst: List[ndarray] = [] 611 similarity_lst: List[ndarray] = [] 612 importance_lst: List[ndarray] = [] 613 importance_sq_lst: List[ndarray] = [] 614 qc_lsts: Dict[str, List[ndarray]] = { 615 "nll": [], 616 "dual_gap": [], 617 "num_steps": [], 618 "step_size": [], 619 "global_scale": [], 620 "peak_id": [], 621 } 622 623 # Initialize buffers for optimizer 624 coef_inter: Float[Tensor, "B M P"] = torch.zeros( 625 (B, M, L - W + 1) 626 ) # (B, M, P) where P = L - W + 1 627 coef_inter = _to_channel_last_layout(coef_inter, device=device, dtype=torch.float32) 628 coef: Float[Tensor, "B M P"] = torch.zeros_like(coef_inter) 629 i: Float[Tensor, "B 1 1"] = torch.zeros((B, 1, 1), dtype=torch.int, device=device) 630 step_sizes: Float[Tensor, "B 1 1"] = torch.full( 631 (B, 1, 1), step_size_max, dtype=torch.float32, device=device 632 ) 633 634 converged: Bool[Tensor, " B"] = torch.full( 635 (B,), True, dtype=torch.bool, device=device 636 ) 637 num_load = B 638 639 contribs_buf: Float[Tensor, "B 4 L"] = torch.zeros((B, 4, L)) 640 contribs_buf = _to_channel_last_layout( 641 contribs_buf, device=device, dtype=torch.float32 642 ) 643 644 seqs_buf: Union[Int[Tensor, "B 4 L"], int] 645 if use_hypothetical: 646 seqs_buf = 1 647 else: 648 seqs_buf = torch.zeros((B, 4, L)) 649 seqs_buf = _to_channel_last_layout(seqs_buf, device=device, dtype=torch.int8) 650 651 importance_scale_buf: Float[Tensor, "B M P"] = torch.zeros((B, M, L - W + 1)) 652 importance_scale_buf = _to_channel_last_layout( 653 importance_scale_buf, device=device, dtype=torch.float32 654 ) 655 656 inds_buf: Int[Tensor, " B"] = torch.zeros((B,), dtype=torch.int, device=device) 657 global_scale_buf: Float[Tensor, " B"] = torch.zeros( 658 (B,), dtype=torch.float, device=device 659 ) 660 661 with tqdm(disable=None, unit="regions", total=N, ncols=120) as pbar: 662 num_complete = 0 663 next_ind = 0 664 while num_complete < N: 665 # Retire converged peaks and fill buffer with new data 666 if num_load > 0: 667 load_start = next_ind 668 load_end = load_start + num_load 669 next_ind = min(load_end, contribs_tensor.shape[0]) 670 671 batch_data = batch_loader.load_batch(int(load_start), int(load_end)) 672 contribs_batch, seqs_batch, inds_batch = batch_data 673 674 if sqrt_transform: 675 contribs_batch = _signed_sqrt(contribs_batch) 676 677 global_scale_batch = ((contribs_batch**2).sum(dim=(1, 2)) / L).sqrt() 678 contribs_batch = torch.nan_to_num( 679 contribs_batch / global_scale_batch[:, None, None] 680 ) 681 682 importance_scale_batch = ( 683 F.conv1d(contribs_batch**2, cwm_trim_mask_tensor) + eps 684 ) ** (-0.5) 685 importance_scale_batch = importance_scale_batch.clamp(max=10) 686 687 contribs_buf[converged, :, :] = contribs_batch 688 if not use_hypothetical: 689 seqs_buf[converged, :, :] = seqs_batch # type: ignore 690 691 importance_scale_buf[converged, :, :] = importance_scale_batch 692 693 inds_buf[converged] = inds_batch 694 global_scale_buf[converged] = global_scale_batch 695 696 coef_inter[converged, :, :] *= 0 697 coef[converged, :, :] *= 0 698 i[converged] *= 0 699 700 step_sizes[converged] = step_size_max 701 702 # Optimization step 703 coef_inter, coef, gap, nll = optimizer_step( 704 cwms_tensor, 705 contribs_buf, 706 importance_scale_buf, 707 seqs_buf, 708 coef_inter, 709 coef, 710 i, 711 step_sizes, 712 L, 713 lambdas_tensor, 714 ) 715 i += 1 716 717 # Assess convergence of each peak being optimized. Reset diverged peaks with lower step size. 718 active = inds_buf >= 0 719 720 diverged = ~torch.isfinite(gap) & active 721 coef_inter[diverged, :, :] *= 0 722 coef[diverged, :, :] *= 0 723 i[diverged] *= 0 724 step_sizes[diverged, :, :] *= step_adjust 725 726 timeouts = (i > max_steps).squeeze() & active 727 if timeouts.sum().item() > 0: 728 timeout_inds = inds_buf[timeouts] 729 for ind in timeout_inds: 730 warnings.warn( 731 f"Region {ind} has not converged within max_steps={max_steps} iterations.", 732 RuntimeWarning, 733 ) 734 735 fails = (step_sizes < step_size_min).squeeze() & active 736 if fails.sum().item() > 0: 737 fail_inds = inds_buf[fails] 738 for ind in fail_inds: 739 warnings.warn(f"Optimizer failed for region {ind}.", RuntimeWarning) 740 741 converged = ((gap <= convergence_tol) | timeouts | fails) & active 742 num_load = converged.sum().item() 743 744 # Extract hits from converged peaks 745 if num_load > 0: 746 inds_out = inds_buf[converged] 747 global_scale_out = global_scale_buf[converged] 748 749 # Compute hit scores 750 coef_out = coef[converged, :, :] 751 importance_scale_out_dense = importance_scale_buf[converged, :, :] 752 importance_sq = importance_scale_out_dense ** (-2) - eps 753 xcor_scale = importance_sq.sqrt() 754 755 contribs_converged = contribs_buf[converged, :, :] 756 importance_sum_out_dense = F.conv1d( 757 torch.abs(contribs_converged), cwm_trim_mask_tensor 758 ) 759 xcov_out_dense = F.conv1d(contribs_converged, cwms_tensor) 760 # xcov_out_dense = F.conv1d(torch.abs(contribs_converged), cwms_tensor) 761 xcor_out_dense = xcov_out_dense / xcor_scale 762 763 if post_filter: 764 coef_out = coef_out * (xcor_out_dense >= lambdas_tensor) 765 766 # Extract hit coordinates using sparse tensor representation 767 coef_out = coef_out.to_sparse() 768 769 # Tensor indexing operations for hit extraction 770 hit_idxs_out = torch.clone(coef_out.indices()) # Sparse tensor indices 771 hit_idxs_out[0, :] = F.embedding( 772 hit_idxs_out[0, :], inds_out[:, None] 773 ).squeeze() # Embedding lookup with complex indexing 774 # Map buffer index to peak index 775 776 ind_tuple = torch.unbind(coef_out.indices()) 777 importance_out = importance_sum_out_dense[ind_tuple] 778 importance_sq_out = importance_sq[ind_tuple] 779 xcor_out = xcor_out_dense[ind_tuple] 780 781 scores_out_raw = coef_out.values() 782 783 # Store outputs 784 gap_out = gap[converged] 785 nll_out = nll[converged] 786 step_out = i[converged, 0, 0] 787 step_sizes_out = step_sizes[converged, 0, 0] 788 789 hit_idxs_lst.append(hit_idxs_out.numpy(force=True).T) 790 coefficients_lst.append(scores_out_raw.numpy(force=True)) 791 similarity_lst.append(xcor_out.numpy(force=True)) 792 importance_lst.append(importance_out.numpy(force=True)) 793 importance_sq_lst.append(importance_sq_out.numpy(force=True)) 794 795 qc_lsts["nll"].append(nll_out.numpy(force=True)) 796 qc_lsts["dual_gap"].append(gap_out.numpy(force=True)) 797 qc_lsts["num_steps"].append(step_out.numpy(force=True)) 798 qc_lsts["global_scale"].append(global_scale_out.numpy(force=True)) 799 qc_lsts["step_size"].append(step_sizes_out.numpy(force=True)) 800 qc_lsts["peak_id"].append(inds_out.numpy(force=True).astype(np.uint32)) 801 802 num_complete += num_load 803 pbar.update(num_load) 804 805 # Merge outputs into arrays 806 hit_idxs = np.concatenate(hit_idxs_lst, axis=0) 807 scores_coefficient = np.concatenate(coefficients_lst, axis=0) 808 scores_similarity = np.concatenate(similarity_lst, axis=0) 809 scores_importance = np.concatenate(importance_lst, axis=0) 810 scores_importance_sq = np.concatenate(importance_sq_lst, axis=0) 811 812 hits: Dict[str, ndarray] = { 813 "peak_id": hit_idxs[:, 0].astype(np.uint32), 814 "motif_id": hit_idxs[:, 1].astype(np.uint32), 815 "hit_start": hit_idxs[:, 2], 816 "hit_coefficient": scores_coefficient, 817 "hit_similarity": scores_similarity, 818 "hit_importance": scores_importance, 819 "hit_importance_sq": scores_importance_sq, 820 } 821 822 qc: Dict[str, ndarray] = {k: np.concatenate(v, axis=0) for k, v in qc_lsts.items()} 823 824 hits_df = pl.DataFrame(hits) 825 qc_df = pl.DataFrame(qc) 826 827 return hits_df, qc_df
29def prox_grad_step( 30 coefficients: Float[Tensor, "B M P"], 31 importance_scale: Float[Tensor, "B 1 P"], 32 cwms: Float[Tensor, "M 4 W"], 33 contribs: Float[Tensor, "B 4 L"], 34 sequences: Union[Int[Tensor, "B 4 L"], int], 35 lambdas: Float[Tensor, "1 M 1"], 36 step_sizes: Float[Tensor, "B 1 1"], 37) -> Tuple[Float[Tensor, "B M P"], Float[Tensor, " B"], Float[Tensor, " B"]]: 38 """Perform a proximal gradient descent optimization step for non-negative lasso. 39 40 This function implements a single optimization step of the Fi-NeMo algorithm, 41 which uses proximal gradient descent to solve a sparse reconstruction problem. 42 The goal is to represent contribution scores as a sparse linear combination 43 of motif contribution weight matrices (CWMs). 44 45 Dimension notation: 46 - B = batch size (number of regions processed simultaneously) 47 - M = number of motifs 48 - L = sequence length 49 - W = motif width (length of each motif) 50 - P = L - W + 1 (number of valid motif positions) 51 52 Parameters 53 ---------- 54 coefficients : Float[Tensor, "B M P"] 55 Current coefficient matrix representing motif instance strengths. 56 importance_scale : Float[Tensor, "B 1 P"] 57 Scaling factors for importance-weighted reconstruction. 58 cwms : Float[Tensor, "M 4 W"] 59 Motif contribution weight matrices for all motifs. 60 4 represents the DNA bases (A, C, G, T). 61 contribs : Float[Tensor, "B 4 L"] 62 Target contribution scores to reconstruct. 63 sequences : Float[Tensor, "B 4 L"] | int 64 One-hot encoded DNA sequences. Can be a scalar (1) for hypothetical mode. 65 lambdas : Float[Tensor, "1 M 1"] 66 L1 regularization weights for each motif. 67 step_sizes : Float[Tensor, "B 1 1"] 68 Optimization step sizes for each batch element. 69 70 Returns 71 ------- 72 c_next : Float[Tensor, "B M P"] 73 Updated coefficient matrix after the optimization step (shape: batch_size × motifs × positions). 74 dual_gap : Float[Tensor, " B"] 75 Duality gap for convergence assessment (shape: batch_size). 76 nll : Float[Tensor, " B"] 77 Negative log likelihood (proportional to MSE, shape: batch_size). 78 79 Notes 80 ----- 81 The algorithm uses proximal gradient descent to solve: 82 83 minimize_c: ||contribs - conv_transpose(c * importance_scale, cwms) * sequences||²₂ + λ||c||₁ 84 85 subject to: c ≥ 0 86 87 References 88 ---------- 89 - Proximal gradient descent: https://yuxinchen2020.github.io/ele520_math_data/lectures/lasso_algorithm_extension.pdf, slide 22 90 - Duality gap computation: https://stanford.edu/~boyd/papers/pdf/l1_ls.pdf, Section III 91 """ 92 # Forward pass: convolution operations require specific tensor layouts 93 coef_adj = coefficients * importance_scale 94 pred_unmasked = F.conv_transpose1d(coef_adj, cwms) # (B, 4, L) 95 pred = ( 96 pred_unmasked * sequences 97 ) # (B, 4, L), element-wise masking for projected mode 98 99 # Compute gradient * -1 100 residuals = contribs - pred # (B, 4, L) 101 ngrad = F.conv1d(residuals, cwms) * importance_scale # (B, M, P) 102 103 # Negative log likelihood (proportional to MSE) 104 nll = (residuals**2).sum(dim=(1, 2)) # (B) 105 106 # Compute duality gap for convergence assessment 107 dual_norm = (ngrad / lambdas).amax(dim=(1, 2)) # (B) 108 dual_scale = (torch.clamp(1 / dual_norm, max=1.0) ** 2 + 1) / 2 # (B) 109 nll_scaled = nll * dual_scale # (B) 110 111 dual_diff = (residuals * contribs).sum(dim=(1, 2)) # (B) 112 l1_term = (torch.abs(coefficients).sum(dim=2, keepdim=True) * lambdas).sum( 113 dim=(1, 2) 114 ) # (B) 115 dual_gap = (nll_scaled - dual_diff + l1_term).abs() # (B) 116 117 # Compute proximal gradient descent step 118 c_next = coefficients + step_sizes * (ngrad - lambdas) # (B, M, P) 119 c_next = F.relu(c_next) # Ensure non-negativity constraint 120 121 return c_next, dual_gap, nll
Perform a proximal gradient descent optimization step for non-negative lasso.
This function implements a single optimization step of the Fi-NeMo algorithm, which uses proximal gradient descent to solve a sparse reconstruction problem. The goal is to represent contribution scores as a sparse linear combination of motif contribution weight matrices (CWMs).
Dimension notation:
- B = batch size (number of regions processed simultaneously)
- M = number of motifs
- L = sequence length
- W = motif width (length of each motif)
- P = L - W + 1 (number of valid motif positions)
Parameters
- coefficients (Float[Tensor, "B M P"]): Current coefficient matrix representing motif instance strengths.
- importance_scale (Float[Tensor, "B 1 P"]): Scaling factors for importance-weighted reconstruction.
- cwms (Float[Tensor, "M 4 W"]): Motif contribution weight matrices for all motifs. 4 represents the DNA bases (A, C, G, T).
- contribs (Float[Tensor, "B 4 L"]): Target contribution scores to reconstruct.
- sequences (Float[Tensor, "B 4 L"] | int): One-hot encoded DNA sequences. Can be a scalar (1) for hypothetical mode.
- lambdas (Float[Tensor, "1 M 1"]): L1 regularization weights for each motif.
- step_sizes (Float[Tensor, "B 1 1"]): Optimization step sizes for each batch element.
Returns
- c_next (Float[Tensor, "B M P"]): Updated coefficient matrix after the optimization step (shape: batch_size × motifs × positions).
- dual_gap (Float[Tensor, " B"]): Duality gap for convergence assessment (shape: batch_size).
- nll (Float[Tensor, " B"]): Negative log likelihood (proportional to MSE, shape: batch_size).
Notes
The algorithm uses proximal gradient descent to solve:
minimize_c: ||contribs - conv_transpose(c * importance_scale, cwms) * sequences||²₂ + λ||c||₁
subject to: c ≥ 0
References
- Proximal gradient descent: https://yuxinchen2020.github.io/ele520_math_data/lectures/lasso_algorithm_extension.pdf, slide 22
- Duality gap computation: https://stanford.edu/~boyd/papers/pdf/l1_ls.pdf, Section III
124def optimizer_step( 125 cwms: Float[Tensor, "M 4 W"], 126 contribs: Float[Tensor, "B 4 L"], 127 importance_scale: Float[Tensor, "B 1 P"], 128 sequences: Union[Int[Tensor, "B 4 L"], int], 129 coef_inter: Float[Tensor, "B M P"], 130 coef: Float[Tensor, "B M P"], 131 i: Float[Tensor, "B 1 1"], 132 step_sizes: Float[Tensor, "B 1 1"], 133 L: int, 134 lambdas: Float[Tensor, "1 M 1"], 135) -> Tuple[ 136 Float[Tensor, "B M P"], 137 Float[Tensor, "B M P"], 138 Float[Tensor, " B"], 139 Float[Tensor, " B"], 140]: 141 """Perform a non-negative lasso optimizer step with Nesterov momentum. 142 143 This function combines proximal gradient descent with momentum acceleration 144 to improve convergence speed while maintaining the non-negative constraint 145 on coefficients. 146 147 Dimension notation: 148 - B = batch size (number of regions processed simultaneously) 149 - M = number of motifs 150 - L = sequence length 151 - W = motif width (length of each motif) 152 - P = L - W + 1 (number of valid motif positions) 153 154 Parameters 155 ---------- 156 cwms : Float[Tensor, "M 4 W"] 157 Motif contribution weight matrices. 158 contribs : Float[Tensor, "B 4 L"] 159 Target contribution scores. 160 importance_scale : Float[Tensor, "B 1 P"] 161 Importance scaling factors. 162 sequences : Union[Int[Tensor, "B 4 L"], int] 163 One-hot encoded sequences or scalar for hypothetical mode. 164 coef_inter : Float[Tensor, "B M P"] 165 Intermediate coefficient matrix (with momentum). 166 coef : Float[Tensor, "B M P"] 167 Current coefficient matrix. 168 i : Float[Tensor, "B 1 1"] 169 Iteration counter for each batch element. 170 step_sizes : Float[Tensor, "B 1 1"] 171 Step sizes for optimization. 172 L : int 173 Sequence length for normalization. 174 lambdas : Float[Tensor, "1 M 1"] 175 Regularization parameters. 176 177 Returns 178 ------- 179 coef_inter : Float[Tensor, "B M P"] 180 Updated intermediate coefficients with momentum (shape: batch_size × motifs × positions). 181 coef : Float[Tensor, "B M P"] 182 Updated coefficient matrix (shape: batch_size × motifs × positions). 183 gap : Float[Tensor, " B"] 184 Normalized duality gap (shape: batch_size). 185 nll : Float[Tensor, " B"] 186 Normalized negative log likelihood (shape: batch_size). 187 188 Notes 189 ----- 190 Uses Nesterov momentum with momentum coefficient i/(i+3) for improved 191 convergence properties. The duality gap and NLL are normalized by 192 sequence length for scale-invariant convergence assessment. 193 194 References 195 ---------- 196 https://yuxinchen2020.github.io/ele520_math_data/lectures/lasso_algorithm_extension.pdf, slides 22, 27 197 """ 198 coef_prev = coef 199 200 # Proximal gradient descent step 201 coef, gap, nll = prox_grad_step( 202 coef_inter, importance_scale, cwms, contribs, sequences, lambdas, step_sizes 203 ) 204 gap = gap / L 205 nll = nll / (2 * L) 206 207 # Compute updated coefficients with Nesterov momentum 208 mom_term = i / (i + 3.0) 209 coef_inter = (1 + mom_term) * coef - mom_term * coef_prev 210 211 return coef_inter, coef, gap, nll
Perform a non-negative lasso optimizer step with Nesterov momentum.
This function combines proximal gradient descent with momentum acceleration to improve convergence speed while maintaining the non-negative constraint on coefficients.
Dimension notation:
- B = batch size (number of regions processed simultaneously)
- M = number of motifs
- L = sequence length
- W = motif width (length of each motif)
- P = L - W + 1 (number of valid motif positions)
Parameters
- cwms (Float[Tensor, "M 4 W"]): Motif contribution weight matrices.
- contribs (Float[Tensor, "B 4 L"]): Target contribution scores.
- importance_scale (Float[Tensor, "B 1 P"]): Importance scaling factors.
- sequences (Union[Int[Tensor, "B 4 L"], int]): One-hot encoded sequences or scalar for hypothetical mode.
- coef_inter (Float[Tensor, "B M P"]): Intermediate coefficient matrix (with momentum).
- coef (Float[Tensor, "B M P"]): Current coefficient matrix.
- i (Float[Tensor, "B 1 1"]): Iteration counter for each batch element.
- step_sizes (Float[Tensor, "B 1 1"]): Step sizes for optimization.
- L (int): Sequence length for normalization.
- lambdas (Float[Tensor, "1 M 1"]): Regularization parameters.
Returns
- coef_inter (Float[Tensor, "B M P"]): Updated intermediate coefficients with momentum (shape: batch_size × motifs × positions).
- coef (Float[Tensor, "B M P"]): Updated coefficient matrix (shape: batch_size × motifs × positions).
- gap (Float[Tensor, " B"]): Normalized duality gap (shape: batch_size).
- nll (Float[Tensor, " B"]): Normalized negative log likelihood (shape: batch_size).
Notes
Uses Nesterov momentum with momentum coefficient i/(i+3) for improved convergence properties. The duality gap and NLL are normalized by sequence length for scale-invariant convergence assessment.
References
https://yuxinchen2020.github.io/ele520_math_data/lectures/lasso_algorithm_extension.pdf, slides 22, 27
253class BatchLoaderBase(ABC): 254 """Base class for loading batches of contribution scores and sequences. 255 256 This class provides common functionality for different input formats 257 including batch indexing and padding for consistent batch sizes. 258 259 Dimension notation: 260 - N = number of sequences/regions in dataset 261 - L = sequence length 262 - B = batch size (number of regions processed simultaneously) 263 264 Parameters 265 ---------- 266 contribs : Union[Float[Tensor, "N 4 L"], Float[Tensor, "N L"]] 267 Contribution scores array. 268 sequences : Int[Tensor, "N 4 L"] 269 One-hot encoded sequences array. 270 L : int 271 Sequence length. 272 device : torch.device 273 Target device for tensor operations. 274 """ 275 276 def __init__( 277 self, 278 contribs: Union[Float[Tensor, "N 4 L"], Float[Tensor, "N L"]], 279 sequences: Int[Tensor, "N 4 L"], 280 L: int, 281 device: torch.device, 282 ) -> None: 283 self.contribs = contribs 284 self.sequences = sequences 285 self.L = L 286 self.device = device 287 288 def _get_inds_and_pad_lens( 289 self, start: int, end: int 290 ) -> Tuple[Int[Tensor, " Z"], Tuple[int, ...]]: 291 """Get indices and padding lengths for batch loading. 292 293 Parameters 294 ---------- 295 start : int 296 Start index for batch. 297 end : int 298 End index for batch. 299 300 Returns 301 ------- 302 inds : Int[Tensor, " Z"] 303 Padded indices tensor with -1 for padding positions (shape: padded_batch_size). 304 pad_lens : tuple 305 Padding specification for F.pad (left, right, top, bottom, front, back). 306 """ 307 N = end - start 308 end = min(end, self.contribs.shape[0]) 309 overhang = N - (end - start) 310 pad_lens = (0, 0, 0, 0, 0, overhang) 311 312 inds = F.pad( 313 torch.arange(start, end, dtype=torch.int), (0, overhang), value=-1 314 ).to(device=self.device) 315 316 return inds, pad_lens 317 318 @abstractmethod 319 def load_batch( 320 self, start: int, end: int 321 ) -> Tuple[ 322 Float[Tensor, "B 4 L"], Union[Int[Tensor, "B 4 L"], int], Int[Tensor, " B"] 323 ]: 324 """Load a batch of data. 325 326 Dimension notation: 327 - B = batch size (number of regions in this batch) 328 - L = sequence length 329 330 Parameters 331 ---------- 332 start : int 333 Start index (used by subclasses). 334 end : int 335 End index (used by subclasses). 336 337 Returns 338 ------- 339 contribs_batch : Float[Tensor, "B 4 L"] 340 Batch of contribution scores (shape: batch_size × 4_bases × L). 341 sequences_batch : Union[Int[Tensor, "B 4 L"], int] 342 Batch of one-hot encoded sequences (shape: batch_size × 4_bases × L) or scalar 1 for hypothetical mode. 343 inds_batch : Int[Tensor, " B"] 344 Batch indices mapping to original sequence indices (shape: batch_size). 345 346 Notes 347 ----- 348 This is an abstract method that must be implemented by subclasses. 349 Parameters are intentionally unused in the base implementation. 350 """ 351 pass
Base class for loading batches of contribution scores and sequences.
This class provides common functionality for different input formats including batch indexing and padding for consistent batch sizes.
Dimension notation:
- N = number of sequences/regions in dataset
- L = sequence length
- B = batch size (number of regions processed simultaneously)
Parameters
- contribs (Union[Float[Tensor, "N 4 L"], Float[Tensor, "N L"]]): Contribution scores array.
- sequences (Int[Tensor, "N 4 L"]): One-hot encoded sequences array.
- L (int): Sequence length.
- device (torch.device): Target device for tensor operations.
318 @abstractmethod 319 def load_batch( 320 self, start: int, end: int 321 ) -> Tuple[ 322 Float[Tensor, "B 4 L"], Union[Int[Tensor, "B 4 L"], int], Int[Tensor, " B"] 323 ]: 324 """Load a batch of data. 325 326 Dimension notation: 327 - B = batch size (number of regions in this batch) 328 - L = sequence length 329 330 Parameters 331 ---------- 332 start : int 333 Start index (used by subclasses). 334 end : int 335 End index (used by subclasses). 336 337 Returns 338 ------- 339 contribs_batch : Float[Tensor, "B 4 L"] 340 Batch of contribution scores (shape: batch_size × 4_bases × L). 341 sequences_batch : Union[Int[Tensor, "B 4 L"], int] 342 Batch of one-hot encoded sequences (shape: batch_size × 4_bases × L) or scalar 1 for hypothetical mode. 343 inds_batch : Int[Tensor, " B"] 344 Batch indices mapping to original sequence indices (shape: batch_size). 345 346 Notes 347 ----- 348 This is an abstract method that must be implemented by subclasses. 349 Parameters are intentionally unused in the base implementation. 350 """ 351 pass
Load a batch of data.
Dimension notation:
- B = batch size (number of regions in this batch)
- L = sequence length
Parameters
- start (int): Start index (used by subclasses).
- end (int): End index (used by subclasses).
Returns
- contribs_batch (Float[Tensor, "B 4 L"]): Batch of contribution scores (shape: batch_size × 4_bases × L).
- sequences_batch (Union[Int[Tensor, "B 4 L"], int]): Batch of one-hot encoded sequences (shape: batch_size × 4_bases × L) or scalar 1 for hypothetical mode.
- inds_batch (Int[Tensor, " B"]): Batch indices mapping to original sequence indices (shape: batch_size).
Notes
This is an abstract method that must be implemented by subclasses. Parameters are intentionally unused in the base implementation.
354class BatchLoaderCompactFmt(BatchLoaderBase): 355 """Batch loader for compact format contribution scores. 356 357 Handles contribution scores in shape (N, L) representing projected 358 scores that need to be broadcasted to (N, 4, L) format. 359 """ 360 361 def load_batch( 362 self, start: int, end: int 363 ) -> Tuple[Float[Tensor, "B 4 L"], Int[Tensor, "B 4 L"], Int[Tensor, " B"]]: 364 inds, pad_lens = self._get_inds_and_pad_lens(start, end) 365 366 contribs_compact = F.pad(self.contribs[start:end, None, :], pad_lens) 367 contribs_batch = _to_channel_last_layout( 368 contribs_compact, device=self.device, dtype=torch.float32 369 ) 370 sequences_batch = F.pad(self.sequences[start:end, :, :], pad_lens) # (B, 4, L) 371 sequences_batch = _to_channel_last_layout( 372 sequences_batch, device=self.device, dtype=torch.int8 373 ) 374 375 contribs_batch = contribs_batch * sequences_batch # (B, 4, L) 376 377 return contribs_batch, sequences_batch, inds
Batch loader for compact format contribution scores.
Handles contribution scores in shape (N, L) representing projected scores that need to be broadcasted to (N, 4, L) format.
361 def load_batch( 362 self, start: int, end: int 363 ) -> Tuple[Float[Tensor, "B 4 L"], Int[Tensor, "B 4 L"], Int[Tensor, " B"]]: 364 inds, pad_lens = self._get_inds_and_pad_lens(start, end) 365 366 contribs_compact = F.pad(self.contribs[start:end, None, :], pad_lens) 367 contribs_batch = _to_channel_last_layout( 368 contribs_compact, device=self.device, dtype=torch.float32 369 ) 370 sequences_batch = F.pad(self.sequences[start:end, :, :], pad_lens) # (B, 4, L) 371 sequences_batch = _to_channel_last_layout( 372 sequences_batch, device=self.device, dtype=torch.int8 373 ) 374 375 contribs_batch = contribs_batch * sequences_batch # (B, 4, L) 376 377 return contribs_batch, sequences_batch, inds
Load a batch of data.
Dimension notation:
- B = batch size (number of regions in this batch)
- L = sequence length
Parameters
- start (int): Start index (used by subclasses).
- end (int): End index (used by subclasses).
Returns
- contribs_batch (Float[Tensor, "B 4 L"]): Batch of contribution scores (shape: batch_size × 4_bases × L).
- sequences_batch (Union[Int[Tensor, "B 4 L"], int]): Batch of one-hot encoded sequences (shape: batch_size × 4_bases × L) or scalar 1 for hypothetical mode.
- inds_batch (Int[Tensor, " B"]): Batch indices mapping to original sequence indices (shape: batch_size).
Notes
This is an abstract method that must be implemented by subclasses. Parameters are intentionally unused in the base implementation.
Inherited Members
380class BatchLoaderProj(BatchLoaderBase): 381 """Batch loader for projected contribution scores. 382 383 Handles contribution scores in shape (N, 4, L) where scores are 384 element-wise multiplied by one-hot sequences to get projected contributions. 385 """ 386 387 def load_batch( 388 self, start: int, end: int 389 ) -> Tuple[Float[Tensor, "B 4 L"], Int[Tensor, "B 4 L"], Int[Tensor, " B"]]: 390 inds, pad_lens = self._get_inds_and_pad_lens(start, end) 391 392 contribs_hyp = F.pad(self.contribs[start:end, :, :], pad_lens) 393 contribs_hyp = _to_channel_last_layout( 394 contribs_hyp, device=self.device, dtype=torch.float32 395 ) 396 sequences_batch = F.pad(self.sequences[start:end, :, :], pad_lens) # (B, 4, L) 397 sequences_batch = _to_channel_last_layout( 398 sequences_batch, device=self.device, dtype=torch.int8 399 ) 400 contribs_batch = contribs_hyp * sequences_batch 401 402 return contribs_batch, sequences_batch, inds
Batch loader for projected contribution scores.
Handles contribution scores in shape (N, 4, L) where scores are element-wise multiplied by one-hot sequences to get projected contributions.
387 def load_batch( 388 self, start: int, end: int 389 ) -> Tuple[Float[Tensor, "B 4 L"], Int[Tensor, "B 4 L"], Int[Tensor, " B"]]: 390 inds, pad_lens = self._get_inds_and_pad_lens(start, end) 391 392 contribs_hyp = F.pad(self.contribs[start:end, :, :], pad_lens) 393 contribs_hyp = _to_channel_last_layout( 394 contribs_hyp, device=self.device, dtype=torch.float32 395 ) 396 sequences_batch = F.pad(self.sequences[start:end, :, :], pad_lens) # (B, 4, L) 397 sequences_batch = _to_channel_last_layout( 398 sequences_batch, device=self.device, dtype=torch.int8 399 ) 400 contribs_batch = contribs_hyp * sequences_batch 401 402 return contribs_batch, sequences_batch, inds
Load a batch of data.
Dimension notation:
- B = batch size (number of regions in this batch)
- L = sequence length
Parameters
- start (int): Start index (used by subclasses).
- end (int): End index (used by subclasses).
Returns
- contribs_batch (Float[Tensor, "B 4 L"]): Batch of contribution scores (shape: batch_size × 4_bases × L).
- sequences_batch (Union[Int[Tensor, "B 4 L"], int]): Batch of one-hot encoded sequences (shape: batch_size × 4_bases × L) or scalar 1 for hypothetical mode.
- inds_batch (Int[Tensor, " B"]): Batch indices mapping to original sequence indices (shape: batch_size).
Notes
This is an abstract method that must be implemented by subclasses. Parameters are intentionally unused in the base implementation.
Inherited Members
405class BatchLoaderHyp(BatchLoaderBase): 406 """Batch loader for hypothetical contribution scores. 407 408 Handles hypothetical contribution scores in shape (N, 4, L) where 409 scores represent counterfactual effects of base substitutions. 410 """ 411 412 def load_batch( 413 self, start: int, end: int 414 ) -> Tuple[Float[Tensor, "B 4 L"], int, Int[Tensor, " B"]]: 415 inds, pad_lens = self._get_inds_and_pad_lens(start, end) 416 417 contribs_batch = F.pad(self.contribs[start:end, :, :], pad_lens) 418 contribs_batch = _to_channel_last_layout( 419 contribs_batch, device=self.device, dtype=torch.float32 420 ) 421 422 return contribs_batch, 1, inds
Batch loader for hypothetical contribution scores.
Handles hypothetical contribution scores in shape (N, 4, L) where scores represent counterfactual effects of base substitutions.
412 def load_batch( 413 self, start: int, end: int 414 ) -> Tuple[Float[Tensor, "B 4 L"], int, Int[Tensor, " B"]]: 415 inds, pad_lens = self._get_inds_and_pad_lens(start, end) 416 417 contribs_batch = F.pad(self.contribs[start:end, :, :], pad_lens) 418 contribs_batch = _to_channel_last_layout( 419 contribs_batch, device=self.device, dtype=torch.float32 420 ) 421 422 return contribs_batch, 1, inds
Load a batch of data.
Dimension notation:
- B = batch size (number of regions in this batch)
- L = sequence length
Parameters
- start (int): Start index (used by subclasses).
- end (int): End index (used by subclasses).
Returns
- contribs_batch (Float[Tensor, "B 4 L"]): Batch of contribution scores (shape: batch_size × 4_bases × L).
- sequences_batch (Union[Int[Tensor, "B 4 L"], int]): Batch of one-hot encoded sequences (shape: batch_size × 4_bases × L) or scalar 1 for hypothetical mode.
- inds_batch (Int[Tensor, " B"]): Batch indices mapping to original sequence indices (shape: batch_size).
Notes
This is an abstract method that must be implemented by subclasses. Parameters are intentionally unused in the base implementation.
Inherited Members
425def fit_contribs( 426 cwms: Float[ndarray, "M 4 W"], 427 contribs: Union[Float[ndarray, "N 4 L"], Float[ndarray, "N L"]], 428 sequences: Int[ndarray, "N 4 L"], 429 cwm_trim_mask: Float[ndarray, "M W"], 430 use_hypothetical: bool, 431 lambdas: Float[ndarray, " M"], 432 step_size_max: float = 3.0, 433 step_size_min: float = 0.08, 434 sqrt_transform: bool = False, 435 convergence_tol: float = 0.0005, 436 max_steps: int = 10000, 437 batch_size: int = 2000, 438 step_adjust: float = 0.7, 439 post_filter: bool = True, 440 device: Optional[torch.device] = None, 441 compile_optimizer: bool = False, 442 eps: float = 1.0, 443) -> Tuple[pl.DataFrame, pl.DataFrame]: 444 """Call motif hits by fitting sparse linear model to contribution scores. 445 446 This is the main function implementing the Fi-NeMo algorithm. It identifies 447 motif instances by solving a sparse reconstruction problem where contribution 448 scores are approximated as a linear combination of motif CWMs at specific 449 positions. The optimization uses proximal gradient descent with momentum. 450 451 Parameters 452 ---------- 453 cwms : Float[ndarray, "M 4 W"] 454 Motif contribution weight matrices where: 455 - M = number of motifs (transcription factor binding patterns) 456 - 4 = DNA bases (A, C, G, T dimensions) 457 - W = motif width (length of each motif pattern) 458 contribs : Float[ndarray, "N 4 L"] | Float[ndarray, "N L"] 459 Neural network contribution scores where: 460 - N = number of regions in dataset 461 - L = sequence length 462 Can be hypothetical (N, 4, L) or projected (N, L) format. 463 sequences : Int[ndarray, "N 4 L"] 464 One-hot encoded DNA sequences (shape: num_regions × 4_bases × L). 465 cwm_trim_mask : Float[ndarray, "M W"] 466 Binary mask indicating which positions of each CWM to use (shape: num_motifs × motif_width). 467 use_hypothetical : bool 468 Whether to use hypothetical contribution scores (True) or 469 projected scores (False). 470 lambdas : Float[ndarray, " M"] 471 L1 regularization weights for each motif (shape: num_motifs). 472 step_size_max : float, default 3.0 473 Maximum optimization step size. 474 step_size_min : float, default 0.08 475 Minimum optimization step size (for convergence failure detection). 476 sqrt_transform : bool, default False 477 Whether to apply signed square root transformation to inputs. 478 convergence_tol : float, default 0.0005 479 Convergence tolerance based on duality gap. 480 max_steps : int, default 10000 481 Maximum number of optimization steps. 482 batch_size : int, default 2000 483 Number of regions to process simultaneously. 484 step_adjust : float, default 0.7 485 Factor to reduce step size when optimization diverges. 486 post_filter : bool, default True 487 Whether to filter hits based on similarity threshold. 488 device : torch.device, optional 489 Target device for computation. Auto-detected if None. 490 compile_optimizer : bool, default False 491 Whether to JIT compile the optimizer for speed. 492 eps : float, default 1.0 493 Small constant for numerical stability. 494 495 Returns 496 ------- 497 hits_df : pl.DataFrame 498 DataFrame containing called motif hits with columns: 499 - peak_id: Region index 500 - motif_id: Motif index 501 - hit_start: Start position of hit 502 - hit_coefficient: Hit strength coefficient 503 - hit_similarity: Cosine similarity with motif 504 - hit_importance: Total contribution score in hit region 505 - hit_importance_sq: Sum of squared contributions (for normalization) 506 qc_df : pl.DataFrame 507 DataFrame containing quality control metrics with columns: 508 - peak_id: Region index 509 - nll: Final negative log likelihood 510 - dual_gap: Final duality gap 511 - num_steps: Number of optimization steps 512 - step_size: Final step size 513 - global_scale: Region-level scaling factor 514 515 Notes 516 ----- 517 The algorithm solves the optimization problem: 518 519 minimize_c: ||contribs - Σⱼ convolve(c * scale, cwms[j]) * sequences||²₂ + Σⱼ λⱼ||c[:,j]||₁ 520 521 subject to: c ≥ 0 522 523 where c[i,j] represents the strength of motif j at position i. 524 525 The importance scaling balances reconstruction across different 526 motifs and positions based on the local contribution magnitude. 527 528 Examples 529 -------- 530 >>> hits_df, qc_df = fit_contribs( 531 ... cwms=motif_cwms, 532 ... contribs=contrib_scores, 533 ... sequences=onehot_seqs, 534 ... cwm_trim_mask=trim_masks, 535 ... use_hypothetical=False, 536 ... lambdas=np.array([0.7, 0.8]), 537 ... step_size_max=3.0, 538 ... step_size_min=0.08, 539 ... sqrt_transform=False, 540 ... convergence_tol=0.0005, 541 ... max_steps=10000, 542 ... batch_size=1000, 543 ... step_adjust=0.7, 544 ... post_filter=True, 545 ... device=None, 546 ... compile_optimizer=False 547 ... ) 548 """ 549 M, _, W = cwms.shape 550 N, _, L = sequences.shape 551 552 B = batch_size # Using uppercase for consistency with dimension notation 553 554 if device is None: 555 if torch.cuda.is_available(): 556 device = torch.device("cuda") 557 else: 558 device = torch.device("cpu") 559 warnings.warn("No GPU available. Running on CPU.", RuntimeWarning) 560 561 # Compile optimizer if requested 562 global optimizer_step 563 if compile_optimizer: 564 optimizer_step = torch.compile(optimizer_step, fullgraph=True) 565 566 # Convert inputs to PyTorch tensors with proper device placement 567 cwms_tensor: torch.Tensor = torch.from_numpy(cwms) 568 contribs_tensor: torch.Tensor = torch.from_numpy(contribs) 569 sequences_tensor: torch.Tensor = torch.from_numpy(sequences) 570 cwm_trim_mask_tensor = torch.from_numpy(cwm_trim_mask)[:, None, :].repeat(1, 4, 1) 571 lambdas_tensor: torch.Tensor = torch.from_numpy(lambdas)[None, :, None].to( 572 device=device, dtype=torch.float32 573 ) 574 575 # Convert to channel-last layout for optimized convolution operations 576 cwms_tensor = _to_channel_last_layout( 577 cwms_tensor, device=device, dtype=torch.float32 578 ) 579 cwm_trim_mask_tensor = _to_channel_last_layout( 580 cwm_trim_mask_tensor, device=device, dtype=torch.float32 581 ) 582 cwms_tensor = cwms_tensor * cwm_trim_mask_tensor # Apply trimming mask 583 584 if sqrt_transform: 585 cwms_tensor = _signed_sqrt(cwms_tensor) 586 cwm_norm = (cwms_tensor**2).sum(dim=(1, 2)).sqrt() 587 cwms_tensor = cwms_tensor / cwm_norm[:, None, None] 588 589 # Initialize batch loader 590 if len(contribs_tensor.shape) == 3: 591 if use_hypothetical: 592 batch_loader = BatchLoaderHyp(contribs_tensor, sequences_tensor, L, device) 593 else: 594 batch_loader = BatchLoaderProj(contribs_tensor, sequences_tensor, L, device) 595 elif len(contribs_tensor.shape) == 2: 596 if use_hypothetical: 597 raise ValueError( 598 "Input regions do not contain hypothetical contribution scores" 599 ) 600 else: 601 batch_loader = BatchLoaderCompactFmt( 602 contribs_tensor, sequences_tensor, L, device 603 ) 604 else: 605 raise ValueError( 606 f"Input contributions array is of incorrect shape {contribs_tensor.shape}" 607 ) 608 609 # Initialize output container objects 610 hit_idxs_lst: List[ndarray] = [] 611 coefficients_lst: List[ndarray] = [] 612 similarity_lst: List[ndarray] = [] 613 importance_lst: List[ndarray] = [] 614 importance_sq_lst: List[ndarray] = [] 615 qc_lsts: Dict[str, List[ndarray]] = { 616 "nll": [], 617 "dual_gap": [], 618 "num_steps": [], 619 "step_size": [], 620 "global_scale": [], 621 "peak_id": [], 622 } 623 624 # Initialize buffers for optimizer 625 coef_inter: Float[Tensor, "B M P"] = torch.zeros( 626 (B, M, L - W + 1) 627 ) # (B, M, P) where P = L - W + 1 628 coef_inter = _to_channel_last_layout(coef_inter, device=device, dtype=torch.float32) 629 coef: Float[Tensor, "B M P"] = torch.zeros_like(coef_inter) 630 i: Float[Tensor, "B 1 1"] = torch.zeros((B, 1, 1), dtype=torch.int, device=device) 631 step_sizes: Float[Tensor, "B 1 1"] = torch.full( 632 (B, 1, 1), step_size_max, dtype=torch.float32, device=device 633 ) 634 635 converged: Bool[Tensor, " B"] = torch.full( 636 (B,), True, dtype=torch.bool, device=device 637 ) 638 num_load = B 639 640 contribs_buf: Float[Tensor, "B 4 L"] = torch.zeros((B, 4, L)) 641 contribs_buf = _to_channel_last_layout( 642 contribs_buf, device=device, dtype=torch.float32 643 ) 644 645 seqs_buf: Union[Int[Tensor, "B 4 L"], int] 646 if use_hypothetical: 647 seqs_buf = 1 648 else: 649 seqs_buf = torch.zeros((B, 4, L)) 650 seqs_buf = _to_channel_last_layout(seqs_buf, device=device, dtype=torch.int8) 651 652 importance_scale_buf: Float[Tensor, "B M P"] = torch.zeros((B, M, L - W + 1)) 653 importance_scale_buf = _to_channel_last_layout( 654 importance_scale_buf, device=device, dtype=torch.float32 655 ) 656 657 inds_buf: Int[Tensor, " B"] = torch.zeros((B,), dtype=torch.int, device=device) 658 global_scale_buf: Float[Tensor, " B"] = torch.zeros( 659 (B,), dtype=torch.float, device=device 660 ) 661 662 with tqdm(disable=None, unit="regions", total=N, ncols=120) as pbar: 663 num_complete = 0 664 next_ind = 0 665 while num_complete < N: 666 # Retire converged peaks and fill buffer with new data 667 if num_load > 0: 668 load_start = next_ind 669 load_end = load_start + num_load 670 next_ind = min(load_end, contribs_tensor.shape[0]) 671 672 batch_data = batch_loader.load_batch(int(load_start), int(load_end)) 673 contribs_batch, seqs_batch, inds_batch = batch_data 674 675 if sqrt_transform: 676 contribs_batch = _signed_sqrt(contribs_batch) 677 678 global_scale_batch = ((contribs_batch**2).sum(dim=(1, 2)) / L).sqrt() 679 contribs_batch = torch.nan_to_num( 680 contribs_batch / global_scale_batch[:, None, None] 681 ) 682 683 importance_scale_batch = ( 684 F.conv1d(contribs_batch**2, cwm_trim_mask_tensor) + eps 685 ) ** (-0.5) 686 importance_scale_batch = importance_scale_batch.clamp(max=10) 687 688 contribs_buf[converged, :, :] = contribs_batch 689 if not use_hypothetical: 690 seqs_buf[converged, :, :] = seqs_batch # type: ignore 691 692 importance_scale_buf[converged, :, :] = importance_scale_batch 693 694 inds_buf[converged] = inds_batch 695 global_scale_buf[converged] = global_scale_batch 696 697 coef_inter[converged, :, :] *= 0 698 coef[converged, :, :] *= 0 699 i[converged] *= 0 700 701 step_sizes[converged] = step_size_max 702 703 # Optimization step 704 coef_inter, coef, gap, nll = optimizer_step( 705 cwms_tensor, 706 contribs_buf, 707 importance_scale_buf, 708 seqs_buf, 709 coef_inter, 710 coef, 711 i, 712 step_sizes, 713 L, 714 lambdas_tensor, 715 ) 716 i += 1 717 718 # Assess convergence of each peak being optimized. Reset diverged peaks with lower step size. 719 active = inds_buf >= 0 720 721 diverged = ~torch.isfinite(gap) & active 722 coef_inter[diverged, :, :] *= 0 723 coef[diverged, :, :] *= 0 724 i[diverged] *= 0 725 step_sizes[diverged, :, :] *= step_adjust 726 727 timeouts = (i > max_steps).squeeze() & active 728 if timeouts.sum().item() > 0: 729 timeout_inds = inds_buf[timeouts] 730 for ind in timeout_inds: 731 warnings.warn( 732 f"Region {ind} has not converged within max_steps={max_steps} iterations.", 733 RuntimeWarning, 734 ) 735 736 fails = (step_sizes < step_size_min).squeeze() & active 737 if fails.sum().item() > 0: 738 fail_inds = inds_buf[fails] 739 for ind in fail_inds: 740 warnings.warn(f"Optimizer failed for region {ind}.", RuntimeWarning) 741 742 converged = ((gap <= convergence_tol) | timeouts | fails) & active 743 num_load = converged.sum().item() 744 745 # Extract hits from converged peaks 746 if num_load > 0: 747 inds_out = inds_buf[converged] 748 global_scale_out = global_scale_buf[converged] 749 750 # Compute hit scores 751 coef_out = coef[converged, :, :] 752 importance_scale_out_dense = importance_scale_buf[converged, :, :] 753 importance_sq = importance_scale_out_dense ** (-2) - eps 754 xcor_scale = importance_sq.sqrt() 755 756 contribs_converged = contribs_buf[converged, :, :] 757 importance_sum_out_dense = F.conv1d( 758 torch.abs(contribs_converged), cwm_trim_mask_tensor 759 ) 760 xcov_out_dense = F.conv1d(contribs_converged, cwms_tensor) 761 # xcov_out_dense = F.conv1d(torch.abs(contribs_converged), cwms_tensor) 762 xcor_out_dense = xcov_out_dense / xcor_scale 763 764 if post_filter: 765 coef_out = coef_out * (xcor_out_dense >= lambdas_tensor) 766 767 # Extract hit coordinates using sparse tensor representation 768 coef_out = coef_out.to_sparse() 769 770 # Tensor indexing operations for hit extraction 771 hit_idxs_out = torch.clone(coef_out.indices()) # Sparse tensor indices 772 hit_idxs_out[0, :] = F.embedding( 773 hit_idxs_out[0, :], inds_out[:, None] 774 ).squeeze() # Embedding lookup with complex indexing 775 # Map buffer index to peak index 776 777 ind_tuple = torch.unbind(coef_out.indices()) 778 importance_out = importance_sum_out_dense[ind_tuple] 779 importance_sq_out = importance_sq[ind_tuple] 780 xcor_out = xcor_out_dense[ind_tuple] 781 782 scores_out_raw = coef_out.values() 783 784 # Store outputs 785 gap_out = gap[converged] 786 nll_out = nll[converged] 787 step_out = i[converged, 0, 0] 788 step_sizes_out = step_sizes[converged, 0, 0] 789 790 hit_idxs_lst.append(hit_idxs_out.numpy(force=True).T) 791 coefficients_lst.append(scores_out_raw.numpy(force=True)) 792 similarity_lst.append(xcor_out.numpy(force=True)) 793 importance_lst.append(importance_out.numpy(force=True)) 794 importance_sq_lst.append(importance_sq_out.numpy(force=True)) 795 796 qc_lsts["nll"].append(nll_out.numpy(force=True)) 797 qc_lsts["dual_gap"].append(gap_out.numpy(force=True)) 798 qc_lsts["num_steps"].append(step_out.numpy(force=True)) 799 qc_lsts["global_scale"].append(global_scale_out.numpy(force=True)) 800 qc_lsts["step_size"].append(step_sizes_out.numpy(force=True)) 801 qc_lsts["peak_id"].append(inds_out.numpy(force=True).astype(np.uint32)) 802 803 num_complete += num_load 804 pbar.update(num_load) 805 806 # Merge outputs into arrays 807 hit_idxs = np.concatenate(hit_idxs_lst, axis=0) 808 scores_coefficient = np.concatenate(coefficients_lst, axis=0) 809 scores_similarity = np.concatenate(similarity_lst, axis=0) 810 scores_importance = np.concatenate(importance_lst, axis=0) 811 scores_importance_sq = np.concatenate(importance_sq_lst, axis=0) 812 813 hits: Dict[str, ndarray] = { 814 "peak_id": hit_idxs[:, 0].astype(np.uint32), 815 "motif_id": hit_idxs[:, 1].astype(np.uint32), 816 "hit_start": hit_idxs[:, 2], 817 "hit_coefficient": scores_coefficient, 818 "hit_similarity": scores_similarity, 819 "hit_importance": scores_importance, 820 "hit_importance_sq": scores_importance_sq, 821 } 822 823 qc: Dict[str, ndarray] = {k: np.concatenate(v, axis=0) for k, v in qc_lsts.items()} 824 825 hits_df = pl.DataFrame(hits) 826 qc_df = pl.DataFrame(qc) 827 828 return hits_df, qc_df
Call motif hits by fitting sparse linear model to contribution scores.
This is the main function implementing the Fi-NeMo algorithm. It identifies motif instances by solving a sparse reconstruction problem where contribution scores are approximated as a linear combination of motif CWMs at specific positions. The optimization uses proximal gradient descent with momentum.
Parameters
- cwms (Float[ndarray, "M 4 W"]):
Motif contribution weight matrices where:
- M = number of motifs (transcription factor binding patterns)
- 4 = DNA bases (A, C, G, T dimensions)
- W = motif width (length of each motif pattern)
- contribs (Float[ndarray, "N 4 L"] | Float[ndarray, "N L"]):
Neural network contribution scores where:
- N = number of regions in dataset
- L = sequence length Can be hypothetical (N, 4, L) or projected (N, L) format.
- sequences (Int[ndarray, "N 4 L"]): One-hot encoded DNA sequences (shape: num_regions × 4_bases × L).
- cwm_trim_mask (Float[ndarray, "M W"]): Binary mask indicating which positions of each CWM to use (shape: num_motifs × motif_width).
- use_hypothetical (bool): Whether to use hypothetical contribution scores (True) or projected scores (False).
- lambdas (Float[ndarray, " M"]): L1 regularization weights for each motif (shape: num_motifs).
- step_size_max (float, default 3.0): Maximum optimization step size.
- step_size_min (float, default 0.08): Minimum optimization step size (for convergence failure detection).
- sqrt_transform (bool, default False): Whether to apply signed square root transformation to inputs.
- convergence_tol (float, default 0.0005): Convergence tolerance based on duality gap.
- max_steps (int, default 10000): Maximum number of optimization steps.
- batch_size (int, default 2000): Number of regions to process simultaneously.
- step_adjust (float, default 0.7): Factor to reduce step size when optimization diverges.
- post_filter (bool, default True): Whether to filter hits based on similarity threshold.
- device (torch.device, optional): Target device for computation. Auto-detected if None.
- compile_optimizer (bool, default False): Whether to JIT compile the optimizer for speed.
- eps (float, default 1.0): Small constant for numerical stability.
Returns
- hits_df (pl.DataFrame):
DataFrame containing called motif hits with columns:
- peak_id: Region index
- motif_id: Motif index
- hit_start: Start position of hit
- hit_coefficient: Hit strength coefficient
- hit_similarity: Cosine similarity with motif
- hit_importance: Total contribution score in hit region
- hit_importance_sq: Sum of squared contributions (for normalization)
- qc_df (pl.DataFrame):
DataFrame containing quality control metrics with columns:
- peak_id: Region index
- nll: Final negative log likelihood
- dual_gap: Final duality gap
- num_steps: Number of optimization steps
- step_size: Final step size
- global_scale: Region-level scaling factor
Notes
The algorithm solves the optimization problem:
minimize_c: ||contribs - Σⱼ convolve(c * scale, cwms[j]) * sequences||²₂ + Σⱼ λⱼ||c[:,j]||₁
subject to: c ≥ 0
where c[i,j] represents the strength of motif j at position i.
The importance scaling balances reconstruction across different motifs and positions based on the local contribution magnitude.
Examples
>>> hits_df, qc_df = fit_contribs(
... cwms=motif_cwms,
... contribs=contrib_scores,
... sequences=onehot_seqs,
... cwm_trim_mask=trim_masks,
... use_hypothetical=False,
... lambdas=np.array([0.7, 0.8]),
... step_size_max=3.0,
... step_size_min=0.08,
... sqrt_transform=False,
... convergence_tol=0.0005,
... max_steps=10000,
... batch_size=1000,
... step_adjust=0.7,
... post_filter=True,
... device=None,
... compile_optimizer=False
... )