jdti.utils

   1import os
   2import re
   3
   4import matplotlib as mpl
   5import matplotlib.patches as mpatches
   6import matplotlib.pyplot as plt
   7import numpy as np
   8import pandas as pd
   9import scipy.stats as stats
  10from adjustText import adjust_text
  11from joblib import Parallel, delayed
  12from matplotlib.lines import Line2D
  13from mpl_toolkits.axes_grid1.inset_locator import inset_axes
  14from scipy.cluster.hierarchy import dendrogram, linkage
  15from scipy.io import mmread
  16from sklearn.preprocessing import MinMaxScaler
  17from tqdm import tqdm
  18
  19
  20def load_sparse(path: str, name: str):
  21    """
  22    Load a sparse matrix dataset along with associated gene
  23    and cell metadata, and return it as a dense DataFrame.
  24
  25    This function expects the input directory to contain three files in standard
  26    10x Genomics format:
  27      - "matrix.mtx": the gene expression matrix in Matrix Market format
  28      - "genes.tsv": tab-separated file containing gene identifiers
  29      - "barcodes.tsv": tab-separated file containing cell barcodes / names
  30
  31    Parameters
  32    ----------
  33    path : str
  34        Path to the directory containing the matrix and annotation files.
  35
  36    name : str
  37        Label or dataset identifier to be assigned to all cells in the metadata.
  38
  39    Returns
  40    -------
  41    data : pandas.DataFrame
  42        Dense expression matrix where rows correspond to genes and columns
  43        correspond to cells.
  44    metadata : pandas.DataFrame
  45        Metadata DataFrame with two columns:
  46          - "cell_names": the names of the cells (barcodes)
  47          - "sets": the dataset label assigned to each cell (from `name`)
  48
  49    Notes
  50    -----
  51    The function converts the sparse matrix into a dense DataFrame. This may
  52    require a large amount of memory for datasets with many cells and genes.
  53    """
  54
  55    data = mmread(os.path.join(path, "matrix.mtx"))
  56    data = pd.DataFrame(data.todense())
  57    genes = pd.read_csv(os.path.join(path, "genes.tsv"), header=None, sep="\t")
  58    names = pd.read_csv(os.path.join(path, "barcodes.tsv"), header=None, sep="\t")
  59    data.columns = [str(x) for x in names[0]]
  60    data.index = list(genes[0])
  61
  62    names = list(data.columns)
  63    sets = [name] * len(names)
  64
  65    metadata = pd.DataFrame({"cell_names": names, "sets": sets})
  66
  67    return data, metadata
  68
  69
  70def volcano_plot(
  71    deg_data: pd.DataFrame,
  72    p_adj: bool = True,
  73    top: int = 25,
  74    p_val: float | int = 0.05,
  75    lfc: float | int = 0.25,
  76    standard_scale: bool = False,
  77    rescale_adj: bool = True,
  78    image_width: int = 12,
  79    image_high: int = 12,
  80):
  81    """
  82    Generate a volcano plot from differential expression results.
  83
  84    A volcano plot visualizes the relationship between statistical significance
  85    (p-values or standarized p-value) and log(fold change) for each gene, highlighting
  86    genes that pass significance thresholds.
  87
  88    Parameters
  89    ----------
  90    deg_data : pandas.DataFrame
  91        DataFrame containing differential expression results from calc_DEG() function.
  92
  93    p_adj : bool, default=True
  94        If True, use adjusted p-values. If False, use raw p-values.
  95
  96    top : int, default=25
  97        Number of top significant genes to highlight on the plot.
  98
  99    p_val : float | int, default=0.05
 100        Significance threshold for p-values (or adjusted p-values).
 101
 102    lfc : float | int, default=0.25
 103        Threshold for absolute log fold change.
 104
 105    standard_scale : bool, default=False
 106        If True, use standardized p-values for visualization instead of -log10(p-values).
 107
 108    rescale_adj : bool, default=True
 109        If True, rescale p-values to avoid long breaks caused by outlier values.
 110
 111    image_width : int, default=12
 112        Width of the generated plot in inches.
 113
 114    image_high : int, default=12
 115        Height of the generated plot in inches.
 116
 117    Returns
 118    -------
 119    matplotlib.figure.Figure
 120        The generated volcano plot figure.
 121
 122    """
 123
 124    if p_adj:
 125        pv = "adj_pval"
 126    else:
 127        pv = "p_val"
 128
 129    deg_df = deg_data.copy()
 130
 131    if standard_scale:
 132
 133        p_val_scale = "scale(p-val)"
 134
 135        scaler = MinMaxScaler(feature_range=(0, 1000))
 136
 137        tmp_p = deg_df[
 138            ((deg_df[pv] != 0) & (deg_df["log(FC)"] < 0))
 139            | ((deg_df[pv] != 0) & (deg_df["log(FC)"] > 0))
 140        ]
 141
 142        tmp_p[p_val_scale] = (-np.log10(deg_df[pv])).astype(float)
 143
 144        tmp_p[p_val_scale] = scaler.fit_transform(
 145            tmp_p[p_val_scale].values.reshape(-1, 1)
 146        ).astype(float)
 147
 148        median_shift = abs(tmp_p[p_val_scale].diff().max()) * 25
 149        cur_max = tmp_p[p_val_scale].max()
 150
 151        zero_p_plus = deg_df[(deg_df[pv] == 0) & (deg_df["log(FC)"] > 0)]
 152        zero_p_plus = zero_p_plus.sort_values(by="log(FC)", ascending=True).reset_index(
 153            drop=True
 154        )
 155
 156        zero_p_plus[p_val_scale] = [
 157            (cur_max + (x * median_shift)) for x in range(1, len(zero_p_plus.index) + 1)
 158        ]
 159        zero_p_plus[p_val_scale] = zero_p_plus[p_val_scale].astype(float)
 160
 161        zero_p_minus = deg_df[(deg_df[pv] == 0) & (deg_df["log(FC)"] < 0)]
 162        zero_p_minus = zero_p_minus.sort_values(
 163            by="log(FC)", ascending=False
 164        ).reset_index(drop=True)
 165
 166        zero_p_minus[p_val_scale] = [
 167            (cur_max + (x * median_shift))
 168            for x in range(1, len(zero_p_minus.index) + 1)
 169        ]
 170        zero_p_minus[p_val_scale] = zero_p_minus[p_val_scale].astype(float)
 171
 172        deg_df = pd.concat([zero_p_plus, tmp_p, zero_p_minus], ignore_index=True)
 173
 174    else:
 175
 176        shift = 0.25
 177
 178        p_val_scale = "-log(p_val)"
 179
 180        min_minus = min(deg_df[pv][(deg_df[pv] != 0) & (deg_df["log(FC)"] < 0)])
 181        min_plus = min(deg_df[pv][(deg_df[pv] != 0) & (deg_df["log(FC)"] > 0)])
 182
 183        zero_p_plus = deg_df[(deg_df[pv] == 0) & (deg_df["log(FC)"] > 0)]
 184        zero_p_plus = zero_p_plus.sort_values(
 185            by="log(FC)", ascending=False
 186        ).reset_index(drop=True)
 187        zero_p_plus[pv] = [
 188            (shift * x) * min_plus for x in range(1, len(zero_p_plus.index) + 1)
 189        ]
 190
 191        zero_p_minus = deg_df[(deg_df[pv] == 0) & (deg_df["log(FC)"] < 0)]
 192        zero_p_minus = zero_p_minus.sort_values(
 193            by="log(FC)", ascending=True
 194        ).reset_index(drop=True)
 195        zero_p_minus[pv] = [
 196            (shift * x) * min_minus for x in range(1, len(zero_p_minus.index) + 1)
 197        ]
 198
 199        tmp_p = deg_df[
 200            ((deg_df[pv] != 0) & (deg_df["log(FC)"] < 0))
 201            | ((deg_df[pv] != 0) & (deg_df["log(FC)"] > 0))
 202        ]
 203
 204        del deg_df
 205
 206        deg_df = pd.concat([zero_p_plus, tmp_p, zero_p_minus], ignore_index=True)
 207
 208        deg_df[p_val_scale] = -np.log10(deg_df[pv])
 209
 210    deg_df["top100"] = None
 211
 212    if rescale_adj:
 213
 214        deg_df = deg_df.sort_values(by=p_val_scale, ascending=False)
 215
 216        deg_df = deg_df.reset_index(drop=True)
 217
 218        doubled = []
 219        ratio = []
 220        for n, i in enumerate(deg_df.index):
 221            for j in range(1, 6):
 222                if (
 223                    n + j < len(deg_df.index)
 224                    and deg_df[p_val_scale][n] / deg_df[p_val_scale][n + j] >= 2
 225                ):
 226                    doubled.append(n)
 227                    ratio.append(deg_df[p_val_scale][n + j] / deg_df[p_val_scale][n])
 228
 229        df = pd.DataFrame({"doubled": doubled, "ratio": ratio})
 230        df = df[df["doubled"] < 100]
 231
 232        df["ratio"] = (1 - df["ratio"]) / 5
 233        df = df.reset_index(drop=True)
 234
 235        df = df.sort_values("doubled")
 236
 237        if len(df["doubled"]) == 1 and 0 in df["doubled"]:
 238            df = df
 239        else:
 240            doubled2 = []
 241
 242            for l in df["doubled"]:
 243                if l + 1 != len(doubled) and l + 1 - l == 1:
 244                    doubled2.append(l)
 245                    doubled2.append(l + 1)
 246                else:
 247                    break
 248
 249            doubled2 = sorted(set(doubled2), reverse=True)
 250
 251        if len(doubled2) > 1:
 252            df = df[df["doubled"].isin(doubled2)]
 253            df = df.sort_values("doubled", ascending=False)
 254            df = df.reset_index(drop=True)
 255            for c in df.index:
 256                deg_df.loc[df["doubled"][c], p_val_scale] = deg_df.loc[
 257                    df["doubled"][c] + 1, p_val_scale
 258                ] * (1 + df["ratio"][c])
 259
 260    deg_df.loc[(deg_df["log(FC)"] <= 0) & (deg_df["p_val"] <= p_val), "top100"] = "red"
 261    deg_df.loc[(deg_df["log(FC)"] > 0) & (deg_df["p_val"] <= p_val), "top100"] = "blue"
 262    deg_df.loc[deg_df["p_val"] > p_val, "top100"] = "lightgray"
 263
 264    if lfc > 0:
 265        deg_df.loc[
 266            (deg_df["log(FC)"] <= lfc) & (deg_df["log(FC)"] >= -lfc), "top100"
 267        ] = "lightgray"
 268
 269    down_int = len(
 270        deg_df["top100"][(deg_df["log(FC)"] <= lfc * -1) & (deg_df["p_val"] <= p_val)]
 271    )
 272    up_int = len(
 273        deg_df["top100"][(deg_df["log(FC)"] > lfc) & (deg_df["p_val"] <= p_val)]
 274    )
 275
 276    deg_df_up = deg_df[deg_df["log(FC)"] > 0]
 277    deg_df_up = deg_df_up.sort_values(["p_val", "log(FC)"], ascending=[True, False])
 278    deg_df_up = deg_df_up.reset_index(drop=True)
 279
 280    n = -1
 281    l = 0
 282    while True:
 283        n += 1
 284        if deg_df_up["log(FC)"][n] > lfc:
 285            deg_df_up.loc[n, "top100"] = "green"
 286            l += 1
 287        if l == top or deg_df_up["p_val"][n] > p_val:
 288            break
 289
 290    deg_df_down = deg_df[deg_df["log(FC)"] <= 0]
 291    deg_df_down = deg_df_down.sort_values(["p_val", "log(FC)"], ascending=[True, True])
 292    deg_df_down = deg_df_down.reset_index(drop=True)
 293
 294    n = -1
 295    l = 0
 296    while True:
 297        n += 1
 298        if deg_df_down["log(FC)"][n] < lfc * -1:
 299            deg_df_down.loc[n, "top100"] = "yellow"
 300
 301            l += 1
 302        if l == top or deg_df_down["p_val"][n] > p_val:
 303            break
 304
 305    deg_df = pd.concat([deg_df_up, deg_df_down])
 306
 307    que = ["lightgray", "red", "blue", "yellow", "green"]
 308
 309    deg_df = deg_df.sort_values(
 310        by="top100", key=lambda x: x.map({v: i for i, v in enumerate(que)})
 311    )
 312
 313    deg_df = deg_df.reset_index(drop=True)
 314
 315    fig, ax = plt.subplots(figsize=(image_width, image_high))
 316
 317    plt.scatter(
 318        x=deg_df["log(FC)"], y=deg_df[p_val_scale], color=deg_df["top100"], zorder=2
 319    )
 320
 321    plt.plot(
 322        [max(deg_df["log(FC)"]) * -1.1, max(deg_df["log(FC)"]) * 1.1],
 323        [-np.log10(p_val), -np.log10(p_val)],
 324        linestyle="--",
 325        linewidth=3,
 326        color="lightgray",
 327        zorder=1,
 328    )
 329
 330    if lfc > 0:
 331        plt.plot(
 332            [lfc * -1, lfc * -1],
 333            [-3, max(deg_df[p_val_scale]) * 1.1],
 334            linestyle="--",
 335            linewidth=3,
 336            color="lightgray",
 337            zorder=1,
 338        )
 339        plt.plot(
 340            [lfc, lfc],
 341            [-3, max(deg_df[p_val_scale]) * 1.1],
 342            linestyle="--",
 343            linewidth=3,
 344            color="lightgray",
 345            zorder=1,
 346        )
 347
 348    plt.xlabel("log(FC)")
 349    plt.ylabel(p_val_scale)
 350    plt.title("Volcano plot")
 351
 352    plt.ylim(min(deg_df[p_val_scale]) - 5, max(deg_df[p_val_scale]) * 1.25)
 353
 354    texts = [
 355        ax.text(deg_df["log(FC)"][i], deg_df[p_val_scale][i], deg_df["feature"][i])
 356        for i in deg_df.index
 357        if deg_df["top100"][i] in ["green", "yellow"]
 358    ]
 359
 360    adjust_text(texts, arrowprops=dict(arrowstyle="-", color="gray", alpha=0.5))
 361
 362    legend_elements = [
 363        Line2D(
 364            [0],
 365            [0],
 366            marker="o",
 367            color="w",
 368            label="top-upregulated",
 369            markerfacecolor="green",
 370            markersize=10,
 371        ),
 372        Line2D(
 373            [0],
 374            [0],
 375            marker="o",
 376            color="w",
 377            label="top-downregulated",
 378            markerfacecolor="yellow",
 379            markersize=10,
 380        ),
 381        Line2D(
 382            [0],
 383            [0],
 384            marker="o",
 385            color="w",
 386            label="upregulated",
 387            markerfacecolor="blue",
 388            markersize=10,
 389        ),
 390        Line2D(
 391            [0],
 392            [0],
 393            marker="o",
 394            color="w",
 395            label="downregulated",
 396            markerfacecolor="red",
 397            markersize=10,
 398        ),
 399        Line2D(
 400            [0],
 401            [0],
 402            marker="o",
 403            color="w",
 404            label="non-significant",
 405            markerfacecolor="lightgray",
 406            markersize=10,
 407        ),
 408    ]
 409
 410    ax.legend(handles=legend_elements, loc="upper right")
 411    ax.grid(visible=False)
 412
 413    ax.annotate(
 414        "\nmin p-value = " + str(p_val),
 415        xy=(0.025, 0.975),
 416        xycoords="axes fraction",
 417        fontsize=12,
 418    )
 419
 420    if lfc > 0:
 421        ax.annotate(
 422            "\nmin log(FC) = " + str(lfc),
 423            xy=(0.025, 0.95),
 424            xycoords="axes fraction",
 425            fontsize=12,
 426        )
 427
 428    ax.annotate(
 429        "\nDownregulated: " + str(down_int),
 430        xy=(0.025, 0.925),
 431        xycoords="axes fraction",
 432        fontsize=12,
 433        color="red",
 434    )
 435
 436    ax.annotate(
 437        "\nUpregulated: " + str(up_int),
 438        xy=(0.025, 0.9),
 439        xycoords="axes fraction",
 440        fontsize=12,
 441        color="blue",
 442    )
 443
 444    plt.show()
 445
 446    return fig
 447
 448
 449def find_features(data: pd.DataFrame, features: list):
 450    """
 451    Identify features (rows) from a DataFrame that match a given list of features,
 452    ignoring case sensitivity.
 453
 454    Parameters
 455    ----------
 456    data : pandas.DataFrame
 457        DataFrame with features in the index (rows).
 458
 459    features : list
 460        List of feature names to search for.
 461
 462    Returns
 463    -------
 464    dict
 465        Dictionary with keys:
 466        - "included": list of features found in the DataFrame index.
 467        - "not_included": list of requested features not found in the DataFrame index.
 468        - "potential": list of features in the DataFrame that may be similar.
 469    """
 470
 471    features_upper = [str(x).upper() for x in features]
 472
 473    index_set = set(data.index)
 474
 475    features_in = [x for x in index_set if x.upper() in features_upper]
 476    features_in_upper = [x.upper() for x in features_in]
 477    features_out_upper = [x for x in features_upper if x not in features_in_upper]
 478    features_out = [x for x in features if x.upper() in features_out_upper]
 479    similar_features = [
 480        idx for idx in index_set if any(x in idx.upper() for x in features_out_upper)
 481    ]
 482
 483    return {
 484        "included": features_in,
 485        "not_included": features_out,
 486        "potential": similar_features,
 487    }
 488
 489
 490def find_names(data: pd.DataFrame, names: list):
 491    """
 492    Identify names (columns) from a DataFrame that match a given list of names,
 493    ignoring case sensitivity.
 494
 495    Parameters
 496    ----------
 497    data : pandas.DataFrame
 498        DataFrame with names in the columns.
 499
 500    names : list
 501        List of names to search for.
 502
 503    Returns
 504    -------
 505    dict
 506        Dictionary with keys:
 507        - "included": list of names found in the DataFrame columns.
 508        - "not_included": list of requested names not found in the DataFrame columns.
 509        - "potential": list of names in the DataFrame that may be similar.
 510    """
 511
 512    names_upper = [str(x).upper() for x in names]
 513
 514    columns = set(data.columns)
 515
 516    names_in = [x for x in columns if x.upper() in names_upper]
 517    names_in_upper = [x.upper() for x in names_in]
 518    names_out_upper = [x for x in names_upper if x not in names_in_upper]
 519    names_out = [x for x in names if x.upper() in names_out_upper]
 520    similar_names = [
 521        idx for idx in columns if any(x in idx.upper() for x in names_out_upper)
 522    ]
 523
 524    return {"included": names_in, "not_included": names_out, "potential": similar_names}
 525
 526
 527def reduce_data(data: pd.DataFrame, features: list = [], names: list = []):
 528    """
 529    Subset a DataFrame based on selected features (rows) and/or names (columns).
 530
 531    Parameters
 532    ----------
 533    data : pandas.DataFrame
 534        Input DataFrame with features as rows and names as columns.
 535
 536    features : list
 537        List of features to include (rows). Default is an empty list.
 538        If empty, all rows are returned.
 539
 540    names : list
 541        List of names to include (columns). Default is an empty list.
 542        If empty, all columns are returned.
 543
 544    Returns
 545    -------
 546    pandas.DataFrame
 547        Subset of the input DataFrame containing only the selected rows
 548        and/or columns.
 549
 550    Raises
 551    ------
 552    ValueError
 553        If both `features` and `names` are empty.
 554    """
 555
 556    if len(features) > 0 and len(names) > 0:
 557        fet = find_features(data=data, features=features)
 558
 559        nam = find_names(data=data, names=names)
 560
 561        data_to_return = data.loc[fet["included"], nam["included"]]
 562
 563    elif len(features) > 0 and len(names) == 0:
 564        fet = find_features(data=data, features=features)
 565
 566        data_to_return = data.loc[fet["included"], :]
 567
 568    elif len(features) == 0 and len(names) > 0:
 569
 570        nam = find_names(data=data, names=names)
 571
 572        data_to_return = data.loc[:, nam["included"]]
 573
 574    else:
 575
 576        raise ValueError("features and names have zero length!")
 577
 578    return data_to_return
 579
 580
 581def make_unique_list(lst):
 582    """
 583    Generate a list where duplicate items are renamed to ensure uniqueness.
 584
 585    Each duplicate is appended with a suffix ".n", where n indicates the
 586    occurrence count (starting from 1).
 587
 588    Parameters
 589    ----------
 590    lst : list
 591        Input list of items (strings or other hashable types).
 592
 593    Returns
 594    -------
 595    list
 596        List with unique values.
 597
 598    Examples
 599    --------
 600    >>> make_unique_list(["A", "B", "A", "A"])
 601    ['A', 'B', 'A.1', 'A.2']
 602    """
 603    seen = {}
 604    result = []
 605    for item in lst:
 606        if item not in seen:
 607            seen[item] = 0
 608            result.append(item)
 609        else:
 610            seen[item] += 1
 611            result.append(f"{item}.{seen[item]}")
 612    return result
 613
 614
 615def get_color_palette(variable_list, palette_name="tab10"):
 616    n = len(variable_list)
 617    cmap = plt.get_cmap(palette_name)
 618    colors = [cmap(i % cmap.N) for i in range(n)]
 619    return dict(zip(variable_list, colors))
 620
 621
 622def features_scatter(
 623    expression_data: pd.DataFrame,
 624    occurence_data: pd.DataFrame | None = None,
 625    scale: bool = False,
 626    features: list | None = None,
 627    metadata_list: list | None = None,
 628    colors: str = "viridis",
 629    hclust: str | None = "complete",
 630    img_width: int = 8,
 631    img_high: int = 5,
 632    label_size: int = 10,
 633    size_scale: int = 100,
 634    y_lab: str = "Genes",
 635    legend_lab: str = "log(CPM + 1)",
 636    set_box_size: float | int = 5,
 637    set_box_high: float | int = 5,
 638    bbox_to_anchor_scale: int = 25,
 639    bbox_to_anchor_perc: tuple = (0.91, 0.63),
 640    bbox_to_anchor_group: tuple = (1.01, 0.4),
 641):
 642    """
 643    Create a bubble scatter plot of selected features across samples.
 644
 645    Each point represents a feature-sample pair, where the color encodes the
 646    expression value and the size encodes occurrence or relative abundance.
 647    Optionally, hierarchical clustering can be applied to order rows and columns.
 648
 649    Parameters
 650    ----------
 651    expression_data : pandas.DataFrame
 652        Expression values (mean) with features as rows and samples as columns derived from average() function.
 653
 654    occurence_data : pandas.DataFrame or None
 655        DataFrame with occurrence/frequency values (same shape as `expression_data`) derived from occurrence() function.
 656        If None, bubble sizes are based on expression values.
 657
 658    scale: bool, default False
 659        If True, expression_data will be scaled (0–1) across the rows (features).
 660
 661    features : list or None
 662        List of features (rows) to display. If None, all features are used.
 663
 664    metadata_list : list or None, optional
 665        Metadata grouping for samples (same length as number of columns).
 666        Used to add group colors and separators in the plot.
 667
 668    colors : str, default='viridis'
 669        Colormap for expression values.
 670
 671    hclust : str or None, default='complete'
 672        Linkage method for hierarchical clustering. If None, no clustering
 673        is performed.
 674
 675    img_width : int or float, default=8
 676        Width of the plot in inches.
 677
 678    img_high : int or float, default=5
 679        Height of the plot in inches.
 680
 681    label_size : int, default=10
 682        Font size for axis labels and ticks.
 683
 684    size_scale : int or float, default=100
 685        Scaling factor for bubble sizes.
 686
 687    y_lab : str, default='Genes'
 688        Label for the x-axis.
 689
 690    legend_lab : str, default='log(CPM + 1)'
 691        Label for the colorbar legend.
 692
 693    bbox_to_anchor_scale : int, default=25
 694        Vertical scale (percentage) for positioning the colorbar.
 695
 696    bbox_to_anchor_perc : tuple, default=(0.91, 0.63)
 697        Anchor position for the size legend (percent bubble legend).
 698
 699    bbox_to_anchor_group : tuple, default=(1.01, 0.4)
 700        Anchor position for the group legend.
 701
 702    Returns
 703    -------
 704    matplotlib.figure.Figure
 705        The generated scatter plot figure.
 706
 707    Raises
 708    ------
 709    ValueError
 710        If `metadata_list` is provided but its length does not match
 711        the number of columns in `expression_data`.
 712
 713    Notes
 714    -----
 715    - Colors represent expression values normalized to the colormap.
 716    - Bubble sizes represent occurrence values (or expression values if
 717      `occurence_data` is None).
 718    - If `metadata_list` is given, groups are indicated with colors and
 719      dashed vertical separators.
 720    """
 721
 722    scatter_df = expression_data.copy()
 723
 724    if scale:
 725
 726        legend_lab = "Scaled\n" + legend_lab
 727
 728        scaler = MinMaxScaler(feature_range=(0, 1))
 729        scatter_df = pd.DataFrame(
 730            scaler.fit_transform(scatter_df.T).T,
 731            index=scatter_df.index,
 732            columns=scatter_df.columns,
 733        )
 734
 735    metadata = {}
 736
 737    metadata["primary_names"] = [str(x) for x in scatter_df.columns]
 738
 739    if metadata_list is not None:
 740        metadata["sets"] = metadata_list
 741
 742        if len(metadata["primary_names"]) != len(metadata["sets"]):
 743
 744            raise ValueError(
 745                "Metadata list and DataFrame columns must have the same length."
 746            )
 747
 748    else:
 749
 750        metadata["sets"] = [""] * len(metadata["primary_names"])
 751
 752    metadata = pd.DataFrame(metadata)
 753    if features is not None:
 754        scatter_df = scatter_df.loc[
 755            find_features(data=scatter_df, features=features)["included"],
 756        ]
 757    scatter_df.columns = metadata["primary_names"] + "#" + metadata["sets"]
 758
 759    if occurence_data is not None:
 760        if features is not None:
 761            occurence_data = occurence_data.loc[
 762                find_features(data=occurence_data, features=features)["included"],
 763            ]
 764        occurence_data.columns = metadata["primary_names"] + "#" + metadata["sets"]
 765
 766    # check duplicated names
 767
 768    tmp_columns = scatter_df.columns
 769
 770    new_cols = make_unique_list(list(tmp_columns))
 771
 772    scatter_df.columns = new_cols
 773
 774    if hclust is not None and len(expression_data.index) != 1:
 775
 776        Z = linkage(scatter_df, method=hclust)
 777
 778        # Get the order of features based on the dendrogram
 779        order_of_features = dendrogram(Z, no_plot=True)["leaves"]
 780
 781        indexes_sort = list(scatter_df.index)
 782        sorted_list_rows = []
 783        for n in order_of_features:
 784            sorted_list_rows.append(indexes_sort[n])
 785
 786        scatter_df = scatter_df.transpose()
 787
 788        Z = linkage(scatter_df, method=hclust)
 789
 790        # Get the order of features based on the dendrogram
 791        order_of_features = dendrogram(Z, no_plot=True)["leaves"]
 792
 793        indexes_sort = list(scatter_df.index)
 794        sorted_list_columns = []
 795        for n in order_of_features:
 796            sorted_list_columns.append(indexes_sort[n])
 797
 798        scatter_df = scatter_df.transpose()
 799
 800        scatter_df = scatter_df.loc[sorted_list_rows, sorted_list_columns]
 801
 802        if occurence_data is not None:
 803            occurence_data = occurence_data.loc[sorted_list_rows, sorted_list_columns]
 804
 805        metadata["sets"] = [re.sub(".*#", "", x) for x in scatter_df.columns]
 806
 807    scatter_df.columns = [re.sub("#.*", "", x) for x in scatter_df.columns]
 808
 809    if occurence_data is not None:
 810        occurence_data.columns = [re.sub("#.*", "", x) for x in occurence_data.columns]
 811
 812    fig, ax = plt.subplots(figsize=(img_width, img_high))
 813
 814    norm = plt.Normalize(0, np.max(scatter_df))
 815
 816    cmap = plt.get_cmap(colors)
 817
 818    # Bubble scatter
 819    for i, _ in enumerate(scatter_df.index):
 820        for j, _ in enumerate(scatter_df.columns):
 821            if occurence_data is not None:
 822                value_e = scatter_df.iloc[i, j]
 823                value_o = occurence_data.iloc[i, j]
 824                ax.scatter(
 825                    j,
 826                    i,
 827                    s=value_o * size_scale,
 828                    c=[cmap(norm(value_e))],
 829                    edgecolors="k",
 830                    linewidths=0.3,
 831                )
 832            else:
 833                value = scatter_df.iloc[i, j]
 834                ax.scatter(
 835                    j,
 836                    i,
 837                    s=value * size_scale,
 838                    c=[cmap(norm(value))],
 839                    edgecolors="k",
 840                    linewidths=0.3,
 841                )
 842
 843    ax.set_yticks(range(len(scatter_df.index)))
 844    ax.set_yticklabels(scatter_df.index, fontsize=label_size * 0.8)
 845    ax.set_ylabel(y_lab, fontsize=label_size)
 846    ax.set_xticks(range(len(scatter_df.columns)))
 847    ax.set_xticklabels(scatter_df.columns, fontsize=label_size * 0.8, rotation=90)
 848
 849    # color bar
 850    cax = inset_axes(
 851        ax,
 852        width="1%",
 853        height=f"{bbox_to_anchor_scale}%",
 854        bbox_to_anchor=(1, 0, 1, 1),
 855        bbox_transform=ax.transAxes,
 856        loc="upper left",
 857    )
 858    cb = mpl.colorbar.ColorbarBase(cax, cmap=cmap, norm=norm)
 859    cb.set_label(legend_lab, fontsize=label_size * 0.65)
 860    cb.ax.tick_params(labelsize=label_size * 0.7)
 861
 862    if metadata_list is not None:
 863
 864        metadata_list = list(metadata["sets"])
 865        group_colors = get_color_palette(list(set(metadata_list)), palette_name="tab10")
 866
 867        for i, group in enumerate(metadata_list):
 868            ax.add_patch(
 869                plt.Rectangle(
 870                    (i - 0.5, len(scatter_df.index) - 0.1 * set_box_high),
 871                    1,
 872                    0.1 * set_box_size,
 873                    color=group_colors[group],
 874                    transform=ax.transData,
 875                    clip_on=False,
 876                )
 877            )
 878
 879        for i in range(1, len(metadata_list)):
 880            if metadata_list[i] != metadata_list[i - 1]:
 881                ax.axvline(i - 0.5, color="black", linestyle="--", lw=1)
 882
 883        group_patches = [
 884            mpatches.Patch(color=color, label=label)
 885            for label, color in group_colors.items()
 886        ]
 887        fig.legend(
 888            handles=group_patches,
 889            title="Group",
 890            fontsize=label_size * 0.7,
 891            title_fontsize=label_size * 0.7,
 892            loc="center left",
 893            bbox_to_anchor=bbox_to_anchor_group,
 894            frameon=False,
 895        )
 896
 897    # second legend (size)
 898    if occurence_data is not None:
 899        size_values = [0.25, 0.5, 1]
 900        legend2_handles = [
 901            plt.Line2D(
 902                [],
 903                [],
 904                marker="o",
 905                linestyle="",
 906                markersize=np.sqrt(v * size_scale * 0.5),
 907                color="gray",
 908                alpha=0.6,
 909                label=f"{v * 100:.1f}",
 910            )
 911            for v in size_values
 912        ]
 913
 914        fig.legend(
 915            handles=legend2_handles,
 916            title="Percent [%]",
 917            fontsize=label_size * 0.7,
 918            title_fontsize=label_size * 0.7,
 919            loc="center left",
 920            bbox_to_anchor=bbox_to_anchor_perc,
 921            frameon=False,
 922        )
 923
 924    _, ymax = ax.get_ylim()
 925
 926    ax.set_xlim(-0.5, len(scatter_df.columns) - 0.5)
 927    ax.set_ylim(-0.5, ymax + 0.5)
 928
 929    return fig
 930
 931
 932def calc_DEG(
 933    data,
 934    metadata_list: list | None = None,
 935    entities: str | list | dict | None = None,
 936    sets: str | list | dict | None = None,
 937    min_exp: int | float = 0,
 938    min_pct: int | float = 0.1,
 939    n_proc: int = 10,
 940):
 941    """
 942    Perform differential gene expression (DEG) analysis on gene expression data.
 943
 944    The function compares groups of cells or samples (defined by `entities` or
 945    `sets`) using the Mann–Whitney U test. It computes p-values, adjusted
 946    p-values, fold changes, standardized effect sizes, and other statistics.
 947
 948    Parameters
 949    ----------
 950    data : pandas.DataFrame
 951        Expression matrix with features (e.g., genes) as rows and samples/cells
 952        as columns.
 953
 954    metadata_list : list or None, optional
 955        Metadata grouping corresponding to the columns in `data`. Required for
 956        comparisons based on sets. Default is None.
 957
 958    entities : list, str, dict, or None, optional
 959        Defines the comparison strategy:
 960        - list of sample names → compare selected cells to the rest.
 961        - 'All' → compare each sample/cell to all others.
 962        - dict → user-defined groups for pairwise comparison.
 963        - None → must be combined with `sets`.
 964
 965    sets : str, dict, or None, optional
 966        Defines group-based comparisons:
 967        - 'All' → compare each set/group to all others.
 968        - dict with two groups → perform pairwise set comparison.
 969        - None → must be combined with `entities`.
 970
 971    min_exp : float | int, default=0
 972        Minimum expression threshold for filtering features.
 973
 974    min_pct : float | int, default=0.1
 975        Minimum proportion of samples within the target group that must express
 976        a feature for it to be tested.
 977
 978    n_proc : int, default=10
 979        Number of parallel processes to use for statistical testing.
 980
 981    Returns
 982    -------
 983    pandas.DataFrame or dict
 984        Results of the differential expression analysis:
 985        - If `entities` is a list → dict with keys: 'valid_cells',
 986          'control_cells', and 'DEG' (results DataFrame).
 987        - If `entities == 'All'` or `sets == 'All'` → DataFrame with results
 988          for all groups.
 989        - If pairwise comparison (dict for `entities` or `sets`) → DataFrame
 990          with results for the specified groups.
 991
 992        The results DataFrame contains:
 993        - 'feature': feature name
 994        - 'p_val': raw p-value
 995        - 'adj_pval': adjusted p-value (multiple testing correction)
 996        - 'pct_valid': fraction of target group expressing the feature
 997        - 'pct_ctrl': fraction of control group expressing the feature
 998        - 'avg_valid': mean expression in target group
 999        - 'avg_ctrl': mean expression in control group
1000        - 'sd_valid': standard deviation in target group
1001        - 'sd_ctrl': standard deviation in control group
1002        - 'esm': effect size metric
1003        - 'FC': fold change
1004        - 'log(FC)': log2-transformed fold change
1005        - 'norm_diff': difference in mean expression
1006
1007    Raises
1008    ------
1009    ValueError
1010        - If `metadata_list` is provided but its length does not match
1011          the number of columns in `data`.
1012        - If neither `entities` nor `sets` is provided.
1013
1014    Notes
1015    -----
1016    - Mann–Whitney U test is used for group comparisons.
1017    - Multiple testing correction is applied using a simple
1018      Benjamini–Hochberg-like method.
1019    - Features expressed below `min_exp` or in fewer than `min_pct` of target
1020      samples are filtered out.
1021    - Parallelization is handled by `joblib.Parallel`.
1022
1023    Examples
1024    --------
1025    Compare a selected list of cells against all others:
1026
1027    >>> result = calc_DEG(data, entities=["cell1", "cell2", "cell3"])
1028
1029    Compare each group to others (based on metadata):
1030
1031    >>> result = calc_DEG(data, metadata_list=group_labels, sets="All")
1032
1033    Perform pairwise comparison between two predefined sets:
1034
1035    >>> sets = {"GroupA": ["A1", "A2"], "GroupB": ["B1", "B2"]}
1036    >>> result = calc_DEG(data, sets=sets)
1037    """
1038
1039    metadata = {}
1040
1041    metadata["primary_names"] = [str(x) for x in data.columns]
1042
1043    if metadata_list is not None:
1044        metadata["sets"] = metadata_list
1045
1046        if len(metadata["primary_names"]) != len(metadata["sets"]):
1047
1048            raise ValueError(
1049                "Metadata list and DataFrame columns must have the same length."
1050            )
1051
1052    else:
1053
1054        metadata["sets"] = [""] * len(metadata["primary_names"])
1055
1056    metadata = pd.DataFrame(metadata)
1057
1058    def stat_calc(choose, feature_name):
1059        target_values = choose.loc[choose["DEG"] == "target", feature_name]
1060        rest_values = choose.loc[choose["DEG"] == "rest", feature_name]
1061
1062        pct_valid = (target_values > 0).sum() / len(target_values)
1063        pct_rest = (rest_values > 0).sum() / len(rest_values)
1064
1065        avg_valid = np.mean(target_values)
1066        avg_ctrl = np.mean(rest_values)
1067        sd_valid = np.std(target_values, ddof=1)
1068        sd_ctrl = np.std(rest_values, ddof=1)
1069        esm = (avg_valid - avg_ctrl) / np.sqrt(((sd_valid**2 + sd_ctrl**2) / 2))
1070
1071        if np.sum(target_values) == np.sum(rest_values):
1072            p_val = 1.0
1073        else:
1074            _, p_val = stats.mannwhitneyu(
1075                target_values, rest_values, alternative="two-sided"
1076            )
1077
1078        return {
1079            "feature": feature_name,
1080            "p_val": p_val,
1081            "pct_valid": pct_valid,
1082            "pct_ctrl": pct_rest,
1083            "avg_valid": avg_valid,
1084            "avg_ctrl": avg_ctrl,
1085            "sd_valid": sd_valid,
1086            "sd_ctrl": sd_ctrl,
1087            "esm": esm,
1088        }
1089
1090    def prepare_and_run_stat(choose, valid_group, min_exp, min_pct, n_proc, factors):
1091
1092        tmp_dat = choose[choose["DEG"] == "target"]
1093        tmp_dat = tmp_dat.drop("DEG", axis=1)
1094
1095        counts = (tmp_dat > min_exp).sum(axis=0)
1096
1097        total_count = tmp_dat.shape[0]
1098
1099        info = pd.DataFrame(
1100            {"feature": list(tmp_dat.columns), "pct": list(counts / total_count)}
1101        )
1102
1103        del tmp_dat
1104
1105        drop_col = info["feature"][info["pct"] <= min_pct]
1106
1107        if len(drop_col) + 1 == len(choose.columns):
1108            drop_col = info["feature"][info["pct"] == 0]
1109
1110        del info
1111
1112        choose = choose.drop(list(drop_col), axis=1)
1113
1114        results = Parallel(n_jobs=n_proc)(
1115            delayed(stat_calc)(choose, feature)
1116            for feature in tqdm(choose.columns[choose.columns != "DEG"])
1117        )
1118
1119        if len(results) > 0:
1120            df = pd.DataFrame(results)
1121            df["valid_group"] = valid_group
1122            df.sort_values(by="p_val", inplace=True)
1123
1124            num_tests = len(df)
1125            df["adj_pval"] = np.minimum(
1126                1, (df["p_val"] * num_tests) / np.arange(1, num_tests + 1)
1127            )
1128
1129            factors_df = factors.to_frame(name="factor")
1130
1131            df = df.merge(factors_df, left_on="feature", right_index=True, how="left")
1132
1133            df[["avg_valid", "avg_ctrl", "factor"]] = df[
1134                ["avg_valid", "avg_ctrl", "factor"]
1135            ].astype(float)
1136
1137            df["FC"] = (df["avg_valid"] + df["factor"]) / (
1138                df["avg_ctrl"] + df["factor"]
1139            )
1140            df["log(FC)"] = np.log2(df["FC"])
1141            df["norm_diff"] = df["avg_valid"] - df["avg_ctrl"]
1142            df = df.drop(columns=["factor"])
1143        else:
1144            columns = [
1145                "feature",
1146                "valid_group",
1147                "p_val",
1148                "adj_pval",
1149                "avg_valid",
1150                "avg_ctrl",
1151                "FC",
1152                "log(FC)",
1153                "norm_diff",
1154            ]
1155            df = pd.DataFrame(columns=columns)
1156        return df
1157
1158    choose = data.T
1159
1160    factors = data.copy().replace(0, np.nan).min(axis=1)
1161
1162    nonzero_factors = factors[factors != 0]
1163    if len(nonzero_factors) > 0:
1164        factors[factors == 0] = nonzero_factors.min()
1165    else:
1166        factors[factors == 0] = 0.01
1167
1168    final_results = []
1169
1170    if isinstance(entities, list) and sets is None:
1171        print("\nAnalysis started...\nComparing selected cells to the whole set...")
1172
1173        if metadata_list is None:
1174            choose.index = metadata["primary_names"]
1175        else:
1176            choose.index = metadata["primary_names"] + " # " + metadata["sets"]
1177
1178            if "#" not in entities[0]:
1179                choose.index = metadata["primary_names"]
1180                print(
1181                    "You provided 'metadata_list', but did not include the set info (name # set) "
1182                    "in the 'entities' list. "
1183                    "Only the names will be compared, without considering the set information."
1184                )
1185
1186        labels = ["target" if idx in entities else "rest" for idx in choose.index]
1187        valid = list(
1188            set(choose.index[[i for i, x in enumerate(labels) if x == "target"]])
1189        )
1190
1191        choose["DEG"] = labels
1192        choose = choose[choose["DEG"] != "drop"]
1193
1194        result_df = prepare_and_run_stat(
1195            choose.reset_index(drop=True),
1196            valid_group=valid,
1197            min_exp=min_exp,
1198            min_pct=min_pct,
1199            n_proc=n_proc,
1200            factors=factors,
1201        )
1202
1203        return {"valid": valid, "control": "rest", "DEG": result_df}
1204
1205    elif entities == "All" and sets is None:
1206        print("\nAnalysis started...\nComparing each type of cell to others...")
1207
1208        if metadata_list is None:
1209            choose.index = metadata["primary_names"]
1210        else:
1211            choose.index = metadata["primary_names"] + " # " + metadata["sets"]
1212
1213        unique_labels = set(choose.index)
1214
1215        for label in tqdm(unique_labels):
1216            print(f"\nCalculating statistics for {label}")
1217            labels = ["target" if idx == label else "rest" for idx in choose.index]
1218            choose["DEG"] = labels
1219            choose = choose[choose["DEG"] != "drop"]
1220            result_df = prepare_and_run_stat(
1221                choose.copy(),
1222                valid_group=label,
1223                min_exp=min_exp,
1224                min_pct=min_pct,
1225                n_proc=n_proc,
1226                factors=factors,
1227            )
1228            final_results.append(result_df)
1229
1230        final_results = pd.concat(final_results, ignore_index=True)
1231
1232        if metadata_list is None:
1233            final_results["valid_group"] = [
1234                re.sub(" # ", "", x) for x in final_results["valid_group"]
1235            ]
1236
1237        return final_results
1238
1239    elif entities is None and sets == "All":
1240        print("\nAnalysis started...\nComparing each set/group to others...")
1241        choose.index = metadata["sets"]
1242        unique_sets = set(choose.index)
1243
1244        for label in tqdm(unique_sets):
1245            print(f"\nCalculating statistics for {label}")
1246            labels = ["target" if idx == label else "rest" for idx in choose.index]
1247
1248            choose["DEG"] = labels
1249            choose = choose[choose["DEG"] != "drop"]
1250            result_df = prepare_and_run_stat(
1251                choose.copy(),
1252                valid_group=label,
1253                min_exp=min_exp,
1254                min_pct=min_pct,
1255                n_proc=n_proc,
1256                factors=factors,
1257            )
1258            final_results.append(result_df)
1259
1260        return pd.concat(final_results, ignore_index=True)
1261
1262    elif entities is None and isinstance(sets, dict):
1263        print("\nAnalysis started...\nComparing groups...")
1264        choose.index = metadata["sets"]
1265
1266        group_list = list(sets.keys())
1267        if len(group_list) != 2:
1268            print("Only pairwise group comparison is supported.")
1269            return None
1270
1271        labels = [
1272            (
1273                "target"
1274                if idx in sets[group_list[0]]
1275                else "rest" if idx in sets[group_list[1]] else "drop"
1276            )
1277            for idx in choose.index
1278        ]
1279        choose["DEG"] = labels
1280        choose = choose[choose["DEG"] != "drop"]
1281
1282        result_df = prepare_and_run_stat(
1283            choose.reset_index(drop=True),
1284            valid_group=group_list[0],
1285            min_exp=min_exp,
1286            min_pct=min_pct,
1287            n_proc=n_proc,
1288            factors=factors,
1289        )
1290        return result_df
1291
1292    elif isinstance(entities, dict) and sets is None:
1293        print("\nAnalysis started...\nComparing groups...")
1294
1295        if metadata_list is None:
1296            choose.index = metadata["primary_names"]
1297        else:
1298            choose.index = metadata["primary_names"] + " # " + metadata["sets"]
1299            if "#" not in entities[list(entities.keys())[0]][0]:
1300                choose.index = metadata["primary_names"]
1301                print(
1302                    "You provided 'metadata_list', but did not include the set info (name # set) "
1303                    "in the 'entities' dict. "
1304                    "Only the names will be compared, without considering the set information."
1305                )
1306
1307        group_list = list(entities.keys())
1308        if len(group_list) != 2:
1309            print("Only pairwise group comparison is supported.")
1310            return None
1311
1312        labels = [
1313            (
1314                "target"
1315                if idx in entities[group_list[0]]
1316                else "rest" if idx in entities[group_list[1]] else "drop"
1317            )
1318            for idx in choose.index
1319        ]
1320
1321        choose["DEG"] = labels
1322        choose = choose[choose["DEG"] != "drop"]
1323
1324        result_df = prepare_and_run_stat(
1325            choose.reset_index(drop=True),
1326            valid_group=group_list[0],
1327            min_exp=min_exp,
1328            min_pct=min_pct,
1329            n_proc=n_proc,
1330            factors=factors,
1331        )
1332
1333        return result_df.reset_index(drop=True)
1334
1335    else:
1336        raise ValueError(
1337            "You must specify either 'entities' or 'sets'. None were provided, which is not allowed for this analysis."
1338        )
1339
1340
1341def average(data):
1342    """
1343    Compute the column-wise average of a DataFrame, aggregating by column names.
1344
1345    If multiple columns share the same name, their values are averaged.
1346
1347    Parameters
1348    ----------
1349    data : pandas.DataFrame
1350        Input DataFrame with numeric values. Columns with identical names
1351        will be aggregated by their mean.
1352
1353    Returns
1354    -------
1355    pandas.DataFrame
1356        DataFrame with the same rows as the input but with unique columns,
1357        where duplicate columns have been replaced by their mean values.
1358    """
1359
1360    wide_data = data
1361
1362    aggregated_df = wide_data.T.groupby(level=0).mean().T
1363
1364    return aggregated_df
1365
1366
1367def occurrence(data):
1368    """
1369    Calculate the occurrence frequency of features in a DataFrame.
1370
1371    Converts the input DataFrame to binary (presence/absence) and computes
1372    the proportion of non-zero entries for each feature, aggregating by
1373    column names if duplicates exist.
1374
1375    Parameters
1376    ----------
1377    data : pandas.DataFrame
1378        Input DataFrame with numeric values. Each column represents a feature.
1379
1380    Returns
1381    -------
1382    pandas.DataFrame
1383        DataFrame with the same rows as the input, where each value represents
1384        the proportion of samples in which the feature is present (non-zero).
1385        Columns with identical names are aggregated.
1386    """
1387
1388    binary_data = (data > 0).astype(int)
1389
1390    counts = binary_data.columns.value_counts()
1391
1392    binary_data = binary_data.T.groupby(level=0).sum().T
1393    binary_data = binary_data.astype(float)
1394
1395    for i in counts.index:
1396        binary_data.loc[:, i] = (binary_data.loc[:, i] / counts[i]).astype(float)
1397
1398    return binary_data
1399
1400
1401def add_subnames(names_list: list, parent_name: str, new_clusters: list):
1402    """
1403    Append sub-cluster names to a parent name within a list of names.
1404
1405    This function replaces occurrences of `parent_name` in `names_list` with
1406    a concatenation of the parent name and corresponding sub-cluster name
1407    from `new_clusters` (formatted as "parent.subcluster"). Non-matching names
1408    are left unchanged.
1409
1410    Parameters
1411    ----------
1412    names_list : list
1413        Original list of names (e.g., column names or cluster labels).
1414
1415    parent_name : str
1416        Name of the parent cluster to which sub-cluster names will be added.
1417        Must exist in `names_list`.
1418
1419    new_clusters : list
1420        List of sub-cluster names. Its length must match the number of times
1421        `parent_name` occurs in `names_list`.
1422
1423    Returns
1424    -------
1425    list
1426        Updated list of names with sub-cluster names appended to the parent name.
1427
1428    Raises
1429    ------
1430    ValueError
1431        - If `parent_name` is not found in `names_list`.
1432        - If `new_clusters` length does not match the number of occurrences of
1433          `parent_name`.
1434
1435    Examples
1436    --------
1437    >>> add_subnames(['A', 'B', 'A'], 'A', ['1', '2'])
1438    ['A.1', 'B', 'A.2']
1439    """
1440
1441    if str(parent_name) not in [str(x) for x in names_list]:
1442        raise ValueError(
1443            "Parent name is missing from the original dataset`s column names!"
1444        )
1445
1446    if len(new_clusters) != len([x for x in names_list if str(x) == str(parent_name)]):
1447        raise ValueError(
1448            "New cluster names list has a different length than the number of clusters in the original dataset!"
1449        )
1450
1451    new_names = []
1452    ixn = 0
1453    for _, i in enumerate(names_list):
1454        if str(i) == str(parent_name):
1455
1456            new_names.append(f"{parent_name}.{new_clusters[ixn]}")
1457            ixn += 1
1458
1459        else:
1460            new_names.append(i)
1461
1462    return new_names
1463
1464
1465def development_clust(
1466    data: pd.DataFrame, method: str = "ward", img_width: int = 5, img_high: int = 5
1467):
1468    """
1469    Perform hierarchical clustering on the columns of a DataFrame and plot a dendrogram.
1470
1471    Uses Ward's method to cluster the transposed data (columns) and generates
1472    a dendrogram showing the relationships between features or samples.
1473
1474    Parameters
1475    ----------
1476    data : pandas.DataFrame
1477        Input DataFrame with features as rows and samples/columns to be clustered.
1478
1479    method : str
1480        Method for hierarchical clustering. Options include:
1481       - 'ward' : minimizes the variance of clusters being merged.
1482       - 'single' : uses the minimum of the distances between all observations of the two sets.
1483       - 'complete' : uses the maximum of the distances between all observations of the two sets.
1484       - 'average' : uses the average of the distances between all observations of the two sets.
1485
1486    img_width : int or float, default=5
1487        Width of the resulting figure in inches.
1488
1489    img_high : int or float, default=5
1490        Height of the resulting figure in inches.
1491
1492    Returns
1493    -------
1494    matplotlib.figure.Figure
1495        The dendrogram figure.
1496    """
1497
1498    z = linkage(data.T, method=method)
1499
1500    figure, ax = plt.subplots(figsize=(img_width, img_high))
1501
1502    dendrogram(z, labels=data.columns, orientation="left", ax=ax)
1503
1504    return figure
1505
1506
1507def adjust_cells_to_group_mean(data, data_avg, beta=0.2):
1508    """
1509    Adjust each cell's values towards the mean of its group (centroid).
1510
1511    This function moves each cell's values in `data` slightly towards the
1512    corresponding group mean in `data_avg`, controlled by the parameter `beta`.
1513
1514    Parameters
1515    ----------
1516    data : pandas.DataFrame
1517        Original data with features as rows and cells/samples as columns.
1518
1519    data_avg : pandas.DataFrame
1520        DataFrame of group averages (centroids) with features as rows and
1521        group names as columns.
1522
1523    beta : float, default=0.2
1524        Weight for adjustment towards the group mean. 0 = no adjustment,
1525        1 = fully replaced by the group mean.
1526
1527    Returns
1528    -------
1529    pandas.DataFrame
1530        Adjusted data with the same shape as the input `data`.
1531    """
1532
1533    df_adjusted = data.copy()
1534
1535    for group_name in data_avg.columns:
1536        col_idx = [
1537            i
1538            for i, c in enumerate(df_adjusted.columns)
1539            if str(c).startswith(group_name)
1540        ]
1541        if not col_idx:
1542            continue
1543
1544        centroid = data_avg.loc[df_adjusted.index, group_name].to_numpy()[:, None]
1545
1546        df_adjusted.iloc[:, col_idx] = (1 - beta) * df_adjusted.iloc[
1547            :, col_idx
1548        ].to_numpy() + beta * centroid
1549
1550    return df_adjusted
def load_sparse(path: str, name: str):
21def load_sparse(path: str, name: str):
22    """
23    Load a sparse matrix dataset along with associated gene
24    and cell metadata, and return it as a dense DataFrame.
25
26    This function expects the input directory to contain three files in standard
27    10x Genomics format:
28      - "matrix.mtx": the gene expression matrix in Matrix Market format
29      - "genes.tsv": tab-separated file containing gene identifiers
30      - "barcodes.tsv": tab-separated file containing cell barcodes / names
31
32    Parameters
33    ----------
34    path : str
35        Path to the directory containing the matrix and annotation files.
36
37    name : str
38        Label or dataset identifier to be assigned to all cells in the metadata.
39
40    Returns
41    -------
42    data : pandas.DataFrame
43        Dense expression matrix where rows correspond to genes and columns
44        correspond to cells.
45    metadata : pandas.DataFrame
46        Metadata DataFrame with two columns:
47          - "cell_names": the names of the cells (barcodes)
48          - "sets": the dataset label assigned to each cell (from `name`)
49
50    Notes
51    -----
52    The function converts the sparse matrix into a dense DataFrame. This may
53    require a large amount of memory for datasets with many cells and genes.
54    """
55
56    data = mmread(os.path.join(path, "matrix.mtx"))
57    data = pd.DataFrame(data.todense())
58    genes = pd.read_csv(os.path.join(path, "genes.tsv"), header=None, sep="\t")
59    names = pd.read_csv(os.path.join(path, "barcodes.tsv"), header=None, sep="\t")
60    data.columns = [str(x) for x in names[0]]
61    data.index = list(genes[0])
62
63    names = list(data.columns)
64    sets = [name] * len(names)
65
66    metadata = pd.DataFrame({"cell_names": names, "sets": sets})
67
68    return data, metadata

Load a sparse matrix dataset along with associated gene and cell metadata, and return it as a dense DataFrame.

This function expects the input directory to contain three files in standard 10x Genomics format:

  • "matrix.mtx": the gene expression matrix in Matrix Market format
  • "genes.tsv": tab-separated file containing gene identifiers
  • "barcodes.tsv": tab-separated file containing cell barcodes / names

Parameters

path : str Path to the directory containing the matrix and annotation files.

name : str Label or dataset identifier to be assigned to all cells in the metadata.

Returns

data : pandas.DataFrame Dense expression matrix where rows correspond to genes and columns correspond to cells. metadata : pandas.DataFrame Metadata DataFrame with two columns: - "cell_names": the names of the cells (barcodes) - "sets": the dataset label assigned to each cell (from name)

Notes

The function converts the sparse matrix into a dense DataFrame. This may require a large amount of memory for datasets with many cells and genes.

def volcano_plot( deg_data: pandas.core.frame.DataFrame, p_adj: bool = True, top: int = 25, p_val: float | int = 0.05, lfc: float | int = 0.25, standard_scale: bool = False, rescale_adj: bool = True, image_width: int = 12, image_high: int = 12):
 71def volcano_plot(
 72    deg_data: pd.DataFrame,
 73    p_adj: bool = True,
 74    top: int = 25,
 75    p_val: float | int = 0.05,
 76    lfc: float | int = 0.25,
 77    standard_scale: bool = False,
 78    rescale_adj: bool = True,
 79    image_width: int = 12,
 80    image_high: int = 12,
 81):
 82    """
 83    Generate a volcano plot from differential expression results.
 84
 85    A volcano plot visualizes the relationship between statistical significance
 86    (p-values or standarized p-value) and log(fold change) for each gene, highlighting
 87    genes that pass significance thresholds.
 88
 89    Parameters
 90    ----------
 91    deg_data : pandas.DataFrame
 92        DataFrame containing differential expression results from calc_DEG() function.
 93
 94    p_adj : bool, default=True
 95        If True, use adjusted p-values. If False, use raw p-values.
 96
 97    top : int, default=25
 98        Number of top significant genes to highlight on the plot.
 99
100    p_val : float | int, default=0.05
101        Significance threshold for p-values (or adjusted p-values).
102
103    lfc : float | int, default=0.25
104        Threshold for absolute log fold change.
105
106    standard_scale : bool, default=False
107        If True, use standardized p-values for visualization instead of -log10(p-values).
108
109    rescale_adj : bool, default=True
110        If True, rescale p-values to avoid long breaks caused by outlier values.
111
112    image_width : int, default=12
113        Width of the generated plot in inches.
114
115    image_high : int, default=12
116        Height of the generated plot in inches.
117
118    Returns
119    -------
120    matplotlib.figure.Figure
121        The generated volcano plot figure.
122
123    """
124
125    if p_adj:
126        pv = "adj_pval"
127    else:
128        pv = "p_val"
129
130    deg_df = deg_data.copy()
131
132    if standard_scale:
133
134        p_val_scale = "scale(p-val)"
135
136        scaler = MinMaxScaler(feature_range=(0, 1000))
137
138        tmp_p = deg_df[
139            ((deg_df[pv] != 0) & (deg_df["log(FC)"] < 0))
140            | ((deg_df[pv] != 0) & (deg_df["log(FC)"] > 0))
141        ]
142
143        tmp_p[p_val_scale] = (-np.log10(deg_df[pv])).astype(float)
144
145        tmp_p[p_val_scale] = scaler.fit_transform(
146            tmp_p[p_val_scale].values.reshape(-1, 1)
147        ).astype(float)
148
149        median_shift = abs(tmp_p[p_val_scale].diff().max()) * 25
150        cur_max = tmp_p[p_val_scale].max()
151
152        zero_p_plus = deg_df[(deg_df[pv] == 0) & (deg_df["log(FC)"] > 0)]
153        zero_p_plus = zero_p_plus.sort_values(by="log(FC)", ascending=True).reset_index(
154            drop=True
155        )
156
157        zero_p_plus[p_val_scale] = [
158            (cur_max + (x * median_shift)) for x in range(1, len(zero_p_plus.index) + 1)
159        ]
160        zero_p_plus[p_val_scale] = zero_p_plus[p_val_scale].astype(float)
161
162        zero_p_minus = deg_df[(deg_df[pv] == 0) & (deg_df["log(FC)"] < 0)]
163        zero_p_minus = zero_p_minus.sort_values(
164            by="log(FC)", ascending=False
165        ).reset_index(drop=True)
166
167        zero_p_minus[p_val_scale] = [
168            (cur_max + (x * median_shift))
169            for x in range(1, len(zero_p_minus.index) + 1)
170        ]
171        zero_p_minus[p_val_scale] = zero_p_minus[p_val_scale].astype(float)
172
173        deg_df = pd.concat([zero_p_plus, tmp_p, zero_p_minus], ignore_index=True)
174
175    else:
176
177        shift = 0.25
178
179        p_val_scale = "-log(p_val)"
180
181        min_minus = min(deg_df[pv][(deg_df[pv] != 0) & (deg_df["log(FC)"] < 0)])
182        min_plus = min(deg_df[pv][(deg_df[pv] != 0) & (deg_df["log(FC)"] > 0)])
183
184        zero_p_plus = deg_df[(deg_df[pv] == 0) & (deg_df["log(FC)"] > 0)]
185        zero_p_plus = zero_p_plus.sort_values(
186            by="log(FC)", ascending=False
187        ).reset_index(drop=True)
188        zero_p_plus[pv] = [
189            (shift * x) * min_plus for x in range(1, len(zero_p_plus.index) + 1)
190        ]
191
192        zero_p_minus = deg_df[(deg_df[pv] == 0) & (deg_df["log(FC)"] < 0)]
193        zero_p_minus = zero_p_minus.sort_values(
194            by="log(FC)", ascending=True
195        ).reset_index(drop=True)
196        zero_p_minus[pv] = [
197            (shift * x) * min_minus for x in range(1, len(zero_p_minus.index) + 1)
198        ]
199
200        tmp_p = deg_df[
201            ((deg_df[pv] != 0) & (deg_df["log(FC)"] < 0))
202            | ((deg_df[pv] != 0) & (deg_df["log(FC)"] > 0))
203        ]
204
205        del deg_df
206
207        deg_df = pd.concat([zero_p_plus, tmp_p, zero_p_minus], ignore_index=True)
208
209        deg_df[p_val_scale] = -np.log10(deg_df[pv])
210
211    deg_df["top100"] = None
212
213    if rescale_adj:
214
215        deg_df = deg_df.sort_values(by=p_val_scale, ascending=False)
216
217        deg_df = deg_df.reset_index(drop=True)
218
219        doubled = []
220        ratio = []
221        for n, i in enumerate(deg_df.index):
222            for j in range(1, 6):
223                if (
224                    n + j < len(deg_df.index)
225                    and deg_df[p_val_scale][n] / deg_df[p_val_scale][n + j] >= 2
226                ):
227                    doubled.append(n)
228                    ratio.append(deg_df[p_val_scale][n + j] / deg_df[p_val_scale][n])
229
230        df = pd.DataFrame({"doubled": doubled, "ratio": ratio})
231        df = df[df["doubled"] < 100]
232
233        df["ratio"] = (1 - df["ratio"]) / 5
234        df = df.reset_index(drop=True)
235
236        df = df.sort_values("doubled")
237
238        if len(df["doubled"]) == 1 and 0 in df["doubled"]:
239            df = df
240        else:
241            doubled2 = []
242
243            for l in df["doubled"]:
244                if l + 1 != len(doubled) and l + 1 - l == 1:
245                    doubled2.append(l)
246                    doubled2.append(l + 1)
247                else:
248                    break
249
250            doubled2 = sorted(set(doubled2), reverse=True)
251
252        if len(doubled2) > 1:
253            df = df[df["doubled"].isin(doubled2)]
254            df = df.sort_values("doubled", ascending=False)
255            df = df.reset_index(drop=True)
256            for c in df.index:
257                deg_df.loc[df["doubled"][c], p_val_scale] = deg_df.loc[
258                    df["doubled"][c] + 1, p_val_scale
259                ] * (1 + df["ratio"][c])
260
261    deg_df.loc[(deg_df["log(FC)"] <= 0) & (deg_df["p_val"] <= p_val), "top100"] = "red"
262    deg_df.loc[(deg_df["log(FC)"] > 0) & (deg_df["p_val"] <= p_val), "top100"] = "blue"
263    deg_df.loc[deg_df["p_val"] > p_val, "top100"] = "lightgray"
264
265    if lfc > 0:
266        deg_df.loc[
267            (deg_df["log(FC)"] <= lfc) & (deg_df["log(FC)"] >= -lfc), "top100"
268        ] = "lightgray"
269
270    down_int = len(
271        deg_df["top100"][(deg_df["log(FC)"] <= lfc * -1) & (deg_df["p_val"] <= p_val)]
272    )
273    up_int = len(
274        deg_df["top100"][(deg_df["log(FC)"] > lfc) & (deg_df["p_val"] <= p_val)]
275    )
276
277    deg_df_up = deg_df[deg_df["log(FC)"] > 0]
278    deg_df_up = deg_df_up.sort_values(["p_val", "log(FC)"], ascending=[True, False])
279    deg_df_up = deg_df_up.reset_index(drop=True)
280
281    n = -1
282    l = 0
283    while True:
284        n += 1
285        if deg_df_up["log(FC)"][n] > lfc:
286            deg_df_up.loc[n, "top100"] = "green"
287            l += 1
288        if l == top or deg_df_up["p_val"][n] > p_val:
289            break
290
291    deg_df_down = deg_df[deg_df["log(FC)"] <= 0]
292    deg_df_down = deg_df_down.sort_values(["p_val", "log(FC)"], ascending=[True, True])
293    deg_df_down = deg_df_down.reset_index(drop=True)
294
295    n = -1
296    l = 0
297    while True:
298        n += 1
299        if deg_df_down["log(FC)"][n] < lfc * -1:
300            deg_df_down.loc[n, "top100"] = "yellow"
301
302            l += 1
303        if l == top or deg_df_down["p_val"][n] > p_val:
304            break
305
306    deg_df = pd.concat([deg_df_up, deg_df_down])
307
308    que = ["lightgray", "red", "blue", "yellow", "green"]
309
310    deg_df = deg_df.sort_values(
311        by="top100", key=lambda x: x.map({v: i for i, v in enumerate(que)})
312    )
313
314    deg_df = deg_df.reset_index(drop=True)
315
316    fig, ax = plt.subplots(figsize=(image_width, image_high))
317
318    plt.scatter(
319        x=deg_df["log(FC)"], y=deg_df[p_val_scale], color=deg_df["top100"], zorder=2
320    )
321
322    plt.plot(
323        [max(deg_df["log(FC)"]) * -1.1, max(deg_df["log(FC)"]) * 1.1],
324        [-np.log10(p_val), -np.log10(p_val)],
325        linestyle="--",
326        linewidth=3,
327        color="lightgray",
328        zorder=1,
329    )
330
331    if lfc > 0:
332        plt.plot(
333            [lfc * -1, lfc * -1],
334            [-3, max(deg_df[p_val_scale]) * 1.1],
335            linestyle="--",
336            linewidth=3,
337            color="lightgray",
338            zorder=1,
339        )
340        plt.plot(
341            [lfc, lfc],
342            [-3, max(deg_df[p_val_scale]) * 1.1],
343            linestyle="--",
344            linewidth=3,
345            color="lightgray",
346            zorder=1,
347        )
348
349    plt.xlabel("log(FC)")
350    plt.ylabel(p_val_scale)
351    plt.title("Volcano plot")
352
353    plt.ylim(min(deg_df[p_val_scale]) - 5, max(deg_df[p_val_scale]) * 1.25)
354
355    texts = [
356        ax.text(deg_df["log(FC)"][i], deg_df[p_val_scale][i], deg_df["feature"][i])
357        for i in deg_df.index
358        if deg_df["top100"][i] in ["green", "yellow"]
359    ]
360
361    adjust_text(texts, arrowprops=dict(arrowstyle="-", color="gray", alpha=0.5))
362
363    legend_elements = [
364        Line2D(
365            [0],
366            [0],
367            marker="o",
368            color="w",
369            label="top-upregulated",
370            markerfacecolor="green",
371            markersize=10,
372        ),
373        Line2D(
374            [0],
375            [0],
376            marker="o",
377            color="w",
378            label="top-downregulated",
379            markerfacecolor="yellow",
380            markersize=10,
381        ),
382        Line2D(
383            [0],
384            [0],
385            marker="o",
386            color="w",
387            label="upregulated",
388            markerfacecolor="blue",
389            markersize=10,
390        ),
391        Line2D(
392            [0],
393            [0],
394            marker="o",
395            color="w",
396            label="downregulated",
397            markerfacecolor="red",
398            markersize=10,
399        ),
400        Line2D(
401            [0],
402            [0],
403            marker="o",
404            color="w",
405            label="non-significant",
406            markerfacecolor="lightgray",
407            markersize=10,
408        ),
409    ]
410
411    ax.legend(handles=legend_elements, loc="upper right")
412    ax.grid(visible=False)
413
414    ax.annotate(
415        "\nmin p-value = " + str(p_val),
416        xy=(0.025, 0.975),
417        xycoords="axes fraction",
418        fontsize=12,
419    )
420
421    if lfc > 0:
422        ax.annotate(
423            "\nmin log(FC) = " + str(lfc),
424            xy=(0.025, 0.95),
425            xycoords="axes fraction",
426            fontsize=12,
427        )
428
429    ax.annotate(
430        "\nDownregulated: " + str(down_int),
431        xy=(0.025, 0.925),
432        xycoords="axes fraction",
433        fontsize=12,
434        color="red",
435    )
436
437    ax.annotate(
438        "\nUpregulated: " + str(up_int),
439        xy=(0.025, 0.9),
440        xycoords="axes fraction",
441        fontsize=12,
442        color="blue",
443    )
444
445    plt.show()
446
447    return fig

Generate a volcano plot from differential expression results.

A volcano plot visualizes the relationship between statistical significance (p-values or standarized p-value) and log(fold change) for each gene, highlighting genes that pass significance thresholds.

Parameters

deg_data : pandas.DataFrame DataFrame containing differential expression results from calc_DEG() function.

p_adj : bool, default=True If True, use adjusted p-values. If False, use raw p-values.

top : int, default=25 Number of top significant genes to highlight on the plot.

p_val : float | int, default=0.05 Significance threshold for p-values (or adjusted p-values).

lfc : float | int, default=0.25 Threshold for absolute log fold change.

standard_scale : bool, default=False If True, use standardized p-values for visualization instead of -log10(p-values).

rescale_adj : bool, default=True If True, rescale p-values to avoid long breaks caused by outlier values.

image_width : int, default=12 Width of the generated plot in inches.

image_high : int, default=12 Height of the generated plot in inches.

Returns

matplotlib.figure.Figure The generated volcano plot figure.

def find_features(data: pandas.core.frame.DataFrame, features: list):
450def find_features(data: pd.DataFrame, features: list):
451    """
452    Identify features (rows) from a DataFrame that match a given list of features,
453    ignoring case sensitivity.
454
455    Parameters
456    ----------
457    data : pandas.DataFrame
458        DataFrame with features in the index (rows).
459
460    features : list
461        List of feature names to search for.
462
463    Returns
464    -------
465    dict
466        Dictionary with keys:
467        - "included": list of features found in the DataFrame index.
468        - "not_included": list of requested features not found in the DataFrame index.
469        - "potential": list of features in the DataFrame that may be similar.
470    """
471
472    features_upper = [str(x).upper() for x in features]
473
474    index_set = set(data.index)
475
476    features_in = [x for x in index_set if x.upper() in features_upper]
477    features_in_upper = [x.upper() for x in features_in]
478    features_out_upper = [x for x in features_upper if x not in features_in_upper]
479    features_out = [x for x in features if x.upper() in features_out_upper]
480    similar_features = [
481        idx for idx in index_set if any(x in idx.upper() for x in features_out_upper)
482    ]
483
484    return {
485        "included": features_in,
486        "not_included": features_out,
487        "potential": similar_features,
488    }

Identify features (rows) from a DataFrame that match a given list of features, ignoring case sensitivity.

Parameters

data : pandas.DataFrame DataFrame with features in the index (rows).

features : list List of feature names to search for.

Returns

dict Dictionary with keys: - "included": list of features found in the DataFrame index. - "not_included": list of requested features not found in the DataFrame index. - "potential": list of features in the DataFrame that may be similar.

def find_names(data: pandas.core.frame.DataFrame, names: list):
491def find_names(data: pd.DataFrame, names: list):
492    """
493    Identify names (columns) from a DataFrame that match a given list of names,
494    ignoring case sensitivity.
495
496    Parameters
497    ----------
498    data : pandas.DataFrame
499        DataFrame with names in the columns.
500
501    names : list
502        List of names to search for.
503
504    Returns
505    -------
506    dict
507        Dictionary with keys:
508        - "included": list of names found in the DataFrame columns.
509        - "not_included": list of requested names not found in the DataFrame columns.
510        - "potential": list of names in the DataFrame that may be similar.
511    """
512
513    names_upper = [str(x).upper() for x in names]
514
515    columns = set(data.columns)
516
517    names_in = [x for x in columns if x.upper() in names_upper]
518    names_in_upper = [x.upper() for x in names_in]
519    names_out_upper = [x for x in names_upper if x not in names_in_upper]
520    names_out = [x for x in names if x.upper() in names_out_upper]
521    similar_names = [
522        idx for idx in columns if any(x in idx.upper() for x in names_out_upper)
523    ]
524
525    return {"included": names_in, "not_included": names_out, "potential": similar_names}

Identify names (columns) from a DataFrame that match a given list of names, ignoring case sensitivity.

Parameters

data : pandas.DataFrame DataFrame with names in the columns.

names : list List of names to search for.

Returns

dict Dictionary with keys: - "included": list of names found in the DataFrame columns. - "not_included": list of requested names not found in the DataFrame columns. - "potential": list of names in the DataFrame that may be similar.

def reduce_data( data: pandas.core.frame.DataFrame, features: list = [], names: list = []):
528def reduce_data(data: pd.DataFrame, features: list = [], names: list = []):
529    """
530    Subset a DataFrame based on selected features (rows) and/or names (columns).
531
532    Parameters
533    ----------
534    data : pandas.DataFrame
535        Input DataFrame with features as rows and names as columns.
536
537    features : list
538        List of features to include (rows). Default is an empty list.
539        If empty, all rows are returned.
540
541    names : list
542        List of names to include (columns). Default is an empty list.
543        If empty, all columns are returned.
544
545    Returns
546    -------
547    pandas.DataFrame
548        Subset of the input DataFrame containing only the selected rows
549        and/or columns.
550
551    Raises
552    ------
553    ValueError
554        If both `features` and `names` are empty.
555    """
556
557    if len(features) > 0 and len(names) > 0:
558        fet = find_features(data=data, features=features)
559
560        nam = find_names(data=data, names=names)
561
562        data_to_return = data.loc[fet["included"], nam["included"]]
563
564    elif len(features) > 0 and len(names) == 0:
565        fet = find_features(data=data, features=features)
566
567        data_to_return = data.loc[fet["included"], :]
568
569    elif len(features) == 0 and len(names) > 0:
570
571        nam = find_names(data=data, names=names)
572
573        data_to_return = data.loc[:, nam["included"]]
574
575    else:
576
577        raise ValueError("features and names have zero length!")
578
579    return data_to_return

Subset a DataFrame based on selected features (rows) and/or names (columns).

Parameters

data : pandas.DataFrame Input DataFrame with features as rows and names as columns.

features : list List of features to include (rows). Default is an empty list. If empty, all rows are returned.

names : list List of names to include (columns). Default is an empty list. If empty, all columns are returned.

Returns

pandas.DataFrame Subset of the input DataFrame containing only the selected rows and/or columns.

Raises

ValueError If both features and names are empty.

def make_unique_list(lst):
582def make_unique_list(lst):
583    """
584    Generate a list where duplicate items are renamed to ensure uniqueness.
585
586    Each duplicate is appended with a suffix ".n", where n indicates the
587    occurrence count (starting from 1).
588
589    Parameters
590    ----------
591    lst : list
592        Input list of items (strings or other hashable types).
593
594    Returns
595    -------
596    list
597        List with unique values.
598
599    Examples
600    --------
601    >>> make_unique_list(["A", "B", "A", "A"])
602    ['A', 'B', 'A.1', 'A.2']
603    """
604    seen = {}
605    result = []
606    for item in lst:
607        if item not in seen:
608            seen[item] = 0
609            result.append(item)
610        else:
611            seen[item] += 1
612            result.append(f"{item}.{seen[item]}")
613    return result

Generate a list where duplicate items are renamed to ensure uniqueness.

Each duplicate is appended with a suffix ".n", where n indicates the occurrence count (starting from 1).

Parameters

lst : list Input list of items (strings or other hashable types).

Returns

list List with unique values.

Examples

>>> make_unique_list(["A", "B", "A", "A"])
['A', 'B', 'A.1', 'A.2']
def get_color_palette(variable_list, palette_name='tab10'):
616def get_color_palette(variable_list, palette_name="tab10"):
617    n = len(variable_list)
618    cmap = plt.get_cmap(palette_name)
619    colors = [cmap(i % cmap.N) for i in range(n)]
620    return dict(zip(variable_list, colors))
def features_scatter( expression_data: pandas.core.frame.DataFrame, occurence_data: pandas.core.frame.DataFrame | None = None, scale: bool = False, features: list | None = None, metadata_list: list | None = None, colors: str = 'viridis', hclust: str | None = 'complete', img_width: int = 8, img_high: int = 5, label_size: int = 10, size_scale: int = 100, y_lab: str = 'Genes', legend_lab: str = 'log(CPM + 1)', set_box_size: float | int = 5, set_box_high: float | int = 5, bbox_to_anchor_scale: int = 25, bbox_to_anchor_perc: tuple = (0.91, 0.63), bbox_to_anchor_group: tuple = (1.01, 0.4)):
623def features_scatter(
624    expression_data: pd.DataFrame,
625    occurence_data: pd.DataFrame | None = None,
626    scale: bool = False,
627    features: list | None = None,
628    metadata_list: list | None = None,
629    colors: str = "viridis",
630    hclust: str | None = "complete",
631    img_width: int = 8,
632    img_high: int = 5,
633    label_size: int = 10,
634    size_scale: int = 100,
635    y_lab: str = "Genes",
636    legend_lab: str = "log(CPM + 1)",
637    set_box_size: float | int = 5,
638    set_box_high: float | int = 5,
639    bbox_to_anchor_scale: int = 25,
640    bbox_to_anchor_perc: tuple = (0.91, 0.63),
641    bbox_to_anchor_group: tuple = (1.01, 0.4),
642):
643    """
644    Create a bubble scatter plot of selected features across samples.
645
646    Each point represents a feature-sample pair, where the color encodes the
647    expression value and the size encodes occurrence or relative abundance.
648    Optionally, hierarchical clustering can be applied to order rows and columns.
649
650    Parameters
651    ----------
652    expression_data : pandas.DataFrame
653        Expression values (mean) with features as rows and samples as columns derived from average() function.
654
655    occurence_data : pandas.DataFrame or None
656        DataFrame with occurrence/frequency values (same shape as `expression_data`) derived from occurrence() function.
657        If None, bubble sizes are based on expression values.
658
659    scale: bool, default False
660        If True, expression_data will be scaled (0–1) across the rows (features).
661
662    features : list or None
663        List of features (rows) to display. If None, all features are used.
664
665    metadata_list : list or None, optional
666        Metadata grouping for samples (same length as number of columns).
667        Used to add group colors and separators in the plot.
668
669    colors : str, default='viridis'
670        Colormap for expression values.
671
672    hclust : str or None, default='complete'
673        Linkage method for hierarchical clustering. If None, no clustering
674        is performed.
675
676    img_width : int or float, default=8
677        Width of the plot in inches.
678
679    img_high : int or float, default=5
680        Height of the plot in inches.
681
682    label_size : int, default=10
683        Font size for axis labels and ticks.
684
685    size_scale : int or float, default=100
686        Scaling factor for bubble sizes.
687
688    y_lab : str, default='Genes'
689        Label for the x-axis.
690
691    legend_lab : str, default='log(CPM + 1)'
692        Label for the colorbar legend.
693
694    bbox_to_anchor_scale : int, default=25
695        Vertical scale (percentage) for positioning the colorbar.
696
697    bbox_to_anchor_perc : tuple, default=(0.91, 0.63)
698        Anchor position for the size legend (percent bubble legend).
699
700    bbox_to_anchor_group : tuple, default=(1.01, 0.4)
701        Anchor position for the group legend.
702
703    Returns
704    -------
705    matplotlib.figure.Figure
706        The generated scatter plot figure.
707
708    Raises
709    ------
710    ValueError
711        If `metadata_list` is provided but its length does not match
712        the number of columns in `expression_data`.
713
714    Notes
715    -----
716    - Colors represent expression values normalized to the colormap.
717    - Bubble sizes represent occurrence values (or expression values if
718      `occurence_data` is None).
719    - If `metadata_list` is given, groups are indicated with colors and
720      dashed vertical separators.
721    """
722
723    scatter_df = expression_data.copy()
724
725    if scale:
726
727        legend_lab = "Scaled\n" + legend_lab
728
729        scaler = MinMaxScaler(feature_range=(0, 1))
730        scatter_df = pd.DataFrame(
731            scaler.fit_transform(scatter_df.T).T,
732            index=scatter_df.index,
733            columns=scatter_df.columns,
734        )
735
736    metadata = {}
737
738    metadata["primary_names"] = [str(x) for x in scatter_df.columns]
739
740    if metadata_list is not None:
741        metadata["sets"] = metadata_list
742
743        if len(metadata["primary_names"]) != len(metadata["sets"]):
744
745            raise ValueError(
746                "Metadata list and DataFrame columns must have the same length."
747            )
748
749    else:
750
751        metadata["sets"] = [""] * len(metadata["primary_names"])
752
753    metadata = pd.DataFrame(metadata)
754    if features is not None:
755        scatter_df = scatter_df.loc[
756            find_features(data=scatter_df, features=features)["included"],
757        ]
758    scatter_df.columns = metadata["primary_names"] + "#" + metadata["sets"]
759
760    if occurence_data is not None:
761        if features is not None:
762            occurence_data = occurence_data.loc[
763                find_features(data=occurence_data, features=features)["included"],
764            ]
765        occurence_data.columns = metadata["primary_names"] + "#" + metadata["sets"]
766
767    # check duplicated names
768
769    tmp_columns = scatter_df.columns
770
771    new_cols = make_unique_list(list(tmp_columns))
772
773    scatter_df.columns = new_cols
774
775    if hclust is not None and len(expression_data.index) != 1:
776
777        Z = linkage(scatter_df, method=hclust)
778
779        # Get the order of features based on the dendrogram
780        order_of_features = dendrogram(Z, no_plot=True)["leaves"]
781
782        indexes_sort = list(scatter_df.index)
783        sorted_list_rows = []
784        for n in order_of_features:
785            sorted_list_rows.append(indexes_sort[n])
786
787        scatter_df = scatter_df.transpose()
788
789        Z = linkage(scatter_df, method=hclust)
790
791        # Get the order of features based on the dendrogram
792        order_of_features = dendrogram(Z, no_plot=True)["leaves"]
793
794        indexes_sort = list(scatter_df.index)
795        sorted_list_columns = []
796        for n in order_of_features:
797            sorted_list_columns.append(indexes_sort[n])
798
799        scatter_df = scatter_df.transpose()
800
801        scatter_df = scatter_df.loc[sorted_list_rows, sorted_list_columns]
802
803        if occurence_data is not None:
804            occurence_data = occurence_data.loc[sorted_list_rows, sorted_list_columns]
805
806        metadata["sets"] = [re.sub(".*#", "", x) for x in scatter_df.columns]
807
808    scatter_df.columns = [re.sub("#.*", "", x) for x in scatter_df.columns]
809
810    if occurence_data is not None:
811        occurence_data.columns = [re.sub("#.*", "", x) for x in occurence_data.columns]
812
813    fig, ax = plt.subplots(figsize=(img_width, img_high))
814
815    norm = plt.Normalize(0, np.max(scatter_df))
816
817    cmap = plt.get_cmap(colors)
818
819    # Bubble scatter
820    for i, _ in enumerate(scatter_df.index):
821        for j, _ in enumerate(scatter_df.columns):
822            if occurence_data is not None:
823                value_e = scatter_df.iloc[i, j]
824                value_o = occurence_data.iloc[i, j]
825                ax.scatter(
826                    j,
827                    i,
828                    s=value_o * size_scale,
829                    c=[cmap(norm(value_e))],
830                    edgecolors="k",
831                    linewidths=0.3,
832                )
833            else:
834                value = scatter_df.iloc[i, j]
835                ax.scatter(
836                    j,
837                    i,
838                    s=value * size_scale,
839                    c=[cmap(norm(value))],
840                    edgecolors="k",
841                    linewidths=0.3,
842                )
843
844    ax.set_yticks(range(len(scatter_df.index)))
845    ax.set_yticklabels(scatter_df.index, fontsize=label_size * 0.8)
846    ax.set_ylabel(y_lab, fontsize=label_size)
847    ax.set_xticks(range(len(scatter_df.columns)))
848    ax.set_xticklabels(scatter_df.columns, fontsize=label_size * 0.8, rotation=90)
849
850    # color bar
851    cax = inset_axes(
852        ax,
853        width="1%",
854        height=f"{bbox_to_anchor_scale}%",
855        bbox_to_anchor=(1, 0, 1, 1),
856        bbox_transform=ax.transAxes,
857        loc="upper left",
858    )
859    cb = mpl.colorbar.ColorbarBase(cax, cmap=cmap, norm=norm)
860    cb.set_label(legend_lab, fontsize=label_size * 0.65)
861    cb.ax.tick_params(labelsize=label_size * 0.7)
862
863    if metadata_list is not None:
864
865        metadata_list = list(metadata["sets"])
866        group_colors = get_color_palette(list(set(metadata_list)), palette_name="tab10")
867
868        for i, group in enumerate(metadata_list):
869            ax.add_patch(
870                plt.Rectangle(
871                    (i - 0.5, len(scatter_df.index) - 0.1 * set_box_high),
872                    1,
873                    0.1 * set_box_size,
874                    color=group_colors[group],
875                    transform=ax.transData,
876                    clip_on=False,
877                )
878            )
879
880        for i in range(1, len(metadata_list)):
881            if metadata_list[i] != metadata_list[i - 1]:
882                ax.axvline(i - 0.5, color="black", linestyle="--", lw=1)
883
884        group_patches = [
885            mpatches.Patch(color=color, label=label)
886            for label, color in group_colors.items()
887        ]
888        fig.legend(
889            handles=group_patches,
890            title="Group",
891            fontsize=label_size * 0.7,
892            title_fontsize=label_size * 0.7,
893            loc="center left",
894            bbox_to_anchor=bbox_to_anchor_group,
895            frameon=False,
896        )
897
898    # second legend (size)
899    if occurence_data is not None:
900        size_values = [0.25, 0.5, 1]
901        legend2_handles = [
902            plt.Line2D(
903                [],
904                [],
905                marker="o",
906                linestyle="",
907                markersize=np.sqrt(v * size_scale * 0.5),
908                color="gray",
909                alpha=0.6,
910                label=f"{v * 100:.1f}",
911            )
912            for v in size_values
913        ]
914
915        fig.legend(
916            handles=legend2_handles,
917            title="Percent [%]",
918            fontsize=label_size * 0.7,
919            title_fontsize=label_size * 0.7,
920            loc="center left",
921            bbox_to_anchor=bbox_to_anchor_perc,
922            frameon=False,
923        )
924
925    _, ymax = ax.get_ylim()
926
927    ax.set_xlim(-0.5, len(scatter_df.columns) - 0.5)
928    ax.set_ylim(-0.5, ymax + 0.5)
929
930    return fig

Create a bubble scatter plot of selected features across samples.

Each point represents a feature-sample pair, where the color encodes the expression value and the size encodes occurrence or relative abundance. Optionally, hierarchical clustering can be applied to order rows and columns.

Parameters

expression_data : pandas.DataFrame Expression values (mean) with features as rows and samples as columns derived from average() function.

occurence_data : pandas.DataFrame or None DataFrame with occurrence/frequency values (same shape as expression_data) derived from occurrence() function. If None, bubble sizes are based on expression values.

scale: bool, default False If True, expression_data will be scaled (0–1) across the rows (features).

features : list or None List of features (rows) to display. If None, all features are used.

metadata_list : list or None, optional Metadata grouping for samples (same length as number of columns). Used to add group colors and separators in the plot.

colors : str, default='viridis' Colormap for expression values.

hclust : str or None, default='complete' Linkage method for hierarchical clustering. If None, no clustering is performed.

img_width : int or float, default=8 Width of the plot in inches.

img_high : int or float, default=5 Height of the plot in inches.

label_size : int, default=10 Font size for axis labels and ticks.

size_scale : int or float, default=100 Scaling factor for bubble sizes.

y_lab : str, default='Genes' Label for the x-axis.

legend_lab : str, default='log(CPM + 1)' Label for the colorbar legend.

bbox_to_anchor_scale : int, default=25 Vertical scale (percentage) for positioning the colorbar.

bbox_to_anchor_perc : tuple, default=(0.91, 0.63) Anchor position for the size legend (percent bubble legend).

bbox_to_anchor_group : tuple, default=(1.01, 0.4) Anchor position for the group legend.

Returns

matplotlib.figure.Figure The generated scatter plot figure.

Raises

ValueError If metadata_list is provided but its length does not match the number of columns in expression_data.

Notes

  • Colors represent expression values normalized to the colormap.
  • Bubble sizes represent occurrence values (or expression values if occurence_data is None).
  • If metadata_list is given, groups are indicated with colors and dashed vertical separators.
def calc_DEG( data, metadata_list: list | None = None, entities: str | list | dict | None = None, sets: str | list | dict | None = None, min_exp: int | float = 0, min_pct: int | float = 0.1, n_proc: int = 10):
 933def calc_DEG(
 934    data,
 935    metadata_list: list | None = None,
 936    entities: str | list | dict | None = None,
 937    sets: str | list | dict | None = None,
 938    min_exp: int | float = 0,
 939    min_pct: int | float = 0.1,
 940    n_proc: int = 10,
 941):
 942    """
 943    Perform differential gene expression (DEG) analysis on gene expression data.
 944
 945    The function compares groups of cells or samples (defined by `entities` or
 946    `sets`) using the Mann–Whitney U test. It computes p-values, adjusted
 947    p-values, fold changes, standardized effect sizes, and other statistics.
 948
 949    Parameters
 950    ----------
 951    data : pandas.DataFrame
 952        Expression matrix with features (e.g., genes) as rows and samples/cells
 953        as columns.
 954
 955    metadata_list : list or None, optional
 956        Metadata grouping corresponding to the columns in `data`. Required for
 957        comparisons based on sets. Default is None.
 958
 959    entities : list, str, dict, or None, optional
 960        Defines the comparison strategy:
 961        - list of sample names → compare selected cells to the rest.
 962        - 'All' → compare each sample/cell to all others.
 963        - dict → user-defined groups for pairwise comparison.
 964        - None → must be combined with `sets`.
 965
 966    sets : str, dict, or None, optional
 967        Defines group-based comparisons:
 968        - 'All' → compare each set/group to all others.
 969        - dict with two groups → perform pairwise set comparison.
 970        - None → must be combined with `entities`.
 971
 972    min_exp : float | int, default=0
 973        Minimum expression threshold for filtering features.
 974
 975    min_pct : float | int, default=0.1
 976        Minimum proportion of samples within the target group that must express
 977        a feature for it to be tested.
 978
 979    n_proc : int, default=10
 980        Number of parallel processes to use for statistical testing.
 981
 982    Returns
 983    -------
 984    pandas.DataFrame or dict
 985        Results of the differential expression analysis:
 986        - If `entities` is a list → dict with keys: 'valid_cells',
 987          'control_cells', and 'DEG' (results DataFrame).
 988        - If `entities == 'All'` or `sets == 'All'` → DataFrame with results
 989          for all groups.
 990        - If pairwise comparison (dict for `entities` or `sets`) → DataFrame
 991          with results for the specified groups.
 992
 993        The results DataFrame contains:
 994        - 'feature': feature name
 995        - 'p_val': raw p-value
 996        - 'adj_pval': adjusted p-value (multiple testing correction)
 997        - 'pct_valid': fraction of target group expressing the feature
 998        - 'pct_ctrl': fraction of control group expressing the feature
 999        - 'avg_valid': mean expression in target group
1000        - 'avg_ctrl': mean expression in control group
1001        - 'sd_valid': standard deviation in target group
1002        - 'sd_ctrl': standard deviation in control group
1003        - 'esm': effect size metric
1004        - 'FC': fold change
1005        - 'log(FC)': log2-transformed fold change
1006        - 'norm_diff': difference in mean expression
1007
1008    Raises
1009    ------
1010    ValueError
1011        - If `metadata_list` is provided but its length does not match
1012          the number of columns in `data`.
1013        - If neither `entities` nor `sets` is provided.
1014
1015    Notes
1016    -----
1017    - Mann–Whitney U test is used for group comparisons.
1018    - Multiple testing correction is applied using a simple
1019      Benjamini–Hochberg-like method.
1020    - Features expressed below `min_exp` or in fewer than `min_pct` of target
1021      samples are filtered out.
1022    - Parallelization is handled by `joblib.Parallel`.
1023
1024    Examples
1025    --------
1026    Compare a selected list of cells against all others:
1027
1028    >>> result = calc_DEG(data, entities=["cell1", "cell2", "cell3"])
1029
1030    Compare each group to others (based on metadata):
1031
1032    >>> result = calc_DEG(data, metadata_list=group_labels, sets="All")
1033
1034    Perform pairwise comparison between two predefined sets:
1035
1036    >>> sets = {"GroupA": ["A1", "A2"], "GroupB": ["B1", "B2"]}
1037    >>> result = calc_DEG(data, sets=sets)
1038    """
1039
1040    metadata = {}
1041
1042    metadata["primary_names"] = [str(x) for x in data.columns]
1043
1044    if metadata_list is not None:
1045        metadata["sets"] = metadata_list
1046
1047        if len(metadata["primary_names"]) != len(metadata["sets"]):
1048
1049            raise ValueError(
1050                "Metadata list and DataFrame columns must have the same length."
1051            )
1052
1053    else:
1054
1055        metadata["sets"] = [""] * len(metadata["primary_names"])
1056
1057    metadata = pd.DataFrame(metadata)
1058
1059    def stat_calc(choose, feature_name):
1060        target_values = choose.loc[choose["DEG"] == "target", feature_name]
1061        rest_values = choose.loc[choose["DEG"] == "rest", feature_name]
1062
1063        pct_valid = (target_values > 0).sum() / len(target_values)
1064        pct_rest = (rest_values > 0).sum() / len(rest_values)
1065
1066        avg_valid = np.mean(target_values)
1067        avg_ctrl = np.mean(rest_values)
1068        sd_valid = np.std(target_values, ddof=1)
1069        sd_ctrl = np.std(rest_values, ddof=1)
1070        esm = (avg_valid - avg_ctrl) / np.sqrt(((sd_valid**2 + sd_ctrl**2) / 2))
1071
1072        if np.sum(target_values) == np.sum(rest_values):
1073            p_val = 1.0
1074        else:
1075            _, p_val = stats.mannwhitneyu(
1076                target_values, rest_values, alternative="two-sided"
1077            )
1078
1079        return {
1080            "feature": feature_name,
1081            "p_val": p_val,
1082            "pct_valid": pct_valid,
1083            "pct_ctrl": pct_rest,
1084            "avg_valid": avg_valid,
1085            "avg_ctrl": avg_ctrl,
1086            "sd_valid": sd_valid,
1087            "sd_ctrl": sd_ctrl,
1088            "esm": esm,
1089        }
1090
1091    def prepare_and_run_stat(choose, valid_group, min_exp, min_pct, n_proc, factors):
1092
1093        tmp_dat = choose[choose["DEG"] == "target"]
1094        tmp_dat = tmp_dat.drop("DEG", axis=1)
1095
1096        counts = (tmp_dat > min_exp).sum(axis=0)
1097
1098        total_count = tmp_dat.shape[0]
1099
1100        info = pd.DataFrame(
1101            {"feature": list(tmp_dat.columns), "pct": list(counts / total_count)}
1102        )
1103
1104        del tmp_dat
1105
1106        drop_col = info["feature"][info["pct"] <= min_pct]
1107
1108        if len(drop_col) + 1 == len(choose.columns):
1109            drop_col = info["feature"][info["pct"] == 0]
1110
1111        del info
1112
1113        choose = choose.drop(list(drop_col), axis=1)
1114
1115        results = Parallel(n_jobs=n_proc)(
1116            delayed(stat_calc)(choose, feature)
1117            for feature in tqdm(choose.columns[choose.columns != "DEG"])
1118        )
1119
1120        if len(results) > 0:
1121            df = pd.DataFrame(results)
1122            df["valid_group"] = valid_group
1123            df.sort_values(by="p_val", inplace=True)
1124
1125            num_tests = len(df)
1126            df["adj_pval"] = np.minimum(
1127                1, (df["p_val"] * num_tests) / np.arange(1, num_tests + 1)
1128            )
1129
1130            factors_df = factors.to_frame(name="factor")
1131
1132            df = df.merge(factors_df, left_on="feature", right_index=True, how="left")
1133
1134            df[["avg_valid", "avg_ctrl", "factor"]] = df[
1135                ["avg_valid", "avg_ctrl", "factor"]
1136            ].astype(float)
1137
1138            df["FC"] = (df["avg_valid"] + df["factor"]) / (
1139                df["avg_ctrl"] + df["factor"]
1140            )
1141            df["log(FC)"] = np.log2(df["FC"])
1142            df["norm_diff"] = df["avg_valid"] - df["avg_ctrl"]
1143            df = df.drop(columns=["factor"])
1144        else:
1145            columns = [
1146                "feature",
1147                "valid_group",
1148                "p_val",
1149                "adj_pval",
1150                "avg_valid",
1151                "avg_ctrl",
1152                "FC",
1153                "log(FC)",
1154                "norm_diff",
1155            ]
1156            df = pd.DataFrame(columns=columns)
1157        return df
1158
1159    choose = data.T
1160
1161    factors = data.copy().replace(0, np.nan).min(axis=1)
1162
1163    nonzero_factors = factors[factors != 0]
1164    if len(nonzero_factors) > 0:
1165        factors[factors == 0] = nonzero_factors.min()
1166    else:
1167        factors[factors == 0] = 0.01
1168
1169    final_results = []
1170
1171    if isinstance(entities, list) and sets is None:
1172        print("\nAnalysis started...\nComparing selected cells to the whole set...")
1173
1174        if metadata_list is None:
1175            choose.index = metadata["primary_names"]
1176        else:
1177            choose.index = metadata["primary_names"] + " # " + metadata["sets"]
1178
1179            if "#" not in entities[0]:
1180                choose.index = metadata["primary_names"]
1181                print(
1182                    "You provided 'metadata_list', but did not include the set info (name # set) "
1183                    "in the 'entities' list. "
1184                    "Only the names will be compared, without considering the set information."
1185                )
1186
1187        labels = ["target" if idx in entities else "rest" for idx in choose.index]
1188        valid = list(
1189            set(choose.index[[i for i, x in enumerate(labels) if x == "target"]])
1190        )
1191
1192        choose["DEG"] = labels
1193        choose = choose[choose["DEG"] != "drop"]
1194
1195        result_df = prepare_and_run_stat(
1196            choose.reset_index(drop=True),
1197            valid_group=valid,
1198            min_exp=min_exp,
1199            min_pct=min_pct,
1200            n_proc=n_proc,
1201            factors=factors,
1202        )
1203
1204        return {"valid": valid, "control": "rest", "DEG": result_df}
1205
1206    elif entities == "All" and sets is None:
1207        print("\nAnalysis started...\nComparing each type of cell to others...")
1208
1209        if metadata_list is None:
1210            choose.index = metadata["primary_names"]
1211        else:
1212            choose.index = metadata["primary_names"] + " # " + metadata["sets"]
1213
1214        unique_labels = set(choose.index)
1215
1216        for label in tqdm(unique_labels):
1217            print(f"\nCalculating statistics for {label}")
1218            labels = ["target" if idx == label else "rest" for idx in choose.index]
1219            choose["DEG"] = labels
1220            choose = choose[choose["DEG"] != "drop"]
1221            result_df = prepare_and_run_stat(
1222                choose.copy(),
1223                valid_group=label,
1224                min_exp=min_exp,
1225                min_pct=min_pct,
1226                n_proc=n_proc,
1227                factors=factors,
1228            )
1229            final_results.append(result_df)
1230
1231        final_results = pd.concat(final_results, ignore_index=True)
1232
1233        if metadata_list is None:
1234            final_results["valid_group"] = [
1235                re.sub(" # ", "", x) for x in final_results["valid_group"]
1236            ]
1237
1238        return final_results
1239
1240    elif entities is None and sets == "All":
1241        print("\nAnalysis started...\nComparing each set/group to others...")
1242        choose.index = metadata["sets"]
1243        unique_sets = set(choose.index)
1244
1245        for label in tqdm(unique_sets):
1246            print(f"\nCalculating statistics for {label}")
1247            labels = ["target" if idx == label else "rest" for idx in choose.index]
1248
1249            choose["DEG"] = labels
1250            choose = choose[choose["DEG"] != "drop"]
1251            result_df = prepare_and_run_stat(
1252                choose.copy(),
1253                valid_group=label,
1254                min_exp=min_exp,
1255                min_pct=min_pct,
1256                n_proc=n_proc,
1257                factors=factors,
1258            )
1259            final_results.append(result_df)
1260
1261        return pd.concat(final_results, ignore_index=True)
1262
1263    elif entities is None and isinstance(sets, dict):
1264        print("\nAnalysis started...\nComparing groups...")
1265        choose.index = metadata["sets"]
1266
1267        group_list = list(sets.keys())
1268        if len(group_list) != 2:
1269            print("Only pairwise group comparison is supported.")
1270            return None
1271
1272        labels = [
1273            (
1274                "target"
1275                if idx in sets[group_list[0]]
1276                else "rest" if idx in sets[group_list[1]] else "drop"
1277            )
1278            for idx in choose.index
1279        ]
1280        choose["DEG"] = labels
1281        choose = choose[choose["DEG"] != "drop"]
1282
1283        result_df = prepare_and_run_stat(
1284            choose.reset_index(drop=True),
1285            valid_group=group_list[0],
1286            min_exp=min_exp,
1287            min_pct=min_pct,
1288            n_proc=n_proc,
1289            factors=factors,
1290        )
1291        return result_df
1292
1293    elif isinstance(entities, dict) and sets is None:
1294        print("\nAnalysis started...\nComparing groups...")
1295
1296        if metadata_list is None:
1297            choose.index = metadata["primary_names"]
1298        else:
1299            choose.index = metadata["primary_names"] + " # " + metadata["sets"]
1300            if "#" not in entities[list(entities.keys())[0]][0]:
1301                choose.index = metadata["primary_names"]
1302                print(
1303                    "You provided 'metadata_list', but did not include the set info (name # set) "
1304                    "in the 'entities' dict. "
1305                    "Only the names will be compared, without considering the set information."
1306                )
1307
1308        group_list = list(entities.keys())
1309        if len(group_list) != 2:
1310            print("Only pairwise group comparison is supported.")
1311            return None
1312
1313        labels = [
1314            (
1315                "target"
1316                if idx in entities[group_list[0]]
1317                else "rest" if idx in entities[group_list[1]] else "drop"
1318            )
1319            for idx in choose.index
1320        ]
1321
1322        choose["DEG"] = labels
1323        choose = choose[choose["DEG"] != "drop"]
1324
1325        result_df = prepare_and_run_stat(
1326            choose.reset_index(drop=True),
1327            valid_group=group_list[0],
1328            min_exp=min_exp,
1329            min_pct=min_pct,
1330            n_proc=n_proc,
1331            factors=factors,
1332        )
1333
1334        return result_df.reset_index(drop=True)
1335
1336    else:
1337        raise ValueError(
1338            "You must specify either 'entities' or 'sets'. None were provided, which is not allowed for this analysis."
1339        )

Perform differential gene expression (DEG) analysis on gene expression data.

The function compares groups of cells or samples (defined by entities or sets) using the Mann–Whitney U test. It computes p-values, adjusted p-values, fold changes, standardized effect sizes, and other statistics.

Parameters

data : pandas.DataFrame Expression matrix with features (e.g., genes) as rows and samples/cells as columns.

metadata_list : list or None, optional Metadata grouping corresponding to the columns in data. Required for comparisons based on sets. Default is None.

entities : list, str, dict, or None, optional Defines the comparison strategy: - list of sample names → compare selected cells to the rest. - 'All' → compare each sample/cell to all others. - dict → user-defined groups for pairwise comparison. - None → must be combined with sets.

sets : str, dict, or None, optional Defines group-based comparisons: - 'All' → compare each set/group to all others. - dict with two groups → perform pairwise set comparison. - None → must be combined with entities.

min_exp : float | int, default=0 Minimum expression threshold for filtering features.

min_pct : float | int, default=0.1 Minimum proportion of samples within the target group that must express a feature for it to be tested.

n_proc : int, default=10 Number of parallel processes to use for statistical testing.

Returns

pandas.DataFrame or dict Results of the differential expression analysis: - If entities is a list → dict with keys: 'valid_cells', 'control_cells', and 'DEG' (results DataFrame). - If entities == 'All' or sets == 'All' → DataFrame with results for all groups. - If pairwise comparison (dict for entities or sets) → DataFrame with results for the specified groups.

The results DataFrame contains:
- 'feature': feature name
- 'p_val': raw p-value
- 'adj_pval': adjusted p-value (multiple testing correction)
- 'pct_valid': fraction of target group expressing the feature
- 'pct_ctrl': fraction of control group expressing the feature
- 'avg_valid': mean expression in target group
- 'avg_ctrl': mean expression in control group
- 'sd_valid': standard deviation in target group
- 'sd_ctrl': standard deviation in control group
- 'esm': effect size metric
- 'FC': fold change
- 'log(FC)': log2-transformed fold change
- 'norm_diff': difference in mean expression

Raises

ValueError - If metadata_list is provided but its length does not match the number of columns in data. - If neither entities nor sets is provided.

Notes

  • Mann–Whitney U test is used for group comparisons.
  • Multiple testing correction is applied using a simple Benjamini–Hochberg-like method.
  • Features expressed below min_exp or in fewer than min_pct of target samples are filtered out.
  • Parallelization is handled by joblib.Parallel.

Examples

Compare a selected list of cells against all others:

>>> result = calc_DEG(data, entities=["cell1", "cell2", "cell3"])

Compare each group to others (based on metadata):

>>> result = calc_DEG(data, metadata_list=group_labels, sets="All")

Perform pairwise comparison between two predefined sets:

>>> sets = {"GroupA": ["A1", "A2"], "GroupB": ["B1", "B2"]}
>>> result = calc_DEG(data, sets=sets)
def average(data):
1342def average(data):
1343    """
1344    Compute the column-wise average of a DataFrame, aggregating by column names.
1345
1346    If multiple columns share the same name, their values are averaged.
1347
1348    Parameters
1349    ----------
1350    data : pandas.DataFrame
1351        Input DataFrame with numeric values. Columns with identical names
1352        will be aggregated by their mean.
1353
1354    Returns
1355    -------
1356    pandas.DataFrame
1357        DataFrame with the same rows as the input but with unique columns,
1358        where duplicate columns have been replaced by their mean values.
1359    """
1360
1361    wide_data = data
1362
1363    aggregated_df = wide_data.T.groupby(level=0).mean().T
1364
1365    return aggregated_df

Compute the column-wise average of a DataFrame, aggregating by column names.

If multiple columns share the same name, their values are averaged.

Parameters

data : pandas.DataFrame Input DataFrame with numeric values. Columns with identical names will be aggregated by their mean.

Returns

pandas.DataFrame DataFrame with the same rows as the input but with unique columns, where duplicate columns have been replaced by their mean values.

def occurrence(data):
1368def occurrence(data):
1369    """
1370    Calculate the occurrence frequency of features in a DataFrame.
1371
1372    Converts the input DataFrame to binary (presence/absence) and computes
1373    the proportion of non-zero entries for each feature, aggregating by
1374    column names if duplicates exist.
1375
1376    Parameters
1377    ----------
1378    data : pandas.DataFrame
1379        Input DataFrame with numeric values. Each column represents a feature.
1380
1381    Returns
1382    -------
1383    pandas.DataFrame
1384        DataFrame with the same rows as the input, where each value represents
1385        the proportion of samples in which the feature is present (non-zero).
1386        Columns with identical names are aggregated.
1387    """
1388
1389    binary_data = (data > 0).astype(int)
1390
1391    counts = binary_data.columns.value_counts()
1392
1393    binary_data = binary_data.T.groupby(level=0).sum().T
1394    binary_data = binary_data.astype(float)
1395
1396    for i in counts.index:
1397        binary_data.loc[:, i] = (binary_data.loc[:, i] / counts[i]).astype(float)
1398
1399    return binary_data

Calculate the occurrence frequency of features in a DataFrame.

Converts the input DataFrame to binary (presence/absence) and computes the proportion of non-zero entries for each feature, aggregating by column names if duplicates exist.

Parameters

data : pandas.DataFrame Input DataFrame with numeric values. Each column represents a feature.

Returns

pandas.DataFrame DataFrame with the same rows as the input, where each value represents the proportion of samples in which the feature is present (non-zero). Columns with identical names are aggregated.

def add_subnames(names_list: list, parent_name: str, new_clusters: list):
1402def add_subnames(names_list: list, parent_name: str, new_clusters: list):
1403    """
1404    Append sub-cluster names to a parent name within a list of names.
1405
1406    This function replaces occurrences of `parent_name` in `names_list` with
1407    a concatenation of the parent name and corresponding sub-cluster name
1408    from `new_clusters` (formatted as "parent.subcluster"). Non-matching names
1409    are left unchanged.
1410
1411    Parameters
1412    ----------
1413    names_list : list
1414        Original list of names (e.g., column names or cluster labels).
1415
1416    parent_name : str
1417        Name of the parent cluster to which sub-cluster names will be added.
1418        Must exist in `names_list`.
1419
1420    new_clusters : list
1421        List of sub-cluster names. Its length must match the number of times
1422        `parent_name` occurs in `names_list`.
1423
1424    Returns
1425    -------
1426    list
1427        Updated list of names with sub-cluster names appended to the parent name.
1428
1429    Raises
1430    ------
1431    ValueError
1432        - If `parent_name` is not found in `names_list`.
1433        - If `new_clusters` length does not match the number of occurrences of
1434          `parent_name`.
1435
1436    Examples
1437    --------
1438    >>> add_subnames(['A', 'B', 'A'], 'A', ['1', '2'])
1439    ['A.1', 'B', 'A.2']
1440    """
1441
1442    if str(parent_name) not in [str(x) for x in names_list]:
1443        raise ValueError(
1444            "Parent name is missing from the original dataset`s column names!"
1445        )
1446
1447    if len(new_clusters) != len([x for x in names_list if str(x) == str(parent_name)]):
1448        raise ValueError(
1449            "New cluster names list has a different length than the number of clusters in the original dataset!"
1450        )
1451
1452    new_names = []
1453    ixn = 0
1454    for _, i in enumerate(names_list):
1455        if str(i) == str(parent_name):
1456
1457            new_names.append(f"{parent_name}.{new_clusters[ixn]}")
1458            ixn += 1
1459
1460        else:
1461            new_names.append(i)
1462
1463    return new_names

Append sub-cluster names to a parent name within a list of names.

This function replaces occurrences of parent_name in names_list with a concatenation of the parent name and corresponding sub-cluster name from new_clusters (formatted as "parent.subcluster"). Non-matching names are left unchanged.

Parameters

names_list : list Original list of names (e.g., column names or cluster labels).

parent_name : str Name of the parent cluster to which sub-cluster names will be added. Must exist in names_list.

new_clusters : list List of sub-cluster names. Its length must match the number of times parent_name occurs in names_list.

Returns

list Updated list of names with sub-cluster names appended to the parent name.

Raises

ValueError - If parent_name is not found in names_list. - If new_clusters length does not match the number of occurrences of parent_name.

Examples

>>> add_subnames(['A', 'B', 'A'], 'A', ['1', '2'])
['A.1', 'B', 'A.2']
def development_clust( data: pandas.core.frame.DataFrame, method: str = 'ward', img_width: int = 5, img_high: int = 5):
1466def development_clust(
1467    data: pd.DataFrame, method: str = "ward", img_width: int = 5, img_high: int = 5
1468):
1469    """
1470    Perform hierarchical clustering on the columns of a DataFrame and plot a dendrogram.
1471
1472    Uses Ward's method to cluster the transposed data (columns) and generates
1473    a dendrogram showing the relationships between features or samples.
1474
1475    Parameters
1476    ----------
1477    data : pandas.DataFrame
1478        Input DataFrame with features as rows and samples/columns to be clustered.
1479
1480    method : str
1481        Method for hierarchical clustering. Options include:
1482       - 'ward' : minimizes the variance of clusters being merged.
1483       - 'single' : uses the minimum of the distances between all observations of the two sets.
1484       - 'complete' : uses the maximum of the distances between all observations of the two sets.
1485       - 'average' : uses the average of the distances between all observations of the two sets.
1486
1487    img_width : int or float, default=5
1488        Width of the resulting figure in inches.
1489
1490    img_high : int or float, default=5
1491        Height of the resulting figure in inches.
1492
1493    Returns
1494    -------
1495    matplotlib.figure.Figure
1496        The dendrogram figure.
1497    """
1498
1499    z = linkage(data.T, method=method)
1500
1501    figure, ax = plt.subplots(figsize=(img_width, img_high))
1502
1503    dendrogram(z, labels=data.columns, orientation="left", ax=ax)
1504
1505    return figure

Perform hierarchical clustering on the columns of a DataFrame and plot a dendrogram.

Uses Ward's method to cluster the transposed data (columns) and generates a dendrogram showing the relationships between features or samples.

Parameters

data : pandas.DataFrame Input DataFrame with features as rows and samples/columns to be clustered.

method : str Method for hierarchical clustering. Options include:

  • 'ward' : minimizes the variance of clusters being merged.
  • 'single' : uses the minimum of the distances between all observations of the two sets.
  • 'complete' : uses the maximum of the distances between all observations of the two sets.
  • 'average' : uses the average of the distances between all observations of the two sets.

img_width : int or float, default=5 Width of the resulting figure in inches.

img_high : int or float, default=5 Height of the resulting figure in inches.

Returns

matplotlib.figure.Figure The dendrogram figure.

def adjust_cells_to_group_mean(data, data_avg, beta=0.2):
1508def adjust_cells_to_group_mean(data, data_avg, beta=0.2):
1509    """
1510    Adjust each cell's values towards the mean of its group (centroid).
1511
1512    This function moves each cell's values in `data` slightly towards the
1513    corresponding group mean in `data_avg`, controlled by the parameter `beta`.
1514
1515    Parameters
1516    ----------
1517    data : pandas.DataFrame
1518        Original data with features as rows and cells/samples as columns.
1519
1520    data_avg : pandas.DataFrame
1521        DataFrame of group averages (centroids) with features as rows and
1522        group names as columns.
1523
1524    beta : float, default=0.2
1525        Weight for adjustment towards the group mean. 0 = no adjustment,
1526        1 = fully replaced by the group mean.
1527
1528    Returns
1529    -------
1530    pandas.DataFrame
1531        Adjusted data with the same shape as the input `data`.
1532    """
1533
1534    df_adjusted = data.copy()
1535
1536    for group_name in data_avg.columns:
1537        col_idx = [
1538            i
1539            for i, c in enumerate(df_adjusted.columns)
1540            if str(c).startswith(group_name)
1541        ]
1542        if not col_idx:
1543            continue
1544
1545        centroid = data_avg.loc[df_adjusted.index, group_name].to_numpy()[:, None]
1546
1547        df_adjusted.iloc[:, col_idx] = (1 - beta) * df_adjusted.iloc[
1548            :, col_idx
1549        ].to_numpy() + beta * centroid
1550
1551    return df_adjusted

Adjust each cell's values towards the mean of its group (centroid).

This function moves each cell's values in data slightly towards the corresponding group mean in data_avg, controlled by the parameter beta.

Parameters

data : pandas.DataFrame Original data with features as rows and cells/samples as columns.

data_avg : pandas.DataFrame DataFrame of group averages (centroids) with features as rows and group names as columns.

beta : float, default=0.2 Weight for adjustment towards the group mean. 0 = no adjustment, 1 = fully replaced by the group mean.

Returns

pandas.DataFrame Adjusted data with the same shape as the input data.