finemo.evaluation

Evaluation module for assessing Fi-NeMo motif discovery and hit calling performance.

This module provides functions for:

  • Computing motif occurrence statistics and co-occurrence patterns
  • Evaluating motif discovery quality against TF-MoDISco results
  • Analyzing hit calling performance and recall metrics
  • Generating confusion matrices for seqlet-hit comparisons
  1"""Evaluation module for assessing Fi-NeMo motif discovery and hit calling performance.
  2
  3This module provides functions for:
  4- Computing motif occurrence statistics and co-occurrence patterns
  5- Evaluating motif discovery quality against TF-MoDISco results
  6- Analyzing hit calling performance and recall metrics
  7- Generating confusion matrices for seqlet-hit comparisons
  8"""
  9
 10import warnings
 11from typing import List, Tuple, Dict, Any, Union
 12
 13import numpy as np
 14from numpy import ndarray
 15import polars as pl
 16from jaxtyping import Float, Int
 17
 18
 19def get_motif_occurences(
 20    hits_df: pl.LazyFrame, motif_names: List[str]
 21) -> Tuple[pl.DataFrame, Int[ndarray, "M M"]]:
 22    """Compute motif occurrence statistics and co-occurrence matrix.
 23
 24    This function analyzes motif occurrence patterns across peaks by creating
 25    a pivot table of hit counts and computing pairwise co-occurrence statistics.
 26
 27    Parameters
 28    ----------
 29    hits_df : pl.LazyFrame
 30        Lazy DataFrame containing hit data with required columns:
 31        - peak_id : Peak identifier
 32        - motif_name : Name of the motif
 33        Additional columns are ignored.
 34    motif_names : List[str]
 35        List of motif names to include in analysis. Missing motifs
 36        will be added as columns with zero counts.
 37
 38    Returns
 39    -------
 40    occ_df : pl.DataFrame
 41        DataFrame with motif occurrence counts per peak. Contains:
 42        - peak_id column
 43        - One column per motif with hit counts
 44        - 'total' column summing all motif counts per peak
 45    coocc : Int[ndarray, "M M"]
 46        Co-occurrence matrix where M = len(motif_names).
 47        Entry (i,j) indicates number of peaks containing both motif i and motif j.
 48        Diagonal entries show total peaks containing each motif.
 49
 50    Notes
 51    -----
 52    The co-occurrence matrix is computed using binary occurrence indicators,
 53    so multiple hits of the same motif in a peak are treated as a single occurrence.
 54    """
 55    occ_df = (
 56        hits_df.collect()
 57        .with_columns(pl.lit(1).alias("count"))
 58        .pivot(
 59            on="motif_name", index="peak_id", values="count", aggregate_function="sum"
 60        )
 61        .fill_null(0)
 62    )
 63
 64    missing_cols = set(motif_names) - set(occ_df.columns)
 65    occ_df = (
 66        occ_df.with_columns([pl.lit(0).alias(m) for m in missing_cols])
 67        .with_columns(total=pl.sum_horizontal(*motif_names))
 68        .sort(["peak_id"])
 69    )
 70
 71    num_peaks = occ_df.height
 72    num_motifs = len(motif_names)
 73
 74    occ_mat = np.zeros((num_peaks, num_motifs), dtype=np.int16)
 75    for i, m in enumerate(motif_names):
 76        occ_mat[:, i] = occ_df.get_column(m).to_numpy()
 77
 78    occ_bin = (occ_mat > 0).astype(np.int32)
 79    coocc = occ_bin.T @ occ_bin
 80
 81    return occ_df, coocc
 82
 83
 84def get_motifs(
 85    regions: Float[ndarray, "N 4 L"], positions_df: pl.DataFrame, motif_width: int
 86) -> Float[ndarray, "H 4 W"]:
 87    """Extract contribution weight matrices from regions based on hit positions.
 88
 89    This function extracts motif-sized windows from contribution score regions
 90    at positions specified by hit coordinates. It handles both forward and
 91    reverse complement orientations and filters out invalid positions.
 92
 93    Parameters
 94    ----------
 95    regions : Float[ndarray, "N 4 L"]
 96        Input contribution score regions multiplied by one-hot sequences,
 97        OR Input one-hot encoded sequences.
 98        Shape: (n_peaks, 4, region_width) where 4 represents DNA bases (A,C,G,T).
 99    positions_df : pl.DataFrame
100        DataFrame containing hit positions with required columns:
101        - peak_id : int, Peak index (0-based)
102        - start_untrimmed : int, Start position in genomic coordinates
103        - peak_region_start : int, Peak region start coordinate
104        - is_revcomp : bool, Whether hit is on reverse complement strand
105    motif_width : int
106        Width of motifs to extract. Must be positive.
107
108    Returns
109    -------
110    motifs : Float[ndarray, "H 4 W"]
111        Extracted motif matrices for valid hits.
112        Shape: (n_valid_hits, 4, motif_width)
113        Invalid hits (outside region boundaries) are filtered out.
114
115    Notes
116    -----
117    - Start positions are converted from genomic to region-relative coordinates
118    - Reverse complement hits have their sequence order reversed
119    - Hits extending beyond region boundaries are excluded
120    - The mean is computed across all valid hits, with warnings suppressed
121      for empty slices or invalid operations
122
123    Raises
124    ------
125    ValueError
126        If motif_width is non-positive or positions_df lacks required columns.
127    """
128    idx_df = positions_df.select(
129        peak_idx=pl.col("peak_id"),
130        start_idx=pl.col("start_untrimmed") - pl.col("peak_region_start"),
131        is_revcomp=pl.col("is_revcomp"),
132    )
133    peak_idx = idx_df.get_column("peak_idx").to_numpy()
134    start_idx = idx_df.get_column("start_idx").to_numpy()
135    is_revcomp = idx_df.get_column("is_revcomp").to_numpy().astype(bool)
136
137    # Filter hits that fall outside the region boundaries
138    valid_mask = (start_idx >= 0) & (start_idx + motif_width <= regions.shape[2])
139    peak_idx = peak_idx[valid_mask]
140    start_idx = start_idx[valid_mask]
141    is_revcomp = is_revcomp[valid_mask]
142
143    row_idx = peak_idx[:, None, None]
144    pos_idx = start_idx[:, None, None] + np.zeros((1, 1, motif_width), dtype=int)
145    pos_idx[~is_revcomp, :, :] += np.arange(motif_width)[None, None, :]
146    pos_idx[is_revcomp, :, :] += np.arange(motif_width)[None, None, ::-1]
147    nuc_idx = np.zeros((peak_idx.shape[0], 4, 1), dtype=int)
148    nuc_idx[~is_revcomp, :, :] += np.arange(4)[None, :, None]
149    nuc_idx[is_revcomp, :, :] += np.arange(4)[None, ::-1, None]
150
151    seqs = regions[row_idx, nuc_idx, pos_idx]
152
153    with warnings.catch_warnings():
154        warnings.filterwarnings(
155            action="ignore", message="invalid value encountered in divide"
156        )
157        warnings.filterwarnings(action="ignore", message="Mean of empty slice")
158        motifs = seqs.mean(axis=0)
159
160    return motifs
161
162
163def tfmodisco_comparison(
164    regions: Float[ndarray, "N 4 L"],
165    sequences: Int[ndarray, "N 4 L"],
166    hits_df: Union[pl.DataFrame, pl.LazyFrame],
167    peaks_df: pl.DataFrame,
168    seqlets_df: Union[pl.DataFrame, pl.LazyFrame, None],
169    motifs_df: pl.DataFrame,
170    cwms_modisco: Float[ndarray, "M 4 W"],
171    motif_names: List[str],
172    modisco_half_width: int,
173    motif_width: int,
174    compute_recall: bool,
175) -> Tuple[
176    Dict[str, Dict[str, Any]],
177    pl.DataFrame,
178    Dict[str, Dict[str, Float[ndarray, "4 W"]]],
179    Dict[str, Dict[str, Tuple[int, int]]],
180]:
181    """Compare Fi-NeMo hits with TF-MoDISco seqlets and compute evaluation metrics.
182
183    This function performs comprehensive comparison between Fi-NeMo hit calls
184    and TF-MoDISco seqlets, computing recall metrics, CWM similarities,
185    and extracting contribution weight matrices for visualization.
186
187    Parameters
188    ----------
189    regions : Float[ndarray, "N 4 L"]
190        Contribution score regions multiplied by one-hot sequences.
191        Shape: (n_peaks, 4, region_length)
192    sequences : Int[ndarray, "N 4 L"]
193        One-hot encoded sequences corresponding to regions.
194        Shape: (n_peaks, 4, region_length)
195    hits_df : Union[pl.DataFrame, pl.LazyFrame]
196        Fi-NeMo hit calls with required columns:
197        - peak_id, start_untrimmed, end_untrimmed, strand, motif_name
198    peaks_df : pl.DataFrame
199        Peak metadata with columns:
200        - peak_id, chr_id, peak_region_start
201    seqlets_df : Optional[pl.DataFrame]
202        TF-MoDISco seqlets with columns:
203        - chr_id, start_untrimmed, is_revcomp, motif_name
204        If None, only basic hit statistics are computed.
205    motifs_df : pl.DataFrame
206        Motif metadata with columns:
207        - motif_name, strand, motif_id, motif_start, motif_end
208    cwms_modisco : Float[ndarray, "M 4 W"]
209        TF-MoDISco contribution weight matrices.
210        Shape: (n_modisco_motifs, 4, motif_width)
211    motif_names : List[str]
212        Names of motifs to analyze.
213    modisco_half_width : int
214        Half-width for restricting hits to central region for fair comparison.
215    motif_width : int
216        Width of motifs for CWM extraction.
217    compute_recall : bool
218        Whether to compute recall metrics requiring seqlets_df.
219
220    Returns
221    -------
222    report_data : Dict[str, Dict[str, Any]]
223        Per-motif evaluation metrics including:
224        - num_hits_total, num_hits_restricted, num_seqlets
225        - num_overlaps, seqlet_recall, cwm_similarity
226    report_df : pl.DataFrame
227        Tabular format of report_data for easy analysis.
228    cwms : Dict[str, Dict[str, Float[ndarray, "4 W"]]]
229        Extracted CWMs for each motif and condition:
230        - hits_fc, hits_rc: Forward/reverse complement hits
231        - modisco_fc, modisco_rc: TF-MoDISco forward/reverse
232        - seqlets_only, hits_restricted_only: Non-overlapping instances
233    cwm_trim_bounds : Dict[str, Dict[str, Tuple[int, int]]]
234        Trimming boundaries for each CWM type and motif.
235
236    Notes
237    -----
238    - Hits are filtered to central region defined by modisco_half_width
239    - CWM similarity is computed as normalized dot product between hit and TF-MoDISco CWMs
240    - Recall metrics require both hits_df and seqlets_df to be non-empty
241    - Missing motifs are handled gracefully with empty DataFrames
242
243    Raises
244    ------
245    ValueError
246        If required columns are missing from input DataFrames.
247    """
248
249    # Ensure hits_df is LazyFrame for consistent operations
250    if isinstance(hits_df, pl.DataFrame):
251        hits_df = hits_df.lazy()
252
253    hits_df = (
254        hits_df.with_columns(pl.col("peak_id").cast(pl.UInt32))
255        .join(peaks_df.lazy(), on="peak_id", how="inner")
256        .select(
257            chr_id=pl.col("chr_id"),
258            start_untrimmed=pl.col("start_untrimmed"),
259            end_untrimmed=pl.col("end_untrimmed"),
260            is_revcomp=pl.col("strand") == "-",
261            motif_name=pl.col("motif_name"),
262            peak_region_start=pl.col("peak_region_start"),
263            peak_id=pl.col("peak_id"),
264        )
265    )
266
267    hits_unique = hits_df.unique(
268        subset=["chr_id", "start_untrimmed", "motif_name", "is_revcomp"]
269    )
270
271    region_len = regions.shape[2]
272    center = region_len / 2
273    hits_filtered = hits_df.filter(
274        (
275            (pl.col("start_untrimmed") - pl.col("peak_region_start"))
276            >= (center - modisco_half_width)
277        )
278        & (
279            (pl.col("end_untrimmed") - pl.col("peak_region_start"))
280            <= (center + modisco_half_width)
281        )
282    ).unique(subset=["chr_id", "start_untrimmed", "motif_name", "is_revcomp"])
283
284    hits_by_motif = hits_unique.collect().partition_by("motif_name", as_dict=True)
285    hits_filtered_by_motif = hits_filtered.collect().partition_by(
286        "motif_name", as_dict=True
287    )
288
289    if seqlets_df is None:
290        seqlets_collected = None
291        seqlets_lazy = None
292    elif isinstance(seqlets_df, pl.LazyFrame):
293        seqlets_collected = seqlets_df.collect()
294        seqlets_lazy = seqlets_df
295    else:
296        seqlets_collected = seqlets_df
297        seqlets_lazy = seqlets_df.lazy()
298
299    if seqlets_collected is not None:
300        seqlets_by_motif = seqlets_collected.partition_by("motif_name", as_dict=True)
301    else:
302        seqlets_by_motif = {}
303
304    if compute_recall and seqlets_lazy is not None:
305        overlaps_df = hits_filtered.join(
306            seqlets_lazy,
307            on=["chr_id", "start_untrimmed", "is_revcomp", "motif_name"],
308            how="inner",
309        ).collect()
310
311        seqlets_only_df = seqlets_lazy.join(
312            hits_df,
313            on=["chr_id", "start_untrimmed", "is_revcomp", "motif_name"],
314            how="anti",
315        ).collect()
316
317        hits_only_filtered_df = hits_filtered.join(
318            seqlets_lazy,
319            on=["chr_id", "start_untrimmed", "is_revcomp", "motif_name"],
320            how="anti",
321        ).collect()
322
323        # Create partition dictionaries
324        overlaps_by_motif = overlaps_df.partition_by("motif_name", as_dict=True)
325        seqlets_only_by_motif = seqlets_only_df.partition_by("motif_name", as_dict=True)
326        hits_only_filtered_by_motif = hits_only_filtered_df.partition_by(
327            "motif_name", as_dict=True
328        )
329    else:
330        overlaps_by_motif = {}
331        seqlets_only_by_motif = {}
332        hits_only_filtered_by_motif = {}
333
334    report_data = {}
335    motifs = {}
336    cwm_trim_bounds = {}
337    dummy_df = hits_df.clear().collect()
338    for m in motif_names:
339        hits = hits_by_motif.get((m,), dummy_df)
340        hits_filtered = hits_filtered_by_motif.get((m,), dummy_df)
341
342        # Initialize default values
343        seqlets = dummy_df
344        overlaps = dummy_df
345        seqlets_only = dummy_df
346        hits_only_filtered = dummy_df
347
348        if seqlets_df is not None:
349            seqlets = seqlets_by_motif.get((m,), dummy_df)
350
351        if compute_recall and seqlets_df is not None:
352            overlaps = overlaps_by_motif.get((m,), dummy_df)
353            seqlets_only = seqlets_only_by_motif.get((m,), dummy_df)
354            hits_only_filtered = hits_only_filtered_by_motif.get((m,), dummy_df)
355
356        report_data[m] = {
357            "num_hits_total": hits.height,
358            "num_hits_restricted": hits_filtered.height,
359        }
360
361        if seqlets_df is not None:
362            report_data[m]["num_seqlets"] = seqlets.height
363
364        if compute_recall and seqlets_df is not None:
365            report_data[m] |= {
366                "num_overlaps": overlaps.height,
367                "num_seqlets_only": seqlets_only.height,
368                "num_hits_restricted_only": hits_only_filtered.height,
369                "seqlet_recall": np.float64(overlaps.height) / seqlets.height
370                if seqlets.height > 0
371                else 0.0,
372            }
373
374        motif_data_fc = motifs_df.row(
375            by_predicate=(pl.col("motif_name") == m) & (pl.col("strand") == "+"),
376            named=True,
377        )
378        motif_data_rc = motifs_df.row(
379            by_predicate=(pl.col("motif_name") == m) & (pl.col("strand") == "-"),
380            named=True,
381        )
382
383        motifs[m] = {
384            "hits_fc": get_motifs(regions, hits, motif_width),
385            "modisco_fc": cwms_modisco[motif_data_fc["motif_id"]],
386            "modisco_rc": cwms_modisco[motif_data_rc["motif_id"]],
387        }
388        motifs[m]["hits_rc"] = motifs[m]["hits_fc"][::-1, ::-1]
389        motifs[m]["hits_ppm_fc"] = get_motifs(sequences, hits, motif_width)
390        motifs[m]["hits_ppm_rc"] = motifs[m]["hits_ppm_fc"][::-1, ::-1]
391
392        if compute_recall and seqlets_df is not None:
393            motifs[m] |= {
394                "seqlets_only": get_motifs(regions, seqlets_only, motif_width),
395                "hits_restricted_only": get_motifs(
396                    regions, hits_only_filtered, motif_width
397                ),
398            }
399
400        bounds_fc = (motif_data_fc["motif_start"], motif_data_fc["motif_end"])
401        bounds_rc = (motif_data_rc["motif_start"], motif_data_rc["motif_end"])
402
403        cwm_trim_bounds[m] = {
404            "hits_fc": bounds_fc,
405            "modisco_fc": bounds_fc,
406            "modisco_rc": bounds_rc,
407            "hits_rc": bounds_rc,
408            "hits_ppm_fc": bounds_fc,
409            "hits_ppm_rc": bounds_rc,
410        }
411
412        if compute_recall and seqlets_df is not None:
413            cwm_trim_bounds[m] |= {
414                "seqlets_only": bounds_fc,
415                "hits_restricted_only": bounds_fc,
416            }
417
418        hits_cwm = motifs[m]["hits_fc"]
419        modisco_cwm = motifs[m]["modisco_fc"]
420        hnorm = np.sqrt((hits_cwm**2).sum())
421        snorm = np.sqrt((modisco_cwm**2).sum())
422        cwm_sim = (hits_cwm * modisco_cwm).sum() / (hnorm * snorm)
423
424        report_data[m]["cwm_similarity"] = cwm_sim
425
426    records = [{"motif_name": k} | v for k, v in report_data.items()]
427    report_df = pl.from_dicts(records)
428
429    return report_data, report_df, motifs, cwm_trim_bounds
430
431
432def seqlet_confusion(
433    hits_df: Union[pl.DataFrame, pl.LazyFrame],
434    seqlets_df: Union[pl.DataFrame, pl.LazyFrame],
435    peaks_df: pl.DataFrame,
436    motif_names: List[str],
437    motif_width: int,
438) -> Tuple[pl.DataFrame, Float[ndarray, "M M"]]:
439    """Compute confusion matrix between TF-MoDISco seqlets and Fi-NeMo hits.
440
441    This function creates a confusion matrix showing the overlap between
442    TF-MoDISco seqlets (ground truth) and Fi-NeMo hits across different motifs.
443    Overlap frequencies are estimated using binned genomic coordinates.
444
445    Parameters
446    ----------
447    hits_df : Union[pl.DataFrame, pl.LazyFrame]
448        Fi-NeMo hit calls with required columns:
449        - peak_id, start_untrimmed, end_untrimmed, strand, motif_name
450    seqlets_df : pl.DataFrame
451        TF-MoDISco seqlets with required columns:
452        - chr_id, start_untrimmed, end_untrimmed, motif_name
453    peaks_df : pl.DataFrame
454        Peak metadata for joining coordinates:
455        - peak_id, chr_id
456    motif_names : List[str]
457        Names of motifs to include in confusion matrix.
458        Determines matrix dimensions.
459    motif_width : int
460        Width used for binning genomic coordinates.
461        Positions are binned to motif_width resolution.
462
463    Returns
464    -------
465    confusion_df : pl.DataFrame
466        Detailed confusion matrix in tabular format with columns:
467        - motif_name_seqlets : Seqlet motif labels (rows)
468        - motif_name_hits : Hit motif labels (columns)
469        - frac_overlap : Fraction of seqlets overlapping with hits
470    confusion_mat : Float[ndarray, "M M"]
471        Confusion matrix where M = len(motif_names).
472        Entry (i,j) = fraction of motif i seqlets overlapping with motif j hits.
473        Rows represent seqlet motifs, columns represent hit motifs.
474
475    Notes
476    -----
477    - Genomic coordinates are binned to motif_width resolution for overlap detection
478    - Only exact bin overlaps are considered (same chr_id, start_bin, end_bin)
479    - Fractions are computed as: overlaps / total_seqlets_per_motif
480    - Missing motif combinations result in zero entries in the confusion matrix
481
482    Raises
483    ------
484    ValueError
485        If required columns are missing from input DataFrames.
486    KeyError
487        If motif names in data don't match those in motif_names list.
488    """
489    bin_size = motif_width
490
491    # Ensure hits_df is LazyFrame for consistent operations
492    if isinstance(hits_df, pl.DataFrame):
493        hits_df = hits_df.lazy()
494
495    hits_binned = (
496        hits_df.with_columns(
497            peak_id=pl.col("peak_id").cast(pl.UInt32),
498            is_revcomp=pl.col("strand") == "-",
499        )
500        .join(peaks_df.lazy(), on="peak_id", how="inner")
501        .unique(subset=["chr_id", "start_untrimmed", "motif_name", "is_revcomp"])
502        .select(
503            chr_id=pl.col("chr_id"),
504            start_bin=pl.col("start_untrimmed") // bin_size,
505            end_bin=pl.col("end_untrimmed") // bin_size,
506            motif_name=pl.col("motif_name"),
507        )
508    )
509
510    seqlets_lazy = seqlets_df.lazy()
511    seqlets_binned = seqlets_lazy.select(
512        chr_id=pl.col("chr_id"),
513        start_bin=pl.col("start_untrimmed") // bin_size,
514        end_bin=pl.col("end_untrimmed") // bin_size,
515        motif_name=pl.col("motif_name"),
516    )
517
518    overlaps_df = seqlets_binned.join(
519        hits_binned, on=["chr_id", "start_bin", "end_bin"], how="inner", suffix="_hits"
520    )
521
522    seqlet_counts = (
523        seqlets_binned.group_by("motif_name").len(name="num_seqlets").collect()
524    )
525    overlap_counts = (
526        overlaps_df.group_by(["motif_name", "motif_name_hits"])
527        .len(name="num_overlaps")
528        .collect()
529    )
530
531    num_motifs = len(motif_names)
532    confusion_mat = np.zeros((num_motifs, num_motifs), dtype=np.float32)
533    name_to_idx = {m: i for i, m in enumerate(motif_names)}
534
535    confusion_df = overlap_counts.join(
536        seqlet_counts, on="motif_name", how="inner"
537    ).select(
538        motif_name_seqlets=pl.col("motif_name"),
539        motif_name_hits=pl.col("motif_name_hits"),
540        frac_overlap=pl.col("num_overlaps") / pl.col("num_seqlets"),
541    )
542
543    confusion_idx_df = confusion_df.select(
544        row_idx=pl.col("motif_name_seqlets").replace_strict(name_to_idx),
545        col_idx=pl.col("motif_name_hits").replace_strict(name_to_idx),
546        frac_overlap=pl.col("frac_overlap"),
547    )
548
549    row_idx = confusion_idx_df["row_idx"].to_numpy()
550    col_idx = confusion_idx_df["col_idx"].to_numpy()
551    frac_overlap = confusion_idx_df["frac_overlap"].to_numpy()
552
553    confusion_mat[row_idx, col_idx] = frac_overlap
554
555    return confusion_df, confusion_mat
def get_motif_occurences( hits_df: polars.lazyframe.frame.LazyFrame, motif_names: List[str]) -> Tuple[polars.dataframe.frame.DataFrame, jaxtyping.Int[ndarray, 'M M']]:
20def get_motif_occurences(
21    hits_df: pl.LazyFrame, motif_names: List[str]
22) -> Tuple[pl.DataFrame, Int[ndarray, "M M"]]:
23    """Compute motif occurrence statistics and co-occurrence matrix.
24
25    This function analyzes motif occurrence patterns across peaks by creating
26    a pivot table of hit counts and computing pairwise co-occurrence statistics.
27
28    Parameters
29    ----------
30    hits_df : pl.LazyFrame
31        Lazy DataFrame containing hit data with required columns:
32        - peak_id : Peak identifier
33        - motif_name : Name of the motif
34        Additional columns are ignored.
35    motif_names : List[str]
36        List of motif names to include in analysis. Missing motifs
37        will be added as columns with zero counts.
38
39    Returns
40    -------
41    occ_df : pl.DataFrame
42        DataFrame with motif occurrence counts per peak. Contains:
43        - peak_id column
44        - One column per motif with hit counts
45        - 'total' column summing all motif counts per peak
46    coocc : Int[ndarray, "M M"]
47        Co-occurrence matrix where M = len(motif_names).
48        Entry (i,j) indicates number of peaks containing both motif i and motif j.
49        Diagonal entries show total peaks containing each motif.
50
51    Notes
52    -----
53    The co-occurrence matrix is computed using binary occurrence indicators,
54    so multiple hits of the same motif in a peak are treated as a single occurrence.
55    """
56    occ_df = (
57        hits_df.collect()
58        .with_columns(pl.lit(1).alias("count"))
59        .pivot(
60            on="motif_name", index="peak_id", values="count", aggregate_function="sum"
61        )
62        .fill_null(0)
63    )
64
65    missing_cols = set(motif_names) - set(occ_df.columns)
66    occ_df = (
67        occ_df.with_columns([pl.lit(0).alias(m) for m in missing_cols])
68        .with_columns(total=pl.sum_horizontal(*motif_names))
69        .sort(["peak_id"])
70    )
71
72    num_peaks = occ_df.height
73    num_motifs = len(motif_names)
74
75    occ_mat = np.zeros((num_peaks, num_motifs), dtype=np.int16)
76    for i, m in enumerate(motif_names):
77        occ_mat[:, i] = occ_df.get_column(m).to_numpy()
78
79    occ_bin = (occ_mat > 0).astype(np.int32)
80    coocc = occ_bin.T @ occ_bin
81
82    return occ_df, coocc

Compute motif occurrence statistics and co-occurrence matrix.

This function analyzes motif occurrence patterns across peaks by creating a pivot table of hit counts and computing pairwise co-occurrence statistics.

Parameters
  • hits_df (pl.LazyFrame): Lazy DataFrame containing hit data with required columns:
    • peak_id : Peak identifier
    • motif_name : Name of the motif Additional columns are ignored.
  • motif_names (List[str]): List of motif names to include in analysis. Missing motifs will be added as columns with zero counts.
Returns
  • occ_df (pl.DataFrame): DataFrame with motif occurrence counts per peak. Contains:
    • peak_id column
    • One column per motif with hit counts
    • 'total' column summing all motif counts per peak
  • coocc (Int[ndarray, "M M"]): Co-occurrence matrix where M = len(motif_names). Entry (i,j) indicates number of peaks containing both motif i and motif j. Diagonal entries show total peaks containing each motif.
Notes

The co-occurrence matrix is computed using binary occurrence indicators, so multiple hits of the same motif in a peak are treated as a single occurrence.

def get_motifs( regions: jaxtyping.Float[ndarray, 'N 4 L'], positions_df: polars.dataframe.frame.DataFrame, motif_width: int) -> jaxtyping.Float[ndarray, 'H 4 W']:
 85def get_motifs(
 86    regions: Float[ndarray, "N 4 L"], positions_df: pl.DataFrame, motif_width: int
 87) -> Float[ndarray, "H 4 W"]:
 88    """Extract contribution weight matrices from regions based on hit positions.
 89
 90    This function extracts motif-sized windows from contribution score regions
 91    at positions specified by hit coordinates. It handles both forward and
 92    reverse complement orientations and filters out invalid positions.
 93
 94    Parameters
 95    ----------
 96    regions : Float[ndarray, "N 4 L"]
 97        Input contribution score regions multiplied by one-hot sequences,
 98        OR Input one-hot encoded sequences.
 99        Shape: (n_peaks, 4, region_width) where 4 represents DNA bases (A,C,G,T).
100    positions_df : pl.DataFrame
101        DataFrame containing hit positions with required columns:
102        - peak_id : int, Peak index (0-based)
103        - start_untrimmed : int, Start position in genomic coordinates
104        - peak_region_start : int, Peak region start coordinate
105        - is_revcomp : bool, Whether hit is on reverse complement strand
106    motif_width : int
107        Width of motifs to extract. Must be positive.
108
109    Returns
110    -------
111    motifs : Float[ndarray, "H 4 W"]
112        Extracted motif matrices for valid hits.
113        Shape: (n_valid_hits, 4, motif_width)
114        Invalid hits (outside region boundaries) are filtered out.
115
116    Notes
117    -----
118    - Start positions are converted from genomic to region-relative coordinates
119    - Reverse complement hits have their sequence order reversed
120    - Hits extending beyond region boundaries are excluded
121    - The mean is computed across all valid hits, with warnings suppressed
122      for empty slices or invalid operations
123
124    Raises
125    ------
126    ValueError
127        If motif_width is non-positive or positions_df lacks required columns.
128    """
129    idx_df = positions_df.select(
130        peak_idx=pl.col("peak_id"),
131        start_idx=pl.col("start_untrimmed") - pl.col("peak_region_start"),
132        is_revcomp=pl.col("is_revcomp"),
133    )
134    peak_idx = idx_df.get_column("peak_idx").to_numpy()
135    start_idx = idx_df.get_column("start_idx").to_numpy()
136    is_revcomp = idx_df.get_column("is_revcomp").to_numpy().astype(bool)
137
138    # Filter hits that fall outside the region boundaries
139    valid_mask = (start_idx >= 0) & (start_idx + motif_width <= regions.shape[2])
140    peak_idx = peak_idx[valid_mask]
141    start_idx = start_idx[valid_mask]
142    is_revcomp = is_revcomp[valid_mask]
143
144    row_idx = peak_idx[:, None, None]
145    pos_idx = start_idx[:, None, None] + np.zeros((1, 1, motif_width), dtype=int)
146    pos_idx[~is_revcomp, :, :] += np.arange(motif_width)[None, None, :]
147    pos_idx[is_revcomp, :, :] += np.arange(motif_width)[None, None, ::-1]
148    nuc_idx = np.zeros((peak_idx.shape[0], 4, 1), dtype=int)
149    nuc_idx[~is_revcomp, :, :] += np.arange(4)[None, :, None]
150    nuc_idx[is_revcomp, :, :] += np.arange(4)[None, ::-1, None]
151
152    seqs = regions[row_idx, nuc_idx, pos_idx]
153
154    with warnings.catch_warnings():
155        warnings.filterwarnings(
156            action="ignore", message="invalid value encountered in divide"
157        )
158        warnings.filterwarnings(action="ignore", message="Mean of empty slice")
159        motifs = seqs.mean(axis=0)
160
161    return motifs

Extract contribution weight matrices from regions based on hit positions.

This function extracts motif-sized windows from contribution score regions at positions specified by hit coordinates. It handles both forward and reverse complement orientations and filters out invalid positions.

Parameters
  • regions (Float[ndarray, "N 4 L"]): Input contribution score regions multiplied by one-hot sequences, OR Input one-hot encoded sequences. Shape: (n_peaks, 4, region_width) where 4 represents DNA bases (A,C,G,T).
  • positions_df (pl.DataFrame): DataFrame containing hit positions with required columns:
    • peak_id : int, Peak index (0-based)
    • start_untrimmed : int, Start position in genomic coordinates
    • peak_region_start : int, Peak region start coordinate
    • is_revcomp : bool, Whether hit is on reverse complement strand
  • motif_width (int): Width of motifs to extract. Must be positive.
Returns
  • motifs (Float[ndarray, "H 4 W"]): Extracted motif matrices for valid hits. Shape: (n_valid_hits, 4, motif_width) Invalid hits (outside region boundaries) are filtered out.
Notes
  • Start positions are converted from genomic to region-relative coordinates
  • Reverse complement hits have their sequence order reversed
  • Hits extending beyond region boundaries are excluded
  • The mean is computed across all valid hits, with warnings suppressed for empty slices or invalid operations
Raises
  • ValueError: If motif_width is non-positive or positions_df lacks required columns.
def tfmodisco_comparison( regions: jaxtyping.Float[ndarray, 'N 4 L'], sequences: jaxtyping.Int[ndarray, 'N 4 L'], hits_df: Union[polars.dataframe.frame.DataFrame, polars.lazyframe.frame.LazyFrame], peaks_df: polars.dataframe.frame.DataFrame, seqlets_df: Union[polars.dataframe.frame.DataFrame, polars.lazyframe.frame.LazyFrame, NoneType], motifs_df: polars.dataframe.frame.DataFrame, cwms_modisco: jaxtyping.Float[ndarray, 'M 4 W'], motif_names: List[str], modisco_half_width: int, motif_width: int, compute_recall: bool) -> Tuple[Dict[str, Dict[str, Any]], polars.dataframe.frame.DataFrame, Dict[str, Dict[str, jaxtyping.Float[ndarray, '4 W']]], Dict[str, Dict[str, Tuple[int, int]]]]:
164def tfmodisco_comparison(
165    regions: Float[ndarray, "N 4 L"],
166    sequences: Int[ndarray, "N 4 L"],
167    hits_df: Union[pl.DataFrame, pl.LazyFrame],
168    peaks_df: pl.DataFrame,
169    seqlets_df: Union[pl.DataFrame, pl.LazyFrame, None],
170    motifs_df: pl.DataFrame,
171    cwms_modisco: Float[ndarray, "M 4 W"],
172    motif_names: List[str],
173    modisco_half_width: int,
174    motif_width: int,
175    compute_recall: bool,
176) -> Tuple[
177    Dict[str, Dict[str, Any]],
178    pl.DataFrame,
179    Dict[str, Dict[str, Float[ndarray, "4 W"]]],
180    Dict[str, Dict[str, Tuple[int, int]]],
181]:
182    """Compare Fi-NeMo hits with TF-MoDISco seqlets and compute evaluation metrics.
183
184    This function performs comprehensive comparison between Fi-NeMo hit calls
185    and TF-MoDISco seqlets, computing recall metrics, CWM similarities,
186    and extracting contribution weight matrices for visualization.
187
188    Parameters
189    ----------
190    regions : Float[ndarray, "N 4 L"]
191        Contribution score regions multiplied by one-hot sequences.
192        Shape: (n_peaks, 4, region_length)
193    sequences : Int[ndarray, "N 4 L"]
194        One-hot encoded sequences corresponding to regions.
195        Shape: (n_peaks, 4, region_length)
196    hits_df : Union[pl.DataFrame, pl.LazyFrame]
197        Fi-NeMo hit calls with required columns:
198        - peak_id, start_untrimmed, end_untrimmed, strand, motif_name
199    peaks_df : pl.DataFrame
200        Peak metadata with columns:
201        - peak_id, chr_id, peak_region_start
202    seqlets_df : Optional[pl.DataFrame]
203        TF-MoDISco seqlets with columns:
204        - chr_id, start_untrimmed, is_revcomp, motif_name
205        If None, only basic hit statistics are computed.
206    motifs_df : pl.DataFrame
207        Motif metadata with columns:
208        - motif_name, strand, motif_id, motif_start, motif_end
209    cwms_modisco : Float[ndarray, "M 4 W"]
210        TF-MoDISco contribution weight matrices.
211        Shape: (n_modisco_motifs, 4, motif_width)
212    motif_names : List[str]
213        Names of motifs to analyze.
214    modisco_half_width : int
215        Half-width for restricting hits to central region for fair comparison.
216    motif_width : int
217        Width of motifs for CWM extraction.
218    compute_recall : bool
219        Whether to compute recall metrics requiring seqlets_df.
220
221    Returns
222    -------
223    report_data : Dict[str, Dict[str, Any]]
224        Per-motif evaluation metrics including:
225        - num_hits_total, num_hits_restricted, num_seqlets
226        - num_overlaps, seqlet_recall, cwm_similarity
227    report_df : pl.DataFrame
228        Tabular format of report_data for easy analysis.
229    cwms : Dict[str, Dict[str, Float[ndarray, "4 W"]]]
230        Extracted CWMs for each motif and condition:
231        - hits_fc, hits_rc: Forward/reverse complement hits
232        - modisco_fc, modisco_rc: TF-MoDISco forward/reverse
233        - seqlets_only, hits_restricted_only: Non-overlapping instances
234    cwm_trim_bounds : Dict[str, Dict[str, Tuple[int, int]]]
235        Trimming boundaries for each CWM type and motif.
236
237    Notes
238    -----
239    - Hits are filtered to central region defined by modisco_half_width
240    - CWM similarity is computed as normalized dot product between hit and TF-MoDISco CWMs
241    - Recall metrics require both hits_df and seqlets_df to be non-empty
242    - Missing motifs are handled gracefully with empty DataFrames
243
244    Raises
245    ------
246    ValueError
247        If required columns are missing from input DataFrames.
248    """
249
250    # Ensure hits_df is LazyFrame for consistent operations
251    if isinstance(hits_df, pl.DataFrame):
252        hits_df = hits_df.lazy()
253
254    hits_df = (
255        hits_df.with_columns(pl.col("peak_id").cast(pl.UInt32))
256        .join(peaks_df.lazy(), on="peak_id", how="inner")
257        .select(
258            chr_id=pl.col("chr_id"),
259            start_untrimmed=pl.col("start_untrimmed"),
260            end_untrimmed=pl.col("end_untrimmed"),
261            is_revcomp=pl.col("strand") == "-",
262            motif_name=pl.col("motif_name"),
263            peak_region_start=pl.col("peak_region_start"),
264            peak_id=pl.col("peak_id"),
265        )
266    )
267
268    hits_unique = hits_df.unique(
269        subset=["chr_id", "start_untrimmed", "motif_name", "is_revcomp"]
270    )
271
272    region_len = regions.shape[2]
273    center = region_len / 2
274    hits_filtered = hits_df.filter(
275        (
276            (pl.col("start_untrimmed") - pl.col("peak_region_start"))
277            >= (center - modisco_half_width)
278        )
279        & (
280            (pl.col("end_untrimmed") - pl.col("peak_region_start"))
281            <= (center + modisco_half_width)
282        )
283    ).unique(subset=["chr_id", "start_untrimmed", "motif_name", "is_revcomp"])
284
285    hits_by_motif = hits_unique.collect().partition_by("motif_name", as_dict=True)
286    hits_filtered_by_motif = hits_filtered.collect().partition_by(
287        "motif_name", as_dict=True
288    )
289
290    if seqlets_df is None:
291        seqlets_collected = None
292        seqlets_lazy = None
293    elif isinstance(seqlets_df, pl.LazyFrame):
294        seqlets_collected = seqlets_df.collect()
295        seqlets_lazy = seqlets_df
296    else:
297        seqlets_collected = seqlets_df
298        seqlets_lazy = seqlets_df.lazy()
299
300    if seqlets_collected is not None:
301        seqlets_by_motif = seqlets_collected.partition_by("motif_name", as_dict=True)
302    else:
303        seqlets_by_motif = {}
304
305    if compute_recall and seqlets_lazy is not None:
306        overlaps_df = hits_filtered.join(
307            seqlets_lazy,
308            on=["chr_id", "start_untrimmed", "is_revcomp", "motif_name"],
309            how="inner",
310        ).collect()
311
312        seqlets_only_df = seqlets_lazy.join(
313            hits_df,
314            on=["chr_id", "start_untrimmed", "is_revcomp", "motif_name"],
315            how="anti",
316        ).collect()
317
318        hits_only_filtered_df = hits_filtered.join(
319            seqlets_lazy,
320            on=["chr_id", "start_untrimmed", "is_revcomp", "motif_name"],
321            how="anti",
322        ).collect()
323
324        # Create partition dictionaries
325        overlaps_by_motif = overlaps_df.partition_by("motif_name", as_dict=True)
326        seqlets_only_by_motif = seqlets_only_df.partition_by("motif_name", as_dict=True)
327        hits_only_filtered_by_motif = hits_only_filtered_df.partition_by(
328            "motif_name", as_dict=True
329        )
330    else:
331        overlaps_by_motif = {}
332        seqlets_only_by_motif = {}
333        hits_only_filtered_by_motif = {}
334
335    report_data = {}
336    motifs = {}
337    cwm_trim_bounds = {}
338    dummy_df = hits_df.clear().collect()
339    for m in motif_names:
340        hits = hits_by_motif.get((m,), dummy_df)
341        hits_filtered = hits_filtered_by_motif.get((m,), dummy_df)
342
343        # Initialize default values
344        seqlets = dummy_df
345        overlaps = dummy_df
346        seqlets_only = dummy_df
347        hits_only_filtered = dummy_df
348
349        if seqlets_df is not None:
350            seqlets = seqlets_by_motif.get((m,), dummy_df)
351
352        if compute_recall and seqlets_df is not None:
353            overlaps = overlaps_by_motif.get((m,), dummy_df)
354            seqlets_only = seqlets_only_by_motif.get((m,), dummy_df)
355            hits_only_filtered = hits_only_filtered_by_motif.get((m,), dummy_df)
356
357        report_data[m] = {
358            "num_hits_total": hits.height,
359            "num_hits_restricted": hits_filtered.height,
360        }
361
362        if seqlets_df is not None:
363            report_data[m]["num_seqlets"] = seqlets.height
364
365        if compute_recall and seqlets_df is not None:
366            report_data[m] |= {
367                "num_overlaps": overlaps.height,
368                "num_seqlets_only": seqlets_only.height,
369                "num_hits_restricted_only": hits_only_filtered.height,
370                "seqlet_recall": np.float64(overlaps.height) / seqlets.height
371                if seqlets.height > 0
372                else 0.0,
373            }
374
375        motif_data_fc = motifs_df.row(
376            by_predicate=(pl.col("motif_name") == m) & (pl.col("strand") == "+"),
377            named=True,
378        )
379        motif_data_rc = motifs_df.row(
380            by_predicate=(pl.col("motif_name") == m) & (pl.col("strand") == "-"),
381            named=True,
382        )
383
384        motifs[m] = {
385            "hits_fc": get_motifs(regions, hits, motif_width),
386            "modisco_fc": cwms_modisco[motif_data_fc["motif_id"]],
387            "modisco_rc": cwms_modisco[motif_data_rc["motif_id"]],
388        }
389        motifs[m]["hits_rc"] = motifs[m]["hits_fc"][::-1, ::-1]
390        motifs[m]["hits_ppm_fc"] = get_motifs(sequences, hits, motif_width)
391        motifs[m]["hits_ppm_rc"] = motifs[m]["hits_ppm_fc"][::-1, ::-1]
392
393        if compute_recall and seqlets_df is not None:
394            motifs[m] |= {
395                "seqlets_only": get_motifs(regions, seqlets_only, motif_width),
396                "hits_restricted_only": get_motifs(
397                    regions, hits_only_filtered, motif_width
398                ),
399            }
400
401        bounds_fc = (motif_data_fc["motif_start"], motif_data_fc["motif_end"])
402        bounds_rc = (motif_data_rc["motif_start"], motif_data_rc["motif_end"])
403
404        cwm_trim_bounds[m] = {
405            "hits_fc": bounds_fc,
406            "modisco_fc": bounds_fc,
407            "modisco_rc": bounds_rc,
408            "hits_rc": bounds_rc,
409            "hits_ppm_fc": bounds_fc,
410            "hits_ppm_rc": bounds_rc,
411        }
412
413        if compute_recall and seqlets_df is not None:
414            cwm_trim_bounds[m] |= {
415                "seqlets_only": bounds_fc,
416                "hits_restricted_only": bounds_fc,
417            }
418
419        hits_cwm = motifs[m]["hits_fc"]
420        modisco_cwm = motifs[m]["modisco_fc"]
421        hnorm = np.sqrt((hits_cwm**2).sum())
422        snorm = np.sqrt((modisco_cwm**2).sum())
423        cwm_sim = (hits_cwm * modisco_cwm).sum() / (hnorm * snorm)
424
425        report_data[m]["cwm_similarity"] = cwm_sim
426
427    records = [{"motif_name": k} | v for k, v in report_data.items()]
428    report_df = pl.from_dicts(records)
429
430    return report_data, report_df, motifs, cwm_trim_bounds

Compare Fi-NeMo hits with TF-MoDISco seqlets and compute evaluation metrics.

This function performs comprehensive comparison between Fi-NeMo hit calls and TF-MoDISco seqlets, computing recall metrics, CWM similarities, and extracting contribution weight matrices for visualization.

Parameters
  • regions (Float[ndarray, "N 4 L"]): Contribution score regions multiplied by one-hot sequences. Shape: (n_peaks, 4, region_length)
  • sequences (Int[ndarray, "N 4 L"]): One-hot encoded sequences corresponding to regions. Shape: (n_peaks, 4, region_length)
  • hits_df (Union[pl.DataFrame, pl.LazyFrame]): Fi-NeMo hit calls with required columns:
    • peak_id, start_untrimmed, end_untrimmed, strand, motif_name
  • peaks_df (pl.DataFrame): Peak metadata with columns:
    • peak_id, chr_id, peak_region_start
  • seqlets_df (Optional[pl.DataFrame]): TF-MoDISco seqlets with columns:
    • chr_id, start_untrimmed, is_revcomp, motif_name If None, only basic hit statistics are computed.
  • motifs_df (pl.DataFrame): Motif metadata with columns:
    • motif_name, strand, motif_id, motif_start, motif_end
  • cwms_modisco (Float[ndarray, "M 4 W"]): TF-MoDISco contribution weight matrices. Shape: (n_modisco_motifs, 4, motif_width)
  • motif_names (List[str]): Names of motifs to analyze.
  • modisco_half_width (int): Half-width for restricting hits to central region for fair comparison.
  • motif_width (int): Width of motifs for CWM extraction.
  • compute_recall (bool): Whether to compute recall metrics requiring seqlets_df.
Returns
  • report_data (Dict[str, Dict[str, Any]]): Per-motif evaluation metrics including:
    • num_hits_total, num_hits_restricted, num_seqlets
    • num_overlaps, seqlet_recall, cwm_similarity
  • report_df (pl.DataFrame): Tabular format of report_data for easy analysis.
  • cwms (Dict[str, Dict[str, Float[ndarray, "4 W"]]]): Extracted CWMs for each motif and condition:
    • hits_fc, hits_rc: Forward/reverse complement hits
    • modisco_fc, modisco_rc: TF-MoDISco forward/reverse
    • seqlets_only, hits_restricted_only: Non-overlapping instances
  • cwm_trim_bounds (Dict[str, Dict[str, Tuple[int, int]]]): Trimming boundaries for each CWM type and motif.
Notes
  • Hits are filtered to central region defined by modisco_half_width
  • CWM similarity is computed as normalized dot product between hit and TF-MoDISco CWMs
  • Recall metrics require both hits_df and seqlets_df to be non-empty
  • Missing motifs are handled gracefully with empty DataFrames
Raises
  • ValueError: If required columns are missing from input DataFrames.
def seqlet_confusion( hits_df: Union[polars.dataframe.frame.DataFrame, polars.lazyframe.frame.LazyFrame], seqlets_df: Union[polars.dataframe.frame.DataFrame, polars.lazyframe.frame.LazyFrame], peaks_df: polars.dataframe.frame.DataFrame, motif_names: List[str], motif_width: int) -> Tuple[polars.dataframe.frame.DataFrame, jaxtyping.Float[ndarray, 'M M']]:
433def seqlet_confusion(
434    hits_df: Union[pl.DataFrame, pl.LazyFrame],
435    seqlets_df: Union[pl.DataFrame, pl.LazyFrame],
436    peaks_df: pl.DataFrame,
437    motif_names: List[str],
438    motif_width: int,
439) -> Tuple[pl.DataFrame, Float[ndarray, "M M"]]:
440    """Compute confusion matrix between TF-MoDISco seqlets and Fi-NeMo hits.
441
442    This function creates a confusion matrix showing the overlap between
443    TF-MoDISco seqlets (ground truth) and Fi-NeMo hits across different motifs.
444    Overlap frequencies are estimated using binned genomic coordinates.
445
446    Parameters
447    ----------
448    hits_df : Union[pl.DataFrame, pl.LazyFrame]
449        Fi-NeMo hit calls with required columns:
450        - peak_id, start_untrimmed, end_untrimmed, strand, motif_name
451    seqlets_df : pl.DataFrame
452        TF-MoDISco seqlets with required columns:
453        - chr_id, start_untrimmed, end_untrimmed, motif_name
454    peaks_df : pl.DataFrame
455        Peak metadata for joining coordinates:
456        - peak_id, chr_id
457    motif_names : List[str]
458        Names of motifs to include in confusion matrix.
459        Determines matrix dimensions.
460    motif_width : int
461        Width used for binning genomic coordinates.
462        Positions are binned to motif_width resolution.
463
464    Returns
465    -------
466    confusion_df : pl.DataFrame
467        Detailed confusion matrix in tabular format with columns:
468        - motif_name_seqlets : Seqlet motif labels (rows)
469        - motif_name_hits : Hit motif labels (columns)
470        - frac_overlap : Fraction of seqlets overlapping with hits
471    confusion_mat : Float[ndarray, "M M"]
472        Confusion matrix where M = len(motif_names).
473        Entry (i,j) = fraction of motif i seqlets overlapping with motif j hits.
474        Rows represent seqlet motifs, columns represent hit motifs.
475
476    Notes
477    -----
478    - Genomic coordinates are binned to motif_width resolution for overlap detection
479    - Only exact bin overlaps are considered (same chr_id, start_bin, end_bin)
480    - Fractions are computed as: overlaps / total_seqlets_per_motif
481    - Missing motif combinations result in zero entries in the confusion matrix
482
483    Raises
484    ------
485    ValueError
486        If required columns are missing from input DataFrames.
487    KeyError
488        If motif names in data don't match those in motif_names list.
489    """
490    bin_size = motif_width
491
492    # Ensure hits_df is LazyFrame for consistent operations
493    if isinstance(hits_df, pl.DataFrame):
494        hits_df = hits_df.lazy()
495
496    hits_binned = (
497        hits_df.with_columns(
498            peak_id=pl.col("peak_id").cast(pl.UInt32),
499            is_revcomp=pl.col("strand") == "-",
500        )
501        .join(peaks_df.lazy(), on="peak_id", how="inner")
502        .unique(subset=["chr_id", "start_untrimmed", "motif_name", "is_revcomp"])
503        .select(
504            chr_id=pl.col("chr_id"),
505            start_bin=pl.col("start_untrimmed") // bin_size,
506            end_bin=pl.col("end_untrimmed") // bin_size,
507            motif_name=pl.col("motif_name"),
508        )
509    )
510
511    seqlets_lazy = seqlets_df.lazy()
512    seqlets_binned = seqlets_lazy.select(
513        chr_id=pl.col("chr_id"),
514        start_bin=pl.col("start_untrimmed") // bin_size,
515        end_bin=pl.col("end_untrimmed") // bin_size,
516        motif_name=pl.col("motif_name"),
517    )
518
519    overlaps_df = seqlets_binned.join(
520        hits_binned, on=["chr_id", "start_bin", "end_bin"], how="inner", suffix="_hits"
521    )
522
523    seqlet_counts = (
524        seqlets_binned.group_by("motif_name").len(name="num_seqlets").collect()
525    )
526    overlap_counts = (
527        overlaps_df.group_by(["motif_name", "motif_name_hits"])
528        .len(name="num_overlaps")
529        .collect()
530    )
531
532    num_motifs = len(motif_names)
533    confusion_mat = np.zeros((num_motifs, num_motifs), dtype=np.float32)
534    name_to_idx = {m: i for i, m in enumerate(motif_names)}
535
536    confusion_df = overlap_counts.join(
537        seqlet_counts, on="motif_name", how="inner"
538    ).select(
539        motif_name_seqlets=pl.col("motif_name"),
540        motif_name_hits=pl.col("motif_name_hits"),
541        frac_overlap=pl.col("num_overlaps") / pl.col("num_seqlets"),
542    )
543
544    confusion_idx_df = confusion_df.select(
545        row_idx=pl.col("motif_name_seqlets").replace_strict(name_to_idx),
546        col_idx=pl.col("motif_name_hits").replace_strict(name_to_idx),
547        frac_overlap=pl.col("frac_overlap"),
548    )
549
550    row_idx = confusion_idx_df["row_idx"].to_numpy()
551    col_idx = confusion_idx_df["col_idx"].to_numpy()
552    frac_overlap = confusion_idx_df["frac_overlap"].to_numpy()
553
554    confusion_mat[row_idx, col_idx] = frac_overlap
555
556    return confusion_df, confusion_mat

Compute confusion matrix between TF-MoDISco seqlets and Fi-NeMo hits.

This function creates a confusion matrix showing the overlap between TF-MoDISco seqlets (ground truth) and Fi-NeMo hits across different motifs. Overlap frequencies are estimated using binned genomic coordinates.

Parameters
  • hits_df (Union[pl.DataFrame, pl.LazyFrame]): Fi-NeMo hit calls with required columns:
    • peak_id, start_untrimmed, end_untrimmed, strand, motif_name
  • seqlets_df (pl.DataFrame): TF-MoDISco seqlets with required columns:
    • chr_id, start_untrimmed, end_untrimmed, motif_name
  • peaks_df (pl.DataFrame): Peak metadata for joining coordinates:
    • peak_id, chr_id
  • motif_names (List[str]): Names of motifs to include in confusion matrix. Determines matrix dimensions.
  • motif_width (int): Width used for binning genomic coordinates. Positions are binned to motif_width resolution.
Returns
  • confusion_df (pl.DataFrame): Detailed confusion matrix in tabular format with columns:
    • motif_name_seqlets : Seqlet motif labels (rows)
    • motif_name_hits : Hit motif labels (columns)
    • frac_overlap : Fraction of seqlets overlapping with hits
  • confusion_mat (Float[ndarray, "M M"]): Confusion matrix where M = len(motif_names). Entry (i,j) = fraction of motif i seqlets overlapping with motif j hits. Rows represent seqlet motifs, columns represent hit motifs.
Notes
  • Genomic coordinates are binned to motif_width resolution for overlap detection
  • Only exact bin overlaps are considered (same chr_id, start_bin, end_bin)
  • Fractions are computed as: overlaps / total_seqlets_per_motif
  • Missing motif combinations result in zero entries in the confusion matrix
Raises
  • ValueError: If required columns are missing from input DataFrames.
  • KeyError: If motif names in data don't match those in motif_names list.