finemo.evaluation
Evaluation module for assessing Fi-NeMo motif discovery and hit calling performance.
This module provides functions for:
- Computing motif occurrence statistics and co-occurrence patterns
- Evaluating motif discovery quality against TF-MoDISco results
- Analyzing hit calling performance and recall metrics
- Generating confusion matrices for seqlet-hit comparisons
1"""Evaluation module for assessing Fi-NeMo motif discovery and hit calling performance. 2 3This module provides functions for: 4- Computing motif occurrence statistics and co-occurrence patterns 5- Evaluating motif discovery quality against TF-MoDISco results 6- Analyzing hit calling performance and recall metrics 7- Generating confusion matrices for seqlet-hit comparisons 8""" 9 10import warnings 11from typing import List, Tuple, Dict, Any, Union 12 13import numpy as np 14from numpy import ndarray 15import polars as pl 16from jaxtyping import Float, Int 17 18 19def get_motif_occurences( 20 hits_df: pl.LazyFrame, motif_names: List[str] 21) -> Tuple[pl.DataFrame, Int[ndarray, "M M"]]: 22 """Compute motif occurrence statistics and co-occurrence matrix. 23 24 This function analyzes motif occurrence patterns across peaks by creating 25 a pivot table of hit counts and computing pairwise co-occurrence statistics. 26 27 Parameters 28 ---------- 29 hits_df : pl.LazyFrame 30 Lazy DataFrame containing hit data with required columns: 31 - peak_id : Peak identifier 32 - motif_name : Name of the motif 33 Additional columns are ignored. 34 motif_names : List[str] 35 List of motif names to include in analysis. Missing motifs 36 will be added as columns with zero counts. 37 38 Returns 39 ------- 40 occ_df : pl.DataFrame 41 DataFrame with motif occurrence counts per peak. Contains: 42 - peak_id column 43 - One column per motif with hit counts 44 - 'total' column summing all motif counts per peak 45 coocc : Int[ndarray, "M M"] 46 Co-occurrence matrix where M = len(motif_names). 47 Entry (i,j) indicates number of peaks containing both motif i and motif j. 48 Diagonal entries show total peaks containing each motif. 49 50 Notes 51 ----- 52 The co-occurrence matrix is computed using binary occurrence indicators, 53 so multiple hits of the same motif in a peak are treated as a single occurrence. 54 """ 55 occ_df = ( 56 hits_df.collect() 57 .with_columns(pl.lit(1).alias("count")) 58 .pivot( 59 on="motif_name", index="peak_id", values="count", aggregate_function="sum" 60 ) 61 .fill_null(0) 62 ) 63 64 missing_cols = set(motif_names) - set(occ_df.columns) 65 occ_df = ( 66 occ_df.with_columns([pl.lit(0).alias(m) for m in missing_cols]) 67 .with_columns(total=pl.sum_horizontal(*motif_names)) 68 .sort(["peak_id"]) 69 ) 70 71 num_peaks = occ_df.height 72 num_motifs = len(motif_names) 73 74 occ_mat = np.zeros((num_peaks, num_motifs), dtype=np.int16) 75 for i, m in enumerate(motif_names): 76 occ_mat[:, i] = occ_df.get_column(m).to_numpy() 77 78 occ_bin = (occ_mat > 0).astype(np.int32) 79 coocc = occ_bin.T @ occ_bin 80 81 return occ_df, coocc 82 83 84def get_cwms( 85 regions: Float[ndarray, "N 4 L"], positions_df: pl.DataFrame, motif_width: int 86) -> Float[ndarray, "H 4 W"]: 87 """Extract contribution weight matrices from regions based on hit positions. 88 89 This function extracts motif-sized windows from contribution score regions 90 at positions specified by hit coordinates. It handles both forward and 91 reverse complement orientations and filters out invalid positions. 92 93 Parameters 94 ---------- 95 regions : Float[ndarray, "N 4 L"] 96 Input contribution score regions multiplied by one-hot sequences. 97 Shape: (n_peaks, 4, region_width) where 4 represents DNA bases (A,C,G,T). 98 positions_df : pl.DataFrame 99 DataFrame containing hit positions with required columns: 100 - peak_id : int, Peak index (0-based) 101 - start_untrimmed : int, Start position in genomic coordinates 102 - peak_region_start : int, Peak region start coordinate 103 - is_revcomp : bool, Whether hit is on reverse complement strand 104 motif_width : int 105 Width of motifs to extract. Must be positive. 106 107 Returns 108 ------- 109 cwms : Float[ndarray, "H 4 W"] 110 Extracted contribution weight matrices for valid hits. 111 Shape: (n_valid_hits, 4, motif_width) 112 Invalid hits (outside region boundaries) are filtered out. 113 114 Notes 115 ----- 116 - Start positions are converted from genomic to region-relative coordinates 117 - Reverse complement hits have their sequence order reversed 118 - Hits extending beyond region boundaries are excluded 119 - The mean is computed across all valid hits, with warnings suppressed 120 for empty slices or invalid operations 121 122 Raises 123 ------ 124 ValueError 125 If motif_width is non-positive or positions_df lacks required columns. 126 """ 127 idx_df = positions_df.select( 128 peak_idx=pl.col("peak_id"), 129 start_idx=pl.col("start_untrimmed") - pl.col("peak_region_start"), 130 is_revcomp=pl.col("is_revcomp"), 131 ) 132 peak_idx = idx_df.get_column("peak_idx").to_numpy() 133 start_idx = idx_df.get_column("start_idx").to_numpy() 134 is_revcomp = idx_df.get_column("is_revcomp").to_numpy().astype(bool) 135 136 # Filter hits that fall outside the region boundaries 137 valid_mask = (start_idx >= 0) & (start_idx + motif_width <= regions.shape[2]) 138 peak_idx = peak_idx[valid_mask] 139 start_idx = start_idx[valid_mask] 140 is_revcomp = is_revcomp[valid_mask] 141 142 row_idx = peak_idx[:, None, None] 143 pos_idx = start_idx[:, None, None] + np.zeros((1, 1, motif_width), dtype=int) 144 pos_idx[~is_revcomp, :, :] += np.arange(motif_width)[None, None, :] 145 pos_idx[is_revcomp, :, :] += np.arange(motif_width)[None, None, ::-1] 146 nuc_idx = np.zeros((peak_idx.shape[0], 4, 1), dtype=int) 147 nuc_idx[~is_revcomp, :, :] += np.arange(4)[None, :, None] 148 nuc_idx[is_revcomp, :, :] += np.arange(4)[None, ::-1, None] 149 150 seqs = regions[row_idx, nuc_idx, pos_idx] 151 152 with warnings.catch_warnings(): 153 warnings.filterwarnings( 154 action="ignore", message="invalid value encountered in divide" 155 ) 156 warnings.filterwarnings(action="ignore", message="Mean of empty slice") 157 cwms = seqs.mean(axis=0) 158 159 return cwms 160 161 162def tfmodisco_comparison( 163 regions: Float[ndarray, "N 4 L"], 164 hits_df: Union[pl.DataFrame, pl.LazyFrame], 165 peaks_df: pl.DataFrame, 166 seqlets_df: Union[pl.DataFrame, pl.LazyFrame, None], 167 motifs_df: pl.DataFrame, 168 cwms_modisco: Float[ndarray, "M 4 W"], 169 motif_names: List[str], 170 modisco_half_width: int, 171 motif_width: int, 172 compute_recall: bool, 173) -> Tuple[ 174 Dict[str, Dict[str, Any]], 175 pl.DataFrame, 176 Dict[str, Dict[str, Float[ndarray, "4 W"]]], 177 Dict[str, Dict[str, Tuple[int, int]]], 178]: 179 """Compare Fi-NeMo hits with TF-MoDISco seqlets and compute evaluation metrics. 180 181 This function performs comprehensive comparison between Fi-NeMo hit calls 182 and TF-MoDISco seqlets, computing recall metrics, CWM similarities, 183 and extracting contribution weight matrices for visualization. 184 185 Parameters 186 ---------- 187 regions : Float[ndarray, "N 4 L"] 188 Contribution score regions multiplied by one-hot sequences. 189 Shape: (n_peaks, 4, region_length) 190 hits_df : Union[pl.DataFrame, pl.LazyFrame] 191 Fi-NeMo hit calls with required columns: 192 - peak_id, start_untrimmed, end_untrimmed, strand, motif_name 193 peaks_df : pl.DataFrame 194 Peak metadata with columns: 195 - peak_id, chr_id, peak_region_start 196 seqlets_df : Optional[pl.DataFrame] 197 TF-MoDISco seqlets with columns: 198 - chr_id, start_untrimmed, is_revcomp, motif_name 199 If None, only basic hit statistics are computed. 200 motifs_df : pl.DataFrame 201 Motif metadata with columns: 202 - motif_name, strand, motif_id, motif_start, motif_end 203 cwms_modisco : Float[ndarray, "M 4 W"] 204 TF-MoDISco contribution weight matrices. 205 Shape: (n_modisco_motifs, 4, motif_width) 206 motif_names : List[str] 207 Names of motifs to analyze. 208 modisco_half_width : int 209 Half-width for restricting hits to central region for fair comparison. 210 motif_width : int 211 Width of motifs for CWM extraction. 212 compute_recall : bool 213 Whether to compute recall metrics requiring seqlets_df. 214 215 Returns 216 ------- 217 report_data : Dict[str, Dict[str, Any]] 218 Per-motif evaluation metrics including: 219 - num_hits_total, num_hits_restricted, num_seqlets 220 - num_overlaps, seqlet_recall, cwm_similarity 221 report_df : pl.DataFrame 222 Tabular format of report_data for easy analysis. 223 cwms : Dict[str, Dict[str, Float[ndarray, "4 W"]]] 224 Extracted CWMs for each motif and condition: 225 - hits_fc, hits_rc: Forward/reverse complement hits 226 - modisco_fc, modisco_rc: TF-MoDISco forward/reverse 227 - seqlets_only, hits_restricted_only: Non-overlapping instances 228 cwm_trim_bounds : Dict[str, Dict[str, Tuple[int, int]]] 229 Trimming boundaries for each CWM type and motif. 230 231 Notes 232 ----- 233 - Hits are filtered to central region defined by modisco_half_width 234 - CWM similarity is computed as normalized dot product between hit and TF-MoDISco CWMs 235 - Recall metrics require both hits_df and seqlets_df to be non-empty 236 - Missing motifs are handled gracefully with empty DataFrames 237 238 Raises 239 ------ 240 ValueError 241 If required columns are missing from input DataFrames. 242 """ 243 244 # Ensure hits_df is LazyFrame for consistent operations 245 if isinstance(hits_df, pl.DataFrame): 246 hits_df = hits_df.lazy() 247 248 hits_df = ( 249 hits_df.with_columns(pl.col("peak_id").cast(pl.UInt32)) 250 .join(peaks_df.lazy(), on="peak_id", how="inner") 251 .select( 252 chr_id=pl.col("chr_id"), 253 start_untrimmed=pl.col("start_untrimmed"), 254 end_untrimmed=pl.col("end_untrimmed"), 255 is_revcomp=pl.col("strand") == "-", 256 motif_name=pl.col("motif_name"), 257 peak_region_start=pl.col("peak_region_start"), 258 peak_id=pl.col("peak_id"), 259 ) 260 ) 261 262 hits_unique = hits_df.unique( 263 subset=["chr_id", "start_untrimmed", "motif_name", "is_revcomp"] 264 ) 265 266 region_len = regions.shape[2] 267 center = region_len / 2 268 hits_filtered = hits_df.filter( 269 ( 270 (pl.col("start_untrimmed") - pl.col("peak_region_start")) 271 >= (center - modisco_half_width) 272 ) 273 & ( 274 (pl.col("end_untrimmed") - pl.col("peak_region_start")) 275 <= (center + modisco_half_width) 276 ) 277 ).unique(subset=["chr_id", "start_untrimmed", "motif_name", "is_revcomp"]) 278 279 hits_by_motif = hits_unique.collect().partition_by("motif_name", as_dict=True) 280 hits_filtered_by_motif = hits_filtered.collect().partition_by( 281 "motif_name", as_dict=True 282 ) 283 284 if seqlets_df is None: 285 seqlets_collected = None 286 seqlets_lazy = None 287 elif isinstance(seqlets_df, pl.LazyFrame): 288 seqlets_collected = seqlets_df.collect() 289 seqlets_lazy = seqlets_df 290 else: 291 seqlets_collected = seqlets_df 292 seqlets_lazy = seqlets_df.lazy() 293 294 if seqlets_collected is not None: 295 seqlets_by_motif = seqlets_collected.partition_by("motif_name", as_dict=True) 296 else: 297 seqlets_by_motif = {} 298 299 if compute_recall and seqlets_lazy is not None: 300 overlaps_df = hits_filtered.join( 301 seqlets_lazy, 302 on=["chr_id", "start_untrimmed", "is_revcomp", "motif_name"], 303 how="inner", 304 ).collect() 305 306 seqlets_only_df = seqlets_lazy.join( 307 hits_df, 308 on=["chr_id", "start_untrimmed", "is_revcomp", "motif_name"], 309 how="anti", 310 ).collect() 311 312 hits_only_filtered_df = hits_filtered.join( 313 seqlets_lazy, 314 on=["chr_id", "start_untrimmed", "is_revcomp", "motif_name"], 315 how="anti", 316 ).collect() 317 318 # Create partition dictionaries 319 overlaps_by_motif = overlaps_df.partition_by("motif_name", as_dict=True) 320 seqlets_only_by_motif = seqlets_only_df.partition_by("motif_name", as_dict=True) 321 hits_only_filtered_by_motif = hits_only_filtered_df.partition_by( 322 "motif_name", as_dict=True 323 ) 324 else: 325 overlaps_by_motif = {} 326 seqlets_only_by_motif = {} 327 hits_only_filtered_by_motif = {} 328 329 report_data = {} 330 cwms = {} 331 cwm_trim_bounds = {} 332 dummy_df = hits_df.clear().collect() 333 for m in motif_names: 334 hits = hits_by_motif.get((m,), dummy_df) 335 hits_filtered = hits_filtered_by_motif.get((m,), dummy_df) 336 337 # Initialize default values 338 seqlets = dummy_df 339 overlaps = dummy_df 340 seqlets_only = dummy_df 341 hits_only_filtered = dummy_df 342 343 if seqlets_df is not None: 344 seqlets = seqlets_by_motif.get((m,), dummy_df) 345 346 if compute_recall and seqlets_df is not None: 347 overlaps = overlaps_by_motif.get((m,), dummy_df) 348 seqlets_only = seqlets_only_by_motif.get((m,), dummy_df) 349 hits_only_filtered = hits_only_filtered_by_motif.get((m,), dummy_df) 350 351 report_data[m] = { 352 "num_hits_total": hits.height, 353 "num_hits_restricted": hits_filtered.height, 354 } 355 356 if seqlets_df is not None: 357 report_data[m]["num_seqlets"] = seqlets.height 358 359 if compute_recall and seqlets_df is not None: 360 report_data[m] |= { 361 "num_overlaps": overlaps.height, 362 "num_seqlets_only": seqlets_only.height, 363 "num_hits_restricted_only": hits_only_filtered.height, 364 "seqlet_recall": np.float64(overlaps.height) / seqlets.height 365 if seqlets.height > 0 366 else 0.0, 367 } 368 369 motif_data_fc = motifs_df.row( 370 by_predicate=(pl.col("motif_name") == m) & (pl.col("strand") == "+"), 371 named=True, 372 ) 373 motif_data_rc = motifs_df.row( 374 by_predicate=(pl.col("motif_name") == m) & (pl.col("strand") == "-"), 375 named=True, 376 ) 377 378 cwms[m] = { 379 "hits_fc": get_cwms(regions, hits, motif_width), 380 "modisco_fc": cwms_modisco[motif_data_fc["motif_id"]], 381 "modisco_rc": cwms_modisco[motif_data_rc["motif_id"]], 382 } 383 cwms[m]["hits_rc"] = cwms[m]["hits_fc"][::-1, ::-1] 384 385 if compute_recall and seqlets_df is not None: 386 cwms[m] |= { 387 "seqlets_only": get_cwms(regions, seqlets_only, motif_width), 388 "hits_restricted_only": get_cwms( 389 regions, hits_only_filtered, motif_width 390 ), 391 } 392 393 bounds_fc = (motif_data_fc["motif_start"], motif_data_fc["motif_end"]) 394 bounds_rc = (motif_data_rc["motif_start"], motif_data_rc["motif_end"]) 395 396 cwm_trim_bounds[m] = { 397 "hits_fc": bounds_fc, 398 "modisco_fc": bounds_fc, 399 "modisco_rc": bounds_rc, 400 "hits_rc": bounds_rc, 401 } 402 403 if compute_recall and seqlets_df is not None: 404 cwm_trim_bounds[m] |= { 405 "seqlets_only": bounds_fc, 406 "hits_restricted_only": bounds_fc, 407 } 408 409 hits_cwm = cwms[m]["hits_fc"] 410 modisco_cwm = cwms[m]["modisco_fc"] 411 hnorm = np.sqrt((hits_cwm**2).sum()) 412 snorm = np.sqrt((modisco_cwm**2).sum()) 413 cwm_sim = (hits_cwm * modisco_cwm).sum() / (hnorm * snorm) 414 415 report_data[m]["cwm_similarity"] = cwm_sim 416 417 records = [{"motif_name": k} | v for k, v in report_data.items()] 418 report_df = pl.from_dicts(records) 419 420 return report_data, report_df, cwms, cwm_trim_bounds 421 422 423def seqlet_confusion( 424 hits_df: Union[pl.DataFrame, pl.LazyFrame], 425 seqlets_df: Union[pl.DataFrame, pl.LazyFrame], 426 peaks_df: pl.DataFrame, 427 motif_names: List[str], 428 motif_width: int, 429) -> Tuple[pl.DataFrame, Float[ndarray, "M M"]]: 430 """Compute confusion matrix between TF-MoDISco seqlets and Fi-NeMo hits. 431 432 This function creates a confusion matrix showing the overlap between 433 TF-MoDISco seqlets (ground truth) and Fi-NeMo hits across different motifs. 434 Overlap frequencies are estimated using binned genomic coordinates. 435 436 Parameters 437 ---------- 438 hits_df : Union[pl.DataFrame, pl.LazyFrame] 439 Fi-NeMo hit calls with required columns: 440 - peak_id, start_untrimmed, end_untrimmed, strand, motif_name 441 seqlets_df : pl.DataFrame 442 TF-MoDISco seqlets with required columns: 443 - chr_id, start_untrimmed, end_untrimmed, motif_name 444 peaks_df : pl.DataFrame 445 Peak metadata for joining coordinates: 446 - peak_id, chr_id 447 motif_names : List[str] 448 Names of motifs to include in confusion matrix. 449 Determines matrix dimensions. 450 motif_width : int 451 Width used for binning genomic coordinates. 452 Positions are binned to motif_width resolution. 453 454 Returns 455 ------- 456 confusion_df : pl.DataFrame 457 Detailed confusion matrix in tabular format with columns: 458 - motif_name_seqlets : Seqlet motif labels (rows) 459 - motif_name_hits : Hit motif labels (columns) 460 - frac_overlap : Fraction of seqlets overlapping with hits 461 confusion_mat : Float[ndarray, "M M"] 462 Confusion matrix where M = len(motif_names). 463 Entry (i,j) = fraction of motif i seqlets overlapping with motif j hits. 464 Rows represent seqlet motifs, columns represent hit motifs. 465 466 Notes 467 ----- 468 - Genomic coordinates are binned to motif_width resolution for overlap detection 469 - Only exact bin overlaps are considered (same chr_id, start_bin, end_bin) 470 - Fractions are computed as: overlaps / total_seqlets_per_motif 471 - Missing motif combinations result in zero entries in the confusion matrix 472 473 Raises 474 ------ 475 ValueError 476 If required columns are missing from input DataFrames. 477 KeyError 478 If motif names in data don't match those in motif_names list. 479 """ 480 bin_size = motif_width 481 482 # Ensure hits_df is LazyFrame for consistent operations 483 if isinstance(hits_df, pl.DataFrame): 484 hits_df = hits_df.lazy() 485 486 hits_binned = ( 487 hits_df.with_columns( 488 peak_id=pl.col("peak_id").cast(pl.UInt32), 489 is_revcomp=pl.col("strand") == "-", 490 ) 491 .join(peaks_df.lazy(), on="peak_id", how="inner") 492 .unique(subset=["chr_id", "start_untrimmed", "motif_name", "is_revcomp"]) 493 .select( 494 chr_id=pl.col("chr_id"), 495 start_bin=pl.col("start_untrimmed") // bin_size, 496 end_bin=pl.col("end_untrimmed") // bin_size, 497 motif_name=pl.col("motif_name"), 498 ) 499 ) 500 501 seqlets_lazy = seqlets_df.lazy() 502 seqlets_binned = seqlets_lazy.select( 503 chr_id=pl.col("chr_id"), 504 start_bin=pl.col("start_untrimmed") // bin_size, 505 end_bin=pl.col("end_untrimmed") // bin_size, 506 motif_name=pl.col("motif_name"), 507 ) 508 509 overlaps_df = seqlets_binned.join( 510 hits_binned, on=["chr_id", "start_bin", "end_bin"], how="inner", suffix="_hits" 511 ) 512 513 seqlet_counts = ( 514 seqlets_binned.group_by("motif_name").len(name="num_seqlets").collect() 515 ) 516 overlap_counts = ( 517 overlaps_df.group_by(["motif_name", "motif_name_hits"]) 518 .len(name="num_overlaps") 519 .collect() 520 ) 521 522 num_motifs = len(motif_names) 523 confusion_mat = np.zeros((num_motifs, num_motifs), dtype=np.float32) 524 name_to_idx = {m: i for i, m in enumerate(motif_names)} 525 526 confusion_df = overlap_counts.join( 527 seqlet_counts, on="motif_name", how="inner" 528 ).select( 529 motif_name_seqlets=pl.col("motif_name"), 530 motif_name_hits=pl.col("motif_name_hits"), 531 frac_overlap=pl.col("num_overlaps") / pl.col("num_seqlets"), 532 ) 533 534 confusion_idx_df = confusion_df.select( 535 row_idx=pl.col("motif_name_seqlets").replace_strict(name_to_idx), 536 col_idx=pl.col("motif_name_hits").replace_strict(name_to_idx), 537 frac_overlap=pl.col("frac_overlap"), 538 ) 539 540 row_idx = confusion_idx_df["row_idx"].to_numpy() 541 col_idx = confusion_idx_df["col_idx"].to_numpy() 542 frac_overlap = confusion_idx_df["frac_overlap"].to_numpy() 543 544 confusion_mat[row_idx, col_idx] = frac_overlap 545 546 return confusion_df, confusion_mat
20def get_motif_occurences( 21 hits_df: pl.LazyFrame, motif_names: List[str] 22) -> Tuple[pl.DataFrame, Int[ndarray, "M M"]]: 23 """Compute motif occurrence statistics and co-occurrence matrix. 24 25 This function analyzes motif occurrence patterns across peaks by creating 26 a pivot table of hit counts and computing pairwise co-occurrence statistics. 27 28 Parameters 29 ---------- 30 hits_df : pl.LazyFrame 31 Lazy DataFrame containing hit data with required columns: 32 - peak_id : Peak identifier 33 - motif_name : Name of the motif 34 Additional columns are ignored. 35 motif_names : List[str] 36 List of motif names to include in analysis. Missing motifs 37 will be added as columns with zero counts. 38 39 Returns 40 ------- 41 occ_df : pl.DataFrame 42 DataFrame with motif occurrence counts per peak. Contains: 43 - peak_id column 44 - One column per motif with hit counts 45 - 'total' column summing all motif counts per peak 46 coocc : Int[ndarray, "M M"] 47 Co-occurrence matrix where M = len(motif_names). 48 Entry (i,j) indicates number of peaks containing both motif i and motif j. 49 Diagonal entries show total peaks containing each motif. 50 51 Notes 52 ----- 53 The co-occurrence matrix is computed using binary occurrence indicators, 54 so multiple hits of the same motif in a peak are treated as a single occurrence. 55 """ 56 occ_df = ( 57 hits_df.collect() 58 .with_columns(pl.lit(1).alias("count")) 59 .pivot( 60 on="motif_name", index="peak_id", values="count", aggregate_function="sum" 61 ) 62 .fill_null(0) 63 ) 64 65 missing_cols = set(motif_names) - set(occ_df.columns) 66 occ_df = ( 67 occ_df.with_columns([pl.lit(0).alias(m) for m in missing_cols]) 68 .with_columns(total=pl.sum_horizontal(*motif_names)) 69 .sort(["peak_id"]) 70 ) 71 72 num_peaks = occ_df.height 73 num_motifs = len(motif_names) 74 75 occ_mat = np.zeros((num_peaks, num_motifs), dtype=np.int16) 76 for i, m in enumerate(motif_names): 77 occ_mat[:, i] = occ_df.get_column(m).to_numpy() 78 79 occ_bin = (occ_mat > 0).astype(np.int32) 80 coocc = occ_bin.T @ occ_bin 81 82 return occ_df, coocc
Compute motif occurrence statistics and co-occurrence matrix.
This function analyzes motif occurrence patterns across peaks by creating a pivot table of hit counts and computing pairwise co-occurrence statistics.
Parameters
- hits_df (pl.LazyFrame):
Lazy DataFrame containing hit data with required columns:
- peak_id : Peak identifier
- motif_name : Name of the motif Additional columns are ignored.
- motif_names (List[str]): List of motif names to include in analysis. Missing motifs will be added as columns with zero counts.
Returns
- occ_df (pl.DataFrame):
DataFrame with motif occurrence counts per peak. Contains:
- peak_id column
- One column per motif with hit counts
- 'total' column summing all motif counts per peak
- coocc (Int[ndarray, "M M"]): Co-occurrence matrix where M = len(motif_names). Entry (i,j) indicates number of peaks containing both motif i and motif j. Diagonal entries show total peaks containing each motif.
Notes
The co-occurrence matrix is computed using binary occurrence indicators, so multiple hits of the same motif in a peak are treated as a single occurrence.
85def get_cwms( 86 regions: Float[ndarray, "N 4 L"], positions_df: pl.DataFrame, motif_width: int 87) -> Float[ndarray, "H 4 W"]: 88 """Extract contribution weight matrices from regions based on hit positions. 89 90 This function extracts motif-sized windows from contribution score regions 91 at positions specified by hit coordinates. It handles both forward and 92 reverse complement orientations and filters out invalid positions. 93 94 Parameters 95 ---------- 96 regions : Float[ndarray, "N 4 L"] 97 Input contribution score regions multiplied by one-hot sequences. 98 Shape: (n_peaks, 4, region_width) where 4 represents DNA bases (A,C,G,T). 99 positions_df : pl.DataFrame 100 DataFrame containing hit positions with required columns: 101 - peak_id : int, Peak index (0-based) 102 - start_untrimmed : int, Start position in genomic coordinates 103 - peak_region_start : int, Peak region start coordinate 104 - is_revcomp : bool, Whether hit is on reverse complement strand 105 motif_width : int 106 Width of motifs to extract. Must be positive. 107 108 Returns 109 ------- 110 cwms : Float[ndarray, "H 4 W"] 111 Extracted contribution weight matrices for valid hits. 112 Shape: (n_valid_hits, 4, motif_width) 113 Invalid hits (outside region boundaries) are filtered out. 114 115 Notes 116 ----- 117 - Start positions are converted from genomic to region-relative coordinates 118 - Reverse complement hits have their sequence order reversed 119 - Hits extending beyond region boundaries are excluded 120 - The mean is computed across all valid hits, with warnings suppressed 121 for empty slices or invalid operations 122 123 Raises 124 ------ 125 ValueError 126 If motif_width is non-positive or positions_df lacks required columns. 127 """ 128 idx_df = positions_df.select( 129 peak_idx=pl.col("peak_id"), 130 start_idx=pl.col("start_untrimmed") - pl.col("peak_region_start"), 131 is_revcomp=pl.col("is_revcomp"), 132 ) 133 peak_idx = idx_df.get_column("peak_idx").to_numpy() 134 start_idx = idx_df.get_column("start_idx").to_numpy() 135 is_revcomp = idx_df.get_column("is_revcomp").to_numpy().astype(bool) 136 137 # Filter hits that fall outside the region boundaries 138 valid_mask = (start_idx >= 0) & (start_idx + motif_width <= regions.shape[2]) 139 peak_idx = peak_idx[valid_mask] 140 start_idx = start_idx[valid_mask] 141 is_revcomp = is_revcomp[valid_mask] 142 143 row_idx = peak_idx[:, None, None] 144 pos_idx = start_idx[:, None, None] + np.zeros((1, 1, motif_width), dtype=int) 145 pos_idx[~is_revcomp, :, :] += np.arange(motif_width)[None, None, :] 146 pos_idx[is_revcomp, :, :] += np.arange(motif_width)[None, None, ::-1] 147 nuc_idx = np.zeros((peak_idx.shape[0], 4, 1), dtype=int) 148 nuc_idx[~is_revcomp, :, :] += np.arange(4)[None, :, None] 149 nuc_idx[is_revcomp, :, :] += np.arange(4)[None, ::-1, None] 150 151 seqs = regions[row_idx, nuc_idx, pos_idx] 152 153 with warnings.catch_warnings(): 154 warnings.filterwarnings( 155 action="ignore", message="invalid value encountered in divide" 156 ) 157 warnings.filterwarnings(action="ignore", message="Mean of empty slice") 158 cwms = seqs.mean(axis=0) 159 160 return cwms
Extract contribution weight matrices from regions based on hit positions.
This function extracts motif-sized windows from contribution score regions at positions specified by hit coordinates. It handles both forward and reverse complement orientations and filters out invalid positions.
Parameters
- regions (Float[ndarray, "N 4 L"]): Input contribution score regions multiplied by one-hot sequences. Shape: (n_peaks, 4, region_width) where 4 represents DNA bases (A,C,G,T).
- positions_df (pl.DataFrame):
DataFrame containing hit positions with required columns:
- peak_id : int, Peak index (0-based)
- start_untrimmed : int, Start position in genomic coordinates
- peak_region_start : int, Peak region start coordinate
- is_revcomp : bool, Whether hit is on reverse complement strand
- motif_width (int): Width of motifs to extract. Must be positive.
Returns
- cwms (Float[ndarray, "H 4 W"]): Extracted contribution weight matrices for valid hits. Shape: (n_valid_hits, 4, motif_width) Invalid hits (outside region boundaries) are filtered out.
Notes
- Start positions are converted from genomic to region-relative coordinates
- Reverse complement hits have their sequence order reversed
- Hits extending beyond region boundaries are excluded
- The mean is computed across all valid hits, with warnings suppressed for empty slices or invalid operations
Raises
- ValueError: If motif_width is non-positive or positions_df lacks required columns.
163def tfmodisco_comparison( 164 regions: Float[ndarray, "N 4 L"], 165 hits_df: Union[pl.DataFrame, pl.LazyFrame], 166 peaks_df: pl.DataFrame, 167 seqlets_df: Union[pl.DataFrame, pl.LazyFrame, None], 168 motifs_df: pl.DataFrame, 169 cwms_modisco: Float[ndarray, "M 4 W"], 170 motif_names: List[str], 171 modisco_half_width: int, 172 motif_width: int, 173 compute_recall: bool, 174) -> Tuple[ 175 Dict[str, Dict[str, Any]], 176 pl.DataFrame, 177 Dict[str, Dict[str, Float[ndarray, "4 W"]]], 178 Dict[str, Dict[str, Tuple[int, int]]], 179]: 180 """Compare Fi-NeMo hits with TF-MoDISco seqlets and compute evaluation metrics. 181 182 This function performs comprehensive comparison between Fi-NeMo hit calls 183 and TF-MoDISco seqlets, computing recall metrics, CWM similarities, 184 and extracting contribution weight matrices for visualization. 185 186 Parameters 187 ---------- 188 regions : Float[ndarray, "N 4 L"] 189 Contribution score regions multiplied by one-hot sequences. 190 Shape: (n_peaks, 4, region_length) 191 hits_df : Union[pl.DataFrame, pl.LazyFrame] 192 Fi-NeMo hit calls with required columns: 193 - peak_id, start_untrimmed, end_untrimmed, strand, motif_name 194 peaks_df : pl.DataFrame 195 Peak metadata with columns: 196 - peak_id, chr_id, peak_region_start 197 seqlets_df : Optional[pl.DataFrame] 198 TF-MoDISco seqlets with columns: 199 - chr_id, start_untrimmed, is_revcomp, motif_name 200 If None, only basic hit statistics are computed. 201 motifs_df : pl.DataFrame 202 Motif metadata with columns: 203 - motif_name, strand, motif_id, motif_start, motif_end 204 cwms_modisco : Float[ndarray, "M 4 W"] 205 TF-MoDISco contribution weight matrices. 206 Shape: (n_modisco_motifs, 4, motif_width) 207 motif_names : List[str] 208 Names of motifs to analyze. 209 modisco_half_width : int 210 Half-width for restricting hits to central region for fair comparison. 211 motif_width : int 212 Width of motifs for CWM extraction. 213 compute_recall : bool 214 Whether to compute recall metrics requiring seqlets_df. 215 216 Returns 217 ------- 218 report_data : Dict[str, Dict[str, Any]] 219 Per-motif evaluation metrics including: 220 - num_hits_total, num_hits_restricted, num_seqlets 221 - num_overlaps, seqlet_recall, cwm_similarity 222 report_df : pl.DataFrame 223 Tabular format of report_data for easy analysis. 224 cwms : Dict[str, Dict[str, Float[ndarray, "4 W"]]] 225 Extracted CWMs for each motif and condition: 226 - hits_fc, hits_rc: Forward/reverse complement hits 227 - modisco_fc, modisco_rc: TF-MoDISco forward/reverse 228 - seqlets_only, hits_restricted_only: Non-overlapping instances 229 cwm_trim_bounds : Dict[str, Dict[str, Tuple[int, int]]] 230 Trimming boundaries for each CWM type and motif. 231 232 Notes 233 ----- 234 - Hits are filtered to central region defined by modisco_half_width 235 - CWM similarity is computed as normalized dot product between hit and TF-MoDISco CWMs 236 - Recall metrics require both hits_df and seqlets_df to be non-empty 237 - Missing motifs are handled gracefully with empty DataFrames 238 239 Raises 240 ------ 241 ValueError 242 If required columns are missing from input DataFrames. 243 """ 244 245 # Ensure hits_df is LazyFrame for consistent operations 246 if isinstance(hits_df, pl.DataFrame): 247 hits_df = hits_df.lazy() 248 249 hits_df = ( 250 hits_df.with_columns(pl.col("peak_id").cast(pl.UInt32)) 251 .join(peaks_df.lazy(), on="peak_id", how="inner") 252 .select( 253 chr_id=pl.col("chr_id"), 254 start_untrimmed=pl.col("start_untrimmed"), 255 end_untrimmed=pl.col("end_untrimmed"), 256 is_revcomp=pl.col("strand") == "-", 257 motif_name=pl.col("motif_name"), 258 peak_region_start=pl.col("peak_region_start"), 259 peak_id=pl.col("peak_id"), 260 ) 261 ) 262 263 hits_unique = hits_df.unique( 264 subset=["chr_id", "start_untrimmed", "motif_name", "is_revcomp"] 265 ) 266 267 region_len = regions.shape[2] 268 center = region_len / 2 269 hits_filtered = hits_df.filter( 270 ( 271 (pl.col("start_untrimmed") - pl.col("peak_region_start")) 272 >= (center - modisco_half_width) 273 ) 274 & ( 275 (pl.col("end_untrimmed") - pl.col("peak_region_start")) 276 <= (center + modisco_half_width) 277 ) 278 ).unique(subset=["chr_id", "start_untrimmed", "motif_name", "is_revcomp"]) 279 280 hits_by_motif = hits_unique.collect().partition_by("motif_name", as_dict=True) 281 hits_filtered_by_motif = hits_filtered.collect().partition_by( 282 "motif_name", as_dict=True 283 ) 284 285 if seqlets_df is None: 286 seqlets_collected = None 287 seqlets_lazy = None 288 elif isinstance(seqlets_df, pl.LazyFrame): 289 seqlets_collected = seqlets_df.collect() 290 seqlets_lazy = seqlets_df 291 else: 292 seqlets_collected = seqlets_df 293 seqlets_lazy = seqlets_df.lazy() 294 295 if seqlets_collected is not None: 296 seqlets_by_motif = seqlets_collected.partition_by("motif_name", as_dict=True) 297 else: 298 seqlets_by_motif = {} 299 300 if compute_recall and seqlets_lazy is not None: 301 overlaps_df = hits_filtered.join( 302 seqlets_lazy, 303 on=["chr_id", "start_untrimmed", "is_revcomp", "motif_name"], 304 how="inner", 305 ).collect() 306 307 seqlets_only_df = seqlets_lazy.join( 308 hits_df, 309 on=["chr_id", "start_untrimmed", "is_revcomp", "motif_name"], 310 how="anti", 311 ).collect() 312 313 hits_only_filtered_df = hits_filtered.join( 314 seqlets_lazy, 315 on=["chr_id", "start_untrimmed", "is_revcomp", "motif_name"], 316 how="anti", 317 ).collect() 318 319 # Create partition dictionaries 320 overlaps_by_motif = overlaps_df.partition_by("motif_name", as_dict=True) 321 seqlets_only_by_motif = seqlets_only_df.partition_by("motif_name", as_dict=True) 322 hits_only_filtered_by_motif = hits_only_filtered_df.partition_by( 323 "motif_name", as_dict=True 324 ) 325 else: 326 overlaps_by_motif = {} 327 seqlets_only_by_motif = {} 328 hits_only_filtered_by_motif = {} 329 330 report_data = {} 331 cwms = {} 332 cwm_trim_bounds = {} 333 dummy_df = hits_df.clear().collect() 334 for m in motif_names: 335 hits = hits_by_motif.get((m,), dummy_df) 336 hits_filtered = hits_filtered_by_motif.get((m,), dummy_df) 337 338 # Initialize default values 339 seqlets = dummy_df 340 overlaps = dummy_df 341 seqlets_only = dummy_df 342 hits_only_filtered = dummy_df 343 344 if seqlets_df is not None: 345 seqlets = seqlets_by_motif.get((m,), dummy_df) 346 347 if compute_recall and seqlets_df is not None: 348 overlaps = overlaps_by_motif.get((m,), dummy_df) 349 seqlets_only = seqlets_only_by_motif.get((m,), dummy_df) 350 hits_only_filtered = hits_only_filtered_by_motif.get((m,), dummy_df) 351 352 report_data[m] = { 353 "num_hits_total": hits.height, 354 "num_hits_restricted": hits_filtered.height, 355 } 356 357 if seqlets_df is not None: 358 report_data[m]["num_seqlets"] = seqlets.height 359 360 if compute_recall and seqlets_df is not None: 361 report_data[m] |= { 362 "num_overlaps": overlaps.height, 363 "num_seqlets_only": seqlets_only.height, 364 "num_hits_restricted_only": hits_only_filtered.height, 365 "seqlet_recall": np.float64(overlaps.height) / seqlets.height 366 if seqlets.height > 0 367 else 0.0, 368 } 369 370 motif_data_fc = motifs_df.row( 371 by_predicate=(pl.col("motif_name") == m) & (pl.col("strand") == "+"), 372 named=True, 373 ) 374 motif_data_rc = motifs_df.row( 375 by_predicate=(pl.col("motif_name") == m) & (pl.col("strand") == "-"), 376 named=True, 377 ) 378 379 cwms[m] = { 380 "hits_fc": get_cwms(regions, hits, motif_width), 381 "modisco_fc": cwms_modisco[motif_data_fc["motif_id"]], 382 "modisco_rc": cwms_modisco[motif_data_rc["motif_id"]], 383 } 384 cwms[m]["hits_rc"] = cwms[m]["hits_fc"][::-1, ::-1] 385 386 if compute_recall and seqlets_df is not None: 387 cwms[m] |= { 388 "seqlets_only": get_cwms(regions, seqlets_only, motif_width), 389 "hits_restricted_only": get_cwms( 390 regions, hits_only_filtered, motif_width 391 ), 392 } 393 394 bounds_fc = (motif_data_fc["motif_start"], motif_data_fc["motif_end"]) 395 bounds_rc = (motif_data_rc["motif_start"], motif_data_rc["motif_end"]) 396 397 cwm_trim_bounds[m] = { 398 "hits_fc": bounds_fc, 399 "modisco_fc": bounds_fc, 400 "modisco_rc": bounds_rc, 401 "hits_rc": bounds_rc, 402 } 403 404 if compute_recall and seqlets_df is not None: 405 cwm_trim_bounds[m] |= { 406 "seqlets_only": bounds_fc, 407 "hits_restricted_only": bounds_fc, 408 } 409 410 hits_cwm = cwms[m]["hits_fc"] 411 modisco_cwm = cwms[m]["modisco_fc"] 412 hnorm = np.sqrt((hits_cwm**2).sum()) 413 snorm = np.sqrt((modisco_cwm**2).sum()) 414 cwm_sim = (hits_cwm * modisco_cwm).sum() / (hnorm * snorm) 415 416 report_data[m]["cwm_similarity"] = cwm_sim 417 418 records = [{"motif_name": k} | v for k, v in report_data.items()] 419 report_df = pl.from_dicts(records) 420 421 return report_data, report_df, cwms, cwm_trim_bounds
Compare Fi-NeMo hits with TF-MoDISco seqlets and compute evaluation metrics.
This function performs comprehensive comparison between Fi-NeMo hit calls and TF-MoDISco seqlets, computing recall metrics, CWM similarities, and extracting contribution weight matrices for visualization.
Parameters
- regions (Float[ndarray, "N 4 L"]): Contribution score regions multiplied by one-hot sequences. Shape: (n_peaks, 4, region_length)
- hits_df (Union[pl.DataFrame, pl.LazyFrame]):
Fi-NeMo hit calls with required columns:
- peak_id, start_untrimmed, end_untrimmed, strand, motif_name
- peaks_df (pl.DataFrame):
Peak metadata with columns:
- peak_id, chr_id, peak_region_start
- seqlets_df (Optional[pl.DataFrame]):
TF-MoDISco seqlets with columns:
- chr_id, start_untrimmed, is_revcomp, motif_name If None, only basic hit statistics are computed.
- motifs_df (pl.DataFrame):
Motif metadata with columns:
- motif_name, strand, motif_id, motif_start, motif_end
- cwms_modisco (Float[ndarray, "M 4 W"]): TF-MoDISco contribution weight matrices. Shape: (n_modisco_motifs, 4, motif_width)
- motif_names (List[str]): Names of motifs to analyze.
- modisco_half_width (int): Half-width for restricting hits to central region for fair comparison.
- motif_width (int): Width of motifs for CWM extraction.
- compute_recall (bool): Whether to compute recall metrics requiring seqlets_df.
Returns
- report_data (Dict[str, Dict[str, Any]]):
Per-motif evaluation metrics including:
- num_hits_total, num_hits_restricted, num_seqlets
- num_overlaps, seqlet_recall, cwm_similarity
- report_df (pl.DataFrame): Tabular format of report_data for easy analysis.
- cwms (Dict[str, Dict[str, Float[ndarray, "4 W"]]]):
Extracted CWMs for each motif and condition:
- hits_fc, hits_rc: Forward/reverse complement hits
- modisco_fc, modisco_rc: TF-MoDISco forward/reverse
- seqlets_only, hits_restricted_only: Non-overlapping instances
- cwm_trim_bounds (Dict[str, Dict[str, Tuple[int, int]]]): Trimming boundaries for each CWM type and motif.
Notes
- Hits are filtered to central region defined by modisco_half_width
- CWM similarity is computed as normalized dot product between hit and TF-MoDISco CWMs
- Recall metrics require both hits_df and seqlets_df to be non-empty
- Missing motifs are handled gracefully with empty DataFrames
Raises
- ValueError: If required columns are missing from input DataFrames.
424def seqlet_confusion( 425 hits_df: Union[pl.DataFrame, pl.LazyFrame], 426 seqlets_df: Union[pl.DataFrame, pl.LazyFrame], 427 peaks_df: pl.DataFrame, 428 motif_names: List[str], 429 motif_width: int, 430) -> Tuple[pl.DataFrame, Float[ndarray, "M M"]]: 431 """Compute confusion matrix between TF-MoDISco seqlets and Fi-NeMo hits. 432 433 This function creates a confusion matrix showing the overlap between 434 TF-MoDISco seqlets (ground truth) and Fi-NeMo hits across different motifs. 435 Overlap frequencies are estimated using binned genomic coordinates. 436 437 Parameters 438 ---------- 439 hits_df : Union[pl.DataFrame, pl.LazyFrame] 440 Fi-NeMo hit calls with required columns: 441 - peak_id, start_untrimmed, end_untrimmed, strand, motif_name 442 seqlets_df : pl.DataFrame 443 TF-MoDISco seqlets with required columns: 444 - chr_id, start_untrimmed, end_untrimmed, motif_name 445 peaks_df : pl.DataFrame 446 Peak metadata for joining coordinates: 447 - peak_id, chr_id 448 motif_names : List[str] 449 Names of motifs to include in confusion matrix. 450 Determines matrix dimensions. 451 motif_width : int 452 Width used for binning genomic coordinates. 453 Positions are binned to motif_width resolution. 454 455 Returns 456 ------- 457 confusion_df : pl.DataFrame 458 Detailed confusion matrix in tabular format with columns: 459 - motif_name_seqlets : Seqlet motif labels (rows) 460 - motif_name_hits : Hit motif labels (columns) 461 - frac_overlap : Fraction of seqlets overlapping with hits 462 confusion_mat : Float[ndarray, "M M"] 463 Confusion matrix where M = len(motif_names). 464 Entry (i,j) = fraction of motif i seqlets overlapping with motif j hits. 465 Rows represent seqlet motifs, columns represent hit motifs. 466 467 Notes 468 ----- 469 - Genomic coordinates are binned to motif_width resolution for overlap detection 470 - Only exact bin overlaps are considered (same chr_id, start_bin, end_bin) 471 - Fractions are computed as: overlaps / total_seqlets_per_motif 472 - Missing motif combinations result in zero entries in the confusion matrix 473 474 Raises 475 ------ 476 ValueError 477 If required columns are missing from input DataFrames. 478 KeyError 479 If motif names in data don't match those in motif_names list. 480 """ 481 bin_size = motif_width 482 483 # Ensure hits_df is LazyFrame for consistent operations 484 if isinstance(hits_df, pl.DataFrame): 485 hits_df = hits_df.lazy() 486 487 hits_binned = ( 488 hits_df.with_columns( 489 peak_id=pl.col("peak_id").cast(pl.UInt32), 490 is_revcomp=pl.col("strand") == "-", 491 ) 492 .join(peaks_df.lazy(), on="peak_id", how="inner") 493 .unique(subset=["chr_id", "start_untrimmed", "motif_name", "is_revcomp"]) 494 .select( 495 chr_id=pl.col("chr_id"), 496 start_bin=pl.col("start_untrimmed") // bin_size, 497 end_bin=pl.col("end_untrimmed") // bin_size, 498 motif_name=pl.col("motif_name"), 499 ) 500 ) 501 502 seqlets_lazy = seqlets_df.lazy() 503 seqlets_binned = seqlets_lazy.select( 504 chr_id=pl.col("chr_id"), 505 start_bin=pl.col("start_untrimmed") // bin_size, 506 end_bin=pl.col("end_untrimmed") // bin_size, 507 motif_name=pl.col("motif_name"), 508 ) 509 510 overlaps_df = seqlets_binned.join( 511 hits_binned, on=["chr_id", "start_bin", "end_bin"], how="inner", suffix="_hits" 512 ) 513 514 seqlet_counts = ( 515 seqlets_binned.group_by("motif_name").len(name="num_seqlets").collect() 516 ) 517 overlap_counts = ( 518 overlaps_df.group_by(["motif_name", "motif_name_hits"]) 519 .len(name="num_overlaps") 520 .collect() 521 ) 522 523 num_motifs = len(motif_names) 524 confusion_mat = np.zeros((num_motifs, num_motifs), dtype=np.float32) 525 name_to_idx = {m: i for i, m in enumerate(motif_names)} 526 527 confusion_df = overlap_counts.join( 528 seqlet_counts, on="motif_name", how="inner" 529 ).select( 530 motif_name_seqlets=pl.col("motif_name"), 531 motif_name_hits=pl.col("motif_name_hits"), 532 frac_overlap=pl.col("num_overlaps") / pl.col("num_seqlets"), 533 ) 534 535 confusion_idx_df = confusion_df.select( 536 row_idx=pl.col("motif_name_seqlets").replace_strict(name_to_idx), 537 col_idx=pl.col("motif_name_hits").replace_strict(name_to_idx), 538 frac_overlap=pl.col("frac_overlap"), 539 ) 540 541 row_idx = confusion_idx_df["row_idx"].to_numpy() 542 col_idx = confusion_idx_df["col_idx"].to_numpy() 543 frac_overlap = confusion_idx_df["frac_overlap"].to_numpy() 544 545 confusion_mat[row_idx, col_idx] = frac_overlap 546 547 return confusion_df, confusion_mat
Compute confusion matrix between TF-MoDISco seqlets and Fi-NeMo hits.
This function creates a confusion matrix showing the overlap between TF-MoDISco seqlets (ground truth) and Fi-NeMo hits across different motifs. Overlap frequencies are estimated using binned genomic coordinates.
Parameters
- hits_df (Union[pl.DataFrame, pl.LazyFrame]):
Fi-NeMo hit calls with required columns:
- peak_id, start_untrimmed, end_untrimmed, strand, motif_name
- seqlets_df (pl.DataFrame):
TF-MoDISco seqlets with required columns:
- chr_id, start_untrimmed, end_untrimmed, motif_name
- peaks_df (pl.DataFrame):
Peak metadata for joining coordinates:
- peak_id, chr_id
- motif_names (List[str]): Names of motifs to include in confusion matrix. Determines matrix dimensions.
- motif_width (int): Width used for binning genomic coordinates. Positions are binned to motif_width resolution.
Returns
- confusion_df (pl.DataFrame):
Detailed confusion matrix in tabular format with columns:
- motif_name_seqlets : Seqlet motif labels (rows)
- motif_name_hits : Hit motif labels (columns)
- frac_overlap : Fraction of seqlets overlapping with hits
- confusion_mat (Float[ndarray, "M M"]): Confusion matrix where M = len(motif_names). Entry (i,j) = fraction of motif i seqlets overlapping with motif j hits. Rows represent seqlet motifs, columns represent hit motifs.
Notes
- Genomic coordinates are binned to motif_width resolution for overlap detection
- Only exact bin overlaps are considered (same chr_id, start_bin, end_bin)
- Fractions are computed as: overlaps / total_seqlets_per_motif
- Missing motif combinations result in zero entries in the confusion matrix
Raises
- ValueError: If required columns are missing from input DataFrames.
- KeyError: If motif names in data don't match those in motif_names list.