finemo.data_io

Data input/output module for the Fi-NeMo motif instance calling pipeline.

This module handles loading and processing of various genomic data formats including:

  • Peak region files (ENCODE NarrowPeak format)
  • Genome sequences (FASTA format)
  • Contribution scores (bigWig, HDF5 formats)
  • Neural network model outputs
  • Motif data from TF-MoDISco
  • Hit calling results

The module supports multiple input formats used for contribution scores and provides utilities for data conversion and quality control.

   1"""Data input/output module for the Fi-NeMo motif instance calling pipeline.
   2
   3This module handles loading and processing of various genomic data formats including:
   4- Peak region files (ENCODE NarrowPeak format)
   5- Genome sequences (FASTA format)
   6- Contribution scores (bigWig, HDF5 formats)
   7- Neural network model outputs
   8- Motif data from TF-MoDISco
   9- Hit calling results
  10
  11The module supports multiple input formats used for contribution scores
  12and provides utilities for data conversion and quality control.
  13"""
  14
  15import json
  16import os
  17import warnings
  18from contextlib import ExitStack
  19from typing import List, Dict, Tuple, Optional, Any, Union, Callable
  20
  21import numpy as np
  22from numpy import ndarray
  23import h5py
  24import hdf5plugin  # noqa: F401, imported for side effects (HDF5 plugin registration)
  25import polars as pl
  26import pyBigWig
  27import pyfaidx
  28from jaxtyping import Float, Int
  29
  30from tqdm import tqdm
  31
  32
  33def load_txt(path: str) -> List[str]:
  34    """Load a text file containing one item per line.
  35
  36    Parameters
  37    ----------
  38    path : str
  39        Path to the text file.
  40
  41    Returns
  42    -------
  43    List[str]
  44        List of strings, one per line (first column if tab-delimited).
  45    """
  46    entries = []
  47    with open(path) as f:
  48        for line in f:
  49            item = line.rstrip("\n").split("\t")[0]
  50            entries.append(item)
  51
  52    return entries
  53
  54
  55def load_mapping(path: str, value_type: Callable[[str], Any]) -> Dict[str, Any]:
  56    """Load a two-column tab-delimited mapping file.
  57
  58    Parameters
  59    ----------
  60    path : str
  61        Path to the mapping file. Must be tab-delimited with exactly two columns.
  62    value_type : Callable[[str], Any]
  63        Type constructor to apply to values (e.g., int, float, str).
  64        Must accept a string and return the converted value.
  65
  66    Returns
  67    -------
  68    Dict[str, Any]
  69        Dictionary mapping keys to values of the specified type.
  70
  71    Raises
  72    ------
  73    ValueError
  74        If lines don't contain exactly two tab-separated values.
  75    FileNotFoundError
  76        If the specified file does not exist.
  77    """
  78    mapping = {}
  79    with open(path) as f:
  80        for line in f:
  81            key, val = line.rstrip("\n").split("\t")
  82            mapping[key] = value_type(val)
  83
  84    return mapping
  85
  86
  87def load_mapping_tuple(
  88    path: str, value_type: Callable[[str], Any]
  89) -> Dict[str, Tuple[Any, ...]]:
  90    """Load a mapping file where values are tuples from multiple columns.
  91
  92    Parameters
  93    ----------
  94    path : str
  95        Path to the mapping file. Must be tab-delimited with multiple columns.
  96    value_type : Callable[[str], Any]
  97        Type constructor to apply to each value element.
  98        Must accept a string and return the converted value.
  99
 100    Returns
 101    -------
 102    Dict[str, Tuple[Any, ...]]
 103        Dictionary mapping keys to tuples of values of the specified type.
 104        The first column is used as the key, remaining columns as tuple values.
 105
 106    Raises
 107    ------
 108    ValueError
 109        If lines don't contain at least two tab-separated values.
 110    FileNotFoundError
 111        If the specified file does not exist.
 112    """
 113    mapping = {}
 114    with open(path) as f:
 115        for line in f:
 116            entries = line.rstrip("\n").split("\t")
 117            key = entries[0]
 118            val = entries[1:]
 119            mapping[key] = tuple(value_type(i) for i in val)
 120
 121    return mapping
 122
 123
 124# ENCODE NarrowPeak format column definitions
 125NARROWPEAK_SCHEMA: List[str] = [
 126    "chr",
 127    "peak_start",
 128    "peak_end",
 129    "peak_name",
 130    "peak_score",
 131    "peak_strand",
 132    "peak_signal",
 133    "peak_pval",
 134    "peak_qval",
 135    "peak_summit",
 136]
 137NARROWPEAK_DTYPES: List[Any] = [
 138    pl.String,
 139    pl.Int32,
 140    pl.Int32,
 141    pl.String,
 142    pl.UInt32,
 143    pl.String,
 144    pl.Float32,
 145    pl.Float32,
 146    pl.Float32,
 147    pl.Int32,
 148]
 149
 150
 151def load_peaks(
 152    peaks_path: str, chrom_order_path: Optional[str], half_width: int
 153) -> pl.DataFrame:
 154    """Load peak region data from ENCODE NarrowPeak format file.
 155
 156    Parameters
 157    ----------
 158    peaks_path : str
 159        Path to the NarrowPeak format file.
 160    chrom_order_path : str, optional
 161        Path to file defining chromosome ordering. If None, uses order from peaks file.
 162    half_width : int
 163        Half-width of regions around peak summits.
 164
 165    Returns
 166    -------
 167    pl.DataFrame
 168        DataFrame containing peak information with columns:
 169        - chr: Chromosome name
 170        - peak_region_start: Start coordinate of centered region
 171        - peak_name: Peak identifier
 172        - peak_id: Sequential peak index
 173        - chr_id: Numeric chromosome identifier
 174    """
 175    peaks = (
 176        pl.scan_csv(
 177            peaks_path,
 178            has_header=False,
 179            new_columns=NARROWPEAK_SCHEMA,
 180            separator="\t",
 181            quote_char=None,
 182            schema_overrides=NARROWPEAK_DTYPES,
 183            null_values=[".", "NA", "null", "NaN"],
 184        )
 185        .select(
 186            chr=pl.col("chr"),
 187            peak_region_start=pl.col("peak_start") + pl.col("peak_summit") - half_width,
 188            peak_name=pl.col("peak_name"),
 189        )
 190        .with_row_index(name="peak_id")
 191        .collect()
 192    )
 193
 194    if chrom_order_path is not None:
 195        chrom_order = load_txt(chrom_order_path)
 196    else:
 197        chrom_order = []
 198
 199    chrom_order_set = set(chrom_order)
 200    chrom_order_peaks = [
 201        i
 202        for i in peaks.get_column("chr").unique(maintain_order=True)
 203        if i not in chrom_order_set
 204    ]
 205    chrom_order.extend(chrom_order_peaks)
 206    chrom_ind_map = {val: ind for ind, val in enumerate(chrom_order)}
 207
 208    peaks = peaks.with_columns(
 209        pl.col("chr").replace_strict(chrom_ind_map).alias("chr_id")
 210    )
 211
 212    return peaks
 213
 214
 215# DNA sequence alphabet for one-hot encoding
 216SEQ_ALPHABET: np.ndarray = np.array(["A", "C", "G", "T"], dtype="S1")
 217
 218
 219def one_hot_encode(sequence: str, dtype: Any = np.int8) -> Int[ndarray, "4 L"]:
 220    """Convert DNA sequence string to one-hot encoded matrix.
 221
 222    Parameters
 223    ----------
 224    sequence : str
 225        DNA sequence string containing A, C, G, T characters.
 226    dtype : np.dtype, default np.int8
 227        Data type for the output array.
 228
 229    Returns
 230    -------
 231    Int[ndarray, "4 L"]
 232        One-hot encoded sequence where rows correspond to A, C, G, T and
 233        L is the sequence length.
 234
 235    Notes
 236    -----
 237    The output array has shape (4, len(sequence)) with rows corresponding to
 238    nucleotides A, C, G, T in that order. Non-standard nucleotides (N, etc.)
 239    result in all-zero columns.
 240    """
 241    sequence = sequence.upper()
 242
 243    seq_chararray = np.frombuffer(sequence.encode("UTF-8"), dtype="S1")
 244    one_hot = (seq_chararray[None, :] == SEQ_ALPHABET[:, None]).astype(dtype)
 245
 246    return one_hot
 247
 248
 249def load_regions_from_bw(
 250    peaks: pl.DataFrame, fa_path: str, bw_paths: List[str], half_width: int
 251) -> Tuple[Int[ndarray, "N 4 L"], Float[ndarray, "N L"]]:
 252    """Load genomic sequences and contribution scores from FASTA and bigWig files.
 253
 254    Parameters
 255    ----------
 256    peaks : pl.DataFrame
 257        Peak regions DataFrame from load_peaks() containing columns:
 258        'chr', 'peak_region_start'.
 259    fa_path : str
 260        Path to genome FASTA file (.fa or .fasta format).
 261    bw_paths : List[str]
 262        List of paths to bigWig files containing contribution scores.
 263        Must be non-empty.
 264    half_width : int
 265        Half-width of regions to extract around peak centers.
 266        Total region width will be 2 * half_width.
 267
 268    Returns
 269    -------
 270    sequences : Int[ndarray, "N 4 L"]
 271        One-hot encoded DNA sequences where N is the number of peaks,
 272        4 represents A,C,G,T nucleotides, and L is the region length (2 * half_width).
 273    contribs : Float[ndarray, "N L"]
 274        Contribution scores averaged across input bigWig files.
 275        Shape is (N peaks, L region_length).
 276
 277    Notes
 278    -----
 279    BigWig files only provide projected contribution scores, not hypothetical scores.
 280    Regions extending beyond chromosome boundaries are zero-padded.
 281    Missing values in bigWig files are converted to zero.
 282    """
 283    num_peaks = peaks.height
 284    region_width = half_width * 2
 285
 286    sequences = np.zeros((num_peaks, 4, region_width), dtype=np.int8)
 287    contribs = np.zeros((num_peaks, region_width), dtype=np.float16)
 288
 289    # Load genome reference
 290    genome = pyfaidx.Fasta(fa_path, one_based_attributes=False)
 291
 292    bws = [pyBigWig.open(i) for i in bw_paths]
 293    contrib_buffer = np.zeros((len(bw_paths), half_width * 2), dtype=np.float16)
 294
 295    try:
 296        for ind, row in tqdm(
 297            enumerate(peaks.iter_rows(named=True)),
 298            disable=None,
 299            unit="regions",
 300            total=num_peaks,
 301        ):
 302            chrom = row["chr"]
 303            start = row["peak_region_start"]
 304            end = start + 2 * half_width
 305
 306            sequence_data: pyfaidx.FastaRecord = genome[chrom][start:end]  # type: ignore
 307            sequence: str = sequence_data.seq  # type: ignore
 308            start_adj: int = sequence_data.start  # type: ignore
 309            end_adj: int = sequence_data.end  # type: ignore
 310            a = start_adj - start
 311            b = end_adj - start
 312
 313            if b > a:
 314                sequences[ind, :, a:b] = one_hot_encode(sequence)
 315
 316                for j, bw in enumerate(bws):
 317                    contrib_buffer[j, :] = np.nan_to_num(
 318                        bw.values(chrom, start_adj, end_adj, numpy=True)
 319                    )
 320
 321                contribs[ind, a:b] = np.mean(contrib_buffer, axis=0)
 322
 323    finally:
 324        for bw in bws:
 325            bw.close()
 326
 327    return sequences, contribs
 328
 329
 330def load_regions_from_chrombpnet_h5(
 331    h5_paths: List[str], half_width: int
 332) -> Tuple[Int[ndarray, "N 4 L"], Float[ndarray, "N 4 L"]]:
 333    """Load genomic sequences and contribution scores from ChromBPNet HDF5 files.
 334
 335    Parameters
 336    ----------
 337    h5_paths : List[str]
 338        List of paths to ChromBPNet HDF5 files containing sequences and SHAP scores.
 339        Must be non-empty and contain compatible data shapes.
 340    half_width : int
 341        Half-width of regions to extract around the center.
 342        Total region width will be 2 * half_width.
 343
 344    Returns
 345    -------
 346    sequences : Int[ndarray, "N 4 L"]
 347        One-hot encoded DNA sequences where N is the number of regions,
 348        4 represents A,C,G,T nucleotides, and L is the region length (2 * half_width).
 349    contribs : Float[ndarray, "N 4 L"]
 350        SHAP contribution scores averaged across input files.
 351        Shape is (N regions, 4 nucleotides, L region_length).
 352
 353    Notes
 354    -----
 355    ChromBPNet files store sequences in 'raw/seq' and SHAP scores in 'shap/seq'.
 356    All input files must have the same dimensions and number of regions.
 357    Missing values in contribution scores are converted to zero.
 358    """
 359    with ExitStack() as stack:
 360        h5s = [stack.enter_context(h5py.File(i)) for i in h5_paths]
 361
 362        start = h5s[0]["raw/seq"].shape[-1] // 2 - half_width  # type: ignore  # HDF5 array access
 363        end = start + 2 * half_width
 364
 365        sequences = h5s[0]["raw/seq"][:, :, start:end].astype(np.int8)  # type: ignore  # HDF5 array access
 366        contribs = np.mean(
 367            [np.nan_to_num(f["shap/seq"][:, :, start:end]) for f in h5s],  # type: ignore  # HDF5 array access
 368            axis=0,
 369            dtype=np.float16,
 370        )
 371
 372    return sequences, contribs  # type: ignore  # HDF5 arrays converted to NumPy
 373
 374
 375def load_regions_from_bpnet_h5(
 376    h5_paths: List[str], half_width: int
 377) -> Tuple[Int[ndarray, "N 4 L"], Float[ndarray, "N 4 L"]]:
 378    """Load genomic sequences and contribution scores from BPNet HDF5 files.
 379
 380    Parameters
 381    ----------
 382    h5_paths : List[str]
 383        List of paths to BPNet HDF5 files containing sequences and contribution scores.
 384        Must be non-empty and contain compatible data shapes.
 385    half_width : int
 386        Half-width of regions to extract around the center.
 387        Total region width will be 2 * half_width.
 388
 389    Returns
 390    -------
 391    sequences : Int[ndarray, "N 4 L"]
 392        One-hot encoded DNA sequences where N is the number of regions,
 393        4 represents A,C,G,T nucleotides, and L is the region length (2 * half_width).
 394    contribs : Float[ndarray, "N 4 L"]
 395        Hypothetical contribution scores averaged across input files.
 396        Shape is (N regions, 4 nucleotides, L region_length).
 397
 398    Notes
 399    -----
 400    BPNet files store sequences in 'input_seqs' and hypothetical scores in 'hyp_scores'.
 401    The data requires axis swapping to convert from (n, length, 4) to (n, 4, length) format.
 402    All input files must have the same dimensions and number of regions.
 403    Missing values in contribution scores are converted to zero.
 404    """
 405    with ExitStack() as stack:
 406        h5s = [stack.enter_context(h5py.File(i)) for i in h5_paths]
 407
 408        start = h5s[0]["input_seqs"].shape[-2] // 2 - half_width  # type: ignore  # HDF5 array access
 409        end = start + 2 * half_width
 410
 411        sequences = h5s[0]["input_seqs"][:, start:end, :].swapaxes(1, 2).astype(np.int8)  # type: ignore  # HDF5 array access with axis swap
 412        contribs = np.mean(
 413            [
 414                np.nan_to_num(f["hyp_scores"][:, start:end, :].swapaxes(1, 2))  # type: ignore  # HDF5 array access
 415                for f in h5s
 416            ],
 417            axis=0,
 418            dtype=np.float16,
 419        )
 420
 421    return sequences, contribs
 422
 423
 424def load_npy_or_npz(path: str) -> ndarray:
 425    """Load array data from .npy or .npz file.
 426
 427    Parameters
 428    ----------
 429    path : str
 430        Path to .npy or .npz file. File must exist and contain valid NumPy data.
 431
 432    Returns
 433    -------
 434    ndarray
 435        Loaded array data. For .npz files, returns the first array ('arr_0').
 436        For .npy files, returns the array directly.
 437
 438    Raises
 439    ------
 440    FileNotFoundError
 441        If the specified file does not exist.
 442    KeyError
 443        If .npz file does not contain 'arr_0' key.
 444    """
 445    f = np.load(path)
 446    if isinstance(f, np.ndarray):
 447        arr = f
 448    else:
 449        arr = f["arr_0"]
 450
 451    return arr
 452
 453
 454def load_regions_from_modisco_fmt(
 455    shaps_paths: List[str], ohe_path: str, half_width: int
 456) -> Tuple[Int[ndarray, "N 4 L"], Float[ndarray, "N 4 L"]]:
 457    """Load genomic sequences and contribution scores from TF-MoDISco format files.
 458
 459    Parameters
 460    ----------
 461    shaps_paths : List[str]
 462        List of paths to .npy/.npz files containing SHAP/attribution scores.
 463        Must be non-empty and all files must have compatible shapes.
 464    ohe_path : str
 465        Path to .npy/.npz file containing one-hot encoded sequences.
 466        Must have shape (n_regions, 4, sequence_length).
 467    half_width : int
 468        Half-width of regions to extract around the center.
 469        Total region width will be 2 * half_width.
 470
 471    Returns
 472    -------
 473    sequences : Int[ndarray, "N 4 L"]
 474        One-hot encoded DNA sequences where N is the number of regions,
 475        4 represents A,C,G,T nucleotides, and L is the region length (2 * half_width).
 476    contribs : Float[ndarray, "N 4 L"]
 477        SHAP contribution scores averaged across input files.
 478        Shape is (N regions, 4 nucleotides, L region_length).
 479
 480    Notes
 481    -----
 482    All SHAP files must have the same shape as the sequence file.
 483    Missing values in contribution scores are converted to zero.
 484    The center of the input sequences is used as the reference point for extraction.
 485    """
 486    sequences_raw = load_npy_or_npz(ohe_path)
 487
 488    start = sequences_raw.shape[-1] // 2 - half_width
 489    end = start + 2 * half_width
 490
 491    sequences = sequences_raw[:, :, start:end].astype(np.int8)
 492
 493    shaps = [np.nan_to_num(load_npy_or_npz(p)[:, :, start:end]) for p in shaps_paths]
 494    contribs = np.mean(shaps, axis=0, dtype=np.float16)
 495
 496    return sequences, contribs
 497
 498
 499def load_regions_npz(
 500    npz_path: str,
 501) -> Tuple[
 502    Int[ndarray, "N 4 L"],
 503    Union[Float[ndarray, "N 4 L"], Float[ndarray, "N L"]],
 504    pl.DataFrame,
 505    bool,
 506]:
 507    """Load preprocessed genomic regions from NPZ file.
 508
 509    Parameters
 510    ----------
 511    npz_path : str
 512        Path to NPZ file containing sequences, contributions, and optional coordinates.
 513        Must contain 'sequences' and 'contributions' arrays at minimum.
 514
 515    Returns
 516    -------
 517    sequences : Int[ndarray, "N 4 L"]
 518        One-hot encoded DNA sequences where N is the number of regions,
 519        4 represents A,C,G,T nucleotides, and L is the region length.
 520    contributions : Union[Float[ndarray, "N 4 L"], Float[ndarray, "N L"]]
 521        Contribution scores in either hypothetical format (N, 4, L) or
 522        projected format (N, L). Shape depends on the input data format.
 523    peaks_df : pl.DataFrame
 524        DataFrame containing peak region information with columns:
 525        'chr', 'chr_id', 'peak_region_start', 'peak_id', 'peak_name'.
 526    has_peaks : bool
 527        Whether the file contains genomic coordinate information.
 528        If False, placeholder coordinate data is used.
 529
 530    Notes
 531    -----
 532    If genomic coordinates are not present in the NPZ file, creates placeholder
 533    coordinate data and issues a warning. The placeholder data uses 'NA' for
 534    chromosome names and sequential indices for peak IDs.
 535
 536    Raises
 537    ------
 538    KeyError
 539        If required arrays 'sequences' or 'contributions' are missing from the file.
 540    """
 541    data = np.load(npz_path)
 542
 543    if "chr" not in data.keys():
 544        warnings.warn(
 545            "No genome coordinates present in the input .npz file. Returning sequences and contributions only."
 546        )
 547        has_peaks = False
 548        num_regions = data["sequences"].shape[0]
 549        peak_data = {
 550            "chr": np.array(["NA"] * num_regions, dtype="U"),
 551            "chr_id": np.arange(num_regions, dtype=np.uint32),
 552            "peak_region_start": np.zeros(num_regions, dtype=np.int32),
 553            "peak_id": np.arange(num_regions, dtype=np.uint32),
 554            "peak_name": np.array(["NA"] * num_regions, dtype="U"),
 555        }
 556
 557    else:
 558        has_peaks = True
 559        peak_data = {
 560            "chr": data["chr"],
 561            "chr_id": data["chr_id"],
 562            "peak_region_start": data["start"],
 563            "peak_id": data["peak_id"],
 564            "peak_name": data["peak_name"],
 565        }
 566
 567    peaks_df = pl.DataFrame(peak_data)
 568
 569    return data["sequences"], data["contributions"], peaks_df, has_peaks
 570
 571
 572def write_regions_npz(
 573    sequences: Int[ndarray, "N 4 L"],
 574    contributions: Union[Float[ndarray, "N 4 L"], Float[ndarray, "N L"]],
 575    out_path: str,
 576    peaks_df: Optional[pl.DataFrame] = None,
 577) -> None:
 578    """Write genomic regions and contribution scores to compressed NPZ file.
 579
 580    Parameters
 581    ----------
 582    sequences : Int[ndarray, "N 4 L"]
 583        One-hot encoded DNA sequences where N is the number of regions,
 584        4 represents A,C,G,T nucleotides, and L is the region length.
 585    contributions : Union[Float[ndarray, "N 4 L"], Float[ndarray, "N L"]]
 586        Contribution scores in either hypothetical format (N, 4, L) or
 587        projected format (N, L).
 588    out_path : str
 589        Output path for the NPZ file. Parent directory must exist.
 590    peaks_df : Optional[pl.DataFrame]
 591        DataFrame containing peak region information with columns:
 592        'chr', 'chr_id', 'peak_region_start', 'peak_id', 'peak_name'.
 593        If None, only sequences and contributions are saved.
 594
 595    Raises
 596    ------
 597    ValueError
 598        If the number of regions in sequences/contributions doesn't match peaks_df.
 599    FileNotFoundError
 600        If the parent directory of out_path does not exist.
 601
 602    Notes
 603    -----
 604    The output file is compressed using NumPy's savez_compressed format.
 605    If peaks_df is provided, genomic coordinate information is included
 606    in the output file for downstream analysis.
 607    """
 608    if peaks_df is None:
 609        warnings.warn(
 610            "No genome coordinates provided. Writing sequences and contributions only."
 611        )
 612        np.savez_compressed(out_path, sequences=sequences, contributions=contributions)
 613
 614    else:
 615        num_regions = peaks_df.height
 616        if (num_regions != sequences.shape[0]) or (
 617            num_regions != contributions.shape[0]
 618        ):
 619            raise ValueError(
 620                f"Input sequences of shape {sequences.shape} and/or "
 621                f"input contributions of shape {contributions.shape} "
 622                f"are not compatible with peak region count of {num_regions}"
 623            )
 624
 625        chr_arr = peaks_df.get_column("chr").to_numpy().astype("U")
 626        chr_id_arr = peaks_df.get_column("chr_id").to_numpy()
 627        start_arr = peaks_df.get_column("peak_region_start").to_numpy()
 628        peak_id_arr = peaks_df.get_column("peak_id").to_numpy()
 629        peak_name_arr = peaks_df.get_column("peak_name").to_numpy().astype("U")
 630        np.savez_compressed(
 631            out_path,
 632            sequences=sequences,
 633            contributions=contributions,
 634            chr=chr_arr,
 635            chr_id=chr_id_arr,
 636            start=start_arr,
 637            peak_id=peak_id_arr,
 638            peak_name=peak_name_arr,
 639        )
 640
 641
 642def trim_motif(cwm: Float[ndarray, "4 W"], trim_threshold: float) -> Tuple[int, int]:
 643    """Determine trimmed start and end positions for a motif based on contribution magnitude.
 644
 645    This function identifies the core region of a motif by finding positions where
 646    the total absolute contribution exceeds a threshold relative to the maximum.
 647
 648    Parameters
 649    ----------
 650    cwm : Float[ndarray, "4 W"]
 651        Contribution weight matrix for the motif where 4 represents A,C,G,T
 652        nucleotides and W is the motif width.
 653    trim_threshold : float
 654        Fraction of maximum score to use as trimming threshold (0.0 to 1.0).
 655        Higher values result in more aggressive trimming.
 656
 657    Returns
 658    -------
 659    start : int
 660        Start position of the trimmed motif (inclusive).
 661    end : int
 662        End position of the trimmed motif (exclusive).
 663
 664    Notes
 665    -----
 666    The trimming is based on the sum of absolute contributions across all nucleotides
 667    at each position. Positions with contributions below trim_threshold * max_score
 668    are removed from the motif edges.
 669
 670    Adapted from https://github.com/jmschrei/tfmodisco-lite/blob/570535ee5ccf43d670e898d92d63af43d68c38c5/modiscolite/report.py#L213-L236
 671    """
 672    score = np.sum(np.abs(cwm), axis=0)
 673    trim_thresh = np.max(score) * trim_threshold
 674    pass_inds = np.nonzero(score >= trim_thresh)
 675    start = max(int(np.min(pass_inds)), 0)  # type: ignore  # nonzero returns tuple of arrays
 676    end = min(int(np.max(pass_inds)) + 1, len(score))  # type: ignore  # nonzero returns tuple of arrays
 677
 678    return start, end
 679
 680
 681def softmax(x: Float[ndarray, "4 W"], temp: float = 100) -> Float[ndarray, "4 W"]:
 682    """Apply softmax transformation with temperature scaling.
 683
 684    Parameters
 685    ----------
 686    x : Float[ndarray, "4 W"]
 687        Input array to transform where 4 represents A,C,G,T nucleotides
 688        and W is the motif width.
 689    temp : float, default 100
 690        Temperature parameter for softmax scaling. Higher values create
 691        sharper probability distributions.
 692
 693    Returns
 694    -------
 695    Float[ndarray, "4 W"]
 696        Softmax-transformed array with same shape as input. Each column
 697        sums to 1.0, representing nucleotide probabilities at each position.
 698
 699    Notes
 700    -----
 701    The softmax is applied along the nucleotide axis (axis=0), normalizing
 702    each position to have probabilities that sum to 1. The temperature
 703    parameter controls the sharpness of the distribution.
 704    """
 705    norm_x = x - np.mean(x, axis=1, keepdims=True)
 706    exp = np.exp(temp * norm_x)
 707    return exp / np.sum(exp, axis=0, keepdims=True)
 708
 709
 710def _motif_name_sort_key(data: Tuple[str, Any]) -> Union[Tuple[int, int], Tuple[int, str]]:
 711    """Generate sort key for TF-MoDISco motif names.
 712
 713    This function creates a sort key that orders motifs by pattern number,
 714    with non-standard patterns sorted to the end.
 715
 716    Parameters
 717    ----------
 718    data : Tuple[str, Any]
 719        Tuple containing motif name as first element and additional data.
 720        The motif name should follow the format 'pattern_N' or 'pattern#N' where N is an integer.
 721
 722    Returns
 723    -------
 724    Union[Tuple[int, int], Tuple[int, str]]
 725        Sort key tuple for ordering motifs. Standard pattern names return
 726        (0, pattern_number) while non-standard names return (1, name).
 727
 728    Notes
 729    -----
 730    This function is used internally by load_modisco_motifs to ensure
 731    consistent motif ordering across runs.
 732    """
 733    pattern_name = data[0]
 734    try:
 735        return (0, int(pattern_name.split("_")[-1]))
 736    except (ValueError, IndexError):
 737        try:
 738            return (0, int(pattern_name.split("#")[-1]))
 739        except (ValueError, IndexError):
 740            return (1, pattern_name)
 741
 742
 743MODISCO_PATTERN_GROUPS = ["pos_patterns", "neg_patterns"]
 744
 745
 746def load_modisco_motifs(
 747    modisco_h5_path: str,
 748    trim_coords: Optional[Dict[str, Tuple[int, int]]],
 749    trim_thresholds: Optional[Dict[str, float]],
 750    trim_threshold_default: float,
 751    motif_type: str,
 752    motifs_include: Optional[List[str]],
 753    motif_name_map: Optional[Dict[str, str]],
 754    motif_lambdas: Optional[Dict[str, float]],
 755    motif_lambda_default: float,
 756    include_rc: bool,
 757) -> Tuple[pl.DataFrame, Float[ndarray, "M 4 W"], Int[ndarray, "M W"], ndarray]:
 758    """Load motif data from TF-MoDISco HDF5 file with customizable processing options.
 759
 760    This function extracts contribution weight matrices and associated metadata from
 761    TF-MoDISco results, with support for custom naming, trimming, and regularization
 762    parameters.
 763
 764    Parameters
 765    ----------
 766    modisco_h5_path : str
 767        Path to TF-MoDISco HDF5 results file containing pattern groups.
 768    trim_coords : Optional[Dict[str, Tuple[int, int]]]
 769        Manual trim coordinates for specific motifs {motif_name: (start, end)}.
 770        Takes precedence over automatic trimming based on thresholds.
 771    trim_thresholds : Optional[Dict[str, float]]
 772        Custom trim thresholds for specific motifs {motif_name: threshold}.
 773        Values should be between 0.0 and 1.0.
 774    trim_threshold_default : float
 775        Default trim threshold for motifs not in trim_thresholds.
 776        Fraction of maximum contribution used for trimming.
 777    motif_type : str
 778        Type of motif to extract. Must be one of:
 779        - 'cwm': Contribution weight matrix (normalized)
 780        - 'hcwm': Hypothetical contribution weight matrix
 781        - 'pfm': Position frequency matrix
 782        - 'pfm_softmax': Softmax-transformed position frequency matrix
 783    motifs_include : Optional[List[str]]
 784        List of motif names to include. If None, includes all motifs found.
 785        Names should follow format 'pos_patterns.pattern_N' or 'neg_patterns.pattern_N'.
 786    motif_name_map : Optional[Dict[str, str]]
 787        Mapping from original to custom motif names {orig_name: new_name}.
 788        New names must be unique across all motifs.
 789    motif_lambdas : Optional[Dict[str, float]]
 790        Custom lambda regularization values for specific motifs {motif_name: lambda}.
 791        Higher values increase sparsity penalty for the corresponding motif.
 792    motif_lambda_default : float
 793        Default lambda value for motifs not specified in motif_lambdas.
 794    include_rc : bool
 795        Whether to include reverse complement motifs in addition to forward motifs.
 796        If True, doubles the number of motifs returned.
 797
 798    Returns
 799    -------
 800    motifs_df : pl.DataFrame
 801        DataFrame containing motif metadata with columns: motif_id, motif_name,
 802        motif_name_orig, strand, motif_start, motif_end, motif_scale, lambda.
 803    cwms : Float[ndarray, "M 4 W"]
 804        Contribution weight matrices for all motifs where M is the number of motifs,
 805        4 represents A,C,G,T nucleotides, and W is the motif width.
 806    trim_masks : Int[ndarray, "M W"]
 807        Binary masks indicating core motif regions (1) vs trimmed regions (0).
 808        Shape is (M motifs, W motif_width).
 809    names : ndarray
 810        Array of unique motif names (forward strand only).
 811
 812    Raises
 813    ------
 814    ValueError
 815        If motif_type is not one of the supported types, or if motif names
 816        in motif_name_map are not unique.
 817    FileNotFoundError
 818        If the specified HDF5 file does not exist.
 819    KeyError
 820        If required datasets are missing from the HDF5 file.
 821
 822    Notes
 823    -----
 824    Motif trimming removes low-contribution positions from the edges based on
 825    the position-wise sum of absolute contributions across nucleotides. The trimming
 826    helps focus on the core binding site.
 827
 828    Adapted from https://github.com/jmschrei/tfmodisco-lite/blob/570535ee5ccf43d670e898d92d63af43d68c38c5/modiscolite/report.py#L252-L272
 829    """
 830    motif_data_lsts = {
 831        "motif_name": [],
 832        "motif_name_orig": [],
 833        "strand": [],
 834        "motif_start": [],
 835        "motif_end": [],
 836        "motif_scale": [],
 837        "lambda": [],
 838    }
 839    motif_lst = []
 840    trim_mask_lst = []
 841
 842    if motifs_include is not None:
 843        motifs_include_set = set(motifs_include)
 844    else:
 845        motifs_include_set = None
 846
 847    if motif_name_map is None:
 848        motif_name_map = {}
 849
 850    if motif_lambdas is None:
 851        motif_lambdas = {}
 852
 853    if trim_coords is None:
 854        trim_coords = {}
 855    if trim_thresholds is None:
 856        trim_thresholds = {}
 857
 858    if len(motif_name_map.values()) != len(set(motif_name_map.values())):
 859        raise ValueError("Specified motif names are not unique")
 860
 861    with h5py.File(modisco_h5_path, "r") as modisco_results:
 862        for name in MODISCO_PATTERN_GROUPS:
 863            if name not in modisco_results.keys():
 864                continue
 865
 866            metacluster = modisco_results[name]
 867            for _, (pattern_name, pattern) in enumerate(
 868                sorted(metacluster.items(), key=_motif_name_sort_key)  # type: ignore  # HDF5 access
 869            ):
 870                pattern_tag = f"{name}.{pattern_name}"
 871
 872                if (
 873                    motifs_include_set is not None
 874                    and pattern_tag not in motifs_include_set
 875                ):
 876                    continue
 877
 878                motif_lambda = motif_lambdas.get(pattern_tag, motif_lambda_default)
 879                pattern_tag_orig = pattern_tag
 880                pattern_tag = motif_name_map.get(pattern_tag, pattern_tag)
 881
 882                cwm_raw = pattern["contrib_scores"][:].T  # type: ignore
 883                cwm_norm = np.sqrt((cwm_raw**2).sum())
 884
 885                cwm_fwd = cwm_raw / cwm_norm
 886                cwm_rev = cwm_fwd[::-1, ::-1]
 887
 888                if pattern_tag in trim_coords:
 889                    start_fwd, end_fwd = trim_coords[pattern_tag]
 890                else:
 891                    trim_threshold = trim_thresholds.get(
 892                        pattern_tag, trim_threshold_default
 893                    )
 894                    start_fwd, end_fwd = trim_motif(cwm_fwd, trim_threshold)
 895
 896                cwm_len = cwm_fwd.shape[1]
 897                start_rev, end_rev = cwm_len - end_fwd, cwm_len - start_fwd
 898
 899                trim_mask_fwd = np.zeros(cwm_fwd.shape[1], dtype=np.int8)
 900                trim_mask_fwd[start_fwd:end_fwd] = 1
 901                trim_mask_rev = np.zeros(cwm_rev.shape[1], dtype=np.int8)
 902                trim_mask_rev[start_rev:end_rev] = 1
 903
 904                if motif_type == "cwm":
 905                    motif_fwd = cwm_fwd
 906                    motif_rev = cwm_rev
 907                    motif_norm = cwm_norm
 908
 909                elif motif_type == "hcwm":
 910                    motif_raw = pattern["hypothetical_contribs"][:].T  # type: ignore
 911                    motif_norm = np.sqrt((motif_raw**2).sum())
 912
 913                    motif_fwd = motif_raw / motif_norm
 914                    motif_rev = motif_fwd[::-1, ::-1]
 915
 916                elif motif_type == "pfm":
 917                    motif_raw = pattern["sequence"][:].T  # type: ignore
 918                    motif_norm = 1
 919
 920                    motif_fwd = motif_raw / np.sum(motif_raw, axis=0, keepdims=True)
 921                    motif_rev = motif_fwd[::-1, ::-1]
 922
 923                elif motif_type == "pfm_softmax":
 924                    motif_raw = pattern["sequence"][:].T  # type: ignore
 925                    motif_norm = 1
 926
 927                    motif_fwd = softmax(motif_raw)
 928                    motif_rev = motif_fwd[::-1, ::-1]
 929
 930                else:
 931                    raise ValueError(
 932                        f"Invalid motif_type: {motif_type}. Must be one of 'cwm', 'hcwm', 'pfm', 'pfm_softmax'."
 933                    )
 934
 935                motif_data_lsts["motif_name"].append(pattern_tag)
 936                motif_data_lsts["motif_name_orig"].append(pattern_tag_orig)
 937                motif_data_lsts["strand"].append("+")
 938                motif_data_lsts["motif_start"].append(start_fwd)
 939                motif_data_lsts["motif_end"].append(end_fwd)
 940                motif_data_lsts["motif_scale"].append(motif_norm)
 941                motif_data_lsts["lambda"].append(motif_lambda)
 942
 943                if include_rc:
 944                    motif_data_lsts["motif_name"].append(pattern_tag)
 945                    motif_data_lsts["motif_name_orig"].append(pattern_tag_orig)
 946                    motif_data_lsts["strand"].append("-")
 947                    motif_data_lsts["motif_start"].append(start_rev)
 948                    motif_data_lsts["motif_end"].append(end_rev)
 949                    motif_data_lsts["motif_scale"].append(motif_norm)
 950                    motif_data_lsts["lambda"].append(motif_lambda)
 951
 952                    motif_lst.extend([motif_fwd, motif_rev])
 953                    trim_mask_lst.extend([trim_mask_fwd, trim_mask_rev])
 954
 955                else:
 956                    motif_lst.append(motif_fwd)
 957                    trim_mask_lst.append(trim_mask_fwd)
 958
 959    motifs_df = pl.DataFrame(motif_data_lsts).with_row_index(name="motif_id")
 960    cwms = np.stack(motif_lst, dtype=np.float16, axis=0)
 961    trim_masks = np.stack(trim_mask_lst, dtype=np.int8, axis=0)
 962    names = (
 963        motifs_df.filter(pl.col("strand") == "+").get_column("motif_name").to_numpy()
 964    )
 965
 966    return motifs_df, cwms, trim_masks, names
 967
 968
 969def load_modisco_seqlets(
 970    modisco_h5_path: str,
 971    peaks_df: pl.DataFrame,
 972    motifs_df: pl.DataFrame,
 973    half_width: int,
 974    modisco_half_width: int,
 975    lazy: bool = False,
 976) -> Union[pl.DataFrame, pl.LazyFrame]:
 977    """Load seqlet data from TF-MoDISco HDF5 file and convert to genomic coordinates.
 978
 979    This function extracts seqlet instances from TF-MoDISco results and converts
 980    their relative positions to absolute genomic coordinates using peak region
 981    information.
 982
 983    Parameters
 984    ----------
 985    modisco_h5_path : str
 986        Path to TF-MoDISco HDF5 results file containing seqlet data.
 987    peaks_df : pl.DataFrame
 988        DataFrame containing peak region information with columns:
 989        'peak_id', 'chr', 'chr_id', 'peak_region_start'.
 990    motifs_df : pl.DataFrame
 991        DataFrame containing motif metadata with columns:
 992        'motif_name_orig', 'strand', 'motif_name', 'motif_start', 'motif_end'.
 993    half_width : int
 994        Half-width of the current analysis regions.
 995    modisco_half_width : int
 996        Half-width of the regions used in the original TF-MoDISco analysis.
 997        Used to calculate coordinate offsets.
 998    lazy : bool, default False
 999        If True, returns a LazyFrame for efficient chaining of operations.
1000        If False, collects the result into a DataFrame.
1001
1002    Returns
1003    -------
1004    Union[pl.DataFrame, pl.LazyFrame]
1005        Seqlets with genomic coordinates containing columns:
1006        - chr: Chromosome name
1007        - chr_id: Numeric chromosome identifier
1008        - start: Start coordinate of trimmed motif instance
1009        - end: End coordinate of trimmed motif instance
1010        - start_untrimmed: Start coordinate of full motif instance
1011        - end_untrimmed: End coordinate of full motif instance
1012        - is_revcomp: Whether the motif is reverse complemented
1013        - strand: Motif strand ('+' or '-')
1014        - motif_name: Motif name (may be remapped)
1015        - peak_id: Peak identifier
1016        - peak_region_start: Peak region start coordinate
1017
1018    Notes
1019    -----
1020    Seqlets are deduplicated based on chromosome ID, start position (untrimmed),
1021    motif name, and reverse complement status to avoid redundant instances.
1022
1023    The coordinate transformation accounts for differences in region sizes
1024    between the original TF-MoDISco analysis and the current analysis.
1025    """
1026
1027    start_lst = []
1028    end_lst = []
1029    is_revcomp_lst = []
1030    strand_lst = []
1031    peak_id_lst = []
1032    pattern_tags = []
1033
1034    with h5py.File(modisco_h5_path, "r") as modisco_results:
1035        for name in MODISCO_PATTERN_GROUPS:
1036            if name not in modisco_results.keys():
1037                continue
1038
1039            metacluster = modisco_results[name]
1040
1041            key = _motif_name_sort_key
1042            for _, (pattern_name, pattern) in enumerate(
1043                sorted(metacluster.items(), key=key)  # type: ignore  # HDF5 access
1044            ):
1045                pattern_tag = f"{name}.{pattern_name}"
1046
1047                starts = pattern["seqlets/start"][:].astype(np.int32)  # type: ignore
1048                ends = pattern["seqlets/end"][:].astype(np.int32)  # type: ignore
1049                is_revcomps = pattern["seqlets/is_revcomp"][:].astype(bool)  # type: ignore
1050                strands = ["+" if not i else "-" for i in is_revcomps]
1051                peak_ids = pattern["seqlets/example_idx"][:].astype(np.uint32)  # type: ignore
1052
1053                n_seqlets = int(pattern["seqlets/n_seqlets"][0])  # type: ignore
1054
1055                start_lst.append(starts)
1056                end_lst.append(ends)
1057                is_revcomp_lst.append(is_revcomps)
1058                strand_lst.extend(strands)
1059                peak_id_lst.append(peak_ids)
1060                pattern_tags.extend([pattern_tag for _ in range(n_seqlets)])
1061
1062    df_data = {
1063        "seqlet_start": np.concatenate(start_lst),
1064        "seqlet_end": np.concatenate(end_lst),
1065        "is_revcomp": np.concatenate(is_revcomp_lst),
1066        "strand": strand_lst,
1067        "peak_id": np.concatenate(peak_id_lst),
1068        "motif_name_orig": pattern_tags,
1069    }
1070
1071    offset = half_width - modisco_half_width
1072
1073    seqlets_df = (
1074        pl.LazyFrame(df_data)
1075        .join(motifs_df.lazy(), on=("motif_name_orig", "strand"), how="inner")
1076        .join(peaks_df.lazy(), on="peak_id", how="inner")
1077        .select(
1078            chr=pl.col("chr"),
1079            chr_id=pl.col("chr_id"),
1080            start=pl.col("peak_region_start")
1081            + pl.col("seqlet_start")
1082            + pl.col("motif_start")
1083            + offset,
1084            end=pl.col("peak_region_start")
1085            + pl.col("seqlet_start")
1086            + pl.col("motif_end")
1087            + offset,
1088            start_untrimmed=pl.col("peak_region_start")
1089            + pl.col("seqlet_start")
1090            + offset,
1091            end_untrimmed=pl.col("peak_region_start") + pl.col("seqlet_end") + offset,
1092            is_revcomp=pl.col("is_revcomp"),
1093            strand=pl.col("strand"),
1094            motif_name=pl.col("motif_name"),
1095            peak_id=pl.col("peak_id"),
1096            peak_region_start=pl.col("peak_region_start"),
1097        )
1098        .unique(subset=["chr_id", "start_untrimmed", "motif_name", "is_revcomp"])
1099    )
1100
1101    seqlets_df = seqlets_df if lazy else seqlets_df.collect()
1102
1103    return seqlets_df
1104
1105
1106def write_modisco_seqlets(
1107    seqlets_df: Union[pl.DataFrame, pl.LazyFrame], out_path: str
1108) -> None:
1109    """Write TF-MoDISco seqlets to TSV file.
1110
1111    Parameters
1112    ----------
1113    seqlets_df : Union[pl.DataFrame, pl.LazyFrame]
1114        Seqlets DataFrame with genomic coordinates. Must contain columns
1115        that are safe to drop: 'chr_id', 'is_revcomp'.
1116    out_path : str
1117        Output TSV file path.
1118
1119    Notes
1120    -----
1121    Removes internal columns 'chr_id' and 'is_revcomp' before writing
1122    to create a clean output format suitable for downstream analysis.
1123    """
1124    seqlets_df = seqlets_df.drop(["chr_id", "is_revcomp"])
1125    if isinstance(seqlets_df, pl.LazyFrame):
1126        seqlets_df = seqlets_df.collect()
1127    seqlets_df.write_csv(out_path, separator="\t")
1128
1129
1130HITS_DTYPES = {
1131    "chr": pl.String,
1132    "start": pl.Int32,
1133    "end": pl.Int32,
1134    "start_untrimmed": pl.Int32,
1135    "end_untrimmed": pl.Int32,
1136    "motif_name": pl.String,
1137    "hit_coefficient": pl.Float32,
1138    "hit_coefficient_global": pl.Float32,
1139    "hit_similarity": pl.Float32,
1140    "hit_correlation": pl.Float32,
1141    "hit_importance": pl.Float32,
1142    "hit_importance_sq": pl.Float32,
1143    "strand": pl.String,
1144    "peak_name": pl.String,
1145    "peak_id": pl.UInt32,
1146}
1147HITS_COLLAPSED_DTYPES = HITS_DTYPES | {"is_primary": pl.UInt32}
1148
1149
1150def load_hits(
1151    hits_path: str, lazy: bool = False, schema: Dict[str, Any] = HITS_DTYPES
1152) -> Union[pl.DataFrame, pl.LazyFrame]:
1153    """Load motif hit data from TSV file.
1154
1155    Parameters
1156    ----------
1157    hits_path : str
1158        Path to TSV file containing motif hit results.
1159    lazy : bool, default False
1160        If True, returns a LazyFrame for efficient chaining operations.
1161        If False, collects the result into a DataFrame.
1162    schema : Dict[str, Any], default HITS_DTYPES
1163        Schema defining column names and data types for the hit data.
1164
1165    Returns
1166    -------
1167    Union[pl.DataFrame, pl.LazyFrame]
1168        Hit data with an additional 'count' column set to 1 for aggregation.
1169    """
1170    hits_df = pl.scan_csv(
1171        hits_path, separator="\t", quote_char=None, schema=schema
1172    ).with_columns(pl.lit(1).alias("count"))
1173
1174    return hits_df if lazy else hits_df.collect()
1175
1176
1177def write_hits_processed(
1178    hits_df: Union[pl.DataFrame, pl.LazyFrame],
1179    out_path: str,
1180    schema: Optional[Dict[str, Any]] = HITS_DTYPES,
1181) -> None:
1182    """Write processed hit data to TSV file with optional column filtering.
1183
1184    Parameters
1185    ----------
1186    hits_df : Union[pl.DataFrame, pl.LazyFrame]
1187        Hit data to write to file.
1188    out_path : str
1189        Output path for the TSV file.
1190    schema : Optional[Dict[str, Any]], default HITS_DTYPES
1191        Schema defining which columns to include in output.
1192        If None, all columns are written.
1193    """
1194    if schema is not None:
1195        hits_df = hits_df.select(schema.keys())
1196
1197    if isinstance(hits_df, pl.LazyFrame):
1198        hits_df = hits_df.collect()
1199
1200    hits_df.write_csv(out_path, separator="\t")
1201
1202
1203def write_hits(
1204    hits_df: Union[pl.DataFrame, pl.LazyFrame],
1205    peaks_df: pl.DataFrame,
1206    motifs_df: pl.DataFrame,
1207    qc_df: pl.DataFrame,
1208    out_dir: str,
1209    motif_width: int,
1210) -> None:
1211    """Write comprehensive hit results to multiple output files.
1212
1213    This function combines hit data with peak, motif, and quality control information
1214    to generate complete output files including genomic coordinates and scores.
1215
1216    Parameters
1217    ----------
1218    hits_df : Union[pl.DataFrame, pl.LazyFrame]
1219        Hit data containing motif instance information.
1220    peaks_df : pl.DataFrame
1221        Peak region information for coordinate conversion.
1222    motifs_df : pl.DataFrame
1223        Motif metadata for annotation and trimming information.
1224    qc_df : pl.DataFrame
1225        Quality control data for normalization factors.
1226    out_dir : str
1227        Output directory for results files. Will be created if it doesn't exist.
1228    motif_width : int
1229        Width of motif instances for coordinate calculations.
1230
1231    Notes
1232    -----
1233    Creates three output files:
1234    - hits.tsv: Complete hit data with all instances
1235    - hits_unique.tsv: Deduplicated hits by genomic position and motif (excludes rows with NA chromosome coordinates)
1236    - hits.bed: BED format file for genome browser visualization
1237    
1238    Rows where the chromosome field is NA are filtered out during deduplication
1239    to ensure that data_unique only contains well-defined genomic coordinates.
1240    """
1241    os.makedirs(out_dir, exist_ok=True)
1242    out_path_tsv = os.path.join(out_dir, "hits.tsv")
1243    out_path_tsv_unique = os.path.join(out_dir, "hits_unique.tsv")
1244    out_path_bed = os.path.join(out_dir, "hits.bed")
1245
1246    data_all = (
1247        hits_df.lazy()
1248        .join(peaks_df.lazy(), on="peak_id", how="inner")
1249        .join(qc_df.lazy(), on="peak_id", how="inner")
1250        .join(motifs_df.lazy(), on="motif_id", how="inner")
1251        .select(
1252            chr_id=pl.col("chr_id"),
1253            chr=pl.col("chr"),
1254            start=pl.col("peak_region_start")
1255            + pl.col("hit_start")
1256            + pl.col("motif_start"),
1257            end=pl.col("peak_region_start") + pl.col("hit_start") + pl.col("motif_end"),
1258            start_untrimmed=pl.col("peak_region_start") + pl.col("hit_start"),
1259            end_untrimmed=pl.col("peak_region_start")
1260            + pl.col("hit_start")
1261            + motif_width,
1262            motif_name=pl.col("motif_name"),
1263            hit_coefficient=pl.col("hit_coefficient"),
1264            hit_coefficient_global=pl.col("hit_coefficient")
1265            * (pl.col("global_scale") ** 2),
1266            hit_similarity=pl.col("hit_similarity"),
1267            hit_correlation=pl.col("hit_similarity"),
1268            hit_importance=pl.col("hit_importance") * pl.col("global_scale"),
1269            hit_importance_sq=pl.col("hit_importance_sq")
1270            * (pl.col("global_scale") ** 2),
1271            strand=pl.col("strand"),
1272            peak_name=pl.col("peak_name"),
1273            peak_id=pl.col("peak_id"),
1274            motif_lambda=pl.col("lambda"),
1275        )
1276        .sort(["chr_id", "start"])
1277        .select(HITS_DTYPES.keys())
1278    )
1279
1280    data_unique = data_all.filter(pl.col("chr").is_not_null()).unique(
1281        subset=["chr", "start", "motif_name", "strand"], maintain_order=True
1282    )
1283
1284    data_bed = data_unique.select(
1285        chr=pl.col("chr"),
1286        start=pl.col("start"),
1287        end=pl.col("end"),
1288        motif_name=pl.col("motif_name"),
1289        score=pl.lit(0),
1290        strand=pl.col("strand"),
1291    )
1292
1293    data_all.collect().write_csv(out_path_tsv, separator="\t")
1294    data_unique.collect().write_csv(out_path_tsv_unique, separator="\t")
1295    data_bed.collect().write_csv(out_path_bed, include_header=False, separator="\t")
1296
1297
1298def write_qc(qc_df: pl.DataFrame, peaks_df: pl.DataFrame, out_path: str) -> None:
1299    """Write quality control data with peak information to TSV file.
1300
1301    Parameters
1302    ----------
1303    qc_df : pl.DataFrame
1304        Quality control metrics for each peak region.
1305    peaks_df : pl.DataFrame
1306        Peak region information for coordinate annotation.
1307    out_path : str
1308        Output path for the TSV file.
1309    """
1310    df = (
1311        qc_df.lazy()
1312        .join(peaks_df.lazy(), on="peak_id", how="inner")
1313        .sort(["chr_id", "peak_region_start"])
1314        .drop("chr_id")
1315        .collect()
1316    )
1317    df.write_csv(out_path, separator="\t")
1318
1319
1320def write_motifs_df(motifs_df: pl.DataFrame, out_path: str) -> None:
1321    """Write motif metadata to TSV file.
1322
1323    Parameters
1324    ----------
1325    motifs_df : pl.DataFrame
1326        Motif metadata DataFrame.
1327    out_path : str
1328        Output path for the TSV file.
1329    """
1330    motifs_df.write_csv(out_path, separator="\t")
1331
1332
1333MOTIF_DTYPES = {
1334    "motif_id": pl.UInt32,
1335    "motif_name": pl.String,
1336    "motif_name_orig": pl.String,
1337    "strand": pl.String,
1338    "motif_start": pl.UInt32,
1339    "motif_end": pl.UInt32,
1340    "motif_scale": pl.Float32,
1341    "lambda": pl.Float32,
1342}
1343
1344
1345def load_motifs_df(motifs_path: str) -> Tuple[pl.DataFrame, ndarray]:
1346    """Load motif metadata from TSV file.
1347
1348    Parameters
1349    ----------
1350    motifs_path : str
1351        Path to motif metadata TSV file.
1352
1353    Returns
1354    -------
1355    motifs_df : pl.DataFrame
1356        Motif metadata with predefined schema.
1357    motif_names : ndarray
1358        Array of unique forward-strand motif names.
1359    """
1360    motifs_df = pl.read_csv(motifs_path, separator="\t", schema=MOTIF_DTYPES)
1361    motif_names = (
1362        motifs_df.filter(pl.col("strand") == "+").get_column("motif_name").to_numpy()
1363    )
1364
1365    return motifs_df, motif_names
1366
1367
1368def write_motif_cwms(cwms: Float[ndarray, "M 4 W"], out_path: str) -> None:
1369    """Write motif contribution weight matrices to .npy file.
1370
1371    Parameters
1372    ----------
1373    cwms : Float[ndarray, "M 4 W"]
1374        Contribution weight matrices for M motifs, 4 nucleotides, W width.
1375    out_path : str
1376        Output path for the .npy file.
1377    """
1378    np.save(out_path, cwms)
1379
1380
1381def load_motif_cwms(cwms_path: str) -> Float[ndarray, "M 4 W"]:
1382    """Load motif contribution weight matrices from .npy file.
1383
1384    Parameters
1385    ----------
1386    cwms_path : str
1387        Path to .npy file containing CWMs.
1388
1389    Returns
1390    -------
1391    Float[ndarray, "M 4 W"]
1392        Loaded contribution weight matrices.
1393    """
1394    return np.load(cwms_path)
1395
1396
1397def write_params(params: Dict[str, Any], out_path: str) -> None:
1398    """Write parameter dictionary to JSON file.
1399
1400    Parameters
1401    ----------
1402    params : Dict[str, Any]
1403        Parameter dictionary to serialize.
1404    out_path : str
1405        Output path for the JSON file.
1406    """
1407    with open(out_path, "w") as f:
1408        json.dump(params, f, indent=4)
1409
1410
1411def load_params(params_path: str) -> Dict[str, Any]:
1412    """Load parameter dictionary from JSON file.
1413
1414    Parameters
1415    ----------
1416    params_path : str
1417        Path to JSON file containing parameters.
1418
1419    Returns
1420    -------
1421    Dict[str, Any]
1422        Loaded parameter dictionary.
1423    """
1424    with open(params_path) as f:
1425        params = json.load(f)
1426
1427    return params
1428
1429
1430def write_occ_df(occ_df: pl.DataFrame, out_path: str) -> None:
1431    """Write occurrence data to TSV file.
1432
1433    Parameters
1434    ----------
1435    occ_df : pl.DataFrame
1436        Occurrence data DataFrame.
1437    out_path : str
1438        Output path for the TSV file.
1439    """
1440    occ_df.write_csv(out_path, separator="\t")
1441
1442
1443def write_seqlet_confusion_df(seqlet_confusion_df: pl.DataFrame, out_path: str) -> None:
1444    """Write seqlet confusion matrix data to TSV file.
1445
1446    Parameters
1447    ----------
1448    seqlet_confusion_df : pl.DataFrame
1449        Seqlet confusion matrix DataFrame.
1450    out_path : str
1451        Output path for the TSV file.
1452    """
1453    seqlet_confusion_df.write_csv(out_path, separator="\t")
1454
1455
1456def write_report_data(
1457    report_df: pl.DataFrame, cwms: Dict[str, Dict[str, ndarray]], out_dir: str
1458) -> None:
1459    """Write comprehensive motif report data including CWMs and metadata.
1460
1461    Parameters
1462    ----------
1463    report_df : pl.DataFrame
1464        Report metadata DataFrame.
1465    cwms : Dict[str, Dict[str, ndarray]]
1466        Nested dictionary of motif names to CWM types to arrays.
1467    out_dir : str
1468        Output directory for report files.
1469    """
1470    cwms_dir = os.path.join(out_dir, "CWMs")
1471    os.makedirs(cwms_dir, exist_ok=True)
1472
1473    for m, v in cwms.items():
1474        motif_dir = os.path.join(cwms_dir, m)
1475        os.makedirs(motif_dir, exist_ok=True)
1476        for cwm_type, cwm in v.items():
1477            np.savetxt(os.path.join(motif_dir, f"{cwm_type}.txt"), cwm)
1478
1479    report_df.write_csv(os.path.join(out_dir, "motif_report.tsv"), separator="\t")
def load_txt(path: str) -> List[str]:
34def load_txt(path: str) -> List[str]:
35    """Load a text file containing one item per line.
36
37    Parameters
38    ----------
39    path : str
40        Path to the text file.
41
42    Returns
43    -------
44    List[str]
45        List of strings, one per line (first column if tab-delimited).
46    """
47    entries = []
48    with open(path) as f:
49        for line in f:
50            item = line.rstrip("\n").split("\t")[0]
51            entries.append(item)
52
53    return entries

Load a text file containing one item per line.

Parameters
  • path (str): Path to the text file.
Returns
  • List[str]: List of strings, one per line (first column if tab-delimited).
def load_mapping(path: str, value_type: Callable[[str], Any]) -> Dict[str, Any]:
56def load_mapping(path: str, value_type: Callable[[str], Any]) -> Dict[str, Any]:
57    """Load a two-column tab-delimited mapping file.
58
59    Parameters
60    ----------
61    path : str
62        Path to the mapping file. Must be tab-delimited with exactly two columns.
63    value_type : Callable[[str], Any]
64        Type constructor to apply to values (e.g., int, float, str).
65        Must accept a string and return the converted value.
66
67    Returns
68    -------
69    Dict[str, Any]
70        Dictionary mapping keys to values of the specified type.
71
72    Raises
73    ------
74    ValueError
75        If lines don't contain exactly two tab-separated values.
76    FileNotFoundError
77        If the specified file does not exist.
78    """
79    mapping = {}
80    with open(path) as f:
81        for line in f:
82            key, val = line.rstrip("\n").split("\t")
83            mapping[key] = value_type(val)
84
85    return mapping

Load a two-column tab-delimited mapping file.

Parameters
  • path (str): Path to the mapping file. Must be tab-delimited with exactly two columns.
  • value_type (Callable[[str], Any]): Type constructor to apply to values (e.g., int, float, str). Must accept a string and return the converted value.
Returns
  • Dict[str, Any]: Dictionary mapping keys to values of the specified type.
Raises
  • ValueError: If lines don't contain exactly two tab-separated values.
  • FileNotFoundError: If the specified file does not exist.
def load_mapping_tuple( path: str, value_type: Callable[[str], Any]) -> Dict[str, Tuple[Any, ...]]:
 88def load_mapping_tuple(
 89    path: str, value_type: Callable[[str], Any]
 90) -> Dict[str, Tuple[Any, ...]]:
 91    """Load a mapping file where values are tuples from multiple columns.
 92
 93    Parameters
 94    ----------
 95    path : str
 96        Path to the mapping file. Must be tab-delimited with multiple columns.
 97    value_type : Callable[[str], Any]
 98        Type constructor to apply to each value element.
 99        Must accept a string and return the converted value.
100
101    Returns
102    -------
103    Dict[str, Tuple[Any, ...]]
104        Dictionary mapping keys to tuples of values of the specified type.
105        The first column is used as the key, remaining columns as tuple values.
106
107    Raises
108    ------
109    ValueError
110        If lines don't contain at least two tab-separated values.
111    FileNotFoundError
112        If the specified file does not exist.
113    """
114    mapping = {}
115    with open(path) as f:
116        for line in f:
117            entries = line.rstrip("\n").split("\t")
118            key = entries[0]
119            val = entries[1:]
120            mapping[key] = tuple(value_type(i) for i in val)
121
122    return mapping

Load a mapping file where values are tuples from multiple columns.

Parameters
  • path (str): Path to the mapping file. Must be tab-delimited with multiple columns.
  • value_type (Callable[[str], Any]): Type constructor to apply to each value element. Must accept a string and return the converted value.
Returns
  • Dict[str, Tuple[Any, ...]]: Dictionary mapping keys to tuples of values of the specified type. The first column is used as the key, remaining columns as tuple values.
Raises
  • ValueError: If lines don't contain at least two tab-separated values.
  • FileNotFoundError: If the specified file does not exist.
NARROWPEAK_SCHEMA: List[str] = ['chr', 'peak_start', 'peak_end', 'peak_name', 'peak_score', 'peak_strand', 'peak_signal', 'peak_pval', 'peak_qval', 'peak_summit']
NARROWPEAK_DTYPES: List[Any] = [String, Int32, Int32, String, UInt32, String, Float32, Float32, Float32, Int32]
def load_peaks( peaks_path: str, chrom_order_path: Optional[str], half_width: int) -> polars.dataframe.frame.DataFrame:
152def load_peaks(
153    peaks_path: str, chrom_order_path: Optional[str], half_width: int
154) -> pl.DataFrame:
155    """Load peak region data from ENCODE NarrowPeak format file.
156
157    Parameters
158    ----------
159    peaks_path : str
160        Path to the NarrowPeak format file.
161    chrom_order_path : str, optional
162        Path to file defining chromosome ordering. If None, uses order from peaks file.
163    half_width : int
164        Half-width of regions around peak summits.
165
166    Returns
167    -------
168    pl.DataFrame
169        DataFrame containing peak information with columns:
170        - chr: Chromosome name
171        - peak_region_start: Start coordinate of centered region
172        - peak_name: Peak identifier
173        - peak_id: Sequential peak index
174        - chr_id: Numeric chromosome identifier
175    """
176    peaks = (
177        pl.scan_csv(
178            peaks_path,
179            has_header=False,
180            new_columns=NARROWPEAK_SCHEMA,
181            separator="\t",
182            quote_char=None,
183            schema_overrides=NARROWPEAK_DTYPES,
184            null_values=[".", "NA", "null", "NaN"],
185        )
186        .select(
187            chr=pl.col("chr"),
188            peak_region_start=pl.col("peak_start") + pl.col("peak_summit") - half_width,
189            peak_name=pl.col("peak_name"),
190        )
191        .with_row_index(name="peak_id")
192        .collect()
193    )
194
195    if chrom_order_path is not None:
196        chrom_order = load_txt(chrom_order_path)
197    else:
198        chrom_order = []
199
200    chrom_order_set = set(chrom_order)
201    chrom_order_peaks = [
202        i
203        for i in peaks.get_column("chr").unique(maintain_order=True)
204        if i not in chrom_order_set
205    ]
206    chrom_order.extend(chrom_order_peaks)
207    chrom_ind_map = {val: ind for ind, val in enumerate(chrom_order)}
208
209    peaks = peaks.with_columns(
210        pl.col("chr").replace_strict(chrom_ind_map).alias("chr_id")
211    )
212
213    return peaks

Load peak region data from ENCODE NarrowPeak format file.

Parameters
  • peaks_path (str): Path to the NarrowPeak format file.
  • chrom_order_path (str, optional): Path to file defining chromosome ordering. If None, uses order from peaks file.
  • half_width (int): Half-width of regions around peak summits.
Returns
  • pl.DataFrame: DataFrame containing peak information with columns:
    • chr: Chromosome name
    • peak_region_start: Start coordinate of centered region
    • peak_name: Peak identifier
    • peak_id: Sequential peak index
    • chr_id: Numeric chromosome identifier
SEQ_ALPHABET: numpy.ndarray = array([b'A', b'C', b'G', b'T'], dtype='|S1')
def one_hot_encode( sequence: str, dtype: Any = <class 'numpy.int8'>) -> jaxtyping.Int[ndarray, '4 L']:
220def one_hot_encode(sequence: str, dtype: Any = np.int8) -> Int[ndarray, "4 L"]:
221    """Convert DNA sequence string to one-hot encoded matrix.
222
223    Parameters
224    ----------
225    sequence : str
226        DNA sequence string containing A, C, G, T characters.
227    dtype : np.dtype, default np.int8
228        Data type for the output array.
229
230    Returns
231    -------
232    Int[ndarray, "4 L"]
233        One-hot encoded sequence where rows correspond to A, C, G, T and
234        L is the sequence length.
235
236    Notes
237    -----
238    The output array has shape (4, len(sequence)) with rows corresponding to
239    nucleotides A, C, G, T in that order. Non-standard nucleotides (N, etc.)
240    result in all-zero columns.
241    """
242    sequence = sequence.upper()
243
244    seq_chararray = np.frombuffer(sequence.encode("UTF-8"), dtype="S1")
245    one_hot = (seq_chararray[None, :] == SEQ_ALPHABET[:, None]).astype(dtype)
246
247    return one_hot

Convert DNA sequence string to one-hot encoded matrix.

Parameters
  • sequence (str): DNA sequence string containing A, C, G, T characters.
  • dtype (np.dtype, default np.int8): Data type for the output array.
Returns
  • Int[ndarray, "4 L"]: One-hot encoded sequence where rows correspond to A, C, G, T and L is the sequence length.
Notes

The output array has shape (4, len(sequence)) with rows corresponding to nucleotides A, C, G, T in that order. Non-standard nucleotides (N, etc.) result in all-zero columns.

def load_regions_from_bw( peaks: polars.dataframe.frame.DataFrame, fa_path: str, bw_paths: List[str], half_width: int) -> Tuple[jaxtyping.Int[ndarray, 'N 4 L'], jaxtyping.Float[ndarray, 'N L']]:
250def load_regions_from_bw(
251    peaks: pl.DataFrame, fa_path: str, bw_paths: List[str], half_width: int
252) -> Tuple[Int[ndarray, "N 4 L"], Float[ndarray, "N L"]]:
253    """Load genomic sequences and contribution scores from FASTA and bigWig files.
254
255    Parameters
256    ----------
257    peaks : pl.DataFrame
258        Peak regions DataFrame from load_peaks() containing columns:
259        'chr', 'peak_region_start'.
260    fa_path : str
261        Path to genome FASTA file (.fa or .fasta format).
262    bw_paths : List[str]
263        List of paths to bigWig files containing contribution scores.
264        Must be non-empty.
265    half_width : int
266        Half-width of regions to extract around peak centers.
267        Total region width will be 2 * half_width.
268
269    Returns
270    -------
271    sequences : Int[ndarray, "N 4 L"]
272        One-hot encoded DNA sequences where N is the number of peaks,
273        4 represents A,C,G,T nucleotides, and L is the region length (2 * half_width).
274    contribs : Float[ndarray, "N L"]
275        Contribution scores averaged across input bigWig files.
276        Shape is (N peaks, L region_length).
277
278    Notes
279    -----
280    BigWig files only provide projected contribution scores, not hypothetical scores.
281    Regions extending beyond chromosome boundaries are zero-padded.
282    Missing values in bigWig files are converted to zero.
283    """
284    num_peaks = peaks.height
285    region_width = half_width * 2
286
287    sequences = np.zeros((num_peaks, 4, region_width), dtype=np.int8)
288    contribs = np.zeros((num_peaks, region_width), dtype=np.float16)
289
290    # Load genome reference
291    genome = pyfaidx.Fasta(fa_path, one_based_attributes=False)
292
293    bws = [pyBigWig.open(i) for i in bw_paths]
294    contrib_buffer = np.zeros((len(bw_paths), half_width * 2), dtype=np.float16)
295
296    try:
297        for ind, row in tqdm(
298            enumerate(peaks.iter_rows(named=True)),
299            disable=None,
300            unit="regions",
301            total=num_peaks,
302        ):
303            chrom = row["chr"]
304            start = row["peak_region_start"]
305            end = start + 2 * half_width
306
307            sequence_data: pyfaidx.FastaRecord = genome[chrom][start:end]  # type: ignore
308            sequence: str = sequence_data.seq  # type: ignore
309            start_adj: int = sequence_data.start  # type: ignore
310            end_adj: int = sequence_data.end  # type: ignore
311            a = start_adj - start
312            b = end_adj - start
313
314            if b > a:
315                sequences[ind, :, a:b] = one_hot_encode(sequence)
316
317                for j, bw in enumerate(bws):
318                    contrib_buffer[j, :] = np.nan_to_num(
319                        bw.values(chrom, start_adj, end_adj, numpy=True)
320                    )
321
322                contribs[ind, a:b] = np.mean(contrib_buffer, axis=0)
323
324    finally:
325        for bw in bws:
326            bw.close()
327
328    return sequences, contribs

Load genomic sequences and contribution scores from FASTA and bigWig files.

Parameters
  • peaks (pl.DataFrame): Peak regions DataFrame from load_peaks() containing columns: 'chr', 'peak_region_start'.
  • fa_path (str): Path to genome FASTA file (.fa or .fasta format).
  • bw_paths (List[str]): List of paths to bigWig files containing contribution scores. Must be non-empty.
  • half_width (int): Half-width of regions to extract around peak centers. Total region width will be 2 * half_width.
Returns
  • sequences (Int[ndarray, "N 4 L"]): One-hot encoded DNA sequences where N is the number of peaks, 4 represents A,C,G,T nucleotides, and L is the region length (2 * half_width).
  • contribs (Float[ndarray, "N L"]): Contribution scores averaged across input bigWig files. Shape is (N peaks, L region_length).
Notes

BigWig files only provide projected contribution scores, not hypothetical scores. Regions extending beyond chromosome boundaries are zero-padded. Missing values in bigWig files are converted to zero.

def load_regions_from_chrombpnet_h5( h5_paths: List[str], half_width: int) -> Tuple[jaxtyping.Int[ndarray, 'N 4 L'], jaxtyping.Float[ndarray, 'N 4 L']]:
331def load_regions_from_chrombpnet_h5(
332    h5_paths: List[str], half_width: int
333) -> Tuple[Int[ndarray, "N 4 L"], Float[ndarray, "N 4 L"]]:
334    """Load genomic sequences and contribution scores from ChromBPNet HDF5 files.
335
336    Parameters
337    ----------
338    h5_paths : List[str]
339        List of paths to ChromBPNet HDF5 files containing sequences and SHAP scores.
340        Must be non-empty and contain compatible data shapes.
341    half_width : int
342        Half-width of regions to extract around the center.
343        Total region width will be 2 * half_width.
344
345    Returns
346    -------
347    sequences : Int[ndarray, "N 4 L"]
348        One-hot encoded DNA sequences where N is the number of regions,
349        4 represents A,C,G,T nucleotides, and L is the region length (2 * half_width).
350    contribs : Float[ndarray, "N 4 L"]
351        SHAP contribution scores averaged across input files.
352        Shape is (N regions, 4 nucleotides, L region_length).
353
354    Notes
355    -----
356    ChromBPNet files store sequences in 'raw/seq' and SHAP scores in 'shap/seq'.
357    All input files must have the same dimensions and number of regions.
358    Missing values in contribution scores are converted to zero.
359    """
360    with ExitStack() as stack:
361        h5s = [stack.enter_context(h5py.File(i)) for i in h5_paths]
362
363        start = h5s[0]["raw/seq"].shape[-1] // 2 - half_width  # type: ignore  # HDF5 array access
364        end = start + 2 * half_width
365
366        sequences = h5s[0]["raw/seq"][:, :, start:end].astype(np.int8)  # type: ignore  # HDF5 array access
367        contribs = np.mean(
368            [np.nan_to_num(f["shap/seq"][:, :, start:end]) for f in h5s],  # type: ignore  # HDF5 array access
369            axis=0,
370            dtype=np.float16,
371        )
372
373    return sequences, contribs  # type: ignore  # HDF5 arrays converted to NumPy

Load genomic sequences and contribution scores from ChromBPNet HDF5 files.

Parameters
  • h5_paths (List[str]): List of paths to ChromBPNet HDF5 files containing sequences and SHAP scores. Must be non-empty and contain compatible data shapes.
  • half_width (int): Half-width of regions to extract around the center. Total region width will be 2 * half_width.
Returns
  • sequences (Int[ndarray, "N 4 L"]): One-hot encoded DNA sequences where N is the number of regions, 4 represents A,C,G,T nucleotides, and L is the region length (2 * half_width).
  • contribs (Float[ndarray, "N 4 L"]): SHAP contribution scores averaged across input files. Shape is (N regions, 4 nucleotides, L region_length).
Notes

ChromBPNet files store sequences in 'raw/seq' and SHAP scores in 'shap/seq'. All input files must have the same dimensions and number of regions. Missing values in contribution scores are converted to zero.

def load_regions_from_bpnet_h5( h5_paths: List[str], half_width: int) -> Tuple[jaxtyping.Int[ndarray, 'N 4 L'], jaxtyping.Float[ndarray, 'N 4 L']]:
376def load_regions_from_bpnet_h5(
377    h5_paths: List[str], half_width: int
378) -> Tuple[Int[ndarray, "N 4 L"], Float[ndarray, "N 4 L"]]:
379    """Load genomic sequences and contribution scores from BPNet HDF5 files.
380
381    Parameters
382    ----------
383    h5_paths : List[str]
384        List of paths to BPNet HDF5 files containing sequences and contribution scores.
385        Must be non-empty and contain compatible data shapes.
386    half_width : int
387        Half-width of regions to extract around the center.
388        Total region width will be 2 * half_width.
389
390    Returns
391    -------
392    sequences : Int[ndarray, "N 4 L"]
393        One-hot encoded DNA sequences where N is the number of regions,
394        4 represents A,C,G,T nucleotides, and L is the region length (2 * half_width).
395    contribs : Float[ndarray, "N 4 L"]
396        Hypothetical contribution scores averaged across input files.
397        Shape is (N regions, 4 nucleotides, L region_length).
398
399    Notes
400    -----
401    BPNet files store sequences in 'input_seqs' and hypothetical scores in 'hyp_scores'.
402    The data requires axis swapping to convert from (n, length, 4) to (n, 4, length) format.
403    All input files must have the same dimensions and number of regions.
404    Missing values in contribution scores are converted to zero.
405    """
406    with ExitStack() as stack:
407        h5s = [stack.enter_context(h5py.File(i)) for i in h5_paths]
408
409        start = h5s[0]["input_seqs"].shape[-2] // 2 - half_width  # type: ignore  # HDF5 array access
410        end = start + 2 * half_width
411
412        sequences = h5s[0]["input_seqs"][:, start:end, :].swapaxes(1, 2).astype(np.int8)  # type: ignore  # HDF5 array access with axis swap
413        contribs = np.mean(
414            [
415                np.nan_to_num(f["hyp_scores"][:, start:end, :].swapaxes(1, 2))  # type: ignore  # HDF5 array access
416                for f in h5s
417            ],
418            axis=0,
419            dtype=np.float16,
420        )
421
422    return sequences, contribs

Load genomic sequences and contribution scores from BPNet HDF5 files.

Parameters
  • h5_paths (List[str]): List of paths to BPNet HDF5 files containing sequences and contribution scores. Must be non-empty and contain compatible data shapes.
  • half_width (int): Half-width of regions to extract around the center. Total region width will be 2 * half_width.
Returns
  • sequences (Int[ndarray, "N 4 L"]): One-hot encoded DNA sequences where N is the number of regions, 4 represents A,C,G,T nucleotides, and L is the region length (2 * half_width).
  • contribs (Float[ndarray, "N 4 L"]): Hypothetical contribution scores averaged across input files. Shape is (N regions, 4 nucleotides, L region_length).
Notes

BPNet files store sequences in 'input_seqs' and hypothetical scores in 'hyp_scores'. The data requires axis swapping to convert from (n, length, 4) to (n, 4, length) format. All input files must have the same dimensions and number of regions. Missing values in contribution scores are converted to zero.

def load_npy_or_npz(path: str) -> numpy.ndarray:
425def load_npy_or_npz(path: str) -> ndarray:
426    """Load array data from .npy or .npz file.
427
428    Parameters
429    ----------
430    path : str
431        Path to .npy or .npz file. File must exist and contain valid NumPy data.
432
433    Returns
434    -------
435    ndarray
436        Loaded array data. For .npz files, returns the first array ('arr_0').
437        For .npy files, returns the array directly.
438
439    Raises
440    ------
441    FileNotFoundError
442        If the specified file does not exist.
443    KeyError
444        If .npz file does not contain 'arr_0' key.
445    """
446    f = np.load(path)
447    if isinstance(f, np.ndarray):
448        arr = f
449    else:
450        arr = f["arr_0"]
451
452    return arr

Load array data from .npy or .npz file.

Parameters
  • path (str): Path to .npy or .npz file. File must exist and contain valid NumPy data.
Returns
  • ndarray: Loaded array data. For .npz files, returns the first array ('arr_0'). For .npy files, returns the array directly.
Raises
  • FileNotFoundError: If the specified file does not exist.
  • KeyError: If .npz file does not contain 'arr_0' key.
def load_regions_from_modisco_fmt( shaps_paths: List[str], ohe_path: str, half_width: int) -> Tuple[jaxtyping.Int[ndarray, 'N 4 L'], jaxtyping.Float[ndarray, 'N 4 L']]:
455def load_regions_from_modisco_fmt(
456    shaps_paths: List[str], ohe_path: str, half_width: int
457) -> Tuple[Int[ndarray, "N 4 L"], Float[ndarray, "N 4 L"]]:
458    """Load genomic sequences and contribution scores from TF-MoDISco format files.
459
460    Parameters
461    ----------
462    shaps_paths : List[str]
463        List of paths to .npy/.npz files containing SHAP/attribution scores.
464        Must be non-empty and all files must have compatible shapes.
465    ohe_path : str
466        Path to .npy/.npz file containing one-hot encoded sequences.
467        Must have shape (n_regions, 4, sequence_length).
468    half_width : int
469        Half-width of regions to extract around the center.
470        Total region width will be 2 * half_width.
471
472    Returns
473    -------
474    sequences : Int[ndarray, "N 4 L"]
475        One-hot encoded DNA sequences where N is the number of regions,
476        4 represents A,C,G,T nucleotides, and L is the region length (2 * half_width).
477    contribs : Float[ndarray, "N 4 L"]
478        SHAP contribution scores averaged across input files.
479        Shape is (N regions, 4 nucleotides, L region_length).
480
481    Notes
482    -----
483    All SHAP files must have the same shape as the sequence file.
484    Missing values in contribution scores are converted to zero.
485    The center of the input sequences is used as the reference point for extraction.
486    """
487    sequences_raw = load_npy_or_npz(ohe_path)
488
489    start = sequences_raw.shape[-1] // 2 - half_width
490    end = start + 2 * half_width
491
492    sequences = sequences_raw[:, :, start:end].astype(np.int8)
493
494    shaps = [np.nan_to_num(load_npy_or_npz(p)[:, :, start:end]) for p in shaps_paths]
495    contribs = np.mean(shaps, axis=0, dtype=np.float16)
496
497    return sequences, contribs

Load genomic sequences and contribution scores from TF-MoDISco format files.

Parameters
  • shaps_paths (List[str]): List of paths to .npy/.npz files containing SHAP/attribution scores. Must be non-empty and all files must have compatible shapes.
  • ohe_path (str): Path to .npy/.npz file containing one-hot encoded sequences. Must have shape (n_regions, 4, sequence_length).
  • half_width (int): Half-width of regions to extract around the center. Total region width will be 2 * half_width.
Returns
  • sequences (Int[ndarray, "N 4 L"]): One-hot encoded DNA sequences where N is the number of regions, 4 represents A,C,G,T nucleotides, and L is the region length (2 * half_width).
  • contribs (Float[ndarray, "N 4 L"]): SHAP contribution scores averaged across input files. Shape is (N regions, 4 nucleotides, L region_length).
Notes

All SHAP files must have the same shape as the sequence file. Missing values in contribution scores are converted to zero. The center of the input sequences is used as the reference point for extraction.

def load_regions_npz( npz_path: str) -> Tuple[jaxtyping.Int[ndarray, 'N 4 L'], Union[jaxtyping.Float[ndarray, 'N 4 L'], jaxtyping.Float[ndarray, 'N L']], polars.dataframe.frame.DataFrame, bool]:
500def load_regions_npz(
501    npz_path: str,
502) -> Tuple[
503    Int[ndarray, "N 4 L"],
504    Union[Float[ndarray, "N 4 L"], Float[ndarray, "N L"]],
505    pl.DataFrame,
506    bool,
507]:
508    """Load preprocessed genomic regions from NPZ file.
509
510    Parameters
511    ----------
512    npz_path : str
513        Path to NPZ file containing sequences, contributions, and optional coordinates.
514        Must contain 'sequences' and 'contributions' arrays at minimum.
515
516    Returns
517    -------
518    sequences : Int[ndarray, "N 4 L"]
519        One-hot encoded DNA sequences where N is the number of regions,
520        4 represents A,C,G,T nucleotides, and L is the region length.
521    contributions : Union[Float[ndarray, "N 4 L"], Float[ndarray, "N L"]]
522        Contribution scores in either hypothetical format (N, 4, L) or
523        projected format (N, L). Shape depends on the input data format.
524    peaks_df : pl.DataFrame
525        DataFrame containing peak region information with columns:
526        'chr', 'chr_id', 'peak_region_start', 'peak_id', 'peak_name'.
527    has_peaks : bool
528        Whether the file contains genomic coordinate information.
529        If False, placeholder coordinate data is used.
530
531    Notes
532    -----
533    If genomic coordinates are not present in the NPZ file, creates placeholder
534    coordinate data and issues a warning. The placeholder data uses 'NA' for
535    chromosome names and sequential indices for peak IDs.
536
537    Raises
538    ------
539    KeyError
540        If required arrays 'sequences' or 'contributions' are missing from the file.
541    """
542    data = np.load(npz_path)
543
544    if "chr" not in data.keys():
545        warnings.warn(
546            "No genome coordinates present in the input .npz file. Returning sequences and contributions only."
547        )
548        has_peaks = False
549        num_regions = data["sequences"].shape[0]
550        peak_data = {
551            "chr": np.array(["NA"] * num_regions, dtype="U"),
552            "chr_id": np.arange(num_regions, dtype=np.uint32),
553            "peak_region_start": np.zeros(num_regions, dtype=np.int32),
554            "peak_id": np.arange(num_regions, dtype=np.uint32),
555            "peak_name": np.array(["NA"] * num_regions, dtype="U"),
556        }
557
558    else:
559        has_peaks = True
560        peak_data = {
561            "chr": data["chr"],
562            "chr_id": data["chr_id"],
563            "peak_region_start": data["start"],
564            "peak_id": data["peak_id"],
565            "peak_name": data["peak_name"],
566        }
567
568    peaks_df = pl.DataFrame(peak_data)
569
570    return data["sequences"], data["contributions"], peaks_df, has_peaks

Load preprocessed genomic regions from NPZ file.

Parameters
  • npz_path (str): Path to NPZ file containing sequences, contributions, and optional coordinates. Must contain 'sequences' and 'contributions' arrays at minimum.
Returns
  • sequences (Int[ndarray, "N 4 L"]): One-hot encoded DNA sequences where N is the number of regions, 4 represents A,C,G,T nucleotides, and L is the region length.
  • contributions (Union[Float[ndarray, "N 4 L"], Float[ndarray, "N L"]]): Contribution scores in either hypothetical format (N, 4, L) or projected format (N, L). Shape depends on the input data format.
  • peaks_df (pl.DataFrame): DataFrame containing peak region information with columns: 'chr', 'chr_id', 'peak_region_start', 'peak_id', 'peak_name'.
  • has_peaks (bool): Whether the file contains genomic coordinate information. If False, placeholder coordinate data is used.
Notes

If genomic coordinates are not present in the NPZ file, creates placeholder coordinate data and issues a warning. The placeholder data uses 'NA' for chromosome names and sequential indices for peak IDs.

Raises
  • KeyError: If required arrays 'sequences' or 'contributions' are missing from the file.
def write_regions_npz( sequences: jaxtyping.Int[ndarray, 'N 4 L'], contributions: Union[jaxtyping.Float[ndarray, 'N 4 L'], jaxtyping.Float[ndarray, 'N L']], out_path: str, peaks_df: Optional[polars.dataframe.frame.DataFrame] = None) -> None:
573def write_regions_npz(
574    sequences: Int[ndarray, "N 4 L"],
575    contributions: Union[Float[ndarray, "N 4 L"], Float[ndarray, "N L"]],
576    out_path: str,
577    peaks_df: Optional[pl.DataFrame] = None,
578) -> None:
579    """Write genomic regions and contribution scores to compressed NPZ file.
580
581    Parameters
582    ----------
583    sequences : Int[ndarray, "N 4 L"]
584        One-hot encoded DNA sequences where N is the number of regions,
585        4 represents A,C,G,T nucleotides, and L is the region length.
586    contributions : Union[Float[ndarray, "N 4 L"], Float[ndarray, "N L"]]
587        Contribution scores in either hypothetical format (N, 4, L) or
588        projected format (N, L).
589    out_path : str
590        Output path for the NPZ file. Parent directory must exist.
591    peaks_df : Optional[pl.DataFrame]
592        DataFrame containing peak region information with columns:
593        'chr', 'chr_id', 'peak_region_start', 'peak_id', 'peak_name'.
594        If None, only sequences and contributions are saved.
595
596    Raises
597    ------
598    ValueError
599        If the number of regions in sequences/contributions doesn't match peaks_df.
600    FileNotFoundError
601        If the parent directory of out_path does not exist.
602
603    Notes
604    -----
605    The output file is compressed using NumPy's savez_compressed format.
606    If peaks_df is provided, genomic coordinate information is included
607    in the output file for downstream analysis.
608    """
609    if peaks_df is None:
610        warnings.warn(
611            "No genome coordinates provided. Writing sequences and contributions only."
612        )
613        np.savez_compressed(out_path, sequences=sequences, contributions=contributions)
614
615    else:
616        num_regions = peaks_df.height
617        if (num_regions != sequences.shape[0]) or (
618            num_regions != contributions.shape[0]
619        ):
620            raise ValueError(
621                f"Input sequences of shape {sequences.shape} and/or "
622                f"input contributions of shape {contributions.shape} "
623                f"are not compatible with peak region count of {num_regions}"
624            )
625
626        chr_arr = peaks_df.get_column("chr").to_numpy().astype("U")
627        chr_id_arr = peaks_df.get_column("chr_id").to_numpy()
628        start_arr = peaks_df.get_column("peak_region_start").to_numpy()
629        peak_id_arr = peaks_df.get_column("peak_id").to_numpy()
630        peak_name_arr = peaks_df.get_column("peak_name").to_numpy().astype("U")
631        np.savez_compressed(
632            out_path,
633            sequences=sequences,
634            contributions=contributions,
635            chr=chr_arr,
636            chr_id=chr_id_arr,
637            start=start_arr,
638            peak_id=peak_id_arr,
639            peak_name=peak_name_arr,
640        )

Write genomic regions and contribution scores to compressed NPZ file.

Parameters
  • sequences (Int[ndarray, "N 4 L"]): One-hot encoded DNA sequences where N is the number of regions, 4 represents A,C,G,T nucleotides, and L is the region length.
  • contributions (Union[Float[ndarray, "N 4 L"], Float[ndarray, "N L"]]): Contribution scores in either hypothetical format (N, 4, L) or projected format (N, L).
  • out_path (str): Output path for the NPZ file. Parent directory must exist.
  • peaks_df (Optional[pl.DataFrame]): DataFrame containing peak region information with columns: 'chr', 'chr_id', 'peak_region_start', 'peak_id', 'peak_name'. If None, only sequences and contributions are saved.
Raises
  • ValueError: If the number of regions in sequences/contributions doesn't match peaks_df.
  • FileNotFoundError: If the parent directory of out_path does not exist.
Notes

The output file is compressed using NumPy's savez_compressed format. If peaks_df is provided, genomic coordinate information is included in the output file for downstream analysis.

def trim_motif( cwm: jaxtyping.Float[ndarray, '4 W'], trim_threshold: float) -> Tuple[int, int]:
643def trim_motif(cwm: Float[ndarray, "4 W"], trim_threshold: float) -> Tuple[int, int]:
644    """Determine trimmed start and end positions for a motif based on contribution magnitude.
645
646    This function identifies the core region of a motif by finding positions where
647    the total absolute contribution exceeds a threshold relative to the maximum.
648
649    Parameters
650    ----------
651    cwm : Float[ndarray, "4 W"]
652        Contribution weight matrix for the motif where 4 represents A,C,G,T
653        nucleotides and W is the motif width.
654    trim_threshold : float
655        Fraction of maximum score to use as trimming threshold (0.0 to 1.0).
656        Higher values result in more aggressive trimming.
657
658    Returns
659    -------
660    start : int
661        Start position of the trimmed motif (inclusive).
662    end : int
663        End position of the trimmed motif (exclusive).
664
665    Notes
666    -----
667    The trimming is based on the sum of absolute contributions across all nucleotides
668    at each position. Positions with contributions below trim_threshold * max_score
669    are removed from the motif edges.
670
671    Adapted from https://github.com/jmschrei/tfmodisco-lite/blob/570535ee5ccf43d670e898d92d63af43d68c38c5/modiscolite/report.py#L213-L236
672    """
673    score = np.sum(np.abs(cwm), axis=0)
674    trim_thresh = np.max(score) * trim_threshold
675    pass_inds = np.nonzero(score >= trim_thresh)
676    start = max(int(np.min(pass_inds)), 0)  # type: ignore  # nonzero returns tuple of arrays
677    end = min(int(np.max(pass_inds)) + 1, len(score))  # type: ignore  # nonzero returns tuple of arrays
678
679    return start, end

Determine trimmed start and end positions for a motif based on contribution magnitude.

This function identifies the core region of a motif by finding positions where the total absolute contribution exceeds a threshold relative to the maximum.

Parameters
  • cwm (Float[ndarray, "4 W"]): Contribution weight matrix for the motif where 4 represents A,C,G,T nucleotides and W is the motif width.
  • trim_threshold (float): Fraction of maximum score to use as trimming threshold (0.0 to 1.0). Higher values result in more aggressive trimming.
Returns
  • start (int): Start position of the trimmed motif (inclusive).
  • end (int): End position of the trimmed motif (exclusive).
Notes

The trimming is based on the sum of absolute contributions across all nucleotides at each position. Positions with contributions below trim_threshold * max_score are removed from the motif edges.

Adapted from https://github.com/jmschrei/tfmodisco-lite/blob/570535ee5ccf43d670e898d92d63af43d68c38c5/modiscolite/report.py#L213-L236

def softmax( x: jaxtyping.Float[ndarray, '4 W'], temp: float = 100) -> jaxtyping.Float[ndarray, '4 W']:
682def softmax(x: Float[ndarray, "4 W"], temp: float = 100) -> Float[ndarray, "4 W"]:
683    """Apply softmax transformation with temperature scaling.
684
685    Parameters
686    ----------
687    x : Float[ndarray, "4 W"]
688        Input array to transform where 4 represents A,C,G,T nucleotides
689        and W is the motif width.
690    temp : float, default 100
691        Temperature parameter for softmax scaling. Higher values create
692        sharper probability distributions.
693
694    Returns
695    -------
696    Float[ndarray, "4 W"]
697        Softmax-transformed array with same shape as input. Each column
698        sums to 1.0, representing nucleotide probabilities at each position.
699
700    Notes
701    -----
702    The softmax is applied along the nucleotide axis (axis=0), normalizing
703    each position to have probabilities that sum to 1. The temperature
704    parameter controls the sharpness of the distribution.
705    """
706    norm_x = x - np.mean(x, axis=1, keepdims=True)
707    exp = np.exp(temp * norm_x)
708    return exp / np.sum(exp, axis=0, keepdims=True)

Apply softmax transformation with temperature scaling.

Parameters
  • x (Float[ndarray, "4 W"]): Input array to transform where 4 represents A,C,G,T nucleotides and W is the motif width.
  • temp (float, default 100): Temperature parameter for softmax scaling. Higher values create sharper probability distributions.
Returns
  • Float[ndarray, "4 W"]: Softmax-transformed array with same shape as input. Each column sums to 1.0, representing nucleotide probabilities at each position.
Notes

The softmax is applied along the nucleotide axis (axis=0), normalizing each position to have probabilities that sum to 1. The temperature parameter controls the sharpness of the distribution.

MODISCO_PATTERN_GROUPS = ['pos_patterns', 'neg_patterns']
def load_modisco_motifs( modisco_h5_path: str, trim_coords: Optional[Dict[str, Tuple[int, int]]], trim_thresholds: Optional[Dict[str, float]], trim_threshold_default: float, motif_type: str, motifs_include: Optional[List[str]], motif_name_map: Optional[Dict[str, str]], motif_lambdas: Optional[Dict[str, float]], motif_lambda_default: float, include_rc: bool) -> Tuple[polars.dataframe.frame.DataFrame, jaxtyping.Float[ndarray, 'M 4 W'], jaxtyping.Int[ndarray, 'M W'], numpy.ndarray]:
747def load_modisco_motifs(
748    modisco_h5_path: str,
749    trim_coords: Optional[Dict[str, Tuple[int, int]]],
750    trim_thresholds: Optional[Dict[str, float]],
751    trim_threshold_default: float,
752    motif_type: str,
753    motifs_include: Optional[List[str]],
754    motif_name_map: Optional[Dict[str, str]],
755    motif_lambdas: Optional[Dict[str, float]],
756    motif_lambda_default: float,
757    include_rc: bool,
758) -> Tuple[pl.DataFrame, Float[ndarray, "M 4 W"], Int[ndarray, "M W"], ndarray]:
759    """Load motif data from TF-MoDISco HDF5 file with customizable processing options.
760
761    This function extracts contribution weight matrices and associated metadata from
762    TF-MoDISco results, with support for custom naming, trimming, and regularization
763    parameters.
764
765    Parameters
766    ----------
767    modisco_h5_path : str
768        Path to TF-MoDISco HDF5 results file containing pattern groups.
769    trim_coords : Optional[Dict[str, Tuple[int, int]]]
770        Manual trim coordinates for specific motifs {motif_name: (start, end)}.
771        Takes precedence over automatic trimming based on thresholds.
772    trim_thresholds : Optional[Dict[str, float]]
773        Custom trim thresholds for specific motifs {motif_name: threshold}.
774        Values should be between 0.0 and 1.0.
775    trim_threshold_default : float
776        Default trim threshold for motifs not in trim_thresholds.
777        Fraction of maximum contribution used for trimming.
778    motif_type : str
779        Type of motif to extract. Must be one of:
780        - 'cwm': Contribution weight matrix (normalized)
781        - 'hcwm': Hypothetical contribution weight matrix
782        - 'pfm': Position frequency matrix
783        - 'pfm_softmax': Softmax-transformed position frequency matrix
784    motifs_include : Optional[List[str]]
785        List of motif names to include. If None, includes all motifs found.
786        Names should follow format 'pos_patterns.pattern_N' or 'neg_patterns.pattern_N'.
787    motif_name_map : Optional[Dict[str, str]]
788        Mapping from original to custom motif names {orig_name: new_name}.
789        New names must be unique across all motifs.
790    motif_lambdas : Optional[Dict[str, float]]
791        Custom lambda regularization values for specific motifs {motif_name: lambda}.
792        Higher values increase sparsity penalty for the corresponding motif.
793    motif_lambda_default : float
794        Default lambda value for motifs not specified in motif_lambdas.
795    include_rc : bool
796        Whether to include reverse complement motifs in addition to forward motifs.
797        If True, doubles the number of motifs returned.
798
799    Returns
800    -------
801    motifs_df : pl.DataFrame
802        DataFrame containing motif metadata with columns: motif_id, motif_name,
803        motif_name_orig, strand, motif_start, motif_end, motif_scale, lambda.
804    cwms : Float[ndarray, "M 4 W"]
805        Contribution weight matrices for all motifs where M is the number of motifs,
806        4 represents A,C,G,T nucleotides, and W is the motif width.
807    trim_masks : Int[ndarray, "M W"]
808        Binary masks indicating core motif regions (1) vs trimmed regions (0).
809        Shape is (M motifs, W motif_width).
810    names : ndarray
811        Array of unique motif names (forward strand only).
812
813    Raises
814    ------
815    ValueError
816        If motif_type is not one of the supported types, or if motif names
817        in motif_name_map are not unique.
818    FileNotFoundError
819        If the specified HDF5 file does not exist.
820    KeyError
821        If required datasets are missing from the HDF5 file.
822
823    Notes
824    -----
825    Motif trimming removes low-contribution positions from the edges based on
826    the position-wise sum of absolute contributions across nucleotides. The trimming
827    helps focus on the core binding site.
828
829    Adapted from https://github.com/jmschrei/tfmodisco-lite/blob/570535ee5ccf43d670e898d92d63af43d68c38c5/modiscolite/report.py#L252-L272
830    """
831    motif_data_lsts = {
832        "motif_name": [],
833        "motif_name_orig": [],
834        "strand": [],
835        "motif_start": [],
836        "motif_end": [],
837        "motif_scale": [],
838        "lambda": [],
839    }
840    motif_lst = []
841    trim_mask_lst = []
842
843    if motifs_include is not None:
844        motifs_include_set = set(motifs_include)
845    else:
846        motifs_include_set = None
847
848    if motif_name_map is None:
849        motif_name_map = {}
850
851    if motif_lambdas is None:
852        motif_lambdas = {}
853
854    if trim_coords is None:
855        trim_coords = {}
856    if trim_thresholds is None:
857        trim_thresholds = {}
858
859    if len(motif_name_map.values()) != len(set(motif_name_map.values())):
860        raise ValueError("Specified motif names are not unique")
861
862    with h5py.File(modisco_h5_path, "r") as modisco_results:
863        for name in MODISCO_PATTERN_GROUPS:
864            if name not in modisco_results.keys():
865                continue
866
867            metacluster = modisco_results[name]
868            for _, (pattern_name, pattern) in enumerate(
869                sorted(metacluster.items(), key=_motif_name_sort_key)  # type: ignore  # HDF5 access
870            ):
871                pattern_tag = f"{name}.{pattern_name}"
872
873                if (
874                    motifs_include_set is not None
875                    and pattern_tag not in motifs_include_set
876                ):
877                    continue
878
879                motif_lambda = motif_lambdas.get(pattern_tag, motif_lambda_default)
880                pattern_tag_orig = pattern_tag
881                pattern_tag = motif_name_map.get(pattern_tag, pattern_tag)
882
883                cwm_raw = pattern["contrib_scores"][:].T  # type: ignore
884                cwm_norm = np.sqrt((cwm_raw**2).sum())
885
886                cwm_fwd = cwm_raw / cwm_norm
887                cwm_rev = cwm_fwd[::-1, ::-1]
888
889                if pattern_tag in trim_coords:
890                    start_fwd, end_fwd = trim_coords[pattern_tag]
891                else:
892                    trim_threshold = trim_thresholds.get(
893                        pattern_tag, trim_threshold_default
894                    )
895                    start_fwd, end_fwd = trim_motif(cwm_fwd, trim_threshold)
896
897                cwm_len = cwm_fwd.shape[1]
898                start_rev, end_rev = cwm_len - end_fwd, cwm_len - start_fwd
899
900                trim_mask_fwd = np.zeros(cwm_fwd.shape[1], dtype=np.int8)
901                trim_mask_fwd[start_fwd:end_fwd] = 1
902                trim_mask_rev = np.zeros(cwm_rev.shape[1], dtype=np.int8)
903                trim_mask_rev[start_rev:end_rev] = 1
904
905                if motif_type == "cwm":
906                    motif_fwd = cwm_fwd
907                    motif_rev = cwm_rev
908                    motif_norm = cwm_norm
909
910                elif motif_type == "hcwm":
911                    motif_raw = pattern["hypothetical_contribs"][:].T  # type: ignore
912                    motif_norm = np.sqrt((motif_raw**2).sum())
913
914                    motif_fwd = motif_raw / motif_norm
915                    motif_rev = motif_fwd[::-1, ::-1]
916
917                elif motif_type == "pfm":
918                    motif_raw = pattern["sequence"][:].T  # type: ignore
919                    motif_norm = 1
920
921                    motif_fwd = motif_raw / np.sum(motif_raw, axis=0, keepdims=True)
922                    motif_rev = motif_fwd[::-1, ::-1]
923
924                elif motif_type == "pfm_softmax":
925                    motif_raw = pattern["sequence"][:].T  # type: ignore
926                    motif_norm = 1
927
928                    motif_fwd = softmax(motif_raw)
929                    motif_rev = motif_fwd[::-1, ::-1]
930
931                else:
932                    raise ValueError(
933                        f"Invalid motif_type: {motif_type}. Must be one of 'cwm', 'hcwm', 'pfm', 'pfm_softmax'."
934                    )
935
936                motif_data_lsts["motif_name"].append(pattern_tag)
937                motif_data_lsts["motif_name_orig"].append(pattern_tag_orig)
938                motif_data_lsts["strand"].append("+")
939                motif_data_lsts["motif_start"].append(start_fwd)
940                motif_data_lsts["motif_end"].append(end_fwd)
941                motif_data_lsts["motif_scale"].append(motif_norm)
942                motif_data_lsts["lambda"].append(motif_lambda)
943
944                if include_rc:
945                    motif_data_lsts["motif_name"].append(pattern_tag)
946                    motif_data_lsts["motif_name_orig"].append(pattern_tag_orig)
947                    motif_data_lsts["strand"].append("-")
948                    motif_data_lsts["motif_start"].append(start_rev)
949                    motif_data_lsts["motif_end"].append(end_rev)
950                    motif_data_lsts["motif_scale"].append(motif_norm)
951                    motif_data_lsts["lambda"].append(motif_lambda)
952
953                    motif_lst.extend([motif_fwd, motif_rev])
954                    trim_mask_lst.extend([trim_mask_fwd, trim_mask_rev])
955
956                else:
957                    motif_lst.append(motif_fwd)
958                    trim_mask_lst.append(trim_mask_fwd)
959
960    motifs_df = pl.DataFrame(motif_data_lsts).with_row_index(name="motif_id")
961    cwms = np.stack(motif_lst, dtype=np.float16, axis=0)
962    trim_masks = np.stack(trim_mask_lst, dtype=np.int8, axis=0)
963    names = (
964        motifs_df.filter(pl.col("strand") == "+").get_column("motif_name").to_numpy()
965    )
966
967    return motifs_df, cwms, trim_masks, names

Load motif data from TF-MoDISco HDF5 file with customizable processing options.

This function extracts contribution weight matrices and associated metadata from TF-MoDISco results, with support for custom naming, trimming, and regularization parameters.

Parameters
  • modisco_h5_path (str): Path to TF-MoDISco HDF5 results file containing pattern groups.
  • trim_coords (Optional[Dict[str, Tuple[int, int]]]): Manual trim coordinates for specific motifs {motif_name: (start, end)}. Takes precedence over automatic trimming based on thresholds.
  • trim_thresholds (Optional[Dict[str, float]]): Custom trim thresholds for specific motifs {motif_name: threshold}. Values should be between 0.0 and 1.0.
  • trim_threshold_default (float): Default trim threshold for motifs not in trim_thresholds. Fraction of maximum contribution used for trimming.
  • motif_type (str): Type of motif to extract. Must be one of:
    • 'cwm': Contribution weight matrix (normalized)
    • 'hcwm': Hypothetical contribution weight matrix
    • 'pfm': Position frequency matrix
    • 'pfm_softmax': Softmax-transformed position frequency matrix
  • motifs_include (Optional[List[str]]): List of motif names to include. If None, includes all motifs found. Names should follow format 'pos_patterns.pattern_N' or 'neg_patterns.pattern_N'.
  • motif_name_map (Optional[Dict[str, str]]): Mapping from original to custom motif names {orig_name: new_name}. New names must be unique across all motifs.
  • motif_lambdas (Optional[Dict[str, float]]): Custom lambda regularization values for specific motifs {motif_name: lambda}. Higher values increase sparsity penalty for the corresponding motif.
  • motif_lambda_default (float): Default lambda value for motifs not specified in motif_lambdas.
  • include_rc (bool): Whether to include reverse complement motifs in addition to forward motifs. If True, doubles the number of motifs returned.
Returns
  • motifs_df (pl.DataFrame): DataFrame containing motif metadata with columns: motif_id, motif_name, motif_name_orig, strand, motif_start, motif_end, motif_scale, lambda.
  • cwms (Float[ndarray, "M 4 W"]): Contribution weight matrices for all motifs where M is the number of motifs, 4 represents A,C,G,T nucleotides, and W is the motif width.
  • trim_masks (Int[ndarray, "M W"]): Binary masks indicating core motif regions (1) vs trimmed regions (0). Shape is (M motifs, W motif_width).
  • names (ndarray): Array of unique motif names (forward strand only).
Raises
  • ValueError: If motif_type is not one of the supported types, or if motif names in motif_name_map are not unique.
  • FileNotFoundError: If the specified HDF5 file does not exist.
  • KeyError: If required datasets are missing from the HDF5 file.
Notes

Motif trimming removes low-contribution positions from the edges based on the position-wise sum of absolute contributions across nucleotides. The trimming helps focus on the core binding site.

Adapted from https://github.com/jmschrei/tfmodisco-lite/blob/570535ee5ccf43d670e898d92d63af43d68c38c5/modiscolite/report.py#L252-L272

def load_modisco_seqlets( modisco_h5_path: str, peaks_df: polars.dataframe.frame.DataFrame, motifs_df: polars.dataframe.frame.DataFrame, half_width: int, modisco_half_width: int, lazy: bool = False) -> Union[polars.dataframe.frame.DataFrame, polars.lazyframe.frame.LazyFrame]:
 970def load_modisco_seqlets(
 971    modisco_h5_path: str,
 972    peaks_df: pl.DataFrame,
 973    motifs_df: pl.DataFrame,
 974    half_width: int,
 975    modisco_half_width: int,
 976    lazy: bool = False,
 977) -> Union[pl.DataFrame, pl.LazyFrame]:
 978    """Load seqlet data from TF-MoDISco HDF5 file and convert to genomic coordinates.
 979
 980    This function extracts seqlet instances from TF-MoDISco results and converts
 981    their relative positions to absolute genomic coordinates using peak region
 982    information.
 983
 984    Parameters
 985    ----------
 986    modisco_h5_path : str
 987        Path to TF-MoDISco HDF5 results file containing seqlet data.
 988    peaks_df : pl.DataFrame
 989        DataFrame containing peak region information with columns:
 990        'peak_id', 'chr', 'chr_id', 'peak_region_start'.
 991    motifs_df : pl.DataFrame
 992        DataFrame containing motif metadata with columns:
 993        'motif_name_orig', 'strand', 'motif_name', 'motif_start', 'motif_end'.
 994    half_width : int
 995        Half-width of the current analysis regions.
 996    modisco_half_width : int
 997        Half-width of the regions used in the original TF-MoDISco analysis.
 998        Used to calculate coordinate offsets.
 999    lazy : bool, default False
1000        If True, returns a LazyFrame for efficient chaining of operations.
1001        If False, collects the result into a DataFrame.
1002
1003    Returns
1004    -------
1005    Union[pl.DataFrame, pl.LazyFrame]
1006        Seqlets with genomic coordinates containing columns:
1007        - chr: Chromosome name
1008        - chr_id: Numeric chromosome identifier
1009        - start: Start coordinate of trimmed motif instance
1010        - end: End coordinate of trimmed motif instance
1011        - start_untrimmed: Start coordinate of full motif instance
1012        - end_untrimmed: End coordinate of full motif instance
1013        - is_revcomp: Whether the motif is reverse complemented
1014        - strand: Motif strand ('+' or '-')
1015        - motif_name: Motif name (may be remapped)
1016        - peak_id: Peak identifier
1017        - peak_region_start: Peak region start coordinate
1018
1019    Notes
1020    -----
1021    Seqlets are deduplicated based on chromosome ID, start position (untrimmed),
1022    motif name, and reverse complement status to avoid redundant instances.
1023
1024    The coordinate transformation accounts for differences in region sizes
1025    between the original TF-MoDISco analysis and the current analysis.
1026    """
1027
1028    start_lst = []
1029    end_lst = []
1030    is_revcomp_lst = []
1031    strand_lst = []
1032    peak_id_lst = []
1033    pattern_tags = []
1034
1035    with h5py.File(modisco_h5_path, "r") as modisco_results:
1036        for name in MODISCO_PATTERN_GROUPS:
1037            if name not in modisco_results.keys():
1038                continue
1039
1040            metacluster = modisco_results[name]
1041
1042            key = _motif_name_sort_key
1043            for _, (pattern_name, pattern) in enumerate(
1044                sorted(metacluster.items(), key=key)  # type: ignore  # HDF5 access
1045            ):
1046                pattern_tag = f"{name}.{pattern_name}"
1047
1048                starts = pattern["seqlets/start"][:].astype(np.int32)  # type: ignore
1049                ends = pattern["seqlets/end"][:].astype(np.int32)  # type: ignore
1050                is_revcomps = pattern["seqlets/is_revcomp"][:].astype(bool)  # type: ignore
1051                strands = ["+" if not i else "-" for i in is_revcomps]
1052                peak_ids = pattern["seqlets/example_idx"][:].astype(np.uint32)  # type: ignore
1053
1054                n_seqlets = int(pattern["seqlets/n_seqlets"][0])  # type: ignore
1055
1056                start_lst.append(starts)
1057                end_lst.append(ends)
1058                is_revcomp_lst.append(is_revcomps)
1059                strand_lst.extend(strands)
1060                peak_id_lst.append(peak_ids)
1061                pattern_tags.extend([pattern_tag for _ in range(n_seqlets)])
1062
1063    df_data = {
1064        "seqlet_start": np.concatenate(start_lst),
1065        "seqlet_end": np.concatenate(end_lst),
1066        "is_revcomp": np.concatenate(is_revcomp_lst),
1067        "strand": strand_lst,
1068        "peak_id": np.concatenate(peak_id_lst),
1069        "motif_name_orig": pattern_tags,
1070    }
1071
1072    offset = half_width - modisco_half_width
1073
1074    seqlets_df = (
1075        pl.LazyFrame(df_data)
1076        .join(motifs_df.lazy(), on=("motif_name_orig", "strand"), how="inner")
1077        .join(peaks_df.lazy(), on="peak_id", how="inner")
1078        .select(
1079            chr=pl.col("chr"),
1080            chr_id=pl.col("chr_id"),
1081            start=pl.col("peak_region_start")
1082            + pl.col("seqlet_start")
1083            + pl.col("motif_start")
1084            + offset,
1085            end=pl.col("peak_region_start")
1086            + pl.col("seqlet_start")
1087            + pl.col("motif_end")
1088            + offset,
1089            start_untrimmed=pl.col("peak_region_start")
1090            + pl.col("seqlet_start")
1091            + offset,
1092            end_untrimmed=pl.col("peak_region_start") + pl.col("seqlet_end") + offset,
1093            is_revcomp=pl.col("is_revcomp"),
1094            strand=pl.col("strand"),
1095            motif_name=pl.col("motif_name"),
1096            peak_id=pl.col("peak_id"),
1097            peak_region_start=pl.col("peak_region_start"),
1098        )
1099        .unique(subset=["chr_id", "start_untrimmed", "motif_name", "is_revcomp"])
1100    )
1101
1102    seqlets_df = seqlets_df if lazy else seqlets_df.collect()
1103
1104    return seqlets_df

Load seqlet data from TF-MoDISco HDF5 file and convert to genomic coordinates.

This function extracts seqlet instances from TF-MoDISco results and converts their relative positions to absolute genomic coordinates using peak region information.

Parameters
  • modisco_h5_path (str): Path to TF-MoDISco HDF5 results file containing seqlet data.
  • peaks_df (pl.DataFrame): DataFrame containing peak region information with columns: 'peak_id', 'chr', 'chr_id', 'peak_region_start'.
  • motifs_df (pl.DataFrame): DataFrame containing motif metadata with columns: 'motif_name_orig', 'strand', 'motif_name', 'motif_start', 'motif_end'.
  • half_width (int): Half-width of the current analysis regions.
  • modisco_half_width (int): Half-width of the regions used in the original TF-MoDISco analysis. Used to calculate coordinate offsets.
  • lazy (bool, default False): If True, returns a LazyFrame for efficient chaining of operations. If False, collects the result into a DataFrame.
Returns
  • Union[pl.DataFrame, pl.LazyFrame]: Seqlets with genomic coordinates containing columns:
    • chr: Chromosome name
    • chr_id: Numeric chromosome identifier
    • start: Start coordinate of trimmed motif instance
    • end: End coordinate of trimmed motif instance
    • start_untrimmed: Start coordinate of full motif instance
    • end_untrimmed: End coordinate of full motif instance
    • is_revcomp: Whether the motif is reverse complemented
    • strand: Motif strand ('+' or '-')
    • motif_name: Motif name (may be remapped)
    • peak_id: Peak identifier
    • peak_region_start: Peak region start coordinate
Notes

Seqlets are deduplicated based on chromosome ID, start position (untrimmed), motif name, and reverse complement status to avoid redundant instances.

The coordinate transformation accounts for differences in region sizes between the original TF-MoDISco analysis and the current analysis.

def write_modisco_seqlets( seqlets_df: Union[polars.dataframe.frame.DataFrame, polars.lazyframe.frame.LazyFrame], out_path: str) -> None:
1107def write_modisco_seqlets(
1108    seqlets_df: Union[pl.DataFrame, pl.LazyFrame], out_path: str
1109) -> None:
1110    """Write TF-MoDISco seqlets to TSV file.
1111
1112    Parameters
1113    ----------
1114    seqlets_df : Union[pl.DataFrame, pl.LazyFrame]
1115        Seqlets DataFrame with genomic coordinates. Must contain columns
1116        that are safe to drop: 'chr_id', 'is_revcomp'.
1117    out_path : str
1118        Output TSV file path.
1119
1120    Notes
1121    -----
1122    Removes internal columns 'chr_id' and 'is_revcomp' before writing
1123    to create a clean output format suitable for downstream analysis.
1124    """
1125    seqlets_df = seqlets_df.drop(["chr_id", "is_revcomp"])
1126    if isinstance(seqlets_df, pl.LazyFrame):
1127        seqlets_df = seqlets_df.collect()
1128    seqlets_df.write_csv(out_path, separator="\t")

Write TF-MoDISco seqlets to TSV file.

Parameters
  • seqlets_df (Union[pl.DataFrame, pl.LazyFrame]): Seqlets DataFrame with genomic coordinates. Must contain columns that are safe to drop: 'chr_id', 'is_revcomp'.
  • out_path (str): Output TSV file path.
Notes

Removes internal columns 'chr_id' and 'is_revcomp' before writing to create a clean output format suitable for downstream analysis.

HITS_DTYPES = {'chr': String, 'start': Int32, 'end': Int32, 'start_untrimmed': Int32, 'end_untrimmed': Int32, 'motif_name': String, 'hit_coefficient': Float32, 'hit_coefficient_global': Float32, 'hit_similarity': Float32, 'hit_correlation': Float32, 'hit_importance': Float32, 'hit_importance_sq': Float32, 'strand': String, 'peak_name': String, 'peak_id': UInt32}
HITS_COLLAPSED_DTYPES = {'chr': String, 'start': Int32, 'end': Int32, 'start_untrimmed': Int32, 'end_untrimmed': Int32, 'motif_name': String, 'hit_coefficient': Float32, 'hit_coefficient_global': Float32, 'hit_similarity': Float32, 'hit_correlation': Float32, 'hit_importance': Float32, 'hit_importance_sq': Float32, 'strand': String, 'peak_name': String, 'peak_id': UInt32, 'is_primary': UInt32}
def load_hits( hits_path: str, lazy: bool = False, schema: Dict[str, Any] = {'chr': String, 'start': Int32, 'end': Int32, 'start_untrimmed': Int32, 'end_untrimmed': Int32, 'motif_name': String, 'hit_coefficient': Float32, 'hit_coefficient_global': Float32, 'hit_similarity': Float32, 'hit_correlation': Float32, 'hit_importance': Float32, 'hit_importance_sq': Float32, 'strand': String, 'peak_name': String, 'peak_id': UInt32}) -> Union[polars.dataframe.frame.DataFrame, polars.lazyframe.frame.LazyFrame]:
1151def load_hits(
1152    hits_path: str, lazy: bool = False, schema: Dict[str, Any] = HITS_DTYPES
1153) -> Union[pl.DataFrame, pl.LazyFrame]:
1154    """Load motif hit data from TSV file.
1155
1156    Parameters
1157    ----------
1158    hits_path : str
1159        Path to TSV file containing motif hit results.
1160    lazy : bool, default False
1161        If True, returns a LazyFrame for efficient chaining operations.
1162        If False, collects the result into a DataFrame.
1163    schema : Dict[str, Any], default HITS_DTYPES
1164        Schema defining column names and data types for the hit data.
1165
1166    Returns
1167    -------
1168    Union[pl.DataFrame, pl.LazyFrame]
1169        Hit data with an additional 'count' column set to 1 for aggregation.
1170    """
1171    hits_df = pl.scan_csv(
1172        hits_path, separator="\t", quote_char=None, schema=schema
1173    ).with_columns(pl.lit(1).alias("count"))
1174
1175    return hits_df if lazy else hits_df.collect()

Load motif hit data from TSV file.

Parameters
  • hits_path (str): Path to TSV file containing motif hit results.
  • lazy (bool, default False): If True, returns a LazyFrame for efficient chaining operations. If False, collects the result into a DataFrame.
  • schema (Dict[str, Any], default HITS_DTYPES): Schema defining column names and data types for the hit data.
Returns
  • Union[pl.DataFrame, pl.LazyFrame]: Hit data with an additional 'count' column set to 1 for aggregation.
def write_hits_processed( hits_df: Union[polars.dataframe.frame.DataFrame, polars.lazyframe.frame.LazyFrame], out_path: str, schema: Optional[Dict[str, Any]] = {'chr': String, 'start': Int32, 'end': Int32, 'start_untrimmed': Int32, 'end_untrimmed': Int32, 'motif_name': String, 'hit_coefficient': Float32, 'hit_coefficient_global': Float32, 'hit_similarity': Float32, 'hit_correlation': Float32, 'hit_importance': Float32, 'hit_importance_sq': Float32, 'strand': String, 'peak_name': String, 'peak_id': UInt32}) -> None:
1178def write_hits_processed(
1179    hits_df: Union[pl.DataFrame, pl.LazyFrame],
1180    out_path: str,
1181    schema: Optional[Dict[str, Any]] = HITS_DTYPES,
1182) -> None:
1183    """Write processed hit data to TSV file with optional column filtering.
1184
1185    Parameters
1186    ----------
1187    hits_df : Union[pl.DataFrame, pl.LazyFrame]
1188        Hit data to write to file.
1189    out_path : str
1190        Output path for the TSV file.
1191    schema : Optional[Dict[str, Any]], default HITS_DTYPES
1192        Schema defining which columns to include in output.
1193        If None, all columns are written.
1194    """
1195    if schema is not None:
1196        hits_df = hits_df.select(schema.keys())
1197
1198    if isinstance(hits_df, pl.LazyFrame):
1199        hits_df = hits_df.collect()
1200
1201    hits_df.write_csv(out_path, separator="\t")

Write processed hit data to TSV file with optional column filtering.

Parameters
  • hits_df (Union[pl.DataFrame, pl.LazyFrame]): Hit data to write to file.
  • out_path (str): Output path for the TSV file.
  • schema (Optional[Dict[str, Any]], default HITS_DTYPES): Schema defining which columns to include in output. If None, all columns are written.
def write_hits( hits_df: Union[polars.dataframe.frame.DataFrame, polars.lazyframe.frame.LazyFrame], peaks_df: polars.dataframe.frame.DataFrame, motifs_df: polars.dataframe.frame.DataFrame, qc_df: polars.dataframe.frame.DataFrame, out_dir: str, motif_width: int) -> None:
1204def write_hits(
1205    hits_df: Union[pl.DataFrame, pl.LazyFrame],
1206    peaks_df: pl.DataFrame,
1207    motifs_df: pl.DataFrame,
1208    qc_df: pl.DataFrame,
1209    out_dir: str,
1210    motif_width: int,
1211) -> None:
1212    """Write comprehensive hit results to multiple output files.
1213
1214    This function combines hit data with peak, motif, and quality control information
1215    to generate complete output files including genomic coordinates and scores.
1216
1217    Parameters
1218    ----------
1219    hits_df : Union[pl.DataFrame, pl.LazyFrame]
1220        Hit data containing motif instance information.
1221    peaks_df : pl.DataFrame
1222        Peak region information for coordinate conversion.
1223    motifs_df : pl.DataFrame
1224        Motif metadata for annotation and trimming information.
1225    qc_df : pl.DataFrame
1226        Quality control data for normalization factors.
1227    out_dir : str
1228        Output directory for results files. Will be created if it doesn't exist.
1229    motif_width : int
1230        Width of motif instances for coordinate calculations.
1231
1232    Notes
1233    -----
1234    Creates three output files:
1235    - hits.tsv: Complete hit data with all instances
1236    - hits_unique.tsv: Deduplicated hits by genomic position and motif (excludes rows with NA chromosome coordinates)
1237    - hits.bed: BED format file for genome browser visualization
1238    
1239    Rows where the chromosome field is NA are filtered out during deduplication
1240    to ensure that data_unique only contains well-defined genomic coordinates.
1241    """
1242    os.makedirs(out_dir, exist_ok=True)
1243    out_path_tsv = os.path.join(out_dir, "hits.tsv")
1244    out_path_tsv_unique = os.path.join(out_dir, "hits_unique.tsv")
1245    out_path_bed = os.path.join(out_dir, "hits.bed")
1246
1247    data_all = (
1248        hits_df.lazy()
1249        .join(peaks_df.lazy(), on="peak_id", how="inner")
1250        .join(qc_df.lazy(), on="peak_id", how="inner")
1251        .join(motifs_df.lazy(), on="motif_id", how="inner")
1252        .select(
1253            chr_id=pl.col("chr_id"),
1254            chr=pl.col("chr"),
1255            start=pl.col("peak_region_start")
1256            + pl.col("hit_start")
1257            + pl.col("motif_start"),
1258            end=pl.col("peak_region_start") + pl.col("hit_start") + pl.col("motif_end"),
1259            start_untrimmed=pl.col("peak_region_start") + pl.col("hit_start"),
1260            end_untrimmed=pl.col("peak_region_start")
1261            + pl.col("hit_start")
1262            + motif_width,
1263            motif_name=pl.col("motif_name"),
1264            hit_coefficient=pl.col("hit_coefficient"),
1265            hit_coefficient_global=pl.col("hit_coefficient")
1266            * (pl.col("global_scale") ** 2),
1267            hit_similarity=pl.col("hit_similarity"),
1268            hit_correlation=pl.col("hit_similarity"),
1269            hit_importance=pl.col("hit_importance") * pl.col("global_scale"),
1270            hit_importance_sq=pl.col("hit_importance_sq")
1271            * (pl.col("global_scale") ** 2),
1272            strand=pl.col("strand"),
1273            peak_name=pl.col("peak_name"),
1274            peak_id=pl.col("peak_id"),
1275            motif_lambda=pl.col("lambda"),
1276        )
1277        .sort(["chr_id", "start"])
1278        .select(HITS_DTYPES.keys())
1279    )
1280
1281    data_unique = data_all.filter(pl.col("chr").is_not_null()).unique(
1282        subset=["chr", "start", "motif_name", "strand"], maintain_order=True
1283    )
1284
1285    data_bed = data_unique.select(
1286        chr=pl.col("chr"),
1287        start=pl.col("start"),
1288        end=pl.col("end"),
1289        motif_name=pl.col("motif_name"),
1290        score=pl.lit(0),
1291        strand=pl.col("strand"),
1292    )
1293
1294    data_all.collect().write_csv(out_path_tsv, separator="\t")
1295    data_unique.collect().write_csv(out_path_tsv_unique, separator="\t")
1296    data_bed.collect().write_csv(out_path_bed, include_header=False, separator="\t")

Write comprehensive hit results to multiple output files.

This function combines hit data with peak, motif, and quality control information to generate complete output files including genomic coordinates and scores.

Parameters
  • hits_df (Union[pl.DataFrame, pl.LazyFrame]): Hit data containing motif instance information.
  • peaks_df (pl.DataFrame): Peak region information for coordinate conversion.
  • motifs_df (pl.DataFrame): Motif metadata for annotation and trimming information.
  • qc_df (pl.DataFrame): Quality control data for normalization factors.
  • out_dir (str): Output directory for results files. Will be created if it doesn't exist.
  • motif_width (int): Width of motif instances for coordinate calculations.
Notes

Creates three output files:

  • hits.tsv: Complete hit data with all instances
  • hits_unique.tsv: Deduplicated hits by genomic position and motif (excludes rows with NA chromosome coordinates)
  • hits.bed: BED format file for genome browser visualization

Rows where the chromosome field is NA are filtered out during deduplication to ensure that data_unique only contains well-defined genomic coordinates.

def write_qc( qc_df: polars.dataframe.frame.DataFrame, peaks_df: polars.dataframe.frame.DataFrame, out_path: str) -> None:
1299def write_qc(qc_df: pl.DataFrame, peaks_df: pl.DataFrame, out_path: str) -> None:
1300    """Write quality control data with peak information to TSV file.
1301
1302    Parameters
1303    ----------
1304    qc_df : pl.DataFrame
1305        Quality control metrics for each peak region.
1306    peaks_df : pl.DataFrame
1307        Peak region information for coordinate annotation.
1308    out_path : str
1309        Output path for the TSV file.
1310    """
1311    df = (
1312        qc_df.lazy()
1313        .join(peaks_df.lazy(), on="peak_id", how="inner")
1314        .sort(["chr_id", "peak_region_start"])
1315        .drop("chr_id")
1316        .collect()
1317    )
1318    df.write_csv(out_path, separator="\t")

Write quality control data with peak information to TSV file.

Parameters
  • qc_df (pl.DataFrame): Quality control metrics for each peak region.
  • peaks_df (pl.DataFrame): Peak region information for coordinate annotation.
  • out_path (str): Output path for the TSV file.
def write_motifs_df(motifs_df: polars.dataframe.frame.DataFrame, out_path: str) -> None:
1321def write_motifs_df(motifs_df: pl.DataFrame, out_path: str) -> None:
1322    """Write motif metadata to TSV file.
1323
1324    Parameters
1325    ----------
1326    motifs_df : pl.DataFrame
1327        Motif metadata DataFrame.
1328    out_path : str
1329        Output path for the TSV file.
1330    """
1331    motifs_df.write_csv(out_path, separator="\t")

Write motif metadata to TSV file.

Parameters
  • motifs_df (pl.DataFrame): Motif metadata DataFrame.
  • out_path (str): Output path for the TSV file.
MOTIF_DTYPES = {'motif_id': UInt32, 'motif_name': String, 'motif_name_orig': String, 'strand': String, 'motif_start': UInt32, 'motif_end': UInt32, 'motif_scale': Float32, 'lambda': Float32}
def load_motifs_df( motifs_path: str) -> Tuple[polars.dataframe.frame.DataFrame, numpy.ndarray]:
1346def load_motifs_df(motifs_path: str) -> Tuple[pl.DataFrame, ndarray]:
1347    """Load motif metadata from TSV file.
1348
1349    Parameters
1350    ----------
1351    motifs_path : str
1352        Path to motif metadata TSV file.
1353
1354    Returns
1355    -------
1356    motifs_df : pl.DataFrame
1357        Motif metadata with predefined schema.
1358    motif_names : ndarray
1359        Array of unique forward-strand motif names.
1360    """
1361    motifs_df = pl.read_csv(motifs_path, separator="\t", schema=MOTIF_DTYPES)
1362    motif_names = (
1363        motifs_df.filter(pl.col("strand") == "+").get_column("motif_name").to_numpy()
1364    )
1365
1366    return motifs_df, motif_names

Load motif metadata from TSV file.

Parameters
  • motifs_path (str): Path to motif metadata TSV file.
Returns
  • motifs_df (pl.DataFrame): Motif metadata with predefined schema.
  • motif_names (ndarray): Array of unique forward-strand motif names.
def write_motif_cwms(cwms: jaxtyping.Float[ndarray, 'M 4 W'], out_path: str) -> None:
1369def write_motif_cwms(cwms: Float[ndarray, "M 4 W"], out_path: str) -> None:
1370    """Write motif contribution weight matrices to .npy file.
1371
1372    Parameters
1373    ----------
1374    cwms : Float[ndarray, "M 4 W"]
1375        Contribution weight matrices for M motifs, 4 nucleotides, W width.
1376    out_path : str
1377        Output path for the .npy file.
1378    """
1379    np.save(out_path, cwms)

Write motif contribution weight matrices to .npy file.

Parameters
  • cwms (Float[ndarray, "M 4 W"]): Contribution weight matrices for M motifs, 4 nucleotides, W width.
  • out_path (str): Output path for the .npy file.
def load_motif_cwms(cwms_path: str) -> jaxtyping.Float[ndarray, 'M 4 W']:
1382def load_motif_cwms(cwms_path: str) -> Float[ndarray, "M 4 W"]:
1383    """Load motif contribution weight matrices from .npy file.
1384
1385    Parameters
1386    ----------
1387    cwms_path : str
1388        Path to .npy file containing CWMs.
1389
1390    Returns
1391    -------
1392    Float[ndarray, "M 4 W"]
1393        Loaded contribution weight matrices.
1394    """
1395    return np.load(cwms_path)

Load motif contribution weight matrices from .npy file.

Parameters
  • cwms_path (str): Path to .npy file containing CWMs.
Returns
  • Float[ndarray, "M 4 W"]: Loaded contribution weight matrices.
def write_params(params: Dict[str, Any], out_path: str) -> None:
1398def write_params(params: Dict[str, Any], out_path: str) -> None:
1399    """Write parameter dictionary to JSON file.
1400
1401    Parameters
1402    ----------
1403    params : Dict[str, Any]
1404        Parameter dictionary to serialize.
1405    out_path : str
1406        Output path for the JSON file.
1407    """
1408    with open(out_path, "w") as f:
1409        json.dump(params, f, indent=4)

Write parameter dictionary to JSON file.

Parameters
  • params (Dict[str, Any]): Parameter dictionary to serialize.
  • out_path (str): Output path for the JSON file.
def load_params(params_path: str) -> Dict[str, Any]:
1412def load_params(params_path: str) -> Dict[str, Any]:
1413    """Load parameter dictionary from JSON file.
1414
1415    Parameters
1416    ----------
1417    params_path : str
1418        Path to JSON file containing parameters.
1419
1420    Returns
1421    -------
1422    Dict[str, Any]
1423        Loaded parameter dictionary.
1424    """
1425    with open(params_path) as f:
1426        params = json.load(f)
1427
1428    return params

Load parameter dictionary from JSON file.

Parameters
  • params_path (str): Path to JSON file containing parameters.
Returns
  • Dict[str, Any]: Loaded parameter dictionary.
def write_occ_df(occ_df: polars.dataframe.frame.DataFrame, out_path: str) -> None:
1431def write_occ_df(occ_df: pl.DataFrame, out_path: str) -> None:
1432    """Write occurrence data to TSV file.
1433
1434    Parameters
1435    ----------
1436    occ_df : pl.DataFrame
1437        Occurrence data DataFrame.
1438    out_path : str
1439        Output path for the TSV file.
1440    """
1441    occ_df.write_csv(out_path, separator="\t")

Write occurrence data to TSV file.

Parameters
  • occ_df (pl.DataFrame): Occurrence data DataFrame.
  • out_path (str): Output path for the TSV file.
def write_seqlet_confusion_df( seqlet_confusion_df: polars.dataframe.frame.DataFrame, out_path: str) -> None:
1444def write_seqlet_confusion_df(seqlet_confusion_df: pl.DataFrame, out_path: str) -> None:
1445    """Write seqlet confusion matrix data to TSV file.
1446
1447    Parameters
1448    ----------
1449    seqlet_confusion_df : pl.DataFrame
1450        Seqlet confusion matrix DataFrame.
1451    out_path : str
1452        Output path for the TSV file.
1453    """
1454    seqlet_confusion_df.write_csv(out_path, separator="\t")

Write seqlet confusion matrix data to TSV file.

Parameters
  • seqlet_confusion_df (pl.DataFrame): Seqlet confusion matrix DataFrame.
  • out_path (str): Output path for the TSV file.
def write_report_data( report_df: polars.dataframe.frame.DataFrame, cwms: Dict[str, Dict[str, numpy.ndarray]], out_dir: str) -> None:
1457def write_report_data(
1458    report_df: pl.DataFrame, cwms: Dict[str, Dict[str, ndarray]], out_dir: str
1459) -> None:
1460    """Write comprehensive motif report data including CWMs and metadata.
1461
1462    Parameters
1463    ----------
1464    report_df : pl.DataFrame
1465        Report metadata DataFrame.
1466    cwms : Dict[str, Dict[str, ndarray]]
1467        Nested dictionary of motif names to CWM types to arrays.
1468    out_dir : str
1469        Output directory for report files.
1470    """
1471    cwms_dir = os.path.join(out_dir, "CWMs")
1472    os.makedirs(cwms_dir, exist_ok=True)
1473
1474    for m, v in cwms.items():
1475        motif_dir = os.path.join(cwms_dir, m)
1476        os.makedirs(motif_dir, exist_ok=True)
1477        for cwm_type, cwm in v.items():
1478            np.savetxt(os.path.join(motif_dir, f"{cwm_type}.txt"), cwm)
1479
1480    report_df.write_csv(os.path.join(out_dir, "motif_report.tsv"), separator="\t")

Write comprehensive motif report data including CWMs and metadata.

Parameters
  • report_df (pl.DataFrame): Report metadata DataFrame.
  • cwms (Dict[str, Dict[str, ndarray]]): Nested dictionary of motif names to CWM types to arrays.
  • out_dir (str): Output directory for report files.