jimg_ncd.utils

   1import copy
   2import io
   3import math
   4import os
   5import pickle
   6import re
   7import sys
   8import tarfile
   9import tkinter as tk
  10from itertools import combinations
  11
  12import cv2
  13import gdown
  14import matplotlib.pyplot as plt
  15import numpy as np
  16import pandas as pd
  17import plotly.express as px
  18import tifffile as tiff
  19from joblib import Parallel, delayed
  20from scipy import stats
  21from scipy.spatial import ConvexHull, cKDTree
  22from scipy.stats import chi2_contingency
  23from tqdm import tqdm
  24
  25sys.stdout = io.StringIO()
  26
  27
  28def umap_html(umap_result, width=1000, height=1200):
  29    """
  30    Create an interactive HTML UMAP scatter plot.
  31
  32    Parameters
  33    ----------
  34    umap_result : pandas.DataFrame or dict-like
  35        UMAP embedding containing at least:
  36        - `0` : array-like, UMAP dimension 1
  37        - `1` : array-like, UMAP dimension 2
  38        - `'clusters'` : array-like, assigned cluster labels
  39
  40    width : int, optional
  41        Width of the output figure in pixels. Default is 1000.
  42
  43    height : int, optional
  44        Height of the output figure in pixels. Default is 1200.
  45
  46    Returns
  47    -------
  48    plotly.graph_objs._figure.Figure
  49        Interactive Plotly scatter plot object visualizing UMAP with colored clusters.
  50    """
  51
  52    fig = px.scatter(
  53        x=umap_result[0],
  54        y=umap_result[1],
  55        color=umap_result["clusters"],
  56        labels={"color": "Cells"},
  57        template="simple_white",
  58        width=width,
  59        height=height,
  60        render_mode="svg",
  61        color_discrete_sequence=px.colors.qualitative.Dark24
  62        + px.colors.qualitative.Light24,
  63    )
  64
  65    fig.update_xaxes(title_text="UMAP 1")
  66    fig.update_yaxes(title_text="UMAP 2")
  67
  68    return fig
  69
  70
  71def umap_static(umap_result, width=10, height=13, n_per_col=20):
  72    """
  73    Create a static matplotlib UMAP scatter plot.
  74
  75    Parameters
  76    ----------
  77    umap_result : pandas.DataFrame or dict-like
  78        UMAP projection containing:
  79        - `0` : array-like, UMAP dimension 1
  80        - `1` : array-like, UMAP dimension 2
  81        - `'clusters'` : array-like, cluster assignments
  82
  83    width : float, optional
  84        Width of the figure in inches. Default is 10.
  85
  86    height : float, optional
  87        Height of the figure in inches. Default is 13.
  88
  89    n_per_col : int, optional
  90        Maximum number of legend entries per column. Default is 20.
  91
  92    Returns
  93    -------
  94    matplotlib.figure.Figure
  95        Static matplotlib figure representing the UMAP embedding with clusters.
  96    """
  97
  98    plotly_colors = px.colors.qualitative.Dark24 + px.colors.qualitative.Light24
  99    num_colors = len(plotly_colors)
 100
 101    fig = plt.figure(figsize=(width, height))
 102
 103    sorted_labels = pd.unique(umap_result["clusters"])
 104
 105    color_map = {
 106        label: plotly_colors[i % num_colors] for i, label in enumerate(sorted_labels)
 107    }
 108
 109    for label in sorted_labels:
 110        subset = umap_result[umap_result["clusters"] == label]
 111        plt.scatter(
 112            subset[0],
 113            subset[1],
 114            c=[color_map[label]],
 115            label=f"Cluster {label}",
 116            alpha=0.7,
 117            s=20,
 118            edgecolor="black",
 119            linewidths=0.1,
 120        )
 121
 122    n_col = -(-len(set(umap_result["clusters"])) // n_per_col)
 123
 124    plt.xlabel("UMAP 1", fontsize=14)
 125    plt.ylabel("UMAP 2", fontsize=14)
 126    plt.grid(True, which="both", linestyle="--", linewidth=0.1)
 127    plt.legend(bbox_to_anchor=(1.05, 1), loc="upper left", ncol=n_col)
 128    plt.tight_layout()
 129
 130    return fig
 131
 132
 133def test_data(path=""):
 134    """
 135    Download and extract test data from Google Drive.
 136
 137    This function downloads a compressed archive containing example test data
 138    and extracts it into the specified directory. The data is fetched using
 139    a direct Google Drive link. If the download or extraction fails, an
 140    error message is printed.
 141
 142    Parameters
 143    ----------
 144    path : str, optional
 145        Destination directory where the test dataset will be downloaded and
 146        extracted. Defaults to the current working directory.
 147
 148
 149    Notes
 150    -----
 151    - The downloaded file is named ``test_data.tar.gz``.
 152    - The archive is extracted into ``<path>/test_data``.
 153    - In case of any failure (download or extraction), the function prints
 154      an informative message instead of raising an exception.
 155    """
 156
 157    try:
 158
 159        file_name = "test_data.tar.gz"
 160
 161        file_name = os.path.join(path, file_name)
 162
 163        url = "https://drive.google.com/uc?id=1MhzhleMP7iTzlBVW8eP5sFaonJdg1a3T"
 164
 165        gdown.download(url, file_name, quiet=False)
 166
 167        # Unzip
 168
 169        with tarfile.open(file_name, "r:gz") as tar:
 170            tar.extractall(path=path)
 171
 172        print(
 173            f"\nTest data downloaded succesfully -> {os.path.join(path, 'test_data')}"
 174        )
 175
 176    except:
 177
 178        print(
 179            "\nTest data could not be downloaded. Please check your connection and try again!"
 180        )
 181
 182
 183def prop_plot(df_pivot, chi_df):
 184    """
 185    Create a stacked bar plot of proportional data with post-hoc significance annotations.
 186
 187    Parameters
 188    ----------
 189    df_pivot : pandas.DataFrame
 190        Pivot table where rows represent categories (e.g., compartments) and columns
 191        represent groups. Values are counts or frequencies.
 192
 193    chi_df : pandas.DataFrame
 194        DataFrame containing pairwise Chi-square test results with an added
 195        'Significance_Label' column (e.g., '***', '**', '*', 'ns') for each pair
 196        of groups. Typically output from `chi_pairs` and `get_significance_label`.
 197
 198    Returns
 199    -------
 200    matplotlib.figure.Figure
 201        The Matplotlib figure object containing the stacked bar plot.
 202
 203    Notes
 204    -----
 205    - The function converts raw counts to percentages per group for visualization.
 206    - Each pairwise comparison and its significance label is displayed as a text box
 207      next to the plot.
 208    - Colors are assigned using the 'viridis' colormap by default.
 209    - The plot is configured for clarity with labeled axes, legend, and appropriately
 210      sized text.
 211    """
 212
 213    df_pivot_perc = df_pivot.div(df_pivot.sum(axis=0), axis=1) * 100
 214
 215    chi_df = chi_df.sort_values(by="p-value", ascending=True)
 216
 217    df_pivot_perc = df_pivot_perc.T.sort_values(by=list(df_pivot_perc.index)).T
 218
 219    posthoc_text = "\n".join(
 220        [
 221            f"{row['Group 1']}{row['Group 2']}: {row['Significance_Label']}"
 222            for _, row in chi_df.iterrows()
 223        ]
 224    )
 225
 226    fig, ax = plt.subplots(figsize=(12, 7))
 227
 228    df_pivot_perc.T.plot(kind="bar", stacked=True, ax=ax, cmap="viridis")
 229
 230    ax.set_ylabel("Percentage (%)", fontsize=16)
 231    ax.set_xlabel("Groups", fontsize=16)
 232
 233    ax.tick_params(axis="both", labelsize=12)
 234
 235    ax.legend(
 236        title="Compartment", loc="upper left", bbox_to_anchor=(1.02, 1.05), fontsize=14
 237    )
 238
 239    plt.figtext(
 240        0.93,
 241        0.6,
 242        posthoc_text,
 243        ha="left",
 244        va="top",
 245        fontsize=12,
 246        bbox={"facecolor": "white", "alpha": 0.7, "pad": 5},
 247    )
 248
 249    return fig
 250
 251
 252def get_significance_label(p_value):
 253    """
 254    Return a standard significance label based on a p-value.
 255
 256    Parameters
 257    ----------
 258    p_value : float
 259        The p-value for which the significance label should be determined.
 260
 261    Returns
 262    -------
 263    str
 264        A significance marker commonly used in statistical reporting:
 265
 266        - '***' : p < 0.001
 267        - '**'  : p < 0.01
 268        - '*'   : p < 0.05
 269        - 'ns'  : not significant (p ≥ 0.05)
 270
 271    Notes
 272    -----
 273    This helper function is typically used for annotating statistical test
 274    results in tables or visualizations. Thresholds follow conventional
 275    statistical notation for significance levels.
 276    """
 277
 278    if p_value < 0.001:
 279        return "***"
 280    elif p_value < 0.01:
 281        return "**"
 282    elif p_value < 0.05:
 283        return "*"
 284    else:
 285        return "ns"
 286
 287
 288def chi_pairs(df_pivot):
 289    """
 290    Compute pairwise Chi-square tests for all combinations of groups in a pivoted dataframe.
 291
 292    Parameters
 293    ----------
 294    df_pivot : pandas.DataFrame
 295        A pivot table where rows represent categories and columns represent groups.
 296        Values should be counts (frequencies). The function will add +1 to each cell
 297        to avoid zero counts during chi-square computation.
 298
 299    Returns
 300    -------
 301    pandas.DataFrame
 302        A DataFrame containing pairwise Chi-square test results with the following columns:
 303        - 'Group 1' : str
 304            Name of the first group in the pair.
 305        - 'Group 2' : str
 306            Name of the second group in the pair.
 307        - 'Chi²' : float
 308            The Chi-square statistic for the comparison.
 309        - 'p-value' : float
 310            The p-value of the Chi-square test.
 311
 312    Notes
 313    -----
 314    The function compares every possible pair of columns using `scipy.stats.chi2_contingency`.
 315    Yates' correction is applied by default unless disabled in the SciPy version used.
 316    A value of 1 is added to all cells to avoid issues with zero frequencies.
 317    """
 318
 319    group_pairs = list(combinations(df_pivot.columns, 2))
 320
 321    posthoc_results = []
 322
 323    for group1, group2 in group_pairs:
 324        sub_table = df_pivot.T.loc[[group1, group2]] + 1
 325        chi2, p, dof, expected = chi2_contingency(sub_table)
 326
 327        posthoc_results.append(
 328            {"Group 1": group1, "Group 2": group2, "Chi²": chi2, "p-value": p}
 329        )
 330
 331    posthoc_df = pd.DataFrame(posthoc_results)
 332
 333    return posthoc_df
 334
 335
 336def statistic(input_df, sets=None, metadata=None, n_proc=10):
 337    """
 338    Compute statistical comparison between cell groups or clusters.
 339
 340    This function performs differential feature analysis between either:
 341    (1) every group vs. all other groups (default mode), or
 342    (2) two user-defined groups specified in ``sets``.
 343
 344    The analysis includes:
 345    - Mann–Whitney U test
 346    - Percentage of non-zero values
 347    - Means and standard deviations
 348    - Effect size metric (ESM)
 349    - Benjamini–Hochberg FDR correction
 350    - Fold-change and log2 fold-change
 351
 352
 353    Parameters
 354    ----------
 355    input_df : pandas.DataFrame
 356        Input feature matrix where rows represent features and columns represent cells.
 357        The function transposes this table internally, treating columns as features.
 358
 359    sets : dict or None, optional
 360        Mode selection:
 361        - ``None`` (default): each unique label in ``metadata['sets']`` is compared
 362          against all remaining groups.
 363        - ``dict``: must contain exactly two keys, each mapping to a list of labels
 364          belonging to each comparison group. Example:
 365          ``{'A': ['T1', 'T2'], 'B': ['C1', 'C2']}``.
 366
 367    metadata : pandas.DataFrame, optional
 368        Metadata containing at least a column ``'sets'`` with group labels
 369        corresponding to columns of ``input_df``.
 370
 371    n_proc : int, optional
 372        Number of parallel processes used for statistical computation.
 373        Default is ``10``.
 374
 375    Returns
 376    -------
 377    pandas.DataFrame or None
 378        A DataFrame containing statistical results for each feature, including:
 379
 380        - ``feature`` : str
 381        - ``p_val`` : float
 382        - ``adj_pval`` : float
 383        - ``pct_valid`` : float
 384        - ``pct_ctrl`` : float
 385        - ``avg_valid`` : float
 386        - ``avg_ctrl`` : float
 387        - ``sd_valid`` : float
 388        - ``sd_ctrl`` : float
 389        - ``esm`` : float
 390        - ``FC`` : float
 391        - ``log(FC)`` : float
 392        - ``norm_diff`` : float
 393        - ``valid_group`` : str
 394        - ``-log(p_val)`` : float
 395
 396        If ``sets`` is ``None``, results for each group are concatenated.
 397
 398        Returns ``None`` in case of errors or invalid parameters.
 399
 400    Notes
 401    -----
 402    - Columns containing only zeros are automatically removed.
 403    - p-values equal for both groups produce ``p_val = 1``.
 404    - Benjamini–Hochberg correction is applied separately within each group comparison.
 405    - Fold-change is stabilized using a small, data-derived ``low_factor``.
 406    - Uses ``Mann–Whitney U`` test with ``alternative='two-sided'``.
 407
 408    Raises
 409    ------
 410    None
 411        All exceptions are caught internally and printed as messages.
 412
 413    Examples
 414    --------
 415    >>> df = pd.DataFrame(...)
 416    >>> meta = pd.DataFrame({'sets': [...]})
 417    >>> stat = statistic(df, metadata=meta)
 418    >>> stat.head()
 419
 420    >>> # Compare two groups explicitly
 421    >>> sets = {'A': ['Type1'], 'B': ['Type2']}
 422    >>> stat = statistic(df, sets=sets, metadata=meta, n_proc=4)
 423    """
 424    try:
 425        offset = 1e-100
 426
 427        def stat_calc(choose, feature_name):
 428            target_values = choose.loc[choose["DEG"] == "target", feature_name]
 429            rest_values = choose.loc[choose["DEG"] == "rest", feature_name]
 430
 431            pct_valid = (target_values > 0).sum() / len(target_values)
 432            pct_rest = (rest_values > 0).sum() / len(rest_values)
 433
 434            avg_valid = np.mean(target_values)
 435            avg_ctrl = np.mean(rest_values)
 436            sd_valid = np.std(target_values, ddof=1)
 437            sd_ctrl = np.std(rest_values, ddof=1)
 438            esm = (avg_valid - avg_ctrl) / np.sqrt(((sd_valid**2 + sd_ctrl**2) / 2))
 439
 440            if np.sum(target_values) == np.sum(rest_values):
 441                p_val = 1.0
 442            else:
 443                _, p_val = stats.mannwhitneyu(
 444                    target_values, rest_values, alternative="two-sided"
 445                )
 446
 447            return {
 448                "feature": feature_name,
 449                "p_val": p_val,
 450                "pct_valid": pct_valid,
 451                "pct_ctrl": pct_rest,
 452                "avg_valid": avg_valid,
 453                "avg_ctrl": avg_ctrl,
 454                "sd_valid": sd_valid,
 455                "sd_ctrl": sd_ctrl,
 456                "esm": esm,
 457            }
 458
 459        def safe_min_half(series):
 460            filtered = series[(series > ((2**-1074) * 2)) & (series.notna())]
 461            return filtered.min() / 2 if not filtered.empty else 0
 462
 463        # Transpose the input DataFrame
 464        choose = input_df.copy().T
 465
 466        if sets is None:
 467            print("\nAnalysis started...")
 468            print("\nComparing each type of cell to others...")
 469            final_results = []
 470
 471            if len(set(metadata["sets"])) > 1:
 472                choose.index = metadata["sets"]
 473
 474            indexes = list(choose.index)
 475
 476            for c in set(indexes):
 477                print(f"Calculating statistics for {c}")
 478
 479                choose.index = indexes
 480                choose["DEG"] = np.where(choose.index == c, "target", "rest")
 481
 482                valid = ",".join(set(choose.index[choose["DEG"] == "target"]))
 483                choose = choose.loc[
 484                    :, (choose != 0).any(axis=0)
 485                ]  # Remove all-zero columns
 486
 487                # Parallel computation
 488                results = Parallel(n_jobs=n_proc)(
 489                    delayed(stat_calc)(choose, feature)
 490                    for feature in tqdm(choose.columns[choose.columns != "DEG"])
 491                )
 492
 493                # Convert results to DataFrame
 494                combined_df = pd.DataFrame(results)
 495                combined_df = combined_df[
 496                    (combined_df["avg_valid"] > 0) | (combined_df["avg_ctrl"] > 0)
 497                ]
 498
 499                combined_df["valid_group"] = valid
 500                combined_df.sort_values(by="p_val", inplace=True)
 501
 502                # Adjusted p-values using Benjamini-Hochberg method
 503                num_tests = len(combined_df)
 504                combined_df["adj_pval"] = np.minimum(
 505                    1, (combined_df["p_val"] * num_tests) / np.arange(1, num_tests + 1)
 506                )
 507
 508                combined_df["p_val"] = combined_df["p_val"].replace(0, 2**-1074)
 509
 510                combined_df["-log(p_val)"] = -np.log10(combined_df["p_val"])
 511
 512                valid_factor = safe_min_half(combined_df["avg_valid"])
 513
 514                ctrl_factor = safe_min_half(combined_df["avg_ctrl"])
 515
 516                cv_factor = min(valid_factor, ctrl_factor)
 517
 518                if cv_factor == 0:
 519                    cv_factor = max(valid_factor, ctrl_factor)
 520
 521                if not np.isfinite(cv_factor) or cv_factor == 0:
 522                    cv_factor += offset
 523
 524                valid = combined_df["avg_valid"].where(
 525                    combined_df["avg_valid"] != 0, combined_df["avg_valid"] + cv_factor
 526                )
 527
 528                ctrl = combined_df["avg_ctrl"].where(
 529                    combined_df["avg_ctrl"] != 0, combined_df["avg_ctrl"] + cv_factor
 530                )
 531
 532                combined_df["FC"] = valid / ctrl
 533
 534                combined_df["log(FC)"] = np.log2(combined_df["FC"])
 535                combined_df["norm_diff"] = (
 536                    combined_df["avg_valid"] - combined_df["avg_ctrl"]
 537                )
 538
 539                final_results.append(combined_df)
 540
 541            print("\nAnalysis has finished!")
 542            return pd.concat(final_results, ignore_index=True)
 543
 544        elif isinstance(sets, dict):
 545            print("\nAnalysis started...")
 546            print("\nComparing groups...")
 547
 548            group_list = list(sets.keys())
 549            choose.index = metadata["sets"]
 550
 551            inx = sorted([item for sublist in sets.values() for item in sublist])
 552            choose = choose.loc[inx]
 553
 554            full_df = pd.DataFrame()
 555            for n, g in enumerate(group_list):
 556                print(f"Calculating statistics for {g}")
 557
 558                rest_indices = [
 559                    idx
 560                    for i, group in enumerate(group_list)
 561                    if i != n
 562                    for idx in sets[group]
 563                ]
 564
 565                choose["DEG"] = np.where(
 566                    choose.index.isin(sets[g]),
 567                    "target",
 568                    np.where(choose.index.isin(rest_indices), "rest", "drop"),
 569                )
 570
 571                choose = choose[choose["DEG"] != "drop"]
 572
 573                valid = g
 574                choose = choose.loc[
 575                    :, (choose != 0).any(axis=0)
 576                ]  # Remove all-zero columns
 577
 578                # Parallel computation
 579                results = Parallel(n_jobs=n_proc)(
 580                    delayed(stat_calc)(choose, feature)
 581                    for feature in tqdm(choose.columns[choose.columns != "DEG"])
 582                )
 583
 584                # Convert results to DataFrame
 585                combined_df = pd.DataFrame(results)
 586                combined_df["valid_group"] = valid
 587                combined_df.sort_values(by="p_val", inplace=True)
 588
 589                # Adjusted p-values using Benjamini-Hochberg method
 590                num_tests = len(combined_df)
 591                combined_df["adj_pval"] = np.minimum(
 592                    1, (combined_df["p_val"] * num_tests) / np.arange(1, num_tests + 1)
 593                )
 594
 595                combined_df["-log(p_val)"] = -np.log10(offset + combined_df["p_val"])
 596
 597                valid_factor = safe_min_half(combined_df["avg_valid"])
 598
 599                ctrl_factor = safe_min_half(combined_df["avg_ctrl"])
 600
 601                cv_factor = min(valid_factor, ctrl_factor)
 602
 603                if cv_factor == 0:
 604                    cv_factor = max(valid_factor, ctrl_factor)
 605
 606                if not np.isfinite(cv_factor) or cv_factor == 0:
 607                    cv_factor += offset
 608
 609                valid = combined_df["avg_valid"].where(
 610                    combined_df["avg_valid"] != 0, combined_df["avg_valid"] + cv_factor
 611                )
 612
 613                ctrl = combined_df["avg_ctrl"].where(
 614                    combined_df["avg_ctrl"] != 0, combined_df["avg_ctrl"] + cv_factor
 615                )
 616
 617                combined_df["FC"] = valid / ctrl
 618
 619                combined_df["log(FC)"] = np.log2(combined_df["FC"])
 620                combined_df["norm_diff"] = (
 621                    combined_df["avg_valid"] - combined_df["avg_ctrl"]
 622                ) + offset
 623
 624                full_df = pd.concat([full_df, combined_df])
 625
 626            print("\nAnalysis has finished!")
 627            return full_df
 628
 629        else:
 630            print("\nInvalid parameters. Please check the input.")
 631            return None
 632
 633    except Exception as e:
 634        print(f"Error: {e}")
 635        return None
 636
 637
 638# UTILS JIMG
 639
 640
 641def save_image(image, path_to_save):
 642    """
 643    Save an image to disk.
 644
 645    Parameters
 646    ----------
 647    image : np.ndarray
 648        Input image array to be saved.
 649
 650    path_to_save : str
 651        Output file path including the filename.
 652        Must include one of the following extensions:
 653        ``.png``, ``.tiff`` or ``.tif``.
 654
 655    Returns
 656    -------
 657    str
 658        Path to the saved file.
 659    """
 660
 661    try:
 662        if (
 663            len(path_to_save) == 0
 664            or ".png" not in path_to_save
 665            or ".tiff" not in path_to_save
 666            or ".tif" not in path_to_save
 667        ):
 668            print(
 669                "\nThe path is not provided or the file extension is not *.png, *.tiff or *.tif"
 670            )
 671        else:
 672            cv2.imwrite(path_to_save, image)
 673
 674    except:
 675        print("Something went wrong. Check the function input data and try again!")
 676
 677
 678def load_tiff(path_to_tiff: str):
 679    """
 680    Load a *.tiff image and ensure it is in 16-bit format.
 681
 682    Parameters
 683    ----------
 684    path_to_tiff : str
 685        Path to the *.tiff file.
 686
 687    Returns
 688    -------
 689    np.ndarray
 690        Loaded image stack, converted to 16-bit if necessary.
 691    """
 692
 693    try:
 694        stack = tiff.imread(path_to_tiff)
 695
 696        if stack.dtype != "uint16":
 697
 698            stack = stack.astype(np.uint16)
 699
 700            for n, _ in enumerate(stack):
 701
 702                min_val = np.min(stack[n])
 703                max_val = np.max(stack[n])
 704
 705                stack[n] = ((stack[n] - min_val) / (max_val - min_val) * 65535).astype(
 706                    np.uint16
 707                )
 708
 709                stack[n] = np.clip(stack[n], 0, 65535)
 710
 711        return stack
 712
 713    except:
 714        print("Something went wrong. Check the function input data and try again!")
 715
 716
 717def z_projection(tiff_object, projection_type="avg"):
 718    """
 719    Perform Z-projection on a stacked (3D) image.
 720
 721    Parameters
 722    ----------
 723    tiff_object : np.ndarray
 724        Input stacked 3D image (e.g., loaded with `load_tiff()`).
 725
 726    projection_type : str
 727        Type of Z-axis projection. Options: 'avg', 'median', 'min', 'max', 'std'.
 728
 729    Returns
 730    -------
 731    np.ndarray
 732        Resulting 2D projection image.
 733    """
 734
 735    try:
 736
 737        if projection_type == "avg":
 738            img = np.mean(tiff_object, axis=0).astype(np.uint16)
 739        elif projection_type == "max":
 740            img = np.max(tiff_object, axis=0).astype(np.uint16)
 741        elif projection_type == "min":
 742            img = np.min(tiff_object, axis=0).astype(np.uint16)
 743        elif projection_type == "std":
 744            img = np.std(tiff_object, axis=0).astype(np.uint16)
 745        elif projection_type == "median":
 746            img = np.median(tiff_object, axis=0).astype(np.uint16)
 747
 748        return img
 749
 750    except:
 751        print("Something went wrong. Check the function input data and try again!")
 752
 753
 754def clahe_16bit(img, kernal=(100, 100)):
 755    """
 756    Apply CLAHE (Contrast Limited Adaptive Histogram Equalization) to an image.
 757
 758    Parameters
 759    ----------
 760    img : np.ndarray
 761        Input image.
 762
 763    Returns
 764    -------
 765    img : np.ndarray
 766        Image after CLAHE enhancement.
 767
 768    kernel : tuple
 769        Size of the CLAHE tile grid used for processing, e.g., (100, 100).
 770    """
 771
 772    try:
 773
 774        img = img.copy()
 775
 776        img8bit = img.copy()
 777
 778        min_val = np.min(img8bit)
 779        max_val = np.max(img8bit)
 780
 781        img8bit = ((img8bit - min_val) / (max_val - min_val) * 255).astype(np.uint8)
 782
 783        clahe = cv2.createCLAHE(clipLimit=10, tileGridSize=kernal)
 784        img8bit = clahe.apply(img8bit)
 785
 786        img8bit = img8bit / 255
 787
 788        img = img * img8bit
 789
 790        min_val = np.min(img)
 791        max_val = np.max(img)
 792
 793        img = ((img - min_val) / (max_val - min_val) * 65535).astype(np.uint16)
 794
 795        return img
 796
 797    except:
 798        print("Something went wrong. Check the function input data and try again!")
 799
 800
 801def equalizeHist_16bit(image_eq):
 802    """
 803    Apply global histogram equalization to an image.
 804
 805    Parameters
 806    ----------
 807    image_eq : np.ndarray
 808        Input image.
 809
 810    Returns
 811    -------
 812    np.ndarray
 813        Image after global histogram equalization.
 814    """
 815
 816    try:
 817
 818        image = image_eq.copy()
 819
 820        min_val = np.min(image)
 821        max_val = np.max(image)
 822
 823        scaled_image = ((image - min_val) / (max_val - min_val) * 255).astype(np.uint8)
 824
 825        eq_image = cv2.equalizeHist(scaled_image)
 826
 827        eq_image_bin = eq_image / 255
 828
 829        image_eq_16 = image * eq_image_bin
 830        image_eq_16 = (image_eq_16 / np.max(image_eq_16)) * 65535
 831        image_eq_16[image_eq_16 > (65535 / 2)] += 65535 - np.max(image_eq_16)
 832        image_eq_16 = image_eq_16.astype(np.uint16)
 833
 834        return image_eq_16
 835
 836    except:
 837        print("Something went wrong. Check the function input data and try again!")
 838
 839
 840def load_image(path):
 841    """
 842    Load an image and convert it to 16-bit if necessary.
 843
 844    Parameters
 845    ----------
 846    path : str
 847        Path to the image file.
 848
 849    Returns
 850    -------
 851    np.ndarray
 852        Loaded 16-bit image.
 853    """
 854
 855    try:
 856
 857        img = cv2.imread(path, cv2.IMREAD_ANYDEPTH | cv2.IMREAD_COLOR)
 858
 859        # convert to 16 bit (the function are working on 16 bit images!)
 860        if img.dtype != "uint16":
 861
 862            min_val = np.min(img)
 863            max_val = np.max(img)
 864
 865            img = ((img - min_val) / (max_val - min_val) * 65535).astype(np.uint16)
 866
 867            img = np.clip(img, 0, 65535)
 868
 869        return img
 870
 871    except:
 872        print("Something went wrong. Check the function input data and try again!")
 873
 874
 875def rotate_image(img, rotate: int):
 876    """
 877    Rotate an image by a specified angle.
 878
 879    Parameters
 880    ----------
 881    img : np.ndarray
 882        Image to rotate.
 883
 884    rotate : int
 885        Degree of rotation. Available options: 90, 180, 270.
 886
 887    Returns
 888    -------
 889    np.ndarray
 890        Rotated image.
 891    """
 892
 893    try:
 894
 895        if rotate == 0:
 896            r = 0
 897        elif rotate == 90:
 898            r = -1
 899        elif rotate == 180:
 900            r = 2
 901        elif rotate == 180:
 902            r = 2
 903        elif rotate == 270:
 904            r = 1
 905        else:
 906            print("Wrong argument - rotate!")
 907            return None
 908
 909        img = img.copy()
 910
 911        img = np.rot90(img.copy(), k=r)
 912
 913        return img
 914
 915    except:
 916        print("Something went wrong. Check the function input data and try again!")
 917
 918
 919def mirror_image(img, rotate: str):
 920    """
 921    Mirror an image along specified axes.
 922
 923    Parameters
 924    ----------
 925    img : np.ndarray
 926        Image to mirror.
 927
 928    rotate : str
 929        Type of mirroring to apply. Options:
 930            'h'  - horizontal mirroring
 931            'v'  - vertical mirroring
 932            'hv' - both horizontal and vertical mirroring
 933
 934    Returns
 935    -------
 936    np.ndarray
 937        Mirrored image.
 938    """
 939
 940    try:
 941
 942        if rotate == "h":
 943            img = np.fliplr(img.copy())
 944        elif rotate == "v":
 945            img = np.flipud(img.copy())
 946        elif rotate == "hv":
 947            img = np.flipud(np.fliplr(img.copy()))
 948        else:
 949            print("Wrong argument - rotate!")
 950            return None
 951
 952        return img
 953
 954    except:
 955        print("Something went wrong. Check the function input data and try again!")
 956
 957
 958# validatet UTILS
 959
 960
 961def merge_images(image_list: list, intensity_factors: list = []):
 962    """
 963    Merge multiple image projections from different channels.
 964
 965    Parameters
 966    ----------
 967    image_list : list of np.ndarray
 968        List of images to merge. All images must have the same shape and size.
 969
 970    intensity_factors : list of float
 971        Intensity scaling factors for each image in `image_list`. Base value is 1.
 972            - Values < 1 decrease intensity.
 973            - Values > 1 increase intensity.
 974
 975    Returns
 976    -------
 977    np.ndarray
 978        Merged image after applying intensity scaling.
 979    """
 980
 981    try:
 982
 983        result = None
 984
 985        if len(intensity_factors) == 0:
 986
 987            intensity_factors = []
 988            for bt in range(len(image_list)):
 989                intensity_factors.append(1)
 990
 991        for i, image in enumerate(image_list):
 992            if result is None:
 993                result = image.astype(np.uint64) * intensity_factors[i]
 994            else:
 995                result = cv2.addWeighted(
 996                    result, 1, image.astype(np.uint64) * intensity_factors[i], 1, 0
 997                )
 998
 999        result = np.clip(result, 0, 65535)
1000
1001        result = result.astype(np.uint16)
1002
1003        return result
1004
1005    except:
1006        print("Something went wrong. Check the function input data and try again!")
1007
1008
1009def adjust_img_16bit(
1010    img,
1011    color="gray",
1012    max_intensity: int = 65535,
1013    min_intenisty: int = 0,
1014    brightness: int = 1000,
1015    contrast=1.0,
1016    gamma=1.0,
1017):
1018    """
1019    Manually adjust image parameters and return the adjusted image.
1020
1021    Parameters
1022    ----------
1023    img : np.ndarray
1024        Input image.
1025
1026    color : str
1027        Image color channel. Options: ['green', 'blue', 'red', 'yellow', 'magenta', 'cyan'].
1028
1029    max_intensity : int
1030        Upper threshold for pixel values. Pixels exceeding this value are set to `max_intensity`.
1031
1032    min_intensity : int
1033        Lower threshold for pixel values. Pixels below this value are set to 0.
1034
1035    brightness : int, optional
1036        Image brightness adjustment. Typical range: [900-2000]. Default is 1000.
1037
1038    contrast : float or int, optional
1039        Image contrast adjustment. Typical range: [0-5]. Default is 1.
1040
1041    gamma : float or int, optional
1042        Gamma correction factor. Typical range: [0-5]. Default is 1.
1043
1044    Returns
1045    -------
1046    np.ndarray
1047        Adjusted image after applying brightness, contrast, gamma, and intensity thresholds.
1048    """
1049
1050    try:
1051
1052        img = img.copy()
1053
1054        img = img.astype(np.uint64)
1055
1056        img = np.clip(img, 0, 65535)
1057
1058        # brightness
1059        if brightness != 1000:
1060            factor = -1000 + brightness
1061            side = factor / abs(factor)
1062            img[img > 0] = img[img > 0] + ((img[img > 0] * abs(factor) / 100) * side)
1063            img = np.clip(img, 0, 65535)
1064
1065        # contrast
1066        if contrast != 1:
1067            img = ((img - np.mean(img)) * contrast) + np.mean(img)
1068            img = np.clip(img, 0, 65535)
1069
1070        # gamma
1071        if gamma != 1:
1072
1073            max_val = np.max(img)
1074
1075            image_array = img.copy() / max_val
1076
1077            image_array = np.clip(image_array, 0, 1)
1078
1079            corrected_array = image_array ** (1 / gamma)
1080
1081            img = corrected_array * max_val
1082
1083            del image_array, corrected_array
1084
1085            img = np.clip(img, 0, 65535)
1086
1087        img = np.nan_to_num(img, nan=0, posinf=65535, neginf=0)
1088        max_val = np.max(img)
1089        if max_val > 0:
1090            img = ((img / max_val) * 65535).astype(np.uint16)
1091
1092        # max intenisty
1093        if max_intensity != 65535:
1094            img[img >= max_intensity] = 65535
1095
1096        # min intenisty
1097        if min_intenisty != 0:
1098            img[img <= min_intenisty] = 0
1099
1100        img_gamma = np.zeros((img.shape[0], img.shape[1], 3), dtype=np.uint16)
1101
1102        if color == "green":
1103            img_gamma[:, :, 1] = img
1104
1105        elif color == "red":
1106            img_gamma[:, :, 2] = img
1107
1108        elif color == "blue":
1109            img_gamma[:, :, 0] = img
1110
1111        elif color == "magenta":
1112            img_gamma[:, :, 0] = img
1113            img_gamma[:, :, 2] = img
1114
1115        elif color == "yellow":
1116            img_gamma[:, :, 1] = img
1117            img_gamma[:, :, 2] = img
1118
1119        elif color == "cyan":
1120            img_gamma[:, :, 0] = img
1121            img_gamma[:, :, 1] = img
1122
1123        elif color == "gray":
1124            img_gamma[:, :, 0] = img
1125            img_gamma[:, :, 1] = img
1126            img_gamma[:, :, 2] = img
1127
1128        return img_gamma
1129
1130    except:
1131        print("Something went wrong. Check the function input data and try again!")
1132
1133
1134def get_screan():
1135    """
1136    Get the current screen width and height.
1137
1138    Returns
1139    -------
1140    tuple of int
1141        screen_width, screen_height
1142    """
1143
1144    root = tk.Tk()
1145    screen_width = root.winfo_screenwidth()
1146    screen_height = root.winfo_screenheight()
1147
1148    root.destroy()
1149
1150    return screen_width, screen_height
1151
1152
1153def resize_to_screen_img(img_file, factor=1):
1154    """
1155    Resize an input image to fit the screen size, optionally scaled by a factor.
1156
1157    Parameters
1158    ----------
1159    img_file : np.ndarray
1160        Input image to be resized.
1161
1162    factor : float, optional
1163        Scaling factor applied to the screen dimensions before resizing the image. Default is 1.
1164
1165    Returns
1166    -------
1167    np.ndarray
1168        Resized image that fits within the screen dimensions while maintaining aspect ratio.
1169    """
1170
1171    screen_width, screen_height = get_screan()
1172
1173    screen_width = int(screen_width * factor)
1174    screen_height = int(screen_height * factor)
1175
1176    h = int(img_file.shape[0])
1177    w = int(img_file.shape[1])
1178
1179    if screen_width < w:
1180        h = img_file.shape[0]
1181        w = img_file.shape[1]
1182
1183        ww = int((screen_width / w) * w)
1184        hh = int((screen_width / w) * h)
1185
1186        img_file = cv2.resize(img_file, (ww, hh))
1187
1188        h = img_file.shape[0]
1189        w = img_file.shape[1]
1190
1191    if screen_height < h:
1192        h = img_file.shape[0]
1193        w = img_file.shape[1]
1194
1195        ww = int((screen_height / h) * w)
1196        hh = int((screen_height / h) * h)
1197
1198        img_file = cv2.resize(img_file, (ww, hh))
1199
1200    return img_file
1201
1202
1203def display_preview(image):
1204    """
1205    Quickly preview an image using a display window.
1206
1207    Parameters
1208    ----------
1209    image : np.ndarray
1210        Input image to be displayed.
1211
1212    Notes
1213    -----
1214        The function displays the image in a window. Does not return a value.
1215    """
1216    try:
1217
1218        res_sc = resize_to_screen_img(image.copy(), factor=0.8)
1219
1220        cv2.imshow("Display", res_sc)
1221
1222        cv2.waitKey(0) & 0xFF
1223
1224        cv2.destroyAllWindows()
1225
1226    except:
1227        print("Something went wrong. Check the function input data and try again!")
1228
1229
1230## flow_JIMG_functions
1231
1232
1233class ImageTools:
1234    """
1235    Collection of utility functions for image preprocessing, adjustment,
1236    merging, and stitching.
1237
1238    This class provides standalone static methods for operations such as
1239    histogram equalization, CLAHE enhancement, intensity adjustments,
1240    image merging based on weighted ratios, and horizontal stitching.
1241    These tools are used internally by `ImagesManagement` but can also be
1242    applied independently for generic image-processing workflows.
1243
1244    Notes
1245    -----
1246    All methods operate on NumPy arrays and expect images in standard
1247    OpenCV format. Some functions assume 16-bit images unless stated
1248    otherwise.
1249
1250    Examples
1251    --------
1252    >>> from ImageTools import ImageTools
1253    >>> img = ImageTools.equalize_hist_16bit(img)
1254    >>> merged = ImageTools.merge_images([img1, img2], [0.5, 0.5])
1255    """
1256
1257    def get_screan(self):
1258        """
1259        Return the current screen size.
1260
1261        Returns
1262        -------
1263        screen_width : int
1264            Width of the screen in pixels.
1265
1266        screen_height : int
1267            Height of the screen in pixels.
1268        """
1269
1270        root = tk.Tk()
1271        screen_width = root.winfo_screenwidth()
1272        screen_height = root.winfo_screenheight()
1273
1274        root.destroy()
1275
1276        return screen_width, screen_height
1277
1278    def resize_to_screen_img(self, img_file, factor=0.5):
1279        """
1280        Resize an input image to a scaled version of the current screen size.
1281
1282        Parameters
1283        ----------
1284        img_file : np.ndarray
1285            Input image to be resized.
1286
1287        factor : int
1288            Scaling factor applied to the screen dimensions.
1289
1290        Returns
1291        -------
1292        resized_image : np.ndarray
1293            Resized image adjusted to the scaled screen size.
1294        """
1295
1296        screen_width, screen_height = self.get_screan()
1297
1298        screen_width = int(screen_width * factor)
1299        screen_height = int(screen_height * factor)
1300
1301        h = int(img_file.shape[0])
1302        w = int(img_file.shape[1])
1303
1304        if screen_width < w or screen_width * 0.3 > w:
1305            h = img_file.shape[0]
1306            w = img_file.shape[1]
1307
1308            ww = int((screen_width / w) * w)
1309            hh = int((screen_width / w) * h)
1310
1311            img_file = cv2.resize(img_file, (ww, hh))
1312
1313            h = img_file.shape[0]
1314            w = img_file.shape[1]
1315
1316        if screen_height < h or screen_height * 0.3 > h:
1317            h = img_file.shape[0]
1318            w = img_file.shape[1]
1319
1320            ww = int((screen_height / h) * w)
1321            hh = int((screen_height / h) * h)
1322
1323            img_file = cv2.resize(img_file, (ww, hh))
1324
1325        return img_file
1326
1327    def load_JIMG_project(self, project_path):
1328        """
1329        Load a JIMG project from a `.pjm` file.
1330
1331        Parameters
1332        ----------
1333        file_path : str
1334            Path to the project file with `.pjm` extension.
1335
1336        Returns
1337        -------
1338        project : object
1339            Loaded project object.
1340
1341        Raises
1342        ------
1343        ValueError
1344            If the provided file does not have a `.pjm` extension.
1345        """
1346
1347        if ".pjm" in project_path:
1348            with open(project_path, "rb") as file:
1349                app_metadata_tmp = pickle.load(file)
1350
1351            return app_metadata_tmp
1352
1353        else:
1354            print("\nProvide path to the project metadata file with *.pjm extension!!!")
1355
1356    def ajd_mask_size(self, image, mask):
1357        """
1358        Adjusts the size of a mask to match the dimensions of a given image.
1359
1360        Parameters
1361        ----------
1362        image : np.ndarray
1363            Reference image whose size the mask should match. Can be 2D or 3D.
1364
1365        mask : np.ndarray
1366            Mask image to be resized.
1367
1368        Returns
1369        -------
1370        np.ndarray
1371            Resized mask matching the input image dimensions.
1372        """
1373
1374        try:
1375            mask = cv2.resize(mask, (image.shape[2], image.shape[1]))
1376        except:
1377            mask = cv2.resize(mask, (image.shape[1], image.shape[0]))
1378
1379        return mask
1380
1381    def load_image(self, path_to_image):
1382        """
1383        Load an image from the specified file path.
1384
1385        Parameters
1386        ----------
1387        path_to_image : str
1388            Path to the image file.
1389
1390        Returns
1391        -------
1392        image : np.ndarray
1393            Loaded image as a NumPy array.
1394        """
1395
1396        image = load_image(path_to_image)
1397        return image
1398
1399    def load_3D_tiff(self, path_to_image):
1400        """
1401        Load a 3D image from a TIFF file.
1402
1403        Parameters
1404        ----------
1405        path_to_image : str
1406            Path to the 3D TIFF image file.
1407
1408        Returns
1409        -------
1410        image : np.ndarray
1411            Loaded 3D image as a NumPy array.
1412        """
1413
1414        image = load_tiff(path_to_image)
1415
1416        return image
1417
1418    def load_mask(self, path_to_mask):
1419        """
1420        Load a mask image.
1421
1422        Parameters
1423        ----------
1424        path_to_mask : str
1425            Path to the mask image file.
1426
1427        Returns
1428        -------
1429        mask : np.ndarray
1430            Loaded mask image as a NumPy array.
1431        """
1432
1433        mask = cv2.imread(path_to_mask, cv2.IMREAD_GRAYSCALE)
1434        return mask
1435
1436    def save(self, image, file_name):
1437        """
1438        Save an image to disk.
1439
1440        Parameters
1441        ----------
1442        image : np.ndarray
1443            Image data to be saved.
1444
1445        file_name : str
1446            Output file path including the desired image extension (e.g., ".png", ".jpg").
1447
1448        """
1449
1450        cv2.imwrite(filename=file_name, img=image)
1451
1452    # calculation methods
1453
1454    def drop_dict(self, dictionary, key, var, action=None):
1455        """
1456        Filters elements from a dictionary based on a condition applied to a specified key.
1457
1458        Parameters
1459        ----------
1460        dictionary : dict
1461            Dictionary containing lists or arrays under each key.
1462
1463        key : str
1464            The key in the dictionary on which the filtering condition will be applied.
1465
1466        var : numeric
1467            Value to compare against each element in dictionary[key].
1468
1469        action : str, optional
1470            Comparison operator as string: '<=', '>=', '==', '<', '>'.
1471            Default is None, which raises an error.
1472
1473        Returns
1474        -------
1475        dict
1476            A new dictionary with elements removed where the condition matches.
1477        """
1478
1479        dictionary = copy.deepcopy(dictionary)
1480        indices_to_drop = []
1481        for i, dr in enumerate(dictionary[key]):
1482
1483            if isinstance(dr, np.ndarray):
1484                dr = np.mean(dr)
1485
1486            if action == "<=":
1487                if var <= dr:
1488                    indices_to_drop.append(i)
1489            elif action == ">=":
1490                if var >= dr:
1491                    indices_to_drop.append(i)
1492            elif action == "==":
1493                if var == dr:
1494                    indices_to_drop.append(i)
1495            elif action == "<":
1496                if var < dr:
1497                    indices_to_drop.append(i)
1498            elif action == ">":
1499                if var > dr:
1500                    indices_to_drop.append(i)
1501            else:
1502                print("\nWrong action!")
1503                return None
1504
1505        for key, value in dictionary.items():
1506            dictionary[key] = [
1507                v for i, v in enumerate(value) if i not in indices_to_drop
1508            ]
1509
1510        return dictionary
1511
1512    # modified for gradation (separation near nucleus)
1513
1514    def create_mask(self, dictionary, image):
1515        """
1516        Creates a mask image with gradated intensity values for each coordinate set in a dictionary.
1517
1518        Parameters
1519        ----------
1520        dictionary : dict
1521            Dictionary containing a 'coords' key with a list of arrays of coordinates.
1522
1523        image : np.ndarray
1524            Base image to define the shape of the mask.
1525
1526        Returns
1527        -------
1528        np.ndarray
1529            Mask image with uint16 gradated intensity.
1530        """
1531
1532        image_mask = np.zeros(image.shape)
1533
1534        arrays_list = copy.deepcopy(dictionary["coords"])
1535
1536        if len(arrays_list) > 0:
1537
1538            initial_val = math.floor((2**16 - 1) / 4)
1539            intensity_list = ((2**16 - 1) - math.floor((2**16 - 1) / 4)) / len(
1540                arrays_list
1541            )
1542
1543            gradation = 1
1544            for arr in arrays_list:
1545                image_mask[arr[:, 0], arr[:, 1]] = initial_val + (
1546                    gradation * intensity_list
1547                )
1548                gradation += 1
1549
1550        return image_mask.astype("uint16")
1551
1552    def min_max_histograme(self, image):
1553        """
1554        Calculates histogram-based minimum and maximum intensity percentiles in an image.
1555
1556        Parameters
1557        ----------
1558        image : np.ndarray
1559            Input image for histogram analysis.
1560
1561        Returns
1562        -------
1563        min_val : float
1564            Minimum intensity percentile above zero.
1565
1566        max_val : float
1567            Maximum intensity percentile based on histogram gradient.
1568
1569        df : pd.DataFrame
1570            DataFrame containing quantiles, corresponding intensity values, and cumulative percents.
1571        """
1572
1573        q = []
1574        val = []
1575        perc = []
1576
1577        max_val = image.shape[0] * image.shape[1]
1578
1579        for n in range(0, 100, 5):
1580            q.append(n)
1581            val.append(np.quantile(image, n / 100))
1582            sum_val = np.sum(image < np.quantile(image, n / 100))
1583            pr = sum_val / max_val
1584            perc.append(pr)
1585
1586        df = pd.DataFrame({"q": q, "val": val, "perc": perc})
1587
1588        min_val = 0
1589        for i in df.index:
1590            if df["val"][i] != 0 and min_val == 0:
1591                min_val = df["perc"][i]
1592
1593        max_val = 0
1594        df = df[df["perc"] > 0]
1595        df = df.sort_values("q", ascending=False).reset_index(drop=True)
1596
1597        for i in df.index:
1598            if i > 1 and df["val"][i] * 1.5 > df["val"][i - 1]:
1599                max_val = df["perc"][i]
1600                break
1601            elif i == len(df.index) - 1:
1602                max_val = df["perc"][i]
1603
1604        return min_val, max_val, df
def umap_html(umap_result, width=1000, height=1200):
29def umap_html(umap_result, width=1000, height=1200):
30    """
31    Create an interactive HTML UMAP scatter plot.
32
33    Parameters
34    ----------
35    umap_result : pandas.DataFrame or dict-like
36        UMAP embedding containing at least:
37        - `0` : array-like, UMAP dimension 1
38        - `1` : array-like, UMAP dimension 2
39        - `'clusters'` : array-like, assigned cluster labels
40
41    width : int, optional
42        Width of the output figure in pixels. Default is 1000.
43
44    height : int, optional
45        Height of the output figure in pixels. Default is 1200.
46
47    Returns
48    -------
49    plotly.graph_objs._figure.Figure
50        Interactive Plotly scatter plot object visualizing UMAP with colored clusters.
51    """
52
53    fig = px.scatter(
54        x=umap_result[0],
55        y=umap_result[1],
56        color=umap_result["clusters"],
57        labels={"color": "Cells"},
58        template="simple_white",
59        width=width,
60        height=height,
61        render_mode="svg",
62        color_discrete_sequence=px.colors.qualitative.Dark24
63        + px.colors.qualitative.Light24,
64    )
65
66    fig.update_xaxes(title_text="UMAP 1")
67    fig.update_yaxes(title_text="UMAP 2")
68
69    return fig

Create an interactive HTML UMAP scatter plot.

Parameters

umap_result : pandas.DataFrame or dict-like UMAP embedding containing at least: - 0 : array-like, UMAP dimension 1 - 1 : array-like, UMAP dimension 2 - 'clusters' : array-like, assigned cluster labels

width : int, optional Width of the output figure in pixels. Default is 1000.

height : int, optional Height of the output figure in pixels. Default is 1200.

Returns

plotly.graph_objs._figure.Figure Interactive Plotly scatter plot object visualizing UMAP with colored clusters.

def umap_static(umap_result, width=10, height=13, n_per_col=20):
 72def umap_static(umap_result, width=10, height=13, n_per_col=20):
 73    """
 74    Create a static matplotlib UMAP scatter plot.
 75
 76    Parameters
 77    ----------
 78    umap_result : pandas.DataFrame or dict-like
 79        UMAP projection containing:
 80        - `0` : array-like, UMAP dimension 1
 81        - `1` : array-like, UMAP dimension 2
 82        - `'clusters'` : array-like, cluster assignments
 83
 84    width : float, optional
 85        Width of the figure in inches. Default is 10.
 86
 87    height : float, optional
 88        Height of the figure in inches. Default is 13.
 89
 90    n_per_col : int, optional
 91        Maximum number of legend entries per column. Default is 20.
 92
 93    Returns
 94    -------
 95    matplotlib.figure.Figure
 96        Static matplotlib figure representing the UMAP embedding with clusters.
 97    """
 98
 99    plotly_colors = px.colors.qualitative.Dark24 + px.colors.qualitative.Light24
100    num_colors = len(plotly_colors)
101
102    fig = plt.figure(figsize=(width, height))
103
104    sorted_labels = pd.unique(umap_result["clusters"])
105
106    color_map = {
107        label: plotly_colors[i % num_colors] for i, label in enumerate(sorted_labels)
108    }
109
110    for label in sorted_labels:
111        subset = umap_result[umap_result["clusters"] == label]
112        plt.scatter(
113            subset[0],
114            subset[1],
115            c=[color_map[label]],
116            label=f"Cluster {label}",
117            alpha=0.7,
118            s=20,
119            edgecolor="black",
120            linewidths=0.1,
121        )
122
123    n_col = -(-len(set(umap_result["clusters"])) // n_per_col)
124
125    plt.xlabel("UMAP 1", fontsize=14)
126    plt.ylabel("UMAP 2", fontsize=14)
127    plt.grid(True, which="both", linestyle="--", linewidth=0.1)
128    plt.legend(bbox_to_anchor=(1.05, 1), loc="upper left", ncol=n_col)
129    plt.tight_layout()
130
131    return fig

Create a static matplotlib UMAP scatter plot.

Parameters

umap_result : pandas.DataFrame or dict-like UMAP projection containing: - 0 : array-like, UMAP dimension 1 - 1 : array-like, UMAP dimension 2 - 'clusters' : array-like, cluster assignments

width : float, optional Width of the figure in inches. Default is 10.

height : float, optional Height of the figure in inches. Default is 13.

n_per_col : int, optional Maximum number of legend entries per column. Default is 20.

Returns

matplotlib.figure.Figure Static matplotlib figure representing the UMAP embedding with clusters.

def test_data(path=''):
134def test_data(path=""):
135    """
136    Download and extract test data from Google Drive.
137
138    This function downloads a compressed archive containing example test data
139    and extracts it into the specified directory. The data is fetched using
140    a direct Google Drive link. If the download or extraction fails, an
141    error message is printed.
142
143    Parameters
144    ----------
145    path : str, optional
146        Destination directory where the test dataset will be downloaded and
147        extracted. Defaults to the current working directory.
148
149
150    Notes
151    -----
152    - The downloaded file is named ``test_data.tar.gz``.
153    - The archive is extracted into ``<path>/test_data``.
154    - In case of any failure (download or extraction), the function prints
155      an informative message instead of raising an exception.
156    """
157
158    try:
159
160        file_name = "test_data.tar.gz"
161
162        file_name = os.path.join(path, file_name)
163
164        url = "https://drive.google.com/uc?id=1MhzhleMP7iTzlBVW8eP5sFaonJdg1a3T"
165
166        gdown.download(url, file_name, quiet=False)
167
168        # Unzip
169
170        with tarfile.open(file_name, "r:gz") as tar:
171            tar.extractall(path=path)
172
173        print(
174            f"\nTest data downloaded succesfully -> {os.path.join(path, 'test_data')}"
175        )
176
177    except:
178
179        print(
180            "\nTest data could not be downloaded. Please check your connection and try again!"
181        )

Download and extract test data from Google Drive.

This function downloads a compressed archive containing example test data and extracts it into the specified directory. The data is fetched using a direct Google Drive link. If the download or extraction fails, an error message is printed.

Parameters

path : str, optional Destination directory where the test dataset will be downloaded and extracted. Defaults to the current working directory.

Notes

  • The downloaded file is named test_data.tar.gz.
  • The archive is extracted into <path>/test_data.
  • In case of any failure (download or extraction), the function prints an informative message instead of raising an exception.
def prop_plot(df_pivot, chi_df):
184def prop_plot(df_pivot, chi_df):
185    """
186    Create a stacked bar plot of proportional data with post-hoc significance annotations.
187
188    Parameters
189    ----------
190    df_pivot : pandas.DataFrame
191        Pivot table where rows represent categories (e.g., compartments) and columns
192        represent groups. Values are counts or frequencies.
193
194    chi_df : pandas.DataFrame
195        DataFrame containing pairwise Chi-square test results with an added
196        'Significance_Label' column (e.g., '***', '**', '*', 'ns') for each pair
197        of groups. Typically output from `chi_pairs` and `get_significance_label`.
198
199    Returns
200    -------
201    matplotlib.figure.Figure
202        The Matplotlib figure object containing the stacked bar plot.
203
204    Notes
205    -----
206    - The function converts raw counts to percentages per group for visualization.
207    - Each pairwise comparison and its significance label is displayed as a text box
208      next to the plot.
209    - Colors are assigned using the 'viridis' colormap by default.
210    - The plot is configured for clarity with labeled axes, legend, and appropriately
211      sized text.
212    """
213
214    df_pivot_perc = df_pivot.div(df_pivot.sum(axis=0), axis=1) * 100
215
216    chi_df = chi_df.sort_values(by="p-value", ascending=True)
217
218    df_pivot_perc = df_pivot_perc.T.sort_values(by=list(df_pivot_perc.index)).T
219
220    posthoc_text = "\n".join(
221        [
222            f"{row['Group 1']}{row['Group 2']}: {row['Significance_Label']}"
223            for _, row in chi_df.iterrows()
224        ]
225    )
226
227    fig, ax = plt.subplots(figsize=(12, 7))
228
229    df_pivot_perc.T.plot(kind="bar", stacked=True, ax=ax, cmap="viridis")
230
231    ax.set_ylabel("Percentage (%)", fontsize=16)
232    ax.set_xlabel("Groups", fontsize=16)
233
234    ax.tick_params(axis="both", labelsize=12)
235
236    ax.legend(
237        title="Compartment", loc="upper left", bbox_to_anchor=(1.02, 1.05), fontsize=14
238    )
239
240    plt.figtext(
241        0.93,
242        0.6,
243        posthoc_text,
244        ha="left",
245        va="top",
246        fontsize=12,
247        bbox={"facecolor": "white", "alpha": 0.7, "pad": 5},
248    )
249
250    return fig

Create a stacked bar plot of proportional data with post-hoc significance annotations.

Parameters

df_pivot : pandas.DataFrame Pivot table where rows represent categories (e.g., compartments) and columns represent groups. Values are counts or frequencies.

chi_df : pandas.DataFrame DataFrame containing pairwise Chi-square test results with an added 'Significance_Label' column (e.g., '', '', '', 'ns') for each pair of groups. Typically output from chi_pairs and get_significance_label.

Returns

matplotlib.figure.Figure The Matplotlib figure object containing the stacked bar plot.

Notes

  • The function converts raw counts to percentages per group for visualization.
  • Each pairwise comparison and its significance label is displayed as a text box next to the plot.
  • Colors are assigned using the 'viridis' colormap by default.
  • The plot is configured for clarity with labeled axes, legend, and appropriately sized text.
def get_significance_label(p_value):
253def get_significance_label(p_value):
254    """
255    Return a standard significance label based on a p-value.
256
257    Parameters
258    ----------
259    p_value : float
260        The p-value for which the significance label should be determined.
261
262    Returns
263    -------
264    str
265        A significance marker commonly used in statistical reporting:
266
267        - '***' : p < 0.001
268        - '**'  : p < 0.01
269        - '*'   : p < 0.05
270        - 'ns'  : not significant (p ≥ 0.05)
271
272    Notes
273    -----
274    This helper function is typically used for annotating statistical test
275    results in tables or visualizations. Thresholds follow conventional
276    statistical notation for significance levels.
277    """
278
279    if p_value < 0.001:
280        return "***"
281    elif p_value < 0.01:
282        return "**"
283    elif p_value < 0.05:
284        return "*"
285    else:
286        return "ns"

Return a standard significance label based on a p-value.

Parameters

p_value : float The p-value for which the significance label should be determined.

Returns

str A significance marker commonly used in statistical reporting:

- '***' : p < 0.001
- '**'  : p < 0.01
- '*'   : p < 0.05
- 'ns'  : not significant (p ≥ 0.05)

Notes

This helper function is typically used for annotating statistical test results in tables or visualizations. Thresholds follow conventional statistical notation for significance levels.

def chi_pairs(df_pivot):
289def chi_pairs(df_pivot):
290    """
291    Compute pairwise Chi-square tests for all combinations of groups in a pivoted dataframe.
292
293    Parameters
294    ----------
295    df_pivot : pandas.DataFrame
296        A pivot table where rows represent categories and columns represent groups.
297        Values should be counts (frequencies). The function will add +1 to each cell
298        to avoid zero counts during chi-square computation.
299
300    Returns
301    -------
302    pandas.DataFrame
303        A DataFrame containing pairwise Chi-square test results with the following columns:
304        - 'Group 1' : str
305            Name of the first group in the pair.
306        - 'Group 2' : str
307            Name of the second group in the pair.
308        - 'Chi²' : float
309            The Chi-square statistic for the comparison.
310        - 'p-value' : float
311            The p-value of the Chi-square test.
312
313    Notes
314    -----
315    The function compares every possible pair of columns using `scipy.stats.chi2_contingency`.
316    Yates' correction is applied by default unless disabled in the SciPy version used.
317    A value of 1 is added to all cells to avoid issues with zero frequencies.
318    """
319
320    group_pairs = list(combinations(df_pivot.columns, 2))
321
322    posthoc_results = []
323
324    for group1, group2 in group_pairs:
325        sub_table = df_pivot.T.loc[[group1, group2]] + 1
326        chi2, p, dof, expected = chi2_contingency(sub_table)
327
328        posthoc_results.append(
329            {"Group 1": group1, "Group 2": group2, "Chi²": chi2, "p-value": p}
330        )
331
332    posthoc_df = pd.DataFrame(posthoc_results)
333
334    return posthoc_df

Compute pairwise Chi-square tests for all combinations of groups in a pivoted dataframe.

Parameters

df_pivot : pandas.DataFrame A pivot table where rows represent categories and columns represent groups. Values should be counts (frequencies). The function will add +1 to each cell to avoid zero counts during chi-square computation.

Returns

pandas.DataFrame A DataFrame containing pairwise Chi-square test results with the following columns: - 'Group 1' : str Name of the first group in the pair. - 'Group 2' : str Name of the second group in the pair. - 'Chi²' : float The Chi-square statistic for the comparison. - 'p-value' : float The p-value of the Chi-square test.

Notes

The function compares every possible pair of columns using scipy.stats.chi2_contingency. Yates' correction is applied by default unless disabled in the SciPy version used. A value of 1 is added to all cells to avoid issues with zero frequencies.

def statistic(input_df, sets=None, metadata=None, n_proc=10):
337def statistic(input_df, sets=None, metadata=None, n_proc=10):
338    """
339    Compute statistical comparison between cell groups or clusters.
340
341    This function performs differential feature analysis between either:
342    (1) every group vs. all other groups (default mode), or
343    (2) two user-defined groups specified in ``sets``.
344
345    The analysis includes:
346    - Mann–Whitney U test
347    - Percentage of non-zero values
348    - Means and standard deviations
349    - Effect size metric (ESM)
350    - Benjamini–Hochberg FDR correction
351    - Fold-change and log2 fold-change
352
353
354    Parameters
355    ----------
356    input_df : pandas.DataFrame
357        Input feature matrix where rows represent features and columns represent cells.
358        The function transposes this table internally, treating columns as features.
359
360    sets : dict or None, optional
361        Mode selection:
362        - ``None`` (default): each unique label in ``metadata['sets']`` is compared
363          against all remaining groups.
364        - ``dict``: must contain exactly two keys, each mapping to a list of labels
365          belonging to each comparison group. Example:
366          ``{'A': ['T1', 'T2'], 'B': ['C1', 'C2']}``.
367
368    metadata : pandas.DataFrame, optional
369        Metadata containing at least a column ``'sets'`` with group labels
370        corresponding to columns of ``input_df``.
371
372    n_proc : int, optional
373        Number of parallel processes used for statistical computation.
374        Default is ``10``.
375
376    Returns
377    -------
378    pandas.DataFrame or None
379        A DataFrame containing statistical results for each feature, including:
380
381        - ``feature`` : str
382        - ``p_val`` : float
383        - ``adj_pval`` : float
384        - ``pct_valid`` : float
385        - ``pct_ctrl`` : float
386        - ``avg_valid`` : float
387        - ``avg_ctrl`` : float
388        - ``sd_valid`` : float
389        - ``sd_ctrl`` : float
390        - ``esm`` : float
391        - ``FC`` : float
392        - ``log(FC)`` : float
393        - ``norm_diff`` : float
394        - ``valid_group`` : str
395        - ``-log(p_val)`` : float
396
397        If ``sets`` is ``None``, results for each group are concatenated.
398
399        Returns ``None`` in case of errors or invalid parameters.
400
401    Notes
402    -----
403    - Columns containing only zeros are automatically removed.
404    - p-values equal for both groups produce ``p_val = 1``.
405    - Benjamini–Hochberg correction is applied separately within each group comparison.
406    - Fold-change is stabilized using a small, data-derived ``low_factor``.
407    - Uses ``Mann–Whitney U`` test with ``alternative='two-sided'``.
408
409    Raises
410    ------
411    None
412        All exceptions are caught internally and printed as messages.
413
414    Examples
415    --------
416    >>> df = pd.DataFrame(...)
417    >>> meta = pd.DataFrame({'sets': [...]})
418    >>> stat = statistic(df, metadata=meta)
419    >>> stat.head()
420
421    >>> # Compare two groups explicitly
422    >>> sets = {'A': ['Type1'], 'B': ['Type2']}
423    >>> stat = statistic(df, sets=sets, metadata=meta, n_proc=4)
424    """
425    try:
426        offset = 1e-100
427
428        def stat_calc(choose, feature_name):
429            target_values = choose.loc[choose["DEG"] == "target", feature_name]
430            rest_values = choose.loc[choose["DEG"] == "rest", feature_name]
431
432            pct_valid = (target_values > 0).sum() / len(target_values)
433            pct_rest = (rest_values > 0).sum() / len(rest_values)
434
435            avg_valid = np.mean(target_values)
436            avg_ctrl = np.mean(rest_values)
437            sd_valid = np.std(target_values, ddof=1)
438            sd_ctrl = np.std(rest_values, ddof=1)
439            esm = (avg_valid - avg_ctrl) / np.sqrt(((sd_valid**2 + sd_ctrl**2) / 2))
440
441            if np.sum(target_values) == np.sum(rest_values):
442                p_val = 1.0
443            else:
444                _, p_val = stats.mannwhitneyu(
445                    target_values, rest_values, alternative="two-sided"
446                )
447
448            return {
449                "feature": feature_name,
450                "p_val": p_val,
451                "pct_valid": pct_valid,
452                "pct_ctrl": pct_rest,
453                "avg_valid": avg_valid,
454                "avg_ctrl": avg_ctrl,
455                "sd_valid": sd_valid,
456                "sd_ctrl": sd_ctrl,
457                "esm": esm,
458            }
459
460        def safe_min_half(series):
461            filtered = series[(series > ((2**-1074) * 2)) & (series.notna())]
462            return filtered.min() / 2 if not filtered.empty else 0
463
464        # Transpose the input DataFrame
465        choose = input_df.copy().T
466
467        if sets is None:
468            print("\nAnalysis started...")
469            print("\nComparing each type of cell to others...")
470            final_results = []
471
472            if len(set(metadata["sets"])) > 1:
473                choose.index = metadata["sets"]
474
475            indexes = list(choose.index)
476
477            for c in set(indexes):
478                print(f"Calculating statistics for {c}")
479
480                choose.index = indexes
481                choose["DEG"] = np.where(choose.index == c, "target", "rest")
482
483                valid = ",".join(set(choose.index[choose["DEG"] == "target"]))
484                choose = choose.loc[
485                    :, (choose != 0).any(axis=0)
486                ]  # Remove all-zero columns
487
488                # Parallel computation
489                results = Parallel(n_jobs=n_proc)(
490                    delayed(stat_calc)(choose, feature)
491                    for feature in tqdm(choose.columns[choose.columns != "DEG"])
492                )
493
494                # Convert results to DataFrame
495                combined_df = pd.DataFrame(results)
496                combined_df = combined_df[
497                    (combined_df["avg_valid"] > 0) | (combined_df["avg_ctrl"] > 0)
498                ]
499
500                combined_df["valid_group"] = valid
501                combined_df.sort_values(by="p_val", inplace=True)
502
503                # Adjusted p-values using Benjamini-Hochberg method
504                num_tests = len(combined_df)
505                combined_df["adj_pval"] = np.minimum(
506                    1, (combined_df["p_val"] * num_tests) / np.arange(1, num_tests + 1)
507                )
508
509                combined_df["p_val"] = combined_df["p_val"].replace(0, 2**-1074)
510
511                combined_df["-log(p_val)"] = -np.log10(combined_df["p_val"])
512
513                valid_factor = safe_min_half(combined_df["avg_valid"])
514
515                ctrl_factor = safe_min_half(combined_df["avg_ctrl"])
516
517                cv_factor = min(valid_factor, ctrl_factor)
518
519                if cv_factor == 0:
520                    cv_factor = max(valid_factor, ctrl_factor)
521
522                if not np.isfinite(cv_factor) or cv_factor == 0:
523                    cv_factor += offset
524
525                valid = combined_df["avg_valid"].where(
526                    combined_df["avg_valid"] != 0, combined_df["avg_valid"] + cv_factor
527                )
528
529                ctrl = combined_df["avg_ctrl"].where(
530                    combined_df["avg_ctrl"] != 0, combined_df["avg_ctrl"] + cv_factor
531                )
532
533                combined_df["FC"] = valid / ctrl
534
535                combined_df["log(FC)"] = np.log2(combined_df["FC"])
536                combined_df["norm_diff"] = (
537                    combined_df["avg_valid"] - combined_df["avg_ctrl"]
538                )
539
540                final_results.append(combined_df)
541
542            print("\nAnalysis has finished!")
543            return pd.concat(final_results, ignore_index=True)
544
545        elif isinstance(sets, dict):
546            print("\nAnalysis started...")
547            print("\nComparing groups...")
548
549            group_list = list(sets.keys())
550            choose.index = metadata["sets"]
551
552            inx = sorted([item for sublist in sets.values() for item in sublist])
553            choose = choose.loc[inx]
554
555            full_df = pd.DataFrame()
556            for n, g in enumerate(group_list):
557                print(f"Calculating statistics for {g}")
558
559                rest_indices = [
560                    idx
561                    for i, group in enumerate(group_list)
562                    if i != n
563                    for idx in sets[group]
564                ]
565
566                choose["DEG"] = np.where(
567                    choose.index.isin(sets[g]),
568                    "target",
569                    np.where(choose.index.isin(rest_indices), "rest", "drop"),
570                )
571
572                choose = choose[choose["DEG"] != "drop"]
573
574                valid = g
575                choose = choose.loc[
576                    :, (choose != 0).any(axis=0)
577                ]  # Remove all-zero columns
578
579                # Parallel computation
580                results = Parallel(n_jobs=n_proc)(
581                    delayed(stat_calc)(choose, feature)
582                    for feature in tqdm(choose.columns[choose.columns != "DEG"])
583                )
584
585                # Convert results to DataFrame
586                combined_df = pd.DataFrame(results)
587                combined_df["valid_group"] = valid
588                combined_df.sort_values(by="p_val", inplace=True)
589
590                # Adjusted p-values using Benjamini-Hochberg method
591                num_tests = len(combined_df)
592                combined_df["adj_pval"] = np.minimum(
593                    1, (combined_df["p_val"] * num_tests) / np.arange(1, num_tests + 1)
594                )
595
596                combined_df["-log(p_val)"] = -np.log10(offset + combined_df["p_val"])
597
598                valid_factor = safe_min_half(combined_df["avg_valid"])
599
600                ctrl_factor = safe_min_half(combined_df["avg_ctrl"])
601
602                cv_factor = min(valid_factor, ctrl_factor)
603
604                if cv_factor == 0:
605                    cv_factor = max(valid_factor, ctrl_factor)
606
607                if not np.isfinite(cv_factor) or cv_factor == 0:
608                    cv_factor += offset
609
610                valid = combined_df["avg_valid"].where(
611                    combined_df["avg_valid"] != 0, combined_df["avg_valid"] + cv_factor
612                )
613
614                ctrl = combined_df["avg_ctrl"].where(
615                    combined_df["avg_ctrl"] != 0, combined_df["avg_ctrl"] + cv_factor
616                )
617
618                combined_df["FC"] = valid / ctrl
619
620                combined_df["log(FC)"] = np.log2(combined_df["FC"])
621                combined_df["norm_diff"] = (
622                    combined_df["avg_valid"] - combined_df["avg_ctrl"]
623                ) + offset
624
625                full_df = pd.concat([full_df, combined_df])
626
627            print("\nAnalysis has finished!")
628            return full_df
629
630        else:
631            print("\nInvalid parameters. Please check the input.")
632            return None
633
634    except Exception as e:
635        print(f"Error: {e}")
636        return None

Compute statistical comparison between cell groups or clusters.

This function performs differential feature analysis between either: (1) every group vs. all other groups (default mode), or (2) two user-defined groups specified in sets.

The analysis includes:

  • Mann–Whitney U test
  • Percentage of non-zero values
  • Means and standard deviations
  • Effect size metric (ESM)
  • Benjamini–Hochberg FDR correction
  • Fold-change and log2 fold-change

Parameters

input_df : pandas.DataFrame Input feature matrix where rows represent features and columns represent cells. The function transposes this table internally, treating columns as features.

sets : dict or None, optional Mode selection: - None (default): each unique label in metadata['sets'] is compared against all remaining groups. - dict: must contain exactly two keys, each mapping to a list of labels belonging to each comparison group. Example: {'A': ['T1', 'T2'], 'B': ['C1', 'C2']}.

metadata : pandas.DataFrame, optional Metadata containing at least a column 'sets' with group labels corresponding to columns of input_df.

n_proc : int, optional Number of parallel processes used for statistical computation. Default is 10.

Returns

pandas.DataFrame or None A DataFrame containing statistical results for each feature, including:

- ``feature`` : str
- ``p_val`` : float
- ``adj_pval`` : float
- ``pct_valid`` : float
- ``pct_ctrl`` : float
- ``avg_valid`` : float
- ``avg_ctrl`` : float
- ``sd_valid`` : float
- ``sd_ctrl`` : float
- ``esm`` : float
- ``FC`` : float
- ``log(FC)`` : float
- ``norm_diff`` : float
- ``valid_group`` : str
- ``-log(p_val)`` : float

If ``sets`` is ``None``, results for each group are concatenated.

Returns ``None`` in case of errors or invalid parameters.

Notes

  • Columns containing only zeros are automatically removed.
  • p-values equal for both groups produce p_val = 1.
  • Benjamini–Hochberg correction is applied separately within each group comparison.
  • Fold-change is stabilized using a small, data-derived low_factor.
  • Uses Mann–Whitney U test with alternative='two-sided'.

Raises

None All exceptions are caught internally and printed as messages.

Examples

>>> df = pd.DataFrame(...)
>>> meta = pd.DataFrame({'sets': [...]})
>>> stat = statistic(df, metadata=meta)
>>> stat.head()
>>> # Compare two groups explicitly
>>> sets = {'A': ['Type1'], 'B': ['Type2']}
>>> stat = statistic(df, sets=sets, metadata=meta, n_proc=4)
def save_image(image, path_to_save):
642def save_image(image, path_to_save):
643    """
644    Save an image to disk.
645
646    Parameters
647    ----------
648    image : np.ndarray
649        Input image array to be saved.
650
651    path_to_save : str
652        Output file path including the filename.
653        Must include one of the following extensions:
654        ``.png``, ``.tiff`` or ``.tif``.
655
656    Returns
657    -------
658    str
659        Path to the saved file.
660    """
661
662    try:
663        if (
664            len(path_to_save) == 0
665            or ".png" not in path_to_save
666            or ".tiff" not in path_to_save
667            or ".tif" not in path_to_save
668        ):
669            print(
670                "\nThe path is not provided or the file extension is not *.png, *.tiff or *.tif"
671            )
672        else:
673            cv2.imwrite(path_to_save, image)
674
675    except:
676        print("Something went wrong. Check the function input data and try again!")

Save an image to disk.

Parameters

image : np.ndarray Input image array to be saved.

path_to_save : str Output file path including the filename. Must include one of the following extensions: .png, .tiff or .tif.

Returns

str Path to the saved file.

def load_tiff(path_to_tiff: str):
679def load_tiff(path_to_tiff: str):
680    """
681    Load a *.tiff image and ensure it is in 16-bit format.
682
683    Parameters
684    ----------
685    path_to_tiff : str
686        Path to the *.tiff file.
687
688    Returns
689    -------
690    np.ndarray
691        Loaded image stack, converted to 16-bit if necessary.
692    """
693
694    try:
695        stack = tiff.imread(path_to_tiff)
696
697        if stack.dtype != "uint16":
698
699            stack = stack.astype(np.uint16)
700
701            for n, _ in enumerate(stack):
702
703                min_val = np.min(stack[n])
704                max_val = np.max(stack[n])
705
706                stack[n] = ((stack[n] - min_val) / (max_val - min_val) * 65535).astype(
707                    np.uint16
708                )
709
710                stack[n] = np.clip(stack[n], 0, 65535)
711
712        return stack
713
714    except:
715        print("Something went wrong. Check the function input data and try again!")

Load a *.tiff image and ensure it is in 16-bit format.

Parameters

path_to_tiff : str Path to the *.tiff file.

Returns

np.ndarray Loaded image stack, converted to 16-bit if necessary.

def z_projection(tiff_object, projection_type='avg'):
718def z_projection(tiff_object, projection_type="avg"):
719    """
720    Perform Z-projection on a stacked (3D) image.
721
722    Parameters
723    ----------
724    tiff_object : np.ndarray
725        Input stacked 3D image (e.g., loaded with `load_tiff()`).
726
727    projection_type : str
728        Type of Z-axis projection. Options: 'avg', 'median', 'min', 'max', 'std'.
729
730    Returns
731    -------
732    np.ndarray
733        Resulting 2D projection image.
734    """
735
736    try:
737
738        if projection_type == "avg":
739            img = np.mean(tiff_object, axis=0).astype(np.uint16)
740        elif projection_type == "max":
741            img = np.max(tiff_object, axis=0).astype(np.uint16)
742        elif projection_type == "min":
743            img = np.min(tiff_object, axis=0).astype(np.uint16)
744        elif projection_type == "std":
745            img = np.std(tiff_object, axis=0).astype(np.uint16)
746        elif projection_type == "median":
747            img = np.median(tiff_object, axis=0).astype(np.uint16)
748
749        return img
750
751    except:
752        print("Something went wrong. Check the function input data and try again!")

Perform Z-projection on a stacked (3D) image.

Parameters

tiff_object : np.ndarray Input stacked 3D image (e.g., loaded with load_tiff()).

projection_type : str Type of Z-axis projection. Options: 'avg', 'median', 'min', 'max', 'std'.

Returns

np.ndarray Resulting 2D projection image.

def clahe_16bit(img, kernal=(100, 100)):
755def clahe_16bit(img, kernal=(100, 100)):
756    """
757    Apply CLAHE (Contrast Limited Adaptive Histogram Equalization) to an image.
758
759    Parameters
760    ----------
761    img : np.ndarray
762        Input image.
763
764    Returns
765    -------
766    img : np.ndarray
767        Image after CLAHE enhancement.
768
769    kernel : tuple
770        Size of the CLAHE tile grid used for processing, e.g., (100, 100).
771    """
772
773    try:
774
775        img = img.copy()
776
777        img8bit = img.copy()
778
779        min_val = np.min(img8bit)
780        max_val = np.max(img8bit)
781
782        img8bit = ((img8bit - min_val) / (max_val - min_val) * 255).astype(np.uint8)
783
784        clahe = cv2.createCLAHE(clipLimit=10, tileGridSize=kernal)
785        img8bit = clahe.apply(img8bit)
786
787        img8bit = img8bit / 255
788
789        img = img * img8bit
790
791        min_val = np.min(img)
792        max_val = np.max(img)
793
794        img = ((img - min_val) / (max_val - min_val) * 65535).astype(np.uint16)
795
796        return img
797
798    except:
799        print("Something went wrong. Check the function input data and try again!")

Apply CLAHE (Contrast Limited Adaptive Histogram Equalization) to an image.

Parameters

img : np.ndarray Input image.

Returns

img : np.ndarray Image after CLAHE enhancement.

kernel : tuple Size of the CLAHE tile grid used for processing, e.g., (100, 100).

def equalizeHist_16bit(image_eq):
802def equalizeHist_16bit(image_eq):
803    """
804    Apply global histogram equalization to an image.
805
806    Parameters
807    ----------
808    image_eq : np.ndarray
809        Input image.
810
811    Returns
812    -------
813    np.ndarray
814        Image after global histogram equalization.
815    """
816
817    try:
818
819        image = image_eq.copy()
820
821        min_val = np.min(image)
822        max_val = np.max(image)
823
824        scaled_image = ((image - min_val) / (max_val - min_val) * 255).astype(np.uint8)
825
826        eq_image = cv2.equalizeHist(scaled_image)
827
828        eq_image_bin = eq_image / 255
829
830        image_eq_16 = image * eq_image_bin
831        image_eq_16 = (image_eq_16 / np.max(image_eq_16)) * 65535
832        image_eq_16[image_eq_16 > (65535 / 2)] += 65535 - np.max(image_eq_16)
833        image_eq_16 = image_eq_16.astype(np.uint16)
834
835        return image_eq_16
836
837    except:
838        print("Something went wrong. Check the function input data and try again!")

Apply global histogram equalization to an image.

Parameters

image_eq : np.ndarray Input image.

Returns

np.ndarray Image after global histogram equalization.

def load_image(path):
841def load_image(path):
842    """
843    Load an image and convert it to 16-bit if necessary.
844
845    Parameters
846    ----------
847    path : str
848        Path to the image file.
849
850    Returns
851    -------
852    np.ndarray
853        Loaded 16-bit image.
854    """
855
856    try:
857
858        img = cv2.imread(path, cv2.IMREAD_ANYDEPTH | cv2.IMREAD_COLOR)
859
860        # convert to 16 bit (the function are working on 16 bit images!)
861        if img.dtype != "uint16":
862
863            min_val = np.min(img)
864            max_val = np.max(img)
865
866            img = ((img - min_val) / (max_val - min_val) * 65535).astype(np.uint16)
867
868            img = np.clip(img, 0, 65535)
869
870        return img
871
872    except:
873        print("Something went wrong. Check the function input data and try again!")

Load an image and convert it to 16-bit if necessary.

Parameters

path : str Path to the image file.

Returns

np.ndarray Loaded 16-bit image.

def rotate_image(img, rotate: int):
876def rotate_image(img, rotate: int):
877    """
878    Rotate an image by a specified angle.
879
880    Parameters
881    ----------
882    img : np.ndarray
883        Image to rotate.
884
885    rotate : int
886        Degree of rotation. Available options: 90, 180, 270.
887
888    Returns
889    -------
890    np.ndarray
891        Rotated image.
892    """
893
894    try:
895
896        if rotate == 0:
897            r = 0
898        elif rotate == 90:
899            r = -1
900        elif rotate == 180:
901            r = 2
902        elif rotate == 180:
903            r = 2
904        elif rotate == 270:
905            r = 1
906        else:
907            print("Wrong argument - rotate!")
908            return None
909
910        img = img.copy()
911
912        img = np.rot90(img.copy(), k=r)
913
914        return img
915
916    except:
917        print("Something went wrong. Check the function input data and try again!")

Rotate an image by a specified angle.

Parameters

img : np.ndarray Image to rotate.

rotate : int Degree of rotation. Available options: 90, 180, 270.

Returns

np.ndarray Rotated image.

def mirror_image(img, rotate: str):
920def mirror_image(img, rotate: str):
921    """
922    Mirror an image along specified axes.
923
924    Parameters
925    ----------
926    img : np.ndarray
927        Image to mirror.
928
929    rotate : str
930        Type of mirroring to apply. Options:
931            'h'  - horizontal mirroring
932            'v'  - vertical mirroring
933            'hv' - both horizontal and vertical mirroring
934
935    Returns
936    -------
937    np.ndarray
938        Mirrored image.
939    """
940
941    try:
942
943        if rotate == "h":
944            img = np.fliplr(img.copy())
945        elif rotate == "v":
946            img = np.flipud(img.copy())
947        elif rotate == "hv":
948            img = np.flipud(np.fliplr(img.copy()))
949        else:
950            print("Wrong argument - rotate!")
951            return None
952
953        return img
954
955    except:
956        print("Something went wrong. Check the function input data and try again!")

Mirror an image along specified axes.

Parameters

img : np.ndarray Image to mirror.

rotate : str Type of mirroring to apply. Options: 'h' - horizontal mirroring 'v' - vertical mirroring 'hv' - both horizontal and vertical mirroring

Returns

np.ndarray Mirrored image.

def merge_images(image_list: list, intensity_factors: list = []):
 962def merge_images(image_list: list, intensity_factors: list = []):
 963    """
 964    Merge multiple image projections from different channels.
 965
 966    Parameters
 967    ----------
 968    image_list : list of np.ndarray
 969        List of images to merge. All images must have the same shape and size.
 970
 971    intensity_factors : list of float
 972        Intensity scaling factors for each image in `image_list`. Base value is 1.
 973            - Values < 1 decrease intensity.
 974            - Values > 1 increase intensity.
 975
 976    Returns
 977    -------
 978    np.ndarray
 979        Merged image after applying intensity scaling.
 980    """
 981
 982    try:
 983
 984        result = None
 985
 986        if len(intensity_factors) == 0:
 987
 988            intensity_factors = []
 989            for bt in range(len(image_list)):
 990                intensity_factors.append(1)
 991
 992        for i, image in enumerate(image_list):
 993            if result is None:
 994                result = image.astype(np.uint64) * intensity_factors[i]
 995            else:
 996                result = cv2.addWeighted(
 997                    result, 1, image.astype(np.uint64) * intensity_factors[i], 1, 0
 998                )
 999
1000        result = np.clip(result, 0, 65535)
1001
1002        result = result.astype(np.uint16)
1003
1004        return result
1005
1006    except:
1007        print("Something went wrong. Check the function input data and try again!")

Merge multiple image projections from different channels.

Parameters

image_list : list of np.ndarray List of images to merge. All images must have the same shape and size.

intensity_factors : list of float Intensity scaling factors for each image in image_list. Base value is 1. - Values < 1 decrease intensity. - Values > 1 increase intensity.

Returns

np.ndarray Merged image after applying intensity scaling.

def adjust_img_16bit( img, color='gray', max_intensity: int = 65535, min_intenisty: int = 0, brightness: int = 1000, contrast=1.0, gamma=1.0):
1010def adjust_img_16bit(
1011    img,
1012    color="gray",
1013    max_intensity: int = 65535,
1014    min_intenisty: int = 0,
1015    brightness: int = 1000,
1016    contrast=1.0,
1017    gamma=1.0,
1018):
1019    """
1020    Manually adjust image parameters and return the adjusted image.
1021
1022    Parameters
1023    ----------
1024    img : np.ndarray
1025        Input image.
1026
1027    color : str
1028        Image color channel. Options: ['green', 'blue', 'red', 'yellow', 'magenta', 'cyan'].
1029
1030    max_intensity : int
1031        Upper threshold for pixel values. Pixels exceeding this value are set to `max_intensity`.
1032
1033    min_intensity : int
1034        Lower threshold for pixel values. Pixels below this value are set to 0.
1035
1036    brightness : int, optional
1037        Image brightness adjustment. Typical range: [900-2000]. Default is 1000.
1038
1039    contrast : float or int, optional
1040        Image contrast adjustment. Typical range: [0-5]. Default is 1.
1041
1042    gamma : float or int, optional
1043        Gamma correction factor. Typical range: [0-5]. Default is 1.
1044
1045    Returns
1046    -------
1047    np.ndarray
1048        Adjusted image after applying brightness, contrast, gamma, and intensity thresholds.
1049    """
1050
1051    try:
1052
1053        img = img.copy()
1054
1055        img = img.astype(np.uint64)
1056
1057        img = np.clip(img, 0, 65535)
1058
1059        # brightness
1060        if brightness != 1000:
1061            factor = -1000 + brightness
1062            side = factor / abs(factor)
1063            img[img > 0] = img[img > 0] + ((img[img > 0] * abs(factor) / 100) * side)
1064            img = np.clip(img, 0, 65535)
1065
1066        # contrast
1067        if contrast != 1:
1068            img = ((img - np.mean(img)) * contrast) + np.mean(img)
1069            img = np.clip(img, 0, 65535)
1070
1071        # gamma
1072        if gamma != 1:
1073
1074            max_val = np.max(img)
1075
1076            image_array = img.copy() / max_val
1077
1078            image_array = np.clip(image_array, 0, 1)
1079
1080            corrected_array = image_array ** (1 / gamma)
1081
1082            img = corrected_array * max_val
1083
1084            del image_array, corrected_array
1085
1086            img = np.clip(img, 0, 65535)
1087
1088        img = np.nan_to_num(img, nan=0, posinf=65535, neginf=0)
1089        max_val = np.max(img)
1090        if max_val > 0:
1091            img = ((img / max_val) * 65535).astype(np.uint16)
1092
1093        # max intenisty
1094        if max_intensity != 65535:
1095            img[img >= max_intensity] = 65535
1096
1097        # min intenisty
1098        if min_intenisty != 0:
1099            img[img <= min_intenisty] = 0
1100
1101        img_gamma = np.zeros((img.shape[0], img.shape[1], 3), dtype=np.uint16)
1102
1103        if color == "green":
1104            img_gamma[:, :, 1] = img
1105
1106        elif color == "red":
1107            img_gamma[:, :, 2] = img
1108
1109        elif color == "blue":
1110            img_gamma[:, :, 0] = img
1111
1112        elif color == "magenta":
1113            img_gamma[:, :, 0] = img
1114            img_gamma[:, :, 2] = img
1115
1116        elif color == "yellow":
1117            img_gamma[:, :, 1] = img
1118            img_gamma[:, :, 2] = img
1119
1120        elif color == "cyan":
1121            img_gamma[:, :, 0] = img
1122            img_gamma[:, :, 1] = img
1123
1124        elif color == "gray":
1125            img_gamma[:, :, 0] = img
1126            img_gamma[:, :, 1] = img
1127            img_gamma[:, :, 2] = img
1128
1129        return img_gamma
1130
1131    except:
1132        print("Something went wrong. Check the function input data and try again!")

Manually adjust image parameters and return the adjusted image.

Parameters

img : np.ndarray Input image.

color : str Image color channel. Options: ['green', 'blue', 'red', 'yellow', 'magenta', 'cyan'].

max_intensity : int Upper threshold for pixel values. Pixels exceeding this value are set to max_intensity.

min_intensity : int Lower threshold for pixel values. Pixels below this value are set to 0.

brightness : int, optional Image brightness adjustment. Typical range: [900-2000]. Default is 1000.

contrast : float or int, optional Image contrast adjustment. Typical range: [0-5]. Default is 1.

gamma : float or int, optional Gamma correction factor. Typical range: [0-5]. Default is 1.

Returns

np.ndarray Adjusted image after applying brightness, contrast, gamma, and intensity thresholds.

def get_screan():
1135def get_screan():
1136    """
1137    Get the current screen width and height.
1138
1139    Returns
1140    -------
1141    tuple of int
1142        screen_width, screen_height
1143    """
1144
1145    root = tk.Tk()
1146    screen_width = root.winfo_screenwidth()
1147    screen_height = root.winfo_screenheight()
1148
1149    root.destroy()
1150
1151    return screen_width, screen_height

Get the current screen width and height.

Returns

tuple of int screen_width, screen_height

def resize_to_screen_img(img_file, factor=1):
1154def resize_to_screen_img(img_file, factor=1):
1155    """
1156    Resize an input image to fit the screen size, optionally scaled by a factor.
1157
1158    Parameters
1159    ----------
1160    img_file : np.ndarray
1161        Input image to be resized.
1162
1163    factor : float, optional
1164        Scaling factor applied to the screen dimensions before resizing the image. Default is 1.
1165
1166    Returns
1167    -------
1168    np.ndarray
1169        Resized image that fits within the screen dimensions while maintaining aspect ratio.
1170    """
1171
1172    screen_width, screen_height = get_screan()
1173
1174    screen_width = int(screen_width * factor)
1175    screen_height = int(screen_height * factor)
1176
1177    h = int(img_file.shape[0])
1178    w = int(img_file.shape[1])
1179
1180    if screen_width < w:
1181        h = img_file.shape[0]
1182        w = img_file.shape[1]
1183
1184        ww = int((screen_width / w) * w)
1185        hh = int((screen_width / w) * h)
1186
1187        img_file = cv2.resize(img_file, (ww, hh))
1188
1189        h = img_file.shape[0]
1190        w = img_file.shape[1]
1191
1192    if screen_height < h:
1193        h = img_file.shape[0]
1194        w = img_file.shape[1]
1195
1196        ww = int((screen_height / h) * w)
1197        hh = int((screen_height / h) * h)
1198
1199        img_file = cv2.resize(img_file, (ww, hh))
1200
1201    return img_file

Resize an input image to fit the screen size, optionally scaled by a factor.

Parameters

img_file : np.ndarray Input image to be resized.

factor : float, optional Scaling factor applied to the screen dimensions before resizing the image. Default is 1.

Returns

np.ndarray Resized image that fits within the screen dimensions while maintaining aspect ratio.

def display_preview(image):
1204def display_preview(image):
1205    """
1206    Quickly preview an image using a display window.
1207
1208    Parameters
1209    ----------
1210    image : np.ndarray
1211        Input image to be displayed.
1212
1213    Notes
1214    -----
1215        The function displays the image in a window. Does not return a value.
1216    """
1217    try:
1218
1219        res_sc = resize_to_screen_img(image.copy(), factor=0.8)
1220
1221        cv2.imshow("Display", res_sc)
1222
1223        cv2.waitKey(0) & 0xFF
1224
1225        cv2.destroyAllWindows()
1226
1227    except:
1228        print("Something went wrong. Check the function input data and try again!")

Quickly preview an image using a display window.

Parameters

image : np.ndarray Input image to be displayed.

Notes

The function displays the image in a window. Does not return a value.
class ImageTools:
1234class ImageTools:
1235    """
1236    Collection of utility functions for image preprocessing, adjustment,
1237    merging, and stitching.
1238
1239    This class provides standalone static methods for operations such as
1240    histogram equalization, CLAHE enhancement, intensity adjustments,
1241    image merging based on weighted ratios, and horizontal stitching.
1242    These tools are used internally by `ImagesManagement` but can also be
1243    applied independently for generic image-processing workflows.
1244
1245    Notes
1246    -----
1247    All methods operate on NumPy arrays and expect images in standard
1248    OpenCV format. Some functions assume 16-bit images unless stated
1249    otherwise.
1250
1251    Examples
1252    --------
1253    >>> from ImageTools import ImageTools
1254    >>> img = ImageTools.equalize_hist_16bit(img)
1255    >>> merged = ImageTools.merge_images([img1, img2], [0.5, 0.5])
1256    """
1257
1258    def get_screan(self):
1259        """
1260        Return the current screen size.
1261
1262        Returns
1263        -------
1264        screen_width : int
1265            Width of the screen in pixels.
1266
1267        screen_height : int
1268            Height of the screen in pixels.
1269        """
1270
1271        root = tk.Tk()
1272        screen_width = root.winfo_screenwidth()
1273        screen_height = root.winfo_screenheight()
1274
1275        root.destroy()
1276
1277        return screen_width, screen_height
1278
1279    def resize_to_screen_img(self, img_file, factor=0.5):
1280        """
1281        Resize an input image to a scaled version of the current screen size.
1282
1283        Parameters
1284        ----------
1285        img_file : np.ndarray
1286            Input image to be resized.
1287
1288        factor : int
1289            Scaling factor applied to the screen dimensions.
1290
1291        Returns
1292        -------
1293        resized_image : np.ndarray
1294            Resized image adjusted to the scaled screen size.
1295        """
1296
1297        screen_width, screen_height = self.get_screan()
1298
1299        screen_width = int(screen_width * factor)
1300        screen_height = int(screen_height * factor)
1301
1302        h = int(img_file.shape[0])
1303        w = int(img_file.shape[1])
1304
1305        if screen_width < w or screen_width * 0.3 > w:
1306            h = img_file.shape[0]
1307            w = img_file.shape[1]
1308
1309            ww = int((screen_width / w) * w)
1310            hh = int((screen_width / w) * h)
1311
1312            img_file = cv2.resize(img_file, (ww, hh))
1313
1314            h = img_file.shape[0]
1315            w = img_file.shape[1]
1316
1317        if screen_height < h or screen_height * 0.3 > h:
1318            h = img_file.shape[0]
1319            w = img_file.shape[1]
1320
1321            ww = int((screen_height / h) * w)
1322            hh = int((screen_height / h) * h)
1323
1324            img_file = cv2.resize(img_file, (ww, hh))
1325
1326        return img_file
1327
1328    def load_JIMG_project(self, project_path):
1329        """
1330        Load a JIMG project from a `.pjm` file.
1331
1332        Parameters
1333        ----------
1334        file_path : str
1335            Path to the project file with `.pjm` extension.
1336
1337        Returns
1338        -------
1339        project : object
1340            Loaded project object.
1341
1342        Raises
1343        ------
1344        ValueError
1345            If the provided file does not have a `.pjm` extension.
1346        """
1347
1348        if ".pjm" in project_path:
1349            with open(project_path, "rb") as file:
1350                app_metadata_tmp = pickle.load(file)
1351
1352            return app_metadata_tmp
1353
1354        else:
1355            print("\nProvide path to the project metadata file with *.pjm extension!!!")
1356
1357    def ajd_mask_size(self, image, mask):
1358        """
1359        Adjusts the size of a mask to match the dimensions of a given image.
1360
1361        Parameters
1362        ----------
1363        image : np.ndarray
1364            Reference image whose size the mask should match. Can be 2D or 3D.
1365
1366        mask : np.ndarray
1367            Mask image to be resized.
1368
1369        Returns
1370        -------
1371        np.ndarray
1372            Resized mask matching the input image dimensions.
1373        """
1374
1375        try:
1376            mask = cv2.resize(mask, (image.shape[2], image.shape[1]))
1377        except:
1378            mask = cv2.resize(mask, (image.shape[1], image.shape[0]))
1379
1380        return mask
1381
1382    def load_image(self, path_to_image):
1383        """
1384        Load an image from the specified file path.
1385
1386        Parameters
1387        ----------
1388        path_to_image : str
1389            Path to the image file.
1390
1391        Returns
1392        -------
1393        image : np.ndarray
1394            Loaded image as a NumPy array.
1395        """
1396
1397        image = load_image(path_to_image)
1398        return image
1399
1400    def load_3D_tiff(self, path_to_image):
1401        """
1402        Load a 3D image from a TIFF file.
1403
1404        Parameters
1405        ----------
1406        path_to_image : str
1407            Path to the 3D TIFF image file.
1408
1409        Returns
1410        -------
1411        image : np.ndarray
1412            Loaded 3D image as a NumPy array.
1413        """
1414
1415        image = load_tiff(path_to_image)
1416
1417        return image
1418
1419    def load_mask(self, path_to_mask):
1420        """
1421        Load a mask image.
1422
1423        Parameters
1424        ----------
1425        path_to_mask : str
1426            Path to the mask image file.
1427
1428        Returns
1429        -------
1430        mask : np.ndarray
1431            Loaded mask image as a NumPy array.
1432        """
1433
1434        mask = cv2.imread(path_to_mask, cv2.IMREAD_GRAYSCALE)
1435        return mask
1436
1437    def save(self, image, file_name):
1438        """
1439        Save an image to disk.
1440
1441        Parameters
1442        ----------
1443        image : np.ndarray
1444            Image data to be saved.
1445
1446        file_name : str
1447            Output file path including the desired image extension (e.g., ".png", ".jpg").
1448
1449        """
1450
1451        cv2.imwrite(filename=file_name, img=image)
1452
1453    # calculation methods
1454
1455    def drop_dict(self, dictionary, key, var, action=None):
1456        """
1457        Filters elements from a dictionary based on a condition applied to a specified key.
1458
1459        Parameters
1460        ----------
1461        dictionary : dict
1462            Dictionary containing lists or arrays under each key.
1463
1464        key : str
1465            The key in the dictionary on which the filtering condition will be applied.
1466
1467        var : numeric
1468            Value to compare against each element in dictionary[key].
1469
1470        action : str, optional
1471            Comparison operator as string: '<=', '>=', '==', '<', '>'.
1472            Default is None, which raises an error.
1473
1474        Returns
1475        -------
1476        dict
1477            A new dictionary with elements removed where the condition matches.
1478        """
1479
1480        dictionary = copy.deepcopy(dictionary)
1481        indices_to_drop = []
1482        for i, dr in enumerate(dictionary[key]):
1483
1484            if isinstance(dr, np.ndarray):
1485                dr = np.mean(dr)
1486
1487            if action == "<=":
1488                if var <= dr:
1489                    indices_to_drop.append(i)
1490            elif action == ">=":
1491                if var >= dr:
1492                    indices_to_drop.append(i)
1493            elif action == "==":
1494                if var == dr:
1495                    indices_to_drop.append(i)
1496            elif action == "<":
1497                if var < dr:
1498                    indices_to_drop.append(i)
1499            elif action == ">":
1500                if var > dr:
1501                    indices_to_drop.append(i)
1502            else:
1503                print("\nWrong action!")
1504                return None
1505
1506        for key, value in dictionary.items():
1507            dictionary[key] = [
1508                v for i, v in enumerate(value) if i not in indices_to_drop
1509            ]
1510
1511        return dictionary
1512
1513    # modified for gradation (separation near nucleus)
1514
1515    def create_mask(self, dictionary, image):
1516        """
1517        Creates a mask image with gradated intensity values for each coordinate set in a dictionary.
1518
1519        Parameters
1520        ----------
1521        dictionary : dict
1522            Dictionary containing a 'coords' key with a list of arrays of coordinates.
1523
1524        image : np.ndarray
1525            Base image to define the shape of the mask.
1526
1527        Returns
1528        -------
1529        np.ndarray
1530            Mask image with uint16 gradated intensity.
1531        """
1532
1533        image_mask = np.zeros(image.shape)
1534
1535        arrays_list = copy.deepcopy(dictionary["coords"])
1536
1537        if len(arrays_list) > 0:
1538
1539            initial_val = math.floor((2**16 - 1) / 4)
1540            intensity_list = ((2**16 - 1) - math.floor((2**16 - 1) / 4)) / len(
1541                arrays_list
1542            )
1543
1544            gradation = 1
1545            for arr in arrays_list:
1546                image_mask[arr[:, 0], arr[:, 1]] = initial_val + (
1547                    gradation * intensity_list
1548                )
1549                gradation += 1
1550
1551        return image_mask.astype("uint16")
1552
1553    def min_max_histograme(self, image):
1554        """
1555        Calculates histogram-based minimum and maximum intensity percentiles in an image.
1556
1557        Parameters
1558        ----------
1559        image : np.ndarray
1560            Input image for histogram analysis.
1561
1562        Returns
1563        -------
1564        min_val : float
1565            Minimum intensity percentile above zero.
1566
1567        max_val : float
1568            Maximum intensity percentile based on histogram gradient.
1569
1570        df : pd.DataFrame
1571            DataFrame containing quantiles, corresponding intensity values, and cumulative percents.
1572        """
1573
1574        q = []
1575        val = []
1576        perc = []
1577
1578        max_val = image.shape[0] * image.shape[1]
1579
1580        for n in range(0, 100, 5):
1581            q.append(n)
1582            val.append(np.quantile(image, n / 100))
1583            sum_val = np.sum(image < np.quantile(image, n / 100))
1584            pr = sum_val / max_val
1585            perc.append(pr)
1586
1587        df = pd.DataFrame({"q": q, "val": val, "perc": perc})
1588
1589        min_val = 0
1590        for i in df.index:
1591            if df["val"][i] != 0 and min_val == 0:
1592                min_val = df["perc"][i]
1593
1594        max_val = 0
1595        df = df[df["perc"] > 0]
1596        df = df.sort_values("q", ascending=False).reset_index(drop=True)
1597
1598        for i in df.index:
1599            if i > 1 and df["val"][i] * 1.5 > df["val"][i - 1]:
1600                max_val = df["perc"][i]
1601                break
1602            elif i == len(df.index) - 1:
1603                max_val = df["perc"][i]
1604
1605        return min_val, max_val, df

Collection of utility functions for image preprocessing, adjustment, merging, and stitching.

This class provides standalone static methods for operations such as histogram equalization, CLAHE enhancement, intensity adjustments, image merging based on weighted ratios, and horizontal stitching. These tools are used internally by ImagesManagement but can also be applied independently for generic image-processing workflows.

Notes

All methods operate on NumPy arrays and expect images in standard OpenCV format. Some functions assume 16-bit images unless stated otherwise.

Examples

>>> from ImageTools import ImageTools
>>> img = ImageTools.equalize_hist_16bit(img)
>>> merged = ImageTools.merge_images([img1, img2], [0.5, 0.5])
def get_screan(self):
1258    def get_screan(self):
1259        """
1260        Return the current screen size.
1261
1262        Returns
1263        -------
1264        screen_width : int
1265            Width of the screen in pixels.
1266
1267        screen_height : int
1268            Height of the screen in pixels.
1269        """
1270
1271        root = tk.Tk()
1272        screen_width = root.winfo_screenwidth()
1273        screen_height = root.winfo_screenheight()
1274
1275        root.destroy()
1276
1277        return screen_width, screen_height

Return the current screen size.

Returns

screen_width : int Width of the screen in pixels.

screen_height : int Height of the screen in pixels.

def resize_to_screen_img(self, img_file, factor=0.5):
1279    def resize_to_screen_img(self, img_file, factor=0.5):
1280        """
1281        Resize an input image to a scaled version of the current screen size.
1282
1283        Parameters
1284        ----------
1285        img_file : np.ndarray
1286            Input image to be resized.
1287
1288        factor : int
1289            Scaling factor applied to the screen dimensions.
1290
1291        Returns
1292        -------
1293        resized_image : np.ndarray
1294            Resized image adjusted to the scaled screen size.
1295        """
1296
1297        screen_width, screen_height = self.get_screan()
1298
1299        screen_width = int(screen_width * factor)
1300        screen_height = int(screen_height * factor)
1301
1302        h = int(img_file.shape[0])
1303        w = int(img_file.shape[1])
1304
1305        if screen_width < w or screen_width * 0.3 > w:
1306            h = img_file.shape[0]
1307            w = img_file.shape[1]
1308
1309            ww = int((screen_width / w) * w)
1310            hh = int((screen_width / w) * h)
1311
1312            img_file = cv2.resize(img_file, (ww, hh))
1313
1314            h = img_file.shape[0]
1315            w = img_file.shape[1]
1316
1317        if screen_height < h or screen_height * 0.3 > h:
1318            h = img_file.shape[0]
1319            w = img_file.shape[1]
1320
1321            ww = int((screen_height / h) * w)
1322            hh = int((screen_height / h) * h)
1323
1324            img_file = cv2.resize(img_file, (ww, hh))
1325
1326        return img_file

Resize an input image to a scaled version of the current screen size.

Parameters

img_file : np.ndarray Input image to be resized.

factor : int Scaling factor applied to the screen dimensions.

Returns

resized_image : np.ndarray Resized image adjusted to the scaled screen size.

def load_JIMG_project(self, project_path):
1328    def load_JIMG_project(self, project_path):
1329        """
1330        Load a JIMG project from a `.pjm` file.
1331
1332        Parameters
1333        ----------
1334        file_path : str
1335            Path to the project file with `.pjm` extension.
1336
1337        Returns
1338        -------
1339        project : object
1340            Loaded project object.
1341
1342        Raises
1343        ------
1344        ValueError
1345            If the provided file does not have a `.pjm` extension.
1346        """
1347
1348        if ".pjm" in project_path:
1349            with open(project_path, "rb") as file:
1350                app_metadata_tmp = pickle.load(file)
1351
1352            return app_metadata_tmp
1353
1354        else:
1355            print("\nProvide path to the project metadata file with *.pjm extension!!!")

Load a JIMG project from a .pjm file.

Parameters

file_path : str Path to the project file with .pjm extension.

Returns

project : object Loaded project object.

Raises

ValueError If the provided file does not have a .pjm extension.

def ajd_mask_size(self, image, mask):
1357    def ajd_mask_size(self, image, mask):
1358        """
1359        Adjusts the size of a mask to match the dimensions of a given image.
1360
1361        Parameters
1362        ----------
1363        image : np.ndarray
1364            Reference image whose size the mask should match. Can be 2D or 3D.
1365
1366        mask : np.ndarray
1367            Mask image to be resized.
1368
1369        Returns
1370        -------
1371        np.ndarray
1372            Resized mask matching the input image dimensions.
1373        """
1374
1375        try:
1376            mask = cv2.resize(mask, (image.shape[2], image.shape[1]))
1377        except:
1378            mask = cv2.resize(mask, (image.shape[1], image.shape[0]))
1379
1380        return mask

Adjusts the size of a mask to match the dimensions of a given image.

Parameters

image : np.ndarray Reference image whose size the mask should match. Can be 2D or 3D.

mask : np.ndarray Mask image to be resized.

Returns

np.ndarray Resized mask matching the input image dimensions.

def load_image(self, path_to_image):
1382    def load_image(self, path_to_image):
1383        """
1384        Load an image from the specified file path.
1385
1386        Parameters
1387        ----------
1388        path_to_image : str
1389            Path to the image file.
1390
1391        Returns
1392        -------
1393        image : np.ndarray
1394            Loaded image as a NumPy array.
1395        """
1396
1397        image = load_image(path_to_image)
1398        return image

Load an image from the specified file path.

Parameters

path_to_image : str Path to the image file.

Returns

image : np.ndarray Loaded image as a NumPy array.

def load_3D_tiff(self, path_to_image):
1400    def load_3D_tiff(self, path_to_image):
1401        """
1402        Load a 3D image from a TIFF file.
1403
1404        Parameters
1405        ----------
1406        path_to_image : str
1407            Path to the 3D TIFF image file.
1408
1409        Returns
1410        -------
1411        image : np.ndarray
1412            Loaded 3D image as a NumPy array.
1413        """
1414
1415        image = load_tiff(path_to_image)
1416
1417        return image

Load a 3D image from a TIFF file.

Parameters

path_to_image : str Path to the 3D TIFF image file.

Returns

image : np.ndarray Loaded 3D image as a NumPy array.

def load_mask(self, path_to_mask):
1419    def load_mask(self, path_to_mask):
1420        """
1421        Load a mask image.
1422
1423        Parameters
1424        ----------
1425        path_to_mask : str
1426            Path to the mask image file.
1427
1428        Returns
1429        -------
1430        mask : np.ndarray
1431            Loaded mask image as a NumPy array.
1432        """
1433
1434        mask = cv2.imread(path_to_mask, cv2.IMREAD_GRAYSCALE)
1435        return mask

Load a mask image.

Parameters

path_to_mask : str Path to the mask image file.

Returns

mask : np.ndarray Loaded mask image as a NumPy array.

def save(self, image, file_name):
1437    def save(self, image, file_name):
1438        """
1439        Save an image to disk.
1440
1441        Parameters
1442        ----------
1443        image : np.ndarray
1444            Image data to be saved.
1445
1446        file_name : str
1447            Output file path including the desired image extension (e.g., ".png", ".jpg").
1448
1449        """
1450
1451        cv2.imwrite(filename=file_name, img=image)

Save an image to disk.

Parameters

image : np.ndarray Image data to be saved.

file_name : str Output file path including the desired image extension (e.g., ".png", ".jpg").

def drop_dict(self, dictionary, key, var, action=None):
1455    def drop_dict(self, dictionary, key, var, action=None):
1456        """
1457        Filters elements from a dictionary based on a condition applied to a specified key.
1458
1459        Parameters
1460        ----------
1461        dictionary : dict
1462            Dictionary containing lists or arrays under each key.
1463
1464        key : str
1465            The key in the dictionary on which the filtering condition will be applied.
1466
1467        var : numeric
1468            Value to compare against each element in dictionary[key].
1469
1470        action : str, optional
1471            Comparison operator as string: '<=', '>=', '==', '<', '>'.
1472            Default is None, which raises an error.
1473
1474        Returns
1475        -------
1476        dict
1477            A new dictionary with elements removed where the condition matches.
1478        """
1479
1480        dictionary = copy.deepcopy(dictionary)
1481        indices_to_drop = []
1482        for i, dr in enumerate(dictionary[key]):
1483
1484            if isinstance(dr, np.ndarray):
1485                dr = np.mean(dr)
1486
1487            if action == "<=":
1488                if var <= dr:
1489                    indices_to_drop.append(i)
1490            elif action == ">=":
1491                if var >= dr:
1492                    indices_to_drop.append(i)
1493            elif action == "==":
1494                if var == dr:
1495                    indices_to_drop.append(i)
1496            elif action == "<":
1497                if var < dr:
1498                    indices_to_drop.append(i)
1499            elif action == ">":
1500                if var > dr:
1501                    indices_to_drop.append(i)
1502            else:
1503                print("\nWrong action!")
1504                return None
1505
1506        for key, value in dictionary.items():
1507            dictionary[key] = [
1508                v for i, v in enumerate(value) if i not in indices_to_drop
1509            ]
1510
1511        return dictionary

Filters elements from a dictionary based on a condition applied to a specified key.

Parameters

dictionary : dict Dictionary containing lists or arrays under each key.

key : str The key in the dictionary on which the filtering condition will be applied.

var : numeric Value to compare against each element in dictionary[key].

action : str, optional Comparison operator as string: '<=', '>=', '==', '<', '>'. Default is None, which raises an error.

Returns

dict A new dictionary with elements removed where the condition matches.

def create_mask(self, dictionary, image):
1515    def create_mask(self, dictionary, image):
1516        """
1517        Creates a mask image with gradated intensity values for each coordinate set in a dictionary.
1518
1519        Parameters
1520        ----------
1521        dictionary : dict
1522            Dictionary containing a 'coords' key with a list of arrays of coordinates.
1523
1524        image : np.ndarray
1525            Base image to define the shape of the mask.
1526
1527        Returns
1528        -------
1529        np.ndarray
1530            Mask image with uint16 gradated intensity.
1531        """
1532
1533        image_mask = np.zeros(image.shape)
1534
1535        arrays_list = copy.deepcopy(dictionary["coords"])
1536
1537        if len(arrays_list) > 0:
1538
1539            initial_val = math.floor((2**16 - 1) / 4)
1540            intensity_list = ((2**16 - 1) - math.floor((2**16 - 1) / 4)) / len(
1541                arrays_list
1542            )
1543
1544            gradation = 1
1545            for arr in arrays_list:
1546                image_mask[arr[:, 0], arr[:, 1]] = initial_val + (
1547                    gradation * intensity_list
1548                )
1549                gradation += 1
1550
1551        return image_mask.astype("uint16")

Creates a mask image with gradated intensity values for each coordinate set in a dictionary.

Parameters

dictionary : dict Dictionary containing a 'coords' key with a list of arrays of coordinates.

image : np.ndarray Base image to define the shape of the mask.

Returns

np.ndarray Mask image with uint16 gradated intensity.

def min_max_histograme(self, image):
1553    def min_max_histograme(self, image):
1554        """
1555        Calculates histogram-based minimum and maximum intensity percentiles in an image.
1556
1557        Parameters
1558        ----------
1559        image : np.ndarray
1560            Input image for histogram analysis.
1561
1562        Returns
1563        -------
1564        min_val : float
1565            Minimum intensity percentile above zero.
1566
1567        max_val : float
1568            Maximum intensity percentile based on histogram gradient.
1569
1570        df : pd.DataFrame
1571            DataFrame containing quantiles, corresponding intensity values, and cumulative percents.
1572        """
1573
1574        q = []
1575        val = []
1576        perc = []
1577
1578        max_val = image.shape[0] * image.shape[1]
1579
1580        for n in range(0, 100, 5):
1581            q.append(n)
1582            val.append(np.quantile(image, n / 100))
1583            sum_val = np.sum(image < np.quantile(image, n / 100))
1584            pr = sum_val / max_val
1585            perc.append(pr)
1586
1587        df = pd.DataFrame({"q": q, "val": val, "perc": perc})
1588
1589        min_val = 0
1590        for i in df.index:
1591            if df["val"][i] != 0 and min_val == 0:
1592                min_val = df["perc"][i]
1593
1594        max_val = 0
1595        df = df[df["perc"] > 0]
1596        df = df.sort_values("q", ascending=False).reset_index(drop=True)
1597
1598        for i in df.index:
1599            if i > 1 and df["val"][i] * 1.5 > df["val"][i - 1]:
1600                max_val = df["perc"][i]
1601                break
1602            elif i == len(df.index) - 1:
1603                max_val = df["perc"][i]
1604
1605        return min_val, max_val, df

Calculates histogram-based minimum and maximum intensity percentiles in an image.

Parameters

image : np.ndarray Input image for histogram analysis.

Returns

min_val : float Minimum intensity percentile above zero.

max_val : float Maximum intensity percentile based on histogram gradient.

df : pd.DataFrame DataFrame containing quantiles, corresponding intensity values, and cumulative percents.