finemo.evaluation

Evaluation module for assessing Fi-NeMo motif discovery and hit calling performance.

This module provides functions for:

  • Computing motif occurrence statistics and co-occurrence patterns
  • Evaluating motif discovery quality against TF-MoDISco results
  • Analyzing hit calling performance and recall metrics
  • Generating confusion matrices for seqlet-hit comparisons
  1"""Evaluation module for assessing Fi-NeMo motif discovery and hit calling performance.
  2
  3This module provides functions for:
  4- Computing motif occurrence statistics and co-occurrence patterns
  5- Evaluating motif discovery quality against TF-MoDISco results
  6- Analyzing hit calling performance and recall metrics
  7- Generating confusion matrices for seqlet-hit comparisons
  8"""
  9
 10import warnings
 11from typing import List, Tuple, Dict, Any, Union
 12
 13import numpy as np
 14from numpy import ndarray
 15import polars as pl
 16from jaxtyping import Float, Int
 17
 18
 19def get_motif_occurences(
 20    hits_df: pl.LazyFrame, motif_names: List[str]
 21) -> Tuple[pl.DataFrame, Int[ndarray, "M M"]]:
 22    """Compute motif occurrence statistics and co-occurrence matrix.
 23
 24    This function analyzes motif occurrence patterns across peaks by creating
 25    a pivot table of hit counts and computing pairwise co-occurrence statistics.
 26
 27    Parameters
 28    ----------
 29    hits_df : pl.LazyFrame
 30        Lazy DataFrame containing hit data with required columns:
 31        - peak_id : Peak identifier
 32        - motif_name : Name of the motif
 33        Additional columns are ignored.
 34    motif_names : List[str]
 35        List of motif names to include in analysis. Missing motifs
 36        will be added as columns with zero counts.
 37
 38    Returns
 39    -------
 40    occ_df : pl.DataFrame
 41        DataFrame with motif occurrence counts per peak. Contains:
 42        - peak_id column
 43        - One column per motif with hit counts
 44        - 'total' column summing all motif counts per peak
 45    coocc : Int[ndarray, "M M"]
 46        Co-occurrence matrix where M = len(motif_names).
 47        Entry (i,j) indicates number of peaks containing both motif i and motif j.
 48        Diagonal entries show total peaks containing each motif.
 49
 50    Notes
 51    -----
 52    The co-occurrence matrix is computed using binary occurrence indicators,
 53    so multiple hits of the same motif in a peak are treated as a single occurrence.
 54    """
 55    occ_df = (
 56        hits_df.collect()
 57        .with_columns(pl.lit(1).alias("count"))
 58        .pivot(
 59            on="motif_name", index="peak_id", values="count", aggregate_function="sum"
 60        )
 61        .fill_null(0)
 62    )
 63
 64    missing_cols = set(motif_names) - set(occ_df.columns)
 65    occ_df = (
 66        occ_df.with_columns([pl.lit(0).alias(m) for m in missing_cols])
 67        .with_columns(total=pl.sum_horizontal(*motif_names))
 68        .sort(["peak_id"])
 69    )
 70
 71    num_peaks = occ_df.height
 72    num_motifs = len(motif_names)
 73
 74    occ_mat = np.zeros((num_peaks, num_motifs), dtype=np.int16)
 75    for i, m in enumerate(motif_names):
 76        occ_mat[:, i] = occ_df.get_column(m).to_numpy()
 77
 78    occ_bin = (occ_mat > 0).astype(np.int32)
 79    coocc = occ_bin.T @ occ_bin
 80
 81    return occ_df, coocc
 82
 83
 84def get_cwms(
 85    regions: Float[ndarray, "N 4 L"], positions_df: pl.DataFrame, motif_width: int
 86) -> Float[ndarray, "H 4 W"]:
 87    """Extract contribution weight matrices from regions based on hit positions.
 88
 89    This function extracts motif-sized windows from contribution score regions
 90    at positions specified by hit coordinates. It handles both forward and
 91    reverse complement orientations and filters out invalid positions.
 92
 93    Parameters
 94    ----------
 95    regions : Float[ndarray, "N 4 L"]
 96        Input contribution score regions multiplied by one-hot sequences.
 97        Shape: (n_peaks, 4, region_width) where 4 represents DNA bases (A,C,G,T).
 98    positions_df : pl.DataFrame
 99        DataFrame containing hit positions with required columns:
100        - peak_id : int, Peak index (0-based)
101        - start_untrimmed : int, Start position in genomic coordinates
102        - peak_region_start : int, Peak region start coordinate
103        - is_revcomp : bool, Whether hit is on reverse complement strand
104    motif_width : int
105        Width of motifs to extract. Must be positive.
106
107    Returns
108    -------
109    cwms : Float[ndarray, "H 4 W"]
110        Extracted contribution weight matrices for valid hits.
111        Shape: (n_valid_hits, 4, motif_width)
112        Invalid hits (outside region boundaries) are filtered out.
113
114    Notes
115    -----
116    - Start positions are converted from genomic to region-relative coordinates
117    - Reverse complement hits have their sequence order reversed
118    - Hits extending beyond region boundaries are excluded
119    - The mean is computed across all valid hits, with warnings suppressed
120      for empty slices or invalid operations
121
122    Raises
123    ------
124    ValueError
125        If motif_width is non-positive or positions_df lacks required columns.
126    """
127    idx_df = positions_df.select(
128        peak_idx=pl.col("peak_id"),
129        start_idx=pl.col("start_untrimmed") - pl.col("peak_region_start"),
130        is_revcomp=pl.col("is_revcomp"),
131    )
132    peak_idx = idx_df.get_column("peak_idx").to_numpy()
133    start_idx = idx_df.get_column("start_idx").to_numpy()
134    is_revcomp = idx_df.get_column("is_revcomp").to_numpy().astype(bool)
135
136    # Filter hits that fall outside the region boundaries
137    valid_mask = (start_idx >= 0) & (start_idx + motif_width <= regions.shape[2])
138    peak_idx = peak_idx[valid_mask]
139    start_idx = start_idx[valid_mask]
140    is_revcomp = is_revcomp[valid_mask]
141
142    row_idx = peak_idx[:, None, None]
143    pos_idx = start_idx[:, None, None] + np.zeros((1, 1, motif_width), dtype=int)
144    pos_idx[~is_revcomp, :, :] += np.arange(motif_width)[None, None, :]
145    pos_idx[is_revcomp, :, :] += np.arange(motif_width)[None, None, ::-1]
146    nuc_idx = np.zeros((peak_idx.shape[0], 4, 1), dtype=int)
147    nuc_idx[~is_revcomp, :, :] += np.arange(4)[None, :, None]
148    nuc_idx[is_revcomp, :, :] += np.arange(4)[None, ::-1, None]
149
150    seqs = regions[row_idx, nuc_idx, pos_idx]
151
152    with warnings.catch_warnings():
153        warnings.filterwarnings(
154            action="ignore", message="invalid value encountered in divide"
155        )
156        warnings.filterwarnings(action="ignore", message="Mean of empty slice")
157        cwms = seqs.mean(axis=0)
158
159    return cwms
160
161
162def tfmodisco_comparison(
163    regions: Float[ndarray, "N 4 L"],
164    hits_df: Union[pl.DataFrame, pl.LazyFrame],
165    peaks_df: pl.DataFrame,
166    seqlets_df: Union[pl.DataFrame, pl.LazyFrame, None],
167    motifs_df: pl.DataFrame,
168    cwms_modisco: Float[ndarray, "M 4 W"],
169    motif_names: List[str],
170    modisco_half_width: int,
171    motif_width: int,
172    compute_recall: bool,
173) -> Tuple[
174    Dict[str, Dict[str, Any]],
175    pl.DataFrame,
176    Dict[str, Dict[str, Float[ndarray, "4 W"]]],
177    Dict[str, Dict[str, Tuple[int, int]]],
178]:
179    """Compare Fi-NeMo hits with TF-MoDISco seqlets and compute evaluation metrics.
180
181    This function performs comprehensive comparison between Fi-NeMo hit calls
182    and TF-MoDISco seqlets, computing recall metrics, CWM similarities,
183    and extracting contribution weight matrices for visualization.
184
185    Parameters
186    ----------
187    regions : Float[ndarray, "N 4 L"]
188        Contribution score regions multiplied by one-hot sequences.
189        Shape: (n_peaks, 4, region_length)
190    hits_df : Union[pl.DataFrame, pl.LazyFrame]
191        Fi-NeMo hit calls with required columns:
192        - peak_id, start_untrimmed, end_untrimmed, strand, motif_name
193    peaks_df : pl.DataFrame
194        Peak metadata with columns:
195        - peak_id, chr_id, peak_region_start
196    seqlets_df : Optional[pl.DataFrame]
197        TF-MoDISco seqlets with columns:
198        - chr_id, start_untrimmed, is_revcomp, motif_name
199        If None, only basic hit statistics are computed.
200    motifs_df : pl.DataFrame
201        Motif metadata with columns:
202        - motif_name, strand, motif_id, motif_start, motif_end
203    cwms_modisco : Float[ndarray, "M 4 W"]
204        TF-MoDISco contribution weight matrices.
205        Shape: (n_modisco_motifs, 4, motif_width)
206    motif_names : List[str]
207        Names of motifs to analyze.
208    modisco_half_width : int
209        Half-width for restricting hits to central region for fair comparison.
210    motif_width : int
211        Width of motifs for CWM extraction.
212    compute_recall : bool
213        Whether to compute recall metrics requiring seqlets_df.
214
215    Returns
216    -------
217    report_data : Dict[str, Dict[str, Any]]
218        Per-motif evaluation metrics including:
219        - num_hits_total, num_hits_restricted, num_seqlets
220        - num_overlaps, seqlet_recall, cwm_similarity
221    report_df : pl.DataFrame
222        Tabular format of report_data for easy analysis.
223    cwms : Dict[str, Dict[str, Float[ndarray, "4 W"]]]
224        Extracted CWMs for each motif and condition:
225        - hits_fc, hits_rc: Forward/reverse complement hits
226        - modisco_fc, modisco_rc: TF-MoDISco forward/reverse
227        - seqlets_only, hits_restricted_only: Non-overlapping instances
228    cwm_trim_bounds : Dict[str, Dict[str, Tuple[int, int]]]
229        Trimming boundaries for each CWM type and motif.
230
231    Notes
232    -----
233    - Hits are filtered to central region defined by modisco_half_width
234    - CWM similarity is computed as normalized dot product between hit and TF-MoDISco CWMs
235    - Recall metrics require both hits_df and seqlets_df to be non-empty
236    - Missing motifs are handled gracefully with empty DataFrames
237
238    Raises
239    ------
240    ValueError
241        If required columns are missing from input DataFrames.
242    """
243
244    # Ensure hits_df is LazyFrame for consistent operations
245    if isinstance(hits_df, pl.DataFrame):
246        hits_df = hits_df.lazy()
247
248    hits_df = (
249        hits_df.with_columns(pl.col("peak_id").cast(pl.UInt32))
250        .join(peaks_df.lazy(), on="peak_id", how="inner")
251        .select(
252            chr_id=pl.col("chr_id"),
253            start_untrimmed=pl.col("start_untrimmed"),
254            end_untrimmed=pl.col("end_untrimmed"),
255            is_revcomp=pl.col("strand") == "-",
256            motif_name=pl.col("motif_name"),
257            peak_region_start=pl.col("peak_region_start"),
258            peak_id=pl.col("peak_id"),
259        )
260    )
261
262    hits_unique = hits_df.unique(
263        subset=["chr_id", "start_untrimmed", "motif_name", "is_revcomp"]
264    )
265
266    region_len = regions.shape[2]
267    center = region_len / 2
268    hits_filtered = hits_df.filter(
269        (
270            (pl.col("start_untrimmed") - pl.col("peak_region_start"))
271            >= (center - modisco_half_width)
272        )
273        & (
274            (pl.col("end_untrimmed") - pl.col("peak_region_start"))
275            <= (center + modisco_half_width)
276        )
277    ).unique(subset=["chr_id", "start_untrimmed", "motif_name", "is_revcomp"])
278
279    hits_by_motif = hits_unique.collect().partition_by("motif_name", as_dict=True)
280    hits_filtered_by_motif = hits_filtered.collect().partition_by(
281        "motif_name", as_dict=True
282    )
283
284    if seqlets_df is None:
285        seqlets_collected = None
286        seqlets_lazy = None
287    elif isinstance(seqlets_df, pl.LazyFrame):
288        seqlets_collected = seqlets_df.collect()
289        seqlets_lazy = seqlets_df
290    else:
291        seqlets_collected = seqlets_df
292        seqlets_lazy = seqlets_df.lazy()
293
294    if seqlets_collected is not None:
295        seqlets_by_motif = seqlets_collected.partition_by("motif_name", as_dict=True)
296    else:
297        seqlets_by_motif = {}
298
299    if compute_recall and seqlets_lazy is not None:
300        overlaps_df = hits_filtered.join(
301            seqlets_lazy,
302            on=["chr_id", "start_untrimmed", "is_revcomp", "motif_name"],
303            how="inner",
304        ).collect()
305
306        seqlets_only_df = seqlets_lazy.join(
307            hits_df,
308            on=["chr_id", "start_untrimmed", "is_revcomp", "motif_name"],
309            how="anti",
310        ).collect()
311
312        hits_only_filtered_df = hits_filtered.join(
313            seqlets_lazy,
314            on=["chr_id", "start_untrimmed", "is_revcomp", "motif_name"],
315            how="anti",
316        ).collect()
317
318        # Create partition dictionaries
319        overlaps_by_motif = overlaps_df.partition_by("motif_name", as_dict=True)
320        seqlets_only_by_motif = seqlets_only_df.partition_by("motif_name", as_dict=True)
321        hits_only_filtered_by_motif = hits_only_filtered_df.partition_by(
322            "motif_name", as_dict=True
323        )
324    else:
325        overlaps_by_motif = {}
326        seqlets_only_by_motif = {}
327        hits_only_filtered_by_motif = {}
328
329    report_data = {}
330    cwms = {}
331    cwm_trim_bounds = {}
332    dummy_df = hits_df.clear().collect()
333    for m in motif_names:
334        hits = hits_by_motif.get((m,), dummy_df)
335        hits_filtered = hits_filtered_by_motif.get((m,), dummy_df)
336
337        # Initialize default values
338        seqlets = dummy_df
339        overlaps = dummy_df
340        seqlets_only = dummy_df
341        hits_only_filtered = dummy_df
342
343        if seqlets_df is not None:
344            seqlets = seqlets_by_motif.get((m,), dummy_df)
345
346        if compute_recall and seqlets_df is not None:
347            overlaps = overlaps_by_motif.get((m,), dummy_df)
348            seqlets_only = seqlets_only_by_motif.get((m,), dummy_df)
349            hits_only_filtered = hits_only_filtered_by_motif.get((m,), dummy_df)
350
351        report_data[m] = {
352            "num_hits_total": hits.height,
353            "num_hits_restricted": hits_filtered.height,
354        }
355
356        if seqlets_df is not None:
357            report_data[m]["num_seqlets"] = seqlets.height
358
359        if compute_recall and seqlets_df is not None:
360            report_data[m] |= {
361                "num_overlaps": overlaps.height,
362                "num_seqlets_only": seqlets_only.height,
363                "num_hits_restricted_only": hits_only_filtered.height,
364                "seqlet_recall": np.float64(overlaps.height) / seqlets.height
365                if seqlets.height > 0
366                else 0.0,
367            }
368
369        motif_data_fc = motifs_df.row(
370            by_predicate=(pl.col("motif_name") == m) & (pl.col("strand") == "+"),
371            named=True,
372        )
373        motif_data_rc = motifs_df.row(
374            by_predicate=(pl.col("motif_name") == m) & (pl.col("strand") == "-"),
375            named=True,
376        )
377
378        cwms[m] = {
379            "hits_fc": get_cwms(regions, hits, motif_width),
380            "modisco_fc": cwms_modisco[motif_data_fc["motif_id"]],
381            "modisco_rc": cwms_modisco[motif_data_rc["motif_id"]],
382        }
383        cwms[m]["hits_rc"] = cwms[m]["hits_fc"][::-1, ::-1]
384
385        if compute_recall and seqlets_df is not None:
386            cwms[m] |= {
387                "seqlets_only": get_cwms(regions, seqlets_only, motif_width),
388                "hits_restricted_only": get_cwms(
389                    regions, hits_only_filtered, motif_width
390                ),
391            }
392
393        bounds_fc = (motif_data_fc["motif_start"], motif_data_fc["motif_end"])
394        bounds_rc = (motif_data_rc["motif_start"], motif_data_rc["motif_end"])
395
396        cwm_trim_bounds[m] = {
397            "hits_fc": bounds_fc,
398            "modisco_fc": bounds_fc,
399            "modisco_rc": bounds_rc,
400            "hits_rc": bounds_rc,
401        }
402
403        if compute_recall and seqlets_df is not None:
404            cwm_trim_bounds[m] |= {
405                "seqlets_only": bounds_fc,
406                "hits_restricted_only": bounds_fc,
407            }
408
409        hits_cwm = cwms[m]["hits_fc"]
410        modisco_cwm = cwms[m]["modisco_fc"]
411        hnorm = np.sqrt((hits_cwm**2).sum())
412        snorm = np.sqrt((modisco_cwm**2).sum())
413        cwm_sim = (hits_cwm * modisco_cwm).sum() / (hnorm * snorm)
414
415        report_data[m]["cwm_similarity"] = cwm_sim
416
417    records = [{"motif_name": k} | v for k, v in report_data.items()]
418    report_df = pl.from_dicts(records)
419
420    return report_data, report_df, cwms, cwm_trim_bounds
421
422
423def seqlet_confusion(
424    hits_df: Union[pl.DataFrame, pl.LazyFrame],
425    seqlets_df: Union[pl.DataFrame, pl.LazyFrame],
426    peaks_df: pl.DataFrame,
427    motif_names: List[str],
428    motif_width: int,
429) -> Tuple[pl.DataFrame, Float[ndarray, "M M"]]:
430    """Compute confusion matrix between TF-MoDISco seqlets and Fi-NeMo hits.
431
432    This function creates a confusion matrix showing the overlap between
433    TF-MoDISco seqlets (ground truth) and Fi-NeMo hits across different motifs.
434    Overlap frequencies are estimated using binned genomic coordinates.
435
436    Parameters
437    ----------
438    hits_df : Union[pl.DataFrame, pl.LazyFrame]
439        Fi-NeMo hit calls with required columns:
440        - peak_id, start_untrimmed, end_untrimmed, strand, motif_name
441    seqlets_df : pl.DataFrame
442        TF-MoDISco seqlets with required columns:
443        - chr_id, start_untrimmed, end_untrimmed, motif_name
444    peaks_df : pl.DataFrame
445        Peak metadata for joining coordinates:
446        - peak_id, chr_id
447    motif_names : List[str]
448        Names of motifs to include in confusion matrix.
449        Determines matrix dimensions.
450    motif_width : int
451        Width used for binning genomic coordinates.
452        Positions are binned to motif_width resolution.
453
454    Returns
455    -------
456    confusion_df : pl.DataFrame
457        Detailed confusion matrix in tabular format with columns:
458        - motif_name_seqlets : Seqlet motif labels (rows)
459        - motif_name_hits : Hit motif labels (columns)
460        - frac_overlap : Fraction of seqlets overlapping with hits
461    confusion_mat : Float[ndarray, "M M"]
462        Confusion matrix where M = len(motif_names).
463        Entry (i,j) = fraction of motif i seqlets overlapping with motif j hits.
464        Rows represent seqlet motifs, columns represent hit motifs.
465
466    Notes
467    -----
468    - Genomic coordinates are binned to motif_width resolution for overlap detection
469    - Only exact bin overlaps are considered (same chr_id, start_bin, end_bin)
470    - Fractions are computed as: overlaps / total_seqlets_per_motif
471    - Missing motif combinations result in zero entries in the confusion matrix
472
473    Raises
474    ------
475    ValueError
476        If required columns are missing from input DataFrames.
477    KeyError
478        If motif names in data don't match those in motif_names list.
479    """
480    bin_size = motif_width
481
482    # Ensure hits_df is LazyFrame for consistent operations
483    if isinstance(hits_df, pl.DataFrame):
484        hits_df = hits_df.lazy()
485
486    hits_binned = (
487        hits_df.with_columns(
488            peak_id=pl.col("peak_id").cast(pl.UInt32),
489            is_revcomp=pl.col("strand") == "-",
490        )
491        .join(peaks_df.lazy(), on="peak_id", how="inner")
492        .unique(subset=["chr_id", "start_untrimmed", "motif_name", "is_revcomp"])
493        .select(
494            chr_id=pl.col("chr_id"),
495            start_bin=pl.col("start_untrimmed") // bin_size,
496            end_bin=pl.col("end_untrimmed") // bin_size,
497            motif_name=pl.col("motif_name"),
498        )
499    )
500
501    seqlets_lazy = seqlets_df.lazy()
502    seqlets_binned = seqlets_lazy.select(
503        chr_id=pl.col("chr_id"),
504        start_bin=pl.col("start_untrimmed") // bin_size,
505        end_bin=pl.col("end_untrimmed") // bin_size,
506        motif_name=pl.col("motif_name"),
507    )
508
509    overlaps_df = seqlets_binned.join(
510        hits_binned, on=["chr_id", "start_bin", "end_bin"], how="inner", suffix="_hits"
511    )
512
513    seqlet_counts = (
514        seqlets_binned.group_by("motif_name").len(name="num_seqlets").collect()
515    )
516    overlap_counts = (
517        overlaps_df.group_by(["motif_name", "motif_name_hits"])
518        .len(name="num_overlaps")
519        .collect()
520    )
521
522    num_motifs = len(motif_names)
523    confusion_mat = np.zeros((num_motifs, num_motifs), dtype=np.float32)
524    name_to_idx = {m: i for i, m in enumerate(motif_names)}
525
526    confusion_df = overlap_counts.join(
527        seqlet_counts, on="motif_name", how="inner"
528    ).select(
529        motif_name_seqlets=pl.col("motif_name"),
530        motif_name_hits=pl.col("motif_name_hits"),
531        frac_overlap=pl.col("num_overlaps") / pl.col("num_seqlets"),
532    )
533
534    confusion_idx_df = confusion_df.select(
535        row_idx=pl.col("motif_name_seqlets").replace_strict(name_to_idx),
536        col_idx=pl.col("motif_name_hits").replace_strict(name_to_idx),
537        frac_overlap=pl.col("frac_overlap"),
538    )
539
540    row_idx = confusion_idx_df["row_idx"].to_numpy()
541    col_idx = confusion_idx_df["col_idx"].to_numpy()
542    frac_overlap = confusion_idx_df["frac_overlap"].to_numpy()
543
544    confusion_mat[row_idx, col_idx] = frac_overlap
545
546    return confusion_df, confusion_mat
def get_motif_occurences( hits_df: polars.lazyframe.frame.LazyFrame, motif_names: List[str]) -> Tuple[polars.dataframe.frame.DataFrame, jaxtyping.Int[ndarray, 'M M']]:
20def get_motif_occurences(
21    hits_df: pl.LazyFrame, motif_names: List[str]
22) -> Tuple[pl.DataFrame, Int[ndarray, "M M"]]:
23    """Compute motif occurrence statistics and co-occurrence matrix.
24
25    This function analyzes motif occurrence patterns across peaks by creating
26    a pivot table of hit counts and computing pairwise co-occurrence statistics.
27
28    Parameters
29    ----------
30    hits_df : pl.LazyFrame
31        Lazy DataFrame containing hit data with required columns:
32        - peak_id : Peak identifier
33        - motif_name : Name of the motif
34        Additional columns are ignored.
35    motif_names : List[str]
36        List of motif names to include in analysis. Missing motifs
37        will be added as columns with zero counts.
38
39    Returns
40    -------
41    occ_df : pl.DataFrame
42        DataFrame with motif occurrence counts per peak. Contains:
43        - peak_id column
44        - One column per motif with hit counts
45        - 'total' column summing all motif counts per peak
46    coocc : Int[ndarray, "M M"]
47        Co-occurrence matrix where M = len(motif_names).
48        Entry (i,j) indicates number of peaks containing both motif i and motif j.
49        Diagonal entries show total peaks containing each motif.
50
51    Notes
52    -----
53    The co-occurrence matrix is computed using binary occurrence indicators,
54    so multiple hits of the same motif in a peak are treated as a single occurrence.
55    """
56    occ_df = (
57        hits_df.collect()
58        .with_columns(pl.lit(1).alias("count"))
59        .pivot(
60            on="motif_name", index="peak_id", values="count", aggregate_function="sum"
61        )
62        .fill_null(0)
63    )
64
65    missing_cols = set(motif_names) - set(occ_df.columns)
66    occ_df = (
67        occ_df.with_columns([pl.lit(0).alias(m) for m in missing_cols])
68        .with_columns(total=pl.sum_horizontal(*motif_names))
69        .sort(["peak_id"])
70    )
71
72    num_peaks = occ_df.height
73    num_motifs = len(motif_names)
74
75    occ_mat = np.zeros((num_peaks, num_motifs), dtype=np.int16)
76    for i, m in enumerate(motif_names):
77        occ_mat[:, i] = occ_df.get_column(m).to_numpy()
78
79    occ_bin = (occ_mat > 0).astype(np.int32)
80    coocc = occ_bin.T @ occ_bin
81
82    return occ_df, coocc

Compute motif occurrence statistics and co-occurrence matrix.

This function analyzes motif occurrence patterns across peaks by creating a pivot table of hit counts and computing pairwise co-occurrence statistics.

Parameters
  • hits_df (pl.LazyFrame): Lazy DataFrame containing hit data with required columns:
    • peak_id : Peak identifier
    • motif_name : Name of the motif Additional columns are ignored.
  • motif_names (List[str]): List of motif names to include in analysis. Missing motifs will be added as columns with zero counts.
Returns
  • occ_df (pl.DataFrame): DataFrame with motif occurrence counts per peak. Contains:
    • peak_id column
    • One column per motif with hit counts
    • 'total' column summing all motif counts per peak
  • coocc (Int[ndarray, "M M"]): Co-occurrence matrix where M = len(motif_names). Entry (i,j) indicates number of peaks containing both motif i and motif j. Diagonal entries show total peaks containing each motif.
Notes

The co-occurrence matrix is computed using binary occurrence indicators, so multiple hits of the same motif in a peak are treated as a single occurrence.

def get_cwms( regions: jaxtyping.Float[ndarray, 'N 4 L'], positions_df: polars.dataframe.frame.DataFrame, motif_width: int) -> jaxtyping.Float[ndarray, 'H 4 W']:
 85def get_cwms(
 86    regions: Float[ndarray, "N 4 L"], positions_df: pl.DataFrame, motif_width: int
 87) -> Float[ndarray, "H 4 W"]:
 88    """Extract contribution weight matrices from regions based on hit positions.
 89
 90    This function extracts motif-sized windows from contribution score regions
 91    at positions specified by hit coordinates. It handles both forward and
 92    reverse complement orientations and filters out invalid positions.
 93
 94    Parameters
 95    ----------
 96    regions : Float[ndarray, "N 4 L"]
 97        Input contribution score regions multiplied by one-hot sequences.
 98        Shape: (n_peaks, 4, region_width) where 4 represents DNA bases (A,C,G,T).
 99    positions_df : pl.DataFrame
100        DataFrame containing hit positions with required columns:
101        - peak_id : int, Peak index (0-based)
102        - start_untrimmed : int, Start position in genomic coordinates
103        - peak_region_start : int, Peak region start coordinate
104        - is_revcomp : bool, Whether hit is on reverse complement strand
105    motif_width : int
106        Width of motifs to extract. Must be positive.
107
108    Returns
109    -------
110    cwms : Float[ndarray, "H 4 W"]
111        Extracted contribution weight matrices for valid hits.
112        Shape: (n_valid_hits, 4, motif_width)
113        Invalid hits (outside region boundaries) are filtered out.
114
115    Notes
116    -----
117    - Start positions are converted from genomic to region-relative coordinates
118    - Reverse complement hits have their sequence order reversed
119    - Hits extending beyond region boundaries are excluded
120    - The mean is computed across all valid hits, with warnings suppressed
121      for empty slices or invalid operations
122
123    Raises
124    ------
125    ValueError
126        If motif_width is non-positive or positions_df lacks required columns.
127    """
128    idx_df = positions_df.select(
129        peak_idx=pl.col("peak_id"),
130        start_idx=pl.col("start_untrimmed") - pl.col("peak_region_start"),
131        is_revcomp=pl.col("is_revcomp"),
132    )
133    peak_idx = idx_df.get_column("peak_idx").to_numpy()
134    start_idx = idx_df.get_column("start_idx").to_numpy()
135    is_revcomp = idx_df.get_column("is_revcomp").to_numpy().astype(bool)
136
137    # Filter hits that fall outside the region boundaries
138    valid_mask = (start_idx >= 0) & (start_idx + motif_width <= regions.shape[2])
139    peak_idx = peak_idx[valid_mask]
140    start_idx = start_idx[valid_mask]
141    is_revcomp = is_revcomp[valid_mask]
142
143    row_idx = peak_idx[:, None, None]
144    pos_idx = start_idx[:, None, None] + np.zeros((1, 1, motif_width), dtype=int)
145    pos_idx[~is_revcomp, :, :] += np.arange(motif_width)[None, None, :]
146    pos_idx[is_revcomp, :, :] += np.arange(motif_width)[None, None, ::-1]
147    nuc_idx = np.zeros((peak_idx.shape[0], 4, 1), dtype=int)
148    nuc_idx[~is_revcomp, :, :] += np.arange(4)[None, :, None]
149    nuc_idx[is_revcomp, :, :] += np.arange(4)[None, ::-1, None]
150
151    seqs = regions[row_idx, nuc_idx, pos_idx]
152
153    with warnings.catch_warnings():
154        warnings.filterwarnings(
155            action="ignore", message="invalid value encountered in divide"
156        )
157        warnings.filterwarnings(action="ignore", message="Mean of empty slice")
158        cwms = seqs.mean(axis=0)
159
160    return cwms

Extract contribution weight matrices from regions based on hit positions.

This function extracts motif-sized windows from contribution score regions at positions specified by hit coordinates. It handles both forward and reverse complement orientations and filters out invalid positions.

Parameters
  • regions (Float[ndarray, "N 4 L"]): Input contribution score regions multiplied by one-hot sequences. Shape: (n_peaks, 4, region_width) where 4 represents DNA bases (A,C,G,T).
  • positions_df (pl.DataFrame): DataFrame containing hit positions with required columns:
    • peak_id : int, Peak index (0-based)
    • start_untrimmed : int, Start position in genomic coordinates
    • peak_region_start : int, Peak region start coordinate
    • is_revcomp : bool, Whether hit is on reverse complement strand
  • motif_width (int): Width of motifs to extract. Must be positive.
Returns
  • cwms (Float[ndarray, "H 4 W"]): Extracted contribution weight matrices for valid hits. Shape: (n_valid_hits, 4, motif_width) Invalid hits (outside region boundaries) are filtered out.
Notes
  • Start positions are converted from genomic to region-relative coordinates
  • Reverse complement hits have their sequence order reversed
  • Hits extending beyond region boundaries are excluded
  • The mean is computed across all valid hits, with warnings suppressed for empty slices or invalid operations
Raises
  • ValueError: If motif_width is non-positive or positions_df lacks required columns.
def tfmodisco_comparison( regions: jaxtyping.Float[ndarray, 'N 4 L'], hits_df: Union[polars.dataframe.frame.DataFrame, polars.lazyframe.frame.LazyFrame], peaks_df: polars.dataframe.frame.DataFrame, seqlets_df: Union[polars.dataframe.frame.DataFrame, polars.lazyframe.frame.LazyFrame, NoneType], motifs_df: polars.dataframe.frame.DataFrame, cwms_modisco: jaxtyping.Float[ndarray, 'M 4 W'], motif_names: List[str], modisco_half_width: int, motif_width: int, compute_recall: bool) -> Tuple[Dict[str, Dict[str, Any]], polars.dataframe.frame.DataFrame, Dict[str, Dict[str, jaxtyping.Float[ndarray, '4 W']]], Dict[str, Dict[str, Tuple[int, int]]]]:
163def tfmodisco_comparison(
164    regions: Float[ndarray, "N 4 L"],
165    hits_df: Union[pl.DataFrame, pl.LazyFrame],
166    peaks_df: pl.DataFrame,
167    seqlets_df: Union[pl.DataFrame, pl.LazyFrame, None],
168    motifs_df: pl.DataFrame,
169    cwms_modisco: Float[ndarray, "M 4 W"],
170    motif_names: List[str],
171    modisco_half_width: int,
172    motif_width: int,
173    compute_recall: bool,
174) -> Tuple[
175    Dict[str, Dict[str, Any]],
176    pl.DataFrame,
177    Dict[str, Dict[str, Float[ndarray, "4 W"]]],
178    Dict[str, Dict[str, Tuple[int, int]]],
179]:
180    """Compare Fi-NeMo hits with TF-MoDISco seqlets and compute evaluation metrics.
181
182    This function performs comprehensive comparison between Fi-NeMo hit calls
183    and TF-MoDISco seqlets, computing recall metrics, CWM similarities,
184    and extracting contribution weight matrices for visualization.
185
186    Parameters
187    ----------
188    regions : Float[ndarray, "N 4 L"]
189        Contribution score regions multiplied by one-hot sequences.
190        Shape: (n_peaks, 4, region_length)
191    hits_df : Union[pl.DataFrame, pl.LazyFrame]
192        Fi-NeMo hit calls with required columns:
193        - peak_id, start_untrimmed, end_untrimmed, strand, motif_name
194    peaks_df : pl.DataFrame
195        Peak metadata with columns:
196        - peak_id, chr_id, peak_region_start
197    seqlets_df : Optional[pl.DataFrame]
198        TF-MoDISco seqlets with columns:
199        - chr_id, start_untrimmed, is_revcomp, motif_name
200        If None, only basic hit statistics are computed.
201    motifs_df : pl.DataFrame
202        Motif metadata with columns:
203        - motif_name, strand, motif_id, motif_start, motif_end
204    cwms_modisco : Float[ndarray, "M 4 W"]
205        TF-MoDISco contribution weight matrices.
206        Shape: (n_modisco_motifs, 4, motif_width)
207    motif_names : List[str]
208        Names of motifs to analyze.
209    modisco_half_width : int
210        Half-width for restricting hits to central region for fair comparison.
211    motif_width : int
212        Width of motifs for CWM extraction.
213    compute_recall : bool
214        Whether to compute recall metrics requiring seqlets_df.
215
216    Returns
217    -------
218    report_data : Dict[str, Dict[str, Any]]
219        Per-motif evaluation metrics including:
220        - num_hits_total, num_hits_restricted, num_seqlets
221        - num_overlaps, seqlet_recall, cwm_similarity
222    report_df : pl.DataFrame
223        Tabular format of report_data for easy analysis.
224    cwms : Dict[str, Dict[str, Float[ndarray, "4 W"]]]
225        Extracted CWMs for each motif and condition:
226        - hits_fc, hits_rc: Forward/reverse complement hits
227        - modisco_fc, modisco_rc: TF-MoDISco forward/reverse
228        - seqlets_only, hits_restricted_only: Non-overlapping instances
229    cwm_trim_bounds : Dict[str, Dict[str, Tuple[int, int]]]
230        Trimming boundaries for each CWM type and motif.
231
232    Notes
233    -----
234    - Hits are filtered to central region defined by modisco_half_width
235    - CWM similarity is computed as normalized dot product between hit and TF-MoDISco CWMs
236    - Recall metrics require both hits_df and seqlets_df to be non-empty
237    - Missing motifs are handled gracefully with empty DataFrames
238
239    Raises
240    ------
241    ValueError
242        If required columns are missing from input DataFrames.
243    """
244
245    # Ensure hits_df is LazyFrame for consistent operations
246    if isinstance(hits_df, pl.DataFrame):
247        hits_df = hits_df.lazy()
248
249    hits_df = (
250        hits_df.with_columns(pl.col("peak_id").cast(pl.UInt32))
251        .join(peaks_df.lazy(), on="peak_id", how="inner")
252        .select(
253            chr_id=pl.col("chr_id"),
254            start_untrimmed=pl.col("start_untrimmed"),
255            end_untrimmed=pl.col("end_untrimmed"),
256            is_revcomp=pl.col("strand") == "-",
257            motif_name=pl.col("motif_name"),
258            peak_region_start=pl.col("peak_region_start"),
259            peak_id=pl.col("peak_id"),
260        )
261    )
262
263    hits_unique = hits_df.unique(
264        subset=["chr_id", "start_untrimmed", "motif_name", "is_revcomp"]
265    )
266
267    region_len = regions.shape[2]
268    center = region_len / 2
269    hits_filtered = hits_df.filter(
270        (
271            (pl.col("start_untrimmed") - pl.col("peak_region_start"))
272            >= (center - modisco_half_width)
273        )
274        & (
275            (pl.col("end_untrimmed") - pl.col("peak_region_start"))
276            <= (center + modisco_half_width)
277        )
278    ).unique(subset=["chr_id", "start_untrimmed", "motif_name", "is_revcomp"])
279
280    hits_by_motif = hits_unique.collect().partition_by("motif_name", as_dict=True)
281    hits_filtered_by_motif = hits_filtered.collect().partition_by(
282        "motif_name", as_dict=True
283    )
284
285    if seqlets_df is None:
286        seqlets_collected = None
287        seqlets_lazy = None
288    elif isinstance(seqlets_df, pl.LazyFrame):
289        seqlets_collected = seqlets_df.collect()
290        seqlets_lazy = seqlets_df
291    else:
292        seqlets_collected = seqlets_df
293        seqlets_lazy = seqlets_df.lazy()
294
295    if seqlets_collected is not None:
296        seqlets_by_motif = seqlets_collected.partition_by("motif_name", as_dict=True)
297    else:
298        seqlets_by_motif = {}
299
300    if compute_recall and seqlets_lazy is not None:
301        overlaps_df = hits_filtered.join(
302            seqlets_lazy,
303            on=["chr_id", "start_untrimmed", "is_revcomp", "motif_name"],
304            how="inner",
305        ).collect()
306
307        seqlets_only_df = seqlets_lazy.join(
308            hits_df,
309            on=["chr_id", "start_untrimmed", "is_revcomp", "motif_name"],
310            how="anti",
311        ).collect()
312
313        hits_only_filtered_df = hits_filtered.join(
314            seqlets_lazy,
315            on=["chr_id", "start_untrimmed", "is_revcomp", "motif_name"],
316            how="anti",
317        ).collect()
318
319        # Create partition dictionaries
320        overlaps_by_motif = overlaps_df.partition_by("motif_name", as_dict=True)
321        seqlets_only_by_motif = seqlets_only_df.partition_by("motif_name", as_dict=True)
322        hits_only_filtered_by_motif = hits_only_filtered_df.partition_by(
323            "motif_name", as_dict=True
324        )
325    else:
326        overlaps_by_motif = {}
327        seqlets_only_by_motif = {}
328        hits_only_filtered_by_motif = {}
329
330    report_data = {}
331    cwms = {}
332    cwm_trim_bounds = {}
333    dummy_df = hits_df.clear().collect()
334    for m in motif_names:
335        hits = hits_by_motif.get((m,), dummy_df)
336        hits_filtered = hits_filtered_by_motif.get((m,), dummy_df)
337
338        # Initialize default values
339        seqlets = dummy_df
340        overlaps = dummy_df
341        seqlets_only = dummy_df
342        hits_only_filtered = dummy_df
343
344        if seqlets_df is not None:
345            seqlets = seqlets_by_motif.get((m,), dummy_df)
346
347        if compute_recall and seqlets_df is not None:
348            overlaps = overlaps_by_motif.get((m,), dummy_df)
349            seqlets_only = seqlets_only_by_motif.get((m,), dummy_df)
350            hits_only_filtered = hits_only_filtered_by_motif.get((m,), dummy_df)
351
352        report_data[m] = {
353            "num_hits_total": hits.height,
354            "num_hits_restricted": hits_filtered.height,
355        }
356
357        if seqlets_df is not None:
358            report_data[m]["num_seqlets"] = seqlets.height
359
360        if compute_recall and seqlets_df is not None:
361            report_data[m] |= {
362                "num_overlaps": overlaps.height,
363                "num_seqlets_only": seqlets_only.height,
364                "num_hits_restricted_only": hits_only_filtered.height,
365                "seqlet_recall": np.float64(overlaps.height) / seqlets.height
366                if seqlets.height > 0
367                else 0.0,
368            }
369
370        motif_data_fc = motifs_df.row(
371            by_predicate=(pl.col("motif_name") == m) & (pl.col("strand") == "+"),
372            named=True,
373        )
374        motif_data_rc = motifs_df.row(
375            by_predicate=(pl.col("motif_name") == m) & (pl.col("strand") == "-"),
376            named=True,
377        )
378
379        cwms[m] = {
380            "hits_fc": get_cwms(regions, hits, motif_width),
381            "modisco_fc": cwms_modisco[motif_data_fc["motif_id"]],
382            "modisco_rc": cwms_modisco[motif_data_rc["motif_id"]],
383        }
384        cwms[m]["hits_rc"] = cwms[m]["hits_fc"][::-1, ::-1]
385
386        if compute_recall and seqlets_df is not None:
387            cwms[m] |= {
388                "seqlets_only": get_cwms(regions, seqlets_only, motif_width),
389                "hits_restricted_only": get_cwms(
390                    regions, hits_only_filtered, motif_width
391                ),
392            }
393
394        bounds_fc = (motif_data_fc["motif_start"], motif_data_fc["motif_end"])
395        bounds_rc = (motif_data_rc["motif_start"], motif_data_rc["motif_end"])
396
397        cwm_trim_bounds[m] = {
398            "hits_fc": bounds_fc,
399            "modisco_fc": bounds_fc,
400            "modisco_rc": bounds_rc,
401            "hits_rc": bounds_rc,
402        }
403
404        if compute_recall and seqlets_df is not None:
405            cwm_trim_bounds[m] |= {
406                "seqlets_only": bounds_fc,
407                "hits_restricted_only": bounds_fc,
408            }
409
410        hits_cwm = cwms[m]["hits_fc"]
411        modisco_cwm = cwms[m]["modisco_fc"]
412        hnorm = np.sqrt((hits_cwm**2).sum())
413        snorm = np.sqrt((modisco_cwm**2).sum())
414        cwm_sim = (hits_cwm * modisco_cwm).sum() / (hnorm * snorm)
415
416        report_data[m]["cwm_similarity"] = cwm_sim
417
418    records = [{"motif_name": k} | v for k, v in report_data.items()]
419    report_df = pl.from_dicts(records)
420
421    return report_data, report_df, cwms, cwm_trim_bounds

Compare Fi-NeMo hits with TF-MoDISco seqlets and compute evaluation metrics.

This function performs comprehensive comparison between Fi-NeMo hit calls and TF-MoDISco seqlets, computing recall metrics, CWM similarities, and extracting contribution weight matrices for visualization.

Parameters
  • regions (Float[ndarray, "N 4 L"]): Contribution score regions multiplied by one-hot sequences. Shape: (n_peaks, 4, region_length)
  • hits_df (Union[pl.DataFrame, pl.LazyFrame]): Fi-NeMo hit calls with required columns:
    • peak_id, start_untrimmed, end_untrimmed, strand, motif_name
  • peaks_df (pl.DataFrame): Peak metadata with columns:
    • peak_id, chr_id, peak_region_start
  • seqlets_df (Optional[pl.DataFrame]): TF-MoDISco seqlets with columns:
    • chr_id, start_untrimmed, is_revcomp, motif_name If None, only basic hit statistics are computed.
  • motifs_df (pl.DataFrame): Motif metadata with columns:
    • motif_name, strand, motif_id, motif_start, motif_end
  • cwms_modisco (Float[ndarray, "M 4 W"]): TF-MoDISco contribution weight matrices. Shape: (n_modisco_motifs, 4, motif_width)
  • motif_names (List[str]): Names of motifs to analyze.
  • modisco_half_width (int): Half-width for restricting hits to central region for fair comparison.
  • motif_width (int): Width of motifs for CWM extraction.
  • compute_recall (bool): Whether to compute recall metrics requiring seqlets_df.
Returns
  • report_data (Dict[str, Dict[str, Any]]): Per-motif evaluation metrics including:
    • num_hits_total, num_hits_restricted, num_seqlets
    • num_overlaps, seqlet_recall, cwm_similarity
  • report_df (pl.DataFrame): Tabular format of report_data for easy analysis.
  • cwms (Dict[str, Dict[str, Float[ndarray, "4 W"]]]): Extracted CWMs for each motif and condition:
    • hits_fc, hits_rc: Forward/reverse complement hits
    • modisco_fc, modisco_rc: TF-MoDISco forward/reverse
    • seqlets_only, hits_restricted_only: Non-overlapping instances
  • cwm_trim_bounds (Dict[str, Dict[str, Tuple[int, int]]]): Trimming boundaries for each CWM type and motif.
Notes
  • Hits are filtered to central region defined by modisco_half_width
  • CWM similarity is computed as normalized dot product between hit and TF-MoDISco CWMs
  • Recall metrics require both hits_df and seqlets_df to be non-empty
  • Missing motifs are handled gracefully with empty DataFrames
Raises
  • ValueError: If required columns are missing from input DataFrames.
def seqlet_confusion( hits_df: Union[polars.dataframe.frame.DataFrame, polars.lazyframe.frame.LazyFrame], seqlets_df: Union[polars.dataframe.frame.DataFrame, polars.lazyframe.frame.LazyFrame], peaks_df: polars.dataframe.frame.DataFrame, motif_names: List[str], motif_width: int) -> Tuple[polars.dataframe.frame.DataFrame, jaxtyping.Float[ndarray, 'M M']]:
424def seqlet_confusion(
425    hits_df: Union[pl.DataFrame, pl.LazyFrame],
426    seqlets_df: Union[pl.DataFrame, pl.LazyFrame],
427    peaks_df: pl.DataFrame,
428    motif_names: List[str],
429    motif_width: int,
430) -> Tuple[pl.DataFrame, Float[ndarray, "M M"]]:
431    """Compute confusion matrix between TF-MoDISco seqlets and Fi-NeMo hits.
432
433    This function creates a confusion matrix showing the overlap between
434    TF-MoDISco seqlets (ground truth) and Fi-NeMo hits across different motifs.
435    Overlap frequencies are estimated using binned genomic coordinates.
436
437    Parameters
438    ----------
439    hits_df : Union[pl.DataFrame, pl.LazyFrame]
440        Fi-NeMo hit calls with required columns:
441        - peak_id, start_untrimmed, end_untrimmed, strand, motif_name
442    seqlets_df : pl.DataFrame
443        TF-MoDISco seqlets with required columns:
444        - chr_id, start_untrimmed, end_untrimmed, motif_name
445    peaks_df : pl.DataFrame
446        Peak metadata for joining coordinates:
447        - peak_id, chr_id
448    motif_names : List[str]
449        Names of motifs to include in confusion matrix.
450        Determines matrix dimensions.
451    motif_width : int
452        Width used for binning genomic coordinates.
453        Positions are binned to motif_width resolution.
454
455    Returns
456    -------
457    confusion_df : pl.DataFrame
458        Detailed confusion matrix in tabular format with columns:
459        - motif_name_seqlets : Seqlet motif labels (rows)
460        - motif_name_hits : Hit motif labels (columns)
461        - frac_overlap : Fraction of seqlets overlapping with hits
462    confusion_mat : Float[ndarray, "M M"]
463        Confusion matrix where M = len(motif_names).
464        Entry (i,j) = fraction of motif i seqlets overlapping with motif j hits.
465        Rows represent seqlet motifs, columns represent hit motifs.
466
467    Notes
468    -----
469    - Genomic coordinates are binned to motif_width resolution for overlap detection
470    - Only exact bin overlaps are considered (same chr_id, start_bin, end_bin)
471    - Fractions are computed as: overlaps / total_seqlets_per_motif
472    - Missing motif combinations result in zero entries in the confusion matrix
473
474    Raises
475    ------
476    ValueError
477        If required columns are missing from input DataFrames.
478    KeyError
479        If motif names in data don't match those in motif_names list.
480    """
481    bin_size = motif_width
482
483    # Ensure hits_df is LazyFrame for consistent operations
484    if isinstance(hits_df, pl.DataFrame):
485        hits_df = hits_df.lazy()
486
487    hits_binned = (
488        hits_df.with_columns(
489            peak_id=pl.col("peak_id").cast(pl.UInt32),
490            is_revcomp=pl.col("strand") == "-",
491        )
492        .join(peaks_df.lazy(), on="peak_id", how="inner")
493        .unique(subset=["chr_id", "start_untrimmed", "motif_name", "is_revcomp"])
494        .select(
495            chr_id=pl.col("chr_id"),
496            start_bin=pl.col("start_untrimmed") // bin_size,
497            end_bin=pl.col("end_untrimmed") // bin_size,
498            motif_name=pl.col("motif_name"),
499        )
500    )
501
502    seqlets_lazy = seqlets_df.lazy()
503    seqlets_binned = seqlets_lazy.select(
504        chr_id=pl.col("chr_id"),
505        start_bin=pl.col("start_untrimmed") // bin_size,
506        end_bin=pl.col("end_untrimmed") // bin_size,
507        motif_name=pl.col("motif_name"),
508    )
509
510    overlaps_df = seqlets_binned.join(
511        hits_binned, on=["chr_id", "start_bin", "end_bin"], how="inner", suffix="_hits"
512    )
513
514    seqlet_counts = (
515        seqlets_binned.group_by("motif_name").len(name="num_seqlets").collect()
516    )
517    overlap_counts = (
518        overlaps_df.group_by(["motif_name", "motif_name_hits"])
519        .len(name="num_overlaps")
520        .collect()
521    )
522
523    num_motifs = len(motif_names)
524    confusion_mat = np.zeros((num_motifs, num_motifs), dtype=np.float32)
525    name_to_idx = {m: i for i, m in enumerate(motif_names)}
526
527    confusion_df = overlap_counts.join(
528        seqlet_counts, on="motif_name", how="inner"
529    ).select(
530        motif_name_seqlets=pl.col("motif_name"),
531        motif_name_hits=pl.col("motif_name_hits"),
532        frac_overlap=pl.col("num_overlaps") / pl.col("num_seqlets"),
533    )
534
535    confusion_idx_df = confusion_df.select(
536        row_idx=pl.col("motif_name_seqlets").replace_strict(name_to_idx),
537        col_idx=pl.col("motif_name_hits").replace_strict(name_to_idx),
538        frac_overlap=pl.col("frac_overlap"),
539    )
540
541    row_idx = confusion_idx_df["row_idx"].to_numpy()
542    col_idx = confusion_idx_df["col_idx"].to_numpy()
543    frac_overlap = confusion_idx_df["frac_overlap"].to_numpy()
544
545    confusion_mat[row_idx, col_idx] = frac_overlap
546
547    return confusion_df, confusion_mat

Compute confusion matrix between TF-MoDISco seqlets and Fi-NeMo hits.

This function creates a confusion matrix showing the overlap between TF-MoDISco seqlets (ground truth) and Fi-NeMo hits across different motifs. Overlap frequencies are estimated using binned genomic coordinates.

Parameters
  • hits_df (Union[pl.DataFrame, pl.LazyFrame]): Fi-NeMo hit calls with required columns:
    • peak_id, start_untrimmed, end_untrimmed, strand, motif_name
  • seqlets_df (pl.DataFrame): TF-MoDISco seqlets with required columns:
    • chr_id, start_untrimmed, end_untrimmed, motif_name
  • peaks_df (pl.DataFrame): Peak metadata for joining coordinates:
    • peak_id, chr_id
  • motif_names (List[str]): Names of motifs to include in confusion matrix. Determines matrix dimensions.
  • motif_width (int): Width used for binning genomic coordinates. Positions are binned to motif_width resolution.
Returns
  • confusion_df (pl.DataFrame): Detailed confusion matrix in tabular format with columns:
    • motif_name_seqlets : Seqlet motif labels (rows)
    • motif_name_hits : Hit motif labels (columns)
    • frac_overlap : Fraction of seqlets overlapping with hits
  • confusion_mat (Float[ndarray, "M M"]): Confusion matrix where M = len(motif_names). Entry (i,j) = fraction of motif i seqlets overlapping with motif j hits. Rows represent seqlet motifs, columns represent hit motifs.
Notes
  • Genomic coordinates are binned to motif_width resolution for overlap detection
  • Only exact bin overlaps are considered (same chr_id, start_bin, end_bin)
  • Fractions are computed as: overlaps / total_seqlets_per_motif
  • Missing motif combinations result in zero entries in the confusion matrix
Raises
  • ValueError: If required columns are missing from input DataFrames.
  • KeyError: If motif names in data don't match those in motif_names list.