jimg_int.utils

   1import copy
   2import io
   3import math
   4import os
   5import pickle
   6import re
   7import sys
   8import tarfile
   9import tkinter as tk
  10from itertools import combinations
  11
  12import cv2
  13import gdown
  14import matplotlib.pyplot as plt
  15import numpy as np
  16import pandas as pd
  17import plotly.express as px
  18import tifffile as tiff
  19from joblib import Parallel, delayed
  20from scipy import stats
  21from scipy.spatial import ConvexHull, cKDTree
  22from scipy.stats import chi2_contingency
  23from tqdm import tqdm
  24
  25sys.stdout = io.StringIO()
  26
  27
  28def umap_html(umap_result, width=1000, height=1200):
  29    """
  30    Create an interactive HTML UMAP scatter plot.
  31
  32    Parameters
  33    ----------
  34    umap_result : pandas.DataFrame or dict-like
  35        UMAP embedding containing at least:
  36        - `0` : array-like, UMAP dimension 1
  37        - `1` : array-like, UMAP dimension 2
  38        - `'clusters'` : array-like, assigned cluster labels
  39
  40    width : int, optional
  41        Width of the output figure in pixels. Default is 1000.
  42
  43    height : int, optional
  44        Height of the output figure in pixels. Default is 1200.
  45
  46    Returns
  47    -------
  48    plotly.graph_objs._figure.Figure
  49        Interactive Plotly scatter plot object visualizing UMAP with colored clusters.
  50    """
  51
  52    fig = px.scatter(
  53        x=umap_result[0],
  54        y=umap_result[1],
  55        color=umap_result["clusters"],
  56        labels={"color": "Cells"},
  57        template="simple_white",
  58        width=width,
  59        height=height,
  60        render_mode="svg",
  61        color_discrete_sequence=px.colors.qualitative.Dark24
  62        + px.colors.qualitative.Light24,
  63    )
  64
  65    fig.update_xaxes(title_text="UMAP 1")
  66    fig.update_yaxes(title_text="UMAP 2")
  67
  68    return fig
  69
  70
  71def umap_static(umap_result, width=10, height=13):
  72    """
  73    Create a static matplotlib UMAP scatter plot.
  74
  75    Parameters
  76    ----------
  77    umap_result : pandas.DataFrame or dict-like
  78        UMAP projection containing:
  79        - `0` : array-like, UMAP dimension 1
  80        - `1` : array-like, UMAP dimension 2
  81        - `'clusters'` : array-like, cluster assignments
  82
  83    width : float, optional
  84        Width of the figure in inches. Default is 10.
  85
  86    height : float, optional
  87        Height of the figure in inches. Default is 13.
  88
  89    Returns
  90    -------
  91    matplotlib.figure.Figure
  92        Static matplotlib figure representing the UMAP embedding with clusters.
  93    """
  94
  95    plotly_colors = px.colors.qualitative.Dark24 + px.colors.qualitative.Light24
  96    num_colors = len(plotly_colors)
  97
  98    fig = plt.figure(figsize=(width, height))
  99
 100    cluster_counts = {
 101        label: np.sum(umap_result["clusters"] == label)
 102        for label in np.unique(umap_result["clusters"])
 103    }
 104
 105    sorted_labels = sorted(
 106        cluster_counts, key=lambda label: cluster_counts[label], reverse=True
 107    )
 108
 109    color_map = {
 110        label: plotly_colors[i % num_colors] for i, label in enumerate(sorted_labels)
 111    }
 112
 113    for label in sorted_labels:
 114        subset = umap_result[umap_result["clusters"] == label]
 115        plt.scatter(
 116            subset[0],
 117            subset[1],
 118            c=[color_map[label]],
 119            label=f"Cluster {label}",
 120            alpha=0.7,
 121            s=20,
 122            edgecolor="black",
 123            linewidths=0.1,
 124        )
 125
 126    plt.xlabel("UMAP 1", fontsize=14)
 127    plt.ylabel("UMAP 2", fontsize=14)
 128    plt.grid(True, which="both", linestyle="--", linewidth=0.1)
 129    plt.legend(bbox_to_anchor=(1.05, 1), loc="upper left")
 130    plt.tight_layout()
 131
 132    return fig
 133
 134
 135def test_data(path=""):
 136    """
 137    Download and extract test data from Google Drive.
 138
 139    This function downloads a compressed archive containing example test data
 140    and extracts it into the specified directory. The data is fetched using
 141    a direct Google Drive link. If the download or extraction fails, an
 142    error message is printed.
 143
 144    Parameters
 145    ----------
 146    path : str, optional
 147        Destination directory where the test dataset will be downloaded and
 148        extracted. Defaults to the current working directory.
 149
 150    Notes
 151    -----
 152    - The downloaded file is named ``test_data.tar.gz``.
 153    - The archive is extracted into ``<path>/test_data``.
 154    - In case of any failure (download or extraction), the function prints
 155      an informative message instead of raising an exception.
 156    """
 157
 158    try:
 159
 160        file_name = "test_data.tar.gz"
 161
 162        file_name = os.path.join(path, file_name)
 163
 164        url = "https://drive.google.com/uc?id=1MhzhleMP7iTzlBVW8eP5sFaonJdg1a3T"
 165
 166        gdown.download(url, file_name, quiet=False)
 167
 168        # Unzip
 169
 170        with tarfile.open(file_name, "r:gz") as tar:
 171            tar.extractall(path=path)
 172
 173        print(
 174            f"\nTest data downloaded succesfully -> {os.path.join(path, 'test_data')}"
 175        )
 176
 177    except:
 178
 179        print(
 180            "\nTest data could not be downloaded. Please check your connection and try again!"
 181        )
 182
 183
 184def prop_plot(df_pivot, chi_df):
 185    """
 186    Create a stacked bar plot of proportional data with post-hoc significance annotations.
 187
 188    Parameters
 189    ----------
 190    df_pivot : pandas.DataFrame
 191        Pivot table where rows represent categories (e.g., compartments) and columns
 192        represent groups. Values are counts or frequencies.
 193
 194    chi_df : pandas.DataFrame
 195        DataFrame containing pairwise Chi-square test results with an added
 196        'Significance_Label' column (e.g., '***', '**', '*', 'ns') for each pair
 197        of groups. Typically output from `chi_pairs` and `get_significance_label`.
 198
 199    Returns
 200    -------
 201    matplotlib.figure.Figure
 202        The Matplotlib figure object containing the stacked bar plot.
 203
 204    Notes
 205    -----
 206    - The function converts raw counts to percentages per group for visualization.
 207    - Each pairwise comparison and its significance label is displayed as a text box
 208      next to the plot.
 209    - Colors are assigned using the 'viridis' colormap by default.
 210    - The plot is configured for clarity with labeled axes, legend, and appropriately
 211      sized text.
 212    """
 213
 214    df_pivot_perc = df_pivot.div(df_pivot.sum(axis=0), axis=1) * 100
 215
 216    posthoc_text = "\n".join(
 217        [
 218            f"{row['Group 1']}{row['Group 2']}: {row['Significance_Label']}"
 219            for _, row in chi_df.iterrows()
 220        ]
 221    )
 222
 223    fig, ax = plt.subplots(figsize=(12, 7))
 224
 225    df_pivot_perc.T.plot(kind="bar", stacked=True, ax=ax, cmap="viridis")
 226
 227    ax.set_ylabel("Percentage (%)", fontsize=16)
 228    ax.set_xlabel("Groups", fontsize=16)
 229
 230    ax.tick_params(axis="both", labelsize=12)
 231
 232    ax.legend(
 233        title="Compartment", loc="upper left", bbox_to_anchor=(1.02, 1.05), fontsize=14
 234    )
 235
 236    plt.figtext(
 237        0.93,
 238        0.6,
 239        posthoc_text,
 240        ha="left",
 241        va="top",
 242        fontsize=12,
 243        bbox={"facecolor": "white", "alpha": 0.7, "pad": 5},
 244    )
 245
 246    return fig
 247
 248
 249def get_significance_label(p_value):
 250    """
 251    Return a standard significance label based on a p-value.
 252
 253    Parameters
 254    ----------
 255    p_value : float
 256        The p-value for which the significance label should be determined.
 257
 258    Returns
 259    -------
 260    str
 261        A significance marker commonly used in statistical reporting:
 262
 263        - '***' : p < 0.001
 264        - '**'  : p < 0.01
 265        - '*'   : p < 0.05
 266        - 'ns'  : not significant (p ≥ 0.05)
 267
 268    Notes
 269    -----
 270    This helper function is typically used for annotating statistical test
 271    results in tables or visualizations. Thresholds follow conventional
 272    statistical notation for significance levels.
 273    """
 274
 275    if p_value < 0.001:
 276        return "***"
 277    elif p_value < 0.01:
 278        return "**"
 279    elif p_value < 0.05:
 280        return "*"
 281    else:
 282        return "ns"
 283
 284
 285def chi_pairs(df_pivot):
 286    """
 287    Compute pairwise Chi-square tests for all combinations of groups in a pivoted dataframe.
 288
 289    Parameters
 290    ----------
 291    df_pivot : pandas.DataFrame
 292        A pivot table where rows represent categories and columns represent groups.
 293        Values should be counts (frequencies). The function will add +1 to each cell
 294        to avoid zero counts during chi-square computation.
 295
 296    Returns
 297    -------
 298    pandas.DataFrame
 299        A DataFrame containing pairwise Chi-square test results with the following columns:
 300        - 'Group 1' : str
 301            Name of the first group in the pair.
 302        - 'Group 2' : str
 303            Name of the second group in the pair.
 304        - 'Chi²' : float
 305            The Chi-square statistic for the comparison.
 306        - 'p-value' : float
 307            The p-value of the Chi-square test.
 308
 309    Notes
 310    -----
 311    The function compares every possible pair of columns using `scipy.stats.chi2_contingency`.
 312    Yates' correction is applied by default unless disabled in the SciPy version used.
 313    A value of 1 is added to all cells to avoid issues with zero frequencies.
 314    """
 315
 316    group_pairs = list(combinations(df_pivot.columns, 2))
 317
 318    posthoc_results = []
 319
 320    for group1, group2 in group_pairs:
 321        sub_table = df_pivot.T.loc[[group1, group2]] + 1
 322        chi2, p, dof, expected = chi2_contingency(sub_table)
 323
 324        posthoc_results.append(
 325            {"Group 1": group1, "Group 2": group2, "Chi²": chi2, "p-value": p}
 326        )
 327
 328    posthoc_df = pd.DataFrame(posthoc_results)
 329
 330    return posthoc_df
 331
 332
 333def statistic(input_df, sets=None, metadata=None, n_proc=10):
 334    """
 335    Compute statistical comparison between cell groups or clusters.
 336
 337    This function performs differential feature analysis between either:
 338    (1) every group vs. all other groups (default mode), or
 339    (2) two user-defined groups specified in ``sets``.
 340
 341    The analysis includes:
 342    - Mann–Whitney U test
 343    - Percentage of non-zero values
 344    - Means and standard deviations
 345    - Effect size metric (ESM)
 346    - Benjamini–Hochberg FDR correction
 347    - Fold-change and log2 fold-change
 348
 349
 350    Parameters
 351    ----------
 352    input_df : pandas.DataFrame
 353        Input feature matrix where rows represent features and columns represent cells.
 354        The function transposes this table internally, treating columns as features.
 355
 356    sets : dict or None, optional
 357        Mode selection:
 358        - ``None`` (default): each unique label in ``metadata['sets']`` is compared
 359          against all remaining groups.
 360        - ``dict``: must contain exactly two keys, each mapping to a list of labels
 361          belonging to each comparison group. Example:
 362          ``{'A': ['T1', 'T2'], 'B': ['C1', 'C2']}``.
 363
 364    metadata : pandas.DataFrame, optional
 365        Metadata containing at least a column ``'sets'`` with group labels
 366        corresponding to columns of ``input_df``.
 367
 368    n_proc : int, optional
 369        Number of parallel processes used for statistical computation.
 370        Default is ``10``.
 371
 372    Returns
 373    -------
 374    pandas.DataFrame or None
 375        A DataFrame containing statistical results for each feature, including:
 376
 377        - ``feature`` : str
 378        - ``p_val`` : float
 379        - ``adj_pval`` : float
 380        - ``pct_valid`` : float
 381        - ``pct_ctrl`` : float
 382        - ``avg_valid`` : float
 383        - ``avg_ctrl`` : float
 384        - ``sd_valid`` : float
 385        - ``sd_ctrl`` : float
 386        - ``esm`` : float
 387        - ``FC`` : float
 388        - ``log(FC)`` : float
 389        - ``norm_diff`` : float
 390        - ``valid_group`` : str
 391        - ``-log(p_val)`` : float
 392
 393        If ``sets`` is ``None``, results for each group are concatenated.
 394
 395        Returns ``None`` in case of errors or invalid parameters.
 396
 397    Notes
 398    -----
 399    - Columns containing only zeros are automatically removed.
 400    - p-values equal for both groups produce ``p_val = 1``.
 401    - Benjamini–Hochberg correction is applied separately within each group comparison.
 402    - Fold-change is stabilized using a small, data-derived ``low_factor``.
 403    - Uses ``Mann–Whitney U`` test with ``alternative='two-sided'``.
 404
 405    Raises
 406    ------
 407    None
 408        All exceptions are caught internally and printed as messages.
 409
 410    Examples
 411    --------
 412    >>> df = pd.DataFrame(...)
 413    >>> meta = pd.DataFrame({'sets': [...]})
 414    >>> stat = statistic(df, metadata=meta)
 415    >>> stat.head()
 416
 417    >>> # Compare two groups explicitly
 418    >>> sets = {'A': ['Type1'], 'B': ['Type2']}
 419    >>> stat = statistic(df, sets=sets, metadata=meta, n_proc=4)
 420    """
 421    try:
 422        offset = 1e-100
 423
 424        def stat_calc(choose, feature_name):
 425            target_values = choose.loc[choose["DEG"] == "target", feature_name]
 426            rest_values = choose.loc[choose["DEG"] == "rest", feature_name]
 427
 428            pct_valid = (target_values > 0).sum() / len(target_values)
 429            pct_rest = (rest_values > 0).sum() / len(rest_values)
 430
 431            avg_valid = np.mean(target_values)
 432            avg_ctrl = np.mean(rest_values)
 433            sd_valid = np.std(target_values, ddof=1)
 434            sd_ctrl = np.std(rest_values, ddof=1)
 435            esm = (avg_valid - avg_ctrl) / np.sqrt(((sd_valid**2 + sd_ctrl**2) / 2))
 436
 437            if np.sum(target_values) == np.sum(rest_values):
 438                p_val = 1.0
 439            else:
 440                _, p_val = stats.mannwhitneyu(
 441                    target_values, rest_values, alternative="two-sided"
 442                )
 443
 444            return {
 445                "feature": feature_name,
 446                "p_val": p_val,
 447                "pct_valid": pct_valid,
 448                "pct_ctrl": pct_rest,
 449                "avg_valid": avg_valid,
 450                "avg_ctrl": avg_ctrl,
 451                "sd_valid": sd_valid,
 452                "sd_ctrl": sd_ctrl,
 453                "esm": esm,
 454            }
 455
 456        # Transpose the input DataFrame
 457        choose = input_df.copy().T
 458
 459        if sets is None:
 460            print("\nAnalysis started...")
 461            print("\nComparing each type of cell to others...")
 462            final_results = []
 463
 464            if len(set(metadata["sets"])) > 1:
 465                choose.index = metadata["sets"]
 466
 467            indexes = list(choose.index)
 468
 469            for c in set(indexes):
 470                print(f"Calculating statistics for {c}")
 471
 472                choose.index = indexes
 473                choose["DEG"] = np.where(choose.index == c, "target", "rest")
 474
 475                valid = ",".join(set(choose.index[choose["DEG"] == "target"]))
 476                choose = choose.loc[
 477                    :, (choose != 0).any(axis=0)
 478                ]  # Remove all-zero columns
 479
 480                # Parallel computation
 481                results = Parallel(n_jobs=n_proc)(
 482                    delayed(stat_calc)(choose, feature)
 483                    for feature in tqdm(choose.columns[choose.columns != "DEG"])
 484                )
 485
 486                # Convert results to DataFrame
 487                combined_df = pd.DataFrame(results)
 488                combined_df["valid_group"] = valid
 489                combined_df.sort_values(by="p_val", inplace=True)
 490
 491                # Adjusted p-values using Benjamini-Hochberg method
 492                num_tests = len(combined_df)
 493                combined_df["adj_pval"] = np.minimum(
 494                    1, (combined_df["p_val"] * num_tests) / np.arange(1, num_tests + 1)
 495                )
 496
 497                combined_df["-log(p_val)"] = -np.log10(offset + combined_df["p_val"])
 498                low_factor = (
 499                    np.min(
 500                        (combined_df["avg_valid"] + combined_df["avg_ctrl"])[
 501                            combined_df["avg_valid"] + combined_df["avg_ctrl"] > 0
 502                        ]
 503                    )
 504                    * 0.95
 505                )
 506
 507                combined_df["FC"] = (combined_df["avg_valid"] + low_factor) / (
 508                    combined_df["avg_ctrl"] + low_factor
 509                )
 510                combined_df["log(FC)"] = np.log2(combined_df["FC"])
 511                combined_df["norm_diff"] = (combined_df["avg_valid"] + low_factor) - (
 512                    combined_df["avg_ctrl"]
 513                )
 514
 515                final_results.append(combined_df)
 516
 517            print("\nAnalysis has finished!")
 518            return pd.concat(final_results, ignore_index=True)
 519
 520        elif isinstance(sets, dict):
 521            print("\nAnalysis started...")
 522            print("\nComparing groups...")
 523
 524            group_list = list(sets.keys())
 525            choose.index = metadata["sets"]
 526
 527            inx = sorted([item for sublist in sets.values() for item in sublist])
 528            choose = choose.loc[inx]
 529
 530            full_df = pd.DataFrame()
 531            for n, g in enumerate(group_list):
 532                print(f"Calculating statistics for {g}")
 533
 534                rest_indices = [
 535                    idx
 536                    for i, group in enumerate(group_list)
 537                    if i != n
 538                    for idx in sets[group]
 539                ]
 540
 541                choose["DEG"] = np.where(
 542                    choose.index.isin(sets[g]),
 543                    "target",
 544                    np.where(choose.index.isin(rest_indices), "rest", "drop"),
 545                )
 546
 547                choose = choose[choose["DEG"] != "drop"]
 548
 549                valid = g
 550                choose = choose.loc[
 551                    :, (choose != 0).any(axis=0)
 552                ]  # Remove all-zero columns
 553
 554                # Parallel computation
 555                results = Parallel(n_jobs=n_proc)(
 556                    delayed(stat_calc)(choose, feature)
 557                    for feature in tqdm(choose.columns[choose.columns != "DEG"])
 558                )
 559
 560                # Convert results to DataFrame
 561                combined_df = pd.DataFrame(results)
 562                combined_df["valid_group"] = valid
 563                combined_df.sort_values(by="p_val", inplace=True)
 564
 565                # Adjusted p-values using Benjamini-Hochberg method
 566                num_tests = len(combined_df)
 567                combined_df["adj_pval"] = np.minimum(
 568                    1, (combined_df["p_val"] * num_tests) / np.arange(1, num_tests + 1)
 569                )
 570
 571                combined_df["-log(p_val)"] = -np.log10(offset + combined_df["p_val"])
 572                low_factor = (
 573                    np.min(
 574                        (combined_df["avg_valid"] + combined_df["avg_ctrl"])[
 575                            combined_df["avg_valid"] + combined_df["avg_ctrl"] > 0
 576                        ]
 577                    )
 578                    * 0.95
 579                )
 580
 581                combined_df["FC"] = (combined_df["avg_valid"] + low_factor) / (
 582                    combined_df["avg_ctrl"] + low_factor
 583                )
 584                combined_df["log(FC)"] = np.log2(combined_df["FC"])
 585                combined_df["norm_diff"] = (
 586                    combined_df["avg_valid"] - combined_df["avg_ctrl"]
 587                )
 588
 589                full_df = pd.concat([full_df, combined_df])
 590
 591            print("\nAnalysis has finished!")
 592            return full_df
 593
 594        else:
 595            print("\nInvalid parameters. Please check the input.")
 596            return None
 597
 598    except Exception as e:
 599        print(f"Error: {e}")
 600        return None
 601
 602
 603# UTILS JIMG
 604
 605
 606def save_image(image, path_to_save):
 607    """
 608    Save an image to disk.
 609
 610    Parameters
 611    ----------
 612    image : np.ndarray
 613        Input image array to be saved.
 614
 615    path_to_save : str
 616        Output file path including the filename.
 617        Must include one of the following extensions:
 618        ``.png``, ``.tiff`` or ``.tif``.
 619
 620    Returns
 621    -------
 622    str
 623        Path to the saved file.
 624    """
 625
 626    try:
 627        if (
 628            len(path_to_save) == 0
 629            or ".png" not in path_to_save
 630            or ".tiff" not in path_to_save
 631            or ".tif" not in path_to_save
 632        ):
 633            print(
 634                "\nThe path is not provided or the file extension is not *.png, *.tiff or *.tif"
 635            )
 636        else:
 637            cv2.imwrite(path_to_save, image)
 638
 639    except:
 640        print("Something went wrong. Check the function input data and try again!")
 641
 642
 643def load_tiff(path_to_tiff: str):
 644    """
 645    Load a *.tiff image and ensure it is in 16-bit format.
 646
 647    Parameters
 648    ----------
 649    path_to_tiff : str
 650        Path to the *.tiff file.
 651
 652    Returns
 653    -------
 654    np.ndarray
 655        Loaded image stack, converted to 16-bit if necessary.
 656    """
 657
 658    try:
 659        stack = tiff.imread(path_to_tiff)
 660
 661        if stack.dtype != "uint16":
 662
 663            stack = stack.astype(np.uint16)
 664
 665            for n, _ in enumerate(stack):
 666
 667                min_val = np.min(stack[n])
 668                max_val = np.max(stack[n])
 669
 670                stack[n] = ((stack[n] - min_val) / (max_val - min_val) * 65535).astype(
 671                    np.uint16
 672                )
 673
 674                stack[n] = np.clip(stack[n], 0, 65535)
 675
 676        return stack
 677
 678    except:
 679        print("Something went wrong. Check the function input data and try again!")
 680
 681
 682def z_projection(tiff_object, projection_type="avg"):
 683    """
 684    Perform Z-projection on a stacked (3D) image.
 685
 686    Parameters
 687    ----------
 688    tiff_object : np.ndarray
 689        Input stacked 3D image (e.g., loaded with `load_tiff()`).
 690
 691    projection_type : str
 692        Type of Z-axis projection. Options: 'avg', 'median', 'min', 'max', 'std'.
 693
 694    Returns
 695    -------
 696    np.ndarray
 697        Resulting 2D projection image.
 698    """
 699
 700    try:
 701
 702        if projection_type == "avg":
 703            img = np.mean(tiff_object, axis=0).astype(np.uint16)
 704        elif projection_type == "max":
 705            img = np.max(tiff_object, axis=0).astype(np.uint16)
 706        elif projection_type == "min":
 707            img = np.min(tiff_object, axis=0).astype(np.uint16)
 708        elif projection_type == "std":
 709            img = np.std(tiff_object, axis=0).astype(np.uint16)
 710        elif projection_type == "median":
 711            img = np.median(tiff_object, axis=0).astype(np.uint16)
 712
 713        return img
 714
 715    except:
 716        print("Something went wrong. Check the function input data and try again!")
 717
 718
 719def clahe_16bit(img, kernal=(100, 100)):
 720    """
 721    Apply CLAHE (Contrast Limited Adaptive Histogram Equalization) to an image.
 722
 723    Parameters
 724    ----------
 725    img : np.ndarray
 726        Input image.
 727
 728    Returns
 729    -------
 730    img : np.ndarray
 731        Image after CLAHE enhancement.
 732
 733    kernel : tuple
 734        Size of the CLAHE tile grid used for processing, e.g., (100, 100).
 735    """
 736
 737    try:
 738
 739        img = img.copy()
 740
 741        img8bit = img.copy()
 742
 743        min_val = np.min(img8bit)
 744        max_val = np.max(img8bit)
 745
 746        img8bit = ((img8bit - min_val) / (max_val - min_val) * 255).astype(np.uint8)
 747
 748        clahe = cv2.createCLAHE(clipLimit=10, tileGridSize=kernal)
 749        img8bit = clahe.apply(img8bit)
 750
 751        img8bit = img8bit / 255
 752
 753        img = img * img8bit
 754
 755        min_val = np.min(img)
 756        max_val = np.max(img)
 757
 758        img = ((img - min_val) / (max_val - min_val) * 65535).astype(np.uint16)
 759
 760        return img
 761
 762    except:
 763        print("Something went wrong. Check the function input data and try again!")
 764
 765
 766def equalizeHist_16bit(image_eq):
 767    """
 768    Apply global histogram equalization to an image.
 769
 770    Parameters
 771    ----------
 772    image_eq : np.ndarray
 773        Input image.
 774
 775    Returns
 776    -------
 777    np.ndarray
 778        Image after global histogram equalization.
 779    """
 780
 781    try:
 782
 783        image = image_eq.copy()
 784
 785        min_val = np.min(image)
 786        max_val = np.max(image)
 787
 788        scaled_image = ((image - min_val) / (max_val - min_val) * 255).astype(np.uint8)
 789
 790        eq_image = cv2.equalizeHist(scaled_image)
 791
 792        eq_image_bin = eq_image / 255
 793
 794        image_eq_16 = image * eq_image_bin
 795        image_eq_16 = (image_eq_16 / np.max(image_eq_16)) * 65535
 796        image_eq_16[image_eq_16 > (65535 / 2)] += 65535 - np.max(image_eq_16)
 797        image_eq_16 = image_eq_16.astype(np.uint16)
 798
 799        return image_eq_16
 800
 801    except:
 802        print("Something went wrong. Check the function input data and try again!")
 803
 804
 805def load_image(path):
 806    """
 807    Load an image and convert it to 16-bit if necessary.
 808
 809    Parameters
 810    ----------
 811    path : str
 812        Path to the image file.
 813
 814    Returns
 815    -------
 816    np.ndarray
 817        Loaded 16-bit image.
 818    """
 819
 820    try:
 821
 822        img = cv2.imread(path, cv2.IMREAD_ANYDEPTH | cv2.IMREAD_COLOR)
 823
 824        # convert to 16 bit (the function are working on 16 bit images!)
 825        if img.dtype != "uint16":
 826
 827            min_val = np.min(img)
 828            max_val = np.max(img)
 829
 830            img = ((img - min_val) / (max_val - min_val) * 65535).astype(np.uint16)
 831
 832            img = np.clip(img, 0, 65535)
 833
 834        return img
 835
 836    except:
 837        print("Something went wrong. Check the function input data and try again!")
 838
 839
 840def rotate_image(img, rotate: int):
 841    """
 842    Rotate an image by a specified angle.
 843
 844    Parameters
 845    ----------
 846    img : np.ndarray
 847        Image to rotate.
 848
 849    rotate : int
 850        Degree of rotation. Available options: 90, 180, 270.
 851
 852    Returns
 853    -------
 854    np.ndarray
 855        Rotated image.
 856    """
 857
 858    try:
 859
 860        if rotate == 0:
 861            r = 0
 862        elif rotate == 90:
 863            r = -1
 864        elif rotate == 180:
 865            r = 2
 866        elif rotate == 180:
 867            r = 2
 868        elif rotate == 270:
 869            r = 1
 870        else:
 871            print("Wrong argument - rotate!")
 872            return None
 873
 874        img = img.copy()
 875
 876        img = np.rot90(img.copy(), k=r)
 877
 878        return img
 879
 880    except:
 881        print("Something went wrong. Check the function input data and try again!")
 882
 883
 884def mirror_image(img, rotate: str):
 885    """
 886    Mirror an image along specified axes.
 887
 888    Parameters
 889    ----------
 890    img : np.ndarray
 891        Image to mirror.
 892
 893    rotate : str
 894        Type of mirroring to apply. Options:
 895            'h'  - horizontal mirroring
 896            'v'  - vertical mirroring
 897            'hv' - both horizontal and vertical mirroring
 898
 899    Returns
 900    -------
 901    np.ndarray
 902        Mirrored image.
 903    """
 904
 905    try:
 906
 907        if rotate == "h":
 908            img = np.fliplr(img.copy())
 909        elif rotate == "v":
 910            img = np.flipud(img.copy())
 911        elif rotate == "hv":
 912            img = np.flipud(np.fliplr(img.copy()))
 913        else:
 914            print("Wrong argument - rotate!")
 915            return None
 916
 917        return img
 918
 919    except:
 920        print("Something went wrong. Check the function input data and try again!")
 921
 922
 923# validatet UTILS
 924
 925
 926def merge_images(image_list: list, intensity_factors: list = []):
 927    """
 928    Merge multiple image projections from different channels.
 929
 930    Parameters
 931    ----------
 932    image_list : list of np.ndarray
 933        List of images to merge. All images must have the same shape and size.
 934
 935    intensity_factors : list of float
 936        Intensity scaling factors for each image in `image_list`. Base value is 1.
 937            - Values < 1 decrease intensity.
 938            - Values > 1 increase intensity.
 939
 940    Returns
 941    -------
 942    np.ndarray
 943        Merged image after applying intensity scaling.
 944    """
 945
 946    try:
 947
 948        result = None
 949
 950        if len(intensity_factors) == 0:
 951
 952            intensity_factors = []
 953            for bt in range(len(image_list)):
 954                intensity_factors.append(1)
 955
 956        for i, image in enumerate(image_list):
 957            if result is None:
 958                result = image.astype(np.uint64) * intensity_factors[i]
 959            else:
 960                result = cv2.addWeighted(
 961                    result, 1, image.astype(np.uint64) * intensity_factors[i], 1, 0
 962                )
 963
 964        result = np.clip(result, 0, 65535)
 965
 966        result = result.astype(np.uint16)
 967
 968        return result
 969
 970    except:
 971        print("Something went wrong. Check the function input data and try again!")
 972
 973
 974def adjust_img_16bit(
 975    img,
 976    color="gray",
 977    max_intensity: int = 65535,
 978    min_intenisty: int = 0,
 979    brightness: int = 1000,
 980    contrast=1.0,
 981    gamma=1.0,
 982):
 983    """
 984    Manually adjust image parameters and return the adjusted image.
 985
 986    Parameters
 987    ----------
 988    img : np.ndarray
 989        Input image.
 990
 991    color : str
 992        Image color channel. Options: ['green', 'blue', 'red', 'yellow', 'magenta', 'cyan'].
 993
 994    max_intensity : int
 995        Upper threshold for pixel values. Pixels exceeding this value are set to `max_intensity`.
 996
 997    min_intensity : int
 998        Lower threshold for pixel values. Pixels below this value are set to 0.
 999
1000    brightness : int, optional
1001        Image brightness adjustment. Typical range: [900-2000]. Default is 1000.
1002
1003    contrast : float or int, optional
1004        Image contrast adjustment. Typical range: [0-5]. Default is 1.
1005
1006    gamma : float or int, optional
1007        Gamma correction factor. Typical range: [0-5]. Default is 1.
1008
1009    Returns
1010    -------
1011    np.ndarray
1012        Adjusted image after applying brightness, contrast, gamma, and intensity thresholds.
1013    """
1014
1015    try:
1016
1017        img = img.copy()
1018
1019        img = img.astype(np.uint64)
1020
1021        img = np.clip(img, 0, 65535)
1022
1023        # brightness
1024        if brightness != 1000:
1025            factor = -1000 + brightness
1026            side = factor / abs(factor)
1027            img[img > 0] = img[img > 0] + ((img[img > 0] * abs(factor) / 100) * side)
1028            img = np.clip(img, 0, 65535)
1029
1030        # contrast
1031        if contrast != 1:
1032            img = ((img - np.mean(img)) * contrast) + np.mean(img)
1033            img = np.clip(img, 0, 65535)
1034
1035        # gamma
1036        if gamma != 1:
1037
1038            max_val = np.max(img)
1039
1040            image_array = img.copy() / max_val
1041
1042            image_array = np.clip(image_array, 0, 1)
1043
1044            corrected_array = image_array ** (1 / gamma)
1045
1046            img = corrected_array * max_val
1047
1048            del image_array, corrected_array
1049
1050            img = np.clip(img, 0, 65535)
1051
1052        img = np.nan_to_num(img, nan=0, posinf=65535, neginf=0)
1053        max_val = np.max(img)
1054        if max_val > 0:
1055            img = ((img / max_val) * 65535).astype(np.uint16)
1056
1057        # max intenisty
1058        if max_intensity != 65535:
1059            img[img >= max_intensity] = 65535
1060
1061        # min intenisty
1062        if min_intenisty != 0:
1063            img[img <= min_intenisty] = 0
1064
1065        img_gamma = np.zeros((img.shape[0], img.shape[1], 3), dtype=np.uint16)
1066
1067        if color == "green":
1068            img_gamma[:, :, 1] = img
1069
1070        elif color == "red":
1071            img_gamma[:, :, 2] = img
1072
1073        elif color == "blue":
1074            img_gamma[:, :, 0] = img
1075
1076        elif color == "magenta":
1077            img_gamma[:, :, 0] = img
1078            img_gamma[:, :, 2] = img
1079
1080        elif color == "yellow":
1081            img_gamma[:, :, 1] = img
1082            img_gamma[:, :, 2] = img
1083
1084        elif color == "cyan":
1085            img_gamma[:, :, 0] = img
1086            img_gamma[:, :, 1] = img
1087
1088        elif color == "gray":
1089            img_gamma[:, :, 0] = img
1090            img_gamma[:, :, 1] = img
1091            img_gamma[:, :, 2] = img
1092
1093        return img_gamma
1094
1095    except:
1096        print("Something went wrong. Check the function input data and try again!")
1097
1098
1099def get_screan():
1100    """
1101    Get the current screen width and height.
1102
1103    Returns
1104    -------
1105    tuple of int
1106        screen_width, screen_height
1107    """
1108
1109    root = tk.Tk()
1110    screen_width = root.winfo_screenwidth()
1111    screen_height = root.winfo_screenheight()
1112
1113    root.destroy()
1114
1115    return screen_width, screen_height
1116
1117
1118def resize_to_screen_img(img_file, factor=1):
1119    """
1120    Resize an input image to fit the screen size, optionally scaled by a factor.
1121
1122    Parameters
1123    ----------
1124    img_file : np.ndarray
1125        Input image to be resized.
1126
1127    factor : float, optional
1128        Scaling factor applied to the screen dimensions before resizing the image. Default is 1.
1129
1130    Returns
1131    -------
1132    np.ndarray
1133        Resized image that fits within the screen dimensions while maintaining aspect ratio.
1134    """
1135
1136    screen_width, screen_height = get_screan()
1137
1138    screen_width = int(screen_width * factor)
1139    screen_height = int(screen_height * factor)
1140
1141    h = int(img_file.shape[0])
1142    w = int(img_file.shape[1])
1143
1144    if screen_width < w:
1145        h = img_file.shape[0]
1146        w = img_file.shape[1]
1147
1148        ww = int((screen_width / w) * w)
1149        hh = int((screen_width / w) * h)
1150
1151        img_file = cv2.resize(img_file, (ww, hh))
1152
1153        h = img_file.shape[0]
1154        w = img_file.shape[1]
1155
1156    if screen_height < h:
1157        h = img_file.shape[0]
1158        w = img_file.shape[1]
1159
1160        ww = int((screen_height / h) * w)
1161        hh = int((screen_height / h) * h)
1162
1163        img_file = cv2.resize(img_file, (ww, hh))
1164
1165    return img_file
1166
1167
1168def display_preview(image):
1169    """
1170    Quickly preview an image using a display window.
1171
1172    Parameters
1173    ----------
1174    image : np.ndarray
1175        Input image to be displayed.
1176
1177    Notes
1178    -----
1179        The function displays the image in a window. Does not return a value.
1180    """
1181    try:
1182
1183        res_sc = resize_to_screen_img(image.copy(), factor=0.8)
1184
1185        cv2.imshow("Display", res_sc)
1186
1187        cv2.waitKey(0) & 0xFF
1188
1189        cv2.destroyAllWindows()
1190
1191    except:
1192        print("Something went wrong. Check the function input data and try again!")
1193
1194
1195## flow_JIMG_functions
1196
1197
1198class ImageTools:
1199    """
1200    Collection of utility functions for image preprocessing, adjustment,
1201    merging, and stitching.
1202
1203    This class provides standalone static methods for operations such as
1204    histogram equalization, CLAHE enhancement, intensity adjustments,
1205    image merging based on weighted ratios, and horizontal stitching.
1206    These tools are used internally by `ImagesManagement` but can also be
1207    applied independently for generic image-processing workflows.
1208
1209    Notes
1210    -----
1211    All methods operate on NumPy arrays and expect images in standard
1212    OpenCV format. Some functions assume 16-bit images unless stated
1213    otherwise.
1214
1215    Examples
1216    --------
1217    >>> from ImageTools import ImageTools
1218    >>> img = ImageTools.equalize_hist_16bit(img)
1219    >>> merged = ImageTools.merge_images([img1, img2], [0.5, 0.5])
1220    """
1221
1222    def get_screan(self):
1223        """
1224        Return the current screen size.
1225
1226        Returns
1227        -------
1228        screen_width : int
1229            Width of the screen in pixels.
1230
1231        screen_height : int
1232            Height of the screen in pixels.
1233        """
1234
1235        root = tk.Tk()
1236        screen_width = root.winfo_screenwidth()
1237        screen_height = root.winfo_screenheight()
1238
1239        root.destroy()
1240
1241        return screen_width, screen_height
1242
1243    def resize_to_screen_img(self, img_file, factor=0.5):
1244        """
1245        Resize an input image to a scaled version of the current screen size.
1246
1247        Parameters
1248        ----------
1249        img_file : np.ndarray
1250            Input image to be resized.
1251
1252        factor : int
1253            Scaling factor applied to the screen dimensions.
1254
1255        Returns
1256        -------
1257        resized_image : np.ndarray
1258            Resized image adjusted to the scaled screen size.
1259        """
1260
1261        screen_width, screen_height = self.get_screan()
1262
1263        screen_width = int(screen_width * factor)
1264        screen_height = int(screen_height * factor)
1265
1266        h = int(img_file.shape[0])
1267        w = int(img_file.shape[1])
1268
1269        if screen_width < w or screen_width * 0.3 > w:
1270            h = img_file.shape[0]
1271            w = img_file.shape[1]
1272
1273            ww = int((screen_width / w) * w)
1274            hh = int((screen_width / w) * h)
1275
1276            img_file = cv2.resize(img_file, (ww, hh))
1277
1278            h = img_file.shape[0]
1279            w = img_file.shape[1]
1280
1281        if screen_height < h or screen_height * 0.3 > h:
1282            h = img_file.shape[0]
1283            w = img_file.shape[1]
1284
1285            ww = int((screen_height / h) * w)
1286            hh = int((screen_height / h) * h)
1287
1288            img_file = cv2.resize(img_file, (ww, hh))
1289
1290        return img_file
1291
1292    def load_JIMG_project(self, project_path):
1293        """
1294        Load a JIMG project from a `.pjm` file.
1295
1296        Parameters
1297        ----------
1298        file_path : str
1299            Path to the project file with `.pjm` extension.
1300
1301        Returns
1302        -------
1303        project : object
1304            Loaded project object.
1305
1306        Raises
1307        ------
1308        ValueError
1309            If the provided file does not have a `.pjm` extension.
1310        """
1311
1312        if ".pjm" in project_path:
1313            with open(project_path, "rb") as file:
1314                app_metadata_tmp = pickle.load(file)
1315
1316            return app_metadata_tmp
1317
1318        else:
1319            print("\nProvide path to the project metadata file with *.pjm extension!!!")
1320
1321    def ajd_mask_size(self, image, mask):
1322        """
1323        Adjusts the size of a mask to match the dimensions of a given image.
1324
1325        Parameters
1326        ----------
1327        image : np.ndarray
1328            Reference image whose size the mask should match. Can be 2D or 3D.
1329
1330        mask : np.ndarray
1331            Mask image to be resized.
1332
1333        Returns
1334        -------
1335        np.ndarray
1336            Resized mask matching the input image dimensions.
1337        """
1338
1339        try:
1340            mask = cv2.resize(mask, (image.shape[2], image.shape[1]))
1341        except:
1342            mask = cv2.resize(mask, (image.shape[1], image.shape[0]))
1343
1344        return mask
1345
1346    def load_image(self, path_to_image):
1347        """
1348        Load an image from the specified file path.
1349
1350        Parameters
1351        ----------
1352        path_to_image : str
1353            Path to the image file.
1354
1355        Returns
1356        -------
1357        image : np.ndarray
1358            Loaded image as a NumPy array.
1359        """
1360
1361        image = load_image(path_to_image)
1362        return image
1363
1364    def load_3D_tiff(self, path_to_image):
1365        """
1366        Load a 3D image from a TIFF file.
1367
1368        Parameters
1369        ----------
1370        path_to_image : str
1371            Path to the 3D TIFF image file.
1372
1373        Returns
1374        -------
1375        image : np.ndarray
1376            Loaded 3D image as a NumPy array.
1377        """
1378
1379        image = load_tiff(path_to_image)
1380
1381        return image
1382
1383    def load_mask(self, path_to_mask):
1384        """
1385        Load a mask image.
1386
1387        Parameters
1388        ----------
1389        path_to_mask : str
1390            Path to the mask image file.
1391
1392        Returns
1393        -------
1394        mask : np.ndarray
1395            Loaded mask image as a NumPy array.
1396        """
1397
1398        mask = cv2.imread(path_to_mask, cv2.IMREAD_GRAYSCALE)
1399        return mask
1400
1401    def save(self, image, file_name):
1402        """
1403        Save an image to disk.
1404
1405        Parameters
1406        ----------
1407        image : np.ndarray
1408            Image data to be saved.
1409
1410        file_name : str
1411            Output file path including the desired image extension (e.g., ".png", ".jpg").
1412
1413        """
1414
1415        cv2.imwrite(filename=file_name, img=image)
1416
1417    # calculation methods
1418
1419    def drop_dict(self, dictionary, key, var, action=None):
1420        """
1421        Filters elements from a dictionary based on a condition applied to a specified key.
1422
1423        Parameters
1424        ----------
1425        dictionary : dict
1426            Dictionary containing lists or arrays under each key.
1427
1428        key : str
1429            The key in the dictionary on which the filtering condition will be applied.
1430
1431        var : numeric
1432            Value to compare against each element in dictionary[key].
1433
1434        action : str, optional
1435            Comparison operator as string: '<=', '>=', '==', '<', '>'.
1436            Default is None, which raises an error.
1437
1438        Returns
1439        -------
1440        dict
1441            A new dictionary with elements removed where the condition matches.
1442        """
1443
1444        dictionary = copy.deepcopy(dictionary)
1445        indices_to_drop = []
1446        for i, dr in enumerate(dictionary[key]):
1447
1448            if isinstance(dr, np.ndarray):
1449                dr = np.mean(dr)
1450
1451            if action == "<=":
1452                if var <= dr:
1453                    indices_to_drop.append(i)
1454            elif action == ">=":
1455                if var >= dr:
1456                    indices_to_drop.append(i)
1457            elif action == "==":
1458                if var == dr:
1459                    indices_to_drop.append(i)
1460            elif action == "<":
1461                if var < dr:
1462                    indices_to_drop.append(i)
1463            elif action == ">":
1464                if var > dr:
1465                    indices_to_drop.append(i)
1466            else:
1467                print("\nWrong action!")
1468                return None
1469
1470        for key, value in dictionary.items():
1471            dictionary[key] = [
1472                v for i, v in enumerate(value) if i not in indices_to_drop
1473            ]
1474
1475        return dictionary
1476
1477    # modified for gradation (separation near nucleus)
1478
1479    def create_mask(self, dictionary, image):
1480        """
1481        Creates a mask image with gradated intensity values for each coordinate set in a dictionary.
1482
1483        Parameters
1484        ----------
1485        dictionary : dict
1486            Dictionary containing a 'coords' key with a list of arrays of coordinates.
1487
1488        image : np.ndarray
1489            Base image to define the shape of the mask.
1490
1491        Returns
1492        -------
1493        np.ndarray
1494            Mask image with uint16 gradated intensity.
1495        """
1496
1497        image_mask = np.zeros(image.shape)
1498
1499        arrays_list = copy.deepcopy(dictionary["coords"])
1500
1501        if len(arrays_list) > 0:
1502
1503            initial_val = math.floor((2**16 - 1) / 4)
1504            intensity_list = ((2**16 - 1) - math.floor((2**16 - 1) / 4)) / len(
1505                arrays_list
1506            )
1507
1508            gradation = 1
1509            for arr in arrays_list:
1510                image_mask[arr[:, 0], arr[:, 1]] = initial_val + (
1511                    gradation * intensity_list
1512                )
1513                gradation += 1
1514
1515        return image_mask.astype("uint16")
1516
1517    def min_max_histograme(self, image):
1518        """
1519        Calculates histogram-based minimum and maximum intensity percentiles in an image.
1520
1521        Parameters
1522        ----------
1523        image : np.ndarray
1524            Input image for histogram analysis.
1525
1526        Returns
1527        -------
1528        min_val : float
1529            Minimum intensity percentile above zero.
1530
1531        max_val : float
1532            Maximum intensity percentile based on histogram gradient.
1533
1534        df : pd.DataFrame
1535            DataFrame containing quantiles, corresponding intensity values, and cumulative percents.
1536        """
1537
1538        q = []
1539        val = []
1540        perc = []
1541
1542        max_val = image.shape[0] * image.shape[1]
1543
1544        for n in range(0, 100, 5):
1545            q.append(n)
1546            val.append(np.quantile(image, n / 100))
1547            sum_val = np.sum(image < np.quantile(image, n / 100))
1548            pr = sum_val / max_val
1549            perc.append(pr)
1550
1551        df = pd.DataFrame({"q": q, "val": val, "perc": perc})
1552
1553        min_val = 0
1554        for i in df.index:
1555            if df["val"][i] != 0 and min_val == 0:
1556                min_val = df["perc"][i]
1557
1558        max_val = 0
1559        df = df[df["perc"] > 0]
1560        df = df.sort_values("q", ascending=False).reset_index(drop=True)
1561
1562        for i in df.index:
1563            if i > 1 and df["val"][i] * 1.5 > df["val"][i - 1]:
1564                max_val = df["perc"][i]
1565                break
1566            elif i == len(df.index) - 1:
1567                max_val = df["perc"][i]
1568
1569        return min_val, max_val, df
def umap_html(umap_result, width=1000, height=1200):
29def umap_html(umap_result, width=1000, height=1200):
30    """
31    Create an interactive HTML UMAP scatter plot.
32
33    Parameters
34    ----------
35    umap_result : pandas.DataFrame or dict-like
36        UMAP embedding containing at least:
37        - `0` : array-like, UMAP dimension 1
38        - `1` : array-like, UMAP dimension 2
39        - `'clusters'` : array-like, assigned cluster labels
40
41    width : int, optional
42        Width of the output figure in pixels. Default is 1000.
43
44    height : int, optional
45        Height of the output figure in pixels. Default is 1200.
46
47    Returns
48    -------
49    plotly.graph_objs._figure.Figure
50        Interactive Plotly scatter plot object visualizing UMAP with colored clusters.
51    """
52
53    fig = px.scatter(
54        x=umap_result[0],
55        y=umap_result[1],
56        color=umap_result["clusters"],
57        labels={"color": "Cells"},
58        template="simple_white",
59        width=width,
60        height=height,
61        render_mode="svg",
62        color_discrete_sequence=px.colors.qualitative.Dark24
63        + px.colors.qualitative.Light24,
64    )
65
66    fig.update_xaxes(title_text="UMAP 1")
67    fig.update_yaxes(title_text="UMAP 2")
68
69    return fig

Create an interactive HTML UMAP scatter plot.

Parameters

umap_result : pandas.DataFrame or dict-like UMAP embedding containing at least: - 0 : array-like, UMAP dimension 1 - 1 : array-like, UMAP dimension 2 - 'clusters' : array-like, assigned cluster labels

width : int, optional Width of the output figure in pixels. Default is 1000.

height : int, optional Height of the output figure in pixels. Default is 1200.

Returns

plotly.graph_objs._figure.Figure Interactive Plotly scatter plot object visualizing UMAP with colored clusters.

def umap_static(umap_result, width=10, height=13):
 72def umap_static(umap_result, width=10, height=13):
 73    """
 74    Create a static matplotlib UMAP scatter plot.
 75
 76    Parameters
 77    ----------
 78    umap_result : pandas.DataFrame or dict-like
 79        UMAP projection containing:
 80        - `0` : array-like, UMAP dimension 1
 81        - `1` : array-like, UMAP dimension 2
 82        - `'clusters'` : array-like, cluster assignments
 83
 84    width : float, optional
 85        Width of the figure in inches. Default is 10.
 86
 87    height : float, optional
 88        Height of the figure in inches. Default is 13.
 89
 90    Returns
 91    -------
 92    matplotlib.figure.Figure
 93        Static matplotlib figure representing the UMAP embedding with clusters.
 94    """
 95
 96    plotly_colors = px.colors.qualitative.Dark24 + px.colors.qualitative.Light24
 97    num_colors = len(plotly_colors)
 98
 99    fig = plt.figure(figsize=(width, height))
100
101    cluster_counts = {
102        label: np.sum(umap_result["clusters"] == label)
103        for label in np.unique(umap_result["clusters"])
104    }
105
106    sorted_labels = sorted(
107        cluster_counts, key=lambda label: cluster_counts[label], reverse=True
108    )
109
110    color_map = {
111        label: plotly_colors[i % num_colors] for i, label in enumerate(sorted_labels)
112    }
113
114    for label in sorted_labels:
115        subset = umap_result[umap_result["clusters"] == label]
116        plt.scatter(
117            subset[0],
118            subset[1],
119            c=[color_map[label]],
120            label=f"Cluster {label}",
121            alpha=0.7,
122            s=20,
123            edgecolor="black",
124            linewidths=0.1,
125        )
126
127    plt.xlabel("UMAP 1", fontsize=14)
128    plt.ylabel("UMAP 2", fontsize=14)
129    plt.grid(True, which="both", linestyle="--", linewidth=0.1)
130    plt.legend(bbox_to_anchor=(1.05, 1), loc="upper left")
131    plt.tight_layout()
132
133    return fig

Create a static matplotlib UMAP scatter plot.

Parameters

umap_result : pandas.DataFrame or dict-like UMAP projection containing: - 0 : array-like, UMAP dimension 1 - 1 : array-like, UMAP dimension 2 - 'clusters' : array-like, cluster assignments

width : float, optional Width of the figure in inches. Default is 10.

height : float, optional Height of the figure in inches. Default is 13.

Returns

matplotlib.figure.Figure Static matplotlib figure representing the UMAP embedding with clusters.

def test_data(path=''):
136def test_data(path=""):
137    """
138    Download and extract test data from Google Drive.
139
140    This function downloads a compressed archive containing example test data
141    and extracts it into the specified directory. The data is fetched using
142    a direct Google Drive link. If the download or extraction fails, an
143    error message is printed.
144
145    Parameters
146    ----------
147    path : str, optional
148        Destination directory where the test dataset will be downloaded and
149        extracted. Defaults to the current working directory.
150
151    Notes
152    -----
153    - The downloaded file is named ``test_data.tar.gz``.
154    - The archive is extracted into ``<path>/test_data``.
155    - In case of any failure (download or extraction), the function prints
156      an informative message instead of raising an exception.
157    """
158
159    try:
160
161        file_name = "test_data.tar.gz"
162
163        file_name = os.path.join(path, file_name)
164
165        url = "https://drive.google.com/uc?id=1MhzhleMP7iTzlBVW8eP5sFaonJdg1a3T"
166
167        gdown.download(url, file_name, quiet=False)
168
169        # Unzip
170
171        with tarfile.open(file_name, "r:gz") as tar:
172            tar.extractall(path=path)
173
174        print(
175            f"\nTest data downloaded succesfully -> {os.path.join(path, 'test_data')}"
176        )
177
178    except:
179
180        print(
181            "\nTest data could not be downloaded. Please check your connection and try again!"
182        )

Download and extract test data from Google Drive.

This function downloads a compressed archive containing example test data and extracts it into the specified directory. The data is fetched using a direct Google Drive link. If the download or extraction fails, an error message is printed.

Parameters

path : str, optional Destination directory where the test dataset will be downloaded and extracted. Defaults to the current working directory.

Notes

  • The downloaded file is named test_data.tar.gz.
  • The archive is extracted into <path>/test_data.
  • In case of any failure (download or extraction), the function prints an informative message instead of raising an exception.
def prop_plot(df_pivot, chi_df):
185def prop_plot(df_pivot, chi_df):
186    """
187    Create a stacked bar plot of proportional data with post-hoc significance annotations.
188
189    Parameters
190    ----------
191    df_pivot : pandas.DataFrame
192        Pivot table where rows represent categories (e.g., compartments) and columns
193        represent groups. Values are counts or frequencies.
194
195    chi_df : pandas.DataFrame
196        DataFrame containing pairwise Chi-square test results with an added
197        'Significance_Label' column (e.g., '***', '**', '*', 'ns') for each pair
198        of groups. Typically output from `chi_pairs` and `get_significance_label`.
199
200    Returns
201    -------
202    matplotlib.figure.Figure
203        The Matplotlib figure object containing the stacked bar plot.
204
205    Notes
206    -----
207    - The function converts raw counts to percentages per group for visualization.
208    - Each pairwise comparison and its significance label is displayed as a text box
209      next to the plot.
210    - Colors are assigned using the 'viridis' colormap by default.
211    - The plot is configured for clarity with labeled axes, legend, and appropriately
212      sized text.
213    """
214
215    df_pivot_perc = df_pivot.div(df_pivot.sum(axis=0), axis=1) * 100
216
217    posthoc_text = "\n".join(
218        [
219            f"{row['Group 1']}{row['Group 2']}: {row['Significance_Label']}"
220            for _, row in chi_df.iterrows()
221        ]
222    )
223
224    fig, ax = plt.subplots(figsize=(12, 7))
225
226    df_pivot_perc.T.plot(kind="bar", stacked=True, ax=ax, cmap="viridis")
227
228    ax.set_ylabel("Percentage (%)", fontsize=16)
229    ax.set_xlabel("Groups", fontsize=16)
230
231    ax.tick_params(axis="both", labelsize=12)
232
233    ax.legend(
234        title="Compartment", loc="upper left", bbox_to_anchor=(1.02, 1.05), fontsize=14
235    )
236
237    plt.figtext(
238        0.93,
239        0.6,
240        posthoc_text,
241        ha="left",
242        va="top",
243        fontsize=12,
244        bbox={"facecolor": "white", "alpha": 0.7, "pad": 5},
245    )
246
247    return fig

Create a stacked bar plot of proportional data with post-hoc significance annotations.

Parameters

df_pivot : pandas.DataFrame Pivot table where rows represent categories (e.g., compartments) and columns represent groups. Values are counts or frequencies.

chi_df : pandas.DataFrame DataFrame containing pairwise Chi-square test results with an added 'Significance_Label' column (e.g., '', '', '', 'ns') for each pair of groups. Typically output from chi_pairs and get_significance_label.

Returns

matplotlib.figure.Figure The Matplotlib figure object containing the stacked bar plot.

Notes

  • The function converts raw counts to percentages per group for visualization.
  • Each pairwise comparison and its significance label is displayed as a text box next to the plot.
  • Colors are assigned using the 'viridis' colormap by default.
  • The plot is configured for clarity with labeled axes, legend, and appropriately sized text.
def get_significance_label(p_value):
250def get_significance_label(p_value):
251    """
252    Return a standard significance label based on a p-value.
253
254    Parameters
255    ----------
256    p_value : float
257        The p-value for which the significance label should be determined.
258
259    Returns
260    -------
261    str
262        A significance marker commonly used in statistical reporting:
263
264        - '***' : p < 0.001
265        - '**'  : p < 0.01
266        - '*'   : p < 0.05
267        - 'ns'  : not significant (p ≥ 0.05)
268
269    Notes
270    -----
271    This helper function is typically used for annotating statistical test
272    results in tables or visualizations. Thresholds follow conventional
273    statistical notation for significance levels.
274    """
275
276    if p_value < 0.001:
277        return "***"
278    elif p_value < 0.01:
279        return "**"
280    elif p_value < 0.05:
281        return "*"
282    else:
283        return "ns"

Return a standard significance label based on a p-value.

Parameters

p_value : float The p-value for which the significance label should be determined.

Returns

str A significance marker commonly used in statistical reporting:

- '***' : p < 0.001
- '**'  : p < 0.01
- '*'   : p < 0.05
- 'ns'  : not significant (p ≥ 0.05)

Notes

This helper function is typically used for annotating statistical test results in tables or visualizations. Thresholds follow conventional statistical notation for significance levels.

def chi_pairs(df_pivot):
286def chi_pairs(df_pivot):
287    """
288    Compute pairwise Chi-square tests for all combinations of groups in a pivoted dataframe.
289
290    Parameters
291    ----------
292    df_pivot : pandas.DataFrame
293        A pivot table where rows represent categories and columns represent groups.
294        Values should be counts (frequencies). The function will add +1 to each cell
295        to avoid zero counts during chi-square computation.
296
297    Returns
298    -------
299    pandas.DataFrame
300        A DataFrame containing pairwise Chi-square test results with the following columns:
301        - 'Group 1' : str
302            Name of the first group in the pair.
303        - 'Group 2' : str
304            Name of the second group in the pair.
305        - 'Chi²' : float
306            The Chi-square statistic for the comparison.
307        - 'p-value' : float
308            The p-value of the Chi-square test.
309
310    Notes
311    -----
312    The function compares every possible pair of columns using `scipy.stats.chi2_contingency`.
313    Yates' correction is applied by default unless disabled in the SciPy version used.
314    A value of 1 is added to all cells to avoid issues with zero frequencies.
315    """
316
317    group_pairs = list(combinations(df_pivot.columns, 2))
318
319    posthoc_results = []
320
321    for group1, group2 in group_pairs:
322        sub_table = df_pivot.T.loc[[group1, group2]] + 1
323        chi2, p, dof, expected = chi2_contingency(sub_table)
324
325        posthoc_results.append(
326            {"Group 1": group1, "Group 2": group2, "Chi²": chi2, "p-value": p}
327        )
328
329    posthoc_df = pd.DataFrame(posthoc_results)
330
331    return posthoc_df

Compute pairwise Chi-square tests for all combinations of groups in a pivoted dataframe.

Parameters

df_pivot : pandas.DataFrame A pivot table where rows represent categories and columns represent groups. Values should be counts (frequencies). The function will add +1 to each cell to avoid zero counts during chi-square computation.

Returns

pandas.DataFrame A DataFrame containing pairwise Chi-square test results with the following columns: - 'Group 1' : str Name of the first group in the pair. - 'Group 2' : str Name of the second group in the pair. - 'Chi²' : float The Chi-square statistic for the comparison. - 'p-value' : float The p-value of the Chi-square test.

Notes

The function compares every possible pair of columns using scipy.stats.chi2_contingency. Yates' correction is applied by default unless disabled in the SciPy version used. A value of 1 is added to all cells to avoid issues with zero frequencies.

def statistic(input_df, sets=None, metadata=None, n_proc=10):
334def statistic(input_df, sets=None, metadata=None, n_proc=10):
335    """
336    Compute statistical comparison between cell groups or clusters.
337
338    This function performs differential feature analysis between either:
339    (1) every group vs. all other groups (default mode), or
340    (2) two user-defined groups specified in ``sets``.
341
342    The analysis includes:
343    - Mann–Whitney U test
344    - Percentage of non-zero values
345    - Means and standard deviations
346    - Effect size metric (ESM)
347    - Benjamini–Hochberg FDR correction
348    - Fold-change and log2 fold-change
349
350
351    Parameters
352    ----------
353    input_df : pandas.DataFrame
354        Input feature matrix where rows represent features and columns represent cells.
355        The function transposes this table internally, treating columns as features.
356
357    sets : dict or None, optional
358        Mode selection:
359        - ``None`` (default): each unique label in ``metadata['sets']`` is compared
360          against all remaining groups.
361        - ``dict``: must contain exactly two keys, each mapping to a list of labels
362          belonging to each comparison group. Example:
363          ``{'A': ['T1', 'T2'], 'B': ['C1', 'C2']}``.
364
365    metadata : pandas.DataFrame, optional
366        Metadata containing at least a column ``'sets'`` with group labels
367        corresponding to columns of ``input_df``.
368
369    n_proc : int, optional
370        Number of parallel processes used for statistical computation.
371        Default is ``10``.
372
373    Returns
374    -------
375    pandas.DataFrame or None
376        A DataFrame containing statistical results for each feature, including:
377
378        - ``feature`` : str
379        - ``p_val`` : float
380        - ``adj_pval`` : float
381        - ``pct_valid`` : float
382        - ``pct_ctrl`` : float
383        - ``avg_valid`` : float
384        - ``avg_ctrl`` : float
385        - ``sd_valid`` : float
386        - ``sd_ctrl`` : float
387        - ``esm`` : float
388        - ``FC`` : float
389        - ``log(FC)`` : float
390        - ``norm_diff`` : float
391        - ``valid_group`` : str
392        - ``-log(p_val)`` : float
393
394        If ``sets`` is ``None``, results for each group are concatenated.
395
396        Returns ``None`` in case of errors or invalid parameters.
397
398    Notes
399    -----
400    - Columns containing only zeros are automatically removed.
401    - p-values equal for both groups produce ``p_val = 1``.
402    - Benjamini–Hochberg correction is applied separately within each group comparison.
403    - Fold-change is stabilized using a small, data-derived ``low_factor``.
404    - Uses ``Mann–Whitney U`` test with ``alternative='two-sided'``.
405
406    Raises
407    ------
408    None
409        All exceptions are caught internally and printed as messages.
410
411    Examples
412    --------
413    >>> df = pd.DataFrame(...)
414    >>> meta = pd.DataFrame({'sets': [...]})
415    >>> stat = statistic(df, metadata=meta)
416    >>> stat.head()
417
418    >>> # Compare two groups explicitly
419    >>> sets = {'A': ['Type1'], 'B': ['Type2']}
420    >>> stat = statistic(df, sets=sets, metadata=meta, n_proc=4)
421    """
422    try:
423        offset = 1e-100
424
425        def stat_calc(choose, feature_name):
426            target_values = choose.loc[choose["DEG"] == "target", feature_name]
427            rest_values = choose.loc[choose["DEG"] == "rest", feature_name]
428
429            pct_valid = (target_values > 0).sum() / len(target_values)
430            pct_rest = (rest_values > 0).sum() / len(rest_values)
431
432            avg_valid = np.mean(target_values)
433            avg_ctrl = np.mean(rest_values)
434            sd_valid = np.std(target_values, ddof=1)
435            sd_ctrl = np.std(rest_values, ddof=1)
436            esm = (avg_valid - avg_ctrl) / np.sqrt(((sd_valid**2 + sd_ctrl**2) / 2))
437
438            if np.sum(target_values) == np.sum(rest_values):
439                p_val = 1.0
440            else:
441                _, p_val = stats.mannwhitneyu(
442                    target_values, rest_values, alternative="two-sided"
443                )
444
445            return {
446                "feature": feature_name,
447                "p_val": p_val,
448                "pct_valid": pct_valid,
449                "pct_ctrl": pct_rest,
450                "avg_valid": avg_valid,
451                "avg_ctrl": avg_ctrl,
452                "sd_valid": sd_valid,
453                "sd_ctrl": sd_ctrl,
454                "esm": esm,
455            }
456
457        # Transpose the input DataFrame
458        choose = input_df.copy().T
459
460        if sets is None:
461            print("\nAnalysis started...")
462            print("\nComparing each type of cell to others...")
463            final_results = []
464
465            if len(set(metadata["sets"])) > 1:
466                choose.index = metadata["sets"]
467
468            indexes = list(choose.index)
469
470            for c in set(indexes):
471                print(f"Calculating statistics for {c}")
472
473                choose.index = indexes
474                choose["DEG"] = np.where(choose.index == c, "target", "rest")
475
476                valid = ",".join(set(choose.index[choose["DEG"] == "target"]))
477                choose = choose.loc[
478                    :, (choose != 0).any(axis=0)
479                ]  # Remove all-zero columns
480
481                # Parallel computation
482                results = Parallel(n_jobs=n_proc)(
483                    delayed(stat_calc)(choose, feature)
484                    for feature in tqdm(choose.columns[choose.columns != "DEG"])
485                )
486
487                # Convert results to DataFrame
488                combined_df = pd.DataFrame(results)
489                combined_df["valid_group"] = valid
490                combined_df.sort_values(by="p_val", inplace=True)
491
492                # Adjusted p-values using Benjamini-Hochberg method
493                num_tests = len(combined_df)
494                combined_df["adj_pval"] = np.minimum(
495                    1, (combined_df["p_val"] * num_tests) / np.arange(1, num_tests + 1)
496                )
497
498                combined_df["-log(p_val)"] = -np.log10(offset + combined_df["p_val"])
499                low_factor = (
500                    np.min(
501                        (combined_df["avg_valid"] + combined_df["avg_ctrl"])[
502                            combined_df["avg_valid"] + combined_df["avg_ctrl"] > 0
503                        ]
504                    )
505                    * 0.95
506                )
507
508                combined_df["FC"] = (combined_df["avg_valid"] + low_factor) / (
509                    combined_df["avg_ctrl"] + low_factor
510                )
511                combined_df["log(FC)"] = np.log2(combined_df["FC"])
512                combined_df["norm_diff"] = (combined_df["avg_valid"] + low_factor) - (
513                    combined_df["avg_ctrl"]
514                )
515
516                final_results.append(combined_df)
517
518            print("\nAnalysis has finished!")
519            return pd.concat(final_results, ignore_index=True)
520
521        elif isinstance(sets, dict):
522            print("\nAnalysis started...")
523            print("\nComparing groups...")
524
525            group_list = list(sets.keys())
526            choose.index = metadata["sets"]
527
528            inx = sorted([item for sublist in sets.values() for item in sublist])
529            choose = choose.loc[inx]
530
531            full_df = pd.DataFrame()
532            for n, g in enumerate(group_list):
533                print(f"Calculating statistics for {g}")
534
535                rest_indices = [
536                    idx
537                    for i, group in enumerate(group_list)
538                    if i != n
539                    for idx in sets[group]
540                ]
541
542                choose["DEG"] = np.where(
543                    choose.index.isin(sets[g]),
544                    "target",
545                    np.where(choose.index.isin(rest_indices), "rest", "drop"),
546                )
547
548                choose = choose[choose["DEG"] != "drop"]
549
550                valid = g
551                choose = choose.loc[
552                    :, (choose != 0).any(axis=0)
553                ]  # Remove all-zero columns
554
555                # Parallel computation
556                results = Parallel(n_jobs=n_proc)(
557                    delayed(stat_calc)(choose, feature)
558                    for feature in tqdm(choose.columns[choose.columns != "DEG"])
559                )
560
561                # Convert results to DataFrame
562                combined_df = pd.DataFrame(results)
563                combined_df["valid_group"] = valid
564                combined_df.sort_values(by="p_val", inplace=True)
565
566                # Adjusted p-values using Benjamini-Hochberg method
567                num_tests = len(combined_df)
568                combined_df["adj_pval"] = np.minimum(
569                    1, (combined_df["p_val"] * num_tests) / np.arange(1, num_tests + 1)
570                )
571
572                combined_df["-log(p_val)"] = -np.log10(offset + combined_df["p_val"])
573                low_factor = (
574                    np.min(
575                        (combined_df["avg_valid"] + combined_df["avg_ctrl"])[
576                            combined_df["avg_valid"] + combined_df["avg_ctrl"] > 0
577                        ]
578                    )
579                    * 0.95
580                )
581
582                combined_df["FC"] = (combined_df["avg_valid"] + low_factor) / (
583                    combined_df["avg_ctrl"] + low_factor
584                )
585                combined_df["log(FC)"] = np.log2(combined_df["FC"])
586                combined_df["norm_diff"] = (
587                    combined_df["avg_valid"] - combined_df["avg_ctrl"]
588                )
589
590                full_df = pd.concat([full_df, combined_df])
591
592            print("\nAnalysis has finished!")
593            return full_df
594
595        else:
596            print("\nInvalid parameters. Please check the input.")
597            return None
598
599    except Exception as e:
600        print(f"Error: {e}")
601        return None

Compute statistical comparison between cell groups or clusters.

This function performs differential feature analysis between either: (1) every group vs. all other groups (default mode), or (2) two user-defined groups specified in sets.

The analysis includes:

  • Mann–Whitney U test
  • Percentage of non-zero values
  • Means and standard deviations
  • Effect size metric (ESM)
  • Benjamini–Hochberg FDR correction
  • Fold-change and log2 fold-change

Parameters

input_df : pandas.DataFrame Input feature matrix where rows represent features and columns represent cells. The function transposes this table internally, treating columns as features.

sets : dict or None, optional Mode selection: - None (default): each unique label in metadata['sets'] is compared against all remaining groups. - dict: must contain exactly two keys, each mapping to a list of labels belonging to each comparison group. Example: {'A': ['T1', 'T2'], 'B': ['C1', 'C2']}.

metadata : pandas.DataFrame, optional Metadata containing at least a column 'sets' with group labels corresponding to columns of input_df.

n_proc : int, optional Number of parallel processes used for statistical computation. Default is 10.

Returns

pandas.DataFrame or None A DataFrame containing statistical results for each feature, including:

- ``feature`` : str
- ``p_val`` : float
- ``adj_pval`` : float
- ``pct_valid`` : float
- ``pct_ctrl`` : float
- ``avg_valid`` : float
- ``avg_ctrl`` : float
- ``sd_valid`` : float
- ``sd_ctrl`` : float
- ``esm`` : float
- ``FC`` : float
- ``log(FC)`` : float
- ``norm_diff`` : float
- ``valid_group`` : str
- ``-log(p_val)`` : float

If ``sets`` is ``None``, results for each group are concatenated.

Returns ``None`` in case of errors or invalid parameters.

Notes

  • Columns containing only zeros are automatically removed.
  • p-values equal for both groups produce p_val = 1.
  • Benjamini–Hochberg correction is applied separately within each group comparison.
  • Fold-change is stabilized using a small, data-derived low_factor.
  • Uses Mann–Whitney U test with alternative='two-sided'.

Raises

None All exceptions are caught internally and printed as messages.

Examples

>>> df = pd.DataFrame(...)
>>> meta = pd.DataFrame({'sets': [...]})
>>> stat = statistic(df, metadata=meta)
>>> stat.head()
>>> # Compare two groups explicitly
>>> sets = {'A': ['Type1'], 'B': ['Type2']}
>>> stat = statistic(df, sets=sets, metadata=meta, n_proc=4)
def save_image(image, path_to_save):
607def save_image(image, path_to_save):
608    """
609    Save an image to disk.
610
611    Parameters
612    ----------
613    image : np.ndarray
614        Input image array to be saved.
615
616    path_to_save : str
617        Output file path including the filename.
618        Must include one of the following extensions:
619        ``.png``, ``.tiff`` or ``.tif``.
620
621    Returns
622    -------
623    str
624        Path to the saved file.
625    """
626
627    try:
628        if (
629            len(path_to_save) == 0
630            or ".png" not in path_to_save
631            or ".tiff" not in path_to_save
632            or ".tif" not in path_to_save
633        ):
634            print(
635                "\nThe path is not provided or the file extension is not *.png, *.tiff or *.tif"
636            )
637        else:
638            cv2.imwrite(path_to_save, image)
639
640    except:
641        print("Something went wrong. Check the function input data and try again!")

Save an image to disk.

Parameters

image : np.ndarray Input image array to be saved.

path_to_save : str Output file path including the filename. Must include one of the following extensions: .png, .tiff or .tif.

Returns

str Path to the saved file.

def load_tiff(path_to_tiff: str):
644def load_tiff(path_to_tiff: str):
645    """
646    Load a *.tiff image and ensure it is in 16-bit format.
647
648    Parameters
649    ----------
650    path_to_tiff : str
651        Path to the *.tiff file.
652
653    Returns
654    -------
655    np.ndarray
656        Loaded image stack, converted to 16-bit if necessary.
657    """
658
659    try:
660        stack = tiff.imread(path_to_tiff)
661
662        if stack.dtype != "uint16":
663
664            stack = stack.astype(np.uint16)
665
666            for n, _ in enumerate(stack):
667
668                min_val = np.min(stack[n])
669                max_val = np.max(stack[n])
670
671                stack[n] = ((stack[n] - min_val) / (max_val - min_val) * 65535).astype(
672                    np.uint16
673                )
674
675                stack[n] = np.clip(stack[n], 0, 65535)
676
677        return stack
678
679    except:
680        print("Something went wrong. Check the function input data and try again!")

Load a *.tiff image and ensure it is in 16-bit format.

Parameters

path_to_tiff : str Path to the *.tiff file.

Returns

np.ndarray Loaded image stack, converted to 16-bit if necessary.

def z_projection(tiff_object, projection_type='avg'):
683def z_projection(tiff_object, projection_type="avg"):
684    """
685    Perform Z-projection on a stacked (3D) image.
686
687    Parameters
688    ----------
689    tiff_object : np.ndarray
690        Input stacked 3D image (e.g., loaded with `load_tiff()`).
691
692    projection_type : str
693        Type of Z-axis projection. Options: 'avg', 'median', 'min', 'max', 'std'.
694
695    Returns
696    -------
697    np.ndarray
698        Resulting 2D projection image.
699    """
700
701    try:
702
703        if projection_type == "avg":
704            img = np.mean(tiff_object, axis=0).astype(np.uint16)
705        elif projection_type == "max":
706            img = np.max(tiff_object, axis=0).astype(np.uint16)
707        elif projection_type == "min":
708            img = np.min(tiff_object, axis=0).astype(np.uint16)
709        elif projection_type == "std":
710            img = np.std(tiff_object, axis=0).astype(np.uint16)
711        elif projection_type == "median":
712            img = np.median(tiff_object, axis=0).astype(np.uint16)
713
714        return img
715
716    except:
717        print("Something went wrong. Check the function input data and try again!")

Perform Z-projection on a stacked (3D) image.

Parameters

tiff_object : np.ndarray Input stacked 3D image (e.g., loaded with load_tiff()).

projection_type : str Type of Z-axis projection. Options: 'avg', 'median', 'min', 'max', 'std'.

Returns

np.ndarray Resulting 2D projection image.

def clahe_16bit(img, kernal=(100, 100)):
720def clahe_16bit(img, kernal=(100, 100)):
721    """
722    Apply CLAHE (Contrast Limited Adaptive Histogram Equalization) to an image.
723
724    Parameters
725    ----------
726    img : np.ndarray
727        Input image.
728
729    Returns
730    -------
731    img : np.ndarray
732        Image after CLAHE enhancement.
733
734    kernel : tuple
735        Size of the CLAHE tile grid used for processing, e.g., (100, 100).
736    """
737
738    try:
739
740        img = img.copy()
741
742        img8bit = img.copy()
743
744        min_val = np.min(img8bit)
745        max_val = np.max(img8bit)
746
747        img8bit = ((img8bit - min_val) / (max_val - min_val) * 255).astype(np.uint8)
748
749        clahe = cv2.createCLAHE(clipLimit=10, tileGridSize=kernal)
750        img8bit = clahe.apply(img8bit)
751
752        img8bit = img8bit / 255
753
754        img = img * img8bit
755
756        min_val = np.min(img)
757        max_val = np.max(img)
758
759        img = ((img - min_val) / (max_val - min_val) * 65535).astype(np.uint16)
760
761        return img
762
763    except:
764        print("Something went wrong. Check the function input data and try again!")

Apply CLAHE (Contrast Limited Adaptive Histogram Equalization) to an image.

Parameters

img : np.ndarray Input image.

Returns

img : np.ndarray Image after CLAHE enhancement.

kernel : tuple Size of the CLAHE tile grid used for processing, e.g., (100, 100).

def equalizeHist_16bit(image_eq):
767def equalizeHist_16bit(image_eq):
768    """
769    Apply global histogram equalization to an image.
770
771    Parameters
772    ----------
773    image_eq : np.ndarray
774        Input image.
775
776    Returns
777    -------
778    np.ndarray
779        Image after global histogram equalization.
780    """
781
782    try:
783
784        image = image_eq.copy()
785
786        min_val = np.min(image)
787        max_val = np.max(image)
788
789        scaled_image = ((image - min_val) / (max_val - min_val) * 255).astype(np.uint8)
790
791        eq_image = cv2.equalizeHist(scaled_image)
792
793        eq_image_bin = eq_image / 255
794
795        image_eq_16 = image * eq_image_bin
796        image_eq_16 = (image_eq_16 / np.max(image_eq_16)) * 65535
797        image_eq_16[image_eq_16 > (65535 / 2)] += 65535 - np.max(image_eq_16)
798        image_eq_16 = image_eq_16.astype(np.uint16)
799
800        return image_eq_16
801
802    except:
803        print("Something went wrong. Check the function input data and try again!")

Apply global histogram equalization to an image.

Parameters

image_eq : np.ndarray Input image.

Returns

np.ndarray Image after global histogram equalization.

def load_image(path):
806def load_image(path):
807    """
808    Load an image and convert it to 16-bit if necessary.
809
810    Parameters
811    ----------
812    path : str
813        Path to the image file.
814
815    Returns
816    -------
817    np.ndarray
818        Loaded 16-bit image.
819    """
820
821    try:
822
823        img = cv2.imread(path, cv2.IMREAD_ANYDEPTH | cv2.IMREAD_COLOR)
824
825        # convert to 16 bit (the function are working on 16 bit images!)
826        if img.dtype != "uint16":
827
828            min_val = np.min(img)
829            max_val = np.max(img)
830
831            img = ((img - min_val) / (max_val - min_val) * 65535).astype(np.uint16)
832
833            img = np.clip(img, 0, 65535)
834
835        return img
836
837    except:
838        print("Something went wrong. Check the function input data and try again!")

Load an image and convert it to 16-bit if necessary.

Parameters

path : str Path to the image file.

Returns

np.ndarray Loaded 16-bit image.

def rotate_image(img, rotate: int):
841def rotate_image(img, rotate: int):
842    """
843    Rotate an image by a specified angle.
844
845    Parameters
846    ----------
847    img : np.ndarray
848        Image to rotate.
849
850    rotate : int
851        Degree of rotation. Available options: 90, 180, 270.
852
853    Returns
854    -------
855    np.ndarray
856        Rotated image.
857    """
858
859    try:
860
861        if rotate == 0:
862            r = 0
863        elif rotate == 90:
864            r = -1
865        elif rotate == 180:
866            r = 2
867        elif rotate == 180:
868            r = 2
869        elif rotate == 270:
870            r = 1
871        else:
872            print("Wrong argument - rotate!")
873            return None
874
875        img = img.copy()
876
877        img = np.rot90(img.copy(), k=r)
878
879        return img
880
881    except:
882        print("Something went wrong. Check the function input data and try again!")

Rotate an image by a specified angle.

Parameters

img : np.ndarray Image to rotate.

rotate : int Degree of rotation. Available options: 90, 180, 270.

Returns

np.ndarray Rotated image.

def mirror_image(img, rotate: str):
885def mirror_image(img, rotate: str):
886    """
887    Mirror an image along specified axes.
888
889    Parameters
890    ----------
891    img : np.ndarray
892        Image to mirror.
893
894    rotate : str
895        Type of mirroring to apply. Options:
896            'h'  - horizontal mirroring
897            'v'  - vertical mirroring
898            'hv' - both horizontal and vertical mirroring
899
900    Returns
901    -------
902    np.ndarray
903        Mirrored image.
904    """
905
906    try:
907
908        if rotate == "h":
909            img = np.fliplr(img.copy())
910        elif rotate == "v":
911            img = np.flipud(img.copy())
912        elif rotate == "hv":
913            img = np.flipud(np.fliplr(img.copy()))
914        else:
915            print("Wrong argument - rotate!")
916            return None
917
918        return img
919
920    except:
921        print("Something went wrong. Check the function input data and try again!")

Mirror an image along specified axes.

Parameters

img : np.ndarray Image to mirror.

rotate : str Type of mirroring to apply. Options: 'h' - horizontal mirroring 'v' - vertical mirroring 'hv' - both horizontal and vertical mirroring

Returns

np.ndarray Mirrored image.

def merge_images(image_list: list, intensity_factors: list = []):
927def merge_images(image_list: list, intensity_factors: list = []):
928    """
929    Merge multiple image projections from different channels.
930
931    Parameters
932    ----------
933    image_list : list of np.ndarray
934        List of images to merge. All images must have the same shape and size.
935
936    intensity_factors : list of float
937        Intensity scaling factors for each image in `image_list`. Base value is 1.
938            - Values < 1 decrease intensity.
939            - Values > 1 increase intensity.
940
941    Returns
942    -------
943    np.ndarray
944        Merged image after applying intensity scaling.
945    """
946
947    try:
948
949        result = None
950
951        if len(intensity_factors) == 0:
952
953            intensity_factors = []
954            for bt in range(len(image_list)):
955                intensity_factors.append(1)
956
957        for i, image in enumerate(image_list):
958            if result is None:
959                result = image.astype(np.uint64) * intensity_factors[i]
960            else:
961                result = cv2.addWeighted(
962                    result, 1, image.astype(np.uint64) * intensity_factors[i], 1, 0
963                )
964
965        result = np.clip(result, 0, 65535)
966
967        result = result.astype(np.uint16)
968
969        return result
970
971    except:
972        print("Something went wrong. Check the function input data and try again!")

Merge multiple image projections from different channels.

Parameters

image_list : list of np.ndarray List of images to merge. All images must have the same shape and size.

intensity_factors : list of float Intensity scaling factors for each image in image_list. Base value is 1. - Values < 1 decrease intensity. - Values > 1 increase intensity.

Returns

np.ndarray Merged image after applying intensity scaling.

def adjust_img_16bit( img, color='gray', max_intensity: int = 65535, min_intenisty: int = 0, brightness: int = 1000, contrast=1.0, gamma=1.0):
 975def adjust_img_16bit(
 976    img,
 977    color="gray",
 978    max_intensity: int = 65535,
 979    min_intenisty: int = 0,
 980    brightness: int = 1000,
 981    contrast=1.0,
 982    gamma=1.0,
 983):
 984    """
 985    Manually adjust image parameters and return the adjusted image.
 986
 987    Parameters
 988    ----------
 989    img : np.ndarray
 990        Input image.
 991
 992    color : str
 993        Image color channel. Options: ['green', 'blue', 'red', 'yellow', 'magenta', 'cyan'].
 994
 995    max_intensity : int
 996        Upper threshold for pixel values. Pixels exceeding this value are set to `max_intensity`.
 997
 998    min_intensity : int
 999        Lower threshold for pixel values. Pixels below this value are set to 0.
1000
1001    brightness : int, optional
1002        Image brightness adjustment. Typical range: [900-2000]. Default is 1000.
1003
1004    contrast : float or int, optional
1005        Image contrast adjustment. Typical range: [0-5]. Default is 1.
1006
1007    gamma : float or int, optional
1008        Gamma correction factor. Typical range: [0-5]. Default is 1.
1009
1010    Returns
1011    -------
1012    np.ndarray
1013        Adjusted image after applying brightness, contrast, gamma, and intensity thresholds.
1014    """
1015
1016    try:
1017
1018        img = img.copy()
1019
1020        img = img.astype(np.uint64)
1021
1022        img = np.clip(img, 0, 65535)
1023
1024        # brightness
1025        if brightness != 1000:
1026            factor = -1000 + brightness
1027            side = factor / abs(factor)
1028            img[img > 0] = img[img > 0] + ((img[img > 0] * abs(factor) / 100) * side)
1029            img = np.clip(img, 0, 65535)
1030
1031        # contrast
1032        if contrast != 1:
1033            img = ((img - np.mean(img)) * contrast) + np.mean(img)
1034            img = np.clip(img, 0, 65535)
1035
1036        # gamma
1037        if gamma != 1:
1038
1039            max_val = np.max(img)
1040
1041            image_array = img.copy() / max_val
1042
1043            image_array = np.clip(image_array, 0, 1)
1044
1045            corrected_array = image_array ** (1 / gamma)
1046
1047            img = corrected_array * max_val
1048
1049            del image_array, corrected_array
1050
1051            img = np.clip(img, 0, 65535)
1052
1053        img = np.nan_to_num(img, nan=0, posinf=65535, neginf=0)
1054        max_val = np.max(img)
1055        if max_val > 0:
1056            img = ((img / max_val) * 65535).astype(np.uint16)
1057
1058        # max intenisty
1059        if max_intensity != 65535:
1060            img[img >= max_intensity] = 65535
1061
1062        # min intenisty
1063        if min_intenisty != 0:
1064            img[img <= min_intenisty] = 0
1065
1066        img_gamma = np.zeros((img.shape[0], img.shape[1], 3), dtype=np.uint16)
1067
1068        if color == "green":
1069            img_gamma[:, :, 1] = img
1070
1071        elif color == "red":
1072            img_gamma[:, :, 2] = img
1073
1074        elif color == "blue":
1075            img_gamma[:, :, 0] = img
1076
1077        elif color == "magenta":
1078            img_gamma[:, :, 0] = img
1079            img_gamma[:, :, 2] = img
1080
1081        elif color == "yellow":
1082            img_gamma[:, :, 1] = img
1083            img_gamma[:, :, 2] = img
1084
1085        elif color == "cyan":
1086            img_gamma[:, :, 0] = img
1087            img_gamma[:, :, 1] = img
1088
1089        elif color == "gray":
1090            img_gamma[:, :, 0] = img
1091            img_gamma[:, :, 1] = img
1092            img_gamma[:, :, 2] = img
1093
1094        return img_gamma
1095
1096    except:
1097        print("Something went wrong. Check the function input data and try again!")

Manually adjust image parameters and return the adjusted image.

Parameters

img : np.ndarray Input image.

color : str Image color channel. Options: ['green', 'blue', 'red', 'yellow', 'magenta', 'cyan'].

max_intensity : int Upper threshold for pixel values. Pixels exceeding this value are set to max_intensity.

min_intensity : int Lower threshold for pixel values. Pixels below this value are set to 0.

brightness : int, optional Image brightness adjustment. Typical range: [900-2000]. Default is 1000.

contrast : float or int, optional Image contrast adjustment. Typical range: [0-5]. Default is 1.

gamma : float or int, optional Gamma correction factor. Typical range: [0-5]. Default is 1.

Returns

np.ndarray Adjusted image after applying brightness, contrast, gamma, and intensity thresholds.

def get_screan():
1100def get_screan():
1101    """
1102    Get the current screen width and height.
1103
1104    Returns
1105    -------
1106    tuple of int
1107        screen_width, screen_height
1108    """
1109
1110    root = tk.Tk()
1111    screen_width = root.winfo_screenwidth()
1112    screen_height = root.winfo_screenheight()
1113
1114    root.destroy()
1115
1116    return screen_width, screen_height

Get the current screen width and height.

Returns

tuple of int screen_width, screen_height

def resize_to_screen_img(img_file, factor=1):
1119def resize_to_screen_img(img_file, factor=1):
1120    """
1121    Resize an input image to fit the screen size, optionally scaled by a factor.
1122
1123    Parameters
1124    ----------
1125    img_file : np.ndarray
1126        Input image to be resized.
1127
1128    factor : float, optional
1129        Scaling factor applied to the screen dimensions before resizing the image. Default is 1.
1130
1131    Returns
1132    -------
1133    np.ndarray
1134        Resized image that fits within the screen dimensions while maintaining aspect ratio.
1135    """
1136
1137    screen_width, screen_height = get_screan()
1138
1139    screen_width = int(screen_width * factor)
1140    screen_height = int(screen_height * factor)
1141
1142    h = int(img_file.shape[0])
1143    w = int(img_file.shape[1])
1144
1145    if screen_width < w:
1146        h = img_file.shape[0]
1147        w = img_file.shape[1]
1148
1149        ww = int((screen_width / w) * w)
1150        hh = int((screen_width / w) * h)
1151
1152        img_file = cv2.resize(img_file, (ww, hh))
1153
1154        h = img_file.shape[0]
1155        w = img_file.shape[1]
1156
1157    if screen_height < h:
1158        h = img_file.shape[0]
1159        w = img_file.shape[1]
1160
1161        ww = int((screen_height / h) * w)
1162        hh = int((screen_height / h) * h)
1163
1164        img_file = cv2.resize(img_file, (ww, hh))
1165
1166    return img_file

Resize an input image to fit the screen size, optionally scaled by a factor.

Parameters

img_file : np.ndarray Input image to be resized.

factor : float, optional Scaling factor applied to the screen dimensions before resizing the image. Default is 1.

Returns

np.ndarray Resized image that fits within the screen dimensions while maintaining aspect ratio.

def display_preview(image):
1169def display_preview(image):
1170    """
1171    Quickly preview an image using a display window.
1172
1173    Parameters
1174    ----------
1175    image : np.ndarray
1176        Input image to be displayed.
1177
1178    Notes
1179    -----
1180        The function displays the image in a window. Does not return a value.
1181    """
1182    try:
1183
1184        res_sc = resize_to_screen_img(image.copy(), factor=0.8)
1185
1186        cv2.imshow("Display", res_sc)
1187
1188        cv2.waitKey(0) & 0xFF
1189
1190        cv2.destroyAllWindows()
1191
1192    except:
1193        print("Something went wrong. Check the function input data and try again!")

Quickly preview an image using a display window.

Parameters

image : np.ndarray Input image to be displayed.

Notes

The function displays the image in a window. Does not return a value.
class ImageTools:
1199class ImageTools:
1200    """
1201    Collection of utility functions for image preprocessing, adjustment,
1202    merging, and stitching.
1203
1204    This class provides standalone static methods for operations such as
1205    histogram equalization, CLAHE enhancement, intensity adjustments,
1206    image merging based on weighted ratios, and horizontal stitching.
1207    These tools are used internally by `ImagesManagement` but can also be
1208    applied independently for generic image-processing workflows.
1209
1210    Notes
1211    -----
1212    All methods operate on NumPy arrays and expect images in standard
1213    OpenCV format. Some functions assume 16-bit images unless stated
1214    otherwise.
1215
1216    Examples
1217    --------
1218    >>> from ImageTools import ImageTools
1219    >>> img = ImageTools.equalize_hist_16bit(img)
1220    >>> merged = ImageTools.merge_images([img1, img2], [0.5, 0.5])
1221    """
1222
1223    def get_screan(self):
1224        """
1225        Return the current screen size.
1226
1227        Returns
1228        -------
1229        screen_width : int
1230            Width of the screen in pixels.
1231
1232        screen_height : int
1233            Height of the screen in pixels.
1234        """
1235
1236        root = tk.Tk()
1237        screen_width = root.winfo_screenwidth()
1238        screen_height = root.winfo_screenheight()
1239
1240        root.destroy()
1241
1242        return screen_width, screen_height
1243
1244    def resize_to_screen_img(self, img_file, factor=0.5):
1245        """
1246        Resize an input image to a scaled version of the current screen size.
1247
1248        Parameters
1249        ----------
1250        img_file : np.ndarray
1251            Input image to be resized.
1252
1253        factor : int
1254            Scaling factor applied to the screen dimensions.
1255
1256        Returns
1257        -------
1258        resized_image : np.ndarray
1259            Resized image adjusted to the scaled screen size.
1260        """
1261
1262        screen_width, screen_height = self.get_screan()
1263
1264        screen_width = int(screen_width * factor)
1265        screen_height = int(screen_height * factor)
1266
1267        h = int(img_file.shape[0])
1268        w = int(img_file.shape[1])
1269
1270        if screen_width < w or screen_width * 0.3 > w:
1271            h = img_file.shape[0]
1272            w = img_file.shape[1]
1273
1274            ww = int((screen_width / w) * w)
1275            hh = int((screen_width / w) * h)
1276
1277            img_file = cv2.resize(img_file, (ww, hh))
1278
1279            h = img_file.shape[0]
1280            w = img_file.shape[1]
1281
1282        if screen_height < h or screen_height * 0.3 > h:
1283            h = img_file.shape[0]
1284            w = img_file.shape[1]
1285
1286            ww = int((screen_height / h) * w)
1287            hh = int((screen_height / h) * h)
1288
1289            img_file = cv2.resize(img_file, (ww, hh))
1290
1291        return img_file
1292
1293    def load_JIMG_project(self, project_path):
1294        """
1295        Load a JIMG project from a `.pjm` file.
1296
1297        Parameters
1298        ----------
1299        file_path : str
1300            Path to the project file with `.pjm` extension.
1301
1302        Returns
1303        -------
1304        project : object
1305            Loaded project object.
1306
1307        Raises
1308        ------
1309        ValueError
1310            If the provided file does not have a `.pjm` extension.
1311        """
1312
1313        if ".pjm" in project_path:
1314            with open(project_path, "rb") as file:
1315                app_metadata_tmp = pickle.load(file)
1316
1317            return app_metadata_tmp
1318
1319        else:
1320            print("\nProvide path to the project metadata file with *.pjm extension!!!")
1321
1322    def ajd_mask_size(self, image, mask):
1323        """
1324        Adjusts the size of a mask to match the dimensions of a given image.
1325
1326        Parameters
1327        ----------
1328        image : np.ndarray
1329            Reference image whose size the mask should match. Can be 2D or 3D.
1330
1331        mask : np.ndarray
1332            Mask image to be resized.
1333
1334        Returns
1335        -------
1336        np.ndarray
1337            Resized mask matching the input image dimensions.
1338        """
1339
1340        try:
1341            mask = cv2.resize(mask, (image.shape[2], image.shape[1]))
1342        except:
1343            mask = cv2.resize(mask, (image.shape[1], image.shape[0]))
1344
1345        return mask
1346
1347    def load_image(self, path_to_image):
1348        """
1349        Load an image from the specified file path.
1350
1351        Parameters
1352        ----------
1353        path_to_image : str
1354            Path to the image file.
1355
1356        Returns
1357        -------
1358        image : np.ndarray
1359            Loaded image as a NumPy array.
1360        """
1361
1362        image = load_image(path_to_image)
1363        return image
1364
1365    def load_3D_tiff(self, path_to_image):
1366        """
1367        Load a 3D image from a TIFF file.
1368
1369        Parameters
1370        ----------
1371        path_to_image : str
1372            Path to the 3D TIFF image file.
1373
1374        Returns
1375        -------
1376        image : np.ndarray
1377            Loaded 3D image as a NumPy array.
1378        """
1379
1380        image = load_tiff(path_to_image)
1381
1382        return image
1383
1384    def load_mask(self, path_to_mask):
1385        """
1386        Load a mask image.
1387
1388        Parameters
1389        ----------
1390        path_to_mask : str
1391            Path to the mask image file.
1392
1393        Returns
1394        -------
1395        mask : np.ndarray
1396            Loaded mask image as a NumPy array.
1397        """
1398
1399        mask = cv2.imread(path_to_mask, cv2.IMREAD_GRAYSCALE)
1400        return mask
1401
1402    def save(self, image, file_name):
1403        """
1404        Save an image to disk.
1405
1406        Parameters
1407        ----------
1408        image : np.ndarray
1409            Image data to be saved.
1410
1411        file_name : str
1412            Output file path including the desired image extension (e.g., ".png", ".jpg").
1413
1414        """
1415
1416        cv2.imwrite(filename=file_name, img=image)
1417
1418    # calculation methods
1419
1420    def drop_dict(self, dictionary, key, var, action=None):
1421        """
1422        Filters elements from a dictionary based on a condition applied to a specified key.
1423
1424        Parameters
1425        ----------
1426        dictionary : dict
1427            Dictionary containing lists or arrays under each key.
1428
1429        key : str
1430            The key in the dictionary on which the filtering condition will be applied.
1431
1432        var : numeric
1433            Value to compare against each element in dictionary[key].
1434
1435        action : str, optional
1436            Comparison operator as string: '<=', '>=', '==', '<', '>'.
1437            Default is None, which raises an error.
1438
1439        Returns
1440        -------
1441        dict
1442            A new dictionary with elements removed where the condition matches.
1443        """
1444
1445        dictionary = copy.deepcopy(dictionary)
1446        indices_to_drop = []
1447        for i, dr in enumerate(dictionary[key]):
1448
1449            if isinstance(dr, np.ndarray):
1450                dr = np.mean(dr)
1451
1452            if action == "<=":
1453                if var <= dr:
1454                    indices_to_drop.append(i)
1455            elif action == ">=":
1456                if var >= dr:
1457                    indices_to_drop.append(i)
1458            elif action == "==":
1459                if var == dr:
1460                    indices_to_drop.append(i)
1461            elif action == "<":
1462                if var < dr:
1463                    indices_to_drop.append(i)
1464            elif action == ">":
1465                if var > dr:
1466                    indices_to_drop.append(i)
1467            else:
1468                print("\nWrong action!")
1469                return None
1470
1471        for key, value in dictionary.items():
1472            dictionary[key] = [
1473                v for i, v in enumerate(value) if i not in indices_to_drop
1474            ]
1475
1476        return dictionary
1477
1478    # modified for gradation (separation near nucleus)
1479
1480    def create_mask(self, dictionary, image):
1481        """
1482        Creates a mask image with gradated intensity values for each coordinate set in a dictionary.
1483
1484        Parameters
1485        ----------
1486        dictionary : dict
1487            Dictionary containing a 'coords' key with a list of arrays of coordinates.
1488
1489        image : np.ndarray
1490            Base image to define the shape of the mask.
1491
1492        Returns
1493        -------
1494        np.ndarray
1495            Mask image with uint16 gradated intensity.
1496        """
1497
1498        image_mask = np.zeros(image.shape)
1499
1500        arrays_list = copy.deepcopy(dictionary["coords"])
1501
1502        if len(arrays_list) > 0:
1503
1504            initial_val = math.floor((2**16 - 1) / 4)
1505            intensity_list = ((2**16 - 1) - math.floor((2**16 - 1) / 4)) / len(
1506                arrays_list
1507            )
1508
1509            gradation = 1
1510            for arr in arrays_list:
1511                image_mask[arr[:, 0], arr[:, 1]] = initial_val + (
1512                    gradation * intensity_list
1513                )
1514                gradation += 1
1515
1516        return image_mask.astype("uint16")
1517
1518    def min_max_histograme(self, image):
1519        """
1520        Calculates histogram-based minimum and maximum intensity percentiles in an image.
1521
1522        Parameters
1523        ----------
1524        image : np.ndarray
1525            Input image for histogram analysis.
1526
1527        Returns
1528        -------
1529        min_val : float
1530            Minimum intensity percentile above zero.
1531
1532        max_val : float
1533            Maximum intensity percentile based on histogram gradient.
1534
1535        df : pd.DataFrame
1536            DataFrame containing quantiles, corresponding intensity values, and cumulative percents.
1537        """
1538
1539        q = []
1540        val = []
1541        perc = []
1542
1543        max_val = image.shape[0] * image.shape[1]
1544
1545        for n in range(0, 100, 5):
1546            q.append(n)
1547            val.append(np.quantile(image, n / 100))
1548            sum_val = np.sum(image < np.quantile(image, n / 100))
1549            pr = sum_val / max_val
1550            perc.append(pr)
1551
1552        df = pd.DataFrame({"q": q, "val": val, "perc": perc})
1553
1554        min_val = 0
1555        for i in df.index:
1556            if df["val"][i] != 0 and min_val == 0:
1557                min_val = df["perc"][i]
1558
1559        max_val = 0
1560        df = df[df["perc"] > 0]
1561        df = df.sort_values("q", ascending=False).reset_index(drop=True)
1562
1563        for i in df.index:
1564            if i > 1 and df["val"][i] * 1.5 > df["val"][i - 1]:
1565                max_val = df["perc"][i]
1566                break
1567            elif i == len(df.index) - 1:
1568                max_val = df["perc"][i]
1569
1570        return min_val, max_val, df

Collection of utility functions for image preprocessing, adjustment, merging, and stitching.

This class provides standalone static methods for operations such as histogram equalization, CLAHE enhancement, intensity adjustments, image merging based on weighted ratios, and horizontal stitching. These tools are used internally by ImagesManagement but can also be applied independently for generic image-processing workflows.

Notes

All methods operate on NumPy arrays and expect images in standard OpenCV format. Some functions assume 16-bit images unless stated otherwise.

Examples

>>> from ImageTools import ImageTools
>>> img = ImageTools.equalize_hist_16bit(img)
>>> merged = ImageTools.merge_images([img1, img2], [0.5, 0.5])
def get_screan(self):
1223    def get_screan(self):
1224        """
1225        Return the current screen size.
1226
1227        Returns
1228        -------
1229        screen_width : int
1230            Width of the screen in pixels.
1231
1232        screen_height : int
1233            Height of the screen in pixels.
1234        """
1235
1236        root = tk.Tk()
1237        screen_width = root.winfo_screenwidth()
1238        screen_height = root.winfo_screenheight()
1239
1240        root.destroy()
1241
1242        return screen_width, screen_height

Return the current screen size.

Returns

screen_width : int Width of the screen in pixels.

screen_height : int Height of the screen in pixels.

def resize_to_screen_img(self, img_file, factor=0.5):
1244    def resize_to_screen_img(self, img_file, factor=0.5):
1245        """
1246        Resize an input image to a scaled version of the current screen size.
1247
1248        Parameters
1249        ----------
1250        img_file : np.ndarray
1251            Input image to be resized.
1252
1253        factor : int
1254            Scaling factor applied to the screen dimensions.
1255
1256        Returns
1257        -------
1258        resized_image : np.ndarray
1259            Resized image adjusted to the scaled screen size.
1260        """
1261
1262        screen_width, screen_height = self.get_screan()
1263
1264        screen_width = int(screen_width * factor)
1265        screen_height = int(screen_height * factor)
1266
1267        h = int(img_file.shape[0])
1268        w = int(img_file.shape[1])
1269
1270        if screen_width < w or screen_width * 0.3 > w:
1271            h = img_file.shape[0]
1272            w = img_file.shape[1]
1273
1274            ww = int((screen_width / w) * w)
1275            hh = int((screen_width / w) * h)
1276
1277            img_file = cv2.resize(img_file, (ww, hh))
1278
1279            h = img_file.shape[0]
1280            w = img_file.shape[1]
1281
1282        if screen_height < h or screen_height * 0.3 > h:
1283            h = img_file.shape[0]
1284            w = img_file.shape[1]
1285
1286            ww = int((screen_height / h) * w)
1287            hh = int((screen_height / h) * h)
1288
1289            img_file = cv2.resize(img_file, (ww, hh))
1290
1291        return img_file

Resize an input image to a scaled version of the current screen size.

Parameters

img_file : np.ndarray Input image to be resized.

factor : int Scaling factor applied to the screen dimensions.

Returns

resized_image : np.ndarray Resized image adjusted to the scaled screen size.

def load_JIMG_project(self, project_path):
1293    def load_JIMG_project(self, project_path):
1294        """
1295        Load a JIMG project from a `.pjm` file.
1296
1297        Parameters
1298        ----------
1299        file_path : str
1300            Path to the project file with `.pjm` extension.
1301
1302        Returns
1303        -------
1304        project : object
1305            Loaded project object.
1306
1307        Raises
1308        ------
1309        ValueError
1310            If the provided file does not have a `.pjm` extension.
1311        """
1312
1313        if ".pjm" in project_path:
1314            with open(project_path, "rb") as file:
1315                app_metadata_tmp = pickle.load(file)
1316
1317            return app_metadata_tmp
1318
1319        else:
1320            print("\nProvide path to the project metadata file with *.pjm extension!!!")

Load a JIMG project from a .pjm file.

Parameters

file_path : str Path to the project file with .pjm extension.

Returns

project : object Loaded project object.

Raises

ValueError If the provided file does not have a .pjm extension.

def ajd_mask_size(self, image, mask):
1322    def ajd_mask_size(self, image, mask):
1323        """
1324        Adjusts the size of a mask to match the dimensions of a given image.
1325
1326        Parameters
1327        ----------
1328        image : np.ndarray
1329            Reference image whose size the mask should match. Can be 2D or 3D.
1330
1331        mask : np.ndarray
1332            Mask image to be resized.
1333
1334        Returns
1335        -------
1336        np.ndarray
1337            Resized mask matching the input image dimensions.
1338        """
1339
1340        try:
1341            mask = cv2.resize(mask, (image.shape[2], image.shape[1]))
1342        except:
1343            mask = cv2.resize(mask, (image.shape[1], image.shape[0]))
1344
1345        return mask

Adjusts the size of a mask to match the dimensions of a given image.

Parameters

image : np.ndarray Reference image whose size the mask should match. Can be 2D or 3D.

mask : np.ndarray Mask image to be resized.

Returns

np.ndarray Resized mask matching the input image dimensions.

def load_image(self, path_to_image):
1347    def load_image(self, path_to_image):
1348        """
1349        Load an image from the specified file path.
1350
1351        Parameters
1352        ----------
1353        path_to_image : str
1354            Path to the image file.
1355
1356        Returns
1357        -------
1358        image : np.ndarray
1359            Loaded image as a NumPy array.
1360        """
1361
1362        image = load_image(path_to_image)
1363        return image

Load an image from the specified file path.

Parameters

path_to_image : str Path to the image file.

Returns

image : np.ndarray Loaded image as a NumPy array.

def load_3D_tiff(self, path_to_image):
1365    def load_3D_tiff(self, path_to_image):
1366        """
1367        Load a 3D image from a TIFF file.
1368
1369        Parameters
1370        ----------
1371        path_to_image : str
1372            Path to the 3D TIFF image file.
1373
1374        Returns
1375        -------
1376        image : np.ndarray
1377            Loaded 3D image as a NumPy array.
1378        """
1379
1380        image = load_tiff(path_to_image)
1381
1382        return image

Load a 3D image from a TIFF file.

Parameters

path_to_image : str Path to the 3D TIFF image file.

Returns

image : np.ndarray Loaded 3D image as a NumPy array.

def load_mask(self, path_to_mask):
1384    def load_mask(self, path_to_mask):
1385        """
1386        Load a mask image.
1387
1388        Parameters
1389        ----------
1390        path_to_mask : str
1391            Path to the mask image file.
1392
1393        Returns
1394        -------
1395        mask : np.ndarray
1396            Loaded mask image as a NumPy array.
1397        """
1398
1399        mask = cv2.imread(path_to_mask, cv2.IMREAD_GRAYSCALE)
1400        return mask

Load a mask image.

Parameters

path_to_mask : str Path to the mask image file.

Returns

mask : np.ndarray Loaded mask image as a NumPy array.

def save(self, image, file_name):
1402    def save(self, image, file_name):
1403        """
1404        Save an image to disk.
1405
1406        Parameters
1407        ----------
1408        image : np.ndarray
1409            Image data to be saved.
1410
1411        file_name : str
1412            Output file path including the desired image extension (e.g., ".png", ".jpg").
1413
1414        """
1415
1416        cv2.imwrite(filename=file_name, img=image)

Save an image to disk.

Parameters

image : np.ndarray Image data to be saved.

file_name : str Output file path including the desired image extension (e.g., ".png", ".jpg").

def drop_dict(self, dictionary, key, var, action=None):
1420    def drop_dict(self, dictionary, key, var, action=None):
1421        """
1422        Filters elements from a dictionary based on a condition applied to a specified key.
1423
1424        Parameters
1425        ----------
1426        dictionary : dict
1427            Dictionary containing lists or arrays under each key.
1428
1429        key : str
1430            The key in the dictionary on which the filtering condition will be applied.
1431
1432        var : numeric
1433            Value to compare against each element in dictionary[key].
1434
1435        action : str, optional
1436            Comparison operator as string: '<=', '>=', '==', '<', '>'.
1437            Default is None, which raises an error.
1438
1439        Returns
1440        -------
1441        dict
1442            A new dictionary with elements removed where the condition matches.
1443        """
1444
1445        dictionary = copy.deepcopy(dictionary)
1446        indices_to_drop = []
1447        for i, dr in enumerate(dictionary[key]):
1448
1449            if isinstance(dr, np.ndarray):
1450                dr = np.mean(dr)
1451
1452            if action == "<=":
1453                if var <= dr:
1454                    indices_to_drop.append(i)
1455            elif action == ">=":
1456                if var >= dr:
1457                    indices_to_drop.append(i)
1458            elif action == "==":
1459                if var == dr:
1460                    indices_to_drop.append(i)
1461            elif action == "<":
1462                if var < dr:
1463                    indices_to_drop.append(i)
1464            elif action == ">":
1465                if var > dr:
1466                    indices_to_drop.append(i)
1467            else:
1468                print("\nWrong action!")
1469                return None
1470
1471        for key, value in dictionary.items():
1472            dictionary[key] = [
1473                v for i, v in enumerate(value) if i not in indices_to_drop
1474            ]
1475
1476        return dictionary

Filters elements from a dictionary based on a condition applied to a specified key.

Parameters

dictionary : dict Dictionary containing lists or arrays under each key.

key : str The key in the dictionary on which the filtering condition will be applied.

var : numeric Value to compare against each element in dictionary[key].

action : str, optional Comparison operator as string: '<=', '>=', '==', '<', '>'. Default is None, which raises an error.

Returns

dict A new dictionary with elements removed where the condition matches.

def create_mask(self, dictionary, image):
1480    def create_mask(self, dictionary, image):
1481        """
1482        Creates a mask image with gradated intensity values for each coordinate set in a dictionary.
1483
1484        Parameters
1485        ----------
1486        dictionary : dict
1487            Dictionary containing a 'coords' key with a list of arrays of coordinates.
1488
1489        image : np.ndarray
1490            Base image to define the shape of the mask.
1491
1492        Returns
1493        -------
1494        np.ndarray
1495            Mask image with uint16 gradated intensity.
1496        """
1497
1498        image_mask = np.zeros(image.shape)
1499
1500        arrays_list = copy.deepcopy(dictionary["coords"])
1501
1502        if len(arrays_list) > 0:
1503
1504            initial_val = math.floor((2**16 - 1) / 4)
1505            intensity_list = ((2**16 - 1) - math.floor((2**16 - 1) / 4)) / len(
1506                arrays_list
1507            )
1508
1509            gradation = 1
1510            for arr in arrays_list:
1511                image_mask[arr[:, 0], arr[:, 1]] = initial_val + (
1512                    gradation * intensity_list
1513                )
1514                gradation += 1
1515
1516        return image_mask.astype("uint16")

Creates a mask image with gradated intensity values for each coordinate set in a dictionary.

Parameters

dictionary : dict Dictionary containing a 'coords' key with a list of arrays of coordinates.

image : np.ndarray Base image to define the shape of the mask.

Returns

np.ndarray Mask image with uint16 gradated intensity.

def min_max_histograme(self, image):
1518    def min_max_histograme(self, image):
1519        """
1520        Calculates histogram-based minimum and maximum intensity percentiles in an image.
1521
1522        Parameters
1523        ----------
1524        image : np.ndarray
1525            Input image for histogram analysis.
1526
1527        Returns
1528        -------
1529        min_val : float
1530            Minimum intensity percentile above zero.
1531
1532        max_val : float
1533            Maximum intensity percentile based on histogram gradient.
1534
1535        df : pd.DataFrame
1536            DataFrame containing quantiles, corresponding intensity values, and cumulative percents.
1537        """
1538
1539        q = []
1540        val = []
1541        perc = []
1542
1543        max_val = image.shape[0] * image.shape[1]
1544
1545        for n in range(0, 100, 5):
1546            q.append(n)
1547            val.append(np.quantile(image, n / 100))
1548            sum_val = np.sum(image < np.quantile(image, n / 100))
1549            pr = sum_val / max_val
1550            perc.append(pr)
1551
1552        df = pd.DataFrame({"q": q, "val": val, "perc": perc})
1553
1554        min_val = 0
1555        for i in df.index:
1556            if df["val"][i] != 0 and min_val == 0:
1557                min_val = df["perc"][i]
1558
1559        max_val = 0
1560        df = df[df["perc"] > 0]
1561        df = df.sort_values("q", ascending=False).reset_index(drop=True)
1562
1563        for i in df.index:
1564            if i > 1 and df["val"][i] * 1.5 > df["val"][i - 1]:
1565                max_val = df["perc"][i]
1566                break
1567            elif i == len(df.index) - 1:
1568                max_val = df["perc"][i]
1569
1570        return min_val, max_val, df

Calculates histogram-based minimum and maximum intensity percentiles in an image.

Parameters

image : np.ndarray Input image for histogram analysis.

Returns

min_val : float Minimum intensity percentile above zero.

max_val : float Maximum intensity percentile based on histogram gradient.

df : pd.DataFrame DataFrame containing quantiles, corresponding intensity values, and cumulative percents.