finemo.visualization
Visualization module for generating plots and reports for Fi-NeMo results.
This module provides functions for:
- Plotting motif contribution weight matrices (CWMs) as sequence logos
- Generating distribution plots for hit statistics
- Creating co-occurrence heatmaps
- Producing HTML reports with interactive visualizations
- Plotting confusion matrices and performance metrics
1"""Visualization module for generating plots and reports for Fi-NeMo results. 2 3This module provides functions for: 4- Plotting motif contribution weight matrices (CWMs) as sequence logos 5- Generating distribution plots for hit statistics 6- Creating co-occurrence heatmaps 7- Producing HTML reports with interactive visualizations 8- Plotting confusion matrices and performance metrics 9""" 10 11import os 12import importlib.resources 13from typing import List, Optional, Dict, Any, Tuple, Union, Mapping, Iterable 14 15import numpy as np 16from numpy import ndarray 17import matplotlib.pyplot as plt 18from matplotlib.axes import Axes 19from matplotlib.patheffects import AbstractPathEffect 20from matplotlib.textpath import TextPath 21from matplotlib.transforms import Affine2D 22from matplotlib.font_manager import FontProperties 23from jinja2 import Template 24import polars as pl 25from jaxtyping import Float, Int 26 27from . import templates 28 29 30def abbreviate_motif_name(name: str) -> str: 31 """Convert TF-MoDISco motif names to abbreviated format. 32 33 Converts full TF-MoDISco pattern names to shorter, more readable format 34 for display in plots and reports. 35 36 Parameters 37 ---------- 38 name : str 39 Full motif name (e.g., 'pos_patterns.pattern_0'). 40 41 Returns 42 ------- 43 str 44 Abbreviated name (e.g., '+/0') or original name if parsing fails. 45 46 Examples 47 -------- 48 >>> abbreviate_motif_name('pos_patterns.pattern_0') 49 '+/0' 50 >>> abbreviate_motif_name('neg_patterns.pattern_1') 51 '-/1' 52 >>> abbreviate_motif_name('invalid_name') 53 'invalid_name' 54 """ 55 try: 56 group, motif = name.split(".") 57 if group == "pos_patterns": 58 group_short = "+" 59 elif group == "neg_patterns": 60 group_short = "-" 61 else: 62 raise Exception 63 motif_num = motif.split("_")[1] 64 return f"{group_short}/{motif_num}" 65 except Exception: 66 return name 67 68 69def plot_hit_stat_distributions( 70 hits_df: pl.LazyFrame, motif_names: List[str], plot_dir: str 71) -> None: 72 """Plot distributions of hit statistics for each motif. 73 74 Creates separate histogram plots for coefficient, similarity, and importance 75 score distributions for each motif. Saves plots in both PNG (high-res) and 76 SVG (vector) formats. 77 78 Parameters 79 ---------- 80 hits_df : pl.LazyFrame 81 Lazy DataFrame containing hit data with required columns: 82 - motif_name : str, name of the motif 83 - hit_coefficient_global : float, global coefficient values 84 - hit_similarity : float, similarity scores to motif CWM 85 - hit_importance : float, importance scores from attribution 86 motif_names : List[str] 87 List of motif names to generate plots for. Motifs not present 88 in hits_df will result in empty histograms. 89 plot_dir : str 90 Directory path where plots will be saved. Creates subdirectory 91 'motif_stat_distributions' if it doesn't exist. 92 93 Notes 94 ----- 95 For each motif, creates three separate plots: 96 - {motif_name}_coefficients.{png,svg} : coefficient distribution 97 - {motif_name}_similarities.{png,svg} : similarity distribution 98 - {motif_name}_importances.{png,svg} : importance distribution 99 """ 100 hits_df_collected = hits_df.collect() 101 hits_by_motif = hits_df_collected.partition_by("motif_name", as_dict=True) 102 dummy_df = hits_df_collected.clear() 103 104 motifs_dir = os.path.join(plot_dir, "motif_stat_distributions") 105 os.makedirs(motifs_dir, exist_ok=True) 106 for m in motif_names: 107 hits = hits_by_motif.get((m,), dummy_df) 108 coefficients = hits.get_column("hit_coefficient_global").to_numpy() 109 similarities = hits.get_column("hit_similarity").to_numpy() 110 importances = hits.get_column("hit_importance").to_numpy() 111 112 fig, ax = plt.subplots(figsize=(5, 2)) 113 114 # Plot coefficient distribution 115 try: 116 ax.hist(coefficients, bins=50, density=True) 117 except ValueError: 118 ax.hist(coefficients, bins=1, density=True) 119 120 output_path_png = os.path.join(motifs_dir, f"{m}_coefficients.png") 121 plt.savefig(output_path_png, dpi=300) 122 output_path_svg = os.path.join(motifs_dir, f"{m}_coefficients.svg") 123 plt.savefig(output_path_svg) 124 plt.close(fig) 125 126 fig, ax = plt.subplots(figsize=(5, 2)) 127 128 # Plot similarity distribution 129 try: 130 ax.hist(similarities, bins=50, density=True) 131 except ValueError: 132 ax.hist(similarities, bins=1, density=True) 133 134 output_path_png = os.path.join(motifs_dir, f"{m}_similarities.png") 135 plt.savefig(output_path_png, dpi=300) 136 output_path_svg = os.path.join(motifs_dir, f"{m}_similarities.svg") 137 plt.savefig(output_path_svg) 138 plt.close(fig) 139 140 fig, ax = plt.subplots(figsize=(5, 2)) 141 142 # Plot importance distribution 143 try: 144 ax.hist(importances, bins=50, density=True) 145 except ValueError: 146 ax.hist(importances, bins=1, density=True) 147 148 output_path_png = os.path.join(motifs_dir, f"{m}_importances.png") 149 plt.savefig(output_path_png, dpi=300) 150 output_path_svg = os.path.join(motifs_dir, f"{m}_importances.svg") 151 plt.savefig(output_path_svg) 152 plt.close(fig) 153 154 155def plot_hit_peak_distributions( 156 occ_df: pl.DataFrame, motif_names: List[str], plot_dir: str 157) -> None: 158 """Plot distribution of hits per peak for each motif. 159 160 Creates bar plots showing the frequency distribution of hit counts per peak 161 for each motif, plus an overall distribution of total hits per peak. 162 163 Parameters 164 ---------- 165 occ_df : pl.DataFrame 166 DataFrame containing motif occurrence counts per peak. Expected to have: 167 - One column per motif name with integer hit counts 168 - 'total' column with sum of all motif hits per peak 169 - Each row represents a peak/genomic region 170 motif_names : List[str] 171 List of motif names corresponding to columns in occ_df. 172 plot_dir : str 173 Directory to save plots. Creates 'motif_hit_distributions' subdirectory. 174 175 Notes 176 ----- 177 Generates the following plots: 178 - Individual motif hit distributions: {motif_name}.{png,svg} 179 - Overall hit distribution: total_hit_distribution.{png,svg} 180 181 Bar plots show frequency (proportion) on y-axis and hit count on x-axis. 182 """ 183 motifs_dir = os.path.join(plot_dir, "motif_hit_distributions") 184 os.makedirs(motifs_dir, exist_ok=True) 185 186 for m in motif_names: 187 fig, ax = plt.subplots(figsize=(5, 2)) 188 189 unique, counts = np.unique(occ_df.get_column(m), return_counts=True) 190 freq = counts / counts.sum() 191 num_bins = np.amax(unique, initial=0) + 1 192 x = np.arange(num_bins) 193 y = np.zeros(num_bins) 194 y[unique] = freq 195 ax.bar(x, y) 196 197 output_path_png = os.path.join(motifs_dir, f"{m}.png") 198 plt.savefig(output_path_png, dpi=300) 199 output_path_svg = os.path.join(motifs_dir, f"{m}.svg") 200 plt.savefig(output_path_svg) 201 202 plt.close(fig) 203 204 fig, ax = plt.subplots(figsize=(8, 4)) 205 206 unique, counts = np.unique(occ_df.get_column("total"), return_counts=True) 207 freq = counts / counts.sum() 208 num_bins = np.amax(unique, initial=0) + 1 209 x = np.arange(num_bins) 210 y = np.zeros(num_bins) 211 y[unique] = freq 212 ax.bar(x, y) 213 214 ax.set_xlabel("Total hits per region") 215 ax.set_ylabel("Frequency") 216 217 output_path_png = os.path.join(plot_dir, "total_hit_distribution.png") 218 plt.savefig(output_path_png, dpi=300) 219 output_path_svg = os.path.join(plot_dir, "total_hit_distribution.svg") 220 plt.savefig(output_path_svg, dpi=300) 221 222 plt.close(fig) 223 224 225def plot_peak_motif_indicator_heatmap( 226 peak_hit_counts: Int[ndarray, "M M"], motif_names: List[str], output_dir: str 227) -> None: 228 """Plot co-occurrence heatmap showing motif associations across peaks. 229 230 Creates a normalized correlation heatmap showing how frequently pairs of 231 motifs co-occur within the same genomic peaks. Values are normalized by 232 the geometric mean of individual motif frequencies. 233 234 Parameters 235 ---------- 236 peak_hit_counts : Int[ndarray, "M M"] 237 Co-occurrence matrix where M = len(motif_names). 238 Entry (i,j) represents the number of peaks containing both motif i and j. 239 Diagonal entries represent total peaks containing each individual motif. 240 motif_names : List[str] 241 List of motif names for axis labels. Order must match matrix dimensions. 242 output_dir : str 243 Directory path where the heatmap plots will be saved. 244 245 Notes 246 ----- 247 Saves plots as: 248 - motif_cooocurrence.png : High-resolution raster format 249 - motif_cooocurrence.svg : Vector format 250 251 The heatmap uses correlation normalization: matrix[i,j] / sqrt(matrix[i,i] * matrix[j,j]) 252 Colors use the 'Greens' colormap with values typically in [0, 1] range. 253 """ 254 cov_norm = 1 / np.sqrt(np.diag(peak_hit_counts)) 255 matrix = peak_hit_counts * cov_norm[:, None] * cov_norm[None, :] 256 motif_keys = [abbreviate_motif_name(m) for m in motif_names] 257 258 fig, ax = plt.subplots(figsize=(8, 8), layout="constrained") 259 260 # Plot the heatmap 261 cax = ax.imshow(matrix, interpolation="nearest", aspect="equal", cmap="Greens") 262 263 # Set axes on heatmap 264 ax.set_yticks(np.arange(len(motif_keys))) 265 ax.set_yticklabels(motif_keys) 266 ax.set_xticks(np.arange(len(motif_keys))) 267 ax.set_xticklabels(motif_keys, rotation=90) 268 ax.set_xlabel("Motif i") 269 ax.set_ylabel("Motif j") 270 271 ax.tick_params(axis="both", labelsize=8) 272 273 cbar = fig.colorbar(cax, ax=ax, orientation="vertical", shrink=0.6, aspect=30) 274 cbar.ax.tick_params(labelsize=8) 275 276 output_path_png = os.path.join(output_dir, "motif_cooocurrence.png") 277 plt.savefig(output_path_png, dpi=300) 278 output_path_svg = os.path.join(output_dir, "motif_cooocurrence.svg") 279 plt.savefig(output_path_svg, dpi=300) 280 281 plt.close() 282 283 284def plot_seqlet_confusion_heatmap( 285 seqlet_confusion: Int[ndarray, "M M"], motif_names: List[str], output_dir: str 286) -> None: 287 """Plot confusion matrix heatmap comparing seqlets to hit calls. 288 289 Creates a heatmap showing the overlap between TF-MoDISco seqlets and 290 Fi-NeMo hit calls. Rows represent seqlet motifs, columns represent hit motifs. 291 292 Parameters 293 ---------- 294 seqlet_confusion : Int[ndarray, "M M"] 295 Confusion matrix where M = len(motif_names). 296 Entry (i,j) represents the number of seqlets of motif i that overlap 297 with hits called for motif j. 298 motif_names : List[str] 299 List of motif names for axis labels. Order must match matrix dimensions. 300 output_dir : str 301 Directory path where the confusion matrix plots will be saved. 302 303 Notes 304 ----- 305 Saves plots as: 306 - seqlet_confusion.png : High-resolution raster format 307 - seqlet_confusion.svg : Vector format 308 309 The heatmap uses 'Blues' colormap. Perfect agreement would show a diagonal 310 pattern with high values along the diagonal and low off-diagonal values. 311 """ 312 motif_keys = [abbreviate_motif_name(m) for m in motif_names] 313 314 fig, ax = plt.subplots(figsize=(8, 8), layout="constrained") 315 316 # Plot the heatmap 317 cax = ax.imshow( 318 seqlet_confusion, interpolation="nearest", aspect="equal", cmap="Blues" 319 ) 320 321 # Set axes on heatmap 322 ax.set_yticks(np.arange(len(motif_keys))) 323 ax.set_yticklabels(motif_keys) 324 ax.set_xticks(np.arange(len(motif_keys))) 325 ax.set_xticklabels(motif_keys, rotation=90) 326 ax.set_xlabel("Hit motif") 327 ax.set_ylabel("Seqlet motif") 328 329 ax.tick_params(axis="both", labelsize=8) 330 331 cbar = fig.colorbar(cax, ax=ax, orientation="vertical", shrink=0.6, aspect=30) 332 cbar.ax.tick_params(labelsize=8) 333 334 output_path_png = os.path.join(output_dir, "seqlet_confusion.png") 335 plt.savefig(output_path_png, dpi=300) 336 output_path_svg = os.path.join(output_dir, "seqlet_confusion.svg") 337 plt.savefig(output_path_svg, dpi=300) 338 339 plt.close() 340 341 342class LogoGlyph(AbstractPathEffect): 343 """Path effect for creating sequence logo glyphs with normalized dimensions. 344 345 This class creates properly scaled and positioned text glyphs for sequence 346 logos by normalizing character dimensions and applying appropriate transforms. 347 348 Parameters 349 ---------- 350 glyph : str 351 Single character to render (e.g., 'A', 'C', 'G', 'T'). 352 ref_glyph : str, default 'E' 353 Reference character used for width normalization. 354 font_props : FontProperties, optional 355 Font properties for the glyph rendering. 356 offset : Tuple[float, float], default (0., 0.) 357 Offset for glyph positioning. 358 **kwargs 359 Additional graphics collection parameters. 360 """ 361 362 def __init__( 363 self, 364 glyph: str, 365 ref_glyph: str = "E", 366 font_props: Optional[FontProperties] = None, 367 offset: Tuple[float, float] = (0.0, 0.0), 368 **kwargs, 369 ) -> None: 370 super().__init__(offset) 371 372 path_orig = TextPath((0, 0), glyph, size=1, prop=font_props) 373 dims = path_orig.get_extents() 374 ref_dims = TextPath((0, 0), ref_glyph, size=1, prop=font_props).get_extents() 375 376 h_scale = 1 / dims.height 377 ref_width = max(dims.width, ref_dims.width) 378 w_scale = 1 / ref_width 379 w_shift = (1 - dims.width / ref_width) / 2 380 x_shift = -dims.x0 381 y_shift = -dims.y0 382 stretch = ( 383 Affine2D() 384 .translate(tx=x_shift, ty=y_shift) 385 .scale(sx=w_scale, sy=h_scale) 386 .translate(tx=w_shift, ty=0) 387 ) 388 389 self.path = stretch.transform_path(path_orig) 390 391 #: The dictionary of keywords to update the graphics collection with. 392 self._gc = kwargs 393 394 def draw_path(self, renderer, gc, tpath, affine, rgbFace) -> Any: # type: ignore[override] 395 """Draw the glyph path using the renderer. 396 397 Parameters 398 ---------- 399 renderer : matplotlib renderer 400 The renderer to draw with. 401 gc : GraphicsContext 402 Graphics context for drawing properties. 403 tpath : Path 404 Original text path (unused, using self.path instead). 405 affine : Transform 406 Affine transformation to apply. 407 rgbFace : color 408 Face color for the glyph. 409 410 Returns 411 ------- 412 Any 413 Result from renderer.draw_path. 414 """ 415 return renderer.draw_path(gc, self.path, affine, rgbFace) 416 417 418def plot_logo( 419 ax: Axes, 420 heights: Float[ndarray, "B W"], 421 glyphs: Iterable[str], 422 colors: Optional[Mapping[str, Optional[str]]] = None, 423 font_props: Optional[FontProperties] = None, 424 shade_bounds: Optional[Tuple[int, int]] = None, 425) -> None: 426 """Plot sequence logo from contribution weight matrix. 427 428 Creates a sequence logo visualization where letter heights represent 429 the contribution or information content at each position. Supports 430 both positive and negative contributions with proper stacking. 431 432 Parameters 433 ---------- 434 ax : Axes 435 Matplotlib axes object to plot on. 436 heights : Float[ndarray, "B W"] 437 Height matrix where B = len(glyphs) and W = motif width. 438 Entry (i,j) represents the height/contribution of base i at position j. 439 Can contain both positive and negative values. 440 glyphs : Iterable[str] 441 Sequence of base characters corresponding to rows in heights matrix. 442 Typically ['A', 'C', 'G', 'T'] for DNA. 443 colors : Dict[str, str], optional 444 Color mapping for each base. Keys should match glyphs. 445 If None, all bases will use default matplotlib colors. 446 font_props : FontProperties, optional 447 Font properties for letter rendering. If None, uses default font. 448 shade_bounds : Tuple[int, int], optional 449 (start, end) position indices to shade in background. 450 Useful for highlighting core motif regions. 451 452 Notes 453 ----- 454 Positive and negative contributions are handled separately: 455 - Positive values are stacked above zero line in order of descending absolute value 456 - Negative values are stacked below zero line in order of descending absolute value 457 - A horizontal line is drawn at y=0 for reference 458 459 The resulting plot has: 460 - X-axis: Position in motif (0-indexed) 461 - Y-axis: Contribution magnitude 462 - Bar width: 0.95 (small gaps between positions) 463 """ 464 if colors is None: 465 colors = {g: None for g in glyphs} 466 467 ax.margins(x=0, y=0) 468 469 pos_values = np.clip(heights, 0, None) 470 neg_values = np.clip(heights, None, 0) 471 pos_order = np.argsort(pos_values, axis=0) 472 neg_order = np.argsort(neg_values, axis=0)[::-1, :] 473 pos_reorder = np.argsort(pos_order, axis=0) 474 neg_reorder = np.argsort(neg_order, axis=0) 475 pos_offsets = np.take_along_axis( 476 np.cumsum(np.take_along_axis(pos_values, pos_order, axis=0), axis=0), 477 pos_reorder, 478 axis=0, 479 ) 480 neg_offsets = np.take_along_axis( 481 np.cumsum(np.take_along_axis(neg_values, neg_order, axis=0), axis=0), 482 neg_reorder, 483 axis=0, 484 ) 485 bottoms = pos_offsets + neg_offsets - heights 486 487 x = np.arange(heights.shape[1]) 488 489 for glyph, height, bottom in zip(glyphs, heights, bottoms): 490 ax.bar( 491 x, 492 height, 493 0.95, 494 bottom=bottom, 495 path_effects=[LogoGlyph(glyph, font_props=font_props)], 496 color=colors[glyph], 497 ) 498 499 if shade_bounds is not None: 500 start, end = shade_bounds 501 ax.axvspan(start - 0.5, end - 0.5, color="0.9", zorder=-1) 502 503 ax.axhline(zorder=-1, linewidth=0.5, color="black") 504 505 506LOGO_ALPHABET = "ACGT" 507LOGO_COLORS = {"A": "#109648", "C": "#255C99", "G": "#F7B32B", "T": "#D62839"} 508LOGO_FONT = FontProperties(weight="bold") 509 510 511def plot_cwms( 512 cwms: Dict[str, Dict[str, Float[ndarray, "4 W"]]], 513 trim_bounds: Dict[str, Dict[str, Tuple[int, int]]], 514 out_dir: str, 515 alphabet: str = LOGO_ALPHABET, 516 colors: Dict[str, str] = LOGO_COLORS, 517 font: FontProperties = LOGO_FONT, 518) -> None: 519 """Plot contribution weight matrices as sequence logos. 520 521 Creates sequence logo plots for all motifs and CWM types, with optional 522 shading to highlight trimmed regions. Saves plots in both PNG and SVG formats. 523 524 Parameters 525 ---------- 526 cwms : Dict[str, Dict[str, Float[ndarray, "4 W"]]] 527 Nested dictionary structure: {motif_name: {cwm_type: cwm_array}}. 528 Each cwm_array has shape (4, W) where W is motif width. 529 Rows correspond to bases in alphabet order. 530 trim_bounds : Dict[str, Dict[str, Tuple[int, int]]] 531 Nested dictionary: {motif_name: {cwm_type: (start, end)}}. 532 Defines regions to shade in the sequence logos. 533 out_dir : str 534 Output directory where motif subdirectories will be created. 535 alphabet : str, default LOGO_ALPHABET 536 DNA alphabet string, typically 'ACGT'. 537 colors : Dict[str, str], default LOGO_COLORS 538 Color mapping for DNA bases. Keys should match alphabet characters. 539 font : FontProperties, default LOGO_FONT 540 Font properties for sequence logo rendering. 541 542 Notes 543 ----- 544 Directory structure created: 545 ``` 546 out_dir/ 547 ├── motif1/ 548 │ ├── cwm_type1.png 549 │ ├── cwm_type1.svg 550 │ └── ... 551 └── motif2/ 552 └── ... 553 ``` 554 555 Each plot is 10x2 inches with trimmed regions shaded if specified. 556 Spines (plot borders) are hidden for cleaner appearance. 557 """ 558 for m, v in cwms.items(): 559 motif_dir = os.path.join(out_dir, m) 560 os.makedirs(motif_dir, exist_ok=True) 561 for cwm_type, cwm in v.items(): 562 fig, ax = plt.subplots(figsize=(10, 2)) 563 564 plot_logo( 565 ax, 566 cwm, 567 alphabet, 568 colors=colors, 569 font_props=font, 570 shade_bounds=trim_bounds[m][cwm_type], 571 ) 572 573 for name, spine in ax.spines.items(): 574 spine.set_visible(False) 575 576 output_path_png = os.path.join(motif_dir, f"{cwm_type}.png") 577 plt.savefig(output_path_png, dpi=100) 578 output_path_svg = os.path.join(motif_dir, f"{cwm_type}.svg") 579 plt.savefig(output_path_svg) 580 581 plt.close(fig) 582 583 584def plot_hit_vs_seqlet_counts( 585 recall_data: Dict[str, Dict[str, Union[int, float]]], output_dir: str 586) -> None: 587 """Plot scatter plot comparing hit counts to seqlet counts per motif. 588 589 Creates a log-log scatter plot showing the relationship between the number 590 of hits called by Fi-NeMo and the number of seqlets identified by TF-MoDISco 591 for each motif. Includes diagonal reference line and motif annotations. 592 593 Parameters 594 ---------- 595 recall_data : Dict[str, Dict[str, Union[int, float]]] 596 Dictionary with motif names as keys and metrics dictionaries as values. 597 Each metrics dictionary must contain: 598 - 'num_hits_total' : int, total number of hits for the motif 599 - 'num_seqlets' : int, total number of seqlets for the motif 600 output_dir : str 601 Directory path where the scatter plot will be saved. 602 603 Notes 604 ----- 605 Saves plots as: 606 - hit_vs_seqlet_counts.png : High-resolution raster format 607 - hit_vs_seqlet_counts.svg : Vector format 608 609 Plot features: 610 - Log-log scale on both axes 611 - Diagonal reference line (y = x) as dashed line 612 - Points annotated with abbreviated motif names 613 """ 614 x = [] 615 y = [] 616 m = [] 617 for k, v in recall_data.items(): 618 x.append(v["num_hits_total"]) 619 y.append(v["num_seqlets"]) 620 m.append(k) 621 622 lim = max(np.amax(x), np.amax(y)) 623 624 fig, ax = plt.subplots(figsize=(8, 8), layout="constrained") 625 ax.axline((0, 0), (lim, lim), color="0.3", linewidth=0.7, linestyle=(0, (5, 5))) 626 ax.scatter(x, y, s=5) 627 for i, txt in enumerate(m): 628 short = abbreviate_motif_name(txt) 629 ax.annotate(short, (x[i], y[i]), fontsize=8, weight="bold") 630 631 ax.set_yscale("log") 632 ax.set_xscale("log") 633 634 ax.set_xlabel("Hits per motif") 635 ax.set_ylabel("Seqlets per motif") 636 637 output_path_png = os.path.join(output_dir, "hit_vs_seqlet_counts.png") 638 plt.savefig(output_path_png, dpi=300) 639 output_path_svg = os.path.join(output_dir, "hit_vs_seqlet_counts.svg") 640 plt.savefig(output_path_svg) 641 642 plt.close() 643 644 645def write_report( 646 report_df: pl.DataFrame, 647 motif_names: List[str], 648 out_path: str, 649 compute_recall: bool, 650 use_seqlets: bool, 651) -> None: 652 """Generate and write HTML report from motif analysis results. 653 654 Creates a comprehensive HTML report with tables and visualizations 655 summarizing the Fi-NeMo motif discovery and hit calling results. 656 657 Parameters 658 ---------- 659 report_df : pl.DataFrame 660 DataFrame containing motif statistics and performance metrics. 661 Expected columns depend on compute_recall and use_seqlets flags. 662 motif_names : List[str] 663 List of motif names to include in the report. 664 Order determines presentation sequence in the report. 665 out_path : str 666 File path where the HTML report will be written. 667 Parent directory must exist. 668 compute_recall : bool 669 Whether recall metrics were computed and should be included 670 in the report template. 671 use_seqlets : bool 672 Whether TF-MoDISco seqlet data was used in the analysis 673 and should be referenced in the report. 674 675 Notes 676 ----- 677 Uses Jinja2 templating with the report.html template from the 678 templates package. The template receives: 679 - report_data: Iterator of DataFrame rows as named tuples 680 - motif_names: List of motif names 681 - compute_recall: Boolean flag for recall metrics 682 - use_seqlets: Boolean flag for seqlet usage 683 684 Raises 685 ------ 686 OSError 687 If the output path cannot be written. 688 """ 689 template_str = ( 690 importlib.resources.files(templates).joinpath("report.html").read_text() 691 ) 692 template = Template(template_str) 693 report = template.render( 694 report_data=report_df.iter_rows(named=True), 695 motif_names=motif_names, 696 compute_recall=compute_recall, 697 use_seqlets=use_seqlets, 698 ) 699 with open(out_path, "w") as f: 700 f.write(report)
31def abbreviate_motif_name(name: str) -> str: 32 """Convert TF-MoDISco motif names to abbreviated format. 33 34 Converts full TF-MoDISco pattern names to shorter, more readable format 35 for display in plots and reports. 36 37 Parameters 38 ---------- 39 name : str 40 Full motif name (e.g., 'pos_patterns.pattern_0'). 41 42 Returns 43 ------- 44 str 45 Abbreviated name (e.g., '+/0') or original name if parsing fails. 46 47 Examples 48 -------- 49 >>> abbreviate_motif_name('pos_patterns.pattern_0') 50 '+/0' 51 >>> abbreviate_motif_name('neg_patterns.pattern_1') 52 '-/1' 53 >>> abbreviate_motif_name('invalid_name') 54 'invalid_name' 55 """ 56 try: 57 group, motif = name.split(".") 58 if group == "pos_patterns": 59 group_short = "+" 60 elif group == "neg_patterns": 61 group_short = "-" 62 else: 63 raise Exception 64 motif_num = motif.split("_")[1] 65 return f"{group_short}/{motif_num}" 66 except Exception: 67 return name
Convert TF-MoDISco motif names to abbreviated format.
Converts full TF-MoDISco pattern names to shorter, more readable format for display in plots and reports.
Parameters
- name (str): Full motif name (e.g., 'pos_patterns.pattern_0').
Returns
- str: Abbreviated name (e.g., '+/0') or original name if parsing fails.
Examples
>>> abbreviate_motif_name('pos_patterns.pattern_0')
'+/0'
>>> abbreviate_motif_name('neg_patterns.pattern_1')
'-/1'
>>> abbreviate_motif_name('invalid_name')
'invalid_name'
70def plot_hit_stat_distributions( 71 hits_df: pl.LazyFrame, motif_names: List[str], plot_dir: str 72) -> None: 73 """Plot distributions of hit statistics for each motif. 74 75 Creates separate histogram plots for coefficient, similarity, and importance 76 score distributions for each motif. Saves plots in both PNG (high-res) and 77 SVG (vector) formats. 78 79 Parameters 80 ---------- 81 hits_df : pl.LazyFrame 82 Lazy DataFrame containing hit data with required columns: 83 - motif_name : str, name of the motif 84 - hit_coefficient_global : float, global coefficient values 85 - hit_similarity : float, similarity scores to motif CWM 86 - hit_importance : float, importance scores from attribution 87 motif_names : List[str] 88 List of motif names to generate plots for. Motifs not present 89 in hits_df will result in empty histograms. 90 plot_dir : str 91 Directory path where plots will be saved. Creates subdirectory 92 'motif_stat_distributions' if it doesn't exist. 93 94 Notes 95 ----- 96 For each motif, creates three separate plots: 97 - {motif_name}_coefficients.{png,svg} : coefficient distribution 98 - {motif_name}_similarities.{png,svg} : similarity distribution 99 - {motif_name}_importances.{png,svg} : importance distribution 100 """ 101 hits_df_collected = hits_df.collect() 102 hits_by_motif = hits_df_collected.partition_by("motif_name", as_dict=True) 103 dummy_df = hits_df_collected.clear() 104 105 motifs_dir = os.path.join(plot_dir, "motif_stat_distributions") 106 os.makedirs(motifs_dir, exist_ok=True) 107 for m in motif_names: 108 hits = hits_by_motif.get((m,), dummy_df) 109 coefficients = hits.get_column("hit_coefficient_global").to_numpy() 110 similarities = hits.get_column("hit_similarity").to_numpy() 111 importances = hits.get_column("hit_importance").to_numpy() 112 113 fig, ax = plt.subplots(figsize=(5, 2)) 114 115 # Plot coefficient distribution 116 try: 117 ax.hist(coefficients, bins=50, density=True) 118 except ValueError: 119 ax.hist(coefficients, bins=1, density=True) 120 121 output_path_png = os.path.join(motifs_dir, f"{m}_coefficients.png") 122 plt.savefig(output_path_png, dpi=300) 123 output_path_svg = os.path.join(motifs_dir, f"{m}_coefficients.svg") 124 plt.savefig(output_path_svg) 125 plt.close(fig) 126 127 fig, ax = plt.subplots(figsize=(5, 2)) 128 129 # Plot similarity distribution 130 try: 131 ax.hist(similarities, bins=50, density=True) 132 except ValueError: 133 ax.hist(similarities, bins=1, density=True) 134 135 output_path_png = os.path.join(motifs_dir, f"{m}_similarities.png") 136 plt.savefig(output_path_png, dpi=300) 137 output_path_svg = os.path.join(motifs_dir, f"{m}_similarities.svg") 138 plt.savefig(output_path_svg) 139 plt.close(fig) 140 141 fig, ax = plt.subplots(figsize=(5, 2)) 142 143 # Plot importance distribution 144 try: 145 ax.hist(importances, bins=50, density=True) 146 except ValueError: 147 ax.hist(importances, bins=1, density=True) 148 149 output_path_png = os.path.join(motifs_dir, f"{m}_importances.png") 150 plt.savefig(output_path_png, dpi=300) 151 output_path_svg = os.path.join(motifs_dir, f"{m}_importances.svg") 152 plt.savefig(output_path_svg) 153 plt.close(fig)
Plot distributions of hit statistics for each motif.
Creates separate histogram plots for coefficient, similarity, and importance score distributions for each motif. Saves plots in both PNG (high-res) and SVG (vector) formats.
Parameters
- hits_df (pl.LazyFrame):
Lazy DataFrame containing hit data with required columns:
- motif_name : str, name of the motif
- hit_coefficient_global : float, global coefficient values
- hit_similarity : float, similarity scores to motif CWM
- hit_importance : float, importance scores from attribution
- motif_names (List[str]): List of motif names to generate plots for. Motifs not present in hits_df will result in empty histograms.
- plot_dir (str): Directory path where plots will be saved. Creates subdirectory 'motif_stat_distributions' if it doesn't exist.
Notes
For each motif, creates three separate plots:
- {motif_name}_coefficients.{png,svg} : coefficient distribution
- {motif_name}_similarities.{png,svg} : similarity distribution
- {motif_name}_importances.{png,svg} : importance distribution
156def plot_hit_peak_distributions( 157 occ_df: pl.DataFrame, motif_names: List[str], plot_dir: str 158) -> None: 159 """Plot distribution of hits per peak for each motif. 160 161 Creates bar plots showing the frequency distribution of hit counts per peak 162 for each motif, plus an overall distribution of total hits per peak. 163 164 Parameters 165 ---------- 166 occ_df : pl.DataFrame 167 DataFrame containing motif occurrence counts per peak. Expected to have: 168 - One column per motif name with integer hit counts 169 - 'total' column with sum of all motif hits per peak 170 - Each row represents a peak/genomic region 171 motif_names : List[str] 172 List of motif names corresponding to columns in occ_df. 173 plot_dir : str 174 Directory to save plots. Creates 'motif_hit_distributions' subdirectory. 175 176 Notes 177 ----- 178 Generates the following plots: 179 - Individual motif hit distributions: {motif_name}.{png,svg} 180 - Overall hit distribution: total_hit_distribution.{png,svg} 181 182 Bar plots show frequency (proportion) on y-axis and hit count on x-axis. 183 """ 184 motifs_dir = os.path.join(plot_dir, "motif_hit_distributions") 185 os.makedirs(motifs_dir, exist_ok=True) 186 187 for m in motif_names: 188 fig, ax = plt.subplots(figsize=(5, 2)) 189 190 unique, counts = np.unique(occ_df.get_column(m), return_counts=True) 191 freq = counts / counts.sum() 192 num_bins = np.amax(unique, initial=0) + 1 193 x = np.arange(num_bins) 194 y = np.zeros(num_bins) 195 y[unique] = freq 196 ax.bar(x, y) 197 198 output_path_png = os.path.join(motifs_dir, f"{m}.png") 199 plt.savefig(output_path_png, dpi=300) 200 output_path_svg = os.path.join(motifs_dir, f"{m}.svg") 201 plt.savefig(output_path_svg) 202 203 plt.close(fig) 204 205 fig, ax = plt.subplots(figsize=(8, 4)) 206 207 unique, counts = np.unique(occ_df.get_column("total"), return_counts=True) 208 freq = counts / counts.sum() 209 num_bins = np.amax(unique, initial=0) + 1 210 x = np.arange(num_bins) 211 y = np.zeros(num_bins) 212 y[unique] = freq 213 ax.bar(x, y) 214 215 ax.set_xlabel("Total hits per region") 216 ax.set_ylabel("Frequency") 217 218 output_path_png = os.path.join(plot_dir, "total_hit_distribution.png") 219 plt.savefig(output_path_png, dpi=300) 220 output_path_svg = os.path.join(plot_dir, "total_hit_distribution.svg") 221 plt.savefig(output_path_svg, dpi=300) 222 223 plt.close(fig)
Plot distribution of hits per peak for each motif.
Creates bar plots showing the frequency distribution of hit counts per peak for each motif, plus an overall distribution of total hits per peak.
Parameters
- occ_df (pl.DataFrame):
DataFrame containing motif occurrence counts per peak. Expected to have:
- One column per motif name with integer hit counts
- 'total' column with sum of all motif hits per peak
- Each row represents a peak/genomic region
- motif_names (List[str]): List of motif names corresponding to columns in occ_df.
- plot_dir (str): Directory to save plots. Creates 'motif_hit_distributions' subdirectory.
Notes
Generates the following plots:
- Individual motif hit distributions: {motif_name}.{png,svg}
- Overall hit distribution: total_hit_distribution.{png,svg}
Bar plots show frequency (proportion) on y-axis and hit count on x-axis.
226def plot_peak_motif_indicator_heatmap( 227 peak_hit_counts: Int[ndarray, "M M"], motif_names: List[str], output_dir: str 228) -> None: 229 """Plot co-occurrence heatmap showing motif associations across peaks. 230 231 Creates a normalized correlation heatmap showing how frequently pairs of 232 motifs co-occur within the same genomic peaks. Values are normalized by 233 the geometric mean of individual motif frequencies. 234 235 Parameters 236 ---------- 237 peak_hit_counts : Int[ndarray, "M M"] 238 Co-occurrence matrix where M = len(motif_names). 239 Entry (i,j) represents the number of peaks containing both motif i and j. 240 Diagonal entries represent total peaks containing each individual motif. 241 motif_names : List[str] 242 List of motif names for axis labels. Order must match matrix dimensions. 243 output_dir : str 244 Directory path where the heatmap plots will be saved. 245 246 Notes 247 ----- 248 Saves plots as: 249 - motif_cooocurrence.png : High-resolution raster format 250 - motif_cooocurrence.svg : Vector format 251 252 The heatmap uses correlation normalization: matrix[i,j] / sqrt(matrix[i,i] * matrix[j,j]) 253 Colors use the 'Greens' colormap with values typically in [0, 1] range. 254 """ 255 cov_norm = 1 / np.sqrt(np.diag(peak_hit_counts)) 256 matrix = peak_hit_counts * cov_norm[:, None] * cov_norm[None, :] 257 motif_keys = [abbreviate_motif_name(m) for m in motif_names] 258 259 fig, ax = plt.subplots(figsize=(8, 8), layout="constrained") 260 261 # Plot the heatmap 262 cax = ax.imshow(matrix, interpolation="nearest", aspect="equal", cmap="Greens") 263 264 # Set axes on heatmap 265 ax.set_yticks(np.arange(len(motif_keys))) 266 ax.set_yticklabels(motif_keys) 267 ax.set_xticks(np.arange(len(motif_keys))) 268 ax.set_xticklabels(motif_keys, rotation=90) 269 ax.set_xlabel("Motif i") 270 ax.set_ylabel("Motif j") 271 272 ax.tick_params(axis="both", labelsize=8) 273 274 cbar = fig.colorbar(cax, ax=ax, orientation="vertical", shrink=0.6, aspect=30) 275 cbar.ax.tick_params(labelsize=8) 276 277 output_path_png = os.path.join(output_dir, "motif_cooocurrence.png") 278 plt.savefig(output_path_png, dpi=300) 279 output_path_svg = os.path.join(output_dir, "motif_cooocurrence.svg") 280 plt.savefig(output_path_svg, dpi=300) 281 282 plt.close()
Plot co-occurrence heatmap showing motif associations across peaks.
Creates a normalized correlation heatmap showing how frequently pairs of motifs co-occur within the same genomic peaks. Values are normalized by the geometric mean of individual motif frequencies.
Parameters
- peak_hit_counts (Int[ndarray, "M M"]): Co-occurrence matrix where M = len(motif_names). Entry (i,j) represents the number of peaks containing both motif i and j. Diagonal entries represent total peaks containing each individual motif.
- motif_names (List[str]): List of motif names for axis labels. Order must match matrix dimensions.
- output_dir (str): Directory path where the heatmap plots will be saved.
Notes
Saves plots as:
- motif_cooocurrence.png : High-resolution raster format
- motif_cooocurrence.svg : Vector format
The heatmap uses correlation normalization: matrix[i,j] / sqrt(matrix[i,i] * matrix[j,j]) Colors use the 'Greens' colormap with values typically in [0, 1] range.
285def plot_seqlet_confusion_heatmap( 286 seqlet_confusion: Int[ndarray, "M M"], motif_names: List[str], output_dir: str 287) -> None: 288 """Plot confusion matrix heatmap comparing seqlets to hit calls. 289 290 Creates a heatmap showing the overlap between TF-MoDISco seqlets and 291 Fi-NeMo hit calls. Rows represent seqlet motifs, columns represent hit motifs. 292 293 Parameters 294 ---------- 295 seqlet_confusion : Int[ndarray, "M M"] 296 Confusion matrix where M = len(motif_names). 297 Entry (i,j) represents the number of seqlets of motif i that overlap 298 with hits called for motif j. 299 motif_names : List[str] 300 List of motif names for axis labels. Order must match matrix dimensions. 301 output_dir : str 302 Directory path where the confusion matrix plots will be saved. 303 304 Notes 305 ----- 306 Saves plots as: 307 - seqlet_confusion.png : High-resolution raster format 308 - seqlet_confusion.svg : Vector format 309 310 The heatmap uses 'Blues' colormap. Perfect agreement would show a diagonal 311 pattern with high values along the diagonal and low off-diagonal values. 312 """ 313 motif_keys = [abbreviate_motif_name(m) for m in motif_names] 314 315 fig, ax = plt.subplots(figsize=(8, 8), layout="constrained") 316 317 # Plot the heatmap 318 cax = ax.imshow( 319 seqlet_confusion, interpolation="nearest", aspect="equal", cmap="Blues" 320 ) 321 322 # Set axes on heatmap 323 ax.set_yticks(np.arange(len(motif_keys))) 324 ax.set_yticklabels(motif_keys) 325 ax.set_xticks(np.arange(len(motif_keys))) 326 ax.set_xticklabels(motif_keys, rotation=90) 327 ax.set_xlabel("Hit motif") 328 ax.set_ylabel("Seqlet motif") 329 330 ax.tick_params(axis="both", labelsize=8) 331 332 cbar = fig.colorbar(cax, ax=ax, orientation="vertical", shrink=0.6, aspect=30) 333 cbar.ax.tick_params(labelsize=8) 334 335 output_path_png = os.path.join(output_dir, "seqlet_confusion.png") 336 plt.savefig(output_path_png, dpi=300) 337 output_path_svg = os.path.join(output_dir, "seqlet_confusion.svg") 338 plt.savefig(output_path_svg, dpi=300) 339 340 plt.close()
Plot confusion matrix heatmap comparing seqlets to hit calls.
Creates a heatmap showing the overlap between TF-MoDISco seqlets and Fi-NeMo hit calls. Rows represent seqlet motifs, columns represent hit motifs.
Parameters
- seqlet_confusion (Int[ndarray, "M M"]): Confusion matrix where M = len(motif_names). Entry (i,j) represents the number of seqlets of motif i that overlap with hits called for motif j.
- motif_names (List[str]): List of motif names for axis labels. Order must match matrix dimensions.
- output_dir (str): Directory path where the confusion matrix plots will be saved.
Notes
Saves plots as:
- seqlet_confusion.png : High-resolution raster format
- seqlet_confusion.svg : Vector format
The heatmap uses 'Blues' colormap. Perfect agreement would show a diagonal pattern with high values along the diagonal and low off-diagonal values.
343class LogoGlyph(AbstractPathEffect): 344 """Path effect for creating sequence logo glyphs with normalized dimensions. 345 346 This class creates properly scaled and positioned text glyphs for sequence 347 logos by normalizing character dimensions and applying appropriate transforms. 348 349 Parameters 350 ---------- 351 glyph : str 352 Single character to render (e.g., 'A', 'C', 'G', 'T'). 353 ref_glyph : str, default 'E' 354 Reference character used for width normalization. 355 font_props : FontProperties, optional 356 Font properties for the glyph rendering. 357 offset : Tuple[float, float], default (0., 0.) 358 Offset for glyph positioning. 359 **kwargs 360 Additional graphics collection parameters. 361 """ 362 363 def __init__( 364 self, 365 glyph: str, 366 ref_glyph: str = "E", 367 font_props: Optional[FontProperties] = None, 368 offset: Tuple[float, float] = (0.0, 0.0), 369 **kwargs, 370 ) -> None: 371 super().__init__(offset) 372 373 path_orig = TextPath((0, 0), glyph, size=1, prop=font_props) 374 dims = path_orig.get_extents() 375 ref_dims = TextPath((0, 0), ref_glyph, size=1, prop=font_props).get_extents() 376 377 h_scale = 1 / dims.height 378 ref_width = max(dims.width, ref_dims.width) 379 w_scale = 1 / ref_width 380 w_shift = (1 - dims.width / ref_width) / 2 381 x_shift = -dims.x0 382 y_shift = -dims.y0 383 stretch = ( 384 Affine2D() 385 .translate(tx=x_shift, ty=y_shift) 386 .scale(sx=w_scale, sy=h_scale) 387 .translate(tx=w_shift, ty=0) 388 ) 389 390 self.path = stretch.transform_path(path_orig) 391 392 #: The dictionary of keywords to update the graphics collection with. 393 self._gc = kwargs 394 395 def draw_path(self, renderer, gc, tpath, affine, rgbFace) -> Any: # type: ignore[override] 396 """Draw the glyph path using the renderer. 397 398 Parameters 399 ---------- 400 renderer : matplotlib renderer 401 The renderer to draw with. 402 gc : GraphicsContext 403 Graphics context for drawing properties. 404 tpath : Path 405 Original text path (unused, using self.path instead). 406 affine : Transform 407 Affine transformation to apply. 408 rgbFace : color 409 Face color for the glyph. 410 411 Returns 412 ------- 413 Any 414 Result from renderer.draw_path. 415 """ 416 return renderer.draw_path(gc, self.path, affine, rgbFace)
Path effect for creating sequence logo glyphs with normalized dimensions.
This class creates properly scaled and positioned text glyphs for sequence logos by normalizing character dimensions and applying appropriate transforms.
Parameters
- glyph (str): Single character to render (e.g., 'A', 'C', 'G', 'T').
- ref_glyph (str, default 'E'): Reference character used for width normalization.
- font_props (FontProperties, optional): Font properties for the glyph rendering.
- offset (Tuple[float, float], default (0., 0.)): Offset for glyph positioning.
- **kwargs: Additional graphics collection parameters.
363 def __init__( 364 self, 365 glyph: str, 366 ref_glyph: str = "E", 367 font_props: Optional[FontProperties] = None, 368 offset: Tuple[float, float] = (0.0, 0.0), 369 **kwargs, 370 ) -> None: 371 super().__init__(offset) 372 373 path_orig = TextPath((0, 0), glyph, size=1, prop=font_props) 374 dims = path_orig.get_extents() 375 ref_dims = TextPath((0, 0), ref_glyph, size=1, prop=font_props).get_extents() 376 377 h_scale = 1 / dims.height 378 ref_width = max(dims.width, ref_dims.width) 379 w_scale = 1 / ref_width 380 w_shift = (1 - dims.width / ref_width) / 2 381 x_shift = -dims.x0 382 y_shift = -dims.y0 383 stretch = ( 384 Affine2D() 385 .translate(tx=x_shift, ty=y_shift) 386 .scale(sx=w_scale, sy=h_scale) 387 .translate(tx=w_shift, ty=0) 388 ) 389 390 self.path = stretch.transform_path(path_orig) 391 392 #: The dictionary of keywords to update the graphics collection with. 393 self._gc = kwargs
Parameters
- offset : (float, float), default ((0, 0)): The (x, y) offset to apply to the path, measured in points.
395 def draw_path(self, renderer, gc, tpath, affine, rgbFace) -> Any: # type: ignore[override] 396 """Draw the glyph path using the renderer. 397 398 Parameters 399 ---------- 400 renderer : matplotlib renderer 401 The renderer to draw with. 402 gc : GraphicsContext 403 Graphics context for drawing properties. 404 tpath : Path 405 Original text path (unused, using self.path instead). 406 affine : Transform 407 Affine transformation to apply. 408 rgbFace : color 409 Face color for the glyph. 410 411 Returns 412 ------- 413 Any 414 Result from renderer.draw_path. 415 """ 416 return renderer.draw_path(gc, self.path, affine, rgbFace)
Draw the glyph path using the renderer.
Parameters
- renderer (matplotlib renderer): The renderer to draw with.
- gc (GraphicsContext): Graphics context for drawing properties.
- tpath (Path): Original text path (unused, using self.path instead).
- affine (Transform): Affine transformation to apply.
- rgbFace (color): Face color for the glyph.
Returns
- Any: Result from renderer.draw_path.
419def plot_logo( 420 ax: Axes, 421 heights: Float[ndarray, "B W"], 422 glyphs: Iterable[str], 423 colors: Optional[Mapping[str, Optional[str]]] = None, 424 font_props: Optional[FontProperties] = None, 425 shade_bounds: Optional[Tuple[int, int]] = None, 426) -> None: 427 """Plot sequence logo from contribution weight matrix. 428 429 Creates a sequence logo visualization where letter heights represent 430 the contribution or information content at each position. Supports 431 both positive and negative contributions with proper stacking. 432 433 Parameters 434 ---------- 435 ax : Axes 436 Matplotlib axes object to plot on. 437 heights : Float[ndarray, "B W"] 438 Height matrix where B = len(glyphs) and W = motif width. 439 Entry (i,j) represents the height/contribution of base i at position j. 440 Can contain both positive and negative values. 441 glyphs : Iterable[str] 442 Sequence of base characters corresponding to rows in heights matrix. 443 Typically ['A', 'C', 'G', 'T'] for DNA. 444 colors : Dict[str, str], optional 445 Color mapping for each base. Keys should match glyphs. 446 If None, all bases will use default matplotlib colors. 447 font_props : FontProperties, optional 448 Font properties for letter rendering. If None, uses default font. 449 shade_bounds : Tuple[int, int], optional 450 (start, end) position indices to shade in background. 451 Useful for highlighting core motif regions. 452 453 Notes 454 ----- 455 Positive and negative contributions are handled separately: 456 - Positive values are stacked above zero line in order of descending absolute value 457 - Negative values are stacked below zero line in order of descending absolute value 458 - A horizontal line is drawn at y=0 for reference 459 460 The resulting plot has: 461 - X-axis: Position in motif (0-indexed) 462 - Y-axis: Contribution magnitude 463 - Bar width: 0.95 (small gaps between positions) 464 """ 465 if colors is None: 466 colors = {g: None for g in glyphs} 467 468 ax.margins(x=0, y=0) 469 470 pos_values = np.clip(heights, 0, None) 471 neg_values = np.clip(heights, None, 0) 472 pos_order = np.argsort(pos_values, axis=0) 473 neg_order = np.argsort(neg_values, axis=0)[::-1, :] 474 pos_reorder = np.argsort(pos_order, axis=0) 475 neg_reorder = np.argsort(neg_order, axis=0) 476 pos_offsets = np.take_along_axis( 477 np.cumsum(np.take_along_axis(pos_values, pos_order, axis=0), axis=0), 478 pos_reorder, 479 axis=0, 480 ) 481 neg_offsets = np.take_along_axis( 482 np.cumsum(np.take_along_axis(neg_values, neg_order, axis=0), axis=0), 483 neg_reorder, 484 axis=0, 485 ) 486 bottoms = pos_offsets + neg_offsets - heights 487 488 x = np.arange(heights.shape[1]) 489 490 for glyph, height, bottom in zip(glyphs, heights, bottoms): 491 ax.bar( 492 x, 493 height, 494 0.95, 495 bottom=bottom, 496 path_effects=[LogoGlyph(glyph, font_props=font_props)], 497 color=colors[glyph], 498 ) 499 500 if shade_bounds is not None: 501 start, end = shade_bounds 502 ax.axvspan(start - 0.5, end - 0.5, color="0.9", zorder=-1) 503 504 ax.axhline(zorder=-1, linewidth=0.5, color="black")
Plot sequence logo from contribution weight matrix.
Creates a sequence logo visualization where letter heights represent the contribution or information content at each position. Supports both positive and negative contributions with proper stacking.
Parameters
- ax (Axes): Matplotlib axes object to plot on.
- heights (Float[ndarray, "B W"]): Height matrix where B = len(glyphs) and W = motif width. Entry (i,j) represents the height/contribution of base i at position j. Can contain both positive and negative values.
- glyphs (Iterable[str]): Sequence of base characters corresponding to rows in heights matrix. Typically ['A', 'C', 'G', 'T'] for DNA.
- colors (Dict[str, str], optional): Color mapping for each base. Keys should match glyphs. If None, all bases will use default matplotlib colors.
- font_props (FontProperties, optional): Font properties for letter rendering. If None, uses default font.
- shade_bounds (Tuple[int, int], optional): (start, end) position indices to shade in background. Useful for highlighting core motif regions.
Notes
Positive and negative contributions are handled separately:
- Positive values are stacked above zero line in order of descending absolute value
- Negative values are stacked below zero line in order of descending absolute value
- A horizontal line is drawn at y=0 for reference
The resulting plot has:
- X-axis: Position in motif (0-indexed)
- Y-axis: Contribution magnitude
- Bar width: 0.95 (small gaps between positions)
512def plot_cwms( 513 cwms: Dict[str, Dict[str, Float[ndarray, "4 W"]]], 514 trim_bounds: Dict[str, Dict[str, Tuple[int, int]]], 515 out_dir: str, 516 alphabet: str = LOGO_ALPHABET, 517 colors: Dict[str, str] = LOGO_COLORS, 518 font: FontProperties = LOGO_FONT, 519) -> None: 520 """Plot contribution weight matrices as sequence logos. 521 522 Creates sequence logo plots for all motifs and CWM types, with optional 523 shading to highlight trimmed regions. Saves plots in both PNG and SVG formats. 524 525 Parameters 526 ---------- 527 cwms : Dict[str, Dict[str, Float[ndarray, "4 W"]]] 528 Nested dictionary structure: {motif_name: {cwm_type: cwm_array}}. 529 Each cwm_array has shape (4, W) where W is motif width. 530 Rows correspond to bases in alphabet order. 531 trim_bounds : Dict[str, Dict[str, Tuple[int, int]]] 532 Nested dictionary: {motif_name: {cwm_type: (start, end)}}. 533 Defines regions to shade in the sequence logos. 534 out_dir : str 535 Output directory where motif subdirectories will be created. 536 alphabet : str, default LOGO_ALPHABET 537 DNA alphabet string, typically 'ACGT'. 538 colors : Dict[str, str], default LOGO_COLORS 539 Color mapping for DNA bases. Keys should match alphabet characters. 540 font : FontProperties, default LOGO_FONT 541 Font properties for sequence logo rendering. 542 543 Notes 544 ----- 545 Directory structure created: 546 ``` 547 out_dir/ 548 ├── motif1/ 549 │ ├── cwm_type1.png 550 │ ├── cwm_type1.svg 551 │ └── ... 552 └── motif2/ 553 └── ... 554 ``` 555 556 Each plot is 10x2 inches with trimmed regions shaded if specified. 557 Spines (plot borders) are hidden for cleaner appearance. 558 """ 559 for m, v in cwms.items(): 560 motif_dir = os.path.join(out_dir, m) 561 os.makedirs(motif_dir, exist_ok=True) 562 for cwm_type, cwm in v.items(): 563 fig, ax = plt.subplots(figsize=(10, 2)) 564 565 plot_logo( 566 ax, 567 cwm, 568 alphabet, 569 colors=colors, 570 font_props=font, 571 shade_bounds=trim_bounds[m][cwm_type], 572 ) 573 574 for name, spine in ax.spines.items(): 575 spine.set_visible(False) 576 577 output_path_png = os.path.join(motif_dir, f"{cwm_type}.png") 578 plt.savefig(output_path_png, dpi=100) 579 output_path_svg = os.path.join(motif_dir, f"{cwm_type}.svg") 580 plt.savefig(output_path_svg) 581 582 plt.close(fig)
Plot contribution weight matrices as sequence logos.
Creates sequence logo plots for all motifs and CWM types, with optional shading to highlight trimmed regions. Saves plots in both PNG and SVG formats.
Parameters
- cwms (Dict[str, Dict[str, Float[ndarray, "4 W"]]]): Nested dictionary structure: {motif_name: {cwm_type: cwm_array}}. Each cwm_array has shape (4, W) where W is motif width. Rows correspond to bases in alphabet order.
- trim_bounds (Dict[str, Dict[str, Tuple[int, int]]]): Nested dictionary: {motif_name: {cwm_type: (start, end)}}. Defines regions to shade in the sequence logos.
- out_dir (str): Output directory where motif subdirectories will be created.
- alphabet (str, default LOGO_ALPHABET): DNA alphabet string, typically 'ACGT'.
- colors (Dict[str, str], default LOGO_COLORS): Color mapping for DNA bases. Keys should match alphabet characters.
- font (FontProperties, default LOGO_FONT): Font properties for sequence logo rendering.
Notes
Directory structure created:
out_dir/
├── motif1/
│ ├── cwm_type1.png
│ ├── cwm_type1.svg
│ └── ...
└── motif2/
└── ...
Each plot is 10x2 inches with trimmed regions shaded if specified. Spines (plot borders) are hidden for cleaner appearance.
585def plot_hit_vs_seqlet_counts( 586 recall_data: Dict[str, Dict[str, Union[int, float]]], output_dir: str 587) -> None: 588 """Plot scatter plot comparing hit counts to seqlet counts per motif. 589 590 Creates a log-log scatter plot showing the relationship between the number 591 of hits called by Fi-NeMo and the number of seqlets identified by TF-MoDISco 592 for each motif. Includes diagonal reference line and motif annotations. 593 594 Parameters 595 ---------- 596 recall_data : Dict[str, Dict[str, Union[int, float]]] 597 Dictionary with motif names as keys and metrics dictionaries as values. 598 Each metrics dictionary must contain: 599 - 'num_hits_total' : int, total number of hits for the motif 600 - 'num_seqlets' : int, total number of seqlets for the motif 601 output_dir : str 602 Directory path where the scatter plot will be saved. 603 604 Notes 605 ----- 606 Saves plots as: 607 - hit_vs_seqlet_counts.png : High-resolution raster format 608 - hit_vs_seqlet_counts.svg : Vector format 609 610 Plot features: 611 - Log-log scale on both axes 612 - Diagonal reference line (y = x) as dashed line 613 - Points annotated with abbreviated motif names 614 """ 615 x = [] 616 y = [] 617 m = [] 618 for k, v in recall_data.items(): 619 x.append(v["num_hits_total"]) 620 y.append(v["num_seqlets"]) 621 m.append(k) 622 623 lim = max(np.amax(x), np.amax(y)) 624 625 fig, ax = plt.subplots(figsize=(8, 8), layout="constrained") 626 ax.axline((0, 0), (lim, lim), color="0.3", linewidth=0.7, linestyle=(0, (5, 5))) 627 ax.scatter(x, y, s=5) 628 for i, txt in enumerate(m): 629 short = abbreviate_motif_name(txt) 630 ax.annotate(short, (x[i], y[i]), fontsize=8, weight="bold") 631 632 ax.set_yscale("log") 633 ax.set_xscale("log") 634 635 ax.set_xlabel("Hits per motif") 636 ax.set_ylabel("Seqlets per motif") 637 638 output_path_png = os.path.join(output_dir, "hit_vs_seqlet_counts.png") 639 plt.savefig(output_path_png, dpi=300) 640 output_path_svg = os.path.join(output_dir, "hit_vs_seqlet_counts.svg") 641 plt.savefig(output_path_svg) 642 643 plt.close()
Plot scatter plot comparing hit counts to seqlet counts per motif.
Creates a log-log scatter plot showing the relationship between the number of hits called by Fi-NeMo and the number of seqlets identified by TF-MoDISco for each motif. Includes diagonal reference line and motif annotations.
Parameters
- recall_data (Dict[str, Dict[str, Union[int, float]]]):
Dictionary with motif names as keys and metrics dictionaries as values.
Each metrics dictionary must contain:
- 'num_hits_total' : int, total number of hits for the motif
- 'num_seqlets' : int, total number of seqlets for the motif
- output_dir (str): Directory path where the scatter plot will be saved.
Notes
Saves plots as:
- hit_vs_seqlet_counts.png : High-resolution raster format
- hit_vs_seqlet_counts.svg : Vector format
Plot features:
- Log-log scale on both axes
- Diagonal reference line (y = x) as dashed line
- Points annotated with abbreviated motif names
646def write_report( 647 report_df: pl.DataFrame, 648 motif_names: List[str], 649 out_path: str, 650 compute_recall: bool, 651 use_seqlets: bool, 652) -> None: 653 """Generate and write HTML report from motif analysis results. 654 655 Creates a comprehensive HTML report with tables and visualizations 656 summarizing the Fi-NeMo motif discovery and hit calling results. 657 658 Parameters 659 ---------- 660 report_df : pl.DataFrame 661 DataFrame containing motif statistics and performance metrics. 662 Expected columns depend on compute_recall and use_seqlets flags. 663 motif_names : List[str] 664 List of motif names to include in the report. 665 Order determines presentation sequence in the report. 666 out_path : str 667 File path where the HTML report will be written. 668 Parent directory must exist. 669 compute_recall : bool 670 Whether recall metrics were computed and should be included 671 in the report template. 672 use_seqlets : bool 673 Whether TF-MoDISco seqlet data was used in the analysis 674 and should be referenced in the report. 675 676 Notes 677 ----- 678 Uses Jinja2 templating with the report.html template from the 679 templates package. The template receives: 680 - report_data: Iterator of DataFrame rows as named tuples 681 - motif_names: List of motif names 682 - compute_recall: Boolean flag for recall metrics 683 - use_seqlets: Boolean flag for seqlet usage 684 685 Raises 686 ------ 687 OSError 688 If the output path cannot be written. 689 """ 690 template_str = ( 691 importlib.resources.files(templates).joinpath("report.html").read_text() 692 ) 693 template = Template(template_str) 694 report = template.render( 695 report_data=report_df.iter_rows(named=True), 696 motif_names=motif_names, 697 compute_recall=compute_recall, 698 use_seqlets=use_seqlets, 699 ) 700 with open(out_path, "w") as f: 701 f.write(report)
Generate and write HTML report from motif analysis results.
Creates a comprehensive HTML report with tables and visualizations summarizing the Fi-NeMo motif discovery and hit calling results.
Parameters
- report_df (pl.DataFrame): DataFrame containing motif statistics and performance metrics. Expected columns depend on compute_recall and use_seqlets flags.
- motif_names (List[str]): List of motif names to include in the report. Order determines presentation sequence in the report.
- out_path (str): File path where the HTML report will be written. Parent directory must exist.
- compute_recall (bool): Whether recall metrics were computed and should be included in the report template.
- use_seqlets (bool): Whether TF-MoDISco seqlet data was used in the analysis and should be referenced in the report.
Notes
Uses Jinja2 templating with the report.html template from the templates package. The template receives:
- report_data: Iterator of DataFrame rows as named tuples
- motif_names: List of motif names
- compute_recall: Boolean flag for recall metrics
- use_seqlets: Boolean flag for seqlet usage
Raises
- OSError: If the output path cannot be written.