finemo.postprocessing

Post-processing utilities for Fi-NeMo hit calling results.

This module provides functions for:

  • Collapsing overlapping hits based on similarity scores
  • Intersecting hit sets across multiple runs
  • Quality control and filtering operations

The main operations are optimized using Numba for efficient processing of large hit datasets.

  1"""Post-processing utilities for Fi-NeMo hit calling results.
  2
  3This module provides functions for:
  4- Collapsing overlapping hits based on similarity scores
  5- Intersecting hit sets across multiple runs
  6- Quality control and filtering operations
  7
  8The main operations are optimized using Numba for efficient processing
  9of large hit datasets.
 10"""
 11
 12import heapq
 13from typing import List, Union
 14
 15import numpy as np
 16from numpy import ndarray
 17import polars as pl
 18from numba import njit
 19from numba.types import Array, uint32, int32, float32  # type: ignore[attr-defined]
 20from jaxtyping import Float, Int
 21
 22
 23@njit(
 24    uint32[:](
 25        Array(uint32, 1, "C", readonly=True),
 26        Array(int32, 1, "C", readonly=True),
 27        Array(int32, 1, "C", readonly=True),
 28        Array(float32, 1, "C", readonly=True),
 29    ),
 30    cache=True,
 31)
 32def _collapse_hits(
 33    chrom_ids: Int[ndarray, " N"],
 34    starts: Int[ndarray, " N"],
 35    ends: Int[ndarray, " N"],
 36    similarities: Float[ndarray, " N"],
 37) -> Int[ndarray, " N"]:
 38    """Identify primary hits among overlapping hits using a sweep line algorithm.
 39
 40    This function uses a heap-based sweep line algorithm to efficiently identify
 41    the best hit (highest similarity) among sets of overlapping hits within each
 42    chromosome. Only one hit per overlapping group is marked as primary.
 43
 44    Parameters
 45    ----------
 46    chrom_ids : Int[ndarray, "N"]
 47        Chromosome identifiers for each hit, where N is the number of hits.
 48        Dtype should be uint32 for Numba compatibility.
 49    starts : Int[ndarray, "N"]
 50        Start positions of hits (adjusted for overlap computation).
 51        Dtype should be int32 for Numba compatibility.
 52    ends : Int[ndarray, "N"]
 53        End positions of hits (adjusted for overlap computation).
 54        Dtype should be int32 for Numba compatibility.
 55    similarities : Float[ndarray, "N"]
 56        Similarity scores used for selecting the best hit.
 57        Dtype should be float32 for Numba compatibility.
 58
 59    Returns
 60    -------
 61    Int[ndarray, "N"]
 62        Binary array where 1 indicates the hit is primary, 0 otherwise.
 63        Returns uint32 array for consistency with input types.
 64
 65    Notes
 66    -----
 67    This function is JIT-compiled with Numba for performance on large datasets.
 68    The algorithm maintains active intervals in a heap and resolves overlaps
 69    by keeping only the hit with the highest similarity score.
 70
 71    The sweep line algorithm processes hits in order and maintains a heap of
 72    currently active intervals. When a new interval is encountered, it is
 73    compared against all overlapping intervals in the heap, and only the
 74    interval with the highest similarity score remains marked as primary.
 75    """
 76    n = chrom_ids.shape[0]
 77    out = np.ones(n, dtype=np.uint32)
 78    heap = [(np.uint32(0), np.int32(0), -1) for _ in range(0)]
 79
 80    for i in range(n):
 81        chrom_new = chrom_ids[i]
 82        start_new = starts[i]
 83        end_new = ends[i]
 84        sim_new = similarities[i]
 85
 86        # Remove expired intervals from heap
 87        while heap and heap[0] < (chrom_new, start_new, -1):
 88            heapq.heappop(heap)
 89
 90        # Check overlaps with active intervals
 91        for _, _, idx in heap:
 92            cmp = sim_new > similarities[idx]
 93            out[idx] &= cmp
 94            out[i] &= not cmp
 95
 96        # Add current interval to heap
 97        heapq.heappush(heap, (chrom_new, end_new, i))
 98
 99    return out
100
101
102def collapse_hits(
103    hits_df: Union[pl.DataFrame, pl.LazyFrame], overlap_frac: float
104) -> pl.DataFrame:
105    """Collapse overlapping hits by selecting the best hit per overlapping group.
106
107    This function identifies overlapping hits and marks only the highest-similarity
108    hit as primary in each overlapping group. Overlap is determined by a fractional
109    threshold based on the average length of the two hits being compared.
110
111    Parameters
112    ----------
113    hits_df : Union[pl.DataFrame, pl.LazyFrame]
114        Hit data containing required columns: chr (or peak_id if no chr), start, end,
115        hit_similarity. Will be collected to DataFrame if passed as LazyFrame.
116    overlap_frac : float
117        Overlap fraction threshold for considering hits as overlapping.
118        For two hits with lengths x and y, minimum overlap = overlap_frac * (x + y) / 2.
119        Must be between 0 and 1, where 0 means any overlap and 1 means complete overlap.
120
121    Returns
122    -------
123    pl.DataFrame
124        Original hit data with an additional 'is_primary' column (1 for primary hits, 0 otherwise).
125        All original columns are preserved, with the new column added at the end.
126
127    Raises
128    ------
129    KeyError
130        If required columns (chr/peak_id, start, end, hit_similarity) are missing.
131
132    Notes
133    -----
134    The algorithm transforms coordinates by scaling by 2 and adjusting by the overlap
135    fraction to create effective overlap regions for efficient processing. This allows
136    using a sweep line algorithm to identify overlaps in a single pass.
137
138    The transformation works as follows:
139    - Original coordinates: [start, end]
140    - Length = end - start
141    - Adjusted start = start * 2 + length * overlap_frac
142    - Adjusted end = end * 2 - length * overlap_frac
143
144    This creates regions that overlap only when the original regions have sufficient
145    overlap according to the specified fraction.
146
147    Examples
148    --------
149    >>> hits_collapsed = collapse_hits(hits_df, overlap_frac=0.2)
150    >>> primary_hits = hits_collapsed.filter(pl.col("is_primary") == 1)
151    >>> print(f"Kept {primary_hits.height}/{hits_df.height} hits as primary")
152    """
153    # Ensure we're working with a DataFrame
154    if isinstance(hits_df, pl.LazyFrame):
155        hits_df = hits_df.collect()
156
157    chroms = hits_df["chr"].unique(maintain_order=True)
158
159    if not chroms.is_empty():
160        chrom_to_id = {chrom: i for i, chrom in enumerate(chroms)}
161        # Transform coordinates for overlap computation
162        # Scale by 2 and adjust by overlap fraction to create effective overlap regions
163        df = hits_df.select(
164            chrom_id=pl.col("chr").replace_strict(chrom_to_id, return_dtype=pl.UInt32),
165            start_trim=pl.col("start") * 2
166            + ((pl.col("end") - pl.col("start")) * overlap_frac).cast(pl.Int32),
167            end_trim=pl.col("end") * 2
168            - ((pl.col("end") - pl.col("start")) * overlap_frac).cast(pl.Int32),
169            similarity=pl.col("hit_similarity"),
170        )
171    else:
172        # Fall back to peak_id when chr column is not available
173        df = hits_df.select(
174            chrom_id=pl.col("peak_id"),
175            start_trim=pl.col("start") * 2
176            + ((pl.col("end") - pl.col("start")) * overlap_frac).cast(pl.Int32),
177            end_trim=pl.col("end") * 2
178            - ((pl.col("end") - pl.col("start")) * overlap_frac).cast(pl.Int32),
179            similarity=pl.col("hit_similarity"),
180        )
181
182    # Rechunk for efficient array access
183    df = df.rechunk()
184    chrom_ids = df["chrom_id"].to_numpy(allow_copy=False)
185    starts = df["start_trim"].to_numpy(allow_copy=False)
186    ends = df["end_trim"].to_numpy(allow_copy=False)
187    similarities = df["similarity"].to_numpy(allow_copy=False)
188
189    # Run the collapse algorithm
190    is_primary = _collapse_hits(chrom_ids, starts, ends, similarities)
191
192    # Add primary indicator column to original DataFrame
193    df_out = hits_df.with_columns(is_primary=pl.Series(is_primary, dtype=pl.UInt32))
194
195    return df_out
196
197
198def intersect_hits(
199    hits_dfs: List[Union[pl.DataFrame, pl.LazyFrame]], relaxed: bool
200) -> pl.DataFrame:
201    """Intersect hit datasets across multiple runs to find common hits.
202
203    This function finds hits that appear consistently across multiple Fi-NeMo
204    runs, which can be useful for identifying robust motif instances that are
205    not sensitive to parameter variations or random initialization.
206
207    Parameters
208    ----------
209    hits_dfs : List[Union[pl.DataFrame, pl.LazyFrame]]
210        List of hit DataFrames from different Fi-NeMo runs. Each DataFrame must
211        contain the columns specified by the intersection criteria. LazyFrames
212        will be collected before processing.
213    relaxed : bool
214        If True, uses relaxed intersection criteria with only motif names and
215        untrimmed coordinates. If False, uses strict criteria including all
216        coordinate and metadata columns.
217
218    Returns
219    -------
220    pl.DataFrame
221        DataFrame containing hits that appear in all input datasets.
222        Columns from later datasets are suffixed with their index (e.g., '_1', '_2').
223        The first dataset's columns retain their original names.
224
225    Raises
226    ------
227    ValueError
228        If fewer than one hits DataFrame is provided.
229    KeyError
230        If required columns for the specified intersection criteria are missing
231        from any of the input DataFrames.
232
233    Notes
234    -----
235    Relaxed intersection is useful when comparing results across different
236    region definitions or motif trimming parameters, but may produce less
237    precise matches. Strict intersection requires identical region definitions
238    and is recommended for most use cases.
239
240    The intersection columns used are:
241    - Relaxed: ["chr", "start_untrimmed", "end_untrimmed", "motif_name", "strand"]
242    - Strict: ["chr", "start", "end", "start_untrimmed", "end_untrimmed",
243               "motif_name", "strand", "peak_name", "peak_id"]
244
245    The function performs successive inner joins starting with the first DataFrame,
246    so the final result contains only hits present in all input datasets.
247
248    Examples
249    --------
250    >>> common_hits = intersect_hits([hits_df1, hits_df2], relaxed=False)
251    >>> print(f"Found {common_hits.height} hits common to both runs")
252    >>>
253    >>> # Compare relaxed vs strict intersection
254    >>> relaxed_hits = intersect_hits([hits_df1, hits_df2], relaxed=True)
255    >>> strict_hits = intersect_hits([hits_df1, hits_df2], relaxed=False)
256    >>> print(f"Relaxed: {relaxed_hits.height}, Strict: {strict_hits.height}")
257    """
258    if relaxed:
259        # Relaxed criteria: only motif identity and untrimmed positions
260        join_cols = ["chr", "start_untrimmed", "end_untrimmed", "motif_name", "strand"]
261    else:
262        # Strict criteria: all coordinate and metadata columns
263        join_cols = [
264            "chr",
265            "start",
266            "end",
267            "start_untrimmed",
268            "end_untrimmed",
269            "motif_name",
270            "strand",
271            "peak_name",
272            "peak_id",
273        ]
274
275    if len(hits_dfs) < 1:
276        raise ValueError("At least one hits dataframe required")
277
278    # Ensure all DataFrames are collected
279    collected_dfs = []
280    for df in hits_dfs:
281        if isinstance(df, pl.LazyFrame):
282            collected_dfs.append(df.collect())
283        else:
284            collected_dfs.append(df)
285
286    # Start with first DataFrame and successively intersect with others
287    hits_df = collected_dfs[0]
288    for i in range(1, len(collected_dfs)):
289        hits_df = hits_df.join(
290            collected_dfs[i],
291            on=join_cols,
292            how="inner",
293            suffix=f"_{i}",
294            join_nulls=True,
295            coalesce=True,
296        )
297
298    return hits_df
def collapse_hits( hits_df: Union[polars.dataframe.frame.DataFrame, polars.lazyframe.frame.LazyFrame], overlap_frac: float) -> polars.dataframe.frame.DataFrame:
103def collapse_hits(
104    hits_df: Union[pl.DataFrame, pl.LazyFrame], overlap_frac: float
105) -> pl.DataFrame:
106    """Collapse overlapping hits by selecting the best hit per overlapping group.
107
108    This function identifies overlapping hits and marks only the highest-similarity
109    hit as primary in each overlapping group. Overlap is determined by a fractional
110    threshold based on the average length of the two hits being compared.
111
112    Parameters
113    ----------
114    hits_df : Union[pl.DataFrame, pl.LazyFrame]
115        Hit data containing required columns: chr (or peak_id if no chr), start, end,
116        hit_similarity. Will be collected to DataFrame if passed as LazyFrame.
117    overlap_frac : float
118        Overlap fraction threshold for considering hits as overlapping.
119        For two hits with lengths x and y, minimum overlap = overlap_frac * (x + y) / 2.
120        Must be between 0 and 1, where 0 means any overlap and 1 means complete overlap.
121
122    Returns
123    -------
124    pl.DataFrame
125        Original hit data with an additional 'is_primary' column (1 for primary hits, 0 otherwise).
126        All original columns are preserved, with the new column added at the end.
127
128    Raises
129    ------
130    KeyError
131        If required columns (chr/peak_id, start, end, hit_similarity) are missing.
132
133    Notes
134    -----
135    The algorithm transforms coordinates by scaling by 2 and adjusting by the overlap
136    fraction to create effective overlap regions for efficient processing. This allows
137    using a sweep line algorithm to identify overlaps in a single pass.
138
139    The transformation works as follows:
140    - Original coordinates: [start, end]
141    - Length = end - start
142    - Adjusted start = start * 2 + length * overlap_frac
143    - Adjusted end = end * 2 - length * overlap_frac
144
145    This creates regions that overlap only when the original regions have sufficient
146    overlap according to the specified fraction.
147
148    Examples
149    --------
150    >>> hits_collapsed = collapse_hits(hits_df, overlap_frac=0.2)
151    >>> primary_hits = hits_collapsed.filter(pl.col("is_primary") == 1)
152    >>> print(f"Kept {primary_hits.height}/{hits_df.height} hits as primary")
153    """
154    # Ensure we're working with a DataFrame
155    if isinstance(hits_df, pl.LazyFrame):
156        hits_df = hits_df.collect()
157
158    chroms = hits_df["chr"].unique(maintain_order=True)
159
160    if not chroms.is_empty():
161        chrom_to_id = {chrom: i for i, chrom in enumerate(chroms)}
162        # Transform coordinates for overlap computation
163        # Scale by 2 and adjust by overlap fraction to create effective overlap regions
164        df = hits_df.select(
165            chrom_id=pl.col("chr").replace_strict(chrom_to_id, return_dtype=pl.UInt32),
166            start_trim=pl.col("start") * 2
167            + ((pl.col("end") - pl.col("start")) * overlap_frac).cast(pl.Int32),
168            end_trim=pl.col("end") * 2
169            - ((pl.col("end") - pl.col("start")) * overlap_frac).cast(pl.Int32),
170            similarity=pl.col("hit_similarity"),
171        )
172    else:
173        # Fall back to peak_id when chr column is not available
174        df = hits_df.select(
175            chrom_id=pl.col("peak_id"),
176            start_trim=pl.col("start") * 2
177            + ((pl.col("end") - pl.col("start")) * overlap_frac).cast(pl.Int32),
178            end_trim=pl.col("end") * 2
179            - ((pl.col("end") - pl.col("start")) * overlap_frac).cast(pl.Int32),
180            similarity=pl.col("hit_similarity"),
181        )
182
183    # Rechunk for efficient array access
184    df = df.rechunk()
185    chrom_ids = df["chrom_id"].to_numpy(allow_copy=False)
186    starts = df["start_trim"].to_numpy(allow_copy=False)
187    ends = df["end_trim"].to_numpy(allow_copy=False)
188    similarities = df["similarity"].to_numpy(allow_copy=False)
189
190    # Run the collapse algorithm
191    is_primary = _collapse_hits(chrom_ids, starts, ends, similarities)
192
193    # Add primary indicator column to original DataFrame
194    df_out = hits_df.with_columns(is_primary=pl.Series(is_primary, dtype=pl.UInt32))
195
196    return df_out

Collapse overlapping hits by selecting the best hit per overlapping group.

This function identifies overlapping hits and marks only the highest-similarity hit as primary in each overlapping group. Overlap is determined by a fractional threshold based on the average length of the two hits being compared.

Parameters
  • hits_df (Union[pl.DataFrame, pl.LazyFrame]): Hit data containing required columns: chr (or peak_id if no chr), start, end, hit_similarity. Will be collected to DataFrame if passed as LazyFrame.
  • overlap_frac (float): Overlap fraction threshold for considering hits as overlapping. For two hits with lengths x and y, minimum overlap = overlap_frac * (x + y) / 2. Must be between 0 and 1, where 0 means any overlap and 1 means complete overlap.
Returns
  • pl.DataFrame: Original hit data with an additional 'is_primary' column (1 for primary hits, 0 otherwise). All original columns are preserved, with the new column added at the end.
Raises
  • KeyError: If required columns (chr/peak_id, start, end, hit_similarity) are missing.
Notes

The algorithm transforms coordinates by scaling by 2 and adjusting by the overlap fraction to create effective overlap regions for efficient processing. This allows using a sweep line algorithm to identify overlaps in a single pass.

The transformation works as follows:

  • Original coordinates: [start, end]
  • Length = end - start
  • Adjusted start = start * 2 + length * overlap_frac
  • Adjusted end = end * 2 - length * overlap_frac

This creates regions that overlap only when the original regions have sufficient overlap according to the specified fraction.

Examples
>>> hits_collapsed = collapse_hits(hits_df, overlap_frac=0.2)
>>> primary_hits = hits_collapsed.filter(pl.col("is_primary") == 1)
>>> print(f"Kept {primary_hits.height}/{hits_df.height} hits as primary")
def intersect_hits( hits_dfs: List[Union[polars.dataframe.frame.DataFrame, polars.lazyframe.frame.LazyFrame]], relaxed: bool) -> polars.dataframe.frame.DataFrame:
199def intersect_hits(
200    hits_dfs: List[Union[pl.DataFrame, pl.LazyFrame]], relaxed: bool
201) -> pl.DataFrame:
202    """Intersect hit datasets across multiple runs to find common hits.
203
204    This function finds hits that appear consistently across multiple Fi-NeMo
205    runs, which can be useful for identifying robust motif instances that are
206    not sensitive to parameter variations or random initialization.
207
208    Parameters
209    ----------
210    hits_dfs : List[Union[pl.DataFrame, pl.LazyFrame]]
211        List of hit DataFrames from different Fi-NeMo runs. Each DataFrame must
212        contain the columns specified by the intersection criteria. LazyFrames
213        will be collected before processing.
214    relaxed : bool
215        If True, uses relaxed intersection criteria with only motif names and
216        untrimmed coordinates. If False, uses strict criteria including all
217        coordinate and metadata columns.
218
219    Returns
220    -------
221    pl.DataFrame
222        DataFrame containing hits that appear in all input datasets.
223        Columns from later datasets are suffixed with their index (e.g., '_1', '_2').
224        The first dataset's columns retain their original names.
225
226    Raises
227    ------
228    ValueError
229        If fewer than one hits DataFrame is provided.
230    KeyError
231        If required columns for the specified intersection criteria are missing
232        from any of the input DataFrames.
233
234    Notes
235    -----
236    Relaxed intersection is useful when comparing results across different
237    region definitions or motif trimming parameters, but may produce less
238    precise matches. Strict intersection requires identical region definitions
239    and is recommended for most use cases.
240
241    The intersection columns used are:
242    - Relaxed: ["chr", "start_untrimmed", "end_untrimmed", "motif_name", "strand"]
243    - Strict: ["chr", "start", "end", "start_untrimmed", "end_untrimmed",
244               "motif_name", "strand", "peak_name", "peak_id"]
245
246    The function performs successive inner joins starting with the first DataFrame,
247    so the final result contains only hits present in all input datasets.
248
249    Examples
250    --------
251    >>> common_hits = intersect_hits([hits_df1, hits_df2], relaxed=False)
252    >>> print(f"Found {common_hits.height} hits common to both runs")
253    >>>
254    >>> # Compare relaxed vs strict intersection
255    >>> relaxed_hits = intersect_hits([hits_df1, hits_df2], relaxed=True)
256    >>> strict_hits = intersect_hits([hits_df1, hits_df2], relaxed=False)
257    >>> print(f"Relaxed: {relaxed_hits.height}, Strict: {strict_hits.height}")
258    """
259    if relaxed:
260        # Relaxed criteria: only motif identity and untrimmed positions
261        join_cols = ["chr", "start_untrimmed", "end_untrimmed", "motif_name", "strand"]
262    else:
263        # Strict criteria: all coordinate and metadata columns
264        join_cols = [
265            "chr",
266            "start",
267            "end",
268            "start_untrimmed",
269            "end_untrimmed",
270            "motif_name",
271            "strand",
272            "peak_name",
273            "peak_id",
274        ]
275
276    if len(hits_dfs) < 1:
277        raise ValueError("At least one hits dataframe required")
278
279    # Ensure all DataFrames are collected
280    collected_dfs = []
281    for df in hits_dfs:
282        if isinstance(df, pl.LazyFrame):
283            collected_dfs.append(df.collect())
284        else:
285            collected_dfs.append(df)
286
287    # Start with first DataFrame and successively intersect with others
288    hits_df = collected_dfs[0]
289    for i in range(1, len(collected_dfs)):
290        hits_df = hits_df.join(
291            collected_dfs[i],
292            on=join_cols,
293            how="inner",
294            suffix=f"_{i}",
295            join_nulls=True,
296            coalesce=True,
297        )
298
299    return hits_df

Intersect hit datasets across multiple runs to find common hits.

This function finds hits that appear consistently across multiple Fi-NeMo runs, which can be useful for identifying robust motif instances that are not sensitive to parameter variations or random initialization.

Parameters
  • hits_dfs (List[Union[pl.DataFrame, pl.LazyFrame]]): List of hit DataFrames from different Fi-NeMo runs. Each DataFrame must contain the columns specified by the intersection criteria. LazyFrames will be collected before processing.
  • relaxed (bool): If True, uses relaxed intersection criteria with only motif names and untrimmed coordinates. If False, uses strict criteria including all coordinate and metadata columns.
Returns
  • pl.DataFrame: DataFrame containing hits that appear in all input datasets. Columns from later datasets are suffixed with their index (e.g., '_1', '_2'). The first dataset's columns retain their original names.
Raises
  • ValueError: If fewer than one hits DataFrame is provided.
  • KeyError: If required columns for the specified intersection criteria are missing from any of the input DataFrames.
Notes

Relaxed intersection is useful when comparing results across different region definitions or motif trimming parameters, but may produce less precise matches. Strict intersection requires identical region definitions and is recommended for most use cases.

The intersection columns used are:

  • Relaxed: ["chr", "start_untrimmed", "end_untrimmed", "motif_name", "strand"]
  • Strict: ["chr", "start", "end", "start_untrimmed", "end_untrimmed", "motif_name", "strand", "peak_name", "peak_id"]

The function performs successive inner joins starting with the first DataFrame, so the final result contains only hits present in all input datasets.

Examples
>>> common_hits = intersect_hits([hits_df1, hits_df2], relaxed=False)
>>> print(f"Found {common_hits.height} hits common to both runs")
>>>
>>> # Compare relaxed vs strict intersection
>>> relaxed_hits = intersect_hits([hits_df1, hits_df2], relaxed=True)
>>> strict_hits = intersect_hits([hits_df1, hits_df2], relaxed=False)
>>> print(f"Relaxed: {relaxed_hits.height}, Strict: {strict_hits.height}")