finemo.visualization

Visualization module for generating plots and reports for Fi-NeMo results.

This module provides functions for:

  • Plotting motif contribution weight matrices (CWMs) as sequence logos
  • Generating distribution plots for hit statistics
  • Creating co-occurrence heatmaps
  • Producing HTML reports with interactive visualizations
  • Plotting confusion matrices and performance metrics
  1"""Visualization module for generating plots and reports for Fi-NeMo results.
  2
  3This module provides functions for:
  4- Plotting motif contribution weight matrices (CWMs) as sequence logos
  5- Generating distribution plots for hit statistics
  6- Creating co-occurrence heatmaps
  7- Producing HTML reports with interactive visualizations
  8- Plotting confusion matrices and performance metrics
  9"""
 10
 11import os
 12import importlib.resources
 13from typing import List, Optional, Dict, Any, Tuple, Union, Mapping, Iterable
 14
 15import numpy as np
 16from numpy import ndarray
 17import matplotlib.pyplot as plt
 18from matplotlib.axes import Axes
 19from matplotlib.patheffects import AbstractPathEffect
 20from matplotlib.textpath import TextPath
 21from matplotlib.transforms import Affine2D
 22from matplotlib.font_manager import FontProperties
 23from jinja2 import Template
 24import polars as pl
 25from jaxtyping import Float, Int
 26
 27from . import templates
 28
 29
 30def abbreviate_motif_name(name: str) -> str:
 31    """Convert TF-MoDISco motif names to abbreviated format.
 32
 33    Converts full TF-MoDISco pattern names to shorter, more readable format
 34    for display in plots and reports.
 35
 36    Parameters
 37    ----------
 38    name : str
 39        Full motif name (e.g., 'pos_patterns.pattern_0').
 40
 41    Returns
 42    -------
 43    str
 44        Abbreviated name (e.g., '+/0') or original name if parsing fails.
 45
 46    Examples
 47    --------
 48    >>> abbreviate_motif_name('pos_patterns.pattern_0')
 49    '+/0'
 50    >>> abbreviate_motif_name('neg_patterns.pattern_1')
 51    '-/1'
 52    >>> abbreviate_motif_name('invalid_name')
 53    'invalid_name'
 54    """
 55    try:
 56        group, motif = name.split(".")
 57        if group == "pos_patterns":
 58            group_short = "+"
 59        elif group == "neg_patterns":
 60            group_short = "-"
 61        else:
 62            raise Exception
 63        motif_num = motif.split("_")[1]
 64        return f"{group_short}/{motif_num}"
 65    except Exception:
 66        return name
 67
 68
 69def plot_hit_stat_distributions(
 70    hits_df: pl.LazyFrame, motif_names: List[str], plot_dir: str
 71) -> None:
 72    """Plot distributions of hit statistics for each motif.
 73
 74    Creates separate histogram plots for coefficient, similarity, and importance
 75    score distributions for each motif. Saves plots in both PNG (high-res) and
 76    SVG (vector) formats.
 77
 78    Parameters
 79    ----------
 80    hits_df : pl.LazyFrame
 81        Lazy DataFrame containing hit data with required columns:
 82        - motif_name : str, name of the motif
 83        - hit_coefficient_global : float, global coefficient values
 84        - hit_similarity : float, similarity scores to motif CWM
 85        - hit_importance : float, importance scores from attribution
 86    motif_names : List[str]
 87        List of motif names to generate plots for. Motifs not present
 88        in hits_df will result in empty histograms.
 89    plot_dir : str
 90        Directory path where plots will be saved. Creates subdirectory
 91        'motif_stat_distributions' if it doesn't exist.
 92
 93    Notes
 94    -----
 95    For each motif, creates three separate plots:
 96    - {motif_name}_coefficients.{png,svg} : coefficient distribution
 97    - {motif_name}_similarities.{png,svg} : similarity distribution
 98    - {motif_name}_importances.{png,svg} : importance distribution
 99    """
100    hits_df_collected = hits_df.collect()
101    hits_by_motif = hits_df_collected.partition_by("motif_name", as_dict=True)
102    dummy_df = hits_df_collected.clear()
103
104    motifs_dir = os.path.join(plot_dir, "motif_stat_distributions")
105    os.makedirs(motifs_dir, exist_ok=True)
106    for m in motif_names:
107        hits = hits_by_motif.get((m,), dummy_df)
108        coefficients = hits.get_column("hit_coefficient_global").to_numpy()
109        similarities = hits.get_column("hit_similarity").to_numpy()
110        importances = hits.get_column("hit_importance").to_numpy()
111
112        fig, ax = plt.subplots(figsize=(5, 2))
113
114        # Plot coefficient distribution
115        try:
116            ax.hist(coefficients, bins=50, density=True)
117        except ValueError:
118            ax.hist(coefficients, bins=1, density=True)
119
120        output_path_png = os.path.join(motifs_dir, f"{m}_coefficients.png")
121        plt.savefig(output_path_png, dpi=300)
122        output_path_svg = os.path.join(motifs_dir, f"{m}_coefficients.svg")
123        plt.savefig(output_path_svg)
124        plt.close(fig)
125
126        fig, ax = plt.subplots(figsize=(5, 2))
127
128        # Plot similarity distribution
129        try:
130            ax.hist(similarities, bins=50, density=True)
131        except ValueError:
132            ax.hist(similarities, bins=1, density=True)
133
134        output_path_png = os.path.join(motifs_dir, f"{m}_similarities.png")
135        plt.savefig(output_path_png, dpi=300)
136        output_path_svg = os.path.join(motifs_dir, f"{m}_similarities.svg")
137        plt.savefig(output_path_svg)
138        plt.close(fig)
139
140        fig, ax = plt.subplots(figsize=(5, 2))
141
142        # Plot importance distribution
143        try:
144            ax.hist(importances, bins=50, density=True)
145        except ValueError:
146            ax.hist(importances, bins=1, density=True)
147
148        output_path_png = os.path.join(motifs_dir, f"{m}_importances.png")
149        plt.savefig(output_path_png, dpi=300)
150        output_path_svg = os.path.join(motifs_dir, f"{m}_importances.svg")
151        plt.savefig(output_path_svg)
152        plt.close(fig)
153
154
155def plot_hit_peak_distributions(
156    occ_df: pl.DataFrame, motif_names: List[str], plot_dir: str
157) -> None:
158    """Plot distribution of hits per peak for each motif.
159
160    Creates bar plots showing the frequency distribution of hit counts per peak
161    for each motif, plus an overall distribution of total hits per peak.
162
163    Parameters
164    ----------
165    occ_df : pl.DataFrame
166        DataFrame containing motif occurrence counts per peak. Expected to have:
167        - One column per motif name with integer hit counts
168        - 'total' column with sum of all motif hits per peak
169        - Each row represents a peak/genomic region
170    motif_names : List[str]
171        List of motif names corresponding to columns in occ_df.
172    plot_dir : str
173        Directory to save plots. Creates 'motif_hit_distributions' subdirectory.
174
175    Notes
176    -----
177    Generates the following plots:
178    - Individual motif hit distributions: {motif_name}.{png,svg}
179    - Overall hit distribution: total_hit_distribution.{png,svg}
180
181    Bar plots show frequency (proportion) on y-axis and hit count on x-axis.
182    """
183    motifs_dir = os.path.join(plot_dir, "motif_hit_distributions")
184    os.makedirs(motifs_dir, exist_ok=True)
185
186    for m in motif_names:
187        fig, ax = plt.subplots(figsize=(5, 2))
188
189        unique, counts = np.unique(occ_df.get_column(m), return_counts=True)
190        freq = counts / counts.sum()
191        num_bins = np.amax(unique, initial=0) + 1
192        x = np.arange(num_bins)
193        y = np.zeros(num_bins)
194        y[unique] = freq
195        ax.bar(x, y)
196
197        output_path_png = os.path.join(motifs_dir, f"{m}.png")
198        plt.savefig(output_path_png, dpi=300)
199        output_path_svg = os.path.join(motifs_dir, f"{m}.svg")
200        plt.savefig(output_path_svg)
201
202        plt.close(fig)
203
204    fig, ax = plt.subplots(figsize=(8, 4))
205
206    unique, counts = np.unique(occ_df.get_column("total"), return_counts=True)
207    freq = counts / counts.sum()
208    num_bins = np.amax(unique, initial=0) + 1
209    x = np.arange(num_bins)
210    y = np.zeros(num_bins)
211    y[unique] = freq
212    ax.bar(x, y)
213
214    ax.set_xlabel("Total hits per region")
215    ax.set_ylabel("Frequency")
216
217    output_path_png = os.path.join(plot_dir, "total_hit_distribution.png")
218    plt.savefig(output_path_png, dpi=300)
219    output_path_svg = os.path.join(plot_dir, "total_hit_distribution.svg")
220    plt.savefig(output_path_svg, dpi=300)
221
222    plt.close(fig)
223
224
225def plot_peak_motif_indicator_heatmap(
226    peak_hit_counts: Int[ndarray, "M M"], motif_names: List[str], output_dir: str
227) -> None:
228    """Plot co-occurrence heatmap showing motif associations across peaks.
229
230    Creates a normalized correlation heatmap showing how frequently pairs of
231    motifs co-occur within the same genomic peaks. Values are normalized by
232    the geometric mean of individual motif frequencies.
233
234    Parameters
235    ----------
236    peak_hit_counts : Int[ndarray, "M M"]
237        Co-occurrence matrix where M = len(motif_names).
238        Entry (i,j) represents the number of peaks containing both motif i and j.
239        Diagonal entries represent total peaks containing each individual motif.
240    motif_names : List[str]
241        List of motif names for axis labels. Order must match matrix dimensions.
242    output_dir : str
243        Directory path where the heatmap plots will be saved.
244
245    Notes
246    -----
247    Saves plots as:
248    - motif_cooocurrence.png : High-resolution raster format
249    - motif_cooocurrence.svg : Vector format
250
251    The heatmap uses correlation normalization: matrix[i,j] / sqrt(matrix[i,i] * matrix[j,j])
252    Colors use the 'Greens' colormap with values typically in [0, 1] range.
253    """
254    cov_norm = 1 / np.sqrt(np.diag(peak_hit_counts))
255    matrix = peak_hit_counts * cov_norm[:, None] * cov_norm[None, :]
256    motif_keys = [abbreviate_motif_name(m) for m in motif_names]
257
258    fig, ax = plt.subplots(figsize=(8, 8), layout="constrained")
259
260    # Plot the heatmap
261    cax = ax.imshow(matrix, interpolation="nearest", aspect="equal", cmap="Greens")
262
263    # Set axes on heatmap
264    ax.set_yticks(np.arange(len(motif_keys)))
265    ax.set_yticklabels(motif_keys)
266    ax.set_xticks(np.arange(len(motif_keys)))
267    ax.set_xticklabels(motif_keys, rotation=90)
268    ax.set_xlabel("Motif i")
269    ax.set_ylabel("Motif j")
270
271    ax.tick_params(axis="both", labelsize=8)
272
273    cbar = fig.colorbar(cax, ax=ax, orientation="vertical", shrink=0.6, aspect=30)
274    cbar.ax.tick_params(labelsize=8)
275
276    output_path_png = os.path.join(output_dir, "motif_cooocurrence.png")
277    plt.savefig(output_path_png, dpi=300)
278    output_path_svg = os.path.join(output_dir, "motif_cooocurrence.svg")
279    plt.savefig(output_path_svg, dpi=300)
280
281    plt.close()
282
283
284def plot_seqlet_confusion_heatmap(
285    seqlet_confusion: Int[ndarray, "M M"], motif_names: List[str], output_dir: str
286) -> None:
287    """Plot confusion matrix heatmap comparing seqlets to hit calls.
288
289    Creates a heatmap showing the overlap between TF-MoDISco seqlets and
290    Fi-NeMo hit calls. Rows represent seqlet motifs, columns represent hit motifs.
291
292    Parameters
293    ----------
294    seqlet_confusion : Int[ndarray, "M M"]
295        Confusion matrix where M = len(motif_names).
296        Entry (i,j) represents the number of seqlets of motif i that overlap
297        with hits called for motif j.
298    motif_names : List[str]
299        List of motif names for axis labels. Order must match matrix dimensions.
300    output_dir : str
301        Directory path where the confusion matrix plots will be saved.
302
303    Notes
304    -----
305    Saves plots as:
306    - seqlet_confusion.png : High-resolution raster format
307    - seqlet_confusion.svg : Vector format
308
309    The heatmap uses 'Blues' colormap. Perfect agreement would show a diagonal
310    pattern with high values along the diagonal and low off-diagonal values.
311    """
312    motif_keys = [abbreviate_motif_name(m) for m in motif_names]
313
314    fig, ax = plt.subplots(figsize=(8, 8), layout="constrained")
315
316    # Plot the heatmap
317    cax = ax.imshow(
318        seqlet_confusion, interpolation="nearest", aspect="equal", cmap="Blues"
319    )
320
321    # Set axes on heatmap
322    ax.set_yticks(np.arange(len(motif_keys)))
323    ax.set_yticklabels(motif_keys)
324    ax.set_xticks(np.arange(len(motif_keys)))
325    ax.set_xticklabels(motif_keys, rotation=90)
326    ax.set_xlabel("Hit motif")
327    ax.set_ylabel("Seqlet motif")
328
329    ax.tick_params(axis="both", labelsize=8)
330
331    cbar = fig.colorbar(cax, ax=ax, orientation="vertical", shrink=0.6, aspect=30)
332    cbar.ax.tick_params(labelsize=8)
333
334    output_path_png = os.path.join(output_dir, "seqlet_confusion.png")
335    plt.savefig(output_path_png, dpi=300)
336    output_path_svg = os.path.join(output_dir, "seqlet_confusion.svg")
337    plt.savefig(output_path_svg, dpi=300)
338
339    plt.close()
340
341
342class LogoGlyph(AbstractPathEffect):
343    """Path effect for creating sequence logo glyphs with normalized dimensions.
344
345    This class creates properly scaled and positioned text glyphs for sequence
346    logos by normalizing character dimensions and applying appropriate transforms.
347
348    Parameters
349    ----------
350    glyph : str
351        Single character to render (e.g., 'A', 'C', 'G', 'T').
352    ref_glyph : str, default 'E'
353        Reference character used for width normalization.
354    font_props : FontProperties, optional
355        Font properties for the glyph rendering.
356    offset : Tuple[float, float], default (0., 0.)
357        Offset for glyph positioning.
358    **kwargs
359        Additional graphics collection parameters.
360    """
361
362    def __init__(
363        self,
364        glyph: str,
365        ref_glyph: str = "E",
366        font_props: Optional[FontProperties] = None,
367        offset: Tuple[float, float] = (0.0, 0.0),
368        **kwargs,
369    ) -> None:
370        super().__init__(offset)
371
372        path_orig = TextPath((0, 0), glyph, size=1, prop=font_props)
373        dims = path_orig.get_extents()
374        ref_dims = TextPath((0, 0), ref_glyph, size=1, prop=font_props).get_extents()
375
376        h_scale = 1 / dims.height
377        ref_width = max(dims.width, ref_dims.width)
378        w_scale = 1 / ref_width
379        w_shift = (1 - dims.width / ref_width) / 2
380        x_shift = -dims.x0
381        y_shift = -dims.y0
382        stretch = (
383            Affine2D()
384            .translate(tx=x_shift, ty=y_shift)
385            .scale(sx=w_scale, sy=h_scale)
386            .translate(tx=w_shift, ty=0)
387        )
388
389        self.path = stretch.transform_path(path_orig)
390
391        #: The dictionary of keywords to update the graphics collection with.
392        self._gc = kwargs
393
394    def draw_path(self, renderer, gc, tpath, affine, rgbFace) -> Any:  # type: ignore[override]
395        """Draw the glyph path using the renderer.
396
397        Parameters
398        ----------
399        renderer : matplotlib renderer
400            The renderer to draw with.
401        gc : GraphicsContext
402            Graphics context for drawing properties.
403        tpath : Path
404            Original text path (unused, using self.path instead).
405        affine : Transform
406            Affine transformation to apply.
407        rgbFace : color
408            Face color for the glyph.
409
410        Returns
411        -------
412        Any
413            Result from renderer.draw_path.
414        """
415        return renderer.draw_path(gc, self.path, affine, rgbFace)
416
417
418def plot_logo(
419    ax: Axes,
420    heights: Float[ndarray, "B W"],
421    glyphs: Iterable[str],
422    colors: Optional[Mapping[str, Optional[str]]] = None,
423    font_props: Optional[FontProperties] = None,
424    shade_bounds: Optional[Tuple[int, int]] = None,
425) -> None:
426    """Plot sequence logo from contribution weight matrix.
427
428    Creates a sequence logo visualization where letter heights represent
429    the contribution or information content at each position. Supports
430    both positive and negative contributions with proper stacking.
431
432    Parameters
433    ----------
434    ax : Axes
435        Matplotlib axes object to plot on.
436    heights : Float[ndarray, "B W"]
437        Height matrix where B = len(glyphs) and W = motif width.
438        Entry (i,j) represents the height/contribution of base i at position j.
439        Can contain both positive and negative values.
440    glyphs : Iterable[str]
441        Sequence of base characters corresponding to rows in heights matrix.
442        Typically ['A', 'C', 'G', 'T'] for DNA.
443    colors : Dict[str, str], optional
444        Color mapping for each base. Keys should match glyphs.
445        If None, all bases will use default matplotlib colors.
446    font_props : FontProperties, optional
447        Font properties for letter rendering. If None, uses default font.
448    shade_bounds : Tuple[int, int], optional
449        (start, end) position indices to shade in background.
450        Useful for highlighting core motif regions.
451
452    Notes
453    -----
454    Positive and negative contributions are handled separately:
455    - Positive values are stacked above zero line in order of descending absolute value
456    - Negative values are stacked below zero line in order of descending absolute value
457    - A horizontal line is drawn at y=0 for reference
458
459    The resulting plot has:
460    - X-axis: Position in motif (0-indexed)
461    - Y-axis: Contribution magnitude
462    - Bar width: 0.95 (small gaps between positions)
463    """
464    if colors is None:
465        colors = {g: None for g in glyphs}
466
467    ax.margins(x=0, y=0)
468
469    pos_values = np.clip(heights, 0, None)
470    neg_values = np.clip(heights, None, 0)
471    pos_order = np.argsort(pos_values, axis=0)
472    neg_order = np.argsort(neg_values, axis=0)[::-1, :]
473    pos_reorder = np.argsort(pos_order, axis=0)
474    neg_reorder = np.argsort(neg_order, axis=0)
475    pos_offsets = np.take_along_axis(
476        np.cumsum(np.take_along_axis(pos_values, pos_order, axis=0), axis=0),
477        pos_reorder,
478        axis=0,
479    )
480    neg_offsets = np.take_along_axis(
481        np.cumsum(np.take_along_axis(neg_values, neg_order, axis=0), axis=0),
482        neg_reorder,
483        axis=0,
484    )
485    bottoms = pos_offsets + neg_offsets - heights
486
487    x = np.arange(heights.shape[1])
488
489    for glyph, height, bottom in zip(glyphs, heights, bottoms):
490        ax.bar(
491            x,
492            height,
493            0.95,
494            bottom=bottom,
495            path_effects=[LogoGlyph(glyph, font_props=font_props)],
496            color=colors[glyph],
497        )
498
499    if shade_bounds is not None:
500        start, end = shade_bounds
501        ax.axvspan(start - 0.5, end - 0.5, color="0.9", zorder=-1)
502
503    ax.axhline(zorder=-1, linewidth=0.5, color="black")
504
505
506LOGO_ALPHABET = "ACGT"
507LOGO_COLORS = {"A": "#109648", "C": "#255C99", "G": "#F7B32B", "T": "#D62839"}
508LOGO_FONT = FontProperties(weight="bold")
509
510
511def plot_cwms(
512    cwms: Dict[str, Dict[str, Float[ndarray, "4 W"]]],
513    trim_bounds: Dict[str, Dict[str, Tuple[int, int]]],
514    out_dir: str,
515    alphabet: str = LOGO_ALPHABET,
516    colors: Dict[str, str] = LOGO_COLORS,
517    font: FontProperties = LOGO_FONT,
518) -> None:
519    """Plot contribution weight matrices as sequence logos.
520
521    Creates sequence logo plots for all motifs and CWM types, with optional
522    shading to highlight trimmed regions. Saves plots in both PNG and SVG formats.
523
524    Parameters
525    ----------
526    cwms : Dict[str, Dict[str, Float[ndarray, "4 W"]]]
527        Nested dictionary structure: {motif_name: {cwm_type: cwm_array}}.
528        Each cwm_array has shape (4, W) where W is motif width.
529        Rows correspond to bases in alphabet order.
530    trim_bounds : Dict[str, Dict[str, Tuple[int, int]]]
531        Nested dictionary: {motif_name: {cwm_type: (start, end)}}.
532        Defines regions to shade in the sequence logos.
533    out_dir : str
534        Output directory where motif subdirectories will be created.
535    alphabet : str, default LOGO_ALPHABET
536        DNA alphabet string, typically 'ACGT'.
537    colors : Dict[str, str], default LOGO_COLORS
538        Color mapping for DNA bases. Keys should match alphabet characters.
539    font : FontProperties, default LOGO_FONT
540        Font properties for sequence logo rendering.
541
542    Notes
543    -----
544    Directory structure created:
545    ```
546    out_dir/
547    ├── motif1/
548    │   ├── cwm_type1.png
549    │   ├── cwm_type1.svg
550    │   └── ...
551    └── motif2/
552        └── ...
553    ```
554
555    Each plot is 10x2 inches with trimmed regions shaded if specified.
556    Spines (plot borders) are hidden for cleaner appearance.
557    """
558    for m, v in cwms.items():
559        motif_dir = os.path.join(out_dir, m)
560        os.makedirs(motif_dir, exist_ok=True)
561        for cwm_type, cwm in v.items():
562            fig, ax = plt.subplots(figsize=(10, 2))
563
564            plot_logo(
565                ax,
566                cwm,
567                alphabet,
568                colors=colors,
569                font_props=font,
570                shade_bounds=trim_bounds[m][cwm_type],
571            )
572
573            for name, spine in ax.spines.items():
574                spine.set_visible(False)
575
576            output_path_png = os.path.join(motif_dir, f"{cwm_type}.png")
577            plt.savefig(output_path_png, dpi=100)
578            output_path_svg = os.path.join(motif_dir, f"{cwm_type}.svg")
579            plt.savefig(output_path_svg)
580
581            plt.close(fig)
582
583
584def plot_hit_vs_seqlet_counts(
585    recall_data: Dict[str, Dict[str, Union[int, float]]], output_dir: str
586) -> None:
587    """Plot scatter plot comparing hit counts to seqlet counts per motif.
588
589    Creates a log-log scatter plot showing the relationship between the number
590    of hits called by Fi-NeMo and the number of seqlets identified by TF-MoDISco
591    for each motif. Includes diagonal reference line and motif annotations.
592
593    Parameters
594    ----------
595    recall_data : Dict[str, Dict[str, Union[int, float]]]
596        Dictionary with motif names as keys and metrics dictionaries as values.
597        Each metrics dictionary must contain:
598        - 'num_hits_total' : int, total number of hits for the motif
599        - 'num_seqlets' : int, total number of seqlets for the motif
600    output_dir : str
601        Directory path where the scatter plot will be saved.
602
603    Notes
604    -----
605    Saves plots as:
606    - hit_vs_seqlet_counts.png : High-resolution raster format
607    - hit_vs_seqlet_counts.svg : Vector format
608
609    Plot features:
610    - Log-log scale on both axes
611    - Diagonal reference line (y = x) as dashed line
612    - Points annotated with abbreviated motif names
613    """
614    x = []
615    y = []
616    m = []
617    for k, v in recall_data.items():
618        x.append(v["num_hits_total"])
619        y.append(v["num_seqlets"])
620        m.append(k)
621
622    lim = max(np.amax(x), np.amax(y))
623
624    fig, ax = plt.subplots(figsize=(8, 8), layout="constrained")
625    ax.axline((0, 0), (lim, lim), color="0.3", linewidth=0.7, linestyle=(0, (5, 5)))
626    ax.scatter(x, y, s=5)
627    for i, txt in enumerate(m):
628        short = abbreviate_motif_name(txt)
629        ax.annotate(short, (x[i], y[i]), fontsize=8, weight="bold")
630
631    ax.set_yscale("log")
632    ax.set_xscale("log")
633
634    ax.set_xlabel("Hits per motif")
635    ax.set_ylabel("Seqlets per motif")
636
637    output_path_png = os.path.join(output_dir, "hit_vs_seqlet_counts.png")
638    plt.savefig(output_path_png, dpi=300)
639    output_path_svg = os.path.join(output_dir, "hit_vs_seqlet_counts.svg")
640    plt.savefig(output_path_svg)
641
642    plt.close()
643
644
645def write_report(
646    report_df: pl.DataFrame,
647    motif_names: List[str],
648    out_path: str,
649    compute_recall: bool,
650    use_seqlets: bool,
651) -> None:
652    """Generate and write HTML report from motif analysis results.
653
654    Creates a comprehensive HTML report with tables and visualizations
655    summarizing the Fi-NeMo motif discovery and hit calling results.
656
657    Parameters
658    ----------
659    report_df : pl.DataFrame
660        DataFrame containing motif statistics and performance metrics.
661        Expected columns depend on compute_recall and use_seqlets flags.
662    motif_names : List[str]
663        List of motif names to include in the report.
664        Order determines presentation sequence in the report.
665    out_path : str
666        File path where the HTML report will be written.
667        Parent directory must exist.
668    compute_recall : bool
669        Whether recall metrics were computed and should be included
670        in the report template.
671    use_seqlets : bool
672        Whether TF-MoDISco seqlet data was used in the analysis
673        and should be referenced in the report.
674
675    Notes
676    -----
677    Uses Jinja2 templating with the report.html template from the
678    templates package. The template receives:
679    - report_data: Iterator of DataFrame rows as named tuples
680    - motif_names: List of motif names
681    - compute_recall: Boolean flag for recall metrics
682    - use_seqlets: Boolean flag for seqlet usage
683
684    Raises
685    ------
686    OSError
687        If the output path cannot be written.
688    """
689    template_str = (
690        importlib.resources.files(templates).joinpath("report.html").read_text()
691    )
692    template = Template(template_str)
693    report = template.render(
694        report_data=report_df.iter_rows(named=True),
695        motif_names=motif_names,
696        compute_recall=compute_recall,
697        use_seqlets=use_seqlets,
698    )
699    with open(out_path, "w") as f:
700        f.write(report)
def abbreviate_motif_name(name: str) -> str:
31def abbreviate_motif_name(name: str) -> str:
32    """Convert TF-MoDISco motif names to abbreviated format.
33
34    Converts full TF-MoDISco pattern names to shorter, more readable format
35    for display in plots and reports.
36
37    Parameters
38    ----------
39    name : str
40        Full motif name (e.g., 'pos_patterns.pattern_0').
41
42    Returns
43    -------
44    str
45        Abbreviated name (e.g., '+/0') or original name if parsing fails.
46
47    Examples
48    --------
49    >>> abbreviate_motif_name('pos_patterns.pattern_0')
50    '+/0'
51    >>> abbreviate_motif_name('neg_patterns.pattern_1')
52    '-/1'
53    >>> abbreviate_motif_name('invalid_name')
54    'invalid_name'
55    """
56    try:
57        group, motif = name.split(".")
58        if group == "pos_patterns":
59            group_short = "+"
60        elif group == "neg_patterns":
61            group_short = "-"
62        else:
63            raise Exception
64        motif_num = motif.split("_")[1]
65        return f"{group_short}/{motif_num}"
66    except Exception:
67        return name

Convert TF-MoDISco motif names to abbreviated format.

Converts full TF-MoDISco pattern names to shorter, more readable format for display in plots and reports.

Parameters
  • name (str): Full motif name (e.g., 'pos_patterns.pattern_0').
Returns
  • str: Abbreviated name (e.g., '+/0') or original name if parsing fails.
Examples
>>> abbreviate_motif_name('pos_patterns.pattern_0')
'+/0'
>>> abbreviate_motif_name('neg_patterns.pattern_1')
'-/1'
>>> abbreviate_motif_name('invalid_name')
'invalid_name'
def plot_hit_stat_distributions( hits_df: polars.lazyframe.frame.LazyFrame, motif_names: List[str], plot_dir: str) -> None:
 70def plot_hit_stat_distributions(
 71    hits_df: pl.LazyFrame, motif_names: List[str], plot_dir: str
 72) -> None:
 73    """Plot distributions of hit statistics for each motif.
 74
 75    Creates separate histogram plots for coefficient, similarity, and importance
 76    score distributions for each motif. Saves plots in both PNG (high-res) and
 77    SVG (vector) formats.
 78
 79    Parameters
 80    ----------
 81    hits_df : pl.LazyFrame
 82        Lazy DataFrame containing hit data with required columns:
 83        - motif_name : str, name of the motif
 84        - hit_coefficient_global : float, global coefficient values
 85        - hit_similarity : float, similarity scores to motif CWM
 86        - hit_importance : float, importance scores from attribution
 87    motif_names : List[str]
 88        List of motif names to generate plots for. Motifs not present
 89        in hits_df will result in empty histograms.
 90    plot_dir : str
 91        Directory path where plots will be saved. Creates subdirectory
 92        'motif_stat_distributions' if it doesn't exist.
 93
 94    Notes
 95    -----
 96    For each motif, creates three separate plots:
 97    - {motif_name}_coefficients.{png,svg} : coefficient distribution
 98    - {motif_name}_similarities.{png,svg} : similarity distribution
 99    - {motif_name}_importances.{png,svg} : importance distribution
100    """
101    hits_df_collected = hits_df.collect()
102    hits_by_motif = hits_df_collected.partition_by("motif_name", as_dict=True)
103    dummy_df = hits_df_collected.clear()
104
105    motifs_dir = os.path.join(plot_dir, "motif_stat_distributions")
106    os.makedirs(motifs_dir, exist_ok=True)
107    for m in motif_names:
108        hits = hits_by_motif.get((m,), dummy_df)
109        coefficients = hits.get_column("hit_coefficient_global").to_numpy()
110        similarities = hits.get_column("hit_similarity").to_numpy()
111        importances = hits.get_column("hit_importance").to_numpy()
112
113        fig, ax = plt.subplots(figsize=(5, 2))
114
115        # Plot coefficient distribution
116        try:
117            ax.hist(coefficients, bins=50, density=True)
118        except ValueError:
119            ax.hist(coefficients, bins=1, density=True)
120
121        output_path_png = os.path.join(motifs_dir, f"{m}_coefficients.png")
122        plt.savefig(output_path_png, dpi=300)
123        output_path_svg = os.path.join(motifs_dir, f"{m}_coefficients.svg")
124        plt.savefig(output_path_svg)
125        plt.close(fig)
126
127        fig, ax = plt.subplots(figsize=(5, 2))
128
129        # Plot similarity distribution
130        try:
131            ax.hist(similarities, bins=50, density=True)
132        except ValueError:
133            ax.hist(similarities, bins=1, density=True)
134
135        output_path_png = os.path.join(motifs_dir, f"{m}_similarities.png")
136        plt.savefig(output_path_png, dpi=300)
137        output_path_svg = os.path.join(motifs_dir, f"{m}_similarities.svg")
138        plt.savefig(output_path_svg)
139        plt.close(fig)
140
141        fig, ax = plt.subplots(figsize=(5, 2))
142
143        # Plot importance distribution
144        try:
145            ax.hist(importances, bins=50, density=True)
146        except ValueError:
147            ax.hist(importances, bins=1, density=True)
148
149        output_path_png = os.path.join(motifs_dir, f"{m}_importances.png")
150        plt.savefig(output_path_png, dpi=300)
151        output_path_svg = os.path.join(motifs_dir, f"{m}_importances.svg")
152        plt.savefig(output_path_svg)
153        plt.close(fig)

Plot distributions of hit statistics for each motif.

Creates separate histogram plots for coefficient, similarity, and importance score distributions for each motif. Saves plots in both PNG (high-res) and SVG (vector) formats.

Parameters
  • hits_df (pl.LazyFrame): Lazy DataFrame containing hit data with required columns:
    • motif_name : str, name of the motif
    • hit_coefficient_global : float, global coefficient values
    • hit_similarity : float, similarity scores to motif CWM
    • hit_importance : float, importance scores from attribution
  • motif_names (List[str]): List of motif names to generate plots for. Motifs not present in hits_df will result in empty histograms.
  • plot_dir (str): Directory path where plots will be saved. Creates subdirectory 'motif_stat_distributions' if it doesn't exist.
Notes

For each motif, creates three separate plots:

  • {motif_name}_coefficients.{png,svg} : coefficient distribution
  • {motif_name}_similarities.{png,svg} : similarity distribution
  • {motif_name}_importances.{png,svg} : importance distribution
def plot_hit_peak_distributions( occ_df: polars.dataframe.frame.DataFrame, motif_names: List[str], plot_dir: str) -> None:
156def plot_hit_peak_distributions(
157    occ_df: pl.DataFrame, motif_names: List[str], plot_dir: str
158) -> None:
159    """Plot distribution of hits per peak for each motif.
160
161    Creates bar plots showing the frequency distribution of hit counts per peak
162    for each motif, plus an overall distribution of total hits per peak.
163
164    Parameters
165    ----------
166    occ_df : pl.DataFrame
167        DataFrame containing motif occurrence counts per peak. Expected to have:
168        - One column per motif name with integer hit counts
169        - 'total' column with sum of all motif hits per peak
170        - Each row represents a peak/genomic region
171    motif_names : List[str]
172        List of motif names corresponding to columns in occ_df.
173    plot_dir : str
174        Directory to save plots. Creates 'motif_hit_distributions' subdirectory.
175
176    Notes
177    -----
178    Generates the following plots:
179    - Individual motif hit distributions: {motif_name}.{png,svg}
180    - Overall hit distribution: total_hit_distribution.{png,svg}
181
182    Bar plots show frequency (proportion) on y-axis and hit count on x-axis.
183    """
184    motifs_dir = os.path.join(plot_dir, "motif_hit_distributions")
185    os.makedirs(motifs_dir, exist_ok=True)
186
187    for m in motif_names:
188        fig, ax = plt.subplots(figsize=(5, 2))
189
190        unique, counts = np.unique(occ_df.get_column(m), return_counts=True)
191        freq = counts / counts.sum()
192        num_bins = np.amax(unique, initial=0) + 1
193        x = np.arange(num_bins)
194        y = np.zeros(num_bins)
195        y[unique] = freq
196        ax.bar(x, y)
197
198        output_path_png = os.path.join(motifs_dir, f"{m}.png")
199        plt.savefig(output_path_png, dpi=300)
200        output_path_svg = os.path.join(motifs_dir, f"{m}.svg")
201        plt.savefig(output_path_svg)
202
203        plt.close(fig)
204
205    fig, ax = plt.subplots(figsize=(8, 4))
206
207    unique, counts = np.unique(occ_df.get_column("total"), return_counts=True)
208    freq = counts / counts.sum()
209    num_bins = np.amax(unique, initial=0) + 1
210    x = np.arange(num_bins)
211    y = np.zeros(num_bins)
212    y[unique] = freq
213    ax.bar(x, y)
214
215    ax.set_xlabel("Total hits per region")
216    ax.set_ylabel("Frequency")
217
218    output_path_png = os.path.join(plot_dir, "total_hit_distribution.png")
219    plt.savefig(output_path_png, dpi=300)
220    output_path_svg = os.path.join(plot_dir, "total_hit_distribution.svg")
221    plt.savefig(output_path_svg, dpi=300)
222
223    plt.close(fig)

Plot distribution of hits per peak for each motif.

Creates bar plots showing the frequency distribution of hit counts per peak for each motif, plus an overall distribution of total hits per peak.

Parameters
  • occ_df (pl.DataFrame): DataFrame containing motif occurrence counts per peak. Expected to have:
    • One column per motif name with integer hit counts
    • 'total' column with sum of all motif hits per peak
    • Each row represents a peak/genomic region
  • motif_names (List[str]): List of motif names corresponding to columns in occ_df.
  • plot_dir (str): Directory to save plots. Creates 'motif_hit_distributions' subdirectory.
Notes

Generates the following plots:

  • Individual motif hit distributions: {motif_name}.{png,svg}
  • Overall hit distribution: total_hit_distribution.{png,svg}

Bar plots show frequency (proportion) on y-axis and hit count on x-axis.

def plot_peak_motif_indicator_heatmap( peak_hit_counts: jaxtyping.Int[ndarray, 'M M'], motif_names: List[str], output_dir: str) -> None:
226def plot_peak_motif_indicator_heatmap(
227    peak_hit_counts: Int[ndarray, "M M"], motif_names: List[str], output_dir: str
228) -> None:
229    """Plot co-occurrence heatmap showing motif associations across peaks.
230
231    Creates a normalized correlation heatmap showing how frequently pairs of
232    motifs co-occur within the same genomic peaks. Values are normalized by
233    the geometric mean of individual motif frequencies.
234
235    Parameters
236    ----------
237    peak_hit_counts : Int[ndarray, "M M"]
238        Co-occurrence matrix where M = len(motif_names).
239        Entry (i,j) represents the number of peaks containing both motif i and j.
240        Diagonal entries represent total peaks containing each individual motif.
241    motif_names : List[str]
242        List of motif names for axis labels. Order must match matrix dimensions.
243    output_dir : str
244        Directory path where the heatmap plots will be saved.
245
246    Notes
247    -----
248    Saves plots as:
249    - motif_cooocurrence.png : High-resolution raster format
250    - motif_cooocurrence.svg : Vector format
251
252    The heatmap uses correlation normalization: matrix[i,j] / sqrt(matrix[i,i] * matrix[j,j])
253    Colors use the 'Greens' colormap with values typically in [0, 1] range.
254    """
255    cov_norm = 1 / np.sqrt(np.diag(peak_hit_counts))
256    matrix = peak_hit_counts * cov_norm[:, None] * cov_norm[None, :]
257    motif_keys = [abbreviate_motif_name(m) for m in motif_names]
258
259    fig, ax = plt.subplots(figsize=(8, 8), layout="constrained")
260
261    # Plot the heatmap
262    cax = ax.imshow(matrix, interpolation="nearest", aspect="equal", cmap="Greens")
263
264    # Set axes on heatmap
265    ax.set_yticks(np.arange(len(motif_keys)))
266    ax.set_yticklabels(motif_keys)
267    ax.set_xticks(np.arange(len(motif_keys)))
268    ax.set_xticklabels(motif_keys, rotation=90)
269    ax.set_xlabel("Motif i")
270    ax.set_ylabel("Motif j")
271
272    ax.tick_params(axis="both", labelsize=8)
273
274    cbar = fig.colorbar(cax, ax=ax, orientation="vertical", shrink=0.6, aspect=30)
275    cbar.ax.tick_params(labelsize=8)
276
277    output_path_png = os.path.join(output_dir, "motif_cooocurrence.png")
278    plt.savefig(output_path_png, dpi=300)
279    output_path_svg = os.path.join(output_dir, "motif_cooocurrence.svg")
280    plt.savefig(output_path_svg, dpi=300)
281
282    plt.close()

Plot co-occurrence heatmap showing motif associations across peaks.

Creates a normalized correlation heatmap showing how frequently pairs of motifs co-occur within the same genomic peaks. Values are normalized by the geometric mean of individual motif frequencies.

Parameters
  • peak_hit_counts (Int[ndarray, "M M"]): Co-occurrence matrix where M = len(motif_names). Entry (i,j) represents the number of peaks containing both motif i and j. Diagonal entries represent total peaks containing each individual motif.
  • motif_names (List[str]): List of motif names for axis labels. Order must match matrix dimensions.
  • output_dir (str): Directory path where the heatmap plots will be saved.
Notes

Saves plots as:

  • motif_cooocurrence.png : High-resolution raster format
  • motif_cooocurrence.svg : Vector format

The heatmap uses correlation normalization: matrix[i,j] / sqrt(matrix[i,i] * matrix[j,j]) Colors use the 'Greens' colormap with values typically in [0, 1] range.

def plot_seqlet_confusion_heatmap( seqlet_confusion: jaxtyping.Int[ndarray, 'M M'], motif_names: List[str], output_dir: str) -> None:
285def plot_seqlet_confusion_heatmap(
286    seqlet_confusion: Int[ndarray, "M M"], motif_names: List[str], output_dir: str
287) -> None:
288    """Plot confusion matrix heatmap comparing seqlets to hit calls.
289
290    Creates a heatmap showing the overlap between TF-MoDISco seqlets and
291    Fi-NeMo hit calls. Rows represent seqlet motifs, columns represent hit motifs.
292
293    Parameters
294    ----------
295    seqlet_confusion : Int[ndarray, "M M"]
296        Confusion matrix where M = len(motif_names).
297        Entry (i,j) represents the number of seqlets of motif i that overlap
298        with hits called for motif j.
299    motif_names : List[str]
300        List of motif names for axis labels. Order must match matrix dimensions.
301    output_dir : str
302        Directory path where the confusion matrix plots will be saved.
303
304    Notes
305    -----
306    Saves plots as:
307    - seqlet_confusion.png : High-resolution raster format
308    - seqlet_confusion.svg : Vector format
309
310    The heatmap uses 'Blues' colormap. Perfect agreement would show a diagonal
311    pattern with high values along the diagonal and low off-diagonal values.
312    """
313    motif_keys = [abbreviate_motif_name(m) for m in motif_names]
314
315    fig, ax = plt.subplots(figsize=(8, 8), layout="constrained")
316
317    # Plot the heatmap
318    cax = ax.imshow(
319        seqlet_confusion, interpolation="nearest", aspect="equal", cmap="Blues"
320    )
321
322    # Set axes on heatmap
323    ax.set_yticks(np.arange(len(motif_keys)))
324    ax.set_yticklabels(motif_keys)
325    ax.set_xticks(np.arange(len(motif_keys)))
326    ax.set_xticklabels(motif_keys, rotation=90)
327    ax.set_xlabel("Hit motif")
328    ax.set_ylabel("Seqlet motif")
329
330    ax.tick_params(axis="both", labelsize=8)
331
332    cbar = fig.colorbar(cax, ax=ax, orientation="vertical", shrink=0.6, aspect=30)
333    cbar.ax.tick_params(labelsize=8)
334
335    output_path_png = os.path.join(output_dir, "seqlet_confusion.png")
336    plt.savefig(output_path_png, dpi=300)
337    output_path_svg = os.path.join(output_dir, "seqlet_confusion.svg")
338    plt.savefig(output_path_svg, dpi=300)
339
340    plt.close()

Plot confusion matrix heatmap comparing seqlets to hit calls.

Creates a heatmap showing the overlap between TF-MoDISco seqlets and Fi-NeMo hit calls. Rows represent seqlet motifs, columns represent hit motifs.

Parameters
  • seqlet_confusion (Int[ndarray, "M M"]): Confusion matrix where M = len(motif_names). Entry (i,j) represents the number of seqlets of motif i that overlap with hits called for motif j.
  • motif_names (List[str]): List of motif names for axis labels. Order must match matrix dimensions.
  • output_dir (str): Directory path where the confusion matrix plots will be saved.
Notes

Saves plots as:

  • seqlet_confusion.png : High-resolution raster format
  • seqlet_confusion.svg : Vector format

The heatmap uses 'Blues' colormap. Perfect agreement would show a diagonal pattern with high values along the diagonal and low off-diagonal values.

class LogoGlyph(matplotlib.patheffects.AbstractPathEffect):
343class LogoGlyph(AbstractPathEffect):
344    """Path effect for creating sequence logo glyphs with normalized dimensions.
345
346    This class creates properly scaled and positioned text glyphs for sequence
347    logos by normalizing character dimensions and applying appropriate transforms.
348
349    Parameters
350    ----------
351    glyph : str
352        Single character to render (e.g., 'A', 'C', 'G', 'T').
353    ref_glyph : str, default 'E'
354        Reference character used for width normalization.
355    font_props : FontProperties, optional
356        Font properties for the glyph rendering.
357    offset : Tuple[float, float], default (0., 0.)
358        Offset for glyph positioning.
359    **kwargs
360        Additional graphics collection parameters.
361    """
362
363    def __init__(
364        self,
365        glyph: str,
366        ref_glyph: str = "E",
367        font_props: Optional[FontProperties] = None,
368        offset: Tuple[float, float] = (0.0, 0.0),
369        **kwargs,
370    ) -> None:
371        super().__init__(offset)
372
373        path_orig = TextPath((0, 0), glyph, size=1, prop=font_props)
374        dims = path_orig.get_extents()
375        ref_dims = TextPath((0, 0), ref_glyph, size=1, prop=font_props).get_extents()
376
377        h_scale = 1 / dims.height
378        ref_width = max(dims.width, ref_dims.width)
379        w_scale = 1 / ref_width
380        w_shift = (1 - dims.width / ref_width) / 2
381        x_shift = -dims.x0
382        y_shift = -dims.y0
383        stretch = (
384            Affine2D()
385            .translate(tx=x_shift, ty=y_shift)
386            .scale(sx=w_scale, sy=h_scale)
387            .translate(tx=w_shift, ty=0)
388        )
389
390        self.path = stretch.transform_path(path_orig)
391
392        #: The dictionary of keywords to update the graphics collection with.
393        self._gc = kwargs
394
395    def draw_path(self, renderer, gc, tpath, affine, rgbFace) -> Any:  # type: ignore[override]
396        """Draw the glyph path using the renderer.
397
398        Parameters
399        ----------
400        renderer : matplotlib renderer
401            The renderer to draw with.
402        gc : GraphicsContext
403            Graphics context for drawing properties.
404        tpath : Path
405            Original text path (unused, using self.path instead).
406        affine : Transform
407            Affine transformation to apply.
408        rgbFace : color
409            Face color for the glyph.
410
411        Returns
412        -------
413        Any
414            Result from renderer.draw_path.
415        """
416        return renderer.draw_path(gc, self.path, affine, rgbFace)

Path effect for creating sequence logo glyphs with normalized dimensions.

This class creates properly scaled and positioned text glyphs for sequence logos by normalizing character dimensions and applying appropriate transforms.

Parameters
  • glyph (str): Single character to render (e.g., 'A', 'C', 'G', 'T').
  • ref_glyph (str, default 'E'): Reference character used for width normalization.
  • font_props (FontProperties, optional): Font properties for the glyph rendering.
  • offset (Tuple[float, float], default (0., 0.)): Offset for glyph positioning.
  • **kwargs: Additional graphics collection parameters.
LogoGlyph( glyph: str, ref_glyph: str = 'E', font_props: Optional[matplotlib.font_manager.FontProperties] = None, offset: Tuple[float, float] = (0.0, 0.0), **kwargs)
363    def __init__(
364        self,
365        glyph: str,
366        ref_glyph: str = "E",
367        font_props: Optional[FontProperties] = None,
368        offset: Tuple[float, float] = (0.0, 0.0),
369        **kwargs,
370    ) -> None:
371        super().__init__(offset)
372
373        path_orig = TextPath((0, 0), glyph, size=1, prop=font_props)
374        dims = path_orig.get_extents()
375        ref_dims = TextPath((0, 0), ref_glyph, size=1, prop=font_props).get_extents()
376
377        h_scale = 1 / dims.height
378        ref_width = max(dims.width, ref_dims.width)
379        w_scale = 1 / ref_width
380        w_shift = (1 - dims.width / ref_width) / 2
381        x_shift = -dims.x0
382        y_shift = -dims.y0
383        stretch = (
384            Affine2D()
385            .translate(tx=x_shift, ty=y_shift)
386            .scale(sx=w_scale, sy=h_scale)
387            .translate(tx=w_shift, ty=0)
388        )
389
390        self.path = stretch.transform_path(path_orig)
391
392        #: The dictionary of keywords to update the graphics collection with.
393        self._gc = kwargs
Parameters
  • offset : (float, float), default ((0, 0)): The (x, y) offset to apply to the path, measured in points.
path
def draw_path(self, renderer, gc, tpath, affine, rgbFace) -> Any:
395    def draw_path(self, renderer, gc, tpath, affine, rgbFace) -> Any:  # type: ignore[override]
396        """Draw the glyph path using the renderer.
397
398        Parameters
399        ----------
400        renderer : matplotlib renderer
401            The renderer to draw with.
402        gc : GraphicsContext
403            Graphics context for drawing properties.
404        tpath : Path
405            Original text path (unused, using self.path instead).
406        affine : Transform
407            Affine transformation to apply.
408        rgbFace : color
409            Face color for the glyph.
410
411        Returns
412        -------
413        Any
414            Result from renderer.draw_path.
415        """
416        return renderer.draw_path(gc, self.path, affine, rgbFace)

Draw the glyph path using the renderer.

Parameters
  • renderer (matplotlib renderer): The renderer to draw with.
  • gc (GraphicsContext): Graphics context for drawing properties.
  • tpath (Path): Original text path (unused, using self.path instead).
  • affine (Transform): Affine transformation to apply.
  • rgbFace (color): Face color for the glyph.
Returns
  • Any: Result from renderer.draw_path.
LOGO_ALPHABET = 'ACGT'
LOGO_COLORS = {'A': '#109648', 'C': '#255C99', 'G': '#F7B32B', 'T': '#D62839'}
LOGO_FONT = <matplotlib.font_manager.FontProperties object>
def plot_cwms( cwms: Dict[str, Dict[str, jaxtyping.Float[ndarray, '4 W']]], trim_bounds: Dict[str, Dict[str, Tuple[int, int]]], out_dir: str, alphabet: str = 'ACGT', colors: Dict[str, str] = {'A': '#109648', 'C': '#255C99', 'G': '#F7B32B', 'T': '#D62839'}, font: matplotlib.font_manager.FontProperties = <matplotlib.font_manager.FontProperties object>) -> None:
512def plot_cwms(
513    cwms: Dict[str, Dict[str, Float[ndarray, "4 W"]]],
514    trim_bounds: Dict[str, Dict[str, Tuple[int, int]]],
515    out_dir: str,
516    alphabet: str = LOGO_ALPHABET,
517    colors: Dict[str, str] = LOGO_COLORS,
518    font: FontProperties = LOGO_FONT,
519) -> None:
520    """Plot contribution weight matrices as sequence logos.
521
522    Creates sequence logo plots for all motifs and CWM types, with optional
523    shading to highlight trimmed regions. Saves plots in both PNG and SVG formats.
524
525    Parameters
526    ----------
527    cwms : Dict[str, Dict[str, Float[ndarray, "4 W"]]]
528        Nested dictionary structure: {motif_name: {cwm_type: cwm_array}}.
529        Each cwm_array has shape (4, W) where W is motif width.
530        Rows correspond to bases in alphabet order.
531    trim_bounds : Dict[str, Dict[str, Tuple[int, int]]]
532        Nested dictionary: {motif_name: {cwm_type: (start, end)}}.
533        Defines regions to shade in the sequence logos.
534    out_dir : str
535        Output directory where motif subdirectories will be created.
536    alphabet : str, default LOGO_ALPHABET
537        DNA alphabet string, typically 'ACGT'.
538    colors : Dict[str, str], default LOGO_COLORS
539        Color mapping for DNA bases. Keys should match alphabet characters.
540    font : FontProperties, default LOGO_FONT
541        Font properties for sequence logo rendering.
542
543    Notes
544    -----
545    Directory structure created:
546    ```
547    out_dir/
548    ├── motif1/
549    │   ├── cwm_type1.png
550    │   ├── cwm_type1.svg
551    │   └── ...
552    └── motif2/
553        └── ...
554    ```
555
556    Each plot is 10x2 inches with trimmed regions shaded if specified.
557    Spines (plot borders) are hidden for cleaner appearance.
558    """
559    for m, v in cwms.items():
560        motif_dir = os.path.join(out_dir, m)
561        os.makedirs(motif_dir, exist_ok=True)
562        for cwm_type, cwm in v.items():
563            fig, ax = plt.subplots(figsize=(10, 2))
564
565            plot_logo(
566                ax,
567                cwm,
568                alphabet,
569                colors=colors,
570                font_props=font,
571                shade_bounds=trim_bounds[m][cwm_type],
572            )
573
574            for name, spine in ax.spines.items():
575                spine.set_visible(False)
576
577            output_path_png = os.path.join(motif_dir, f"{cwm_type}.png")
578            plt.savefig(output_path_png, dpi=100)
579            output_path_svg = os.path.join(motif_dir, f"{cwm_type}.svg")
580            plt.savefig(output_path_svg)
581
582            plt.close(fig)

Plot contribution weight matrices as sequence logos.

Creates sequence logo plots for all motifs and CWM types, with optional shading to highlight trimmed regions. Saves plots in both PNG and SVG formats.

Parameters
  • cwms (Dict[str, Dict[str, Float[ndarray, "4 W"]]]): Nested dictionary structure: {motif_name: {cwm_type: cwm_array}}. Each cwm_array has shape (4, W) where W is motif width. Rows correspond to bases in alphabet order.
  • trim_bounds (Dict[str, Dict[str, Tuple[int, int]]]): Nested dictionary: {motif_name: {cwm_type: (start, end)}}. Defines regions to shade in the sequence logos.
  • out_dir (str): Output directory where motif subdirectories will be created.
  • alphabet (str, default LOGO_ALPHABET): DNA alphabet string, typically 'ACGT'.
  • colors (Dict[str, str], default LOGO_COLORS): Color mapping for DNA bases. Keys should match alphabet characters.
  • font (FontProperties, default LOGO_FONT): Font properties for sequence logo rendering.
Notes

Directory structure created:

out_dir/
├── motif1/
│   ├── cwm_type1.png
│   ├── cwm_type1.svg
│   └── ...
└── motif2/
    └── ...

Each plot is 10x2 inches with trimmed regions shaded if specified. Spines (plot borders) are hidden for cleaner appearance.

def plot_hit_vs_seqlet_counts( recall_data: Dict[str, Dict[str, Union[int, float]]], output_dir: str) -> None:
585def plot_hit_vs_seqlet_counts(
586    recall_data: Dict[str, Dict[str, Union[int, float]]], output_dir: str
587) -> None:
588    """Plot scatter plot comparing hit counts to seqlet counts per motif.
589
590    Creates a log-log scatter plot showing the relationship between the number
591    of hits called by Fi-NeMo and the number of seqlets identified by TF-MoDISco
592    for each motif. Includes diagonal reference line and motif annotations.
593
594    Parameters
595    ----------
596    recall_data : Dict[str, Dict[str, Union[int, float]]]
597        Dictionary with motif names as keys and metrics dictionaries as values.
598        Each metrics dictionary must contain:
599        - 'num_hits_total' : int, total number of hits for the motif
600        - 'num_seqlets' : int, total number of seqlets for the motif
601    output_dir : str
602        Directory path where the scatter plot will be saved.
603
604    Notes
605    -----
606    Saves plots as:
607    - hit_vs_seqlet_counts.png : High-resolution raster format
608    - hit_vs_seqlet_counts.svg : Vector format
609
610    Plot features:
611    - Log-log scale on both axes
612    - Diagonal reference line (y = x) as dashed line
613    - Points annotated with abbreviated motif names
614    """
615    x = []
616    y = []
617    m = []
618    for k, v in recall_data.items():
619        x.append(v["num_hits_total"])
620        y.append(v["num_seqlets"])
621        m.append(k)
622
623    lim = max(np.amax(x), np.amax(y))
624
625    fig, ax = plt.subplots(figsize=(8, 8), layout="constrained")
626    ax.axline((0, 0), (lim, lim), color="0.3", linewidth=0.7, linestyle=(0, (5, 5)))
627    ax.scatter(x, y, s=5)
628    for i, txt in enumerate(m):
629        short = abbreviate_motif_name(txt)
630        ax.annotate(short, (x[i], y[i]), fontsize=8, weight="bold")
631
632    ax.set_yscale("log")
633    ax.set_xscale("log")
634
635    ax.set_xlabel("Hits per motif")
636    ax.set_ylabel("Seqlets per motif")
637
638    output_path_png = os.path.join(output_dir, "hit_vs_seqlet_counts.png")
639    plt.savefig(output_path_png, dpi=300)
640    output_path_svg = os.path.join(output_dir, "hit_vs_seqlet_counts.svg")
641    plt.savefig(output_path_svg)
642
643    plt.close()

Plot scatter plot comparing hit counts to seqlet counts per motif.

Creates a log-log scatter plot showing the relationship between the number of hits called by Fi-NeMo and the number of seqlets identified by TF-MoDISco for each motif. Includes diagonal reference line and motif annotations.

Parameters
  • recall_data (Dict[str, Dict[str, Union[int, float]]]): Dictionary with motif names as keys and metrics dictionaries as values. Each metrics dictionary must contain:
    • 'num_hits_total' : int, total number of hits for the motif
    • 'num_seqlets' : int, total number of seqlets for the motif
  • output_dir (str): Directory path where the scatter plot will be saved.
Notes

Saves plots as:

  • hit_vs_seqlet_counts.png : High-resolution raster format
  • hit_vs_seqlet_counts.svg : Vector format

Plot features:

  • Log-log scale on both axes
  • Diagonal reference line (y = x) as dashed line
  • Points annotated with abbreviated motif names
def write_report( report_df: polars.dataframe.frame.DataFrame, motif_names: List[str], out_path: str, compute_recall: bool, use_seqlets: bool) -> None:
646def write_report(
647    report_df: pl.DataFrame,
648    motif_names: List[str],
649    out_path: str,
650    compute_recall: bool,
651    use_seqlets: bool,
652) -> None:
653    """Generate and write HTML report from motif analysis results.
654
655    Creates a comprehensive HTML report with tables and visualizations
656    summarizing the Fi-NeMo motif discovery and hit calling results.
657
658    Parameters
659    ----------
660    report_df : pl.DataFrame
661        DataFrame containing motif statistics and performance metrics.
662        Expected columns depend on compute_recall and use_seqlets flags.
663    motif_names : List[str]
664        List of motif names to include in the report.
665        Order determines presentation sequence in the report.
666    out_path : str
667        File path where the HTML report will be written.
668        Parent directory must exist.
669    compute_recall : bool
670        Whether recall metrics were computed and should be included
671        in the report template.
672    use_seqlets : bool
673        Whether TF-MoDISco seqlet data was used in the analysis
674        and should be referenced in the report.
675
676    Notes
677    -----
678    Uses Jinja2 templating with the report.html template from the
679    templates package. The template receives:
680    - report_data: Iterator of DataFrame rows as named tuples
681    - motif_names: List of motif names
682    - compute_recall: Boolean flag for recall metrics
683    - use_seqlets: Boolean flag for seqlet usage
684
685    Raises
686    ------
687    OSError
688        If the output path cannot be written.
689    """
690    template_str = (
691        importlib.resources.files(templates).joinpath("report.html").read_text()
692    )
693    template = Template(template_str)
694    report = template.render(
695        report_data=report_df.iter_rows(named=True),
696        motif_names=motif_names,
697        compute_recall=compute_recall,
698        use_seqlets=use_seqlets,
699    )
700    with open(out_path, "w") as f:
701        f.write(report)

Generate and write HTML report from motif analysis results.

Creates a comprehensive HTML report with tables and visualizations summarizing the Fi-NeMo motif discovery and hit calling results.

Parameters
  • report_df (pl.DataFrame): DataFrame containing motif statistics and performance metrics. Expected columns depend on compute_recall and use_seqlets flags.
  • motif_names (List[str]): List of motif names to include in the report. Order determines presentation sequence in the report.
  • out_path (str): File path where the HTML report will be written. Parent directory must exist.
  • compute_recall (bool): Whether recall metrics were computed and should be included in the report template.
  • use_seqlets (bool): Whether TF-MoDISco seqlet data was used in the analysis and should be referenced in the report.
Notes

Uses Jinja2 templating with the report.html template from the templates package. The template receives:

  • report_data: Iterator of DataFrame rows as named tuples
  • motif_names: List of motif names
  • compute_recall: Boolean flag for recall metrics
  • use_seqlets: Boolean flag for seqlet usage
Raises
  • OSError: If the output path cannot be written.