finemo.hitcaller

Hit caller module implementing the Fi-NeMo motif instance calling algorithm.

This module provides the core functionality for identifying transcription factor binding motif instances in neural network contribution scores using a competitive optimization approach based on proximal gradient descent.

The main algorithm fits a sparse linear model where contribution scores are reconstructed as a weighted combination of motif contribution weight matrices (CWMs) at specific genomic positions. The sparsity constraint ensures that only the most significant motif instances are called.

  1"""Hit caller module implementing the Fi-NeMo motif instance calling algorithm.
  2
  3This module provides the core functionality for identifying transcription factor
  4binding motif instances in neural network contribution scores using a competitive
  5optimization approach based on proximal gradient descent.
  6
  7The main algorithm fits a sparse linear model where contribution scores are
  8reconstructed as a weighted combination of motif contribution weight matrices (CWMs)
  9at specific genomic positions. The sparsity constraint ensures that only the most
 10significant motif instances are called.
 11"""
 12
 13import warnings
 14from typing import Tuple, Union, Optional, Dict, List
 15from abc import ABC, abstractmethod
 16
 17import numpy as np
 18from numpy import ndarray
 19import torch
 20import torch.nn.functional as F
 21from torch import Tensor
 22import polars as pl
 23from jaxtyping import Float, Int, Bool
 24
 25from tqdm import tqdm
 26
 27
 28def prox_grad_step(
 29    coefficients: Float[Tensor, "B M P"],
 30    importance_scale: Float[Tensor, "B 1 P"],
 31    cwms: Float[Tensor, "M 4 W"],
 32    contribs: Float[Tensor, "B 4 L"],
 33    sequences: Union[Int[Tensor, "B 4 L"], int],
 34    lambdas: Float[Tensor, "1 M 1"],
 35    step_sizes: Float[Tensor, "B 1 1"],
 36) -> Tuple[Float[Tensor, "B M P"], Float[Tensor, " B"], Float[Tensor, " B"]]:
 37    """Perform a proximal gradient descent optimization step for non-negative lasso.
 38
 39    This function implements a single optimization step of the Fi-NeMo algorithm,
 40    which uses proximal gradient descent to solve a sparse reconstruction problem.
 41    The goal is to represent contribution scores as a sparse linear combination
 42    of motif contribution weight matrices (CWMs).
 43
 44    Dimension notation:
 45    - B = batch size (number of regions processed simultaneously)
 46    - M = number of motifs
 47    - L = sequence length
 48    - W = motif width (length of each motif)
 49    - P = L - W + 1 (number of valid motif positions)
 50
 51    Parameters
 52    ----------
 53    coefficients : Float[Tensor, "B M P"]
 54        Current coefficient matrix representing motif instance strengths.
 55    importance_scale : Float[Tensor, "B 1 P"]
 56        Scaling factors for importance-weighted reconstruction.
 57    cwms : Float[Tensor, "M 4 W"]
 58        Motif contribution weight matrices for all motifs.
 59        4 represents the DNA bases (A, C, G, T).
 60    contribs : Float[Tensor, "B 4 L"]
 61        Target contribution scores to reconstruct.
 62    sequences : Float[Tensor, "B 4 L"] | int
 63        One-hot encoded DNA sequences. Can be a scalar (1) for hypothetical mode.
 64    lambdas : Float[Tensor, "1 M 1"]
 65        L1 regularization weights for each motif.
 66    step_sizes : Float[Tensor, "B 1 1"]
 67        Optimization step sizes for each batch element.
 68
 69    Returns
 70    -------
 71    c_next : Float[Tensor, "B M P"]
 72        Updated coefficient matrix after the optimization step (shape: batch_size × motifs × positions).
 73    dual_gap : Float[Tensor, " B"]
 74        Duality gap for convergence assessment (shape: batch_size).
 75    nll : Float[Tensor, " B"]
 76        Negative log likelihood (proportional to MSE, shape: batch_size).
 77
 78    Notes
 79    -----
 80    The algorithm uses proximal gradient descent to solve:
 81
 82    minimize_c: ||contribs - conv_transpose(c * importance_scale, cwms) * sequences||²₂ + λ||c||₁
 83
 84    subject to: c ≥ 0
 85
 86    References
 87    ----------
 88    - Proximal gradient descent: https://yuxinchen2020.github.io/ele520_math_data/lectures/lasso_algorithm_extension.pdf, slide 22
 89    - Duality gap computation: https://stanford.edu/~boyd/papers/pdf/l1_ls.pdf, Section III
 90    """
 91    # Forward pass: convolution operations require specific tensor layouts
 92    coef_adj = coefficients * importance_scale
 93    pred_unmasked = F.conv_transpose1d(coef_adj, cwms)  # (B, 4, L)
 94    pred = (
 95        pred_unmasked * sequences
 96    )  # (B, 4, L), element-wise masking for projected mode
 97
 98    # Compute gradient * -1
 99    residuals = contribs - pred  # (B, 4, L)
100    ngrad = F.conv1d(residuals, cwms) * importance_scale  # (B, M, P)
101
102    # Negative log likelihood (proportional to MSE)
103    nll = (residuals**2).sum(dim=(1, 2))  # (B)
104
105    # Compute duality gap for convergence assessment
106    dual_norm = (ngrad / lambdas).amax(dim=(1, 2))  # (B)
107    dual_scale = (torch.clamp(1 / dual_norm, max=1.0) ** 2 + 1) / 2  # (B)
108    nll_scaled = nll * dual_scale  # (B)
109
110    dual_diff = (residuals * contribs).sum(dim=(1, 2))  # (B)
111    l1_term = (torch.abs(coefficients).sum(dim=2, keepdim=True) * lambdas).sum(
112        dim=(1, 2)
113    )  # (B)
114    dual_gap = (nll_scaled - dual_diff + l1_term).abs()  # (B)
115
116    # Compute proximal gradient descent step
117    c_next = coefficients + step_sizes * (ngrad - lambdas)  # (B, M, P)
118    c_next = F.relu(c_next)  # Ensure non-negativity constraint
119
120    return c_next, dual_gap, nll
121
122
123def optimizer_step(
124    cwms: Float[Tensor, "M 4 W"],
125    contribs: Float[Tensor, "B 4 L"],
126    importance_scale: Float[Tensor, "B 1 P"],
127    sequences: Union[Int[Tensor, "B 4 L"], int],
128    coef_inter: Float[Tensor, "B M P"],
129    coef: Float[Tensor, "B M P"],
130    i: Float[Tensor, "B 1 1"],
131    step_sizes: Float[Tensor, "B 1 1"],
132    L: int,
133    lambdas: Float[Tensor, "1 M 1"],
134) -> Tuple[
135    Float[Tensor, "B M P"],
136    Float[Tensor, "B M P"],
137    Float[Tensor, " B"],
138    Float[Tensor, " B"],
139]:
140    """Perform a non-negative lasso optimizer step with Nesterov momentum.
141
142    This function combines proximal gradient descent with momentum acceleration
143    to improve convergence speed while maintaining the non-negative constraint
144    on coefficients.
145
146    Dimension notation:
147    - B = batch size (number of regions processed simultaneously)
148    - M = number of motifs
149    - L = sequence length
150    - W = motif width (length of each motif)
151    - P = L - W + 1 (number of valid motif positions)
152
153    Parameters
154    ----------
155    cwms : Float[Tensor, "M 4 W"]
156        Motif contribution weight matrices.
157    contribs : Float[Tensor, "B 4 L"]
158        Target contribution scores.
159    importance_scale : Float[Tensor, "B 1 P"]
160        Importance scaling factors.
161    sequences : Union[Int[Tensor, "B 4 L"], int]
162        One-hot encoded sequences or scalar for hypothetical mode.
163    coef_inter : Float[Tensor, "B M P"]
164        Intermediate coefficient matrix (with momentum).
165    coef : Float[Tensor, "B M P"]
166        Current coefficient matrix.
167    i : Float[Tensor, "B 1 1"]
168        Iteration counter for each batch element.
169    step_sizes :  Float[Tensor, "B 1 1"]
170        Step sizes for optimization.
171    L : int
172        Sequence length for normalization.
173    lambdas : Float[Tensor, "1 M 1"]
174        Regularization parameters.
175
176    Returns
177    -------
178    coef_inter : Float[Tensor, "B M P"]
179        Updated intermediate coefficients with momentum (shape: batch_size × motifs × positions).
180    coef : Float[Tensor, "B M P"]
181        Updated coefficient matrix (shape: batch_size × motifs × positions).
182    gap : Float[Tensor, " B"]
183        Normalized duality gap (shape: batch_size).
184    nll : Float[Tensor, " B"]
185        Normalized negative log likelihood (shape: batch_size).
186
187    Notes
188    -----
189    Uses Nesterov momentum with momentum coefficient i/(i+3) for improved
190    convergence properties. The duality gap and NLL are normalized by
191    sequence length for scale-invariant convergence assessment.
192
193    References
194    ----------
195    https://yuxinchen2020.github.io/ele520_math_data/lectures/lasso_algorithm_extension.pdf, slides 22, 27
196    """
197    coef_prev = coef
198
199    # Proximal gradient descent step
200    coef, gap, nll = prox_grad_step(
201        coef_inter, importance_scale, cwms, contribs, sequences, lambdas, step_sizes
202    )
203    gap = gap / L
204    nll = nll / (2 * L)
205
206    # Compute updated coefficients with Nesterov momentum
207    mom_term = i / (i + 3.0)
208    coef_inter = (1 + mom_term) * coef - mom_term * coef_prev
209
210    return coef_inter, coef, gap, nll
211
212
213def _to_channel_last_layout(tensor: Tensor, **kwargs) -> torch.Tensor:
214    """Convert tensor to channel-last memory layout for optimized convolution operations.
215
216    Parameters
217    ----------
218    tensor : torch.Tensor
219        Input tensor to convert.
220    **kwargs
221        Additional keyword arguments passed to tensor.to().
222
223    Returns
224    -------
225    torch.Tensor
226        Tensor with channel-last memory layout.
227    """
228    return (
229        tensor[:, :, :, None].to(memory_format=torch.channels_last, **kwargs).squeeze(3)
230    )
231
232
233def _signed_sqrt(x: torch.Tensor) -> torch.Tensor:
234    """Apply signed square root transformation to input tensor.
235
236    This transformation preserves the sign while applying square root to the
237    absolute value, which can help with numerical stability and gradient flow.
238
239    Parameters
240    ----------
241    x : torch.Tensor
242        Input tensor.
243
244    Returns
245    -------
246    torch.Tensor
247        Transformed tensor with same shape as input.
248    """
249    return torch.sign(x) * torch.sqrt(torch.abs(x))
250
251
252class BatchLoaderBase(ABC):
253    """Base class for loading batches of contribution scores and sequences.
254
255    This class provides common functionality for different input formats
256    including batch indexing and padding for consistent batch sizes.
257
258    Dimension notation:
259    - N = number of sequences/regions in dataset
260    - L = sequence length
261    - B = batch size (number of regions processed simultaneously)
262
263    Parameters
264    ----------
265    contribs : Union[Float[Tensor, "N 4 L"], Float[Tensor, "N L"]]
266        Contribution scores array.
267    sequences : Int[Tensor, "N 4 L"]
268        One-hot encoded sequences array.
269    L : int
270        Sequence length.
271    device : torch.device
272        Target device for tensor operations.
273    """
274
275    def __init__(
276        self,
277        contribs: Union[Float[Tensor, "N 4 L"], Float[Tensor, "N L"]],
278        sequences: Int[Tensor, "N 4 L"],
279        L: int,
280        device: torch.device,
281    ) -> None:
282        self.contribs = contribs
283        self.sequences = sequences
284        self.L = L
285        self.device = device
286
287    def _get_inds_and_pad_lens(
288        self, start: int, end: int
289    ) -> Tuple[Int[Tensor, " Z"], Tuple[int, ...]]:
290        """Get indices and padding lengths for batch loading.
291
292        Parameters
293        ----------
294        start : int
295            Start index for batch.
296        end : int
297            End index for batch.
298
299        Returns
300        -------
301        inds : Int[Tensor, " Z"]
302            Padded indices tensor with -1 for padding positions (shape: padded_batch_size).
303        pad_lens : tuple
304            Padding specification for F.pad (left, right, top, bottom, front, back).
305        """
306        N = end - start
307        end = min(end, self.contribs.shape[0])
308        overhang = N - (end - start)
309        pad_lens = (0, 0, 0, 0, 0, overhang)
310
311        inds = F.pad(
312            torch.arange(start, end, dtype=torch.int), (0, overhang), value=-1
313        ).to(device=self.device)
314
315        return inds, pad_lens
316
317    @abstractmethod
318    def load_batch(
319        self, start: int, end: int
320    ) -> Tuple[
321        Float[Tensor, "B 4 L"], Union[Int[Tensor, "B 4 L"], int], Int[Tensor, " B"]
322    ]:
323        """Load a batch of data.
324
325        Dimension notation:
326        - B = batch size (number of regions in this batch)
327        - L = sequence length
328
329        Parameters
330        ----------
331        start : int
332            Start index (used by subclasses).
333        end : int
334            End index (used by subclasses).
335
336        Returns
337        -------
338        contribs_batch : Float[Tensor, "B 4 L"]
339            Batch of contribution scores (shape: batch_size × 4_bases × L).
340        sequences_batch : Union[Int[Tensor, "B 4 L"], int]
341            Batch of one-hot encoded sequences (shape: batch_size × 4_bases × L) or scalar 1 for hypothetical mode.
342        inds_batch : Int[Tensor, " B"]
343            Batch indices mapping to original sequence indices (shape: batch_size).
344
345        Notes
346        -----
347        This is an abstract method that must be implemented by subclasses.
348        Parameters are intentionally unused in the base implementation.
349        """
350        pass
351
352
353class BatchLoaderCompactFmt(BatchLoaderBase):
354    """Batch loader for compact format contribution scores.
355
356    Handles contribution scores in shape (N, L) representing projected
357    scores that need to be broadcasted to (N, 4, L) format.
358    """
359
360    def load_batch(
361        self, start: int, end: int
362    ) -> Tuple[Float[Tensor, "B 4 L"], Int[Tensor, "B 4 L"], Int[Tensor, " B"]]:
363        inds, pad_lens = self._get_inds_and_pad_lens(start, end)
364
365        contribs_compact = F.pad(self.contribs[start:end, None, :], pad_lens)
366        contribs_batch = _to_channel_last_layout(
367            contribs_compact, device=self.device, dtype=torch.float32
368        )
369        sequences_batch = F.pad(self.sequences[start:end, :, :], pad_lens)  # (B, 4, L)
370        sequences_batch = _to_channel_last_layout(
371            sequences_batch, device=self.device, dtype=torch.int8
372        )
373
374        contribs_batch = contribs_batch * sequences_batch  # (B, 4, L)
375
376        return contribs_batch, sequences_batch, inds
377
378
379class BatchLoaderProj(BatchLoaderBase):
380    """Batch loader for projected contribution scores.
381
382    Handles contribution scores in shape (N, 4, L) where scores are
383    element-wise multiplied by one-hot sequences to get projected contributions.
384    """
385
386    def load_batch(
387        self, start: int, end: int
388    ) -> Tuple[Float[Tensor, "B 4 L"], Int[Tensor, "B 4 L"], Int[Tensor, " B"]]:
389        inds, pad_lens = self._get_inds_and_pad_lens(start, end)
390
391        contribs_hyp = F.pad(self.contribs[start:end, :, :], pad_lens)
392        contribs_hyp = _to_channel_last_layout(
393            contribs_hyp, device=self.device, dtype=torch.float32
394        )
395        sequences_batch = F.pad(self.sequences[start:end, :, :], pad_lens)  # (B, 4, L)
396        sequences_batch = _to_channel_last_layout(
397            sequences_batch, device=self.device, dtype=torch.int8
398        )
399        contribs_batch = contribs_hyp * sequences_batch
400
401        return contribs_batch, sequences_batch, inds
402
403
404class BatchLoaderHyp(BatchLoaderBase):
405    """Batch loader for hypothetical contribution scores.
406
407    Handles hypothetical contribution scores in shape (N, 4, L) where
408    scores represent counterfactual effects of base substitutions.
409    """
410
411    def load_batch(
412        self, start: int, end: int
413    ) -> Tuple[Float[Tensor, "B 4 L"], int, Int[Tensor, " B"]]:
414        inds, pad_lens = self._get_inds_and_pad_lens(start, end)
415
416        contribs_batch = F.pad(self.contribs[start:end, :, :], pad_lens)
417        contribs_batch = _to_channel_last_layout(
418            contribs_batch, device=self.device, dtype=torch.float32
419        )
420
421        return contribs_batch, 1, inds
422
423
424def fit_contribs(
425    cwms: Float[ndarray, "M 4 W"],
426    contribs: Union[Float[ndarray, "N 4 L"], Float[ndarray, "N L"]],
427    sequences: Int[ndarray, "N 4 L"],
428    cwm_trim_mask: Float[ndarray, "M W"],
429    use_hypothetical: bool,
430    lambdas: Float[ndarray, " M"],
431    step_size_max: float = 3.0,
432    step_size_min: float = 0.08,
433    sqrt_transform: bool = False,
434    convergence_tol: float = 0.0005,
435    max_steps: int = 10000,
436    batch_size: int = 2000,
437    step_adjust: float = 0.7,
438    post_filter: bool = True,
439    device: Optional[torch.device] = None,
440    compile_optimizer: bool = False,
441    eps: float = 1.0,
442) -> Tuple[pl.DataFrame, pl.DataFrame]:
443    """Call motif hits by fitting sparse linear model to contribution scores.
444
445    This is the main function implementing the Fi-NeMo algorithm. It identifies
446    motif instances by solving a sparse reconstruction problem where contribution
447    scores are approximated as a linear combination of motif CWMs at specific
448    positions. The optimization uses proximal gradient descent with momentum.
449
450    Parameters
451    ----------
452    cwms : Float[ndarray, "M 4 W"]
453        Motif contribution weight matrices where:
454        - M = number of motifs (transcription factor binding patterns)
455        - 4 = DNA bases (A, C, G, T dimensions)
456        - W = motif width (length of each motif pattern)
457    contribs : Float[ndarray, "N 4 L"] | Float[ndarray, "N L"]
458        Neural network contribution scores where:
459        - N = number of regions in dataset
460        - L = sequence length
461        Can be hypothetical (N, 4, L) or projected (N, L) format.
462    sequences : Int[ndarray, "N 4 L"]
463        One-hot encoded DNA sequences (shape: num_regions × 4_bases × L).
464    cwm_trim_mask : Float[ndarray, "M W"]
465        Binary mask indicating which positions of each CWM to use (shape: num_motifs × motif_width).
466    use_hypothetical : bool
467        Whether to use hypothetical contribution scores (True) or
468        projected scores (False).
469    lambdas : Float[ndarray, " M"]
470        L1 regularization weights for each motif (shape: num_motifs).
471    step_size_max : float, default 3.0
472        Maximum optimization step size.
473    step_size_min : float, default 0.08
474        Minimum optimization step size (for convergence failure detection).
475    sqrt_transform : bool, default False
476        Whether to apply signed square root transformation to inputs.
477    convergence_tol : float, default 0.0005
478        Convergence tolerance based on duality gap.
479    max_steps : int, default 10000
480        Maximum number of optimization steps.
481    batch_size : int, default 2000
482        Number of regions to process simultaneously.
483    step_adjust : float, default 0.7
484        Factor to reduce step size when optimization diverges.
485    post_filter : bool, default True
486        Whether to filter hits based on similarity threshold.
487    device : torch.device, optional
488        Target device for computation. Auto-detected if None.
489    compile_optimizer : bool, default False
490        Whether to JIT compile the optimizer for speed.
491    eps : float, default 1.0
492        Small constant for numerical stability.
493
494    Returns
495    -------
496    hits_df : pl.DataFrame
497        DataFrame containing called motif hits with columns:
498        - peak_id: Region index
499        - motif_id: Motif index
500        - hit_start: Start position of hit
501        - hit_coefficient: Hit strength coefficient
502        - hit_similarity: Cosine similarity with motif
503        - hit_importance: Total contribution score in hit region
504        - hit_importance_sq: Sum of squared contributions (for normalization)
505    qc_df : pl.DataFrame
506        DataFrame containing quality control metrics with columns:
507        - peak_id: Region index
508        - nll: Final negative log likelihood
509        - dual_gap: Final duality gap
510        - num_steps: Number of optimization steps
511        - step_size: Final step size
512        - global_scale: Region-level scaling factor
513
514    Notes
515    -----
516    The algorithm solves the optimization problem:
517
518    minimize_c: ||contribs - Σⱼ convolve(c * scale, cwms[j]) * sequences||²₂ + Σⱼ λⱼ||c[:,j]||₁
519
520    subject to: c ≥ 0
521
522    where c[i,j] represents the strength of motif j at position i.
523
524    The importance scaling balances reconstruction across different
525    motifs and positions based on the local contribution magnitude.
526
527    Examples
528    --------
529    >>> hits_df, qc_df = fit_contribs(
530    ...     cwms=motif_cwms,
531    ...     contribs=contrib_scores,
532    ...     sequences=onehot_seqs,
533    ...     cwm_trim_mask=trim_masks,
534    ...     use_hypothetical=False,
535    ...     lambdas=np.array([0.7, 0.8]),
536    ...     step_size_max=3.0,
537    ...     step_size_min=0.08,
538    ...     sqrt_transform=False,
539    ...     convergence_tol=0.0005,
540    ...     max_steps=10000,
541    ...     batch_size=1000,
542    ...     step_adjust=0.7,
543    ...     post_filter=True,
544    ...     device=None,
545    ...     compile_optimizer=False
546    ... )
547    """
548    M, _, W = cwms.shape
549    N, _, L = sequences.shape
550
551    B = batch_size  # Using uppercase for consistency with dimension notation
552
553    if device is None:
554        if torch.cuda.is_available():
555            device = torch.device("cuda")
556        else:
557            device = torch.device("cpu")
558            warnings.warn("No GPU available. Running on CPU.", RuntimeWarning)
559
560    # Compile optimizer if requested
561    global optimizer_step
562    if compile_optimizer:
563        optimizer_step = torch.compile(optimizer_step, fullgraph=True)
564
565    # Convert inputs to PyTorch tensors with proper device placement
566    cwms_tensor: torch.Tensor = torch.from_numpy(cwms)
567    contribs_tensor: torch.Tensor = torch.from_numpy(contribs)
568    sequences_tensor: torch.Tensor = torch.from_numpy(sequences)
569    cwm_trim_mask_tensor = torch.from_numpy(cwm_trim_mask)[:, None, :].repeat(1, 4, 1)
570    lambdas_tensor: torch.Tensor = torch.from_numpy(lambdas)[None, :, None].to(
571        device=device, dtype=torch.float32
572    )
573
574    # Convert to channel-last layout for optimized convolution operations
575    cwms_tensor = _to_channel_last_layout(
576        cwms_tensor, device=device, dtype=torch.float32
577    )
578    cwm_trim_mask_tensor = _to_channel_last_layout(
579        cwm_trim_mask_tensor, device=device, dtype=torch.float32
580    )
581    cwms_tensor = cwms_tensor * cwm_trim_mask_tensor  # Apply trimming mask
582
583    if sqrt_transform:
584        cwms_tensor = _signed_sqrt(cwms_tensor)
585        cwm_norm = (cwms_tensor**2).sum(dim=(1, 2)).sqrt()
586        cwms_tensor = cwms_tensor / cwm_norm[:, None, None]
587
588    # Initialize batch loader
589    if len(contribs_tensor.shape) == 3:
590        if use_hypothetical:
591            batch_loader = BatchLoaderHyp(contribs_tensor, sequences_tensor, L, device)
592        else:
593            batch_loader = BatchLoaderProj(contribs_tensor, sequences_tensor, L, device)
594    elif len(contribs_tensor.shape) == 2:
595        if use_hypothetical:
596            raise ValueError(
597                "Input regions do not contain hypothetical contribution scores"
598            )
599        else:
600            batch_loader = BatchLoaderCompactFmt(
601                contribs_tensor, sequences_tensor, L, device
602            )
603    else:
604        raise ValueError(
605            f"Input contributions array is of incorrect shape {contribs_tensor.shape}"
606        )
607
608    # Initialize output container objects
609    hit_idxs_lst: List[ndarray] = []
610    coefficients_lst: List[ndarray] = []
611    similarity_lst: List[ndarray] = []
612    importance_lst: List[ndarray] = []
613    importance_sq_lst: List[ndarray] = []
614    qc_lsts: Dict[str, List[ndarray]] = {
615        "nll": [],
616        "dual_gap": [],
617        "num_steps": [],
618        "step_size": [],
619        "global_scale": [],
620        "peak_id": [],
621    }
622
623    # Initialize buffers for optimizer
624    coef_inter: Float[Tensor, "B M P"] = torch.zeros(
625        (B, M, L - W + 1)
626    )  # (B, M, P) where P = L - W + 1
627    coef_inter = _to_channel_last_layout(coef_inter, device=device, dtype=torch.float32)
628    coef: Float[Tensor, "B M P"] = torch.zeros_like(coef_inter)
629    i: Float[Tensor, "B 1 1"] = torch.zeros((B, 1, 1), dtype=torch.int, device=device)
630    step_sizes: Float[Tensor, "B 1 1"] = torch.full(
631        (B, 1, 1), step_size_max, dtype=torch.float32, device=device
632    )
633
634    converged: Bool[Tensor, " B"] = torch.full(
635        (B,), True, dtype=torch.bool, device=device
636    )
637    num_load = B
638
639    contribs_buf: Float[Tensor, "B 4 L"] = torch.zeros((B, 4, L))
640    contribs_buf = _to_channel_last_layout(
641        contribs_buf, device=device, dtype=torch.float32
642    )
643
644    seqs_buf: Union[Int[Tensor, "B 4 L"], int]
645    if use_hypothetical:
646        seqs_buf = 1
647    else:
648        seqs_buf = torch.zeros((B, 4, L))
649        seqs_buf = _to_channel_last_layout(seqs_buf, device=device, dtype=torch.int8)
650
651    importance_scale_buf: Float[Tensor, "B M P"] = torch.zeros((B, M, L - W + 1))
652    importance_scale_buf = _to_channel_last_layout(
653        importance_scale_buf, device=device, dtype=torch.float32
654    )
655
656    inds_buf: Int[Tensor, " B"] = torch.zeros((B,), dtype=torch.int, device=device)
657    global_scale_buf: Float[Tensor, " B"] = torch.zeros(
658        (B,), dtype=torch.float, device=device
659    )
660
661    with tqdm(disable=None, unit="regions", total=N, ncols=120) as pbar:
662        num_complete = 0
663        next_ind = 0
664        while num_complete < N:
665            # Retire converged peaks and fill buffer with new data
666            if num_load > 0:
667                load_start = next_ind
668                load_end = load_start + num_load
669                next_ind = min(load_end, contribs_tensor.shape[0])
670
671                batch_data = batch_loader.load_batch(int(load_start), int(load_end))
672                contribs_batch, seqs_batch, inds_batch = batch_data
673
674                if sqrt_transform:
675                    contribs_batch = _signed_sqrt(contribs_batch)
676
677                global_scale_batch = ((contribs_batch**2).sum(dim=(1, 2)) / L).sqrt()
678                contribs_batch = torch.nan_to_num(
679                    contribs_batch / global_scale_batch[:, None, None]
680                )
681
682                importance_scale_batch = (
683                    F.conv1d(contribs_batch**2, cwm_trim_mask_tensor) + eps
684                ) ** (-0.5)
685                importance_scale_batch = importance_scale_batch.clamp(max=10)
686
687                contribs_buf[converged, :, :] = contribs_batch
688                if not use_hypothetical:
689                    seqs_buf[converged, :, :] = seqs_batch  # type: ignore
690
691                importance_scale_buf[converged, :, :] = importance_scale_batch
692
693                inds_buf[converged] = inds_batch
694                global_scale_buf[converged] = global_scale_batch
695
696                coef_inter[converged, :, :] *= 0
697                coef[converged, :, :] *= 0
698                i[converged] *= 0
699
700                step_sizes[converged] = step_size_max
701
702            # Optimization step
703            coef_inter, coef, gap, nll = optimizer_step(
704                cwms_tensor,
705                contribs_buf,
706                importance_scale_buf,
707                seqs_buf,
708                coef_inter,
709                coef,
710                i,
711                step_sizes,
712                L,
713                lambdas_tensor,
714            )
715            i += 1
716
717            # Assess convergence of each peak being optimized. Reset diverged peaks with lower step size.
718            active = inds_buf >= 0
719
720            diverged = ~torch.isfinite(gap) & active
721            coef_inter[diverged, :, :] *= 0
722            coef[diverged, :, :] *= 0
723            i[diverged] *= 0
724            step_sizes[diverged, :, :] *= step_adjust
725
726            timeouts = (i > max_steps).squeeze() & active
727            if timeouts.sum().item() > 0:
728                timeout_inds = inds_buf[timeouts]
729                for ind in timeout_inds:
730                    warnings.warn(
731                        f"Region {ind} has not converged within max_steps={max_steps} iterations.",
732                        RuntimeWarning,
733                    )
734
735            fails = (step_sizes < step_size_min).squeeze() & active
736            if fails.sum().item() > 0:
737                fail_inds = inds_buf[fails]
738                for ind in fail_inds:
739                    warnings.warn(f"Optimizer failed for region {ind}.", RuntimeWarning)
740
741            converged = ((gap <= convergence_tol) | timeouts | fails) & active
742            num_load = converged.sum().item()
743
744            # Extract hits from converged peaks
745            if num_load > 0:
746                inds_out = inds_buf[converged]
747                global_scale_out = global_scale_buf[converged]
748
749                # Compute hit scores
750                coef_out = coef[converged, :, :]
751                importance_scale_out_dense = importance_scale_buf[converged, :, :]
752                importance_sq = importance_scale_out_dense ** (-2) - eps
753                xcor_scale = importance_sq.sqrt()
754
755                contribs_converged = contribs_buf[converged, :, :]
756                importance_sum_out_dense = F.conv1d(
757                    torch.abs(contribs_converged), cwm_trim_mask_tensor
758                )
759                xcov_out_dense = F.conv1d(contribs_converged, cwms_tensor)
760                # xcov_out_dense = F.conv1d(torch.abs(contribs_converged), cwms_tensor)
761                xcor_out_dense = xcov_out_dense / xcor_scale
762
763                if post_filter:
764                    coef_out = coef_out * (xcor_out_dense >= lambdas_tensor)
765
766                # Extract hit coordinates using sparse tensor representation
767                coef_out = coef_out.to_sparse()
768
769                # Tensor indexing operations for hit extraction
770                hit_idxs_out = torch.clone(coef_out.indices())  # Sparse tensor indices
771                hit_idxs_out[0, :] = F.embedding(
772                    hit_idxs_out[0, :], inds_out[:, None]
773                ).squeeze()  # Embedding lookup with complex indexing
774                # Map buffer index to peak index
775
776                ind_tuple = torch.unbind(coef_out.indices())
777                importance_out = importance_sum_out_dense[ind_tuple]
778                importance_sq_out = importance_sq[ind_tuple]
779                xcor_out = xcor_out_dense[ind_tuple]
780
781                scores_out_raw = coef_out.values()
782
783                # Store outputs
784                gap_out = gap[converged]
785                nll_out = nll[converged]
786                step_out = i[converged, 0, 0]
787                step_sizes_out = step_sizes[converged, 0, 0]
788
789                hit_idxs_lst.append(hit_idxs_out.numpy(force=True).T)
790                coefficients_lst.append(scores_out_raw.numpy(force=True))
791                similarity_lst.append(xcor_out.numpy(force=True))
792                importance_lst.append(importance_out.numpy(force=True))
793                importance_sq_lst.append(importance_sq_out.numpy(force=True))
794
795                qc_lsts["nll"].append(nll_out.numpy(force=True))
796                qc_lsts["dual_gap"].append(gap_out.numpy(force=True))
797                qc_lsts["num_steps"].append(step_out.numpy(force=True))
798                qc_lsts["global_scale"].append(global_scale_out.numpy(force=True))
799                qc_lsts["step_size"].append(step_sizes_out.numpy(force=True))
800                qc_lsts["peak_id"].append(inds_out.numpy(force=True).astype(np.uint32))
801
802                num_complete += num_load
803                pbar.update(num_load)
804
805    # Merge outputs into arrays
806    hit_idxs = np.concatenate(hit_idxs_lst, axis=0)
807    scores_coefficient = np.concatenate(coefficients_lst, axis=0)
808    scores_similarity = np.concatenate(similarity_lst, axis=0)
809    scores_importance = np.concatenate(importance_lst, axis=0)
810    scores_importance_sq = np.concatenate(importance_sq_lst, axis=0)
811
812    hits: Dict[str, ndarray] = {
813        "peak_id": hit_idxs[:, 0].astype(np.uint32),
814        "motif_id": hit_idxs[:, 1].astype(np.uint32),
815        "hit_start": hit_idxs[:, 2],
816        "hit_coefficient": scores_coefficient,
817        "hit_similarity": scores_similarity,
818        "hit_importance": scores_importance,
819        "hit_importance_sq": scores_importance_sq,
820    }
821
822    qc: Dict[str, ndarray] = {k: np.concatenate(v, axis=0) for k, v in qc_lsts.items()}
823
824    hits_df = pl.DataFrame(hits)
825    qc_df = pl.DataFrame(qc)
826
827    return hits_df, qc_df
def prox_grad_step( coefficients: jaxtyping.Float[Tensor, 'B M P'], importance_scale: jaxtyping.Float[Tensor, 'B 1 P'], cwms: jaxtyping.Float[Tensor, 'M 4 W'], contribs: jaxtyping.Float[Tensor, 'B 4 L'], sequences: Union[jaxtyping.Int[Tensor, 'B 4 L'], int], lambdas: jaxtyping.Float[Tensor, '1 M 1'], step_sizes: jaxtyping.Float[Tensor, 'B 1 1']) -> Tuple[jaxtyping.Float[Tensor, 'B M P'], jaxtyping.Float[Tensor, 'B'], jaxtyping.Float[Tensor, 'B']]:
 29def prox_grad_step(
 30    coefficients: Float[Tensor, "B M P"],
 31    importance_scale: Float[Tensor, "B 1 P"],
 32    cwms: Float[Tensor, "M 4 W"],
 33    contribs: Float[Tensor, "B 4 L"],
 34    sequences: Union[Int[Tensor, "B 4 L"], int],
 35    lambdas: Float[Tensor, "1 M 1"],
 36    step_sizes: Float[Tensor, "B 1 1"],
 37) -> Tuple[Float[Tensor, "B M P"], Float[Tensor, " B"], Float[Tensor, " B"]]:
 38    """Perform a proximal gradient descent optimization step for non-negative lasso.
 39
 40    This function implements a single optimization step of the Fi-NeMo algorithm,
 41    which uses proximal gradient descent to solve a sparse reconstruction problem.
 42    The goal is to represent contribution scores as a sparse linear combination
 43    of motif contribution weight matrices (CWMs).
 44
 45    Dimension notation:
 46    - B = batch size (number of regions processed simultaneously)
 47    - M = number of motifs
 48    - L = sequence length
 49    - W = motif width (length of each motif)
 50    - P = L - W + 1 (number of valid motif positions)
 51
 52    Parameters
 53    ----------
 54    coefficients : Float[Tensor, "B M P"]
 55        Current coefficient matrix representing motif instance strengths.
 56    importance_scale : Float[Tensor, "B 1 P"]
 57        Scaling factors for importance-weighted reconstruction.
 58    cwms : Float[Tensor, "M 4 W"]
 59        Motif contribution weight matrices for all motifs.
 60        4 represents the DNA bases (A, C, G, T).
 61    contribs : Float[Tensor, "B 4 L"]
 62        Target contribution scores to reconstruct.
 63    sequences : Float[Tensor, "B 4 L"] | int
 64        One-hot encoded DNA sequences. Can be a scalar (1) for hypothetical mode.
 65    lambdas : Float[Tensor, "1 M 1"]
 66        L1 regularization weights for each motif.
 67    step_sizes : Float[Tensor, "B 1 1"]
 68        Optimization step sizes for each batch element.
 69
 70    Returns
 71    -------
 72    c_next : Float[Tensor, "B M P"]
 73        Updated coefficient matrix after the optimization step (shape: batch_size × motifs × positions).
 74    dual_gap : Float[Tensor, " B"]
 75        Duality gap for convergence assessment (shape: batch_size).
 76    nll : Float[Tensor, " B"]
 77        Negative log likelihood (proportional to MSE, shape: batch_size).
 78
 79    Notes
 80    -----
 81    The algorithm uses proximal gradient descent to solve:
 82
 83    minimize_c: ||contribs - conv_transpose(c * importance_scale, cwms) * sequences||²₂ + λ||c||₁
 84
 85    subject to: c ≥ 0
 86
 87    References
 88    ----------
 89    - Proximal gradient descent: https://yuxinchen2020.github.io/ele520_math_data/lectures/lasso_algorithm_extension.pdf, slide 22
 90    - Duality gap computation: https://stanford.edu/~boyd/papers/pdf/l1_ls.pdf, Section III
 91    """
 92    # Forward pass: convolution operations require specific tensor layouts
 93    coef_adj = coefficients * importance_scale
 94    pred_unmasked = F.conv_transpose1d(coef_adj, cwms)  # (B, 4, L)
 95    pred = (
 96        pred_unmasked * sequences
 97    )  # (B, 4, L), element-wise masking for projected mode
 98
 99    # Compute gradient * -1
100    residuals = contribs - pred  # (B, 4, L)
101    ngrad = F.conv1d(residuals, cwms) * importance_scale  # (B, M, P)
102
103    # Negative log likelihood (proportional to MSE)
104    nll = (residuals**2).sum(dim=(1, 2))  # (B)
105
106    # Compute duality gap for convergence assessment
107    dual_norm = (ngrad / lambdas).amax(dim=(1, 2))  # (B)
108    dual_scale = (torch.clamp(1 / dual_norm, max=1.0) ** 2 + 1) / 2  # (B)
109    nll_scaled = nll * dual_scale  # (B)
110
111    dual_diff = (residuals * contribs).sum(dim=(1, 2))  # (B)
112    l1_term = (torch.abs(coefficients).sum(dim=2, keepdim=True) * lambdas).sum(
113        dim=(1, 2)
114    )  # (B)
115    dual_gap = (nll_scaled - dual_diff + l1_term).abs()  # (B)
116
117    # Compute proximal gradient descent step
118    c_next = coefficients + step_sizes * (ngrad - lambdas)  # (B, M, P)
119    c_next = F.relu(c_next)  # Ensure non-negativity constraint
120
121    return c_next, dual_gap, nll

Perform a proximal gradient descent optimization step for non-negative lasso.

This function implements a single optimization step of the Fi-NeMo algorithm, which uses proximal gradient descent to solve a sparse reconstruction problem. The goal is to represent contribution scores as a sparse linear combination of motif contribution weight matrices (CWMs).

Dimension notation:

  • B = batch size (number of regions processed simultaneously)
  • M = number of motifs
  • L = sequence length
  • W = motif width (length of each motif)
  • P = L - W + 1 (number of valid motif positions)
Parameters
  • coefficients (Float[Tensor, "B M P"]): Current coefficient matrix representing motif instance strengths.
  • importance_scale (Float[Tensor, "B 1 P"]): Scaling factors for importance-weighted reconstruction.
  • cwms (Float[Tensor, "M 4 W"]): Motif contribution weight matrices for all motifs. 4 represents the DNA bases (A, C, G, T).
  • contribs (Float[Tensor, "B 4 L"]): Target contribution scores to reconstruct.
  • sequences (Float[Tensor, "B 4 L"] | int): One-hot encoded DNA sequences. Can be a scalar (1) for hypothetical mode.
  • lambdas (Float[Tensor, "1 M 1"]): L1 regularization weights for each motif.
  • step_sizes (Float[Tensor, "B 1 1"]): Optimization step sizes for each batch element.
Returns
  • c_next (Float[Tensor, "B M P"]): Updated coefficient matrix after the optimization step (shape: batch_size × motifs × positions).
  • dual_gap (Float[Tensor, " B"]): Duality gap for convergence assessment (shape: batch_size).
  • nll (Float[Tensor, " B"]): Negative log likelihood (proportional to MSE, shape: batch_size).
Notes

The algorithm uses proximal gradient descent to solve:

minimize_c: ||contribs - conv_transpose(c * importance_scale, cwms) * sequences||²₂ + λ||c||₁

subject to: c ≥ 0

References
def optimizer_step( cwms: jaxtyping.Float[Tensor, 'M 4 W'], contribs: jaxtyping.Float[Tensor, 'B 4 L'], importance_scale: jaxtyping.Float[Tensor, 'B 1 P'], sequences: Union[jaxtyping.Int[Tensor, 'B 4 L'], int], coef_inter: jaxtyping.Float[Tensor, 'B M P'], coef: jaxtyping.Float[Tensor, 'B M P'], i: jaxtyping.Float[Tensor, 'B 1 1'], step_sizes: jaxtyping.Float[Tensor, 'B 1 1'], L: int, lambdas: jaxtyping.Float[Tensor, '1 M 1']) -> Tuple[jaxtyping.Float[Tensor, 'B M P'], jaxtyping.Float[Tensor, 'B M P'], jaxtyping.Float[Tensor, 'B'], jaxtyping.Float[Tensor, 'B']]:
124def optimizer_step(
125    cwms: Float[Tensor, "M 4 W"],
126    contribs: Float[Tensor, "B 4 L"],
127    importance_scale: Float[Tensor, "B 1 P"],
128    sequences: Union[Int[Tensor, "B 4 L"], int],
129    coef_inter: Float[Tensor, "B M P"],
130    coef: Float[Tensor, "B M P"],
131    i: Float[Tensor, "B 1 1"],
132    step_sizes: Float[Tensor, "B 1 1"],
133    L: int,
134    lambdas: Float[Tensor, "1 M 1"],
135) -> Tuple[
136    Float[Tensor, "B M P"],
137    Float[Tensor, "B M P"],
138    Float[Tensor, " B"],
139    Float[Tensor, " B"],
140]:
141    """Perform a non-negative lasso optimizer step with Nesterov momentum.
142
143    This function combines proximal gradient descent with momentum acceleration
144    to improve convergence speed while maintaining the non-negative constraint
145    on coefficients.
146
147    Dimension notation:
148    - B = batch size (number of regions processed simultaneously)
149    - M = number of motifs
150    - L = sequence length
151    - W = motif width (length of each motif)
152    - P = L - W + 1 (number of valid motif positions)
153
154    Parameters
155    ----------
156    cwms : Float[Tensor, "M 4 W"]
157        Motif contribution weight matrices.
158    contribs : Float[Tensor, "B 4 L"]
159        Target contribution scores.
160    importance_scale : Float[Tensor, "B 1 P"]
161        Importance scaling factors.
162    sequences : Union[Int[Tensor, "B 4 L"], int]
163        One-hot encoded sequences or scalar for hypothetical mode.
164    coef_inter : Float[Tensor, "B M P"]
165        Intermediate coefficient matrix (with momentum).
166    coef : Float[Tensor, "B M P"]
167        Current coefficient matrix.
168    i : Float[Tensor, "B 1 1"]
169        Iteration counter for each batch element.
170    step_sizes :  Float[Tensor, "B 1 1"]
171        Step sizes for optimization.
172    L : int
173        Sequence length for normalization.
174    lambdas : Float[Tensor, "1 M 1"]
175        Regularization parameters.
176
177    Returns
178    -------
179    coef_inter : Float[Tensor, "B M P"]
180        Updated intermediate coefficients with momentum (shape: batch_size × motifs × positions).
181    coef : Float[Tensor, "B M P"]
182        Updated coefficient matrix (shape: batch_size × motifs × positions).
183    gap : Float[Tensor, " B"]
184        Normalized duality gap (shape: batch_size).
185    nll : Float[Tensor, " B"]
186        Normalized negative log likelihood (shape: batch_size).
187
188    Notes
189    -----
190    Uses Nesterov momentum with momentum coefficient i/(i+3) for improved
191    convergence properties. The duality gap and NLL are normalized by
192    sequence length for scale-invariant convergence assessment.
193
194    References
195    ----------
196    https://yuxinchen2020.github.io/ele520_math_data/lectures/lasso_algorithm_extension.pdf, slides 22, 27
197    """
198    coef_prev = coef
199
200    # Proximal gradient descent step
201    coef, gap, nll = prox_grad_step(
202        coef_inter, importance_scale, cwms, contribs, sequences, lambdas, step_sizes
203    )
204    gap = gap / L
205    nll = nll / (2 * L)
206
207    # Compute updated coefficients with Nesterov momentum
208    mom_term = i / (i + 3.0)
209    coef_inter = (1 + mom_term) * coef - mom_term * coef_prev
210
211    return coef_inter, coef, gap, nll

Perform a non-negative lasso optimizer step with Nesterov momentum.

This function combines proximal gradient descent with momentum acceleration to improve convergence speed while maintaining the non-negative constraint on coefficients.

Dimension notation:

  • B = batch size (number of regions processed simultaneously)
  • M = number of motifs
  • L = sequence length
  • W = motif width (length of each motif)
  • P = L - W + 1 (number of valid motif positions)
Parameters
  • cwms (Float[Tensor, "M 4 W"]): Motif contribution weight matrices.
  • contribs (Float[Tensor, "B 4 L"]): Target contribution scores.
  • importance_scale (Float[Tensor, "B 1 P"]): Importance scaling factors.
  • sequences (Union[Int[Tensor, "B 4 L"], int]): One-hot encoded sequences or scalar for hypothetical mode.
  • coef_inter (Float[Tensor, "B M P"]): Intermediate coefficient matrix (with momentum).
  • coef (Float[Tensor, "B M P"]): Current coefficient matrix.
  • i (Float[Tensor, "B 1 1"]): Iteration counter for each batch element.
  • step_sizes (Float[Tensor, "B 1 1"]): Step sizes for optimization.
  • L (int): Sequence length for normalization.
  • lambdas (Float[Tensor, "1 M 1"]): Regularization parameters.
Returns
  • coef_inter (Float[Tensor, "B M P"]): Updated intermediate coefficients with momentum (shape: batch_size × motifs × positions).
  • coef (Float[Tensor, "B M P"]): Updated coefficient matrix (shape: batch_size × motifs × positions).
  • gap (Float[Tensor, " B"]): Normalized duality gap (shape: batch_size).
  • nll (Float[Tensor, " B"]): Normalized negative log likelihood (shape: batch_size).
Notes

Uses Nesterov momentum with momentum coefficient i/(i+3) for improved convergence properties. The duality gap and NLL are normalized by sequence length for scale-invariant convergence assessment.

References

https://yuxinchen2020.github.io/ele520_math_data/lectures/lasso_algorithm_extension.pdf, slides 22, 27

class BatchLoaderBase(abc.ABC):
253class BatchLoaderBase(ABC):
254    """Base class for loading batches of contribution scores and sequences.
255
256    This class provides common functionality for different input formats
257    including batch indexing and padding for consistent batch sizes.
258
259    Dimension notation:
260    - N = number of sequences/regions in dataset
261    - L = sequence length
262    - B = batch size (number of regions processed simultaneously)
263
264    Parameters
265    ----------
266    contribs : Union[Float[Tensor, "N 4 L"], Float[Tensor, "N L"]]
267        Contribution scores array.
268    sequences : Int[Tensor, "N 4 L"]
269        One-hot encoded sequences array.
270    L : int
271        Sequence length.
272    device : torch.device
273        Target device for tensor operations.
274    """
275
276    def __init__(
277        self,
278        contribs: Union[Float[Tensor, "N 4 L"], Float[Tensor, "N L"]],
279        sequences: Int[Tensor, "N 4 L"],
280        L: int,
281        device: torch.device,
282    ) -> None:
283        self.contribs = contribs
284        self.sequences = sequences
285        self.L = L
286        self.device = device
287
288    def _get_inds_and_pad_lens(
289        self, start: int, end: int
290    ) -> Tuple[Int[Tensor, " Z"], Tuple[int, ...]]:
291        """Get indices and padding lengths for batch loading.
292
293        Parameters
294        ----------
295        start : int
296            Start index for batch.
297        end : int
298            End index for batch.
299
300        Returns
301        -------
302        inds : Int[Tensor, " Z"]
303            Padded indices tensor with -1 for padding positions (shape: padded_batch_size).
304        pad_lens : tuple
305            Padding specification for F.pad (left, right, top, bottom, front, back).
306        """
307        N = end - start
308        end = min(end, self.contribs.shape[0])
309        overhang = N - (end - start)
310        pad_lens = (0, 0, 0, 0, 0, overhang)
311
312        inds = F.pad(
313            torch.arange(start, end, dtype=torch.int), (0, overhang), value=-1
314        ).to(device=self.device)
315
316        return inds, pad_lens
317
318    @abstractmethod
319    def load_batch(
320        self, start: int, end: int
321    ) -> Tuple[
322        Float[Tensor, "B 4 L"], Union[Int[Tensor, "B 4 L"], int], Int[Tensor, " B"]
323    ]:
324        """Load a batch of data.
325
326        Dimension notation:
327        - B = batch size (number of regions in this batch)
328        - L = sequence length
329
330        Parameters
331        ----------
332        start : int
333            Start index (used by subclasses).
334        end : int
335            End index (used by subclasses).
336
337        Returns
338        -------
339        contribs_batch : Float[Tensor, "B 4 L"]
340            Batch of contribution scores (shape: batch_size × 4_bases × L).
341        sequences_batch : Union[Int[Tensor, "B 4 L"], int]
342            Batch of one-hot encoded sequences (shape: batch_size × 4_bases × L) or scalar 1 for hypothetical mode.
343        inds_batch : Int[Tensor, " B"]
344            Batch indices mapping to original sequence indices (shape: batch_size).
345
346        Notes
347        -----
348        This is an abstract method that must be implemented by subclasses.
349        Parameters are intentionally unused in the base implementation.
350        """
351        pass

Base class for loading batches of contribution scores and sequences.

This class provides common functionality for different input formats including batch indexing and padding for consistent batch sizes.

Dimension notation:

  • N = number of sequences/regions in dataset
  • L = sequence length
  • B = batch size (number of regions processed simultaneously)
Parameters
  • contribs (Union[Float[Tensor, "N 4 L"], Float[Tensor, "N L"]]): Contribution scores array.
  • sequences (Int[Tensor, "N 4 L"]): One-hot encoded sequences array.
  • L (int): Sequence length.
  • device (torch.device): Target device for tensor operations.
contribs
sequences
L
device
@abstractmethod
def load_batch( self, start: int, end: int) -> Tuple[jaxtyping.Float[Tensor, 'B 4 L'], Union[jaxtyping.Int[Tensor, 'B 4 L'], int], jaxtyping.Int[Tensor, 'B']]:
318    @abstractmethod
319    def load_batch(
320        self, start: int, end: int
321    ) -> Tuple[
322        Float[Tensor, "B 4 L"], Union[Int[Tensor, "B 4 L"], int], Int[Tensor, " B"]
323    ]:
324        """Load a batch of data.
325
326        Dimension notation:
327        - B = batch size (number of regions in this batch)
328        - L = sequence length
329
330        Parameters
331        ----------
332        start : int
333            Start index (used by subclasses).
334        end : int
335            End index (used by subclasses).
336
337        Returns
338        -------
339        contribs_batch : Float[Tensor, "B 4 L"]
340            Batch of contribution scores (shape: batch_size × 4_bases × L).
341        sequences_batch : Union[Int[Tensor, "B 4 L"], int]
342            Batch of one-hot encoded sequences (shape: batch_size × 4_bases × L) or scalar 1 for hypothetical mode.
343        inds_batch : Int[Tensor, " B"]
344            Batch indices mapping to original sequence indices (shape: batch_size).
345
346        Notes
347        -----
348        This is an abstract method that must be implemented by subclasses.
349        Parameters are intentionally unused in the base implementation.
350        """
351        pass

Load a batch of data.

Dimension notation:

  • B = batch size (number of regions in this batch)
  • L = sequence length
Parameters
  • start (int): Start index (used by subclasses).
  • end (int): End index (used by subclasses).
Returns
  • contribs_batch (Float[Tensor, "B 4 L"]): Batch of contribution scores (shape: batch_size × 4_bases × L).
  • sequences_batch (Union[Int[Tensor, "B 4 L"], int]): Batch of one-hot encoded sequences (shape: batch_size × 4_bases × L) or scalar 1 for hypothetical mode.
  • inds_batch (Int[Tensor, " B"]): Batch indices mapping to original sequence indices (shape: batch_size).
Notes

This is an abstract method that must be implemented by subclasses. Parameters are intentionally unused in the base implementation.

class BatchLoaderCompactFmt(BatchLoaderBase):
354class BatchLoaderCompactFmt(BatchLoaderBase):
355    """Batch loader for compact format contribution scores.
356
357    Handles contribution scores in shape (N, L) representing projected
358    scores that need to be broadcasted to (N, 4, L) format.
359    """
360
361    def load_batch(
362        self, start: int, end: int
363    ) -> Tuple[Float[Tensor, "B 4 L"], Int[Tensor, "B 4 L"], Int[Tensor, " B"]]:
364        inds, pad_lens = self._get_inds_and_pad_lens(start, end)
365
366        contribs_compact = F.pad(self.contribs[start:end, None, :], pad_lens)
367        contribs_batch = _to_channel_last_layout(
368            contribs_compact, device=self.device, dtype=torch.float32
369        )
370        sequences_batch = F.pad(self.sequences[start:end, :, :], pad_lens)  # (B, 4, L)
371        sequences_batch = _to_channel_last_layout(
372            sequences_batch, device=self.device, dtype=torch.int8
373        )
374
375        contribs_batch = contribs_batch * sequences_batch  # (B, 4, L)
376
377        return contribs_batch, sequences_batch, inds

Batch loader for compact format contribution scores.

Handles contribution scores in shape (N, L) representing projected scores that need to be broadcasted to (N, 4, L) format.

def load_batch( self, start: int, end: int) -> Tuple[jaxtyping.Float[Tensor, 'B 4 L'], jaxtyping.Int[Tensor, 'B 4 L'], jaxtyping.Int[Tensor, 'B']]:
361    def load_batch(
362        self, start: int, end: int
363    ) -> Tuple[Float[Tensor, "B 4 L"], Int[Tensor, "B 4 L"], Int[Tensor, " B"]]:
364        inds, pad_lens = self._get_inds_and_pad_lens(start, end)
365
366        contribs_compact = F.pad(self.contribs[start:end, None, :], pad_lens)
367        contribs_batch = _to_channel_last_layout(
368            contribs_compact, device=self.device, dtype=torch.float32
369        )
370        sequences_batch = F.pad(self.sequences[start:end, :, :], pad_lens)  # (B, 4, L)
371        sequences_batch = _to_channel_last_layout(
372            sequences_batch, device=self.device, dtype=torch.int8
373        )
374
375        contribs_batch = contribs_batch * sequences_batch  # (B, 4, L)
376
377        return contribs_batch, sequences_batch, inds

Load a batch of data.

Dimension notation:

  • B = batch size (number of regions in this batch)
  • L = sequence length
Parameters
  • start (int): Start index (used by subclasses).
  • end (int): End index (used by subclasses).
Returns
  • contribs_batch (Float[Tensor, "B 4 L"]): Batch of contribution scores (shape: batch_size × 4_bases × L).
  • sequences_batch (Union[Int[Tensor, "B 4 L"], int]): Batch of one-hot encoded sequences (shape: batch_size × 4_bases × L) or scalar 1 for hypothetical mode.
  • inds_batch (Int[Tensor, " B"]): Batch indices mapping to original sequence indices (shape: batch_size).
Notes

This is an abstract method that must be implemented by subclasses. Parameters are intentionally unused in the base implementation.

class BatchLoaderProj(BatchLoaderBase):
380class BatchLoaderProj(BatchLoaderBase):
381    """Batch loader for projected contribution scores.
382
383    Handles contribution scores in shape (N, 4, L) where scores are
384    element-wise multiplied by one-hot sequences to get projected contributions.
385    """
386
387    def load_batch(
388        self, start: int, end: int
389    ) -> Tuple[Float[Tensor, "B 4 L"], Int[Tensor, "B 4 L"], Int[Tensor, " B"]]:
390        inds, pad_lens = self._get_inds_and_pad_lens(start, end)
391
392        contribs_hyp = F.pad(self.contribs[start:end, :, :], pad_lens)
393        contribs_hyp = _to_channel_last_layout(
394            contribs_hyp, device=self.device, dtype=torch.float32
395        )
396        sequences_batch = F.pad(self.sequences[start:end, :, :], pad_lens)  # (B, 4, L)
397        sequences_batch = _to_channel_last_layout(
398            sequences_batch, device=self.device, dtype=torch.int8
399        )
400        contribs_batch = contribs_hyp * sequences_batch
401
402        return contribs_batch, sequences_batch, inds

Batch loader for projected contribution scores.

Handles contribution scores in shape (N, 4, L) where scores are element-wise multiplied by one-hot sequences to get projected contributions.

def load_batch( self, start: int, end: int) -> Tuple[jaxtyping.Float[Tensor, 'B 4 L'], jaxtyping.Int[Tensor, 'B 4 L'], jaxtyping.Int[Tensor, 'B']]:
387    def load_batch(
388        self, start: int, end: int
389    ) -> Tuple[Float[Tensor, "B 4 L"], Int[Tensor, "B 4 L"], Int[Tensor, " B"]]:
390        inds, pad_lens = self._get_inds_and_pad_lens(start, end)
391
392        contribs_hyp = F.pad(self.contribs[start:end, :, :], pad_lens)
393        contribs_hyp = _to_channel_last_layout(
394            contribs_hyp, device=self.device, dtype=torch.float32
395        )
396        sequences_batch = F.pad(self.sequences[start:end, :, :], pad_lens)  # (B, 4, L)
397        sequences_batch = _to_channel_last_layout(
398            sequences_batch, device=self.device, dtype=torch.int8
399        )
400        contribs_batch = contribs_hyp * sequences_batch
401
402        return contribs_batch, sequences_batch, inds

Load a batch of data.

Dimension notation:

  • B = batch size (number of regions in this batch)
  • L = sequence length
Parameters
  • start (int): Start index (used by subclasses).
  • end (int): End index (used by subclasses).
Returns
  • contribs_batch (Float[Tensor, "B 4 L"]): Batch of contribution scores (shape: batch_size × 4_bases × L).
  • sequences_batch (Union[Int[Tensor, "B 4 L"], int]): Batch of one-hot encoded sequences (shape: batch_size × 4_bases × L) or scalar 1 for hypothetical mode.
  • inds_batch (Int[Tensor, " B"]): Batch indices mapping to original sequence indices (shape: batch_size).
Notes

This is an abstract method that must be implemented by subclasses. Parameters are intentionally unused in the base implementation.

class BatchLoaderHyp(BatchLoaderBase):
405class BatchLoaderHyp(BatchLoaderBase):
406    """Batch loader for hypothetical contribution scores.
407
408    Handles hypothetical contribution scores in shape (N, 4, L) where
409    scores represent counterfactual effects of base substitutions.
410    """
411
412    def load_batch(
413        self, start: int, end: int
414    ) -> Tuple[Float[Tensor, "B 4 L"], int, Int[Tensor, " B"]]:
415        inds, pad_lens = self._get_inds_and_pad_lens(start, end)
416
417        contribs_batch = F.pad(self.contribs[start:end, :, :], pad_lens)
418        contribs_batch = _to_channel_last_layout(
419            contribs_batch, device=self.device, dtype=torch.float32
420        )
421
422        return contribs_batch, 1, inds

Batch loader for hypothetical contribution scores.

Handles hypothetical contribution scores in shape (N, 4, L) where scores represent counterfactual effects of base substitutions.

def load_batch( self, start: int, end: int) -> Tuple[jaxtyping.Float[Tensor, 'B 4 L'], int, jaxtyping.Int[Tensor, 'B']]:
412    def load_batch(
413        self, start: int, end: int
414    ) -> Tuple[Float[Tensor, "B 4 L"], int, Int[Tensor, " B"]]:
415        inds, pad_lens = self._get_inds_and_pad_lens(start, end)
416
417        contribs_batch = F.pad(self.contribs[start:end, :, :], pad_lens)
418        contribs_batch = _to_channel_last_layout(
419            contribs_batch, device=self.device, dtype=torch.float32
420        )
421
422        return contribs_batch, 1, inds

Load a batch of data.

Dimension notation:

  • B = batch size (number of regions in this batch)
  • L = sequence length
Parameters
  • start (int): Start index (used by subclasses).
  • end (int): End index (used by subclasses).
Returns
  • contribs_batch (Float[Tensor, "B 4 L"]): Batch of contribution scores (shape: batch_size × 4_bases × L).
  • sequences_batch (Union[Int[Tensor, "B 4 L"], int]): Batch of one-hot encoded sequences (shape: batch_size × 4_bases × L) or scalar 1 for hypothetical mode.
  • inds_batch (Int[Tensor, " B"]): Batch indices mapping to original sequence indices (shape: batch_size).
Notes

This is an abstract method that must be implemented by subclasses. Parameters are intentionally unused in the base implementation.

def fit_contribs( cwms: jaxtyping.Float[ndarray, 'M 4 W'], contribs: Union[jaxtyping.Float[ndarray, 'N 4 L'], jaxtyping.Float[ndarray, 'N L']], sequences: jaxtyping.Int[ndarray, 'N 4 L'], cwm_trim_mask: jaxtyping.Float[ndarray, 'M W'], use_hypothetical: bool, lambdas: jaxtyping.Float[ndarray, 'M'], step_size_max: float = 3.0, step_size_min: float = 0.08, sqrt_transform: bool = False, convergence_tol: float = 0.0005, max_steps: int = 10000, batch_size: int = 2000, step_adjust: float = 0.7, post_filter: bool = True, device: Optional[torch.device] = None, compile_optimizer: bool = False, eps: float = 1.0) -> Tuple[polars.dataframe.frame.DataFrame, polars.dataframe.frame.DataFrame]:
425def fit_contribs(
426    cwms: Float[ndarray, "M 4 W"],
427    contribs: Union[Float[ndarray, "N 4 L"], Float[ndarray, "N L"]],
428    sequences: Int[ndarray, "N 4 L"],
429    cwm_trim_mask: Float[ndarray, "M W"],
430    use_hypothetical: bool,
431    lambdas: Float[ndarray, " M"],
432    step_size_max: float = 3.0,
433    step_size_min: float = 0.08,
434    sqrt_transform: bool = False,
435    convergence_tol: float = 0.0005,
436    max_steps: int = 10000,
437    batch_size: int = 2000,
438    step_adjust: float = 0.7,
439    post_filter: bool = True,
440    device: Optional[torch.device] = None,
441    compile_optimizer: bool = False,
442    eps: float = 1.0,
443) -> Tuple[pl.DataFrame, pl.DataFrame]:
444    """Call motif hits by fitting sparse linear model to contribution scores.
445
446    This is the main function implementing the Fi-NeMo algorithm. It identifies
447    motif instances by solving a sparse reconstruction problem where contribution
448    scores are approximated as a linear combination of motif CWMs at specific
449    positions. The optimization uses proximal gradient descent with momentum.
450
451    Parameters
452    ----------
453    cwms : Float[ndarray, "M 4 W"]
454        Motif contribution weight matrices where:
455        - M = number of motifs (transcription factor binding patterns)
456        - 4 = DNA bases (A, C, G, T dimensions)
457        - W = motif width (length of each motif pattern)
458    contribs : Float[ndarray, "N 4 L"] | Float[ndarray, "N L"]
459        Neural network contribution scores where:
460        - N = number of regions in dataset
461        - L = sequence length
462        Can be hypothetical (N, 4, L) or projected (N, L) format.
463    sequences : Int[ndarray, "N 4 L"]
464        One-hot encoded DNA sequences (shape: num_regions × 4_bases × L).
465    cwm_trim_mask : Float[ndarray, "M W"]
466        Binary mask indicating which positions of each CWM to use (shape: num_motifs × motif_width).
467    use_hypothetical : bool
468        Whether to use hypothetical contribution scores (True) or
469        projected scores (False).
470    lambdas : Float[ndarray, " M"]
471        L1 regularization weights for each motif (shape: num_motifs).
472    step_size_max : float, default 3.0
473        Maximum optimization step size.
474    step_size_min : float, default 0.08
475        Minimum optimization step size (for convergence failure detection).
476    sqrt_transform : bool, default False
477        Whether to apply signed square root transformation to inputs.
478    convergence_tol : float, default 0.0005
479        Convergence tolerance based on duality gap.
480    max_steps : int, default 10000
481        Maximum number of optimization steps.
482    batch_size : int, default 2000
483        Number of regions to process simultaneously.
484    step_adjust : float, default 0.7
485        Factor to reduce step size when optimization diverges.
486    post_filter : bool, default True
487        Whether to filter hits based on similarity threshold.
488    device : torch.device, optional
489        Target device for computation. Auto-detected if None.
490    compile_optimizer : bool, default False
491        Whether to JIT compile the optimizer for speed.
492    eps : float, default 1.0
493        Small constant for numerical stability.
494
495    Returns
496    -------
497    hits_df : pl.DataFrame
498        DataFrame containing called motif hits with columns:
499        - peak_id: Region index
500        - motif_id: Motif index
501        - hit_start: Start position of hit
502        - hit_coefficient: Hit strength coefficient
503        - hit_similarity: Cosine similarity with motif
504        - hit_importance: Total contribution score in hit region
505        - hit_importance_sq: Sum of squared contributions (for normalization)
506    qc_df : pl.DataFrame
507        DataFrame containing quality control metrics with columns:
508        - peak_id: Region index
509        - nll: Final negative log likelihood
510        - dual_gap: Final duality gap
511        - num_steps: Number of optimization steps
512        - step_size: Final step size
513        - global_scale: Region-level scaling factor
514
515    Notes
516    -----
517    The algorithm solves the optimization problem:
518
519    minimize_c: ||contribs - Σⱼ convolve(c * scale, cwms[j]) * sequences||²₂ + Σⱼ λⱼ||c[:,j]||₁
520
521    subject to: c ≥ 0
522
523    where c[i,j] represents the strength of motif j at position i.
524
525    The importance scaling balances reconstruction across different
526    motifs and positions based on the local contribution magnitude.
527
528    Examples
529    --------
530    >>> hits_df, qc_df = fit_contribs(
531    ...     cwms=motif_cwms,
532    ...     contribs=contrib_scores,
533    ...     sequences=onehot_seqs,
534    ...     cwm_trim_mask=trim_masks,
535    ...     use_hypothetical=False,
536    ...     lambdas=np.array([0.7, 0.8]),
537    ...     step_size_max=3.0,
538    ...     step_size_min=0.08,
539    ...     sqrt_transform=False,
540    ...     convergence_tol=0.0005,
541    ...     max_steps=10000,
542    ...     batch_size=1000,
543    ...     step_adjust=0.7,
544    ...     post_filter=True,
545    ...     device=None,
546    ...     compile_optimizer=False
547    ... )
548    """
549    M, _, W = cwms.shape
550    N, _, L = sequences.shape
551
552    B = batch_size  # Using uppercase for consistency with dimension notation
553
554    if device is None:
555        if torch.cuda.is_available():
556            device = torch.device("cuda")
557        else:
558            device = torch.device("cpu")
559            warnings.warn("No GPU available. Running on CPU.", RuntimeWarning)
560
561    # Compile optimizer if requested
562    global optimizer_step
563    if compile_optimizer:
564        optimizer_step = torch.compile(optimizer_step, fullgraph=True)
565
566    # Convert inputs to PyTorch tensors with proper device placement
567    cwms_tensor: torch.Tensor = torch.from_numpy(cwms)
568    contribs_tensor: torch.Tensor = torch.from_numpy(contribs)
569    sequences_tensor: torch.Tensor = torch.from_numpy(sequences)
570    cwm_trim_mask_tensor = torch.from_numpy(cwm_trim_mask)[:, None, :].repeat(1, 4, 1)
571    lambdas_tensor: torch.Tensor = torch.from_numpy(lambdas)[None, :, None].to(
572        device=device, dtype=torch.float32
573    )
574
575    # Convert to channel-last layout for optimized convolution operations
576    cwms_tensor = _to_channel_last_layout(
577        cwms_tensor, device=device, dtype=torch.float32
578    )
579    cwm_trim_mask_tensor = _to_channel_last_layout(
580        cwm_trim_mask_tensor, device=device, dtype=torch.float32
581    )
582    cwms_tensor = cwms_tensor * cwm_trim_mask_tensor  # Apply trimming mask
583
584    if sqrt_transform:
585        cwms_tensor = _signed_sqrt(cwms_tensor)
586        cwm_norm = (cwms_tensor**2).sum(dim=(1, 2)).sqrt()
587        cwms_tensor = cwms_tensor / cwm_norm[:, None, None]
588
589    # Initialize batch loader
590    if len(contribs_tensor.shape) == 3:
591        if use_hypothetical:
592            batch_loader = BatchLoaderHyp(contribs_tensor, sequences_tensor, L, device)
593        else:
594            batch_loader = BatchLoaderProj(contribs_tensor, sequences_tensor, L, device)
595    elif len(contribs_tensor.shape) == 2:
596        if use_hypothetical:
597            raise ValueError(
598                "Input regions do not contain hypothetical contribution scores"
599            )
600        else:
601            batch_loader = BatchLoaderCompactFmt(
602                contribs_tensor, sequences_tensor, L, device
603            )
604    else:
605        raise ValueError(
606            f"Input contributions array is of incorrect shape {contribs_tensor.shape}"
607        )
608
609    # Initialize output container objects
610    hit_idxs_lst: List[ndarray] = []
611    coefficients_lst: List[ndarray] = []
612    similarity_lst: List[ndarray] = []
613    importance_lst: List[ndarray] = []
614    importance_sq_lst: List[ndarray] = []
615    qc_lsts: Dict[str, List[ndarray]] = {
616        "nll": [],
617        "dual_gap": [],
618        "num_steps": [],
619        "step_size": [],
620        "global_scale": [],
621        "peak_id": [],
622    }
623
624    # Initialize buffers for optimizer
625    coef_inter: Float[Tensor, "B M P"] = torch.zeros(
626        (B, M, L - W + 1)
627    )  # (B, M, P) where P = L - W + 1
628    coef_inter = _to_channel_last_layout(coef_inter, device=device, dtype=torch.float32)
629    coef: Float[Tensor, "B M P"] = torch.zeros_like(coef_inter)
630    i: Float[Tensor, "B 1 1"] = torch.zeros((B, 1, 1), dtype=torch.int, device=device)
631    step_sizes: Float[Tensor, "B 1 1"] = torch.full(
632        (B, 1, 1), step_size_max, dtype=torch.float32, device=device
633    )
634
635    converged: Bool[Tensor, " B"] = torch.full(
636        (B,), True, dtype=torch.bool, device=device
637    )
638    num_load = B
639
640    contribs_buf: Float[Tensor, "B 4 L"] = torch.zeros((B, 4, L))
641    contribs_buf = _to_channel_last_layout(
642        contribs_buf, device=device, dtype=torch.float32
643    )
644
645    seqs_buf: Union[Int[Tensor, "B 4 L"], int]
646    if use_hypothetical:
647        seqs_buf = 1
648    else:
649        seqs_buf = torch.zeros((B, 4, L))
650        seqs_buf = _to_channel_last_layout(seqs_buf, device=device, dtype=torch.int8)
651
652    importance_scale_buf: Float[Tensor, "B M P"] = torch.zeros((B, M, L - W + 1))
653    importance_scale_buf = _to_channel_last_layout(
654        importance_scale_buf, device=device, dtype=torch.float32
655    )
656
657    inds_buf: Int[Tensor, " B"] = torch.zeros((B,), dtype=torch.int, device=device)
658    global_scale_buf: Float[Tensor, " B"] = torch.zeros(
659        (B,), dtype=torch.float, device=device
660    )
661
662    with tqdm(disable=None, unit="regions", total=N, ncols=120) as pbar:
663        num_complete = 0
664        next_ind = 0
665        while num_complete < N:
666            # Retire converged peaks and fill buffer with new data
667            if num_load > 0:
668                load_start = next_ind
669                load_end = load_start + num_load
670                next_ind = min(load_end, contribs_tensor.shape[0])
671
672                batch_data = batch_loader.load_batch(int(load_start), int(load_end))
673                contribs_batch, seqs_batch, inds_batch = batch_data
674
675                if sqrt_transform:
676                    contribs_batch = _signed_sqrt(contribs_batch)
677
678                global_scale_batch = ((contribs_batch**2).sum(dim=(1, 2)) / L).sqrt()
679                contribs_batch = torch.nan_to_num(
680                    contribs_batch / global_scale_batch[:, None, None]
681                )
682
683                importance_scale_batch = (
684                    F.conv1d(contribs_batch**2, cwm_trim_mask_tensor) + eps
685                ) ** (-0.5)
686                importance_scale_batch = importance_scale_batch.clamp(max=10)
687
688                contribs_buf[converged, :, :] = contribs_batch
689                if not use_hypothetical:
690                    seqs_buf[converged, :, :] = seqs_batch  # type: ignore
691
692                importance_scale_buf[converged, :, :] = importance_scale_batch
693
694                inds_buf[converged] = inds_batch
695                global_scale_buf[converged] = global_scale_batch
696
697                coef_inter[converged, :, :] *= 0
698                coef[converged, :, :] *= 0
699                i[converged] *= 0
700
701                step_sizes[converged] = step_size_max
702
703            # Optimization step
704            coef_inter, coef, gap, nll = optimizer_step(
705                cwms_tensor,
706                contribs_buf,
707                importance_scale_buf,
708                seqs_buf,
709                coef_inter,
710                coef,
711                i,
712                step_sizes,
713                L,
714                lambdas_tensor,
715            )
716            i += 1
717
718            # Assess convergence of each peak being optimized. Reset diverged peaks with lower step size.
719            active = inds_buf >= 0
720
721            diverged = ~torch.isfinite(gap) & active
722            coef_inter[diverged, :, :] *= 0
723            coef[diverged, :, :] *= 0
724            i[diverged] *= 0
725            step_sizes[diverged, :, :] *= step_adjust
726
727            timeouts = (i > max_steps).squeeze() & active
728            if timeouts.sum().item() > 0:
729                timeout_inds = inds_buf[timeouts]
730                for ind in timeout_inds:
731                    warnings.warn(
732                        f"Region {ind} has not converged within max_steps={max_steps} iterations.",
733                        RuntimeWarning,
734                    )
735
736            fails = (step_sizes < step_size_min).squeeze() & active
737            if fails.sum().item() > 0:
738                fail_inds = inds_buf[fails]
739                for ind in fail_inds:
740                    warnings.warn(f"Optimizer failed for region {ind}.", RuntimeWarning)
741
742            converged = ((gap <= convergence_tol) | timeouts | fails) & active
743            num_load = converged.sum().item()
744
745            # Extract hits from converged peaks
746            if num_load > 0:
747                inds_out = inds_buf[converged]
748                global_scale_out = global_scale_buf[converged]
749
750                # Compute hit scores
751                coef_out = coef[converged, :, :]
752                importance_scale_out_dense = importance_scale_buf[converged, :, :]
753                importance_sq = importance_scale_out_dense ** (-2) - eps
754                xcor_scale = importance_sq.sqrt()
755
756                contribs_converged = contribs_buf[converged, :, :]
757                importance_sum_out_dense = F.conv1d(
758                    torch.abs(contribs_converged), cwm_trim_mask_tensor
759                )
760                xcov_out_dense = F.conv1d(contribs_converged, cwms_tensor)
761                # xcov_out_dense = F.conv1d(torch.abs(contribs_converged), cwms_tensor)
762                xcor_out_dense = xcov_out_dense / xcor_scale
763
764                if post_filter:
765                    coef_out = coef_out * (xcor_out_dense >= lambdas_tensor)
766
767                # Extract hit coordinates using sparse tensor representation
768                coef_out = coef_out.to_sparse()
769
770                # Tensor indexing operations for hit extraction
771                hit_idxs_out = torch.clone(coef_out.indices())  # Sparse tensor indices
772                hit_idxs_out[0, :] = F.embedding(
773                    hit_idxs_out[0, :], inds_out[:, None]
774                ).squeeze()  # Embedding lookup with complex indexing
775                # Map buffer index to peak index
776
777                ind_tuple = torch.unbind(coef_out.indices())
778                importance_out = importance_sum_out_dense[ind_tuple]
779                importance_sq_out = importance_sq[ind_tuple]
780                xcor_out = xcor_out_dense[ind_tuple]
781
782                scores_out_raw = coef_out.values()
783
784                # Store outputs
785                gap_out = gap[converged]
786                nll_out = nll[converged]
787                step_out = i[converged, 0, 0]
788                step_sizes_out = step_sizes[converged, 0, 0]
789
790                hit_idxs_lst.append(hit_idxs_out.numpy(force=True).T)
791                coefficients_lst.append(scores_out_raw.numpy(force=True))
792                similarity_lst.append(xcor_out.numpy(force=True))
793                importance_lst.append(importance_out.numpy(force=True))
794                importance_sq_lst.append(importance_sq_out.numpy(force=True))
795
796                qc_lsts["nll"].append(nll_out.numpy(force=True))
797                qc_lsts["dual_gap"].append(gap_out.numpy(force=True))
798                qc_lsts["num_steps"].append(step_out.numpy(force=True))
799                qc_lsts["global_scale"].append(global_scale_out.numpy(force=True))
800                qc_lsts["step_size"].append(step_sizes_out.numpy(force=True))
801                qc_lsts["peak_id"].append(inds_out.numpy(force=True).astype(np.uint32))
802
803                num_complete += num_load
804                pbar.update(num_load)
805
806    # Merge outputs into arrays
807    hit_idxs = np.concatenate(hit_idxs_lst, axis=0)
808    scores_coefficient = np.concatenate(coefficients_lst, axis=0)
809    scores_similarity = np.concatenate(similarity_lst, axis=0)
810    scores_importance = np.concatenate(importance_lst, axis=0)
811    scores_importance_sq = np.concatenate(importance_sq_lst, axis=0)
812
813    hits: Dict[str, ndarray] = {
814        "peak_id": hit_idxs[:, 0].astype(np.uint32),
815        "motif_id": hit_idxs[:, 1].astype(np.uint32),
816        "hit_start": hit_idxs[:, 2],
817        "hit_coefficient": scores_coefficient,
818        "hit_similarity": scores_similarity,
819        "hit_importance": scores_importance,
820        "hit_importance_sq": scores_importance_sq,
821    }
822
823    qc: Dict[str, ndarray] = {k: np.concatenate(v, axis=0) for k, v in qc_lsts.items()}
824
825    hits_df = pl.DataFrame(hits)
826    qc_df = pl.DataFrame(qc)
827
828    return hits_df, qc_df

Call motif hits by fitting sparse linear model to contribution scores.

This is the main function implementing the Fi-NeMo algorithm. It identifies motif instances by solving a sparse reconstruction problem where contribution scores are approximated as a linear combination of motif CWMs at specific positions. The optimization uses proximal gradient descent with momentum.

Parameters
  • cwms (Float[ndarray, "M 4 W"]): Motif contribution weight matrices where:
    • M = number of motifs (transcription factor binding patterns)
    • 4 = DNA bases (A, C, G, T dimensions)
    • W = motif width (length of each motif pattern)
  • contribs (Float[ndarray, "N 4 L"] | Float[ndarray, "N L"]): Neural network contribution scores where:
    • N = number of regions in dataset
    • L = sequence length Can be hypothetical (N, 4, L) or projected (N, L) format.
  • sequences (Int[ndarray, "N 4 L"]): One-hot encoded DNA sequences (shape: num_regions × 4_bases × L).
  • cwm_trim_mask (Float[ndarray, "M W"]): Binary mask indicating which positions of each CWM to use (shape: num_motifs × motif_width).
  • use_hypothetical (bool): Whether to use hypothetical contribution scores (True) or projected scores (False).
  • lambdas (Float[ndarray, " M"]): L1 regularization weights for each motif (shape: num_motifs).
  • step_size_max (float, default 3.0): Maximum optimization step size.
  • step_size_min (float, default 0.08): Minimum optimization step size (for convergence failure detection).
  • sqrt_transform (bool, default False): Whether to apply signed square root transformation to inputs.
  • convergence_tol (float, default 0.0005): Convergence tolerance based on duality gap.
  • max_steps (int, default 10000): Maximum number of optimization steps.
  • batch_size (int, default 2000): Number of regions to process simultaneously.
  • step_adjust (float, default 0.7): Factor to reduce step size when optimization diverges.
  • post_filter (bool, default True): Whether to filter hits based on similarity threshold.
  • device (torch.device, optional): Target device for computation. Auto-detected if None.
  • compile_optimizer (bool, default False): Whether to JIT compile the optimizer for speed.
  • eps (float, default 1.0): Small constant for numerical stability.
Returns
  • hits_df (pl.DataFrame): DataFrame containing called motif hits with columns:
    • peak_id: Region index
    • motif_id: Motif index
    • hit_start: Start position of hit
    • hit_coefficient: Hit strength coefficient
    • hit_similarity: Cosine similarity with motif
    • hit_importance: Total contribution score in hit region
    • hit_importance_sq: Sum of squared contributions (for normalization)
  • qc_df (pl.DataFrame): DataFrame containing quality control metrics with columns:
    • peak_id: Region index
    • nll: Final negative log likelihood
    • dual_gap: Final duality gap
    • num_steps: Number of optimization steps
    • step_size: Final step size
    • global_scale: Region-level scaling factor
Notes

The algorithm solves the optimization problem:

minimize_c: ||contribs - Σⱼ convolve(c * scale, cwms[j]) * sequences||²₂ + Σⱼ λⱼ||c[:,j]||₁

subject to: c ≥ 0

where c[i,j] represents the strength of motif j at position i.

The importance scaling balances reconstruction across different motifs and positions based on the local contribution magnitude.

Examples
>>> hits_df, qc_df = fit_contribs(
...     cwms=motif_cwms,
...     contribs=contrib_scores,
...     sequences=onehot_seqs,
...     cwm_trim_mask=trim_masks,
...     use_hypothetical=False,
...     lambdas=np.array([0.7, 0.8]),
...     step_size_max=3.0,
...     step_size_min=0.08,
...     sqrt_transform=False,
...     convergence_tol=0.0005,
...     max_steps=10000,
...     batch_size=1000,
...     step_adjust=0.7,
...     post_filter=True,
...     device=None,
...     compile_optimizer=False
... )