jdti.utils

   1import os
   2import re
   3
   4import matplotlib as mpl
   5import matplotlib.patches as mpatches
   6import matplotlib.pyplot as plt
   7import numpy as np
   8import pandas as pd
   9import scipy.stats as stats
  10from adjustText import adjust_text
  11from joblib import Parallel, delayed
  12from matplotlib.lines import Line2D
  13from mpl_toolkits.axes_grid1.inset_locator import inset_axes
  14from scipy.cluster.hierarchy import dendrogram, linkage
  15from scipy.io import mmread
  16from tqdm import tqdm
  17
  18
  19def load_sparse(path: str, name: str):
  20    """
  21    Load a sparse matrix dataset along with associated gene
  22    and cell metadata, and return it as a dense DataFrame.
  23
  24    This function expects the input directory to contain three files in standard
  25    10x Genomics format:
  26      - "matrix.mtx": the gene expression matrix in Matrix Market format
  27      - "genes.tsv": tab-separated file containing gene identifiers
  28      - "barcodes.tsv": tab-separated file containing cell barcodes / names
  29
  30    Parameters
  31    ----------
  32    path : str
  33        Path to the directory containing the matrix and annotation files.
  34
  35    name : str
  36        Label or dataset identifier to be assigned to all cells in the metadata.
  37
  38    Returns
  39    -------
  40    data : pandas.DataFrame
  41        Dense expression matrix where rows correspond to genes and columns
  42        correspond to cells.
  43    metadata : pandas.DataFrame
  44        Metadata DataFrame with two columns:
  45          - "cell_names": the names of the cells (barcodes)
  46          - "sets": the dataset label assigned to each cell (from `name`)
  47
  48    Notes
  49    -----
  50    The function converts the sparse matrix into a dense DataFrame. This may
  51    require a large amount of memory for datasets with many cells and genes.
  52    """
  53
  54    data = mmread(os.path.join(path, "matrix.mtx"))
  55    data = pd.DataFrame(data.todense())
  56    genes = pd.read_csv(os.path.join(path, "genes.tsv"), header=None, sep="\t")
  57    names = pd.read_csv(os.path.join(path, "barcodes.tsv"), header=None, sep="\t")
  58    data.columns = [str(x) for x in names[0]]
  59    data.index = list(genes[0])
  60
  61    names = list(data.columns)
  62    sets = [name] * len(names)
  63
  64    metadata = pd.DataFrame({"cell_names": names, "sets": sets})
  65
  66    return data, metadata
  67
  68
  69def volcano_plot(
  70    deg_data: pd.DataFrame,
  71    p_adj: bool = True,
  72    top: int = 25,
  73    top_rank: str = "p_value",
  74    p_val: float | int = 0.05,
  75    lfc: float | int = 0.25,
  76    rescale_adj: bool = True,
  77    image_width: int = 12,
  78    image_high: int = 12,
  79):
  80    """
  81    Generate a volcano plot from differential expression results.
  82
  83    A volcano plot visualizes the relationship between statistical significance
  84    (p-values or standarized p-value) and log(fold change) for each gene, highlighting
  85    genes that pass significance thresholds.
  86
  87    Parameters
  88    ----------
  89    deg_data : pandas.DataFrame
  90        DataFrame containing differential expression results from calc_DEG() function.
  91
  92    p_adj : bool, default=True
  93        If True, use adjusted p-values. If False, use raw p-values.
  94
  95    top : int, default=25
  96        Number of top significant genes to highlight on the plot.
  97
  98    top_rank : str, default='p_value'
  99        Statistic used primarily to determine the top significant genes to highlight on the plot. ['p_value' or 'FC']
 100
 101    p_val : float | int, default=0.05
 102        Significance threshold for p-values (or adjusted p-values).
 103
 104    lfc : float | int, default=0.25
 105        Threshold for absolute log fold change.
 106
 107    rescale_adj : bool, default=True
 108        If True, rescale p-values to avoid long breaks caused by outlier values.
 109
 110    image_width : int, default=12
 111        Width of the generated plot in inches.
 112
 113    image_high : int, default=12
 114        Height of the generated plot in inches.
 115
 116    Returns
 117    -------
 118    matplotlib.figure.Figure
 119        The generated volcano plot figure.
 120
 121    """
 122
 123    if top_rank.upper() not in ["FC", "P_VALUE"]:
 124        raise ValueError("top_rank must be either 'FC' or 'p_value'")
 125
 126    if p_adj:
 127        pv = "adj_pval"
 128    else:
 129        pv = "p_val"
 130
 131    deg_df = deg_data.copy()
 132
 133    shift = 0.25
 134
 135    p_val_scale = "-log(p_val)"
 136
 137    min_minus = min(deg_df[pv][(deg_df[pv] != 0) & (deg_df["log(FC)"] < 0)])
 138    min_plus = min(deg_df[pv][(deg_df[pv] != 0) & (deg_df["log(FC)"] > 0)])
 139
 140    zero_p_plus = deg_df[(deg_df[pv] == 0) & (deg_df["log(FC)"] > 0)]
 141    zero_p_plus = zero_p_plus.sort_values(by="log(FC)", ascending=False).reset_index(
 142        drop=True
 143    )
 144    zero_p_plus[pv] = [
 145        (shift * x) * min_plus for x in range(1, len(zero_p_plus.index) + 1)
 146    ]
 147
 148    zero_p_minus = deg_df[(deg_df[pv] == 0) & (deg_df["log(FC)"] < 0)]
 149    zero_p_minus = zero_p_minus.sort_values(by="log(FC)", ascending=True).reset_index(
 150        drop=True
 151    )
 152    zero_p_minus[pv] = [
 153        (shift * x) * min_minus for x in range(1, len(zero_p_minus.index) + 1)
 154    ]
 155
 156    tmp_p = deg_df[
 157        ((deg_df[pv] != 0) & (deg_df["log(FC)"] < 0))
 158        | ((deg_df[pv] != 0) & (deg_df["log(FC)"] > 0))
 159    ]
 160
 161    del deg_df
 162
 163    deg_df = pd.concat([zero_p_plus, tmp_p, zero_p_minus], ignore_index=True)
 164
 165    deg_df[pv] = deg_df[pv].replace(0, 2**-1074)
 166
 167    deg_df[p_val_scale] = -np.log10(deg_df[pv])
 168
 169    deg_df["top100"] = None
 170
 171    if rescale_adj:
 172
 173        deg_df = deg_df.sort_values(by=p_val_scale, ascending=False)
 174
 175        deg_df = deg_df.reset_index(drop=True)
 176
 177        eps = 1e-300
 178        doubled = []
 179        ratio = []
 180        for n, i in enumerate(deg_df.index):
 181            for j in range(1, 6):
 182                if (
 183                    n + j < len(deg_df.index)
 184                    and (deg_df[p_val_scale][n] + eps)
 185                    / (deg_df[p_val_scale][n + j] + eps)
 186                    >= 2
 187                ):
 188                    doubled.append(n)
 189                    ratio.append(
 190                        (deg_df[p_val_scale][n + j] + eps)
 191                        / (deg_df[p_val_scale][n] + eps)
 192                    )
 193
 194        df = pd.DataFrame({"doubled": doubled, "ratio": ratio})
 195        df = df[df["doubled"] < 100]
 196
 197        df["ratio"] = (1 - df["ratio"]) / 5
 198        df = df.reset_index(drop=True)
 199
 200        df = df.sort_values("doubled")
 201
 202        if len(df["doubled"]) == 1 and 0 in df["doubled"]:
 203            df = df
 204        else:
 205            doubled2 = []
 206
 207            for l in df["doubled"]:
 208                if l + 1 != len(doubled) and l + 1 - l == 1:
 209                    doubled2.append(l)
 210                    doubled2.append(l + 1)
 211                else:
 212                    break
 213
 214            doubled2 = sorted(set(doubled2), reverse=True)
 215
 216        if len(doubled2) > 1:
 217            df = df[df["doubled"].isin(doubled2)]
 218            df = df.sort_values("doubled", ascending=False)
 219            df = df.reset_index(drop=True)
 220            for c in df.index:
 221                deg_df.loc[df["doubled"][c], p_val_scale] = deg_df.loc[
 222                    df["doubled"][c] + 1, p_val_scale
 223                ] * (1 + df["ratio"][c])
 224
 225    deg_df.loc[(deg_df["log(FC)"] <= 0) & (deg_df[pv] <= p_val), "top100"] = "red"
 226    deg_df.loc[(deg_df["log(FC)"] > 0) & (deg_df[pv] <= p_val), "top100"] = "blue"
 227    deg_df.loc[deg_df[pv] > p_val, "top100"] = "lightgray"
 228
 229    if lfc > 0:
 230        deg_df.loc[
 231            (deg_df["log(FC)"] <= lfc) & (deg_df["log(FC)"] >= -lfc), "top100"
 232        ] = "lightgray"
 233
 234    down_int = len(
 235        deg_df["top100"][(deg_df["log(FC)"] <= lfc * -1) & (deg_df[pv] <= p_val)]
 236    )
 237    up_int = len(deg_df["top100"][(deg_df["log(FC)"] > lfc) & (deg_df[pv] <= p_val)])
 238
 239    deg_df_up = deg_df[deg_df["log(FC)"] > 0]
 240
 241    if top_rank.upper() == "P_VALUE":
 242        deg_df_up = deg_df_up.sort_values([pv, "log(FC)"], ascending=[True, False])
 243    elif top_rank.upper() == "FC":
 244        deg_df_up = deg_df_up.sort_values(["log(FC)", pv], ascending=[False, True])
 245
 246    deg_df_up = deg_df_up.reset_index(drop=True)
 247
 248    n = -1
 249    l = 0
 250    while True:
 251        n += 1
 252        if deg_df_up["log(FC)"][n] > lfc and deg_df_up[pv][n] <= p_val:
 253            deg_df_up.loc[n, "top100"] = "green"
 254            l += 1
 255        if l == top or deg_df_up[pv][n] > p_val:
 256            break
 257
 258    deg_df_down = deg_df[deg_df["log(FC)"] <= 0]
 259
 260    if top_rank.upper() == "P_VALUE":
 261        deg_df_down = deg_df_down.sort_values([pv, "log(FC)"], ascending=[True, True])
 262    elif top_rank.upper() == "FC":
 263        deg_df_down = deg_df_down.sort_values(["log(FC)", pv], ascending=[True, True])
 264
 265    deg_df_down = deg_df_down.reset_index(drop=True)
 266
 267    n = -1
 268    l = 0
 269    while True:
 270        n += 1
 271        if deg_df_down["log(FC)"][n] < lfc * -1 and deg_df_down[pv][n] <= p_val:
 272            deg_df_down.loc[n, "top100"] = "yellow"
 273
 274            l += 1
 275        if l == top or deg_df_down[pv][n] > p_val:
 276            break
 277
 278    deg_df = pd.concat([deg_df_up, deg_df_down])
 279
 280    que = ["lightgray", "red", "blue", "yellow", "green"]
 281
 282    deg_df = deg_df.sort_values(
 283        by="top100", key=lambda x: x.map({v: i for i, v in enumerate(que)})
 284    )
 285
 286    deg_df = deg_df.reset_index(drop=True)
 287
 288    fig, ax = plt.subplots(figsize=(image_width, image_high))
 289
 290    plt.scatter(
 291        x=deg_df["log(FC)"], y=deg_df[p_val_scale], color=deg_df["top100"], zorder=2
 292    )
 293
 294    tl = deg_df[p_val_scale][deg_df[pv] >= p_val]
 295
 296    if len(tl) > 0:
 297
 298        line_p = np.max(tl)
 299
 300    else:
 301        line_p = np.min(deg_df[p_val_scale])
 302
 303    plt.plot(
 304        [max(deg_df["log(FC)"]) * -1.1, max(deg_df["log(FC)"]) * 1.1],
 305        [line_p, line_p],
 306        linestyle="--",
 307        linewidth=3,
 308        color="lightgray",
 309        zorder=1,
 310    )
 311
 312    if lfc > 0:
 313        plt.plot(
 314            [lfc * -1, lfc * -1],
 315            [-3, max(deg_df[p_val_scale]) * 1.1],
 316            linestyle="--",
 317            linewidth=3,
 318            color="lightgray",
 319            zorder=1,
 320        )
 321        plt.plot(
 322            [lfc, lfc],
 323            [-3, max(deg_df[p_val_scale]) * 1.1],
 324            linestyle="--",
 325            linewidth=3,
 326            color="lightgray",
 327            zorder=1,
 328        )
 329
 330    plt.xlabel("log(FC)")
 331    plt.ylabel(p_val_scale)
 332    plt.title("Volcano plot")
 333
 334    plt.ylim(min(deg_df[p_val_scale]) - 5, max(deg_df[p_val_scale]) * 1.25)
 335
 336    texts = [
 337        ax.text(deg_df["log(FC)"][i], deg_df[p_val_scale][i], deg_df["feature"][i])
 338        for i in deg_df.index
 339        if deg_df["top100"][i] in ["green", "yellow"]
 340    ]
 341
 342    adjust_text(texts, arrowprops=dict(arrowstyle="-", color="gray", alpha=0.5))
 343
 344    legend_elements = [
 345        Line2D(
 346            [0],
 347            [0],
 348            marker="o",
 349            color="w",
 350            label="top-upregulated",
 351            markerfacecolor="green",
 352            markersize=10,
 353        ),
 354        Line2D(
 355            [0],
 356            [0],
 357            marker="o",
 358            color="w",
 359            label="top-downregulated",
 360            markerfacecolor="yellow",
 361            markersize=10,
 362        ),
 363        Line2D(
 364            [0],
 365            [0],
 366            marker="o",
 367            color="w",
 368            label="upregulated",
 369            markerfacecolor="blue",
 370            markersize=10,
 371        ),
 372        Line2D(
 373            [0],
 374            [0],
 375            marker="o",
 376            color="w",
 377            label="downregulated",
 378            markerfacecolor="red",
 379            markersize=10,
 380        ),
 381        Line2D(
 382            [0],
 383            [0],
 384            marker="o",
 385            color="w",
 386            label="non-significant",
 387            markerfacecolor="lightgray",
 388            markersize=10,
 389        ),
 390    ]
 391
 392    ax.legend(handles=legend_elements, loc="upper right")
 393    ax.grid(visible=False)
 394
 395    ax.annotate(
 396        f"\nmin {pv} = " + str(p_val),
 397        xy=(0.025, 0.975),
 398        xycoords="axes fraction",
 399        fontsize=12,
 400    )
 401
 402    if lfc > 0:
 403        ax.annotate(
 404            "\nmin log(FC) = " + str(lfc),
 405            xy=(0.025, 0.95),
 406            xycoords="axes fraction",
 407            fontsize=12,
 408        )
 409
 410    ax.annotate(
 411        "\nDownregulated: " + str(down_int),
 412        xy=(0.025, 0.925),
 413        xycoords="axes fraction",
 414        fontsize=12,
 415        color="red",
 416    )
 417
 418    ax.annotate(
 419        "\nUpregulated: " + str(up_int),
 420        xy=(0.025, 0.9),
 421        xycoords="axes fraction",
 422        fontsize=12,
 423        color="blue",
 424    )
 425
 426    plt.show()
 427
 428    return fig
 429
 430
 431def find_features(data: pd.DataFrame, features: list):
 432    """
 433    Identify features (rows) from a DataFrame that match a given list of features,
 434    ignoring case sensitivity.
 435
 436    Parameters
 437    ----------
 438    data : pandas.DataFrame
 439        DataFrame with features in the index (rows).
 440
 441    features : list
 442        List of feature names to search for.
 443
 444    Returns
 445    -------
 446    dict
 447        Dictionary with keys:
 448        - "included": list of features found in the DataFrame index.
 449        - "not_included": list of requested features not found in the DataFrame index.
 450        - "potential": list of features in the DataFrame that may be similar.
 451    """
 452
 453    features_upper = [str(x).upper() for x in features]
 454
 455    index_set = set(data.index)
 456
 457    features_in = [x for x in index_set if x.upper() in features_upper]
 458    features_in_upper = [x.upper() for x in features_in]
 459    features_out_upper = [x for x in features_upper if x not in features_in_upper]
 460    features_out = [x for x in features if x.upper() in features_out_upper]
 461    similar_features = [
 462        idx for idx in index_set if any(x in idx.upper() for x in features_out_upper)
 463    ]
 464
 465    return {
 466        "included": features_in,
 467        "not_included": features_out,
 468        "potential": similar_features,
 469    }
 470
 471
 472def find_names(data: pd.DataFrame, names: list):
 473    """
 474    Identify names (columns) from a DataFrame that match a given list of names,
 475    ignoring case sensitivity.
 476
 477    Parameters
 478    ----------
 479    data : pandas.DataFrame
 480        DataFrame with names in the columns.
 481
 482    names : list
 483        List of names to search for.
 484
 485    Returns
 486    -------
 487    dict
 488        Dictionary with keys:
 489        - "included": list of names found in the DataFrame columns.
 490        - "not_included": list of requested names not found in the DataFrame columns.
 491        - "potential": list of names in the DataFrame that may be similar.
 492    """
 493
 494    names_upper = [str(x).upper() for x in names]
 495
 496    columns = set(data.columns)
 497
 498    names_in = [x for x in columns if x.upper() in names_upper]
 499    names_in_upper = [x.upper() for x in names_in]
 500    names_out_upper = [x for x in names_upper if x not in names_in_upper]
 501    names_out = [x for x in names if x.upper() in names_out_upper]
 502    similar_names = [
 503        idx for idx in columns if any(x in idx.upper() for x in names_out_upper)
 504    ]
 505
 506    return {"included": names_in, "not_included": names_out, "potential": similar_names}
 507
 508
 509def reduce_data(data: pd.DataFrame, features: list = [], names: list = []):
 510    """
 511    Subset a DataFrame based on selected features (rows) and/or names (columns).
 512
 513    Parameters
 514    ----------
 515    data : pandas.DataFrame
 516        Input DataFrame with features as rows and names as columns.
 517
 518    features : list
 519        List of features to include (rows). Default is an empty list.
 520        If empty, all rows are returned.
 521
 522    names : list
 523        List of names to include (columns). Default is an empty list.
 524        If empty, all columns are returned.
 525
 526    Returns
 527    -------
 528    pandas.DataFrame
 529        Subset of the input DataFrame containing only the selected rows
 530        and/or columns.
 531
 532    Raises
 533    ------
 534    ValueError
 535        If both `features` and `names` are empty.
 536    """
 537
 538    if len(features) > 0 and len(names) > 0:
 539        fet = find_features(data=data, features=features)
 540
 541        nam = find_names(data=data, names=names)
 542
 543        data_to_return = data.loc[fet["included"], nam["included"]]
 544
 545    elif len(features) > 0 and len(names) == 0:
 546        fet = find_features(data=data, features=features)
 547
 548        data_to_return = data.loc[fet["included"], :]
 549
 550    elif len(features) == 0 and len(names) > 0:
 551
 552        nam = find_names(data=data, names=names)
 553
 554        data_to_return = data.loc[:, nam["included"]]
 555
 556    else:
 557
 558        raise ValueError("features and names have zero length!")
 559
 560    return data_to_return
 561
 562
 563def make_unique_list(lst):
 564    """
 565    Generate a list where duplicate items are renamed to ensure uniqueness.
 566
 567    Each duplicate is appended with a suffix ".n", where n indicates the
 568    occurrence count (starting from 1).
 569
 570    Parameters
 571    ----------
 572    lst : list
 573        Input list of items (strings or other hashable types).
 574
 575    Returns
 576    -------
 577    list
 578        List with unique values.
 579
 580    Examples
 581    --------
 582    >>> make_unique_list(["A", "B", "A", "A"])
 583    ['A', 'B', 'A.1', 'A.2']
 584    """
 585    seen = {}
 586    result = []
 587    for item in lst:
 588        if item not in seen:
 589            seen[item] = 0
 590            result.append(item)
 591        else:
 592            seen[item] += 1
 593            result.append(f"{item}.{seen[item]}")
 594    return result
 595
 596
 597def get_color_palette(variable_list, palette_name="tab10"):
 598    n = len(variable_list)
 599    cmap = plt.get_cmap(palette_name)
 600    colors = [cmap(i % cmap.N) for i in range(n)]
 601    return dict(zip(variable_list, colors))
 602
 603
 604def features_scatter(
 605    expression_data: pd.DataFrame,
 606    occurence_data: pd.DataFrame | None = None,
 607    scale: bool = False,
 608    features: list | None = None,
 609    metadata_list: list | None = None,
 610    colors: str = "viridis",
 611    hclust: str | None = "complete",
 612    img_width: int = 8,
 613    img_high: int = 5,
 614    label_size: int = 10,
 615    size_scale: int = 100,
 616    y_lab: str = "Genes",
 617    legend_lab: str = "log(CPM + 1)",
 618    set_box_size: float | int = 5,
 619    set_box_high: float | int = 5,
 620    bbox_to_anchor_scale: int = 25,
 621    bbox_to_anchor_perc: tuple = (0.91, 0.63),
 622    bbox_to_anchor_group: tuple = (1.01, 0.4),
 623):
 624    """
 625    Create a bubble scatter plot of selected features across samples.
 626
 627    Each point represents a feature-sample pair, where the color encodes the
 628    expression value and the size encodes occurrence or relative abundance.
 629    Optionally, hierarchical clustering can be applied to order rows and columns.
 630
 631    Parameters
 632    ----------
 633    expression_data : pandas.DataFrame
 634        Expression values (mean) with features as rows and samples as columns derived from average() function.
 635
 636    occurence_data : pandas.DataFrame or None
 637        DataFrame with occurrence/frequency values (same shape as `expression_data`) derived from occurrence() function.
 638        If None, bubble sizes are based on expression values.
 639
 640    scale: bool, default False
 641        If True, expression_data (features) will be scaled (0–1) across the colums (sample).
 642
 643    features : list or None
 644        List of features (rows) to display. If None, all features are used.
 645
 646    metadata_list : list or None, optional
 647        Metadata grouping for samples (same length as number of columns).
 648        Used to add group colors and separators in the plot.
 649
 650    colors : str, default='viridis'
 651        Colormap for expression values.
 652
 653    hclust : str or None, default='complete'
 654        Linkage method for hierarchical clustering. If None, no clustering
 655        is performed.
 656
 657    img_width : int or float, default=8
 658        Width of the plot in inches.
 659
 660    img_high : int or float, default=5
 661        Height of the plot in inches.
 662
 663    label_size : int, default=10
 664        Font size for axis labels and ticks.
 665
 666    size_scale : int or float, default=100
 667        Scaling factor for bubble sizes.
 668
 669    y_lab : str, default='Genes'
 670        Label for the x-axis.
 671
 672    legend_lab : str, default='log(CPM + 1)'
 673        Label for the colorbar legend.
 674
 675    bbox_to_anchor_scale : int, default=25
 676        Vertical scale (percentage) for positioning the colorbar.
 677
 678    bbox_to_anchor_perc : tuple, default=(0.91, 0.63)
 679        Anchor position for the size legend (percent bubble legend).
 680
 681    bbox_to_anchor_group : tuple, default=(1.01, 0.4)
 682        Anchor position for the group legend.
 683
 684    Returns
 685    -------
 686    matplotlib.figure.Figure
 687        The generated scatter plot figure.
 688
 689    Raises
 690    ------
 691    ValueError
 692        If `metadata_list` is provided but its length does not match
 693        the number of columns in `expression_data`.
 694
 695    Notes
 696    -----
 697    - Colors represent expression values normalized to the colormap.
 698    - Bubble sizes represent occurrence values (or expression values if
 699      `occurence_data` is None).
 700    - If `metadata_list` is given, groups are indicated with colors and
 701      dashed vertical separators.
 702    """
 703
 704    scatter_df = expression_data.copy()
 705
 706    if scale:
 707
 708        legend_lab = "Scaled\n" + legend_lab
 709
 710        column_max = scatter_df.max()
 711        scatter_df = scatter_df.div(column_max).replace([np.inf, -np.inf], np.nan).fillna(0)
 712        scatter_df = pd.DataFrame(scatter_df, index=scatter_df.index, columns=scatter_df.columns)
 713
 714
 715    metadata = {}
 716
 717    metadata["primary_names"] = [str(x) for x in scatter_df.columns]
 718
 719    if metadata_list is not None:
 720        metadata["sets"] = metadata_list
 721
 722        if len(metadata["primary_names"]) != len(metadata["sets"]):
 723
 724            raise ValueError(
 725                "Metadata list and DataFrame columns must have the same length."
 726            )
 727
 728    else:
 729
 730        metadata["sets"] = [""] * len(metadata["primary_names"])
 731
 732    metadata = pd.DataFrame(metadata)
 733    if features is not None:
 734        scatter_df = scatter_df.loc[
 735            find_features(data=scatter_df, features=features)["included"],
 736        ]
 737    scatter_df.columns = metadata["primary_names"] + "#" + metadata["sets"]
 738
 739    if occurence_data is not None:
 740        if features is not None:
 741            occurence_data = occurence_data.loc[
 742                find_features(data=occurence_data, features=features)["included"],
 743            ]
 744        occurence_data.columns = metadata["primary_names"] + "#" + metadata["sets"]
 745
 746    # check duplicated names
 747
 748    tmp_columns = scatter_df.columns
 749
 750    new_cols = make_unique_list(list(tmp_columns))
 751
 752    scatter_df.columns = new_cols
 753
 754    if hclust is not None and len(expression_data.index) != 1:
 755
 756        Z = linkage(scatter_df, method=hclust)
 757
 758        # Get the order of features based on the dendrogram
 759        order_of_features = dendrogram(Z, no_plot=True)["leaves"]
 760
 761        indexes_sort = list(scatter_df.index)
 762        sorted_list_rows = []
 763        for n in order_of_features:
 764            sorted_list_rows.append(indexes_sort[n])
 765
 766        scatter_df = scatter_df.transpose()
 767
 768        Z = linkage(scatter_df, method=hclust)
 769
 770        # Get the order of features based on the dendrogram
 771        order_of_features = dendrogram(Z, no_plot=True)["leaves"]
 772
 773        indexes_sort = list(scatter_df.index)
 774        sorted_list_columns = []
 775        for n in order_of_features:
 776            sorted_list_columns.append(indexes_sort[n])
 777
 778        scatter_df = scatter_df.transpose()
 779
 780        scatter_df = scatter_df.loc[sorted_list_rows, sorted_list_columns]
 781
 782        if occurence_data is not None:
 783            occurence_data = occurence_data.loc[sorted_list_rows, sorted_list_columns]
 784
 785        metadata["sets"] = [re.sub(".*#", "", x) for x in scatter_df.columns]
 786
 787    scatter_df.columns = [re.sub("#.*", "", x) for x in scatter_df.columns]
 788
 789    if occurence_data is not None:
 790        occurence_data.columns = [re.sub("#.*", "", x) for x in occurence_data.columns]
 791
 792    fig, ax = plt.subplots(figsize=(img_width, img_high))
 793
 794    norm = plt.Normalize(0, np.max(scatter_df))
 795
 796    cmap = plt.get_cmap(colors)
 797
 798    # Bubble scatter
 799    for i, _ in enumerate(scatter_df.index):
 800        for j, _ in enumerate(scatter_df.columns):
 801            if occurence_data is not None:
 802                value_e = scatter_df.iloc[i, j]
 803                value_o = occurence_data.iloc[i, j]
 804                ax.scatter(
 805                    j,
 806                    i,
 807                    s=value_o * size_scale,
 808                    c=[cmap(norm(value_e))],
 809                    edgecolors="k",
 810                    linewidths=0.3,
 811                )
 812            else:
 813                value = scatter_df.iloc[i, j]
 814                ax.scatter(
 815                    j,
 816                    i,
 817                    s=value * size_scale,
 818                    c=[cmap(norm(value))],
 819                    edgecolors="k",
 820                    linewidths=0.3,
 821                )
 822
 823    ax.set_yticks(range(len(scatter_df.index)))
 824    ax.set_yticklabels(scatter_df.index, fontsize=label_size * 0.8)
 825    ax.set_ylabel(y_lab, fontsize=label_size)
 826    ax.set_xticks(range(len(scatter_df.columns)))
 827    ax.set_xticklabels(scatter_df.columns, fontsize=label_size * 0.8, rotation=90)
 828
 829    ax_pos = ax.get_position()
 830
 831    width_fig = 0.01
 832    height_fig = ax_pos.height * (bbox_to_anchor_scale / 100)
 833    left_fig = ax_pos.x1 + 0.01
 834    bottom_fig = ax_pos.y1 - height_fig
 835
 836    cax = fig.add_axes([left_fig, bottom_fig, width_fig, height_fig])
 837    cb = mpl.colorbar.ColorbarBase(cax, cmap=cmap, norm=norm)
 838    cb.set_label(legend_lab, fontsize=label_size * 0.65)
 839    cb.ax.tick_params(labelsize=label_size * 0.7)
 840
 841    if metadata_list is not None:
 842
 843        metadata_list = list(metadata["sets"])
 844        group_colors = get_color_palette(list(set(metadata_list)), palette_name="tab10")
 845
 846        for i, group in enumerate(metadata_list):
 847            ax.add_patch(
 848                plt.Rectangle(
 849                    (i - 0.5, len(scatter_df.index) - 0.1 * set_box_high),
 850                    1,
 851                    0.1 * set_box_size,
 852                    color=group_colors[group],
 853                    transform=ax.transData,
 854                    clip_on=False,
 855                )
 856            )
 857
 858        for i in range(1, len(metadata_list)):
 859            if metadata_list[i] != metadata_list[i - 1]:
 860                ax.axvline(i - 0.5, color="black", linestyle="--", lw=1)
 861
 862        group_patches = [
 863            mpatches.Patch(color=color, label=label)
 864            for label, color in group_colors.items()
 865        ]
 866        fig.legend(
 867            handles=group_patches,
 868            title="Group",
 869            fontsize=label_size * 0.7,
 870            title_fontsize=label_size * 0.7,
 871            loc="center left",
 872            bbox_to_anchor=bbox_to_anchor_group,
 873            frameon=False,
 874        )
 875
 876    # second legend (size)
 877    if occurence_data is not None:
 878        size_values = [0.25, 0.5, 1]
 879        legend2_handles = [
 880            plt.Line2D(
 881                [],
 882                [],
 883                marker="o",
 884                linestyle="",
 885                markersize=np.sqrt(v * size_scale * 0.5),
 886                color="gray",
 887                alpha=0.6,
 888                label=f"{v * 100:.1f}",
 889            )
 890            for v in size_values
 891        ]
 892
 893        fig.legend(
 894            handles=legend2_handles,
 895            title="Percent [%]",
 896            fontsize=label_size * 0.7,
 897            title_fontsize=label_size * 0.7,
 898            loc="center left",
 899            bbox_to_anchor=bbox_to_anchor_perc,
 900            frameon=False,
 901        )
 902
 903    _, ymax = ax.get_ylim()
 904
 905    ax.set_xlim(-0.5, len(scatter_df.columns) - 0.5)
 906    ax.set_ylim(-0.5, ymax + 0.5)
 907
 908    return fig
 909
 910
 911def calc_DEG(
 912    data,
 913    metadata_list: list | None = None,
 914    entities: str | list | dict | None = None,
 915    sets: str | list | dict | None = None,
 916    min_exp: int | float = 0,
 917    min_pct: int | float = 0.1,
 918    n_proc: int = 10,
 919):
 920    """
 921    Perform differential gene expression (DEG) analysis on gene expression data.
 922
 923    The function compares groups of cells or samples (defined by `entities` or
 924    `sets`) using the Mann–Whitney U test. It computes p-values, adjusted
 925    p-values, fold changes, standardized effect sizes, and other statistics.
 926
 927    Parameters
 928    ----------
 929    data : pandas.DataFrame
 930        Expression matrix with features (e.g., genes) as rows and samples/cells
 931        as columns.
 932
 933    metadata_list : list or None, optional
 934        Metadata grouping corresponding to the columns in `data`. Required for
 935        comparisons based on sets. Default is None.
 936
 937    entities : list, str, dict, or None, optional
 938        Defines the comparison strategy:
 939        - list of sample names → compare selected cells to the rest.
 940        - 'All' → compare each sample/cell to all others.
 941        - dict → user-defined groups for pairwise comparison.
 942        - None → must be combined with `sets`.
 943
 944    sets : str, dict, or None, optional
 945        Defines group-based comparisons:
 946        - 'All' → compare each set/group to all others.
 947        - dict with two groups → perform pairwise set comparison.
 948        - None → must be combined with `entities`.
 949
 950    min_exp : float | int, default=0
 951        Minimum expression threshold for filtering features.
 952
 953    min_pct : float | int, default=0.1
 954        Minimum proportion of samples within the target group that must express
 955        a feature for it to be tested.
 956
 957    n_proc : int, default=10
 958        Number of parallel processes to use for statistical testing.
 959
 960    Returns
 961    -------
 962    pandas.DataFrame or dict
 963        Results of the differential expression analysis:
 964        - If `entities` is a list → dict with keys: 'valid_cells',
 965          'control_cells', and 'DEG' (results DataFrame).
 966        - If `entities == 'All'` or `sets == 'All'` → DataFrame with results
 967          for all groups.
 968        - If pairwise comparison (dict for `entities` or `sets`) → DataFrame
 969          with results for the specified groups.
 970
 971        The results DataFrame contains:
 972        - 'feature': feature name
 973        - 'p_val': raw p-value
 974        - 'adj_pval': adjusted p-value (multiple testing correction)
 975        - 'pct_valid': fraction of target group expressing the feature
 976        - 'pct_ctrl': fraction of control group expressing the feature
 977        - 'avg_valid': mean expression in target group
 978        - 'avg_ctrl': mean expression in control group
 979        - 'sd_valid': standard deviation in target group
 980        - 'sd_ctrl': standard deviation in control group
 981        - 'esm': effect size metric
 982        - 'FC': fold change
 983        - 'log(FC)': log2-transformed fold change
 984        - 'norm_diff': difference in mean expression
 985
 986    Raises
 987    ------
 988    ValueError
 989        - If `metadata_list` is provided but its length does not match
 990          the number of columns in `data`.
 991        - If neither `entities` nor `sets` is provided.
 992
 993    Notes
 994    -----
 995    - Mann–Whitney U test is used for group comparisons.
 996    - Multiple testing correction is applied using a simple
 997      Benjamini–Hochberg-like method.
 998    - Features expressed below `min_exp` or in fewer than `min_pct` of target
 999      samples are filtered out.
1000    - Parallelization is handled by `joblib.Parallel`.
1001
1002    Examples
1003    --------
1004    Compare a selected list of cells against all others:
1005
1006    >>> result = calc_DEG(data, entities=["cell1", "cell2", "cell3"])
1007
1008    Compare each group to others (based on metadata):
1009
1010    >>> result = calc_DEG(data, metadata_list=group_labels, sets="All")
1011
1012    Perform pairwise comparison between two predefined sets:
1013
1014    >>> sets = {"GroupA": ["A1", "A2"], "GroupB": ["B1", "B2"]}
1015    >>> result = calc_DEG(data, sets=sets)
1016    """
1017    offset = 1e-100
1018
1019    metadata = {}
1020
1021    metadata["primary_names"] = [str(x) for x in data.columns]
1022
1023    if metadata_list is not None:
1024        metadata["sets"] = metadata_list
1025
1026        if len(metadata["primary_names"]) != len(metadata["sets"]):
1027
1028            raise ValueError(
1029                "Metadata list and DataFrame columns must have the same length."
1030            )
1031
1032    else:
1033
1034        metadata["sets"] = [""] * len(metadata["primary_names"])
1035
1036    metadata = pd.DataFrame(metadata)
1037
1038    def stat_calc(choose, feature_name):
1039        target_values = choose.loc[choose["DEG"] == "target", feature_name]
1040        rest_values = choose.loc[choose["DEG"] == "rest", feature_name]
1041
1042        pct_valid = (target_values > 0).sum() / len(target_values)
1043        pct_rest = (rest_values > 0).sum() / len(rest_values)
1044
1045        avg_valid = np.mean(target_values)
1046        avg_ctrl = np.mean(rest_values)
1047        sd_valid = np.std(target_values, ddof=1)
1048        sd_ctrl = np.std(rest_values, ddof=1)
1049        esm = (avg_valid - avg_ctrl) / np.sqrt(((sd_valid**2 + sd_ctrl**2) / 2))
1050
1051        if np.sum(target_values) == np.sum(rest_values):
1052            p_val = 1.0
1053        else:
1054            _, p_val = stats.mannwhitneyu(
1055                target_values, rest_values, alternative="two-sided"
1056            )
1057
1058        return {
1059            "feature": feature_name,
1060            "p_val": p_val,
1061            "pct_valid": pct_valid,
1062            "pct_ctrl": pct_rest,
1063            "avg_valid": avg_valid,
1064            "avg_ctrl": avg_ctrl,
1065            "sd_valid": sd_valid,
1066            "sd_ctrl": sd_ctrl,
1067            "esm": esm,
1068        }
1069
1070    def prepare_and_run_stat(choose, valid_group, min_exp, min_pct, n_proc):
1071
1072        def safe_min_half(series):
1073            filtered = series[(series > ((2**-1074)*2)) & (series.notna())]
1074            return filtered.min() / 2 if not filtered.empty else 0
1075
1076        tmp_dat = choose[choose["DEG"] == "target"]
1077        tmp_dat = tmp_dat.drop("DEG", axis=1)
1078
1079        counts = (tmp_dat > min_exp).sum(axis=0)
1080
1081        total_count = tmp_dat.shape[0]
1082
1083        info = pd.DataFrame(
1084            {"feature": list(tmp_dat.columns), "pct": list(counts / total_count)}
1085        )
1086
1087        del tmp_dat
1088
1089        drop_col = info["feature"][info["pct"] <= min_pct]
1090
1091        if len(drop_col) + 1 == len(choose.columns):
1092            drop_col = info["feature"][info["pct"] == 0]
1093
1094        del info
1095
1096        choose = choose.drop(list(drop_col), axis=1)
1097
1098        results = Parallel(n_jobs=n_proc)(
1099            delayed(stat_calc)(choose, feature)
1100            for feature in tqdm(choose.columns[choose.columns != "DEG"])
1101        )
1102
1103        if len(results) > 0:
1104            df = pd.DataFrame(results)
1105
1106            df = df[(df["avg_valid"] > 0) | (df["avg_ctrl"] > 0)]
1107
1108            df["valid_group"] = valid_group
1109            df.sort_values(by="p_val", inplace=True)
1110
1111            num_tests = len(df)
1112            df["adj_pval"] = np.minimum(
1113                1, (df["p_val"] * num_tests) / np.arange(1, num_tests + 1)
1114            )             
1115
1116            valid_factor = safe_min_half(df["avg_valid"])
1117            ctrl_factor = safe_min_half(df["avg_ctrl"])
1118
1119            cv_factor = min(valid_factor, ctrl_factor)
1120
1121            if cv_factor == 0:
1122                cv_factor = max(valid_factor, ctrl_factor)
1123
1124            if not np.isfinite(cv_factor) or cv_factor == 0:
1125                cv_factor += offset
1126
1127            valid = df["avg_valid"].where(
1128                df["avg_valid"] != 0, df["avg_valid"] + cv_factor
1129            )
1130            ctrl = df["avg_ctrl"].where(
1131                df["avg_ctrl"] != 0, df["avg_ctrl"] + cv_factor
1132            )
1133
1134            df["FC"] = valid / ctrl
1135
1136            df["log(FC)"] = np.log2(df["FC"])
1137            df["norm_diff"] = df["avg_valid"] - df["avg_ctrl"]
1138
1139        else:
1140            columns = [
1141                "feature",
1142                "valid_group",
1143                "p_val",
1144                "adj_pval",
1145                "avg_valid",
1146                "avg_ctrl",
1147                "FC",
1148                "log(FC)",
1149                "norm_diff",
1150            ]
1151            df = pd.DataFrame(columns=columns)
1152        return df
1153
1154    choose = data.T
1155
1156    final_results = []
1157
1158    if isinstance(entities, list) and sets is None:
1159        print("\nAnalysis started...\nComparing selected cells to the whole set...")
1160
1161        if metadata_list is None:
1162            choose.index = metadata["primary_names"]
1163        else:
1164            choose.index = metadata["primary_names"] + " # " + metadata["sets"]
1165
1166            if "#" not in entities[0]:
1167                choose.index = metadata["primary_names"]
1168                print(
1169                    "You provided 'metadata_list', but did not include the set info (name # set) "
1170                    "in the 'entities' list. "
1171                    "Only the names will be compared, without considering the set information."
1172                )
1173
1174        labels = ["target" if idx in entities else "rest" for idx in choose.index]
1175        valid = list(
1176            set(choose.index[[i for i, x in enumerate(labels) if x == "target"]])
1177        )
1178
1179        choose["DEG"] = labels
1180        choose = choose[choose["DEG"] != "drop"]
1181
1182        result_df = prepare_and_run_stat(
1183            choose.reset_index(drop=True),
1184            valid_group=valid,
1185            min_exp=min_exp,
1186            min_pct=min_pct,
1187            n_proc=n_proc,
1188        )
1189
1190        return {"valid": valid, "control": "rest", "DEG": result_df}
1191
1192    elif entities == "All" and sets is None:
1193        print("\nAnalysis started...\nComparing each type of cell to others...")
1194
1195        if metadata_list is None:
1196            choose.index = metadata["primary_names"]
1197        else:
1198            choose.index = metadata["primary_names"] + " # " + metadata["sets"]
1199
1200        unique_labels = set(choose.index)
1201
1202        for label in tqdm(unique_labels):
1203            print(f"\nCalculating statistics for {label}")
1204            labels = ["target" if idx == label else "rest" for idx in choose.index]
1205            choose["DEG"] = labels
1206            choose = choose[choose["DEG"] != "drop"]
1207            result_df = prepare_and_run_stat(
1208                choose.copy(),
1209                valid_group=label,
1210                min_exp=min_exp,
1211                min_pct=min_pct,
1212                n_proc=n_proc,
1213            )
1214            final_results.append(result_df)
1215
1216        final_results = pd.concat(final_results, ignore_index=True)
1217
1218        if metadata_list is None:
1219            final_results["valid_group"] = [
1220                re.sub(" # ", "", x) for x in final_results["valid_group"]
1221            ]
1222
1223        return final_results
1224
1225    elif entities is None and sets == "All":
1226        print("\nAnalysis started...\nComparing each set/group to others...")
1227        choose.index = metadata["sets"]
1228        unique_sets = set(choose.index)
1229
1230        for label in tqdm(unique_sets):
1231            print(f"\nCalculating statistics for {label}")
1232            labels = ["target" if idx == label else "rest" for idx in choose.index]
1233
1234            choose["DEG"] = labels
1235            choose = choose[choose["DEG"] != "drop"]
1236            result_df = prepare_and_run_stat(
1237                choose.copy(),
1238                valid_group=label,
1239                min_exp=min_exp,
1240                min_pct=min_pct,
1241                n_proc=n_proc,
1242            )
1243            final_results.append(result_df)
1244
1245        return pd.concat(final_results, ignore_index=True)
1246
1247    elif entities is None and isinstance(sets, dict):
1248        print("\nAnalysis started...\nComparing groups...")
1249        choose.index = metadata["sets"]
1250
1251        group_list = list(sets.keys())
1252        if len(group_list) != 2:
1253            print("Only pairwise group comparison is supported.")
1254            return None
1255
1256        labels = [
1257            (
1258                "target"
1259                if idx in sets[group_list[0]]
1260                else "rest" if idx in sets[group_list[1]] else "drop"
1261            )
1262            for idx in choose.index
1263        ]
1264        choose["DEG"] = labels
1265        choose = choose[choose["DEG"] != "drop"]
1266
1267        result_df = prepare_and_run_stat(
1268            choose.reset_index(drop=True),
1269            valid_group=group_list[0],
1270            min_exp=min_exp,
1271            min_pct=min_pct,
1272            n_proc=n_proc,
1273        )
1274        return result_df
1275
1276    elif isinstance(entities, dict) and sets is None:
1277        print("\nAnalysis started...\nComparing groups...")
1278
1279        if metadata_list is None:
1280            choose.index = metadata["primary_names"]
1281        else:
1282            choose.index = metadata["primary_names"] + " # " + metadata["sets"]
1283            if "#" not in entities[list(entities.keys())[0]][0]:
1284                choose.index = metadata["primary_names"]
1285                print(
1286                    "You provided 'metadata_list', but did not include the set info (name # set) "
1287                    "in the 'entities' dict. "
1288                    "Only the names will be compared, without considering the set information."
1289                )
1290
1291        group_list = list(entities.keys())
1292        if len(group_list) != 2:
1293            print("Only pairwise group comparison is supported.")
1294            return None
1295
1296        labels = [
1297            (
1298                "target"
1299                if idx in entities[group_list[0]]
1300                else "rest" if idx in entities[group_list[1]] else "drop"
1301            )
1302            for idx in choose.index
1303        ]
1304
1305        choose["DEG"] = labels
1306        choose = choose[choose["DEG"] != "drop"]
1307
1308        result_df = prepare_and_run_stat(
1309            choose.reset_index(drop=True),
1310            valid_group=group_list[0],
1311            min_exp=min_exp,
1312            min_pct=min_pct,
1313            n_proc=n_proc,
1314        )
1315
1316        return result_df.reset_index(drop=True)
1317
1318    else:
1319        raise ValueError(
1320            "You must specify either 'entities' or 'sets'. None were provided, which is not allowed for this analysis."
1321        )
1322
1323
1324def average(data):
1325    """
1326    Compute the column-wise average of a DataFrame, aggregating by column names.
1327
1328    If multiple columns share the same name, their values are averaged.
1329
1330    Parameters
1331    ----------
1332    data : pandas.DataFrame
1333        Input DataFrame with numeric values. Columns with identical names
1334        will be aggregated by their mean.
1335
1336    Returns
1337    -------
1338    pandas.DataFrame
1339        DataFrame with the same rows as the input but with unique columns,
1340        where duplicate columns have been replaced by their mean values.
1341    """
1342
1343    wide_data = data
1344
1345    aggregated_df = wide_data.T.groupby(level=0).mean().T
1346
1347    return aggregated_df
1348
1349
1350def occurrence(data):
1351    """
1352    Calculate the occurrence frequency of features in a DataFrame.
1353
1354    Converts the input DataFrame to binary (presence/absence) and computes
1355    the proportion of non-zero entries for each feature, aggregating by
1356    column names if duplicates exist.
1357
1358    Parameters
1359    ----------
1360    data : pandas.DataFrame
1361        Input DataFrame with numeric values. Each column represents a feature.
1362
1363    Returns
1364    -------
1365    pandas.DataFrame
1366        DataFrame with the same rows as the input, where each value represents
1367        the proportion of samples in which the feature is present (non-zero).
1368        Columns with identical names are aggregated.
1369    """
1370
1371    binary_data = (data > 0).astype(int)
1372
1373    counts = binary_data.columns.value_counts()
1374
1375    binary_data = binary_data.T.groupby(level=0).sum().T
1376    binary_data = binary_data.astype(float)
1377
1378    for i in counts.index:
1379        binary_data.loc[:, i] = (binary_data.loc[:, i] / counts[i]).astype(float)
1380
1381    return binary_data
1382
1383
1384def add_subnames(names_list: list, parent_name: str, new_clusters: list):
1385    """
1386    Append sub-cluster names to a parent name within a list of names.
1387
1388    This function replaces occurrences of `parent_name` in `names_list` with
1389    a concatenation of the parent name and corresponding sub-cluster name
1390    from `new_clusters` (formatted as "parent.subcluster"). Non-matching names
1391    are left unchanged.
1392
1393    Parameters
1394    ----------
1395    names_list : list
1396        Original list of names (e.g., column names or cluster labels).
1397
1398    parent_name : str
1399        Name of the parent cluster to which sub-cluster names will be added.
1400        Must exist in `names_list`.
1401
1402    new_clusters : list
1403        List of sub-cluster names. Its length must match the number of times
1404        `parent_name` occurs in `names_list`.
1405
1406    Returns
1407    -------
1408    list
1409        Updated list of names with sub-cluster names appended to the parent name.
1410
1411    Raises
1412    ------
1413    ValueError
1414        - If `parent_name` is not found in `names_list`.
1415        - If `new_clusters` length does not match the number of occurrences of
1416          `parent_name`.
1417
1418    Examples
1419    --------
1420    >>> add_subnames(['A', 'B', 'A'], 'A', ['1', '2'])
1421    ['A.1', 'B', 'A.2']
1422    """
1423
1424    if str(parent_name) not in [str(x) for x in names_list]:
1425        raise ValueError(
1426            "Parent name is missing from the original dataset`s column names!"
1427        )
1428
1429    if len(new_clusters) != len([x for x in names_list if str(x) == str(parent_name)]):
1430        raise ValueError(
1431            "New cluster names list has a different length than the number of clusters in the original dataset!"
1432        )
1433
1434    new_names = []
1435    ixn = 0
1436    for _, i in enumerate(names_list):
1437        if str(i) == str(parent_name):
1438
1439            new_names.append(f"{parent_name}.{new_clusters[ixn]}")
1440            ixn += 1
1441
1442        else:
1443            new_names.append(i)
1444
1445    return new_names
1446
1447
1448def development_clust(
1449    data: pd.DataFrame, method: str = "ward", img_width: int = 5, img_high: int = 5
1450):
1451    """
1452    Perform hierarchical clustering on the columns of a DataFrame and plot a dendrogram.
1453
1454    Uses Ward's method to cluster the transposed data (columns) and generates
1455    a dendrogram showing the relationships between features or samples.
1456
1457    Parameters
1458    ----------
1459    data : pandas.DataFrame
1460        Input DataFrame with features as rows and samples/columns to be clustered.
1461
1462    method : str
1463        Method for hierarchical clustering. Options include:
1464       - 'ward' : minimizes the variance of clusters being merged.
1465       - 'single' : uses the minimum of the distances between all observations of the two sets.
1466       - 'complete' : uses the maximum of the distances between all observations of the two sets.
1467       - 'average' : uses the average of the distances between all observations of the two sets.
1468
1469    img_width : int or float, default=5
1470        Width of the resulting figure in inches.
1471
1472    img_high : int or float, default=5
1473        Height of the resulting figure in inches.
1474
1475    Returns
1476    -------
1477    matplotlib.figure.Figure
1478        The dendrogram figure.
1479    """
1480
1481    z = linkage(data.T, method=method)
1482
1483    figure, ax = plt.subplots(figsize=(img_width, img_high))
1484
1485    dendrogram(z, labels=data.columns, orientation="left", ax=ax)
1486
1487    return figure
1488
1489
1490def adjust_cells_to_group_mean(data, data_avg, beta=0.2):
1491    """
1492    Adjust each cell's values towards the mean of its group (centroid).
1493
1494    This function moves each cell's values in `data` slightly towards the
1495    corresponding group mean in `data_avg`, controlled by the parameter `beta`.
1496
1497    Parameters
1498    ----------
1499    data : pandas.DataFrame
1500        Original data with features as rows and cells/samples as columns.
1501
1502    data_avg : pandas.DataFrame
1503        DataFrame of group averages (centroids) with features as rows and
1504        group names as columns.
1505
1506    beta : float, default=0.2
1507        Weight for adjustment towards the group mean. 0 = no adjustment,
1508        1 = fully replaced by the group mean.
1509
1510    Returns
1511    -------
1512    pandas.DataFrame
1513        Adjusted data with the same shape as the input `data`.
1514    """
1515
1516    df_adjusted = data.copy()
1517
1518    for group_name in data_avg.columns:
1519        col_idx = [
1520            i
1521            for i, c in enumerate(df_adjusted.columns)
1522            if str(c).startswith(group_name)
1523        ]
1524        if not col_idx:
1525            continue
1526
1527        centroid = data_avg.loc[df_adjusted.index, group_name].to_numpy()[:, None]
1528
1529        df_adjusted.iloc[:, col_idx] = (1 - beta) * df_adjusted.iloc[
1530            :, col_idx
1531        ].to_numpy() + beta * centroid
1532
1533    return df_adjusted
def load_sparse(path: str, name: str):
20def load_sparse(path: str, name: str):
21    """
22    Load a sparse matrix dataset along with associated gene
23    and cell metadata, and return it as a dense DataFrame.
24
25    This function expects the input directory to contain three files in standard
26    10x Genomics format:
27      - "matrix.mtx": the gene expression matrix in Matrix Market format
28      - "genes.tsv": tab-separated file containing gene identifiers
29      - "barcodes.tsv": tab-separated file containing cell barcodes / names
30
31    Parameters
32    ----------
33    path : str
34        Path to the directory containing the matrix and annotation files.
35
36    name : str
37        Label or dataset identifier to be assigned to all cells in the metadata.
38
39    Returns
40    -------
41    data : pandas.DataFrame
42        Dense expression matrix where rows correspond to genes and columns
43        correspond to cells.
44    metadata : pandas.DataFrame
45        Metadata DataFrame with two columns:
46          - "cell_names": the names of the cells (barcodes)
47          - "sets": the dataset label assigned to each cell (from `name`)
48
49    Notes
50    -----
51    The function converts the sparse matrix into a dense DataFrame. This may
52    require a large amount of memory for datasets with many cells and genes.
53    """
54
55    data = mmread(os.path.join(path, "matrix.mtx"))
56    data = pd.DataFrame(data.todense())
57    genes = pd.read_csv(os.path.join(path, "genes.tsv"), header=None, sep="\t")
58    names = pd.read_csv(os.path.join(path, "barcodes.tsv"), header=None, sep="\t")
59    data.columns = [str(x) for x in names[0]]
60    data.index = list(genes[0])
61
62    names = list(data.columns)
63    sets = [name] * len(names)
64
65    metadata = pd.DataFrame({"cell_names": names, "sets": sets})
66
67    return data, metadata

Load a sparse matrix dataset along with associated gene and cell metadata, and return it as a dense DataFrame.

This function expects the input directory to contain three files in standard 10x Genomics format:

  • "matrix.mtx": the gene expression matrix in Matrix Market format
  • "genes.tsv": tab-separated file containing gene identifiers
  • "barcodes.tsv": tab-separated file containing cell barcodes / names

Parameters

path : str Path to the directory containing the matrix and annotation files.

name : str Label or dataset identifier to be assigned to all cells in the metadata.

Returns

data : pandas.DataFrame Dense expression matrix where rows correspond to genes and columns correspond to cells. metadata : pandas.DataFrame Metadata DataFrame with two columns: - "cell_names": the names of the cells (barcodes) - "sets": the dataset label assigned to each cell (from name)

Notes

The function converts the sparse matrix into a dense DataFrame. This may require a large amount of memory for datasets with many cells and genes.

def volcano_plot( deg_data: pandas.core.frame.DataFrame, p_adj: bool = True, top: int = 25, top_rank: str = 'p_value', p_val: float | int = 0.05, lfc: float | int = 0.25, rescale_adj: bool = True, image_width: int = 12, image_high: int = 12):
 70def volcano_plot(
 71    deg_data: pd.DataFrame,
 72    p_adj: bool = True,
 73    top: int = 25,
 74    top_rank: str = "p_value",
 75    p_val: float | int = 0.05,
 76    lfc: float | int = 0.25,
 77    rescale_adj: bool = True,
 78    image_width: int = 12,
 79    image_high: int = 12,
 80):
 81    """
 82    Generate a volcano plot from differential expression results.
 83
 84    A volcano plot visualizes the relationship between statistical significance
 85    (p-values or standarized p-value) and log(fold change) for each gene, highlighting
 86    genes that pass significance thresholds.
 87
 88    Parameters
 89    ----------
 90    deg_data : pandas.DataFrame
 91        DataFrame containing differential expression results from calc_DEG() function.
 92
 93    p_adj : bool, default=True
 94        If True, use adjusted p-values. If False, use raw p-values.
 95
 96    top : int, default=25
 97        Number of top significant genes to highlight on the plot.
 98
 99    top_rank : str, default='p_value'
100        Statistic used primarily to determine the top significant genes to highlight on the plot. ['p_value' or 'FC']
101
102    p_val : float | int, default=0.05
103        Significance threshold for p-values (or adjusted p-values).
104
105    lfc : float | int, default=0.25
106        Threshold for absolute log fold change.
107
108    rescale_adj : bool, default=True
109        If True, rescale p-values to avoid long breaks caused by outlier values.
110
111    image_width : int, default=12
112        Width of the generated plot in inches.
113
114    image_high : int, default=12
115        Height of the generated plot in inches.
116
117    Returns
118    -------
119    matplotlib.figure.Figure
120        The generated volcano plot figure.
121
122    """
123
124    if top_rank.upper() not in ["FC", "P_VALUE"]:
125        raise ValueError("top_rank must be either 'FC' or 'p_value'")
126
127    if p_adj:
128        pv = "adj_pval"
129    else:
130        pv = "p_val"
131
132    deg_df = deg_data.copy()
133
134    shift = 0.25
135
136    p_val_scale = "-log(p_val)"
137
138    min_minus = min(deg_df[pv][(deg_df[pv] != 0) & (deg_df["log(FC)"] < 0)])
139    min_plus = min(deg_df[pv][(deg_df[pv] != 0) & (deg_df["log(FC)"] > 0)])
140
141    zero_p_plus = deg_df[(deg_df[pv] == 0) & (deg_df["log(FC)"] > 0)]
142    zero_p_plus = zero_p_plus.sort_values(by="log(FC)", ascending=False).reset_index(
143        drop=True
144    )
145    zero_p_plus[pv] = [
146        (shift * x) * min_plus for x in range(1, len(zero_p_plus.index) + 1)
147    ]
148
149    zero_p_minus = deg_df[(deg_df[pv] == 0) & (deg_df["log(FC)"] < 0)]
150    zero_p_minus = zero_p_minus.sort_values(by="log(FC)", ascending=True).reset_index(
151        drop=True
152    )
153    zero_p_minus[pv] = [
154        (shift * x) * min_minus for x in range(1, len(zero_p_minus.index) + 1)
155    ]
156
157    tmp_p = deg_df[
158        ((deg_df[pv] != 0) & (deg_df["log(FC)"] < 0))
159        | ((deg_df[pv] != 0) & (deg_df["log(FC)"] > 0))
160    ]
161
162    del deg_df
163
164    deg_df = pd.concat([zero_p_plus, tmp_p, zero_p_minus], ignore_index=True)
165
166    deg_df[pv] = deg_df[pv].replace(0, 2**-1074)
167
168    deg_df[p_val_scale] = -np.log10(deg_df[pv])
169
170    deg_df["top100"] = None
171
172    if rescale_adj:
173
174        deg_df = deg_df.sort_values(by=p_val_scale, ascending=False)
175
176        deg_df = deg_df.reset_index(drop=True)
177
178        eps = 1e-300
179        doubled = []
180        ratio = []
181        for n, i in enumerate(deg_df.index):
182            for j in range(1, 6):
183                if (
184                    n + j < len(deg_df.index)
185                    and (deg_df[p_val_scale][n] + eps)
186                    / (deg_df[p_val_scale][n + j] + eps)
187                    >= 2
188                ):
189                    doubled.append(n)
190                    ratio.append(
191                        (deg_df[p_val_scale][n + j] + eps)
192                        / (deg_df[p_val_scale][n] + eps)
193                    )
194
195        df = pd.DataFrame({"doubled": doubled, "ratio": ratio})
196        df = df[df["doubled"] < 100]
197
198        df["ratio"] = (1 - df["ratio"]) / 5
199        df = df.reset_index(drop=True)
200
201        df = df.sort_values("doubled")
202
203        if len(df["doubled"]) == 1 and 0 in df["doubled"]:
204            df = df
205        else:
206            doubled2 = []
207
208            for l in df["doubled"]:
209                if l + 1 != len(doubled) and l + 1 - l == 1:
210                    doubled2.append(l)
211                    doubled2.append(l + 1)
212                else:
213                    break
214
215            doubled2 = sorted(set(doubled2), reverse=True)
216
217        if len(doubled2) > 1:
218            df = df[df["doubled"].isin(doubled2)]
219            df = df.sort_values("doubled", ascending=False)
220            df = df.reset_index(drop=True)
221            for c in df.index:
222                deg_df.loc[df["doubled"][c], p_val_scale] = deg_df.loc[
223                    df["doubled"][c] + 1, p_val_scale
224                ] * (1 + df["ratio"][c])
225
226    deg_df.loc[(deg_df["log(FC)"] <= 0) & (deg_df[pv] <= p_val), "top100"] = "red"
227    deg_df.loc[(deg_df["log(FC)"] > 0) & (deg_df[pv] <= p_val), "top100"] = "blue"
228    deg_df.loc[deg_df[pv] > p_val, "top100"] = "lightgray"
229
230    if lfc > 0:
231        deg_df.loc[
232            (deg_df["log(FC)"] <= lfc) & (deg_df["log(FC)"] >= -lfc), "top100"
233        ] = "lightgray"
234
235    down_int = len(
236        deg_df["top100"][(deg_df["log(FC)"] <= lfc * -1) & (deg_df[pv] <= p_val)]
237    )
238    up_int = len(deg_df["top100"][(deg_df["log(FC)"] > lfc) & (deg_df[pv] <= p_val)])
239
240    deg_df_up = deg_df[deg_df["log(FC)"] > 0]
241
242    if top_rank.upper() == "P_VALUE":
243        deg_df_up = deg_df_up.sort_values([pv, "log(FC)"], ascending=[True, False])
244    elif top_rank.upper() == "FC":
245        deg_df_up = deg_df_up.sort_values(["log(FC)", pv], ascending=[False, True])
246
247    deg_df_up = deg_df_up.reset_index(drop=True)
248
249    n = -1
250    l = 0
251    while True:
252        n += 1
253        if deg_df_up["log(FC)"][n] > lfc and deg_df_up[pv][n] <= p_val:
254            deg_df_up.loc[n, "top100"] = "green"
255            l += 1
256        if l == top or deg_df_up[pv][n] > p_val:
257            break
258
259    deg_df_down = deg_df[deg_df["log(FC)"] <= 0]
260
261    if top_rank.upper() == "P_VALUE":
262        deg_df_down = deg_df_down.sort_values([pv, "log(FC)"], ascending=[True, True])
263    elif top_rank.upper() == "FC":
264        deg_df_down = deg_df_down.sort_values(["log(FC)", pv], ascending=[True, True])
265
266    deg_df_down = deg_df_down.reset_index(drop=True)
267
268    n = -1
269    l = 0
270    while True:
271        n += 1
272        if deg_df_down["log(FC)"][n] < lfc * -1 and deg_df_down[pv][n] <= p_val:
273            deg_df_down.loc[n, "top100"] = "yellow"
274
275            l += 1
276        if l == top or deg_df_down[pv][n] > p_val:
277            break
278
279    deg_df = pd.concat([deg_df_up, deg_df_down])
280
281    que = ["lightgray", "red", "blue", "yellow", "green"]
282
283    deg_df = deg_df.sort_values(
284        by="top100", key=lambda x: x.map({v: i for i, v in enumerate(que)})
285    )
286
287    deg_df = deg_df.reset_index(drop=True)
288
289    fig, ax = plt.subplots(figsize=(image_width, image_high))
290
291    plt.scatter(
292        x=deg_df["log(FC)"], y=deg_df[p_val_scale], color=deg_df["top100"], zorder=2
293    )
294
295    tl = deg_df[p_val_scale][deg_df[pv] >= p_val]
296
297    if len(tl) > 0:
298
299        line_p = np.max(tl)
300
301    else:
302        line_p = np.min(deg_df[p_val_scale])
303
304    plt.plot(
305        [max(deg_df["log(FC)"]) * -1.1, max(deg_df["log(FC)"]) * 1.1],
306        [line_p, line_p],
307        linestyle="--",
308        linewidth=3,
309        color="lightgray",
310        zorder=1,
311    )
312
313    if lfc > 0:
314        plt.plot(
315            [lfc * -1, lfc * -1],
316            [-3, max(deg_df[p_val_scale]) * 1.1],
317            linestyle="--",
318            linewidth=3,
319            color="lightgray",
320            zorder=1,
321        )
322        plt.plot(
323            [lfc, lfc],
324            [-3, max(deg_df[p_val_scale]) * 1.1],
325            linestyle="--",
326            linewidth=3,
327            color="lightgray",
328            zorder=1,
329        )
330
331    plt.xlabel("log(FC)")
332    plt.ylabel(p_val_scale)
333    plt.title("Volcano plot")
334
335    plt.ylim(min(deg_df[p_val_scale]) - 5, max(deg_df[p_val_scale]) * 1.25)
336
337    texts = [
338        ax.text(deg_df["log(FC)"][i], deg_df[p_val_scale][i], deg_df["feature"][i])
339        for i in deg_df.index
340        if deg_df["top100"][i] in ["green", "yellow"]
341    ]
342
343    adjust_text(texts, arrowprops=dict(arrowstyle="-", color="gray", alpha=0.5))
344
345    legend_elements = [
346        Line2D(
347            [0],
348            [0],
349            marker="o",
350            color="w",
351            label="top-upregulated",
352            markerfacecolor="green",
353            markersize=10,
354        ),
355        Line2D(
356            [0],
357            [0],
358            marker="o",
359            color="w",
360            label="top-downregulated",
361            markerfacecolor="yellow",
362            markersize=10,
363        ),
364        Line2D(
365            [0],
366            [0],
367            marker="o",
368            color="w",
369            label="upregulated",
370            markerfacecolor="blue",
371            markersize=10,
372        ),
373        Line2D(
374            [0],
375            [0],
376            marker="o",
377            color="w",
378            label="downregulated",
379            markerfacecolor="red",
380            markersize=10,
381        ),
382        Line2D(
383            [0],
384            [0],
385            marker="o",
386            color="w",
387            label="non-significant",
388            markerfacecolor="lightgray",
389            markersize=10,
390        ),
391    ]
392
393    ax.legend(handles=legend_elements, loc="upper right")
394    ax.grid(visible=False)
395
396    ax.annotate(
397        f"\nmin {pv} = " + str(p_val),
398        xy=(0.025, 0.975),
399        xycoords="axes fraction",
400        fontsize=12,
401    )
402
403    if lfc > 0:
404        ax.annotate(
405            "\nmin log(FC) = " + str(lfc),
406            xy=(0.025, 0.95),
407            xycoords="axes fraction",
408            fontsize=12,
409        )
410
411    ax.annotate(
412        "\nDownregulated: " + str(down_int),
413        xy=(0.025, 0.925),
414        xycoords="axes fraction",
415        fontsize=12,
416        color="red",
417    )
418
419    ax.annotate(
420        "\nUpregulated: " + str(up_int),
421        xy=(0.025, 0.9),
422        xycoords="axes fraction",
423        fontsize=12,
424        color="blue",
425    )
426
427    plt.show()
428
429    return fig

Generate a volcano plot from differential expression results.

A volcano plot visualizes the relationship between statistical significance (p-values or standarized p-value) and log(fold change) for each gene, highlighting genes that pass significance thresholds.

Parameters

deg_data : pandas.DataFrame DataFrame containing differential expression results from calc_DEG() function.

p_adj : bool, default=True If True, use adjusted p-values. If False, use raw p-values.

top : int, default=25 Number of top significant genes to highlight on the plot.

top_rank : str, default='p_value' Statistic used primarily to determine the top significant genes to highlight on the plot. ['p_value' or 'FC']

p_val : float | int, default=0.05 Significance threshold for p-values (or adjusted p-values).

lfc : float | int, default=0.25 Threshold for absolute log fold change.

rescale_adj : bool, default=True If True, rescale p-values to avoid long breaks caused by outlier values.

image_width : int, default=12 Width of the generated plot in inches.

image_high : int, default=12 Height of the generated plot in inches.

Returns

matplotlib.figure.Figure The generated volcano plot figure.

def find_features(data: pandas.core.frame.DataFrame, features: list):
432def find_features(data: pd.DataFrame, features: list):
433    """
434    Identify features (rows) from a DataFrame that match a given list of features,
435    ignoring case sensitivity.
436
437    Parameters
438    ----------
439    data : pandas.DataFrame
440        DataFrame with features in the index (rows).
441
442    features : list
443        List of feature names to search for.
444
445    Returns
446    -------
447    dict
448        Dictionary with keys:
449        - "included": list of features found in the DataFrame index.
450        - "not_included": list of requested features not found in the DataFrame index.
451        - "potential": list of features in the DataFrame that may be similar.
452    """
453
454    features_upper = [str(x).upper() for x in features]
455
456    index_set = set(data.index)
457
458    features_in = [x for x in index_set if x.upper() in features_upper]
459    features_in_upper = [x.upper() for x in features_in]
460    features_out_upper = [x for x in features_upper if x not in features_in_upper]
461    features_out = [x for x in features if x.upper() in features_out_upper]
462    similar_features = [
463        idx for idx in index_set if any(x in idx.upper() for x in features_out_upper)
464    ]
465
466    return {
467        "included": features_in,
468        "not_included": features_out,
469        "potential": similar_features,
470    }

Identify features (rows) from a DataFrame that match a given list of features, ignoring case sensitivity.

Parameters

data : pandas.DataFrame DataFrame with features in the index (rows).

features : list List of feature names to search for.

Returns

dict Dictionary with keys: - "included": list of features found in the DataFrame index. - "not_included": list of requested features not found in the DataFrame index. - "potential": list of features in the DataFrame that may be similar.

def find_names(data: pandas.core.frame.DataFrame, names: list):
473def find_names(data: pd.DataFrame, names: list):
474    """
475    Identify names (columns) from a DataFrame that match a given list of names,
476    ignoring case sensitivity.
477
478    Parameters
479    ----------
480    data : pandas.DataFrame
481        DataFrame with names in the columns.
482
483    names : list
484        List of names to search for.
485
486    Returns
487    -------
488    dict
489        Dictionary with keys:
490        - "included": list of names found in the DataFrame columns.
491        - "not_included": list of requested names not found in the DataFrame columns.
492        - "potential": list of names in the DataFrame that may be similar.
493    """
494
495    names_upper = [str(x).upper() for x in names]
496
497    columns = set(data.columns)
498
499    names_in = [x for x in columns if x.upper() in names_upper]
500    names_in_upper = [x.upper() for x in names_in]
501    names_out_upper = [x for x in names_upper if x not in names_in_upper]
502    names_out = [x for x in names if x.upper() in names_out_upper]
503    similar_names = [
504        idx for idx in columns if any(x in idx.upper() for x in names_out_upper)
505    ]
506
507    return {"included": names_in, "not_included": names_out, "potential": similar_names}

Identify names (columns) from a DataFrame that match a given list of names, ignoring case sensitivity.

Parameters

data : pandas.DataFrame DataFrame with names in the columns.

names : list List of names to search for.

Returns

dict Dictionary with keys: - "included": list of names found in the DataFrame columns. - "not_included": list of requested names not found in the DataFrame columns. - "potential": list of names in the DataFrame that may be similar.

def reduce_data( data: pandas.core.frame.DataFrame, features: list = [], names: list = []):
510def reduce_data(data: pd.DataFrame, features: list = [], names: list = []):
511    """
512    Subset a DataFrame based on selected features (rows) and/or names (columns).
513
514    Parameters
515    ----------
516    data : pandas.DataFrame
517        Input DataFrame with features as rows and names as columns.
518
519    features : list
520        List of features to include (rows). Default is an empty list.
521        If empty, all rows are returned.
522
523    names : list
524        List of names to include (columns). Default is an empty list.
525        If empty, all columns are returned.
526
527    Returns
528    -------
529    pandas.DataFrame
530        Subset of the input DataFrame containing only the selected rows
531        and/or columns.
532
533    Raises
534    ------
535    ValueError
536        If both `features` and `names` are empty.
537    """
538
539    if len(features) > 0 and len(names) > 0:
540        fet = find_features(data=data, features=features)
541
542        nam = find_names(data=data, names=names)
543
544        data_to_return = data.loc[fet["included"], nam["included"]]
545
546    elif len(features) > 0 and len(names) == 0:
547        fet = find_features(data=data, features=features)
548
549        data_to_return = data.loc[fet["included"], :]
550
551    elif len(features) == 0 and len(names) > 0:
552
553        nam = find_names(data=data, names=names)
554
555        data_to_return = data.loc[:, nam["included"]]
556
557    else:
558
559        raise ValueError("features and names have zero length!")
560
561    return data_to_return

Subset a DataFrame based on selected features (rows) and/or names (columns).

Parameters

data : pandas.DataFrame Input DataFrame with features as rows and names as columns.

features : list List of features to include (rows). Default is an empty list. If empty, all rows are returned.

names : list List of names to include (columns). Default is an empty list. If empty, all columns are returned.

Returns

pandas.DataFrame Subset of the input DataFrame containing only the selected rows and/or columns.

Raises

ValueError If both features and names are empty.

def make_unique_list(lst):
564def make_unique_list(lst):
565    """
566    Generate a list where duplicate items are renamed to ensure uniqueness.
567
568    Each duplicate is appended with a suffix ".n", where n indicates the
569    occurrence count (starting from 1).
570
571    Parameters
572    ----------
573    lst : list
574        Input list of items (strings or other hashable types).
575
576    Returns
577    -------
578    list
579        List with unique values.
580
581    Examples
582    --------
583    >>> make_unique_list(["A", "B", "A", "A"])
584    ['A', 'B', 'A.1', 'A.2']
585    """
586    seen = {}
587    result = []
588    for item in lst:
589        if item not in seen:
590            seen[item] = 0
591            result.append(item)
592        else:
593            seen[item] += 1
594            result.append(f"{item}.{seen[item]}")
595    return result

Generate a list where duplicate items are renamed to ensure uniqueness.

Each duplicate is appended with a suffix ".n", where n indicates the occurrence count (starting from 1).

Parameters

lst : list Input list of items (strings or other hashable types).

Returns

list List with unique values.

Examples

>>> make_unique_list(["A", "B", "A", "A"])
['A', 'B', 'A.1', 'A.2']
def get_color_palette(variable_list, palette_name='tab10'):
598def get_color_palette(variable_list, palette_name="tab10"):
599    n = len(variable_list)
600    cmap = plt.get_cmap(palette_name)
601    colors = [cmap(i % cmap.N) for i in range(n)]
602    return dict(zip(variable_list, colors))
def features_scatter( expression_data: pandas.core.frame.DataFrame, occurence_data: pandas.core.frame.DataFrame | None = None, scale: bool = False, features: list | None = None, metadata_list: list | None = None, colors: str = 'viridis', hclust: str | None = 'complete', img_width: int = 8, img_high: int = 5, label_size: int = 10, size_scale: int = 100, y_lab: str = 'Genes', legend_lab: str = 'log(CPM + 1)', set_box_size: float | int = 5, set_box_high: float | int = 5, bbox_to_anchor_scale: int = 25, bbox_to_anchor_perc: tuple = (0.91, 0.63), bbox_to_anchor_group: tuple = (1.01, 0.4)):
605def features_scatter(
606    expression_data: pd.DataFrame,
607    occurence_data: pd.DataFrame | None = None,
608    scale: bool = False,
609    features: list | None = None,
610    metadata_list: list | None = None,
611    colors: str = "viridis",
612    hclust: str | None = "complete",
613    img_width: int = 8,
614    img_high: int = 5,
615    label_size: int = 10,
616    size_scale: int = 100,
617    y_lab: str = "Genes",
618    legend_lab: str = "log(CPM + 1)",
619    set_box_size: float | int = 5,
620    set_box_high: float | int = 5,
621    bbox_to_anchor_scale: int = 25,
622    bbox_to_anchor_perc: tuple = (0.91, 0.63),
623    bbox_to_anchor_group: tuple = (1.01, 0.4),
624):
625    """
626    Create a bubble scatter plot of selected features across samples.
627
628    Each point represents a feature-sample pair, where the color encodes the
629    expression value and the size encodes occurrence or relative abundance.
630    Optionally, hierarchical clustering can be applied to order rows and columns.
631
632    Parameters
633    ----------
634    expression_data : pandas.DataFrame
635        Expression values (mean) with features as rows and samples as columns derived from average() function.
636
637    occurence_data : pandas.DataFrame or None
638        DataFrame with occurrence/frequency values (same shape as `expression_data`) derived from occurrence() function.
639        If None, bubble sizes are based on expression values.
640
641    scale: bool, default False
642        If True, expression_data (features) will be scaled (0–1) across the colums (sample).
643
644    features : list or None
645        List of features (rows) to display. If None, all features are used.
646
647    metadata_list : list or None, optional
648        Metadata grouping for samples (same length as number of columns).
649        Used to add group colors and separators in the plot.
650
651    colors : str, default='viridis'
652        Colormap for expression values.
653
654    hclust : str or None, default='complete'
655        Linkage method for hierarchical clustering. If None, no clustering
656        is performed.
657
658    img_width : int or float, default=8
659        Width of the plot in inches.
660
661    img_high : int or float, default=5
662        Height of the plot in inches.
663
664    label_size : int, default=10
665        Font size for axis labels and ticks.
666
667    size_scale : int or float, default=100
668        Scaling factor for bubble sizes.
669
670    y_lab : str, default='Genes'
671        Label for the x-axis.
672
673    legend_lab : str, default='log(CPM + 1)'
674        Label for the colorbar legend.
675
676    bbox_to_anchor_scale : int, default=25
677        Vertical scale (percentage) for positioning the colorbar.
678
679    bbox_to_anchor_perc : tuple, default=(0.91, 0.63)
680        Anchor position for the size legend (percent bubble legend).
681
682    bbox_to_anchor_group : tuple, default=(1.01, 0.4)
683        Anchor position for the group legend.
684
685    Returns
686    -------
687    matplotlib.figure.Figure
688        The generated scatter plot figure.
689
690    Raises
691    ------
692    ValueError
693        If `metadata_list` is provided but its length does not match
694        the number of columns in `expression_data`.
695
696    Notes
697    -----
698    - Colors represent expression values normalized to the colormap.
699    - Bubble sizes represent occurrence values (or expression values if
700      `occurence_data` is None).
701    - If `metadata_list` is given, groups are indicated with colors and
702      dashed vertical separators.
703    """
704
705    scatter_df = expression_data.copy()
706
707    if scale:
708
709        legend_lab = "Scaled\n" + legend_lab
710
711        column_max = scatter_df.max()
712        scatter_df = scatter_df.div(column_max).replace([np.inf, -np.inf], np.nan).fillna(0)
713        scatter_df = pd.DataFrame(scatter_df, index=scatter_df.index, columns=scatter_df.columns)
714
715
716    metadata = {}
717
718    metadata["primary_names"] = [str(x) for x in scatter_df.columns]
719
720    if metadata_list is not None:
721        metadata["sets"] = metadata_list
722
723        if len(metadata["primary_names"]) != len(metadata["sets"]):
724
725            raise ValueError(
726                "Metadata list and DataFrame columns must have the same length."
727            )
728
729    else:
730
731        metadata["sets"] = [""] * len(metadata["primary_names"])
732
733    metadata = pd.DataFrame(metadata)
734    if features is not None:
735        scatter_df = scatter_df.loc[
736            find_features(data=scatter_df, features=features)["included"],
737        ]
738    scatter_df.columns = metadata["primary_names"] + "#" + metadata["sets"]
739
740    if occurence_data is not None:
741        if features is not None:
742            occurence_data = occurence_data.loc[
743                find_features(data=occurence_data, features=features)["included"],
744            ]
745        occurence_data.columns = metadata["primary_names"] + "#" + metadata["sets"]
746
747    # check duplicated names
748
749    tmp_columns = scatter_df.columns
750
751    new_cols = make_unique_list(list(tmp_columns))
752
753    scatter_df.columns = new_cols
754
755    if hclust is not None and len(expression_data.index) != 1:
756
757        Z = linkage(scatter_df, method=hclust)
758
759        # Get the order of features based on the dendrogram
760        order_of_features = dendrogram(Z, no_plot=True)["leaves"]
761
762        indexes_sort = list(scatter_df.index)
763        sorted_list_rows = []
764        for n in order_of_features:
765            sorted_list_rows.append(indexes_sort[n])
766
767        scatter_df = scatter_df.transpose()
768
769        Z = linkage(scatter_df, method=hclust)
770
771        # Get the order of features based on the dendrogram
772        order_of_features = dendrogram(Z, no_plot=True)["leaves"]
773
774        indexes_sort = list(scatter_df.index)
775        sorted_list_columns = []
776        for n in order_of_features:
777            sorted_list_columns.append(indexes_sort[n])
778
779        scatter_df = scatter_df.transpose()
780
781        scatter_df = scatter_df.loc[sorted_list_rows, sorted_list_columns]
782
783        if occurence_data is not None:
784            occurence_data = occurence_data.loc[sorted_list_rows, sorted_list_columns]
785
786        metadata["sets"] = [re.sub(".*#", "", x) for x in scatter_df.columns]
787
788    scatter_df.columns = [re.sub("#.*", "", x) for x in scatter_df.columns]
789
790    if occurence_data is not None:
791        occurence_data.columns = [re.sub("#.*", "", x) for x in occurence_data.columns]
792
793    fig, ax = plt.subplots(figsize=(img_width, img_high))
794
795    norm = plt.Normalize(0, np.max(scatter_df))
796
797    cmap = plt.get_cmap(colors)
798
799    # Bubble scatter
800    for i, _ in enumerate(scatter_df.index):
801        for j, _ in enumerate(scatter_df.columns):
802            if occurence_data is not None:
803                value_e = scatter_df.iloc[i, j]
804                value_o = occurence_data.iloc[i, j]
805                ax.scatter(
806                    j,
807                    i,
808                    s=value_o * size_scale,
809                    c=[cmap(norm(value_e))],
810                    edgecolors="k",
811                    linewidths=0.3,
812                )
813            else:
814                value = scatter_df.iloc[i, j]
815                ax.scatter(
816                    j,
817                    i,
818                    s=value * size_scale,
819                    c=[cmap(norm(value))],
820                    edgecolors="k",
821                    linewidths=0.3,
822                )
823
824    ax.set_yticks(range(len(scatter_df.index)))
825    ax.set_yticklabels(scatter_df.index, fontsize=label_size * 0.8)
826    ax.set_ylabel(y_lab, fontsize=label_size)
827    ax.set_xticks(range(len(scatter_df.columns)))
828    ax.set_xticklabels(scatter_df.columns, fontsize=label_size * 0.8, rotation=90)
829
830    ax_pos = ax.get_position()
831
832    width_fig = 0.01
833    height_fig = ax_pos.height * (bbox_to_anchor_scale / 100)
834    left_fig = ax_pos.x1 + 0.01
835    bottom_fig = ax_pos.y1 - height_fig
836
837    cax = fig.add_axes([left_fig, bottom_fig, width_fig, height_fig])
838    cb = mpl.colorbar.ColorbarBase(cax, cmap=cmap, norm=norm)
839    cb.set_label(legend_lab, fontsize=label_size * 0.65)
840    cb.ax.tick_params(labelsize=label_size * 0.7)
841
842    if metadata_list is not None:
843
844        metadata_list = list(metadata["sets"])
845        group_colors = get_color_palette(list(set(metadata_list)), palette_name="tab10")
846
847        for i, group in enumerate(metadata_list):
848            ax.add_patch(
849                plt.Rectangle(
850                    (i - 0.5, len(scatter_df.index) - 0.1 * set_box_high),
851                    1,
852                    0.1 * set_box_size,
853                    color=group_colors[group],
854                    transform=ax.transData,
855                    clip_on=False,
856                )
857            )
858
859        for i in range(1, len(metadata_list)):
860            if metadata_list[i] != metadata_list[i - 1]:
861                ax.axvline(i - 0.5, color="black", linestyle="--", lw=1)
862
863        group_patches = [
864            mpatches.Patch(color=color, label=label)
865            for label, color in group_colors.items()
866        ]
867        fig.legend(
868            handles=group_patches,
869            title="Group",
870            fontsize=label_size * 0.7,
871            title_fontsize=label_size * 0.7,
872            loc="center left",
873            bbox_to_anchor=bbox_to_anchor_group,
874            frameon=False,
875        )
876
877    # second legend (size)
878    if occurence_data is not None:
879        size_values = [0.25, 0.5, 1]
880        legend2_handles = [
881            plt.Line2D(
882                [],
883                [],
884                marker="o",
885                linestyle="",
886                markersize=np.sqrt(v * size_scale * 0.5),
887                color="gray",
888                alpha=0.6,
889                label=f"{v * 100:.1f}",
890            )
891            for v in size_values
892        ]
893
894        fig.legend(
895            handles=legend2_handles,
896            title="Percent [%]",
897            fontsize=label_size * 0.7,
898            title_fontsize=label_size * 0.7,
899            loc="center left",
900            bbox_to_anchor=bbox_to_anchor_perc,
901            frameon=False,
902        )
903
904    _, ymax = ax.get_ylim()
905
906    ax.set_xlim(-0.5, len(scatter_df.columns) - 0.5)
907    ax.set_ylim(-0.5, ymax + 0.5)
908
909    return fig

Create a bubble scatter plot of selected features across samples.

Each point represents a feature-sample pair, where the color encodes the expression value and the size encodes occurrence or relative abundance. Optionally, hierarchical clustering can be applied to order rows and columns.

Parameters

expression_data : pandas.DataFrame Expression values (mean) with features as rows and samples as columns derived from average() function.

occurence_data : pandas.DataFrame or None DataFrame with occurrence/frequency values (same shape as expression_data) derived from occurrence() function. If None, bubble sizes are based on expression values.

scale: bool, default False If True, expression_data (features) will be scaled (0–1) across the colums (sample).

features : list or None List of features (rows) to display. If None, all features are used.

metadata_list : list or None, optional Metadata grouping for samples (same length as number of columns). Used to add group colors and separators in the plot.

colors : str, default='viridis' Colormap for expression values.

hclust : str or None, default='complete' Linkage method for hierarchical clustering. If None, no clustering is performed.

img_width : int or float, default=8 Width of the plot in inches.

img_high : int or float, default=5 Height of the plot in inches.

label_size : int, default=10 Font size for axis labels and ticks.

size_scale : int or float, default=100 Scaling factor for bubble sizes.

y_lab : str, default='Genes' Label for the x-axis.

legend_lab : str, default='log(CPM + 1)' Label for the colorbar legend.

bbox_to_anchor_scale : int, default=25 Vertical scale (percentage) for positioning the colorbar.

bbox_to_anchor_perc : tuple, default=(0.91, 0.63) Anchor position for the size legend (percent bubble legend).

bbox_to_anchor_group : tuple, default=(1.01, 0.4) Anchor position for the group legend.

Returns

matplotlib.figure.Figure The generated scatter plot figure.

Raises

ValueError If metadata_list is provided but its length does not match the number of columns in expression_data.

Notes

  • Colors represent expression values normalized to the colormap.
  • Bubble sizes represent occurrence values (or expression values if occurence_data is None).
  • If metadata_list is given, groups are indicated with colors and dashed vertical separators.
def calc_DEG( data, metadata_list: list | None = None, entities: str | list | dict | None = None, sets: str | list | dict | None = None, min_exp: int | float = 0, min_pct: int | float = 0.1, n_proc: int = 10):
 912def calc_DEG(
 913    data,
 914    metadata_list: list | None = None,
 915    entities: str | list | dict | None = None,
 916    sets: str | list | dict | None = None,
 917    min_exp: int | float = 0,
 918    min_pct: int | float = 0.1,
 919    n_proc: int = 10,
 920):
 921    """
 922    Perform differential gene expression (DEG) analysis on gene expression data.
 923
 924    The function compares groups of cells or samples (defined by `entities` or
 925    `sets`) using the Mann–Whitney U test. It computes p-values, adjusted
 926    p-values, fold changes, standardized effect sizes, and other statistics.
 927
 928    Parameters
 929    ----------
 930    data : pandas.DataFrame
 931        Expression matrix with features (e.g., genes) as rows and samples/cells
 932        as columns.
 933
 934    metadata_list : list or None, optional
 935        Metadata grouping corresponding to the columns in `data`. Required for
 936        comparisons based on sets. Default is None.
 937
 938    entities : list, str, dict, or None, optional
 939        Defines the comparison strategy:
 940        - list of sample names → compare selected cells to the rest.
 941        - 'All' → compare each sample/cell to all others.
 942        - dict → user-defined groups for pairwise comparison.
 943        - None → must be combined with `sets`.
 944
 945    sets : str, dict, or None, optional
 946        Defines group-based comparisons:
 947        - 'All' → compare each set/group to all others.
 948        - dict with two groups → perform pairwise set comparison.
 949        - None → must be combined with `entities`.
 950
 951    min_exp : float | int, default=0
 952        Minimum expression threshold for filtering features.
 953
 954    min_pct : float | int, default=0.1
 955        Minimum proportion of samples within the target group that must express
 956        a feature for it to be tested.
 957
 958    n_proc : int, default=10
 959        Number of parallel processes to use for statistical testing.
 960
 961    Returns
 962    -------
 963    pandas.DataFrame or dict
 964        Results of the differential expression analysis:
 965        - If `entities` is a list → dict with keys: 'valid_cells',
 966          'control_cells', and 'DEG' (results DataFrame).
 967        - If `entities == 'All'` or `sets == 'All'` → DataFrame with results
 968          for all groups.
 969        - If pairwise comparison (dict for `entities` or `sets`) → DataFrame
 970          with results for the specified groups.
 971
 972        The results DataFrame contains:
 973        - 'feature': feature name
 974        - 'p_val': raw p-value
 975        - 'adj_pval': adjusted p-value (multiple testing correction)
 976        - 'pct_valid': fraction of target group expressing the feature
 977        - 'pct_ctrl': fraction of control group expressing the feature
 978        - 'avg_valid': mean expression in target group
 979        - 'avg_ctrl': mean expression in control group
 980        - 'sd_valid': standard deviation in target group
 981        - 'sd_ctrl': standard deviation in control group
 982        - 'esm': effect size metric
 983        - 'FC': fold change
 984        - 'log(FC)': log2-transformed fold change
 985        - 'norm_diff': difference in mean expression
 986
 987    Raises
 988    ------
 989    ValueError
 990        - If `metadata_list` is provided but its length does not match
 991          the number of columns in `data`.
 992        - If neither `entities` nor `sets` is provided.
 993
 994    Notes
 995    -----
 996    - Mann–Whitney U test is used for group comparisons.
 997    - Multiple testing correction is applied using a simple
 998      Benjamini–Hochberg-like method.
 999    - Features expressed below `min_exp` or in fewer than `min_pct` of target
1000      samples are filtered out.
1001    - Parallelization is handled by `joblib.Parallel`.
1002
1003    Examples
1004    --------
1005    Compare a selected list of cells against all others:
1006
1007    >>> result = calc_DEG(data, entities=["cell1", "cell2", "cell3"])
1008
1009    Compare each group to others (based on metadata):
1010
1011    >>> result = calc_DEG(data, metadata_list=group_labels, sets="All")
1012
1013    Perform pairwise comparison between two predefined sets:
1014
1015    >>> sets = {"GroupA": ["A1", "A2"], "GroupB": ["B1", "B2"]}
1016    >>> result = calc_DEG(data, sets=sets)
1017    """
1018    offset = 1e-100
1019
1020    metadata = {}
1021
1022    metadata["primary_names"] = [str(x) for x in data.columns]
1023
1024    if metadata_list is not None:
1025        metadata["sets"] = metadata_list
1026
1027        if len(metadata["primary_names"]) != len(metadata["sets"]):
1028
1029            raise ValueError(
1030                "Metadata list and DataFrame columns must have the same length."
1031            )
1032
1033    else:
1034
1035        metadata["sets"] = [""] * len(metadata["primary_names"])
1036
1037    metadata = pd.DataFrame(metadata)
1038
1039    def stat_calc(choose, feature_name):
1040        target_values = choose.loc[choose["DEG"] == "target", feature_name]
1041        rest_values = choose.loc[choose["DEG"] == "rest", feature_name]
1042
1043        pct_valid = (target_values > 0).sum() / len(target_values)
1044        pct_rest = (rest_values > 0).sum() / len(rest_values)
1045
1046        avg_valid = np.mean(target_values)
1047        avg_ctrl = np.mean(rest_values)
1048        sd_valid = np.std(target_values, ddof=1)
1049        sd_ctrl = np.std(rest_values, ddof=1)
1050        esm = (avg_valid - avg_ctrl) / np.sqrt(((sd_valid**2 + sd_ctrl**2) / 2))
1051
1052        if np.sum(target_values) == np.sum(rest_values):
1053            p_val = 1.0
1054        else:
1055            _, p_val = stats.mannwhitneyu(
1056                target_values, rest_values, alternative="two-sided"
1057            )
1058
1059        return {
1060            "feature": feature_name,
1061            "p_val": p_val,
1062            "pct_valid": pct_valid,
1063            "pct_ctrl": pct_rest,
1064            "avg_valid": avg_valid,
1065            "avg_ctrl": avg_ctrl,
1066            "sd_valid": sd_valid,
1067            "sd_ctrl": sd_ctrl,
1068            "esm": esm,
1069        }
1070
1071    def prepare_and_run_stat(choose, valid_group, min_exp, min_pct, n_proc):
1072
1073        def safe_min_half(series):
1074            filtered = series[(series > ((2**-1074)*2)) & (series.notna())]
1075            return filtered.min() / 2 if not filtered.empty else 0
1076
1077        tmp_dat = choose[choose["DEG"] == "target"]
1078        tmp_dat = tmp_dat.drop("DEG", axis=1)
1079
1080        counts = (tmp_dat > min_exp).sum(axis=0)
1081
1082        total_count = tmp_dat.shape[0]
1083
1084        info = pd.DataFrame(
1085            {"feature": list(tmp_dat.columns), "pct": list(counts / total_count)}
1086        )
1087
1088        del tmp_dat
1089
1090        drop_col = info["feature"][info["pct"] <= min_pct]
1091
1092        if len(drop_col) + 1 == len(choose.columns):
1093            drop_col = info["feature"][info["pct"] == 0]
1094
1095        del info
1096
1097        choose = choose.drop(list(drop_col), axis=1)
1098
1099        results = Parallel(n_jobs=n_proc)(
1100            delayed(stat_calc)(choose, feature)
1101            for feature in tqdm(choose.columns[choose.columns != "DEG"])
1102        )
1103
1104        if len(results) > 0:
1105            df = pd.DataFrame(results)
1106
1107            df = df[(df["avg_valid"] > 0) | (df["avg_ctrl"] > 0)]
1108
1109            df["valid_group"] = valid_group
1110            df.sort_values(by="p_val", inplace=True)
1111
1112            num_tests = len(df)
1113            df["adj_pval"] = np.minimum(
1114                1, (df["p_val"] * num_tests) / np.arange(1, num_tests + 1)
1115            )             
1116
1117            valid_factor = safe_min_half(df["avg_valid"])
1118            ctrl_factor = safe_min_half(df["avg_ctrl"])
1119
1120            cv_factor = min(valid_factor, ctrl_factor)
1121
1122            if cv_factor == 0:
1123                cv_factor = max(valid_factor, ctrl_factor)
1124
1125            if not np.isfinite(cv_factor) or cv_factor == 0:
1126                cv_factor += offset
1127
1128            valid = df["avg_valid"].where(
1129                df["avg_valid"] != 0, df["avg_valid"] + cv_factor
1130            )
1131            ctrl = df["avg_ctrl"].where(
1132                df["avg_ctrl"] != 0, df["avg_ctrl"] + cv_factor
1133            )
1134
1135            df["FC"] = valid / ctrl
1136
1137            df["log(FC)"] = np.log2(df["FC"])
1138            df["norm_diff"] = df["avg_valid"] - df["avg_ctrl"]
1139
1140        else:
1141            columns = [
1142                "feature",
1143                "valid_group",
1144                "p_val",
1145                "adj_pval",
1146                "avg_valid",
1147                "avg_ctrl",
1148                "FC",
1149                "log(FC)",
1150                "norm_diff",
1151            ]
1152            df = pd.DataFrame(columns=columns)
1153        return df
1154
1155    choose = data.T
1156
1157    final_results = []
1158
1159    if isinstance(entities, list) and sets is None:
1160        print("\nAnalysis started...\nComparing selected cells to the whole set...")
1161
1162        if metadata_list is None:
1163            choose.index = metadata["primary_names"]
1164        else:
1165            choose.index = metadata["primary_names"] + " # " + metadata["sets"]
1166
1167            if "#" not in entities[0]:
1168                choose.index = metadata["primary_names"]
1169                print(
1170                    "You provided 'metadata_list', but did not include the set info (name # set) "
1171                    "in the 'entities' list. "
1172                    "Only the names will be compared, without considering the set information."
1173                )
1174
1175        labels = ["target" if idx in entities else "rest" for idx in choose.index]
1176        valid = list(
1177            set(choose.index[[i for i, x in enumerate(labels) if x == "target"]])
1178        )
1179
1180        choose["DEG"] = labels
1181        choose = choose[choose["DEG"] != "drop"]
1182
1183        result_df = prepare_and_run_stat(
1184            choose.reset_index(drop=True),
1185            valid_group=valid,
1186            min_exp=min_exp,
1187            min_pct=min_pct,
1188            n_proc=n_proc,
1189        )
1190
1191        return {"valid": valid, "control": "rest", "DEG": result_df}
1192
1193    elif entities == "All" and sets is None:
1194        print("\nAnalysis started...\nComparing each type of cell to others...")
1195
1196        if metadata_list is None:
1197            choose.index = metadata["primary_names"]
1198        else:
1199            choose.index = metadata["primary_names"] + " # " + metadata["sets"]
1200
1201        unique_labels = set(choose.index)
1202
1203        for label in tqdm(unique_labels):
1204            print(f"\nCalculating statistics for {label}")
1205            labels = ["target" if idx == label else "rest" for idx in choose.index]
1206            choose["DEG"] = labels
1207            choose = choose[choose["DEG"] != "drop"]
1208            result_df = prepare_and_run_stat(
1209                choose.copy(),
1210                valid_group=label,
1211                min_exp=min_exp,
1212                min_pct=min_pct,
1213                n_proc=n_proc,
1214            )
1215            final_results.append(result_df)
1216
1217        final_results = pd.concat(final_results, ignore_index=True)
1218
1219        if metadata_list is None:
1220            final_results["valid_group"] = [
1221                re.sub(" # ", "", x) for x in final_results["valid_group"]
1222            ]
1223
1224        return final_results
1225
1226    elif entities is None and sets == "All":
1227        print("\nAnalysis started...\nComparing each set/group to others...")
1228        choose.index = metadata["sets"]
1229        unique_sets = set(choose.index)
1230
1231        for label in tqdm(unique_sets):
1232            print(f"\nCalculating statistics for {label}")
1233            labels = ["target" if idx == label else "rest" for idx in choose.index]
1234
1235            choose["DEG"] = labels
1236            choose = choose[choose["DEG"] != "drop"]
1237            result_df = prepare_and_run_stat(
1238                choose.copy(),
1239                valid_group=label,
1240                min_exp=min_exp,
1241                min_pct=min_pct,
1242                n_proc=n_proc,
1243            )
1244            final_results.append(result_df)
1245
1246        return pd.concat(final_results, ignore_index=True)
1247
1248    elif entities is None and isinstance(sets, dict):
1249        print("\nAnalysis started...\nComparing groups...")
1250        choose.index = metadata["sets"]
1251
1252        group_list = list(sets.keys())
1253        if len(group_list) != 2:
1254            print("Only pairwise group comparison is supported.")
1255            return None
1256
1257        labels = [
1258            (
1259                "target"
1260                if idx in sets[group_list[0]]
1261                else "rest" if idx in sets[group_list[1]] else "drop"
1262            )
1263            for idx in choose.index
1264        ]
1265        choose["DEG"] = labels
1266        choose = choose[choose["DEG"] != "drop"]
1267
1268        result_df = prepare_and_run_stat(
1269            choose.reset_index(drop=True),
1270            valid_group=group_list[0],
1271            min_exp=min_exp,
1272            min_pct=min_pct,
1273            n_proc=n_proc,
1274        )
1275        return result_df
1276
1277    elif isinstance(entities, dict) and sets is None:
1278        print("\nAnalysis started...\nComparing groups...")
1279
1280        if metadata_list is None:
1281            choose.index = metadata["primary_names"]
1282        else:
1283            choose.index = metadata["primary_names"] + " # " + metadata["sets"]
1284            if "#" not in entities[list(entities.keys())[0]][0]:
1285                choose.index = metadata["primary_names"]
1286                print(
1287                    "You provided 'metadata_list', but did not include the set info (name # set) "
1288                    "in the 'entities' dict. "
1289                    "Only the names will be compared, without considering the set information."
1290                )
1291
1292        group_list = list(entities.keys())
1293        if len(group_list) != 2:
1294            print("Only pairwise group comparison is supported.")
1295            return None
1296
1297        labels = [
1298            (
1299                "target"
1300                if idx in entities[group_list[0]]
1301                else "rest" if idx in entities[group_list[1]] else "drop"
1302            )
1303            for idx in choose.index
1304        ]
1305
1306        choose["DEG"] = labels
1307        choose = choose[choose["DEG"] != "drop"]
1308
1309        result_df = prepare_and_run_stat(
1310            choose.reset_index(drop=True),
1311            valid_group=group_list[0],
1312            min_exp=min_exp,
1313            min_pct=min_pct,
1314            n_proc=n_proc,
1315        )
1316
1317        return result_df.reset_index(drop=True)
1318
1319    else:
1320        raise ValueError(
1321            "You must specify either 'entities' or 'sets'. None were provided, which is not allowed for this analysis."
1322        )

Perform differential gene expression (DEG) analysis on gene expression data.

The function compares groups of cells or samples (defined by entities or sets) using the Mann–Whitney U test. It computes p-values, adjusted p-values, fold changes, standardized effect sizes, and other statistics.

Parameters

data : pandas.DataFrame Expression matrix with features (e.g., genes) as rows and samples/cells as columns.

metadata_list : list or None, optional Metadata grouping corresponding to the columns in data. Required for comparisons based on sets. Default is None.

entities : list, str, dict, or None, optional Defines the comparison strategy: - list of sample names → compare selected cells to the rest. - 'All' → compare each sample/cell to all others. - dict → user-defined groups for pairwise comparison. - None → must be combined with sets.

sets : str, dict, or None, optional Defines group-based comparisons: - 'All' → compare each set/group to all others. - dict with two groups → perform pairwise set comparison. - None → must be combined with entities.

min_exp : float | int, default=0 Minimum expression threshold for filtering features.

min_pct : float | int, default=0.1 Minimum proportion of samples within the target group that must express a feature for it to be tested.

n_proc : int, default=10 Number of parallel processes to use for statistical testing.

Returns

pandas.DataFrame or dict Results of the differential expression analysis: - If entities is a list → dict with keys: 'valid_cells', 'control_cells', and 'DEG' (results DataFrame). - If entities == 'All' or sets == 'All' → DataFrame with results for all groups. - If pairwise comparison (dict for entities or sets) → DataFrame with results for the specified groups.

The results DataFrame contains:
- 'feature': feature name
- 'p_val': raw p-value
- 'adj_pval': adjusted p-value (multiple testing correction)
- 'pct_valid': fraction of target group expressing the feature
- 'pct_ctrl': fraction of control group expressing the feature
- 'avg_valid': mean expression in target group
- 'avg_ctrl': mean expression in control group
- 'sd_valid': standard deviation in target group
- 'sd_ctrl': standard deviation in control group
- 'esm': effect size metric
- 'FC': fold change
- 'log(FC)': log2-transformed fold change
- 'norm_diff': difference in mean expression

Raises

ValueError - If metadata_list is provided but its length does not match the number of columns in data. - If neither entities nor sets is provided.

Notes

  • Mann–Whitney U test is used for group comparisons.
  • Multiple testing correction is applied using a simple Benjamini–Hochberg-like method.
  • Features expressed below min_exp or in fewer than min_pct of target samples are filtered out.
  • Parallelization is handled by joblib.Parallel.

Examples

Compare a selected list of cells against all others:

>>> result = calc_DEG(data, entities=["cell1", "cell2", "cell3"])

Compare each group to others (based on metadata):

>>> result = calc_DEG(data, metadata_list=group_labels, sets="All")

Perform pairwise comparison between two predefined sets:

>>> sets = {"GroupA": ["A1", "A2"], "GroupB": ["B1", "B2"]}
>>> result = calc_DEG(data, sets=sets)
def average(data):
1325def average(data):
1326    """
1327    Compute the column-wise average of a DataFrame, aggregating by column names.
1328
1329    If multiple columns share the same name, their values are averaged.
1330
1331    Parameters
1332    ----------
1333    data : pandas.DataFrame
1334        Input DataFrame with numeric values. Columns with identical names
1335        will be aggregated by their mean.
1336
1337    Returns
1338    -------
1339    pandas.DataFrame
1340        DataFrame with the same rows as the input but with unique columns,
1341        where duplicate columns have been replaced by their mean values.
1342    """
1343
1344    wide_data = data
1345
1346    aggregated_df = wide_data.T.groupby(level=0).mean().T
1347
1348    return aggregated_df

Compute the column-wise average of a DataFrame, aggregating by column names.

If multiple columns share the same name, their values are averaged.

Parameters

data : pandas.DataFrame Input DataFrame with numeric values. Columns with identical names will be aggregated by their mean.

Returns

pandas.DataFrame DataFrame with the same rows as the input but with unique columns, where duplicate columns have been replaced by their mean values.

def occurrence(data):
1351def occurrence(data):
1352    """
1353    Calculate the occurrence frequency of features in a DataFrame.
1354
1355    Converts the input DataFrame to binary (presence/absence) and computes
1356    the proportion of non-zero entries for each feature, aggregating by
1357    column names if duplicates exist.
1358
1359    Parameters
1360    ----------
1361    data : pandas.DataFrame
1362        Input DataFrame with numeric values. Each column represents a feature.
1363
1364    Returns
1365    -------
1366    pandas.DataFrame
1367        DataFrame with the same rows as the input, where each value represents
1368        the proportion of samples in which the feature is present (non-zero).
1369        Columns with identical names are aggregated.
1370    """
1371
1372    binary_data = (data > 0).astype(int)
1373
1374    counts = binary_data.columns.value_counts()
1375
1376    binary_data = binary_data.T.groupby(level=0).sum().T
1377    binary_data = binary_data.astype(float)
1378
1379    for i in counts.index:
1380        binary_data.loc[:, i] = (binary_data.loc[:, i] / counts[i]).astype(float)
1381
1382    return binary_data

Calculate the occurrence frequency of features in a DataFrame.

Converts the input DataFrame to binary (presence/absence) and computes the proportion of non-zero entries for each feature, aggregating by column names if duplicates exist.

Parameters

data : pandas.DataFrame Input DataFrame with numeric values. Each column represents a feature.

Returns

pandas.DataFrame DataFrame with the same rows as the input, where each value represents the proportion of samples in which the feature is present (non-zero). Columns with identical names are aggregated.

def add_subnames(names_list: list, parent_name: str, new_clusters: list):
1385def add_subnames(names_list: list, parent_name: str, new_clusters: list):
1386    """
1387    Append sub-cluster names to a parent name within a list of names.
1388
1389    This function replaces occurrences of `parent_name` in `names_list` with
1390    a concatenation of the parent name and corresponding sub-cluster name
1391    from `new_clusters` (formatted as "parent.subcluster"). Non-matching names
1392    are left unchanged.
1393
1394    Parameters
1395    ----------
1396    names_list : list
1397        Original list of names (e.g., column names or cluster labels).
1398
1399    parent_name : str
1400        Name of the parent cluster to which sub-cluster names will be added.
1401        Must exist in `names_list`.
1402
1403    new_clusters : list
1404        List of sub-cluster names. Its length must match the number of times
1405        `parent_name` occurs in `names_list`.
1406
1407    Returns
1408    -------
1409    list
1410        Updated list of names with sub-cluster names appended to the parent name.
1411
1412    Raises
1413    ------
1414    ValueError
1415        - If `parent_name` is not found in `names_list`.
1416        - If `new_clusters` length does not match the number of occurrences of
1417          `parent_name`.
1418
1419    Examples
1420    --------
1421    >>> add_subnames(['A', 'B', 'A'], 'A', ['1', '2'])
1422    ['A.1', 'B', 'A.2']
1423    """
1424
1425    if str(parent_name) not in [str(x) for x in names_list]:
1426        raise ValueError(
1427            "Parent name is missing from the original dataset`s column names!"
1428        )
1429
1430    if len(new_clusters) != len([x for x in names_list if str(x) == str(parent_name)]):
1431        raise ValueError(
1432            "New cluster names list has a different length than the number of clusters in the original dataset!"
1433        )
1434
1435    new_names = []
1436    ixn = 0
1437    for _, i in enumerate(names_list):
1438        if str(i) == str(parent_name):
1439
1440            new_names.append(f"{parent_name}.{new_clusters[ixn]}")
1441            ixn += 1
1442
1443        else:
1444            new_names.append(i)
1445
1446    return new_names

Append sub-cluster names to a parent name within a list of names.

This function replaces occurrences of parent_name in names_list with a concatenation of the parent name and corresponding sub-cluster name from new_clusters (formatted as "parent.subcluster"). Non-matching names are left unchanged.

Parameters

names_list : list Original list of names (e.g., column names or cluster labels).

parent_name : str Name of the parent cluster to which sub-cluster names will be added. Must exist in names_list.

new_clusters : list List of sub-cluster names. Its length must match the number of times parent_name occurs in names_list.

Returns

list Updated list of names with sub-cluster names appended to the parent name.

Raises

ValueError - If parent_name is not found in names_list. - If new_clusters length does not match the number of occurrences of parent_name.

Examples

>>> add_subnames(['A', 'B', 'A'], 'A', ['1', '2'])
['A.1', 'B', 'A.2']
def development_clust( data: pandas.core.frame.DataFrame, method: str = 'ward', img_width: int = 5, img_high: int = 5):
1449def development_clust(
1450    data: pd.DataFrame, method: str = "ward", img_width: int = 5, img_high: int = 5
1451):
1452    """
1453    Perform hierarchical clustering on the columns of a DataFrame and plot a dendrogram.
1454
1455    Uses Ward's method to cluster the transposed data (columns) and generates
1456    a dendrogram showing the relationships between features or samples.
1457
1458    Parameters
1459    ----------
1460    data : pandas.DataFrame
1461        Input DataFrame with features as rows and samples/columns to be clustered.
1462
1463    method : str
1464        Method for hierarchical clustering. Options include:
1465       - 'ward' : minimizes the variance of clusters being merged.
1466       - 'single' : uses the minimum of the distances between all observations of the two sets.
1467       - 'complete' : uses the maximum of the distances between all observations of the two sets.
1468       - 'average' : uses the average of the distances between all observations of the two sets.
1469
1470    img_width : int or float, default=5
1471        Width of the resulting figure in inches.
1472
1473    img_high : int or float, default=5
1474        Height of the resulting figure in inches.
1475
1476    Returns
1477    -------
1478    matplotlib.figure.Figure
1479        The dendrogram figure.
1480    """
1481
1482    z = linkage(data.T, method=method)
1483
1484    figure, ax = plt.subplots(figsize=(img_width, img_high))
1485
1486    dendrogram(z, labels=data.columns, orientation="left", ax=ax)
1487
1488    return figure

Perform hierarchical clustering on the columns of a DataFrame and plot a dendrogram.

Uses Ward's method to cluster the transposed data (columns) and generates a dendrogram showing the relationships between features or samples.

Parameters

data : pandas.DataFrame Input DataFrame with features as rows and samples/columns to be clustered.

method : str Method for hierarchical clustering. Options include:

  • 'ward' : minimizes the variance of clusters being merged.
  • 'single' : uses the minimum of the distances between all observations of the two sets.
  • 'complete' : uses the maximum of the distances between all observations of the two sets.
  • 'average' : uses the average of the distances between all observations of the two sets.

img_width : int or float, default=5 Width of the resulting figure in inches.

img_high : int or float, default=5 Height of the resulting figure in inches.

Returns

matplotlib.figure.Figure The dendrogram figure.

def adjust_cells_to_group_mean(data, data_avg, beta=0.2):
1491def adjust_cells_to_group_mean(data, data_avg, beta=0.2):
1492    """
1493    Adjust each cell's values towards the mean of its group (centroid).
1494
1495    This function moves each cell's values in `data` slightly towards the
1496    corresponding group mean in `data_avg`, controlled by the parameter `beta`.
1497
1498    Parameters
1499    ----------
1500    data : pandas.DataFrame
1501        Original data with features as rows and cells/samples as columns.
1502
1503    data_avg : pandas.DataFrame
1504        DataFrame of group averages (centroids) with features as rows and
1505        group names as columns.
1506
1507    beta : float, default=0.2
1508        Weight for adjustment towards the group mean. 0 = no adjustment,
1509        1 = fully replaced by the group mean.
1510
1511    Returns
1512    -------
1513    pandas.DataFrame
1514        Adjusted data with the same shape as the input `data`.
1515    """
1516
1517    df_adjusted = data.copy()
1518
1519    for group_name in data_avg.columns:
1520        col_idx = [
1521            i
1522            for i, c in enumerate(df_adjusted.columns)
1523            if str(c).startswith(group_name)
1524        ]
1525        if not col_idx:
1526            continue
1527
1528        centroid = data_avg.loc[df_adjusted.index, group_name].to_numpy()[:, None]
1529
1530        df_adjusted.iloc[:, col_idx] = (1 - beta) * df_adjusted.iloc[
1531            :, col_idx
1532        ].to_numpy() + beta * centroid
1533
1534    return df_adjusted

Adjust each cell's values towards the mean of its group (centroid).

This function moves each cell's values in data slightly towards the corresponding group mean in data_avg, controlled by the parameter beta.

Parameters

data : pandas.DataFrame Original data with features as rows and cells/samples as columns.

data_avg : pandas.DataFrame DataFrame of group averages (centroids) with features as rows and group names as columns.

beta : float, default=0.2 Weight for adjustment towards the group mean. 0 = no adjustment, 1 = fully replaced by the group mean.

Returns

pandas.DataFrame Adjusted data with the same shape as the input data.