finemo.evaluation
Evaluation module for assessing Fi-NeMo motif discovery and hit calling performance.
This module provides functions for:
- Computing motif occurrence statistics and co-occurrence patterns
- Evaluating motif discovery quality against TF-MoDISco results
- Analyzing hit calling performance and recall metrics
- Generating confusion matrices for seqlet-hit comparisons
1"""Evaluation module for assessing Fi-NeMo motif discovery and hit calling performance. 2 3This module provides functions for: 4- Computing motif occurrence statistics and co-occurrence patterns 5- Evaluating motif discovery quality against TF-MoDISco results 6- Analyzing hit calling performance and recall metrics 7- Generating confusion matrices for seqlet-hit comparisons 8""" 9 10import warnings 11from typing import List, Tuple, Dict, Any, Union 12 13import numpy as np 14from numpy import ndarray 15import polars as pl 16from jaxtyping import Float, Int 17 18 19def get_motif_occurences( 20 hits_df: pl.LazyFrame, motif_names: List[str] 21) -> Tuple[pl.DataFrame, Int[ndarray, "M M"]]: 22 """Compute motif occurrence statistics and co-occurrence matrix. 23 24 This function analyzes motif occurrence patterns across peaks by creating 25 a pivot table of hit counts and computing pairwise co-occurrence statistics. 26 27 Parameters 28 ---------- 29 hits_df : pl.LazyFrame 30 Lazy DataFrame containing hit data with required columns: 31 - peak_id : Peak identifier 32 - motif_name : Name of the motif 33 Additional columns are ignored. 34 motif_names : List[str] 35 List of motif names to include in analysis. Missing motifs 36 will be added as columns with zero counts. 37 38 Returns 39 ------- 40 occ_df : pl.DataFrame 41 DataFrame with motif occurrence counts per peak. Contains: 42 - peak_id column 43 - One column per motif with hit counts 44 - 'total' column summing all motif counts per peak 45 coocc : Int[ndarray, "M M"] 46 Co-occurrence matrix where M = len(motif_names). 47 Entry (i,j) indicates number of peaks containing both motif i and motif j. 48 Diagonal entries show total peaks containing each motif. 49 50 Notes 51 ----- 52 The co-occurrence matrix is computed using binary occurrence indicators, 53 so multiple hits of the same motif in a peak are treated as a single occurrence. 54 """ 55 occ_df = ( 56 hits_df.collect() 57 .with_columns(pl.lit(1).alias("count")) 58 .pivot( 59 on="motif_name", index="peak_id", values="count", aggregate_function="sum" 60 ) 61 .fill_null(0) 62 ) 63 64 missing_cols = set(motif_names) - set(occ_df.columns) 65 occ_df = ( 66 occ_df.with_columns([pl.lit(0).alias(m) for m in missing_cols]) 67 .with_columns(total=pl.sum_horizontal(*motif_names)) 68 .sort(["peak_id"]) 69 ) 70 71 num_peaks = occ_df.height 72 num_motifs = len(motif_names) 73 74 occ_mat = np.zeros((num_peaks, num_motifs), dtype=np.int16) 75 for i, m in enumerate(motif_names): 76 occ_mat[:, i] = occ_df.get_column(m).to_numpy() 77 78 occ_bin = (occ_mat > 0).astype(np.int32) 79 coocc = occ_bin.T @ occ_bin 80 81 return occ_df, coocc 82 83 84def get_motifs( 85 regions: Float[ndarray, "N 4 L"], positions_df: pl.DataFrame, motif_width: int 86) -> Float[ndarray, "H 4 W"]: 87 """Extract contribution weight matrices from regions based on hit positions. 88 89 This function extracts motif-sized windows from contribution score regions 90 at positions specified by hit coordinates. It handles both forward and 91 reverse complement orientations and filters out invalid positions. 92 93 Parameters 94 ---------- 95 regions : Float[ndarray, "N 4 L"] 96 Input contribution score regions multiplied by one-hot sequences, 97 OR Input one-hot encoded sequences. 98 Shape: (n_peaks, 4, region_width) where 4 represents DNA bases (A,C,G,T). 99 positions_df : pl.DataFrame 100 DataFrame containing hit positions with required columns: 101 - peak_id : int, Peak index (0-based) 102 - start_untrimmed : int, Start position in genomic coordinates 103 - peak_region_start : int, Peak region start coordinate 104 - is_revcomp : bool, Whether hit is on reverse complement strand 105 motif_width : int 106 Width of motifs to extract. Must be positive. 107 108 Returns 109 ------- 110 motifs : Float[ndarray, "H 4 W"] 111 Extracted motif matrices for valid hits. 112 Shape: (n_valid_hits, 4, motif_width) 113 Invalid hits (outside region boundaries) are filtered out. 114 115 Notes 116 ----- 117 - Start positions are converted from genomic to region-relative coordinates 118 - Reverse complement hits have their sequence order reversed 119 - Hits extending beyond region boundaries are excluded 120 - The mean is computed across all valid hits, with warnings suppressed 121 for empty slices or invalid operations 122 123 Raises 124 ------ 125 ValueError 126 If motif_width is non-positive or positions_df lacks required columns. 127 """ 128 idx_df = positions_df.select( 129 peak_idx=pl.col("peak_id"), 130 start_idx=pl.col("start_untrimmed") - pl.col("peak_region_start"), 131 is_revcomp=pl.col("is_revcomp"), 132 ) 133 peak_idx = idx_df.get_column("peak_idx").to_numpy() 134 start_idx = idx_df.get_column("start_idx").to_numpy() 135 is_revcomp = idx_df.get_column("is_revcomp").to_numpy().astype(bool) 136 137 # Filter hits that fall outside the region boundaries 138 valid_mask = (start_idx >= 0) & (start_idx + motif_width <= regions.shape[2]) 139 peak_idx = peak_idx[valid_mask] 140 start_idx = start_idx[valid_mask] 141 is_revcomp = is_revcomp[valid_mask] 142 143 row_idx = peak_idx[:, None, None] 144 pos_idx = start_idx[:, None, None] + np.zeros((1, 1, motif_width), dtype=int) 145 pos_idx[~is_revcomp, :, :] += np.arange(motif_width)[None, None, :] 146 pos_idx[is_revcomp, :, :] += np.arange(motif_width)[None, None, ::-1] 147 nuc_idx = np.zeros((peak_idx.shape[0], 4, 1), dtype=int) 148 nuc_idx[~is_revcomp, :, :] += np.arange(4)[None, :, None] 149 nuc_idx[is_revcomp, :, :] += np.arange(4)[None, ::-1, None] 150 151 seqs = regions[row_idx, nuc_idx, pos_idx] 152 153 with warnings.catch_warnings(): 154 warnings.filterwarnings( 155 action="ignore", message="invalid value encountered in divide" 156 ) 157 warnings.filterwarnings(action="ignore", message="Mean of empty slice") 158 motifs = seqs.mean(axis=0) 159 160 return motifs 161 162 163def tfmodisco_comparison( 164 regions: Float[ndarray, "N 4 L"], 165 sequences: Int[ndarray, "N 4 L"], 166 hits_df: Union[pl.DataFrame, pl.LazyFrame], 167 peaks_df: pl.DataFrame, 168 seqlets_df: Union[pl.DataFrame, pl.LazyFrame, None], 169 motifs_df: pl.DataFrame, 170 cwms_modisco: Float[ndarray, "M 4 W"], 171 motif_names: List[str], 172 modisco_half_width: int, 173 motif_width: int, 174 compute_recall: bool, 175) -> Tuple[ 176 Dict[str, Dict[str, Any]], 177 pl.DataFrame, 178 Dict[str, Dict[str, Float[ndarray, "4 W"]]], 179 Dict[str, Dict[str, Tuple[int, int]]], 180]: 181 """Compare Fi-NeMo hits with TF-MoDISco seqlets and compute evaluation metrics. 182 183 This function performs comprehensive comparison between Fi-NeMo hit calls 184 and TF-MoDISco seqlets, computing recall metrics, CWM similarities, 185 and extracting contribution weight matrices for visualization. 186 187 Parameters 188 ---------- 189 regions : Float[ndarray, "N 4 L"] 190 Contribution score regions multiplied by one-hot sequences. 191 Shape: (n_peaks, 4, region_length) 192 sequences : Int[ndarray, "N 4 L"] 193 One-hot encoded sequences corresponding to regions. 194 Shape: (n_peaks, 4, region_length) 195 hits_df : Union[pl.DataFrame, pl.LazyFrame] 196 Fi-NeMo hit calls with required columns: 197 - peak_id, start_untrimmed, end_untrimmed, strand, motif_name 198 peaks_df : pl.DataFrame 199 Peak metadata with columns: 200 - peak_id, chr_id, peak_region_start 201 seqlets_df : Optional[pl.DataFrame] 202 TF-MoDISco seqlets with columns: 203 - chr_id, start_untrimmed, is_revcomp, motif_name 204 If None, only basic hit statistics are computed. 205 motifs_df : pl.DataFrame 206 Motif metadata with columns: 207 - motif_name, strand, motif_id, motif_start, motif_end 208 cwms_modisco : Float[ndarray, "M 4 W"] 209 TF-MoDISco contribution weight matrices. 210 Shape: (n_modisco_motifs, 4, motif_width) 211 motif_names : List[str] 212 Names of motifs to analyze. 213 modisco_half_width : int 214 Half-width for restricting hits to central region for fair comparison. 215 motif_width : int 216 Width of motifs for CWM extraction. 217 compute_recall : bool 218 Whether to compute recall metrics requiring seqlets_df. 219 220 Returns 221 ------- 222 report_data : Dict[str, Dict[str, Any]] 223 Per-motif evaluation metrics including: 224 - num_hits_total, num_hits_restricted, num_seqlets 225 - num_overlaps, seqlet_recall, cwm_similarity 226 report_df : pl.DataFrame 227 Tabular format of report_data for easy analysis. 228 cwms : Dict[str, Dict[str, Float[ndarray, "4 W"]]] 229 Extracted CWMs for each motif and condition: 230 - hits_fc, hits_rc: Forward/reverse complement hits 231 - modisco_fc, modisco_rc: TF-MoDISco forward/reverse 232 - seqlets_only, hits_restricted_only: Non-overlapping instances 233 cwm_trim_bounds : Dict[str, Dict[str, Tuple[int, int]]] 234 Trimming boundaries for each CWM type and motif. 235 236 Notes 237 ----- 238 - Hits are filtered to central region defined by modisco_half_width 239 - CWM similarity is computed as normalized dot product between hit and TF-MoDISco CWMs 240 - Recall metrics require both hits_df and seqlets_df to be non-empty 241 - Missing motifs are handled gracefully with empty DataFrames 242 243 Raises 244 ------ 245 ValueError 246 If required columns are missing from input DataFrames. 247 """ 248 249 # Ensure hits_df is LazyFrame for consistent operations 250 if isinstance(hits_df, pl.DataFrame): 251 hits_df = hits_df.lazy() 252 253 hits_df = ( 254 hits_df.with_columns(pl.col("peak_id").cast(pl.UInt32)) 255 .join(peaks_df.lazy(), on="peak_id", how="inner") 256 .select( 257 chr_id=pl.col("chr_id"), 258 start_untrimmed=pl.col("start_untrimmed"), 259 end_untrimmed=pl.col("end_untrimmed"), 260 is_revcomp=pl.col("strand") == "-", 261 motif_name=pl.col("motif_name"), 262 peak_region_start=pl.col("peak_region_start"), 263 peak_id=pl.col("peak_id"), 264 ) 265 ) 266 267 hits_unique = hits_df.unique( 268 subset=["chr_id", "start_untrimmed", "motif_name", "is_revcomp"] 269 ) 270 271 region_len = regions.shape[2] 272 center = region_len / 2 273 hits_filtered = hits_df.filter( 274 ( 275 (pl.col("start_untrimmed") - pl.col("peak_region_start")) 276 >= (center - modisco_half_width) 277 ) 278 & ( 279 (pl.col("end_untrimmed") - pl.col("peak_region_start")) 280 <= (center + modisco_half_width) 281 ) 282 ).unique(subset=["chr_id", "start_untrimmed", "motif_name", "is_revcomp"]) 283 284 hits_by_motif = hits_unique.collect().partition_by("motif_name", as_dict=True) 285 hits_filtered_by_motif = hits_filtered.collect().partition_by( 286 "motif_name", as_dict=True 287 ) 288 289 if seqlets_df is None: 290 seqlets_collected = None 291 seqlets_lazy = None 292 elif isinstance(seqlets_df, pl.LazyFrame): 293 seqlets_collected = seqlets_df.collect() 294 seqlets_lazy = seqlets_df 295 else: 296 seqlets_collected = seqlets_df 297 seqlets_lazy = seqlets_df.lazy() 298 299 if seqlets_collected is not None: 300 seqlets_by_motif = seqlets_collected.partition_by("motif_name", as_dict=True) 301 else: 302 seqlets_by_motif = {} 303 304 if compute_recall and seqlets_lazy is not None: 305 overlaps_df = hits_filtered.join( 306 seqlets_lazy, 307 on=["chr_id", "start_untrimmed", "is_revcomp", "motif_name"], 308 how="inner", 309 ).collect() 310 311 seqlets_only_df = seqlets_lazy.join( 312 hits_df, 313 on=["chr_id", "start_untrimmed", "is_revcomp", "motif_name"], 314 how="anti", 315 ).collect() 316 317 hits_only_filtered_df = hits_filtered.join( 318 seqlets_lazy, 319 on=["chr_id", "start_untrimmed", "is_revcomp", "motif_name"], 320 how="anti", 321 ).collect() 322 323 # Create partition dictionaries 324 overlaps_by_motif = overlaps_df.partition_by("motif_name", as_dict=True) 325 seqlets_only_by_motif = seqlets_only_df.partition_by("motif_name", as_dict=True) 326 hits_only_filtered_by_motif = hits_only_filtered_df.partition_by( 327 "motif_name", as_dict=True 328 ) 329 else: 330 overlaps_by_motif = {} 331 seqlets_only_by_motif = {} 332 hits_only_filtered_by_motif = {} 333 334 report_data = {} 335 motifs = {} 336 cwm_trim_bounds = {} 337 dummy_df = hits_df.clear().collect() 338 for m in motif_names: 339 hits = hits_by_motif.get((m,), dummy_df) 340 hits_filtered = hits_filtered_by_motif.get((m,), dummy_df) 341 342 # Initialize default values 343 seqlets = dummy_df 344 overlaps = dummy_df 345 seqlets_only = dummy_df 346 hits_only_filtered = dummy_df 347 348 if seqlets_df is not None: 349 seqlets = seqlets_by_motif.get((m,), dummy_df) 350 351 if compute_recall and seqlets_df is not None: 352 overlaps = overlaps_by_motif.get((m,), dummy_df) 353 seqlets_only = seqlets_only_by_motif.get((m,), dummy_df) 354 hits_only_filtered = hits_only_filtered_by_motif.get((m,), dummy_df) 355 356 report_data[m] = { 357 "num_hits_total": hits.height, 358 "num_hits_restricted": hits_filtered.height, 359 } 360 361 if seqlets_df is not None: 362 report_data[m]["num_seqlets"] = seqlets.height 363 364 if compute_recall and seqlets_df is not None: 365 report_data[m] |= { 366 "num_overlaps": overlaps.height, 367 "num_seqlets_only": seqlets_only.height, 368 "num_hits_restricted_only": hits_only_filtered.height, 369 "seqlet_recall": np.float64(overlaps.height) / seqlets.height 370 if seqlets.height > 0 371 else 0.0, 372 } 373 374 motif_data_fc = motifs_df.row( 375 by_predicate=(pl.col("motif_name") == m) & (pl.col("strand") == "+"), 376 named=True, 377 ) 378 motif_data_rc = motifs_df.row( 379 by_predicate=(pl.col("motif_name") == m) & (pl.col("strand") == "-"), 380 named=True, 381 ) 382 383 motifs[m] = { 384 "hits_fc": get_motifs(regions, hits, motif_width), 385 "modisco_fc": cwms_modisco[motif_data_fc["motif_id"]], 386 "modisco_rc": cwms_modisco[motif_data_rc["motif_id"]], 387 } 388 motifs[m]["hits_rc"] = motifs[m]["hits_fc"][::-1, ::-1] 389 motifs[m]["hits_ppm_fc"] = get_motifs(sequences, hits, motif_width) 390 motifs[m]["hits_ppm_rc"] = motifs[m]["hits_ppm_fc"][::-1, ::-1] 391 392 if compute_recall and seqlets_df is not None: 393 motifs[m] |= { 394 "seqlets_only": get_motifs(regions, seqlets_only, motif_width), 395 "hits_restricted_only": get_motifs( 396 regions, hits_only_filtered, motif_width 397 ), 398 } 399 400 bounds_fc = (motif_data_fc["motif_start"], motif_data_fc["motif_end"]) 401 bounds_rc = (motif_data_rc["motif_start"], motif_data_rc["motif_end"]) 402 403 cwm_trim_bounds[m] = { 404 "hits_fc": bounds_fc, 405 "modisco_fc": bounds_fc, 406 "modisco_rc": bounds_rc, 407 "hits_rc": bounds_rc, 408 "hits_ppm_fc": bounds_fc, 409 "hits_ppm_rc": bounds_rc, 410 } 411 412 if compute_recall and seqlets_df is not None: 413 cwm_trim_bounds[m] |= { 414 "seqlets_only": bounds_fc, 415 "hits_restricted_only": bounds_fc, 416 } 417 418 hits_cwm = motifs[m]["hits_fc"] 419 modisco_cwm = motifs[m]["modisco_fc"] 420 hnorm = np.sqrt((hits_cwm**2).sum()) 421 snorm = np.sqrt((modisco_cwm**2).sum()) 422 cwm_sim = (hits_cwm * modisco_cwm).sum() / (hnorm * snorm) 423 424 report_data[m]["cwm_similarity"] = cwm_sim 425 426 records = [{"motif_name": k} | v for k, v in report_data.items()] 427 report_df = pl.from_dicts(records) 428 429 return report_data, report_df, motifs, cwm_trim_bounds 430 431 432def seqlet_confusion( 433 hits_df: Union[pl.DataFrame, pl.LazyFrame], 434 seqlets_df: Union[pl.DataFrame, pl.LazyFrame], 435 peaks_df: pl.DataFrame, 436 motif_names: List[str], 437 motif_width: int, 438) -> Tuple[pl.DataFrame, Float[ndarray, "M M"]]: 439 """Compute confusion matrix between TF-MoDISco seqlets and Fi-NeMo hits. 440 441 This function creates a confusion matrix showing the overlap between 442 TF-MoDISco seqlets (ground truth) and Fi-NeMo hits across different motifs. 443 Overlap frequencies are estimated using binned genomic coordinates. 444 445 Parameters 446 ---------- 447 hits_df : Union[pl.DataFrame, pl.LazyFrame] 448 Fi-NeMo hit calls with required columns: 449 - peak_id, start_untrimmed, end_untrimmed, strand, motif_name 450 seqlets_df : pl.DataFrame 451 TF-MoDISco seqlets with required columns: 452 - chr_id, start_untrimmed, end_untrimmed, motif_name 453 peaks_df : pl.DataFrame 454 Peak metadata for joining coordinates: 455 - peak_id, chr_id 456 motif_names : List[str] 457 Names of motifs to include in confusion matrix. 458 Determines matrix dimensions. 459 motif_width : int 460 Width used for binning genomic coordinates. 461 Positions are binned to motif_width resolution. 462 463 Returns 464 ------- 465 confusion_df : pl.DataFrame 466 Detailed confusion matrix in tabular format with columns: 467 - motif_name_seqlets : Seqlet motif labels (rows) 468 - motif_name_hits : Hit motif labels (columns) 469 - frac_overlap : Fraction of seqlets overlapping with hits 470 confusion_mat : Float[ndarray, "M M"] 471 Confusion matrix where M = len(motif_names). 472 Entry (i,j) = fraction of motif i seqlets overlapping with motif j hits. 473 Rows represent seqlet motifs, columns represent hit motifs. 474 475 Notes 476 ----- 477 - Genomic coordinates are binned to motif_width resolution for overlap detection 478 - Only exact bin overlaps are considered (same chr_id, start_bin, end_bin) 479 - Fractions are computed as: overlaps / total_seqlets_per_motif 480 - Missing motif combinations result in zero entries in the confusion matrix 481 482 Raises 483 ------ 484 ValueError 485 If required columns are missing from input DataFrames. 486 KeyError 487 If motif names in data don't match those in motif_names list. 488 """ 489 bin_size = motif_width 490 491 # Ensure hits_df is LazyFrame for consistent operations 492 if isinstance(hits_df, pl.DataFrame): 493 hits_df = hits_df.lazy() 494 495 hits_binned = ( 496 hits_df.with_columns( 497 peak_id=pl.col("peak_id").cast(pl.UInt32), 498 is_revcomp=pl.col("strand") == "-", 499 ) 500 .join(peaks_df.lazy(), on="peak_id", how="inner") 501 .unique(subset=["chr_id", "start_untrimmed", "motif_name", "is_revcomp"]) 502 .select( 503 chr_id=pl.col("chr_id"), 504 start_bin=pl.col("start_untrimmed") // bin_size, 505 end_bin=pl.col("end_untrimmed") // bin_size, 506 motif_name=pl.col("motif_name"), 507 ) 508 ) 509 510 seqlets_lazy = seqlets_df.lazy() 511 seqlets_binned = seqlets_lazy.select( 512 chr_id=pl.col("chr_id"), 513 start_bin=pl.col("start_untrimmed") // bin_size, 514 end_bin=pl.col("end_untrimmed") // bin_size, 515 motif_name=pl.col("motif_name"), 516 ) 517 518 overlaps_df = seqlets_binned.join( 519 hits_binned, on=["chr_id", "start_bin", "end_bin"], how="inner", suffix="_hits" 520 ) 521 522 seqlet_counts = ( 523 seqlets_binned.group_by("motif_name").len(name="num_seqlets").collect() 524 ) 525 overlap_counts = ( 526 overlaps_df.group_by(["motif_name", "motif_name_hits"]) 527 .len(name="num_overlaps") 528 .collect() 529 ) 530 531 num_motifs = len(motif_names) 532 confusion_mat = np.zeros((num_motifs, num_motifs), dtype=np.float32) 533 name_to_idx = {m: i for i, m in enumerate(motif_names)} 534 535 confusion_df = overlap_counts.join( 536 seqlet_counts, on="motif_name", how="inner" 537 ).select( 538 motif_name_seqlets=pl.col("motif_name"), 539 motif_name_hits=pl.col("motif_name_hits"), 540 frac_overlap=pl.col("num_overlaps") / pl.col("num_seqlets"), 541 ) 542 543 confusion_idx_df = confusion_df.select( 544 row_idx=pl.col("motif_name_seqlets").replace_strict(name_to_idx), 545 col_idx=pl.col("motif_name_hits").replace_strict(name_to_idx), 546 frac_overlap=pl.col("frac_overlap"), 547 ) 548 549 row_idx = confusion_idx_df["row_idx"].to_numpy() 550 col_idx = confusion_idx_df["col_idx"].to_numpy() 551 frac_overlap = confusion_idx_df["frac_overlap"].to_numpy() 552 553 confusion_mat[row_idx, col_idx] = frac_overlap 554 555 return confusion_df, confusion_mat
20def get_motif_occurences( 21 hits_df: pl.LazyFrame, motif_names: List[str] 22) -> Tuple[pl.DataFrame, Int[ndarray, "M M"]]: 23 """Compute motif occurrence statistics and co-occurrence matrix. 24 25 This function analyzes motif occurrence patterns across peaks by creating 26 a pivot table of hit counts and computing pairwise co-occurrence statistics. 27 28 Parameters 29 ---------- 30 hits_df : pl.LazyFrame 31 Lazy DataFrame containing hit data with required columns: 32 - peak_id : Peak identifier 33 - motif_name : Name of the motif 34 Additional columns are ignored. 35 motif_names : List[str] 36 List of motif names to include in analysis. Missing motifs 37 will be added as columns with zero counts. 38 39 Returns 40 ------- 41 occ_df : pl.DataFrame 42 DataFrame with motif occurrence counts per peak. Contains: 43 - peak_id column 44 - One column per motif with hit counts 45 - 'total' column summing all motif counts per peak 46 coocc : Int[ndarray, "M M"] 47 Co-occurrence matrix where M = len(motif_names). 48 Entry (i,j) indicates number of peaks containing both motif i and motif j. 49 Diagonal entries show total peaks containing each motif. 50 51 Notes 52 ----- 53 The co-occurrence matrix is computed using binary occurrence indicators, 54 so multiple hits of the same motif in a peak are treated as a single occurrence. 55 """ 56 occ_df = ( 57 hits_df.collect() 58 .with_columns(pl.lit(1).alias("count")) 59 .pivot( 60 on="motif_name", index="peak_id", values="count", aggregate_function="sum" 61 ) 62 .fill_null(0) 63 ) 64 65 missing_cols = set(motif_names) - set(occ_df.columns) 66 occ_df = ( 67 occ_df.with_columns([pl.lit(0).alias(m) for m in missing_cols]) 68 .with_columns(total=pl.sum_horizontal(*motif_names)) 69 .sort(["peak_id"]) 70 ) 71 72 num_peaks = occ_df.height 73 num_motifs = len(motif_names) 74 75 occ_mat = np.zeros((num_peaks, num_motifs), dtype=np.int16) 76 for i, m in enumerate(motif_names): 77 occ_mat[:, i] = occ_df.get_column(m).to_numpy() 78 79 occ_bin = (occ_mat > 0).astype(np.int32) 80 coocc = occ_bin.T @ occ_bin 81 82 return occ_df, coocc
Compute motif occurrence statistics and co-occurrence matrix.
This function analyzes motif occurrence patterns across peaks by creating a pivot table of hit counts and computing pairwise co-occurrence statistics.
Parameters
- hits_df (pl.LazyFrame):
Lazy DataFrame containing hit data with required columns:
- peak_id : Peak identifier
- motif_name : Name of the motif Additional columns are ignored.
- motif_names (List[str]): List of motif names to include in analysis. Missing motifs will be added as columns with zero counts.
Returns
- occ_df (pl.DataFrame):
DataFrame with motif occurrence counts per peak. Contains:
- peak_id column
- One column per motif with hit counts
- 'total' column summing all motif counts per peak
- coocc (Int[ndarray, "M M"]): Co-occurrence matrix where M = len(motif_names). Entry (i,j) indicates number of peaks containing both motif i and motif j. Diagonal entries show total peaks containing each motif.
Notes
The co-occurrence matrix is computed using binary occurrence indicators, so multiple hits of the same motif in a peak are treated as a single occurrence.
85def get_motifs( 86 regions: Float[ndarray, "N 4 L"], positions_df: pl.DataFrame, motif_width: int 87) -> Float[ndarray, "H 4 W"]: 88 """Extract contribution weight matrices from regions based on hit positions. 89 90 This function extracts motif-sized windows from contribution score regions 91 at positions specified by hit coordinates. It handles both forward and 92 reverse complement orientations and filters out invalid positions. 93 94 Parameters 95 ---------- 96 regions : Float[ndarray, "N 4 L"] 97 Input contribution score regions multiplied by one-hot sequences, 98 OR Input one-hot encoded sequences. 99 Shape: (n_peaks, 4, region_width) where 4 represents DNA bases (A,C,G,T). 100 positions_df : pl.DataFrame 101 DataFrame containing hit positions with required columns: 102 - peak_id : int, Peak index (0-based) 103 - start_untrimmed : int, Start position in genomic coordinates 104 - peak_region_start : int, Peak region start coordinate 105 - is_revcomp : bool, Whether hit is on reverse complement strand 106 motif_width : int 107 Width of motifs to extract. Must be positive. 108 109 Returns 110 ------- 111 motifs : Float[ndarray, "H 4 W"] 112 Extracted motif matrices for valid hits. 113 Shape: (n_valid_hits, 4, motif_width) 114 Invalid hits (outside region boundaries) are filtered out. 115 116 Notes 117 ----- 118 - Start positions are converted from genomic to region-relative coordinates 119 - Reverse complement hits have their sequence order reversed 120 - Hits extending beyond region boundaries are excluded 121 - The mean is computed across all valid hits, with warnings suppressed 122 for empty slices or invalid operations 123 124 Raises 125 ------ 126 ValueError 127 If motif_width is non-positive or positions_df lacks required columns. 128 """ 129 idx_df = positions_df.select( 130 peak_idx=pl.col("peak_id"), 131 start_idx=pl.col("start_untrimmed") - pl.col("peak_region_start"), 132 is_revcomp=pl.col("is_revcomp"), 133 ) 134 peak_idx = idx_df.get_column("peak_idx").to_numpy() 135 start_idx = idx_df.get_column("start_idx").to_numpy() 136 is_revcomp = idx_df.get_column("is_revcomp").to_numpy().astype(bool) 137 138 # Filter hits that fall outside the region boundaries 139 valid_mask = (start_idx >= 0) & (start_idx + motif_width <= regions.shape[2]) 140 peak_idx = peak_idx[valid_mask] 141 start_idx = start_idx[valid_mask] 142 is_revcomp = is_revcomp[valid_mask] 143 144 row_idx = peak_idx[:, None, None] 145 pos_idx = start_idx[:, None, None] + np.zeros((1, 1, motif_width), dtype=int) 146 pos_idx[~is_revcomp, :, :] += np.arange(motif_width)[None, None, :] 147 pos_idx[is_revcomp, :, :] += np.arange(motif_width)[None, None, ::-1] 148 nuc_idx = np.zeros((peak_idx.shape[0], 4, 1), dtype=int) 149 nuc_idx[~is_revcomp, :, :] += np.arange(4)[None, :, None] 150 nuc_idx[is_revcomp, :, :] += np.arange(4)[None, ::-1, None] 151 152 seqs = regions[row_idx, nuc_idx, pos_idx] 153 154 with warnings.catch_warnings(): 155 warnings.filterwarnings( 156 action="ignore", message="invalid value encountered in divide" 157 ) 158 warnings.filterwarnings(action="ignore", message="Mean of empty slice") 159 motifs = seqs.mean(axis=0) 160 161 return motifs
Extract contribution weight matrices from regions based on hit positions.
This function extracts motif-sized windows from contribution score regions at positions specified by hit coordinates. It handles both forward and reverse complement orientations and filters out invalid positions.
Parameters
- regions (Float[ndarray, "N 4 L"]): Input contribution score regions multiplied by one-hot sequences, OR Input one-hot encoded sequences. Shape: (n_peaks, 4, region_width) where 4 represents DNA bases (A,C,G,T).
- positions_df (pl.DataFrame):
DataFrame containing hit positions with required columns:
- peak_id : int, Peak index (0-based)
- start_untrimmed : int, Start position in genomic coordinates
- peak_region_start : int, Peak region start coordinate
- is_revcomp : bool, Whether hit is on reverse complement strand
- motif_width (int): Width of motifs to extract. Must be positive.
Returns
- motifs (Float[ndarray, "H 4 W"]): Extracted motif matrices for valid hits. Shape: (n_valid_hits, 4, motif_width) Invalid hits (outside region boundaries) are filtered out.
Notes
- Start positions are converted from genomic to region-relative coordinates
- Reverse complement hits have their sequence order reversed
- Hits extending beyond region boundaries are excluded
- The mean is computed across all valid hits, with warnings suppressed for empty slices or invalid operations
Raises
- ValueError: If motif_width is non-positive or positions_df lacks required columns.
164def tfmodisco_comparison( 165 regions: Float[ndarray, "N 4 L"], 166 sequences: Int[ndarray, "N 4 L"], 167 hits_df: Union[pl.DataFrame, pl.LazyFrame], 168 peaks_df: pl.DataFrame, 169 seqlets_df: Union[pl.DataFrame, pl.LazyFrame, None], 170 motifs_df: pl.DataFrame, 171 cwms_modisco: Float[ndarray, "M 4 W"], 172 motif_names: List[str], 173 modisco_half_width: int, 174 motif_width: int, 175 compute_recall: bool, 176) -> Tuple[ 177 Dict[str, Dict[str, Any]], 178 pl.DataFrame, 179 Dict[str, Dict[str, Float[ndarray, "4 W"]]], 180 Dict[str, Dict[str, Tuple[int, int]]], 181]: 182 """Compare Fi-NeMo hits with TF-MoDISco seqlets and compute evaluation metrics. 183 184 This function performs comprehensive comparison between Fi-NeMo hit calls 185 and TF-MoDISco seqlets, computing recall metrics, CWM similarities, 186 and extracting contribution weight matrices for visualization. 187 188 Parameters 189 ---------- 190 regions : Float[ndarray, "N 4 L"] 191 Contribution score regions multiplied by one-hot sequences. 192 Shape: (n_peaks, 4, region_length) 193 sequences : Int[ndarray, "N 4 L"] 194 One-hot encoded sequences corresponding to regions. 195 Shape: (n_peaks, 4, region_length) 196 hits_df : Union[pl.DataFrame, pl.LazyFrame] 197 Fi-NeMo hit calls with required columns: 198 - peak_id, start_untrimmed, end_untrimmed, strand, motif_name 199 peaks_df : pl.DataFrame 200 Peak metadata with columns: 201 - peak_id, chr_id, peak_region_start 202 seqlets_df : Optional[pl.DataFrame] 203 TF-MoDISco seqlets with columns: 204 - chr_id, start_untrimmed, is_revcomp, motif_name 205 If None, only basic hit statistics are computed. 206 motifs_df : pl.DataFrame 207 Motif metadata with columns: 208 - motif_name, strand, motif_id, motif_start, motif_end 209 cwms_modisco : Float[ndarray, "M 4 W"] 210 TF-MoDISco contribution weight matrices. 211 Shape: (n_modisco_motifs, 4, motif_width) 212 motif_names : List[str] 213 Names of motifs to analyze. 214 modisco_half_width : int 215 Half-width for restricting hits to central region for fair comparison. 216 motif_width : int 217 Width of motifs for CWM extraction. 218 compute_recall : bool 219 Whether to compute recall metrics requiring seqlets_df. 220 221 Returns 222 ------- 223 report_data : Dict[str, Dict[str, Any]] 224 Per-motif evaluation metrics including: 225 - num_hits_total, num_hits_restricted, num_seqlets 226 - num_overlaps, seqlet_recall, cwm_similarity 227 report_df : pl.DataFrame 228 Tabular format of report_data for easy analysis. 229 cwms : Dict[str, Dict[str, Float[ndarray, "4 W"]]] 230 Extracted CWMs for each motif and condition: 231 - hits_fc, hits_rc: Forward/reverse complement hits 232 - modisco_fc, modisco_rc: TF-MoDISco forward/reverse 233 - seqlets_only, hits_restricted_only: Non-overlapping instances 234 cwm_trim_bounds : Dict[str, Dict[str, Tuple[int, int]]] 235 Trimming boundaries for each CWM type and motif. 236 237 Notes 238 ----- 239 - Hits are filtered to central region defined by modisco_half_width 240 - CWM similarity is computed as normalized dot product between hit and TF-MoDISco CWMs 241 - Recall metrics require both hits_df and seqlets_df to be non-empty 242 - Missing motifs are handled gracefully with empty DataFrames 243 244 Raises 245 ------ 246 ValueError 247 If required columns are missing from input DataFrames. 248 """ 249 250 # Ensure hits_df is LazyFrame for consistent operations 251 if isinstance(hits_df, pl.DataFrame): 252 hits_df = hits_df.lazy() 253 254 hits_df = ( 255 hits_df.with_columns(pl.col("peak_id").cast(pl.UInt32)) 256 .join(peaks_df.lazy(), on="peak_id", how="inner") 257 .select( 258 chr_id=pl.col("chr_id"), 259 start_untrimmed=pl.col("start_untrimmed"), 260 end_untrimmed=pl.col("end_untrimmed"), 261 is_revcomp=pl.col("strand") == "-", 262 motif_name=pl.col("motif_name"), 263 peak_region_start=pl.col("peak_region_start"), 264 peak_id=pl.col("peak_id"), 265 ) 266 ) 267 268 hits_unique = hits_df.unique( 269 subset=["chr_id", "start_untrimmed", "motif_name", "is_revcomp"] 270 ) 271 272 region_len = regions.shape[2] 273 center = region_len / 2 274 hits_filtered = hits_df.filter( 275 ( 276 (pl.col("start_untrimmed") - pl.col("peak_region_start")) 277 >= (center - modisco_half_width) 278 ) 279 & ( 280 (pl.col("end_untrimmed") - pl.col("peak_region_start")) 281 <= (center + modisco_half_width) 282 ) 283 ).unique(subset=["chr_id", "start_untrimmed", "motif_name", "is_revcomp"]) 284 285 hits_by_motif = hits_unique.collect().partition_by("motif_name", as_dict=True) 286 hits_filtered_by_motif = hits_filtered.collect().partition_by( 287 "motif_name", as_dict=True 288 ) 289 290 if seqlets_df is None: 291 seqlets_collected = None 292 seqlets_lazy = None 293 elif isinstance(seqlets_df, pl.LazyFrame): 294 seqlets_collected = seqlets_df.collect() 295 seqlets_lazy = seqlets_df 296 else: 297 seqlets_collected = seqlets_df 298 seqlets_lazy = seqlets_df.lazy() 299 300 if seqlets_collected is not None: 301 seqlets_by_motif = seqlets_collected.partition_by("motif_name", as_dict=True) 302 else: 303 seqlets_by_motif = {} 304 305 if compute_recall and seqlets_lazy is not None: 306 overlaps_df = hits_filtered.join( 307 seqlets_lazy, 308 on=["chr_id", "start_untrimmed", "is_revcomp", "motif_name"], 309 how="inner", 310 ).collect() 311 312 seqlets_only_df = seqlets_lazy.join( 313 hits_df, 314 on=["chr_id", "start_untrimmed", "is_revcomp", "motif_name"], 315 how="anti", 316 ).collect() 317 318 hits_only_filtered_df = hits_filtered.join( 319 seqlets_lazy, 320 on=["chr_id", "start_untrimmed", "is_revcomp", "motif_name"], 321 how="anti", 322 ).collect() 323 324 # Create partition dictionaries 325 overlaps_by_motif = overlaps_df.partition_by("motif_name", as_dict=True) 326 seqlets_only_by_motif = seqlets_only_df.partition_by("motif_name", as_dict=True) 327 hits_only_filtered_by_motif = hits_only_filtered_df.partition_by( 328 "motif_name", as_dict=True 329 ) 330 else: 331 overlaps_by_motif = {} 332 seqlets_only_by_motif = {} 333 hits_only_filtered_by_motif = {} 334 335 report_data = {} 336 motifs = {} 337 cwm_trim_bounds = {} 338 dummy_df = hits_df.clear().collect() 339 for m in motif_names: 340 hits = hits_by_motif.get((m,), dummy_df) 341 hits_filtered = hits_filtered_by_motif.get((m,), dummy_df) 342 343 # Initialize default values 344 seqlets = dummy_df 345 overlaps = dummy_df 346 seqlets_only = dummy_df 347 hits_only_filtered = dummy_df 348 349 if seqlets_df is not None: 350 seqlets = seqlets_by_motif.get((m,), dummy_df) 351 352 if compute_recall and seqlets_df is not None: 353 overlaps = overlaps_by_motif.get((m,), dummy_df) 354 seqlets_only = seqlets_only_by_motif.get((m,), dummy_df) 355 hits_only_filtered = hits_only_filtered_by_motif.get((m,), dummy_df) 356 357 report_data[m] = { 358 "num_hits_total": hits.height, 359 "num_hits_restricted": hits_filtered.height, 360 } 361 362 if seqlets_df is not None: 363 report_data[m]["num_seqlets"] = seqlets.height 364 365 if compute_recall and seqlets_df is not None: 366 report_data[m] |= { 367 "num_overlaps": overlaps.height, 368 "num_seqlets_only": seqlets_only.height, 369 "num_hits_restricted_only": hits_only_filtered.height, 370 "seqlet_recall": np.float64(overlaps.height) / seqlets.height 371 if seqlets.height > 0 372 else 0.0, 373 } 374 375 motif_data_fc = motifs_df.row( 376 by_predicate=(pl.col("motif_name") == m) & (pl.col("strand") == "+"), 377 named=True, 378 ) 379 motif_data_rc = motifs_df.row( 380 by_predicate=(pl.col("motif_name") == m) & (pl.col("strand") == "-"), 381 named=True, 382 ) 383 384 motifs[m] = { 385 "hits_fc": get_motifs(regions, hits, motif_width), 386 "modisco_fc": cwms_modisco[motif_data_fc["motif_id"]], 387 "modisco_rc": cwms_modisco[motif_data_rc["motif_id"]], 388 } 389 motifs[m]["hits_rc"] = motifs[m]["hits_fc"][::-1, ::-1] 390 motifs[m]["hits_ppm_fc"] = get_motifs(sequences, hits, motif_width) 391 motifs[m]["hits_ppm_rc"] = motifs[m]["hits_ppm_fc"][::-1, ::-1] 392 393 if compute_recall and seqlets_df is not None: 394 motifs[m] |= { 395 "seqlets_only": get_motifs(regions, seqlets_only, motif_width), 396 "hits_restricted_only": get_motifs( 397 regions, hits_only_filtered, motif_width 398 ), 399 } 400 401 bounds_fc = (motif_data_fc["motif_start"], motif_data_fc["motif_end"]) 402 bounds_rc = (motif_data_rc["motif_start"], motif_data_rc["motif_end"]) 403 404 cwm_trim_bounds[m] = { 405 "hits_fc": bounds_fc, 406 "modisco_fc": bounds_fc, 407 "modisco_rc": bounds_rc, 408 "hits_rc": bounds_rc, 409 "hits_ppm_fc": bounds_fc, 410 "hits_ppm_rc": bounds_rc, 411 } 412 413 if compute_recall and seqlets_df is not None: 414 cwm_trim_bounds[m] |= { 415 "seqlets_only": bounds_fc, 416 "hits_restricted_only": bounds_fc, 417 } 418 419 hits_cwm = motifs[m]["hits_fc"] 420 modisco_cwm = motifs[m]["modisco_fc"] 421 hnorm = np.sqrt((hits_cwm**2).sum()) 422 snorm = np.sqrt((modisco_cwm**2).sum()) 423 cwm_sim = (hits_cwm * modisco_cwm).sum() / (hnorm * snorm) 424 425 report_data[m]["cwm_similarity"] = cwm_sim 426 427 records = [{"motif_name": k} | v for k, v in report_data.items()] 428 report_df = pl.from_dicts(records) 429 430 return report_data, report_df, motifs, cwm_trim_bounds
Compare Fi-NeMo hits with TF-MoDISco seqlets and compute evaluation metrics.
This function performs comprehensive comparison between Fi-NeMo hit calls and TF-MoDISco seqlets, computing recall metrics, CWM similarities, and extracting contribution weight matrices for visualization.
Parameters
- regions (Float[ndarray, "N 4 L"]): Contribution score regions multiplied by one-hot sequences. Shape: (n_peaks, 4, region_length)
- sequences (Int[ndarray, "N 4 L"]): One-hot encoded sequences corresponding to regions. Shape: (n_peaks, 4, region_length)
- hits_df (Union[pl.DataFrame, pl.LazyFrame]):
Fi-NeMo hit calls with required columns:
- peak_id, start_untrimmed, end_untrimmed, strand, motif_name
- peaks_df (pl.DataFrame):
Peak metadata with columns:
- peak_id, chr_id, peak_region_start
- seqlets_df (Optional[pl.DataFrame]):
TF-MoDISco seqlets with columns:
- chr_id, start_untrimmed, is_revcomp, motif_name If None, only basic hit statistics are computed.
- motifs_df (pl.DataFrame):
Motif metadata with columns:
- motif_name, strand, motif_id, motif_start, motif_end
- cwms_modisco (Float[ndarray, "M 4 W"]): TF-MoDISco contribution weight matrices. Shape: (n_modisco_motifs, 4, motif_width)
- motif_names (List[str]): Names of motifs to analyze.
- modisco_half_width (int): Half-width for restricting hits to central region for fair comparison.
- motif_width (int): Width of motifs for CWM extraction.
- compute_recall (bool): Whether to compute recall metrics requiring seqlets_df.
Returns
- report_data (Dict[str, Dict[str, Any]]):
Per-motif evaluation metrics including:
- num_hits_total, num_hits_restricted, num_seqlets
- num_overlaps, seqlet_recall, cwm_similarity
- report_df (pl.DataFrame): Tabular format of report_data for easy analysis.
- cwms (Dict[str, Dict[str, Float[ndarray, "4 W"]]]):
Extracted CWMs for each motif and condition:
- hits_fc, hits_rc: Forward/reverse complement hits
- modisco_fc, modisco_rc: TF-MoDISco forward/reverse
- seqlets_only, hits_restricted_only: Non-overlapping instances
- cwm_trim_bounds (Dict[str, Dict[str, Tuple[int, int]]]): Trimming boundaries for each CWM type and motif.
Notes
- Hits are filtered to central region defined by modisco_half_width
- CWM similarity is computed as normalized dot product between hit and TF-MoDISco CWMs
- Recall metrics require both hits_df and seqlets_df to be non-empty
- Missing motifs are handled gracefully with empty DataFrames
Raises
- ValueError: If required columns are missing from input DataFrames.
433def seqlet_confusion( 434 hits_df: Union[pl.DataFrame, pl.LazyFrame], 435 seqlets_df: Union[pl.DataFrame, pl.LazyFrame], 436 peaks_df: pl.DataFrame, 437 motif_names: List[str], 438 motif_width: int, 439) -> Tuple[pl.DataFrame, Float[ndarray, "M M"]]: 440 """Compute confusion matrix between TF-MoDISco seqlets and Fi-NeMo hits. 441 442 This function creates a confusion matrix showing the overlap between 443 TF-MoDISco seqlets (ground truth) and Fi-NeMo hits across different motifs. 444 Overlap frequencies are estimated using binned genomic coordinates. 445 446 Parameters 447 ---------- 448 hits_df : Union[pl.DataFrame, pl.LazyFrame] 449 Fi-NeMo hit calls with required columns: 450 - peak_id, start_untrimmed, end_untrimmed, strand, motif_name 451 seqlets_df : pl.DataFrame 452 TF-MoDISco seqlets with required columns: 453 - chr_id, start_untrimmed, end_untrimmed, motif_name 454 peaks_df : pl.DataFrame 455 Peak metadata for joining coordinates: 456 - peak_id, chr_id 457 motif_names : List[str] 458 Names of motifs to include in confusion matrix. 459 Determines matrix dimensions. 460 motif_width : int 461 Width used for binning genomic coordinates. 462 Positions are binned to motif_width resolution. 463 464 Returns 465 ------- 466 confusion_df : pl.DataFrame 467 Detailed confusion matrix in tabular format with columns: 468 - motif_name_seqlets : Seqlet motif labels (rows) 469 - motif_name_hits : Hit motif labels (columns) 470 - frac_overlap : Fraction of seqlets overlapping with hits 471 confusion_mat : Float[ndarray, "M M"] 472 Confusion matrix where M = len(motif_names). 473 Entry (i,j) = fraction of motif i seqlets overlapping with motif j hits. 474 Rows represent seqlet motifs, columns represent hit motifs. 475 476 Notes 477 ----- 478 - Genomic coordinates are binned to motif_width resolution for overlap detection 479 - Only exact bin overlaps are considered (same chr_id, start_bin, end_bin) 480 - Fractions are computed as: overlaps / total_seqlets_per_motif 481 - Missing motif combinations result in zero entries in the confusion matrix 482 483 Raises 484 ------ 485 ValueError 486 If required columns are missing from input DataFrames. 487 KeyError 488 If motif names in data don't match those in motif_names list. 489 """ 490 bin_size = motif_width 491 492 # Ensure hits_df is LazyFrame for consistent operations 493 if isinstance(hits_df, pl.DataFrame): 494 hits_df = hits_df.lazy() 495 496 hits_binned = ( 497 hits_df.with_columns( 498 peak_id=pl.col("peak_id").cast(pl.UInt32), 499 is_revcomp=pl.col("strand") == "-", 500 ) 501 .join(peaks_df.lazy(), on="peak_id", how="inner") 502 .unique(subset=["chr_id", "start_untrimmed", "motif_name", "is_revcomp"]) 503 .select( 504 chr_id=pl.col("chr_id"), 505 start_bin=pl.col("start_untrimmed") // bin_size, 506 end_bin=pl.col("end_untrimmed") // bin_size, 507 motif_name=pl.col("motif_name"), 508 ) 509 ) 510 511 seqlets_lazy = seqlets_df.lazy() 512 seqlets_binned = seqlets_lazy.select( 513 chr_id=pl.col("chr_id"), 514 start_bin=pl.col("start_untrimmed") // bin_size, 515 end_bin=pl.col("end_untrimmed") // bin_size, 516 motif_name=pl.col("motif_name"), 517 ) 518 519 overlaps_df = seqlets_binned.join( 520 hits_binned, on=["chr_id", "start_bin", "end_bin"], how="inner", suffix="_hits" 521 ) 522 523 seqlet_counts = ( 524 seqlets_binned.group_by("motif_name").len(name="num_seqlets").collect() 525 ) 526 overlap_counts = ( 527 overlaps_df.group_by(["motif_name", "motif_name_hits"]) 528 .len(name="num_overlaps") 529 .collect() 530 ) 531 532 num_motifs = len(motif_names) 533 confusion_mat = np.zeros((num_motifs, num_motifs), dtype=np.float32) 534 name_to_idx = {m: i for i, m in enumerate(motif_names)} 535 536 confusion_df = overlap_counts.join( 537 seqlet_counts, on="motif_name", how="inner" 538 ).select( 539 motif_name_seqlets=pl.col("motif_name"), 540 motif_name_hits=pl.col("motif_name_hits"), 541 frac_overlap=pl.col("num_overlaps") / pl.col("num_seqlets"), 542 ) 543 544 confusion_idx_df = confusion_df.select( 545 row_idx=pl.col("motif_name_seqlets").replace_strict(name_to_idx), 546 col_idx=pl.col("motif_name_hits").replace_strict(name_to_idx), 547 frac_overlap=pl.col("frac_overlap"), 548 ) 549 550 row_idx = confusion_idx_df["row_idx"].to_numpy() 551 col_idx = confusion_idx_df["col_idx"].to_numpy() 552 frac_overlap = confusion_idx_df["frac_overlap"].to_numpy() 553 554 confusion_mat[row_idx, col_idx] = frac_overlap 555 556 return confusion_df, confusion_mat
Compute confusion matrix between TF-MoDISco seqlets and Fi-NeMo hits.
This function creates a confusion matrix showing the overlap between TF-MoDISco seqlets (ground truth) and Fi-NeMo hits across different motifs. Overlap frequencies are estimated using binned genomic coordinates.
Parameters
- hits_df (Union[pl.DataFrame, pl.LazyFrame]):
Fi-NeMo hit calls with required columns:
- peak_id, start_untrimmed, end_untrimmed, strand, motif_name
- seqlets_df (pl.DataFrame):
TF-MoDISco seqlets with required columns:
- chr_id, start_untrimmed, end_untrimmed, motif_name
- peaks_df (pl.DataFrame):
Peak metadata for joining coordinates:
- peak_id, chr_id
- motif_names (List[str]): Names of motifs to include in confusion matrix. Determines matrix dimensions.
- motif_width (int): Width used for binning genomic coordinates. Positions are binned to motif_width resolution.
Returns
- confusion_df (pl.DataFrame):
Detailed confusion matrix in tabular format with columns:
- motif_name_seqlets : Seqlet motif labels (rows)
- motif_name_hits : Hit motif labels (columns)
- frac_overlap : Fraction of seqlets overlapping with hits
- confusion_mat (Float[ndarray, "M M"]): Confusion matrix where M = len(motif_names). Entry (i,j) = fraction of motif i seqlets overlapping with motif j hits. Rows represent seqlet motifs, columns represent hit motifs.
Notes
- Genomic coordinates are binned to motif_width resolution for overlap detection
- Only exact bin overlaps are considered (same chr_id, start_bin, end_bin)
- Fractions are computed as: overlaps / total_seqlets_per_motif
- Missing motif combinations result in zero entries in the confusion matrix
Raises
- ValueError: If required columns are missing from input DataFrames.
- KeyError: If motif names in data don't match those in motif_names list.