finemo.postprocessing
Post-processing utilities for Fi-NeMo hit calling results.
This module provides functions for:
- Collapsing overlapping hits based on similarity scores
- Intersecting hit sets across multiple runs
- Quality control and filtering operations
The main operations are optimized using Numba for efficient processing of large hit datasets.
1"""Post-processing utilities for Fi-NeMo hit calling results. 2 3This module provides functions for: 4- Collapsing overlapping hits based on similarity scores 5- Intersecting hit sets across multiple runs 6- Quality control and filtering operations 7 8The main operations are optimized using Numba for efficient processing 9of large hit datasets. 10""" 11 12import heapq 13from typing import List, Union 14 15import numpy as np 16from numpy import ndarray 17import polars as pl 18from numba import njit 19from numba.types import Array, uint32, int32, float32 # type: ignore[attr-defined] 20from jaxtyping import Float, Int 21 22 23@njit( 24 uint32[:]( 25 Array(uint32, 1, "C", readonly=True), 26 Array(int32, 1, "C", readonly=True), 27 Array(int32, 1, "C", readonly=True), 28 Array(float32, 1, "C", readonly=True), 29 ), 30 cache=True, 31) 32def _collapse_hits( 33 chrom_ids: Int[ndarray, " N"], 34 starts: Int[ndarray, " N"], 35 ends: Int[ndarray, " N"], 36 similarities: Float[ndarray, " N"], 37) -> Int[ndarray, " N"]: 38 """Identify primary hits among overlapping hits using a sweep line algorithm. 39 40 This function uses a heap-based sweep line algorithm to efficiently identify 41 the best hit (highest similarity) among sets of overlapping hits within each 42 chromosome. Only one hit per overlapping group is marked as primary. 43 44 Parameters 45 ---------- 46 chrom_ids : Int[ndarray, "N"] 47 Chromosome identifiers for each hit, where N is the number of hits. 48 Dtype should be uint32 for Numba compatibility. 49 starts : Int[ndarray, "N"] 50 Start positions of hits (adjusted for overlap computation). 51 Dtype should be int32 for Numba compatibility. 52 ends : Int[ndarray, "N"] 53 End positions of hits (adjusted for overlap computation). 54 Dtype should be int32 for Numba compatibility. 55 similarities : Float[ndarray, "N"] 56 Similarity scores used for selecting the best hit. 57 Dtype should be float32 for Numba compatibility. 58 59 Returns 60 ------- 61 Int[ndarray, "N"] 62 Binary array where 1 indicates the hit is primary, 0 otherwise. 63 Returns uint32 array for consistency with input types. 64 65 Notes 66 ----- 67 This function is JIT-compiled with Numba for performance on large datasets. 68 The algorithm maintains active intervals in a heap and resolves overlaps 69 by keeping only the hit with the highest similarity score. 70 71 The sweep line algorithm processes hits in order and maintains a heap of 72 currently active intervals. When a new interval is encountered, it is 73 compared against all overlapping intervals in the heap, and only the 74 interval with the highest similarity score remains marked as primary. 75 """ 76 n = chrom_ids.shape[0] 77 out = np.ones(n, dtype=np.uint32) 78 heap = [(np.uint32(0), np.int32(0), -1) for _ in range(0)] 79 80 for i in range(n): 81 chrom_new = chrom_ids[i] 82 start_new = starts[i] 83 end_new = ends[i] 84 sim_new = similarities[i] 85 86 # Remove expired intervals from heap 87 while heap and heap[0] < (chrom_new, start_new, -1): 88 heapq.heappop(heap) 89 90 # Check overlaps with active intervals 91 for _, _, idx in heap: 92 cmp = sim_new > similarities[idx] 93 out[idx] &= cmp 94 out[i] &= not cmp 95 96 # Add current interval to heap 97 heapq.heappush(heap, (chrom_new, end_new, i)) 98 99 return out 100 101 102def collapse_hits( 103 hits_df: Union[pl.DataFrame, pl.LazyFrame], overlap_frac: float 104) -> pl.DataFrame: 105 """Collapse overlapping hits by selecting the best hit per overlapping group. 106 107 This function identifies overlapping hits and marks only the highest-similarity 108 hit as primary in each overlapping group. Overlap is determined by a fractional 109 threshold based on the average length of the two hits being compared. 110 111 Parameters 112 ---------- 113 hits_df : Union[pl.DataFrame, pl.LazyFrame] 114 Hit data containing required columns: chr (or peak_id if no chr), start, end, 115 hit_similarity. Will be collected to DataFrame if passed as LazyFrame. 116 overlap_frac : float 117 Overlap fraction threshold for considering hits as overlapping. 118 For two hits with lengths x and y, minimum overlap = overlap_frac * (x + y) / 2. 119 Must be between 0 and 1, where 0 means any overlap and 1 means complete overlap. 120 121 Returns 122 ------- 123 pl.DataFrame 124 Original hit data with an additional 'is_primary' column (1 for primary hits, 0 otherwise). 125 All original columns are preserved, with the new column added at the end. 126 127 Raises 128 ------ 129 KeyError 130 If required columns (chr/peak_id, start, end, hit_similarity) are missing. 131 132 Notes 133 ----- 134 The algorithm transforms coordinates by scaling by 2 and adjusting by the overlap 135 fraction to create effective overlap regions for efficient processing. This allows 136 using a sweep line algorithm to identify overlaps in a single pass. 137 138 The transformation works as follows: 139 - Original coordinates: [start, end] 140 - Length = end - start 141 - Adjusted start = start * 2 + length * overlap_frac 142 - Adjusted end = end * 2 - length * overlap_frac 143 144 This creates regions that overlap only when the original regions have sufficient 145 overlap according to the specified fraction. 146 147 Examples 148 -------- 149 >>> hits_collapsed = collapse_hits(hits_df, overlap_frac=0.2) 150 >>> primary_hits = hits_collapsed.filter(pl.col("is_primary") == 1) 151 >>> print(f"Kept {primary_hits.height}/{hits_df.height} hits as primary") 152 """ 153 # Ensure we're working with a DataFrame 154 if isinstance(hits_df, pl.LazyFrame): 155 hits_df = hits_df.collect() 156 157 chroms = hits_df["chr"].unique(maintain_order=True) 158 159 if not chroms.is_empty(): 160 chrom_to_id = {chrom: i for i, chrom in enumerate(chroms)} 161 # Transform coordinates for overlap computation 162 # Scale by 2 and adjust by overlap fraction to create effective overlap regions 163 df = hits_df.select( 164 chrom_id=pl.col("chr").replace_strict(chrom_to_id, return_dtype=pl.UInt32), 165 start_trim=pl.col("start") * 2 166 + ((pl.col("end") - pl.col("start")) * overlap_frac).cast(pl.Int32), 167 end_trim=pl.col("end") * 2 168 - ((pl.col("end") - pl.col("start")) * overlap_frac).cast(pl.Int32), 169 similarity=pl.col("hit_similarity"), 170 ) 171 else: 172 # Fall back to peak_id when chr column is not available 173 df = hits_df.select( 174 chrom_id=pl.col("peak_id"), 175 start_trim=pl.col("start") * 2 176 + ((pl.col("end") - pl.col("start")) * overlap_frac).cast(pl.Int32), 177 end_trim=pl.col("end") * 2 178 - ((pl.col("end") - pl.col("start")) * overlap_frac).cast(pl.Int32), 179 similarity=pl.col("hit_similarity"), 180 ) 181 182 # Rechunk for efficient array access 183 df = df.rechunk() 184 chrom_ids = df["chrom_id"].to_numpy(allow_copy=False) 185 starts = df["start_trim"].to_numpy(allow_copy=False) 186 ends = df["end_trim"].to_numpy(allow_copy=False) 187 similarities = df["similarity"].to_numpy(allow_copy=False) 188 189 # Run the collapse algorithm 190 is_primary = _collapse_hits(chrom_ids, starts, ends, similarities) 191 192 # Add primary indicator column to original DataFrame 193 df_out = hits_df.with_columns(is_primary=pl.Series(is_primary, dtype=pl.UInt32)) 194 195 return df_out 196 197 198def intersect_hits( 199 hits_dfs: List[Union[pl.DataFrame, pl.LazyFrame]], relaxed: bool 200) -> pl.DataFrame: 201 """Intersect hit datasets across multiple runs to find common hits. 202 203 This function finds hits that appear consistently across multiple Fi-NeMo 204 runs, which can be useful for identifying robust motif instances that are 205 not sensitive to parameter variations or random initialization. 206 207 Parameters 208 ---------- 209 hits_dfs : List[Union[pl.DataFrame, pl.LazyFrame]] 210 List of hit DataFrames from different Fi-NeMo runs. Each DataFrame must 211 contain the columns specified by the intersection criteria. LazyFrames 212 will be collected before processing. 213 relaxed : bool 214 If True, uses relaxed intersection criteria with only motif names and 215 untrimmed coordinates. If False, uses strict criteria including all 216 coordinate and metadata columns. 217 218 Returns 219 ------- 220 pl.DataFrame 221 DataFrame containing hits that appear in all input datasets. 222 Columns from later datasets are suffixed with their index (e.g., '_1', '_2'). 223 The first dataset's columns retain their original names. 224 225 Raises 226 ------ 227 ValueError 228 If fewer than one hits DataFrame is provided. 229 KeyError 230 If required columns for the specified intersection criteria are missing 231 from any of the input DataFrames. 232 233 Notes 234 ----- 235 Relaxed intersection is useful when comparing results across different 236 region definitions or motif trimming parameters, but may produce less 237 precise matches. Strict intersection requires identical region definitions 238 and is recommended for most use cases. 239 240 The intersection columns used are: 241 - Relaxed: ["chr", "start_untrimmed", "end_untrimmed", "motif_name", "strand"] 242 - Strict: ["chr", "start", "end", "start_untrimmed", "end_untrimmed", 243 "motif_name", "strand", "peak_name", "peak_id"] 244 245 The function performs successive inner joins starting with the first DataFrame, 246 so the final result contains only hits present in all input datasets. 247 248 Examples 249 -------- 250 >>> common_hits = intersect_hits([hits_df1, hits_df2], relaxed=False) 251 >>> print(f"Found {common_hits.height} hits common to both runs") 252 >>> 253 >>> # Compare relaxed vs strict intersection 254 >>> relaxed_hits = intersect_hits([hits_df1, hits_df2], relaxed=True) 255 >>> strict_hits = intersect_hits([hits_df1, hits_df2], relaxed=False) 256 >>> print(f"Relaxed: {relaxed_hits.height}, Strict: {strict_hits.height}") 257 """ 258 if relaxed: 259 # Relaxed criteria: only motif identity and untrimmed positions 260 join_cols = ["chr", "start_untrimmed", "end_untrimmed", "motif_name", "strand"] 261 else: 262 # Strict criteria: all coordinate and metadata columns 263 join_cols = [ 264 "chr", 265 "start", 266 "end", 267 "start_untrimmed", 268 "end_untrimmed", 269 "motif_name", 270 "strand", 271 "peak_name", 272 "peak_id", 273 ] 274 275 if len(hits_dfs) < 1: 276 raise ValueError("At least one hits dataframe required") 277 278 # Ensure all DataFrames are collected 279 collected_dfs = [] 280 for df in hits_dfs: 281 if isinstance(df, pl.LazyFrame): 282 collected_dfs.append(df.collect()) 283 else: 284 collected_dfs.append(df) 285 286 # Start with first DataFrame and successively intersect with others 287 hits_df = collected_dfs[0] 288 for i in range(1, len(collected_dfs)): 289 hits_df = hits_df.join( 290 collected_dfs[i], 291 on=join_cols, 292 how="inner", 293 suffix=f"_{i}", 294 join_nulls=True, 295 coalesce=True, 296 ) 297 298 return hits_df
103def collapse_hits( 104 hits_df: Union[pl.DataFrame, pl.LazyFrame], overlap_frac: float 105) -> pl.DataFrame: 106 """Collapse overlapping hits by selecting the best hit per overlapping group. 107 108 This function identifies overlapping hits and marks only the highest-similarity 109 hit as primary in each overlapping group. Overlap is determined by a fractional 110 threshold based on the average length of the two hits being compared. 111 112 Parameters 113 ---------- 114 hits_df : Union[pl.DataFrame, pl.LazyFrame] 115 Hit data containing required columns: chr (or peak_id if no chr), start, end, 116 hit_similarity. Will be collected to DataFrame if passed as LazyFrame. 117 overlap_frac : float 118 Overlap fraction threshold for considering hits as overlapping. 119 For two hits with lengths x and y, minimum overlap = overlap_frac * (x + y) / 2. 120 Must be between 0 and 1, where 0 means any overlap and 1 means complete overlap. 121 122 Returns 123 ------- 124 pl.DataFrame 125 Original hit data with an additional 'is_primary' column (1 for primary hits, 0 otherwise). 126 All original columns are preserved, with the new column added at the end. 127 128 Raises 129 ------ 130 KeyError 131 If required columns (chr/peak_id, start, end, hit_similarity) are missing. 132 133 Notes 134 ----- 135 The algorithm transforms coordinates by scaling by 2 and adjusting by the overlap 136 fraction to create effective overlap regions for efficient processing. This allows 137 using a sweep line algorithm to identify overlaps in a single pass. 138 139 The transformation works as follows: 140 - Original coordinates: [start, end] 141 - Length = end - start 142 - Adjusted start = start * 2 + length * overlap_frac 143 - Adjusted end = end * 2 - length * overlap_frac 144 145 This creates regions that overlap only when the original regions have sufficient 146 overlap according to the specified fraction. 147 148 Examples 149 -------- 150 >>> hits_collapsed = collapse_hits(hits_df, overlap_frac=0.2) 151 >>> primary_hits = hits_collapsed.filter(pl.col("is_primary") == 1) 152 >>> print(f"Kept {primary_hits.height}/{hits_df.height} hits as primary") 153 """ 154 # Ensure we're working with a DataFrame 155 if isinstance(hits_df, pl.LazyFrame): 156 hits_df = hits_df.collect() 157 158 chroms = hits_df["chr"].unique(maintain_order=True) 159 160 if not chroms.is_empty(): 161 chrom_to_id = {chrom: i for i, chrom in enumerate(chroms)} 162 # Transform coordinates for overlap computation 163 # Scale by 2 and adjust by overlap fraction to create effective overlap regions 164 df = hits_df.select( 165 chrom_id=pl.col("chr").replace_strict(chrom_to_id, return_dtype=pl.UInt32), 166 start_trim=pl.col("start") * 2 167 + ((pl.col("end") - pl.col("start")) * overlap_frac).cast(pl.Int32), 168 end_trim=pl.col("end") * 2 169 - ((pl.col("end") - pl.col("start")) * overlap_frac).cast(pl.Int32), 170 similarity=pl.col("hit_similarity"), 171 ) 172 else: 173 # Fall back to peak_id when chr column is not available 174 df = hits_df.select( 175 chrom_id=pl.col("peak_id"), 176 start_trim=pl.col("start") * 2 177 + ((pl.col("end") - pl.col("start")) * overlap_frac).cast(pl.Int32), 178 end_trim=pl.col("end") * 2 179 - ((pl.col("end") - pl.col("start")) * overlap_frac).cast(pl.Int32), 180 similarity=pl.col("hit_similarity"), 181 ) 182 183 # Rechunk for efficient array access 184 df = df.rechunk() 185 chrom_ids = df["chrom_id"].to_numpy(allow_copy=False) 186 starts = df["start_trim"].to_numpy(allow_copy=False) 187 ends = df["end_trim"].to_numpy(allow_copy=False) 188 similarities = df["similarity"].to_numpy(allow_copy=False) 189 190 # Run the collapse algorithm 191 is_primary = _collapse_hits(chrom_ids, starts, ends, similarities) 192 193 # Add primary indicator column to original DataFrame 194 df_out = hits_df.with_columns(is_primary=pl.Series(is_primary, dtype=pl.UInt32)) 195 196 return df_out
Collapse overlapping hits by selecting the best hit per overlapping group.
This function identifies overlapping hits and marks only the highest-similarity hit as primary in each overlapping group. Overlap is determined by a fractional threshold based on the average length of the two hits being compared.
Parameters
- hits_df (Union[pl.DataFrame, pl.LazyFrame]): Hit data containing required columns: chr (or peak_id if no chr), start, end, hit_similarity. Will be collected to DataFrame if passed as LazyFrame.
- overlap_frac (float): Overlap fraction threshold for considering hits as overlapping. For two hits with lengths x and y, minimum overlap = overlap_frac * (x + y) / 2. Must be between 0 and 1, where 0 means any overlap and 1 means complete overlap.
Returns
- pl.DataFrame: Original hit data with an additional 'is_primary' column (1 for primary hits, 0 otherwise). All original columns are preserved, with the new column added at the end.
Raises
- KeyError: If required columns (chr/peak_id, start, end, hit_similarity) are missing.
Notes
The algorithm transforms coordinates by scaling by 2 and adjusting by the overlap fraction to create effective overlap regions for efficient processing. This allows using a sweep line algorithm to identify overlaps in a single pass.
The transformation works as follows:
- Original coordinates: [start, end]
- Length = end - start
- Adjusted start = start * 2 + length * overlap_frac
- Adjusted end = end * 2 - length * overlap_frac
This creates regions that overlap only when the original regions have sufficient overlap according to the specified fraction.
Examples
>>> hits_collapsed = collapse_hits(hits_df, overlap_frac=0.2)
>>> primary_hits = hits_collapsed.filter(pl.col("is_primary") == 1)
>>> print(f"Kept {primary_hits.height}/{hits_df.height} hits as primary")
199def intersect_hits( 200 hits_dfs: List[Union[pl.DataFrame, pl.LazyFrame]], relaxed: bool 201) -> pl.DataFrame: 202 """Intersect hit datasets across multiple runs to find common hits. 203 204 This function finds hits that appear consistently across multiple Fi-NeMo 205 runs, which can be useful for identifying robust motif instances that are 206 not sensitive to parameter variations or random initialization. 207 208 Parameters 209 ---------- 210 hits_dfs : List[Union[pl.DataFrame, pl.LazyFrame]] 211 List of hit DataFrames from different Fi-NeMo runs. Each DataFrame must 212 contain the columns specified by the intersection criteria. LazyFrames 213 will be collected before processing. 214 relaxed : bool 215 If True, uses relaxed intersection criteria with only motif names and 216 untrimmed coordinates. If False, uses strict criteria including all 217 coordinate and metadata columns. 218 219 Returns 220 ------- 221 pl.DataFrame 222 DataFrame containing hits that appear in all input datasets. 223 Columns from later datasets are suffixed with their index (e.g., '_1', '_2'). 224 The first dataset's columns retain their original names. 225 226 Raises 227 ------ 228 ValueError 229 If fewer than one hits DataFrame is provided. 230 KeyError 231 If required columns for the specified intersection criteria are missing 232 from any of the input DataFrames. 233 234 Notes 235 ----- 236 Relaxed intersection is useful when comparing results across different 237 region definitions or motif trimming parameters, but may produce less 238 precise matches. Strict intersection requires identical region definitions 239 and is recommended for most use cases. 240 241 The intersection columns used are: 242 - Relaxed: ["chr", "start_untrimmed", "end_untrimmed", "motif_name", "strand"] 243 - Strict: ["chr", "start", "end", "start_untrimmed", "end_untrimmed", 244 "motif_name", "strand", "peak_name", "peak_id"] 245 246 The function performs successive inner joins starting with the first DataFrame, 247 so the final result contains only hits present in all input datasets. 248 249 Examples 250 -------- 251 >>> common_hits = intersect_hits([hits_df1, hits_df2], relaxed=False) 252 >>> print(f"Found {common_hits.height} hits common to both runs") 253 >>> 254 >>> # Compare relaxed vs strict intersection 255 >>> relaxed_hits = intersect_hits([hits_df1, hits_df2], relaxed=True) 256 >>> strict_hits = intersect_hits([hits_df1, hits_df2], relaxed=False) 257 >>> print(f"Relaxed: {relaxed_hits.height}, Strict: {strict_hits.height}") 258 """ 259 if relaxed: 260 # Relaxed criteria: only motif identity and untrimmed positions 261 join_cols = ["chr", "start_untrimmed", "end_untrimmed", "motif_name", "strand"] 262 else: 263 # Strict criteria: all coordinate and metadata columns 264 join_cols = [ 265 "chr", 266 "start", 267 "end", 268 "start_untrimmed", 269 "end_untrimmed", 270 "motif_name", 271 "strand", 272 "peak_name", 273 "peak_id", 274 ] 275 276 if len(hits_dfs) < 1: 277 raise ValueError("At least one hits dataframe required") 278 279 # Ensure all DataFrames are collected 280 collected_dfs = [] 281 for df in hits_dfs: 282 if isinstance(df, pl.LazyFrame): 283 collected_dfs.append(df.collect()) 284 else: 285 collected_dfs.append(df) 286 287 # Start with first DataFrame and successively intersect with others 288 hits_df = collected_dfs[0] 289 for i in range(1, len(collected_dfs)): 290 hits_df = hits_df.join( 291 collected_dfs[i], 292 on=join_cols, 293 how="inner", 294 suffix=f"_{i}", 295 join_nulls=True, 296 coalesce=True, 297 ) 298 299 return hits_df
Intersect hit datasets across multiple runs to find common hits.
This function finds hits that appear consistently across multiple Fi-NeMo runs, which can be useful for identifying robust motif instances that are not sensitive to parameter variations or random initialization.
Parameters
- hits_dfs (List[Union[pl.DataFrame, pl.LazyFrame]]): List of hit DataFrames from different Fi-NeMo runs. Each DataFrame must contain the columns specified by the intersection criteria. LazyFrames will be collected before processing.
- relaxed (bool): If True, uses relaxed intersection criteria with only motif names and untrimmed coordinates. If False, uses strict criteria including all coordinate and metadata columns.
Returns
- pl.DataFrame: DataFrame containing hits that appear in all input datasets. Columns from later datasets are suffixed with their index (e.g., '_1', '_2'). The first dataset's columns retain their original names.
Raises
- ValueError: If fewer than one hits DataFrame is provided.
- KeyError: If required columns for the specified intersection criteria are missing from any of the input DataFrames.
Notes
Relaxed intersection is useful when comparing results across different region definitions or motif trimming parameters, but may produce less precise matches. Strict intersection requires identical region definitions and is recommended for most use cases.
The intersection columns used are:
- Relaxed: ["chr", "start_untrimmed", "end_untrimmed", "motif_name", "strand"]
- Strict: ["chr", "start", "end", "start_untrimmed", "end_untrimmed", "motif_name", "strand", "peak_name", "peak_id"]
The function performs successive inner joins starting with the first DataFrame, so the final result contains only hits present in all input datasets.
Examples
>>> common_hits = intersect_hits([hits_df1, hits_df2], relaxed=False)
>>> print(f"Found {common_hits.height} hits common to both runs")
>>>
>>> # Compare relaxed vs strict intersection
>>> relaxed_hits = intersect_hits([hits_df1, hits_df2], relaxed=True)
>>> strict_hits = intersect_hits([hits_df1, hits_df2], relaxed=False)
>>> print(f"Relaxed: {relaxed_hits.height}, Strict: {strict_hits.height}")