jdti.jdti

   1import math
   2import os
   3import pickle
   4import re
   5
   6import harmonypy as harmonize
   7import matplotlib.pyplot as plt
   8import numpy as np
   9import pandas as pd
  10import plotly.express as px
  11import seaborn as sns
  12import umap
  13from adjustText import adjust_text
  14from joblib import Parallel, delayed
  15from matplotlib.patches import FancyArrowPatch, Patch, Polygon
  16from scipy import sparse
  17from scipy.io import mmwrite
  18from scipy.spatial import ConvexHull
  19from scipy.spatial.distance import pdist, squareform
  20from scipy.stats import norm, stats, zscore
  21from sklearn.cluster import DBSCAN, MeanShift
  22from sklearn.decomposition import PCA
  23from sklearn.metrics import pairwise_distances, silhouette_score
  24from sklearn.preprocessing import StandardScaler
  25from tqdm import tqdm
  26
  27from .utils import *
  28
  29
  30class Clustering:
  31    """
  32    A class for performing dimensionality reduction, clustering, and visualization
  33    on high-dimensional data (e.g., single-cell gene expression).
  34
  35    The class provides methods for:
  36    - Normalizing and extracting subsets of data
  37    - Principal Component Analysis (PCA) and related clustering
  38    - Uniform Manifold Approximation and Projection (UMAP) and clustering
  39    - Visualization of PCA and UMAP embeddings
  40    - Harmonization of batch effects
  41    - Accessing processed data and cluster labels
  42
  43    Methods
  44    -------
  45    add_data_frame(data, metadata)
  46        Class method to create a Clustering instance from a DataFrame and metadata.
  47
  48    harmonize_sets()
  49        Perform batch effect harmonization on PCA data.
  50
  51    perform_PCA(pc_num=100, width=8, height=6)
  52        Perform PCA on the dataset and visualize the first two PCs.
  53
  54    knee_plot_PCA(width=8, height=6)
  55        Plot the cumulative variance explained by PCs to determine optimal dimensionality.
  56
  57    find_clusters_PCA(pc_num=0, eps=0.5, min_samples=10, width=8, height=6, harmonized=False)
  58        Apply DBSCAN clustering to PCA embeddings and visualize results.
  59
  60    perform_UMAP(factorize=False, umap_num=100, pc_num=0, harmonized=False, ...)
  61        Compute UMAP embeddings with optional parameter tuning.
  62
  63    knee_plot_umap(eps=0.5, min_samples=10)
  64        Determine optimal UMAP dimensionality using silhouette scores.
  65
  66    find_clusters_UMAP(umap_n=5, eps=0.5, min_samples=10, width=8, height=6)
  67        Apply DBSCAN clustering on UMAP embeddings and visualize clusters.
  68
  69    UMAP_vis(names_slot='cell_names', set_sep=True, point_size=0.6, ...)
  70        Visualize UMAP embeddings with labels and optional cluster numbering.
  71
  72    UMAP_feature(feature_name, features_data=None, point_size=0.6, ...)
  73        Plot a single feature over UMAP coordinates with customizable colormap.
  74
  75    get_umap_data()
  76        Return the UMAP embeddings along with cluster labels if available.
  77
  78    get_pca_data()
  79        Return the PCA results along with cluster labels if available.
  80
  81    return_clusters(clusters='umap')
  82        Return the cluster labels for UMAP or PCA embeddings.
  83
  84    Raises
  85    ------
  86    ValueError
  87        For invalid parameters, mismatched dimensions, or missing metadata.
  88    """
  89
  90    def __init__(self, data, metadata):
  91        """
  92        Initialize the clustering class with data and optional metadata.
  93
  94        Parameters
  95        ----------
  96        data : pandas.DataFrame
  97            Input data for clustering. Columns are considered as samples.
  98
  99        metadata : pandas.DataFrame, optional
 100            Metadata for the samples. If None, a default DataFrame with column
 101            names as 'cell_names' is created.
 102
 103        Attributes
 104        ----------
 105        -clustering_data : pandas.DataFrame
 106        -clustering_metadata : pandas.DataFrame
 107        -subclusters : None or dict
 108        -explained_var : None or numpy.ndarray
 109        -cumulative_var : None or numpy.ndarray
 110        -pca : None or pandas.DataFrame
 111        -harmonized_pca : None or pandas.DataFrame
 112        -umap : None or pandas.DataFrame
 113        """
 114
 115        self.clustering_data = data
 116        """The input data used for clustering."""
 117
 118        if metadata is None:
 119            metadata = pd.DataFrame({"cell_names": list(data.columns)})
 120
 121        self.clustering_metadata = metadata
 122        """Metadata associated with the samples."""
 123
 124        self.subclusters = None
 125        """Placeholder for storing subcluster information."""
 126
 127        self.explained_var = None
 128        """Explained variance from PCA, initialized as None."""
 129
 130        self.cumulative_var = None
 131        """Cumulative explained variance from PCA, initialized as None."""
 132
 133        self.pca = None
 134        """PCA-transformed data, initialized as None."""
 135
 136        self.harmonized_pca = None
 137        """PCA data after batch effect harmonization, initialized as None."""
 138
 139        self.umap = None
 140        """UMAP embeddings, initialized as None."""
 141
 142    @classmethod
 143    def add_data_frame(cls, data: pd.DataFrame, metadata: pd.DataFrame | None):
 144        """
 145        Create a Clustering instance from a DataFrame and optional metadata.
 146
 147        Parameters
 148        ----------
 149        data : pandas.DataFrame
 150            Input data with features as rows and samples/cells as columns.
 151
 152        metadata : pandas.DataFrame or None
 153            Optional metadata for the samples.
 154            Each row corresponds to a sample/cell, and column names in this DataFrame
 155            should match the sample/cell names in `data`. Columns can contain additional
 156            information such as cell type, experimental condition, batch, sets, etc.
 157
 158        Returns
 159        -------
 160        Clustering
 161            A new instance of the Clustering class.
 162        """
 163
 164        return cls(data, metadata)
 165
 166    def harmonize_sets(self, batch_col: str = "sets"):
 167        """
 168        Perform batch effect harmonization on PCA embeddings.
 169
 170        Parameters
 171        ----------
 172        batch_col : str, default 'sets'
 173            Name of the column in `metadata` that contains batch information for the samples/cells.
 174
 175        Returns
 176        -------
 177        None
 178            Updates the `harmonized_pca` attribute with harmonized data.
 179        """
 180
 181        data_mat = np.array(self.pca)
 182
 183        metadata = self.clustering_metadata
 184
 185        self.harmonized_pca = pd.DataFrame(
 186            harmonize.run_harmony(data_mat, metadata, vars_use=batch_col).Z_corr
 187        ).T
 188
 189        self.harmonized_pca.columns = self.pca.columns
 190
 191    def perform_PCA(self, pc_num: int = 100, width=8, height=6):
 192        """
 193        Perform Principal Component Analysis (PCA) on the dataset.
 194
 195        This method standardizes the data, applies PCA, stores results as attributes,
 196        and generates a scatter plot of the first two principal components.
 197
 198        Parameters
 199        ----------
 200        pc_num : int, default 100
 201            Number of principal components to compute.
 202            If 0, computes all available components.
 203
 204        width : int or float, default 8
 205            Width of the PCA figure.
 206
 207        height : int or float, default 6
 208            Height of the PCA figure.
 209
 210        Returns
 211        -------
 212        matplotlib.figure.Figure
 213            Scatter plot showing the first two principal components.
 214
 215        Updates
 216        -------
 217        self.pca : pandas.DataFrame
 218            DataFrame with principal component scores for each sample.
 219
 220        self.explained_var : numpy.ndarray
 221            Percentage of variance explained by each principal component.
 222
 223        self.cumulative_var : numpy.ndarray
 224            Cumulative explained variance.
 225        """
 226
 227        scaler = StandardScaler()
 228        data_scaled = scaler.fit_transform(self.clustering_data.T)
 229
 230        if pc_num == 0 or pc_num > data_scaled.shape[0]:
 231            pc_num = data_scaled.shape[0]
 232
 233        pca = PCA(n_components=pc_num, random_state=42)
 234
 235        principal_components = pca.fit_transform(data_scaled)
 236
 237        pca_df = pd.DataFrame(
 238            data=principal_components,
 239            columns=["PC" + str(x + 1) for x in range(pc_num)],
 240        )
 241
 242        self.explained_var = pca.explained_variance_ratio_ * 100
 243        self.cumulative_var = np.cumsum(self.explained_var)
 244
 245        self.pca = pca_df
 246
 247        fig = plt.figure(figsize=(width, height))
 248        plt.scatter(pca_df["PC1"], pca_df["PC2"], alpha=0.7)
 249        plt.xlabel("PC 1")
 250        plt.ylabel("PC 2")
 251        plt.grid(True)
 252        plt.show()
 253
 254        return fig
 255
 256    def knee_plot_PCA(self, width: int = 8, height: int = 6):
 257        """
 258        Plot cumulative explained variance to determine the optimal number of PCs.
 259
 260        Parameters
 261        ----------
 262        width : int, default 8
 263            Width of the figure.
 264
 265        height : int or, default 6
 266            Height of the figure.
 267
 268        Returns
 269        -------
 270        matplotlib.figure.Figure
 271            Line plot showing cumulative variance explained by each PC.
 272        """
 273
 274        fig_knee = plt.figure(figsize=(width, height))
 275        plt.plot(range(1, len(self.explained_var) + 1), self.cumulative_var, marker="o")
 276        plt.xlabel("PC (n components)")
 277        plt.ylabel("Cumulative explained variance (%)")
 278        plt.grid(True)
 279
 280        xticks = [1] + list(range(5, len(self.explained_var) + 1, 5))
 281        plt.xticks(xticks, rotation=60)
 282
 283        plt.show()
 284
 285        return fig_knee
 286
 287    def find_clusters_PCA(
 288        self,
 289        pc_num: int = 2,
 290        eps: float = 0.5,
 291        min_samples: int = 10,
 292        width: int = 8,
 293        height: int = 6,
 294        harmonized: bool = False,
 295    ):
 296        """
 297        Apply DBSCAN clustering to PCA embeddings and visualize the results.
 298
 299        This method performs density-based clustering (DBSCAN) on the PCA-reduced
 300        dataset. Cluster labels are stored in the object's metadata, and a scatter
 301        plot of the first two principal components with cluster annotations is returned.
 302
 303        Parameters
 304        ----------
 305        pc_num : int, default 2
 306            Number of principal components to use for clustering.
 307            If 0, uses all available components.
 308
 309        eps : float, default 0.5
 310            Maximum distance between two points for them to be considered
 311            as neighbors (DBSCAN parameter).
 312
 313        min_samples : int, default 10
 314            Minimum number of samples required to form a cluster (DBSCAN parameter).
 315
 316        width : int, default 8
 317            Width of the output scatter plot.
 318
 319        height : int, default 6
 320            Height of the output scatter plot.
 321
 322        harmonized : bool, default False
 323            If True, use harmonized PCA data (`self.harmonized_pca`).
 324            If False, use standard PCA results (`self.pca`).
 325
 326        Returns
 327        -------
 328        matplotlib.figure.Figure
 329            Scatter plot of the first two principal components colored by
 330            cluster assignments.
 331
 332        Updates
 333        -------
 334        self.clustering_metadata['PCA_clusters'] : list
 335            Cluster labels assigned to each cell/sample.
 336
 337        self.input_metadata['PCA_clusters'] : list, optional
 338            Cluster labels stored in input metadata (if available).
 339        """
 340
 341        dbscan = DBSCAN(eps=eps, min_samples=min_samples)
 342
 343        if pc_num == 0 and harmonized:
 344            PCA = self.harmonized_pca
 345
 346        elif pc_num == 0:
 347            PCA = self.pca
 348
 349        else:
 350            if harmonized:
 351
 352                PCA = self.harmonized_pca.iloc[:, 0:pc_num]
 353            else:
 354
 355                PCA = self.pca.iloc[:, 0:pc_num]
 356
 357        dbscan_labels = dbscan.fit_predict(PCA)
 358
 359        pca_df = pd.DataFrame(PCA)
 360        pca_df["Cluster"] = dbscan_labels
 361
 362        fig = plt.figure(figsize=(width, height))
 363
 364        for cluster_id in sorted(pca_df["Cluster"].unique()):
 365            cluster_data = pca_df[pca_df["Cluster"] == cluster_id]
 366            plt.scatter(
 367                cluster_data["PC1"],
 368                cluster_data["PC2"],
 369                label=f"Cluster {cluster_id}",
 370                alpha=0.7,
 371            )
 372
 373        plt.xlabel("PC 1")
 374        plt.ylabel("PC 2")
 375        plt.legend(title="Clusters", loc="center left", bbox_to_anchor=(1.0, 0.5))
 376
 377        plt.grid(True)
 378        plt.show()
 379
 380        self.clustering_metadata["PCA_clusters"] = [str(x) for x in dbscan_labels]
 381
 382        try:
 383            self.input_metadata["PCA_clusters"] = [str(x) for x in dbscan_labels]
 384        except:
 385            pass
 386
 387        return fig
 388
 389    def perform_UMAP(
 390        self,
 391        factorize: bool = False,
 392        umap_num: int = 100,
 393        pc_num: int = 0,
 394        harmonized: bool = False,
 395        n_neighbors: int = 5,
 396        min_dist: float | int = 0.1,
 397        spread: float | int = 1.0,
 398        set_op_mix_ratio: float | int = 1.0,
 399        local_connectivity: int = 1,
 400        repulsion_strength: float | int = 1.0,
 401        negative_sample_rate: int = 5,
 402        width: int = 8,
 403        height: int = 6,
 404    ):
 405        """
 406        Compute and visualize UMAP embeddings of the dataset.
 407
 408        This method applies Uniform Manifold Approximation and Projection (UMAP)
 409        for dimensionality reduction on either raw, PCA, or harmonized PCA data.
 410        Results are stored as a DataFrame (`self.umap`) and a scatter plot figure
 411        (`self.UMAP_plot`).
 412
 413        Parameters
 414        ----------
 415        factorize : bool, default False
 416            If True, categorical sample labels (from column names) are factorized
 417            and used as supervision in UMAP fitting.
 418
 419        umap_num : int, default 100
 420            Number of UMAP dimensions to compute. If 0, matches the input dimension.
 421
 422        pc_num : int, default 0
 423            Number of principal components to use as UMAP input.
 424            If 0, use all available components or raw data.
 425
 426        harmonized : bool, default False
 427            If True, use harmonized PCA embeddings (`self.harmonized_pca`).
 428            If False, use standard PCA or raw scaled data.
 429
 430        n_neighbors : int, default 5
 431            UMAP parameter controlling the size of the local neighborhood.
 432
 433        min_dist : float, default 0.1
 434            UMAP parameter controlling minimum allowed distance between embedded points.
 435
 436        spread : int | float, default 1.0
 437            Effective scale of embedded space (UMAP parameter).
 438
 439        set_op_mix_ratio : int | float, default 1.0
 440            Interpolation parameter between union and intersection in fuzzy sets.
 441
 442        local_connectivity : int, default 1
 443            Number of nearest neighbors assumed for each point.
 444
 445        repulsion_strength : int | float, default 1.0
 446            Weighting applied to negative samples during optimization.
 447
 448        negative_sample_rate : int, default 5
 449            Number of negative samples per positive sample in optimization.
 450
 451        width : int, default 8
 452            Width of the output scatter plot.
 453
 454        height : int, default 6
 455            Height of the output scatter plot.
 456
 457        Updates
 458        -------
 459        self.umap : pandas.DataFrame
 460            Table of UMAP embeddings with columns `UMAP1 ... UMAPn`.
 461
 462        Notes
 463        -----
 464        For supervised UMAP (`factorize=True`), categorical codes from column
 465        names of the dataset are used as labels.
 466        """
 467
 468        scaler = StandardScaler()
 469
 470        if pc_num == 0 and harmonized:
 471            data_scaled = self.harmonized_pca
 472
 473        elif pc_num == 0:
 474            data_scaled = scaler.fit_transform(self.clustering_data.T)
 475
 476        else:
 477            if harmonized:
 478
 479                data_scaled = self.harmonized_pca.iloc[:, 0:pc_num]
 480            else:
 481
 482                data_scaled = self.pca.iloc[:, 0:pc_num]
 483
 484        if umap_num == 0 or umap_num > data_scaled.shape[1]:
 485
 486            umap_num = data_scaled.shape[1]
 487
 488            reducer = umap.UMAP(
 489                n_components=len(data_scaled.T),
 490                random_state=42,
 491                n_neighbors=n_neighbors,
 492                min_dist=min_dist,
 493                spread=spread,
 494                set_op_mix_ratio=set_op_mix_ratio,
 495                local_connectivity=local_connectivity,
 496                repulsion_strength=repulsion_strength,
 497                negative_sample_rate=negative_sample_rate,
 498                n_jobs=1,
 499            )
 500
 501        else:
 502
 503            reducer = umap.UMAP(
 504                n_components=umap_num,
 505                random_state=42,
 506                n_neighbors=n_neighbors,
 507                min_dist=min_dist,
 508                spread=spread,
 509                set_op_mix_ratio=set_op_mix_ratio,
 510                local_connectivity=local_connectivity,
 511                repulsion_strength=repulsion_strength,
 512                negative_sample_rate=negative_sample_rate,
 513                n_jobs=1,
 514            )
 515
 516        if factorize:
 517            embedding = reducer.fit_transform(
 518                X=data_scaled, y=pd.Categorical(self.clustering_data.columns).codes
 519            )
 520        else:
 521            embedding = reducer.fit_transform(X=data_scaled)
 522
 523        umap_df = pd.DataFrame(
 524            embedding, columns=["UMAP" + str(x + 1) for x in range(umap_num)]
 525        )
 526
 527        plt.figure(figsize=(width, height))
 528        plt.scatter(umap_df["UMAP1"], umap_df["UMAP2"], alpha=0.7)
 529        plt.xlabel("UMAP 1")
 530        plt.ylabel("UMAP 2")
 531        plt.grid(True)
 532
 533        plt.show()
 534
 535        self.umap = umap_df
 536
 537    def knee_plot_umap(self, eps: int | float = 0.5, min_samples: int = 10):
 538        """
 539        Plot silhouette scores for different UMAP dimensions to determine optimal n_components.
 540
 541        Parameters
 542        ----------
 543        eps : float, default 0.5
 544            DBSCAN eps parameter for clustering each UMAP dimension.
 545
 546        min_samples : int, default 10
 547            Minimum number of samples to form a cluster in DBSCAN.
 548
 549        Returns
 550        -------
 551        matplotlib.figure.Figure
 552            Silhouette score plot across UMAP dimensions.
 553        """
 554
 555        umap_range = range(2, len(self.umap.T) + 1)
 556
 557        silhouette_scores = []
 558        component = []
 559        for n in umap_range:
 560
 561            db = DBSCAN(eps=eps, min_samples=min_samples)
 562            labels = db.fit_predict(np.array(self.umap)[:, :n])
 563
 564            mask = labels != -1
 565            if len(set(labels[mask])) > 1:
 566                score = silhouette_score(np.array(self.umap)[:, :n][mask], labels[mask])
 567            else:
 568                score = -1
 569
 570            silhouette_scores.append(score)
 571            component.append(n)
 572
 573        fig = plt.figure(figsize=(10, 5))
 574        plt.plot(component, silhouette_scores, marker="o")
 575        plt.xlabel("UMAP (n_components)")
 576        plt.ylabel("Silhouette Score")
 577        plt.grid(True)
 578        plt.xticks(range(int(min(component)), int(max(component)) + 1, 1))
 579
 580        plt.show()
 581
 582        return fig
 583
 584    def find_clusters_UMAP(
 585        self,
 586        umap_n: int = 5,
 587        eps: float | float = 0.5,
 588        min_samples: int = 10,
 589        width: int = 8,
 590        height: int = 6,
 591    ):
 592        """
 593        Apply DBSCAN clustering on UMAP embeddings and visualize clusters.
 594
 595        This method performs density-based clustering (DBSCAN) on the UMAP-reduced
 596        dataset. Cluster labels are stored in the object's metadata, and a scatter
 597        plot of the first two UMAP components with cluster annotations is returned.
 598
 599        Parameters
 600        ----------
 601        umap_n : int, default 5
 602            Number of UMAP dimensions to use for DBSCAN clustering.
 603            Must be <= number of columns in `self.umap`.
 604
 605        eps : float | int, default 0.5
 606            Maximum neighborhood distance between two samples for them to be considered
 607            as in the same cluster (DBSCAN parameter).
 608
 609        min_samples : int, default 10
 610            Minimum number of samples in a neighborhood to form a cluster (DBSCAN parameter).
 611
 612        width : int, default 8
 613            Figure width.
 614
 615        height : int, default 6
 616            Figure height.
 617
 618        Returns
 619        -------
 620        matplotlib.figure.Figure
 621            Scatter plot of the first two UMAP components colored by
 622            cluster assignments.
 623
 624        Updates
 625        -------
 626        self.clustering_metadata['UMAP_clusters'] : list
 627            Cluster labels assigned to each cell/sample.
 628
 629        self.input_metadata['UMAP_clusters'] : list, optional
 630            Cluster labels stored in input metadata (if available).
 631        """
 632
 633        dbscan = DBSCAN(eps=eps, min_samples=min_samples)
 634        dbscan_labels = dbscan.fit_predict(np.array(self.umap)[:, :umap_n])
 635
 636        umap_df = self.umap
 637        umap_df["Cluster"] = dbscan_labels
 638
 639        fig = plt.figure(figsize=(width, height))
 640
 641        for cluster_id in sorted(umap_df["Cluster"].unique()):
 642            cluster_data = umap_df[umap_df["Cluster"] == cluster_id]
 643            plt.scatter(
 644                cluster_data["UMAP1"],
 645                cluster_data["UMAP2"],
 646                label=f"Cluster {cluster_id}",
 647                alpha=0.7,
 648            )
 649
 650        plt.xlabel("UMAP 1")
 651        plt.ylabel("UMAP 2")
 652        plt.legend(title="Clusters", loc="center left", bbox_to_anchor=(1.0, 0.5))
 653        plt.grid(True)
 654
 655        self.clustering_metadata["UMAP_clusters"] = [str(x) for x in dbscan_labels]
 656
 657        try:
 658            self.input_metadata["UMAP_clusters"] = [str(x) for x in dbscan_labels]
 659        except:
 660            pass
 661
 662        return fig
 663
 664    def UMAP_vis(
 665        self,
 666        names_slot: str = "cell_names",
 667        set_sep: bool = True,
 668        point_size: int | float = 0.6,
 669        font_size: int | float = 6,
 670        legend_split_col: int = 2,
 671        width: int = 8,
 672        height: int = 6,
 673        inc_num: bool = True,
 674    ):
 675        """
 676        Visualize UMAP embeddings with sample labels based on specyfic metadata slot.
 677
 678        Parameters
 679        ----------
 680        names_slot : str, default 'cell_names'
 681            Column in metadata to use as sample labels.
 682
 683        set_sep : bool, default True
 684            If True, separate points by dataset.
 685
 686        point_size : float, default 0.6
 687            Size of scatter points.
 688
 689        font_size : int, default 6
 690            Font size for numbers on points.
 691
 692        legend_split_col : int, default 2
 693            Number of columns in legend.
 694
 695        width : int, default 8
 696            Figure width.
 697
 698        height : int, default 6
 699            Figure height.
 700
 701        inc_num : bool, default True
 702            If True, annotate points with numeric labels.
 703
 704        Returns
 705        -------
 706        matplotlib.figure.Figure
 707            UMAP scatter plot figure.
 708        """
 709
 710        umap_df = self.umap.iloc[:, 0:2].copy()
 711        umap_df["names"] = list(self.clustering_metadata[names_slot])
 712
 713        if set_sep:
 714
 715            if "sets" in list(self.clustering_metadata.columns):
 716                umap_df["dataset"] = list(self.clustering_metadata["sets"])
 717            else:
 718                umap_df["dataset"] = "default"
 719
 720        else:
 721            umap_df["dataset"] = "default"
 722
 723        umap_df["tmp_nam"] = list(umap_df["names"] + umap_df["dataset"])
 724
 725        umap_df["count"] = umap_df["tmp_nam"].map(umap_df["tmp_nam"].value_counts())
 726
 727        numeric_df = (
 728            pd.DataFrame(umap_df[["count", "tmp_nam", "names"]].copy())
 729            .drop_duplicates()
 730            .sort_values("count", ascending=False)
 731        )
 732        numeric_df["numeric_values"] = range(0, numeric_df.shape[0])
 733
 734        umap_df = umap_df.merge(
 735            numeric_df[["tmp_nam", "numeric_values"]], on="tmp_nam", how="left"
 736        )
 737
 738        fig, ax = plt.subplots(figsize=(width, height))
 739
 740        markers = ["o", "s", "^", "D", "P", "*", "X"]
 741        marker_map = {
 742            ds: markers[i % len(markers)]
 743            for i, ds in enumerate(umap_df["dataset"].unique())
 744        }
 745
 746        cord_list = []
 747
 748        for num, nam in zip(numeric_df["numeric_values"], numeric_df["names"]):
 749
 750            cluster_data = umap_df[umap_df["numeric_values"] == num]
 751
 752            ax.scatter(
 753                cluster_data["UMAP1"],
 754                cluster_data["UMAP2"],
 755                label=f"{num} - {nam}",
 756                marker=marker_map[cluster_data["dataset"].iloc[0]],
 757                alpha=0.6,
 758                s=point_size,
 759            )
 760
 761            coords = cluster_data[["UMAP1", "UMAP2"]].values
 762
 763            dists = pairwise_distances(coords)
 764
 765            sum_dists = dists.sum(axis=1)
 766
 767            center_idx = np.argmin(sum_dists)
 768            center_point = coords[center_idx]
 769
 770            cord_list.append(center_point)
 771
 772        if inc_num:
 773            texts = []
 774            for (x, y), num in zip(cord_list, numeric_df["numeric_values"]):
 775                texts.append(
 776                    ax.text(
 777                        x,
 778                        y,
 779                        str(num),
 780                        ha="center",
 781                        va="center",
 782                        fontsize=font_size,
 783                        color="black",
 784                    )
 785                )
 786
 787            adjust_text(texts, arrowprops=dict(arrowstyle="-", color="gray", lw=0.5))
 788
 789        ax.set_xlabel("UMAP 1")
 790        ax.set_ylabel("UMAP 2")
 791
 792        ax.legend(
 793            title="Clusters",
 794            loc="center left",
 795            bbox_to_anchor=(1.05, 0.5),
 796            ncol=legend_split_col,
 797            markerscale=5,
 798        )
 799
 800        ax.grid(True)
 801
 802        plt.tight_layout()
 803
 804        return fig
 805
 806    def UMAP_feature(
 807        self,
 808        feature_name: str,
 809        features_data: pd.DataFrame | None,
 810        point_size: int | float = 0.6,
 811        font_size: int | float = 6,
 812        width: int = 8,
 813        height: int = 6,
 814        palette="light",
 815    ):
 816        """
 817        Visualize UMAP embedding with expression levels of a selected feature.
 818
 819        Each point (cell) in the UMAP plot is colored according to the expression
 820        value of the chosen feature, enabling interpretation of spatial patterns
 821        of gene activity or metadata distribution in low-dimensional space.
 822
 823        Parameters
 824        ----------
 825        feature_name : str
 826           Name of the feature to plot.
 827
 828        features_data : pandas.DataFrame or None, default None
 829            If None, the function uses the DataFrame containing the clustering data.
 830            To plot features not used in clustering, provide a wider DataFrame
 831            containing the original feature values.
 832
 833        point_size : float, default 0.6
 834            Size of scatter points in the plot.
 835
 836        font_size : int, default 6
 837            Font size for axis labels and annotations.
 838
 839        width : int, default 8
 840            Width of the matplotlib figure.
 841
 842        height : int, default 6
 843            Height of the matplotlib figure.
 844
 845        palette : str, default 'light'
 846            Color palette for expression visualization. Options are:
 847            - 'light'
 848            - 'dark'
 849            - 'green'
 850            - 'gray'
 851
 852        Returns
 853        -------
 854        matplotlib.figure.Figure
 855            UMAP scatter plot colored by feature values.
 856        """
 857
 858        umap_df = self.umap.iloc[:, 0:2].copy()
 859
 860        if features_data is None:
 861
 862            features_data = self.clustering_data
 863
 864        if features_data.shape[1] != umap_df.shape[0]:
 865            raise ValueError(
 866                "Imputed 'features_data' shape does not match the number of UMAP cells"
 867            )
 868
 869        blist = [
 870            True if x.upper() == feature_name.upper() else False
 871            for x in features_data.index
 872        ]
 873
 874        if not any(blist):
 875            raise ValueError("Imputed feature_name is not included in the data")
 876
 877        umap_df.loc[:, "feature"] = (
 878            features_data.loc[blist, :]
 879            .apply(lambda row: row.tolist(), axis=1)
 880            .values[0]
 881        )
 882
 883        umap_df = umap_df.sort_values("feature", ascending=True)
 884
 885        import matplotlib.colors as mcolors
 886
 887        if palette == "light":
 888            palette = px.colors.sequential.Sunsetdark
 889
 890        elif palette == "dark":
 891            palette = px.colors.sequential.thermal
 892
 893        elif palette == "green":
 894            palette = px.colors.sequential.Aggrnyl
 895
 896        elif palette == "gray":
 897            palette = px.colors.sequential.gray
 898            palette = palette[::-1]
 899
 900        else:
 901            raise ValueError(
 902                'Palette not found. Use: "light", "dark", "gray", or "green"'
 903            )
 904
 905        converted = []
 906        for c in palette:
 907            rgb_255 = px.colors.unlabel_rgb(c)
 908            rgb_01 = tuple(v / 255.0 for v in rgb_255)
 909            converted.append(rgb_01)
 910
 911        my_cmap = mcolors.ListedColormap(converted, name="custom")
 912
 913        fig, ax = plt.subplots(figsize=(width, height))
 914        sc = ax.scatter(
 915            umap_df["UMAP1"],
 916            umap_df["UMAP2"],
 917            c=umap_df["feature"],
 918            s=point_size,
 919            cmap=my_cmap,
 920            alpha=1.0,
 921            edgecolors="black",
 922            linewidths=0.1,
 923        )
 924
 925        cbar = plt.colorbar(sc, ax=ax)
 926        cbar.set_label(f"{feature_name}")
 927
 928        ax.set_xlabel("UMAP 1")
 929        ax.set_ylabel("UMAP 2")
 930
 931        ax.grid(True)
 932
 933        plt.tight_layout()
 934
 935        return fig
 936
 937    def get_umap_data(self):
 938        """
 939        Retrieve UMAP embedding data with optional cluster labels.
 940
 941        Returns the UMAP coordinates stored in `self.umap`. If clustering
 942        metadata is available (specifically `UMAP_clusters`), the corresponding
 943        cluster assignments are appended as an additional column.
 944
 945        Returns
 946        -------
 947        pandas.DataFrame
 948            DataFrame containing UMAP coordinates (columns: 'UMAP1', 'UMAP2', ...).
 949            If available, includes an extra column 'clusters' with cluster labels.
 950
 951        Notes
 952        -----
 953        - UMAP embeddings must be computed beforehand (e.g., using `perform_UMAP`).
 954        - Cluster labels are added only if present in `self.clustering_metadata`.
 955        """
 956
 957        umap_data = self.umap
 958
 959        try:
 960            umap_data["clusters"] = self.clustering_metadata["UMAP_clusters"]
 961        except:
 962            pass
 963
 964        return umap_data
 965
 966    def get_pca_data(self):
 967        """
 968        Retrieve PCA embedding data with optional cluster labels.
 969
 970        Returns the principal component scores stored in `self.pca`. If clustering
 971        metadata is available (specifically `PCA_clusters`), the corresponding
 972        cluster assignments are appended as an additional column.
 973
 974        Returns
 975        -------
 976        pandas.DataFrame
 977            DataFrame containing PCA coordinates (columns: 'PC1', 'PC2', ...).
 978            If available, includes an extra column 'clusters' with cluster labels.
 979
 980        Notes
 981        -----
 982        - PCA must be computed beforehand (e.g., using `perform_PCA`).
 983        - Cluster labels are added only if present in `self.clustering_metadata`.
 984        """
 985
 986        pca_data = self.pca
 987
 988        try:
 989            pca_data["clusters"] = self.clustering_metadata["PCA_clusters"]
 990        except:
 991            pass
 992
 993        return pca_data
 994
 995    def return_clusters(self, clusters="umap"):
 996        """
 997        Retrieve cluster labels from UMAP or PCA clustering results.
 998
 999        Parameters
1000        ----------
1001        clusters : str, default 'umap'
1002            Source of cluster labels to return. Must be one of:
1003            - 'umap': return cluster labels from UMAP embeddings.
1004            - 'pca' : return cluster labels from PCA embeddings.
1005
1006        Returns
1007        -------
1008        list
1009            Cluster labels corresponding to the selected embedding method.
1010
1011        Raises
1012        ------
1013        ValueError
1014            If `clusters` is not 'umap' or 'pca'.
1015
1016        Notes
1017        -----
1018        Requires that clustering has already been performed
1019        (e.g., using `find_clusters_UMAP` or `find_clusters_PCA`).
1020        """
1021
1022        if clusters.lower() == "umap":
1023            clusters_vector = self.clustering_metadata["UMAP_clusters"]
1024        elif clusters.lower() == "pca":
1025            clusters_vector = self.clustering_metadata["PCA_clusters"]
1026        else:
1027            raise ValueError("Parameter 'clusters' must be either 'umap' or 'pca'.")
1028
1029        return clusters_vector
1030
1031
1032class COMPsc(Clustering):
1033    """
1034    A class `COMPsc` (Comparison of single-cell data) designed for the integration,
1035    analysis, and visualization of single-cell datasets.
1036    The class supports independent dataset integration, subclustering of existing clusters,
1037    marker detection, and multiple visualization strategies.
1038
1039    The COMPsc class provides methods for:
1040
1041        - Normalizing and filtering single-cell data
1042        - Loading and saving sparse 10x-style datasets
1043        - Computing differential expression and marker genes
1044        - Clustering and subclustering analysis
1045        - Visualizing similarity and spatial relationships
1046        - Aggregating data by cell and set annotations
1047        - Managing metadata and renaming labels
1048        - Plotting gene detection histograms and feature scatters
1049
1050    Methods
1051    -------
1052    project_dir(path_to_directory, project_list)
1053        Scans a directory to create a COMPsc instance mapping project names to their paths.
1054
1055    save_project(name, path=os.getcwd())
1056        Saves the COMPsc object to a pickle file on disk.
1057
1058    load_project(path)
1059        Loads a previously saved COMPsc object from a pickle file.
1060
1061    reduce_cols(reg, inc_set=False)
1062        Removes columns from data tables where column names contain a specified name or partial substring.
1063
1064    reduce_rows(reg, inc_set=False)
1065        Removes rows from data tables where column names contain a specified feature (gene) name.
1066
1067    get_data(set_info=False)
1068        Returns normalized data with optional set annotations in column names.
1069
1070    get_metadata()
1071        Returns the stored input metadata.
1072
1073    get_partial_data(names=None, features=None, name_slot='cell_names')
1074        Return a subset of the data by sample names and/or features.
1075
1076    gene_calculation()
1077        Calculates and stores per-cell gene detection counts as a pandas Series.
1078
1079    gene_histograme(bins=100)
1080        Plots a histogram of genes detected per cell with an overlaid normal distribution.
1081
1082    gene_threshold(min_n=None, max_n=None)
1083        Filters cells based on minimum and/or maximum gene detection thresholds.
1084
1085    load_sparse_from_projects(normalized_data=False)
1086        Loads and concatenates sparse 10x-style datasets from project paths into count or normalized data.
1087
1088    rename_names(mapping, slot='cell_names')
1089        Renames entries in a specified metadata column using a provided mapping dictionary.
1090
1091    rename_subclusters(mapping)
1092        Renames subcluster labels using a provided mapping dictionary.
1093
1094    save_sparse(path_to_save=os.getcwd(), name_slot='cell_names', data_slot='normalized')
1095        Exports data as 10x-compatible sparse files (matrix.mtx, barcodes.tsv, genes.tsv).
1096
1097    normalize_data(normalize=True, normalize_factor=100000)
1098        Normalizes raw counts to counts-per-specified factor (e.g., CPM-like).
1099
1100    statistic(cells=None, sets=None, min_exp=0.01, min_pct=0.1, n_proc=10)
1101        Computes per-feature differential expression statistics (Mann-Whitney U) comparing target vs. rest groups.
1102
1103    calculate_difference_markers(min_exp=0, min_pct=0.25, n_proc=10, force=False)
1104        Computes and caches differential markers using the statistic method.
1105
1106    clustering_features(features_list=None, name_slot='cell_names', p_val=0.05, top_n=25, adj_mean=True, beta=0.4)
1107        Prepares clustering input by selecting marker features and optionally smoothing cell values.
1108
1109    average()
1110        Aggregates normalized data by averaging across (cell_name, set) pairs.
1111
1112    estimating_similarity(method='pearson', p_val=0.05, top_n=25)
1113        Computes pairwise correlation and Euclidean distance between aggregated samples.
1114
1115    similarity_plot(split_sets=True, set_info=True, cmap='seismic', width=12, height=10)
1116        Visualizes pairwise similarity as a scatter plot with correlation as hue and scaled distance as point size.
1117
1118    spatial_similarity(set_info=True, bandwidth=1, n_neighbors=5, min_dist=0.1, legend_split=2, point_size=20, ...)
1119        Creates a UMAP-like visualization of similarity relationships with cluster hulls and nearest-neighbor arrows.
1120
1121    subcluster_prepare(features, cluster)
1122        Initializes a Clustering object for subcluster analysis on a selected parent cluster.
1123
1124    define_subclusters(umap_num=2, eps=0.5, min_samples=10, bandwidth=1, n_neighbors=5, min_dist=0.1, ...)
1125        Performs UMAP and DBSCAN clustering on prepared subcluster data and stores cluster labels.
1126
1127    subcluster_features_scatter(colors='viridis', hclust='complete', img_width=3, img_high=5, label_size=6, ...)
1128        Visualizes averaged expression and occurrence of features for subclusters as a scatter plot.
1129
1130    subcluster_DEG_scatter(top_n=3, min_exp=0, min_pct=0.25, p_val=0.05, colors='viridis', ...)
1131        Plots top differential features for subclusters as a features-scatter visualization.
1132
1133    accept_subclusters()
1134        Commits subcluster labels to main metadata by renaming cell names and clears subcluster data.
1135
1136    Raises
1137    ------
1138    ValueError
1139        For invalid parameters, mismatched dimensions, or missing metadata.
1140
1141    """
1142
1143    def __init__(
1144        self,
1145        objects=None,
1146    ):
1147        """
1148        Initialize the COMPsc class for single-cell data integration and analysis.
1149
1150        Parameters
1151        ----------
1152        objects : list or None, optional
1153            Optional list of data objects to initialize the instance with.
1154
1155        Attributes
1156        ----------
1157        -objects : list or None
1158        -input_data : pandas.DataFrame or None
1159        -input_metadata : pandas.DataFrame or None
1160        -normalized_data : pandas.DataFrame or None
1161        -agg_metadata : pandas.DataFrame or None
1162        -agg_normalized_data : pandas.DataFrame or None
1163        -similarity : pandas.DataFrame or None
1164        -var_data : pandas.DataFrame or None
1165        -subclusters_ : instance of Clustering class or None
1166        -cells_calc : pandas.Series or None
1167        -gene_calc : pandas.Series or None
1168        -composition_data : pandas.DataFrame or None
1169        """
1170
1171        self.objects = objects
1172        """ Stores the input data objects."""
1173
1174        self.input_data = None
1175        """Raw input data for clustering or integration analysis."""
1176
1177        self.input_metadata = None
1178        """Metadata associated with the input data."""
1179
1180        self.normalized_data = None
1181        """Normalized version of the input data."""
1182
1183        self.agg_metadata = None
1184        '''Aggregated metadata for all sets in object related to "agg_normalized_data"'''
1185
1186        self.agg_normalized_data = None
1187        """Aggregated and normalized data across multiple sets."""
1188
1189        self.similarity = None
1190        """Similarity data between cells across all samples. and sets"""
1191
1192        self.var_data = None
1193        """DEG analysis results summarizing variance across all samples in the object."""
1194
1195        self.subclusters_ = None
1196        """Placeholder for information about subclusters analysis; if computed."""
1197
1198        self.cells_calc = None
1199        """Number of cells detected per sample (grouped by lineage, e.g., cluster or name), reflecting data composition."""
1200
1201        self.gene_calc = None
1202        """Number of genes detected per sample (cell), reflecting the sequencing depth."""
1203
1204        self.composition_data = None
1205        """Data describing composition of cells across clusters or sets."""
1206
1207    @classmethod
1208    def project_dir(cls, path_to_directory, project_list):
1209        """
1210        Scan a directory and build a COMPsc instance mapping provided project names
1211        to their paths.
1212
1213        Parameters
1214        ----------
1215        path_to_directory : str
1216            Path containing project subfolders.
1217
1218        project_list : list[str]
1219            List of filenames (folder names) to include in the returned object map.
1220
1221        Returns
1222        -------
1223        COMPsc
1224            New COMPsc instance with `objects` populated.
1225
1226        Raises
1227        ------
1228        Exception
1229            A generic exception is caught and a message printed if scanning fails.
1230
1231        Notes
1232        -----
1233        Function attempts to match entries in `project_list` to directory
1234        names and constructs a simplified object key from the folder name.
1235        """
1236        try:
1237            objects = {}
1238            for filename in tqdm(os.listdir(path_to_directory)):
1239                for c in project_list:
1240                    f = os.path.join(path_to_directory, filename)
1241                    if c == filename and os.path.isdir(f):
1242                        objects[str(c)] = f
1243
1244            return cls(objects)
1245
1246        except:
1247            print("Something went wrong. Check the function input data and try again!")
1248
1249    def save_project(self, name, path: str = os.getcwd()):
1250        """
1251        Save the COMPsc object to disk using pickle.
1252
1253        Parameters
1254        ----------
1255        name : str
1256            Base filename (without extension) to use when saving.
1257
1258        path : str, default os.getcwd()
1259            Directory in which to save the project file.
1260
1261        Returns
1262        -------
1263        None
1264
1265        Side Effects
1266        ------------
1267        - Writes a file `<path>/<name>.jpkl` containing the pickled object.
1268        - Prints a confirmation message with saved path.
1269        """
1270
1271        full = os.path.join(path, f"{name}.jpkl")
1272
1273        with open(full, "wb") as f:
1274            pickle.dump(self, f)
1275
1276        print(f"Project saved as {full}")
1277
1278    @classmethod
1279    def load_project(cls, path):
1280        """
1281        Load a previously saved COMPsc project from a pickle file.
1282
1283        Parameters
1284        ----------
1285        path : str
1286            Full path to the pickled project file.
1287
1288        Returns
1289        -------
1290        COMPsc
1291            The unpickled COMPsc object.
1292
1293        Raises
1294        ------
1295        FileNotFoundError
1296            If the provided path does not exist.
1297        """
1298
1299        if not os.path.exists(path):
1300            raise FileNotFoundError("File does not exist!")
1301        with open(path, "rb") as f:
1302            obj = pickle.load(f)
1303        return obj
1304
1305    def reduce_cols(
1306        self,
1307        reg: str | None = None,
1308        full: str | None = None,
1309        name_slot: str = "cell_names",
1310        inc_set: bool = False,
1311    ):
1312        """
1313        Remove columns (cells) whose names contain a substring `reg` or
1314        full name `full` from available tables.
1315
1316        Parameters
1317        ----------
1318        reg : str | None
1319            Substring to search for in column/cell names; matching columns will be removed.
1320            If not None, `full` must be None.
1321
1322        full : str | None
1323            Full name to search for in column/cell names; matching columns will be removed.
1324            If not None, `reg` must be None.
1325
1326        name_slot : str, default 'cell_names'
1327            Column in metadata to use as sample names.
1328
1329        inc_set : bool, default False
1330            If True, column names are interpreted as 'cell_name # set' when matching.
1331
1332        Update
1333        ------------
1334        Mutates `self.input_data`, `self.normalized_data`, `self.input_metadata`,
1335        `self.agg_normalized_data`, and `self.agg_metadata` (if they exist),
1336        removing columns/rows that match `reg`.
1337
1338        Raises
1339        ------
1340        Raises ValueError if nothing matches the reduction mask.
1341        """
1342
1343        if reg is None and full is None:
1344            raise ValueError(
1345                "Both 'reg' and 'full' arguments not provided. Please provide at least one of them!"
1346            )
1347
1348        if reg is not None and full is not None:
1349            raise ValueError(
1350                "Both 'reg' and 'full' arguments are provided. "
1351                "Please provide only one of them!\n"
1352                "'reg' is used when only part of the name must be detected.\n"
1353                "'full' is used if the full name must be detected."
1354            )
1355
1356        if reg is not None:
1357
1358            if self.input_data is not None:
1359
1360                if inc_set:
1361
1362                    self.input_data.columns = (
1363                        self.input_metadata[name_slot]
1364                        + " # "
1365                        + self.input_metadata["sets"]
1366                    )
1367
1368                else:
1369
1370                    self.input_data.columns = self.input_metadata[name_slot]
1371
1372                mask = [reg.upper() not in x.upper() for x in self.input_data.columns]
1373
1374                if len([y for y in mask if y is False]) == 0:
1375                    raise ValueError("Nothing found to reduce")
1376
1377                self.input_data = self.input_data.loc[:, mask]
1378
1379            if self.normalized_data is not None:
1380
1381                if inc_set:
1382
1383                    self.normalized_data.columns = (
1384                        self.input_metadata[name_slot]
1385                        + " # "
1386                        + self.input_metadata["sets"]
1387                    )
1388
1389                else:
1390
1391                    self.normalized_data.columns = self.input_metadata[name_slot]
1392
1393                mask = [
1394                    reg.upper() not in x.upper() for x in self.normalized_data.columns
1395                ]
1396
1397                if len([y for y in mask if y is False]) == 0:
1398                    raise ValueError("Nothing found to reduce")
1399
1400                self.normalized_data = self.normalized_data.loc[:, mask]
1401
1402            if self.input_metadata is not None:
1403
1404                if inc_set:
1405
1406                    self.input_metadata["drop"] = (
1407                        self.input_metadata[name_slot]
1408                        + " # "
1409                        + self.input_metadata["sets"]
1410                    )
1411
1412                else:
1413
1414                    self.input_metadata["drop"] = self.input_metadata[name_slot]
1415
1416                mask = [
1417                    reg.upper() not in x.upper() for x in self.input_metadata["drop"]
1418                ]
1419
1420                if len([y for y in mask if y is False]) == 0:
1421                    raise ValueError("Nothing found to reduce")
1422
1423                self.input_metadata = self.input_metadata.loc[mask, :].reset_index(
1424                    drop=True
1425                )
1426
1427                self.input_metadata = self.input_metadata.drop(
1428                    columns=["drop"], errors="ignore"
1429                )
1430
1431            if self.agg_normalized_data is not None:
1432
1433                if inc_set:
1434
1435                    self.agg_normalized_data.columns = (
1436                        self.agg_metadata[name_slot] + " # " + self.agg_metadata["sets"]
1437                    )
1438
1439                else:
1440
1441                    self.agg_normalized_data.columns = self.agg_metadata[name_slot]
1442
1443                mask = [
1444                    reg.upper() not in x.upper()
1445                    for x in self.agg_normalized_data.columns
1446                ]
1447
1448                if len([y for y in mask if y is False]) == 0:
1449                    raise ValueError("Nothing found to reduce")
1450
1451                self.agg_normalized_data = self.agg_normalized_data.loc[:, mask]
1452
1453            if self.agg_metadata is not None:
1454
1455                if inc_set:
1456
1457                    self.agg_metadata["drop"] = (
1458                        self.agg_metadata[name_slot] + " # " + self.agg_metadata["sets"]
1459                    )
1460
1461                else:
1462
1463                    self.agg_metadata["drop"] = self.agg_metadata[name_slot]
1464
1465                mask = [reg.upper() not in x.upper() for x in self.agg_metadata["drop"]]
1466
1467                if len([y for y in mask if y is False]) == 0:
1468                    raise ValueError("Nothing found to reduce")
1469
1470                self.agg_metadata = self.agg_metadata.loc[mask, :].reset_index(
1471                    drop=True
1472                )
1473
1474                self.agg_metadata = self.agg_metadata.drop(
1475                    columns=["drop"], errors="ignore"
1476                )
1477
1478        elif full is not None:
1479
1480            if self.input_data is not None:
1481
1482                if inc_set:
1483
1484                    self.input_data.columns = (
1485                        self.input_metadata[name_slot]
1486                        + " # "
1487                        + self.input_metadata["sets"]
1488                    )
1489
1490                    if "#" not in full:
1491
1492                        self.input_data.columns = self.input_metadata[name_slot]
1493
1494                        print(
1495                            "Not include the set info (name # set) in the 'full' argument, where 'inc_set' is True"
1496                            "Only the names will be compared, without considering the set information."
1497                        )
1498
1499                else:
1500
1501                    self.input_data.columns = self.input_metadata[name_slot]
1502
1503                mask = [full.upper() != x.upper() for x in self.input_data.columns]
1504
1505                if len([y for y in mask if y is False]) == 0:
1506                    raise ValueError("Nothing found to reduce")
1507
1508                self.input_data = self.input_data.loc[:, mask]
1509
1510            if self.normalized_data is not None:
1511
1512                if inc_set:
1513
1514                    self.normalized_data.columns = (
1515                        self.input_metadata[name_slot]
1516                        + " # "
1517                        + self.input_metadata["sets"]
1518                    )
1519
1520                    if "#" not in full:
1521
1522                        self.normalized_data.columns = self.input_metadata[name_slot]
1523
1524                        print(
1525                            "Not include the set info (name # set) in the 'full' argument, where 'inc_set' is True"
1526                            "Only the names will be compared, without considering the set information."
1527                        )
1528
1529                else:
1530
1531                    self.normalized_data.columns = self.input_metadata[name_slot]
1532
1533                mask = [full.upper() != x.upper() for x in self.normalized_data.columns]
1534
1535                if len([y for y in mask if y is False]) == 0:
1536                    raise ValueError("Nothing found to reduce")
1537
1538                self.normalized_data = self.normalized_data.loc[:, mask]
1539
1540            if self.input_metadata is not None:
1541
1542                if inc_set:
1543
1544                    self.input_metadata["drop"] = (
1545                        self.input_metadata[name_slot]
1546                        + " # "
1547                        + self.input_metadata["sets"]
1548                    )
1549
1550                    if "#" not in full:
1551
1552                        self.input_metadata["drop"] = self.input_metadata[name_slot]
1553
1554                        print(
1555                            "Not include the set info (name # set) in the 'full' argument, where 'inc_set' is True"
1556                            "Only the names will be compared, without considering the set information."
1557                        )
1558
1559                else:
1560
1561                    self.input_metadata["drop"] = self.input_metadata[name_slot]
1562
1563                mask = [full.upper() != x.upper() for x in self.input_metadata["drop"]]
1564
1565                if len([y for y in mask if y is False]) == 0:
1566                    raise ValueError("Nothing found to reduce")
1567
1568                self.input_metadata = self.input_metadata.loc[mask, :].reset_index(
1569                    drop=True
1570                )
1571
1572                self.input_metadata = self.input_metadata.drop(
1573                    columns=["drop"], errors="ignore"
1574                )
1575
1576            if self.agg_normalized_data is not None:
1577
1578                if inc_set:
1579
1580                    self.agg_normalized_data.columns = (
1581                        self.agg_metadata[name_slot] + " # " + self.agg_metadata["sets"]
1582                    )
1583
1584                    if "#" not in full:
1585
1586                        self.agg_normalized_data.columns = self.agg_metadata[name_slot]
1587
1588                        print(
1589                            "Not include the set info (name # set) in the 'full' argument, where 'inc_set' is True"
1590                            "Only the names will be compared, without considering the set information."
1591                        )
1592                else:
1593
1594                    self.agg_normalized_data.columns = self.agg_metadata[name_slot]
1595
1596                mask = [
1597                    full.upper() != x.upper() for x in self.agg_normalized_data.columns
1598                ]
1599
1600                if len([y for y in mask if y is False]) == 0:
1601                    raise ValueError("Nothing found to reduce")
1602
1603                self.agg_normalized_data = self.agg_normalized_data.loc[:, mask]
1604
1605            if self.agg_metadata is not None:
1606
1607                if inc_set:
1608
1609                    self.agg_metadata["drop"] = (
1610                        self.agg_metadata[name_slot] + " # " + self.agg_metadata["sets"]
1611                    )
1612
1613                    if "#" not in full:
1614
1615                        self.agg_metadata["drop"] = self.agg_metadata[name_slot]
1616
1617                        print(
1618                            "Not include the set info (name # set) in the 'full' argument, where 'inc_set' is True"
1619                            "Only the names will be compared, without considering the set information."
1620                        )
1621                else:
1622
1623                    self.agg_metadata["drop"] = self.agg_metadata[name_slot]
1624
1625                mask = [full.upper() != x.upper() for x in self.agg_metadata["drop"]]
1626
1627                if len([y for y in mask if y is False]) == 0:
1628                    raise ValueError("Nothing found to reduce")
1629
1630                self.agg_metadata = self.agg_metadata.loc[mask, :].reset_index(
1631                    drop=True
1632                )
1633
1634                self.agg_metadata = self.agg_metadata.drop(
1635                    columns=["drop"], errors="ignore"
1636                )
1637
1638        self.gene_calculation()
1639        self.cells_calculation()
1640
1641    def reduce_rows(self, features_list: list):
1642        """
1643        Remove rows (features) whose names are included in features_list.
1644
1645        Parameters
1646        ----------
1647        features_list : list
1648            List of features to search for in index/gene names; matching entries will be removed.
1649
1650        Update
1651        ------------
1652        Mutates `self.input_data`, `self.normalized_data`, `self.input_metadata`,
1653        `self.agg_normalized_data`, and `self.agg_metadata` (if they exist),
1654        removing columns/rows that match `reg`.
1655
1656        Raises
1657        ------
1658        Prints a message listing features that are not found in the data.
1659        """
1660
1661        if self.input_data is not None:
1662
1663            res = find_features(self.input_data, features=features_list)
1664
1665            res_list = [x.upper() for x in res["included"]]
1666
1667            mask = [x.upper() not in res_list for x in self.input_data.index]
1668
1669            if len([y for y in mask if y is False]) == 0:
1670                raise ValueError("Nothing found to reduce")
1671
1672            self.input_data = self.input_data.loc[mask, :]
1673
1674        if self.normalized_data is not None:
1675
1676            res = find_features(self.normalized_data, features=features_list)
1677
1678            res_list = [x.upper() for x in res["included"]]
1679
1680            mask = [x.upper() not in res_list for x in self.normalized_data.index]
1681
1682            if len([y for y in mask if y is False]) == 0:
1683                raise ValueError("Nothing found to reduce")
1684
1685            self.normalized_data = self.normalized_data.loc[mask, :]
1686
1687        if self.agg_normalized_data is not None:
1688
1689            res = find_features(self.agg_normalized_data, features=features_list)
1690
1691            res_list = [x.upper() for x in res["included"]]
1692
1693            mask = [x.upper() not in res_list for x in self.agg_normalized_data.index]
1694
1695            if len([y for y in mask if y is False]) == 0:
1696                raise ValueError("Nothing found to reduce")
1697
1698            self.agg_normalized_data = self.agg_normalized_data.loc[mask, :]
1699
1700        if len(res["not_included"]) > 0:
1701            print("\nFeatures not found:")
1702            for i in res["not_included"]:
1703                print(i)
1704
1705        self.gene_calculation()
1706        self.cells_calculation()
1707
1708    def get_data(self, set_info: bool = False):
1709        """
1710        Return normalized data with optional set annotation appended to column names.
1711
1712        Parameters
1713        ----------
1714        set_info : bool, default False
1715            If True, column names are returned as "cell_name # set"; otherwise
1716            only the `cell_name` is used.
1717
1718        Returns
1719        -------
1720        pandas.DataFrame
1721            The `self.normalized_data` table with columns renamed according to `set_info`.
1722
1723        Raises
1724        ------
1725        AttributeError
1726            If `self.normalized_data` or `self.input_metadata` is missing.
1727        """
1728
1729        to_return = self.normalized_data
1730
1731        if set_info:
1732            to_return.columns = (
1733                self.input_metadata["cell_names"] + " # " + self.input_metadata["sets"]
1734            )
1735        else:
1736            to_return.columns = self.input_metadata["cell_names"]
1737
1738        return to_return
1739
1740    def get_partial_data(
1741        self,
1742        names: list | str | None = None,
1743        features: list | str | None = None,
1744        name_slot: str = "cell_names",
1745        inc_metadata: bool = False,
1746    ):
1747        """
1748        Return a subset of the data filtered by sample names and/or feature names.
1749
1750        Parameters
1751        ----------
1752        names : list, str, or None
1753            Names of samples to include. If None, all samples are considered.
1754
1755        features : list, str, or None
1756            Names of features to include. If None, all features are considered.
1757
1758        name_slot : str
1759            Column in metadata to use as sample names.
1760
1761        inc_metadata : bool
1762            If True return tuple (data, metadata)
1763
1764        Returns
1765        -------
1766        pandas.DataFrame
1767            Subset of the normalized data based on the specified names and features.
1768        """
1769
1770        data = self.normalized_data.copy()
1771        metadata = self.input_metadata
1772
1773        if name_slot in self.input_metadata.columns:
1774            data.columns = self.input_metadata[name_slot]
1775        else:
1776            raise ValueError("'name_slot' not occured in data!'")
1777
1778        if isinstance(features, str):
1779            features = [features]
1780        elif features is None:
1781            features = []
1782
1783        if isinstance(names, str):
1784            names = [names]
1785        elif names is None:
1786            names = []
1787
1788        features = [x.upper() for x in features]
1789        names = [x.upper() for x in names]
1790
1791        columns_names = [x.upper() for x in data.columns]
1792        features_names = [x.upper() for x in data.index]
1793
1794        columns_bool = [True if x in names else False for x in columns_names]
1795        features_bool = [True if x in features else False for x in features_names]
1796
1797        if True not in columns_bool and True not in features_bool:
1798            print("Missing 'names' and/or 'features'. Returning full dataset instead.")
1799            
1800        if True in columns_bool:
1801            data = data.loc[:, columns_bool]
1802            metadata = metadata.loc[columns_bool, :]
1803
1804        if True in features_bool:
1805            data = data.loc[features_bool, :]
1806
1807        not_in_features = [y for y in features if y not in features_names]
1808
1809        if len(not_in_features) > 0:
1810            print('\nThe following features were not found in data:')
1811            print('\n'.join(not_in_features))
1812
1813        not_in_names = [y for y in names if y not in columns_names]
1814
1815        if len(not_in_names) > 0:
1816            print('\nThe following names were not found in data:')
1817            print('\n'.join(not_in_names))
1818
1819        if inc_metadata:
1820            return data, metadata
1821        else:
1822            return data
1823
1824    def get_metadata(self):
1825        """
1826        Return the stored input metadata.
1827
1828        Returns
1829        -------
1830        pandas.DataFrame
1831            `self.input_metadata` (may be None if not set).
1832        """
1833
1834        to_return = self.input_metadata
1835
1836        return to_return
1837
1838    def gene_calculation(self):
1839        """
1840        Calculate and store per-cell counts (e.g., number of detected genes).
1841
1842        The method computes a binary (presence/absence) per cell and sums across
1843        features to produce `self.gene_calc`.
1844
1845        Update
1846        ------
1847        Sets `self.gene_calc` as a pandas.Series.
1848
1849        Side Effects
1850        ------------
1851        Uses `self.input_data` when available, otherwise `self.normalized_data`.
1852        """
1853
1854        if self.input_data is not None:
1855
1856            bin_col = self.input_data.columns.copy()
1857
1858            bin_col = bin_col.where(bin_col <= 0, 1)
1859
1860            sum_data = bin_col.sum(axis=0)
1861
1862            self.gene_calc = sum_data
1863
1864        elif self.normalized_data is not None:
1865
1866            bin_col = self.normalized_data.copy()
1867
1868            bin_col = bin_col.where(bin_col <= 0, 1)
1869
1870            sum_data = bin_col.sum(axis=0)
1871
1872            self.gene_calc = sum_data
1873
1874    def gene_histograme(self, bins=100):
1875        """
1876        Plot a histogram of the number of genes detected per cell.
1877
1878        Parameters
1879        ----------
1880        bins : int, default 100
1881            Number of histogram bins.
1882
1883        Returns
1884        -------
1885        matplotlib.figure.Figure
1886            Figure containing the histogram of gene contents.
1887
1888        Notes
1889        -----
1890        Requires `self.gene_calc` to be computed prior to calling.
1891        """
1892
1893        fig, ax = plt.subplots(figsize=(8, 5))
1894
1895        _, bin_edges, _ = ax.hist(
1896            self.gene_calc, bins=bins, edgecolor="black", alpha=0.6
1897        )
1898
1899        mu, sigma = np.mean(self.gene_calc), np.std(self.gene_calc)
1900
1901        x = np.linspace(min(self.gene_calc), max(self.gene_calc), 1000)
1902        y = norm.pdf(x, mu, sigma)
1903
1904        y_scaled = y * len(self.gene_calc) * (bin_edges[1] - bin_edges[0])
1905
1906        ax.plot(
1907            x, y_scaled, "r-", linewidth=2, label=f"Normal(μ={mu:.2f}, σ={sigma:.2f})"
1908        )
1909
1910        ax.set_xlabel("Value")
1911        ax.set_ylabel("Count")
1912        ax.set_title("Histogram of genes detected per cell")
1913
1914        ax.set_xticks(np.linspace(min(self.gene_calc), max(self.gene_calc), 20))
1915        ax.tick_params(axis="x", rotation=90)
1916
1917        ax.legend()
1918
1919        return fig
1920
1921    def gene_threshold(self, min_n: int | None, max_n: int | None):
1922        """
1923        Filter cells by gene-detection thresholds (min and/or max).
1924
1925        Parameters
1926        ----------
1927        min_n : int or None
1928            Minimum number of detected genes required to keep a cell.
1929
1930        max_n : int or None
1931            Maximum number of detected genes allowed to keep a cell.
1932
1933        Update
1934        -------
1935        Filters `self.input_data`, `self.normalized_data`, `self.input_metadata`
1936        (and calls `average()` if `self.agg_normalized_data` exists).
1937
1938        Side Effects
1939        ------------
1940        Raises ValueError if both bounds are None or if filtering removes all cells.
1941        """
1942
1943        if min_n is not None and max_n is not None:
1944            mask = (self.gene_calc > min_n) & (self.gene_calc < max_n)
1945        elif min_n is None and max_n is not None:
1946            mask = self.gene_calc < max_n
1947        elif min_n is not None and max_n is None:
1948            mask = self.gene_calc > min_n
1949        else:
1950            raise ValueError("Lack of both min_n and max_n values")
1951
1952        if self.input_data is not None:
1953
1954            if len([y for y in mask if y is False]) == 0:
1955                raise ValueError("Nothing to reduce")
1956
1957            self.input_data = self.input_data.loc[:, mask.values]
1958
1959        if self.normalized_data is not None:
1960
1961            if len([y for y in mask if y is False]) == 0:
1962                raise ValueError("Nothing to reduce")
1963
1964            self.normalized_data = self.normalized_data.loc[:, mask.values]
1965
1966        if self.input_metadata is not None:
1967
1968            if len([y for y in mask if y is False]) == 0:
1969                raise ValueError("Nothing to reduce")
1970
1971            self.input_metadata = self.input_metadata.loc[mask.values, :].reset_index(
1972                drop=True
1973            )
1974
1975            self.input_metadata = self.input_metadata.drop(
1976                columns=["drop"], errors="ignore"
1977            )
1978
1979        if self.agg_normalized_data is not None:
1980            self.average()
1981
1982        self.gene_calculation()
1983        self.cells_calculation()
1984
1985    def cells_calculation(self, name_slot="cell_names"):
1986        """
1987        Calculate number of cells per  call name / cluster.
1988
1989        The method computes a binary (presence/absence) per cell name / cluster and sums across
1990        cells.
1991
1992        Parameters
1993        ----------
1994        name_slot : str, default 'cell_names'
1995            Column in metadata to use as sample names.
1996
1997        Update
1998        ------
1999        Sets `self.cells_calc` as a pd.DataFrame.
2000        """
2001
2002        ls = list(self.input_metadata[name_slot])
2003
2004        df = pd.DataFrame(
2005            {
2006                "cluster": pd.Series(ls).value_counts().index,
2007                "n": pd.Series(ls).value_counts().values,
2008            }
2009        )
2010
2011        self.cells_calc = df
2012
2013    def cell_histograme(self, name_slot: str = "cell_names"):
2014        """
2015        Plot a histogram of the number of cells detected per cell name (cluster).
2016
2017        Parameters
2018        ----------
2019        name_slot : str, default 'cell_names'
2020            Column in metadata to use as sample names.
2021
2022        Returns
2023        -------
2024        matplotlib.figure.Figure
2025            Figure containing the histogram of cell contents.
2026
2027        Notes
2028        -----
2029        Requires `self.cells_calc` to be computed prior to calling.
2030        """
2031
2032        if name_slot != "cell_names":
2033            self.cells_calculation(name_slot=name_slot)
2034
2035        fig, ax = plt.subplots(figsize=(8, 5))
2036
2037        _, bin_edges, _ = ax.hist(
2038            list(self.cells_calc["n"]),
2039            bins=len(set(self.cells_calc["cluster"])),
2040            edgecolor="black",
2041            color="orange",
2042            alpha=0.6,
2043        )
2044
2045        mu, sigma = np.mean(list(self.cells_calc["n"])), np.std(
2046            list(self.cells_calc["n"])
2047        )
2048
2049        x = np.linspace(
2050            min(list(self.cells_calc["n"])), max(list(self.cells_calc["n"])), 1000
2051        )
2052        y = norm.pdf(x, mu, sigma)
2053
2054        y_scaled = y * len(list(self.cells_calc["n"])) * (bin_edges[1] - bin_edges[0])
2055
2056        ax.plot(
2057            x, y_scaled, "r-", linewidth=2, label=f"Normal(μ={mu:.2f}, σ={sigma:.2f})"
2058        )
2059
2060        ax.set_xlabel("Value")
2061        ax.set_ylabel("Count")
2062        ax.set_title("Histogram of cells detected per cell name / cluster")
2063
2064        ax.set_xticks(
2065            np.linspace(
2066                min(list(self.cells_calc["n"])), max(list(self.cells_calc["n"])), 20
2067            )
2068        )
2069        ax.tick_params(axis="x", rotation=90)
2070
2071        ax.legend()
2072
2073        return fig
2074
2075    def cluster_threshold(self, min_n: int | None, name_slot: str = "cell_names"):
2076        """
2077        Filter cell names / clusters by cell-detection threshold.
2078
2079        Parameters
2080        ----------
2081        min_n : int or None
2082            Minimum number of detected genes required to keep a cell.
2083
2084        name_slot : str, default 'cell_names'
2085            Column in metadata to use as sample names.
2086
2087
2088        Update
2089        -------
2090        Filters `self.input_data`, `self.normalized_data`, `self.input_metadata`
2091        (and calls `average()` if `self.agg_normalized_data` exists).
2092        """
2093
2094        if name_slot != "cell_names":
2095            self.cells_calculation(name_slot=name_slot)
2096
2097        if min_n is not None:
2098            names = self.cells_calc["cluster"][self.cells_calc["n"] < min_n]
2099        else:
2100            raise ValueError("Lack of min_n value")
2101
2102        if len(names) > 0:
2103
2104            if self.input_data is not None:
2105
2106                self.input_data.columns = self.input_metadata[name_slot]
2107
2108                mask = [not any(r in x for r in names) for x in self.input_data.columns]
2109
2110                if len([y for y in mask if y is False]) > 0:
2111
2112                    self.input_data = self.input_data.loc[:, mask]
2113
2114            if self.normalized_data is not None:
2115
2116                self.normalized_data.columns = self.input_metadata[name_slot]
2117
2118                mask = [
2119                    not any(r in x for r in names) for x in self.normalized_data.columns
2120                ]
2121
2122                if len([y for y in mask if y is False]) > 0:
2123
2124                    self.normalized_data = self.normalized_data.loc[:, mask]
2125
2126            if self.input_metadata is not None:
2127
2128                self.input_metadata["drop"] = self.input_metadata[name_slot]
2129
2130                mask = [
2131                    not any(r in x for r in names) for x in self.input_metadata["drop"]
2132                ]
2133
2134                if len([y for y in mask if y is False]) > 0:
2135
2136                    self.input_metadata = self.input_metadata.loc[mask, :].reset_index(
2137                        drop=True
2138                    )
2139
2140                self.input_metadata = self.input_metadata.drop(
2141                    columns=["drop"], errors="ignore"
2142                )
2143
2144            if self.agg_normalized_data is not None:
2145
2146                self.agg_normalized_data.columns = self.agg_metadata[name_slot]
2147
2148                mask = [
2149                    not any(r in x for r in names)
2150                    for x in self.agg_normalized_data.columns
2151                ]
2152
2153                if len([y for y in mask if y is False]) > 0:
2154
2155                    self.agg_normalized_data = self.agg_normalized_data.loc[:, mask]
2156
2157            if self.agg_metadata is not None:
2158
2159                self.agg_metadata["drop"] = self.agg_metadata[name_slot]
2160
2161                mask = [
2162                    not any(r in x for r in names) for x in self.agg_metadata["drop"]
2163                ]
2164
2165                if len([y for y in mask if y is False]) > 0:
2166
2167                    self.agg_metadata = self.agg_metadata.loc[mask, :].reset_index(
2168                        drop=True
2169                    )
2170
2171                self.agg_metadata = self.agg_metadata.drop(
2172                    columns=["drop"], errors="ignore"
2173                )
2174
2175            self.gene_calculation()
2176            self.cells_calculation()
2177
2178    def load_sparse_from_projects(self, normalized_data: bool = False):
2179        """
2180        Load sparse 10x-style datasets from stored project paths, concatenate them,
2181        and populate `input_data` / `normalized_data` and `input_metadata`.
2182
2183        Parameters
2184        ----------
2185        normalized_data : bool, default False
2186            If True, store concatenated tables in `self.normalized_data`.
2187            If False, store them in `self.input_data` and normalization
2188            is needed using normalize_data() method.
2189
2190        Side Effects
2191        ------------
2192        - Reads each project using `load_sparse(...)` (expects matrix.mtx, genes.tsv, barcodes.tsv).
2193        - Concatenates all projects column-wise and sets `self.input_metadata`.
2194        - Replaces NaNs with zeros and updates `self.gene_calc`.
2195        """
2196
2197        obj = self.objects
2198
2199        full_data = pd.DataFrame()
2200        full_metadata = pd.DataFrame()
2201
2202        for ke in obj.keys():
2203            print(ke)
2204
2205            dt, met = load_sparse(path=obj[ke], name=ke)
2206
2207            full_data = pd.concat([full_data, dt], axis=1)
2208            full_metadata = pd.concat([full_metadata, met], axis=0)
2209
2210        full_data[np.isnan(full_data)] = 0
2211
2212        if normalized_data:
2213            self.normalized_data = full_data
2214            self.input_metadata = full_metadata
2215        else:
2216
2217            self.input_data = full_data
2218            self.input_metadata = full_metadata
2219
2220        self.gene_calculation()
2221        self.cells_calculation()
2222
2223    def rename_names(self, mapping: dict, slot: str = "cell_names"):
2224        """
2225        Rename entries in `self.input_metadata[slot]` according to a provided mapping.
2226
2227        Parameters
2228        ----------
2229        mapping : dict
2230            Dictionary with keys 'old_name' and 'new_name', each mapping to a list
2231            of equal length describing replacements.
2232
2233        slot : str, default 'cell_names'
2234            Metadata column to operate on.
2235
2236        Update
2237        -------
2238        Updates `self.input_metadata[slot]` in-place with renamed values.
2239
2240        Raises
2241        ------
2242        ValueError
2243            If mapping keys are incorrect, lengths differ, or some 'old_name' values
2244            are not present in the metadata column.
2245        """
2246
2247        if set(["old_name", "new_name"]) != set(mapping.keys()):
2248            raise ValueError(
2249                "Mapping dictionary must contain keys 'old_name' and 'new_name', "
2250                "each with a list of names to change."
2251            )
2252
2253        if len(mapping["old_name"]) != len(mapping["new_name"]):
2254            raise ValueError(
2255                "Mapping dictionary lists 'old_name' and 'new_name' "
2256                "must have the same length!"
2257            )
2258
2259        names_vector = list(self.input_metadata[slot])
2260
2261        if not all(elem in names_vector for elem in list(mapping["old_name"])):
2262            raise ValueError(
2263                f"Some entries from 'old_name' do not exist in the names of slot {slot}."
2264            )
2265
2266        replace_dict = dict(zip(mapping["old_name"], mapping["new_name"]))
2267
2268        names_vector_ret = [replace_dict.get(item, item) for item in names_vector]
2269
2270        self.input_metadata[slot] = names_vector_ret
2271
2272    def rename_subclusters(self, mapping):
2273        """
2274        Rename labels stored in `self.subclusters_.subclusters` according to mapping.
2275
2276        Parameters
2277        ----------
2278        mapping : dict
2279            Mapping with keys 'old_name' and 'new_name' (lists of equal length).
2280
2281        Update
2282        -------
2283        Updates `self.subclusters_.subclusters` with renamed labels.
2284
2285        Raises
2286        ------
2287        ValueError
2288            If mapping is invalid or old names are not present.
2289        """
2290
2291        if set(["old_name", "new_name"]) != set(mapping.keys()):
2292            raise ValueError(
2293                "Mapping dictionary must contain keys 'old_name' and 'new_name', "
2294                "each with a list of names to change."
2295            )
2296
2297        if len(mapping["old_name"]) != len(mapping["new_name"]):
2298            raise ValueError(
2299                "Mapping dictionary lists 'old_name' and 'new_name' "
2300                "must have the same length!"
2301            )
2302
2303        names_vector = list(self.subclusters_.subclusters)
2304
2305        if not all(elem in names_vector for elem in list(mapping["old_name"])):
2306            raise ValueError(
2307                "Some entries from 'old_name' do not exist in the subcluster names."
2308            )
2309
2310        replace_dict = dict(zip(mapping["old_name"], mapping["new_name"]))
2311
2312        names_vector_ret = [replace_dict.get(item, item) for item in names_vector]
2313
2314        self.subclusters_.subclusters = names_vector_ret
2315
2316    def save_sparse(
2317        self,
2318        path_to_save: str = os.getcwd(),
2319        name_slot: str = "cell_names",
2320        data_slot: str = "normalized",
2321    ):
2322        """
2323        Export data as 10x-compatible sparse files (matrix.mtx, barcodes.tsv, genes.tsv).
2324
2325        Parameters
2326        ----------
2327        path_to_save : str, default current working directory
2328            Directory where files will be written.
2329
2330        name_slot : str, default 'cell_names'
2331            Metadata column providing cell names for barcodes.tsv.
2332
2333        data_slot : str, default 'normalized'
2334            Either 'normalized' (uses self.normalized_data) or 'count' (uses self.input_data).
2335
2336        Raises
2337        ------
2338        ValueError
2339            If `data_slot` is not 'normalized' or 'count'.
2340        """
2341
2342        names = self.input_metadata[name_slot]
2343
2344        if data_slot.lower() == "normalized":
2345
2346            features = list(self.normalized_data.index)
2347            mtx = sparse.csr_matrix(self.normalized_data)
2348
2349        elif data_slot.lower() == "count":
2350
2351            features = list(self.input_data.index)
2352            mtx = sparse.csr_matrix(self.input_data)
2353
2354        else:
2355            raise ValueError("'data_slot' must be included in 'normalized' or 'count'")
2356
2357        os.makedirs(path_to_save, exist_ok=True)
2358
2359        mmwrite(os.path.join(path_to_save, "matrix.mtx"), mtx)
2360
2361        pd.Series(names).to_csv(
2362            os.path.join(path_to_save, "barcodes.tsv"),
2363            index=False,
2364            header=False,
2365            sep="\t",
2366        )
2367
2368        pd.Series(features).to_csv(
2369            os.path.join(path_to_save, "genes.tsv"), index=False, header=False, sep="\t"
2370        )
2371
2372    def normalize_counts(
2373        self, normalize_factor: int = 100000, log_transform: bool = True
2374    ):
2375        """
2376        Normalize raw counts to counts-per-(normalize_factor)
2377        (e.g., CPM, TPM - depending on normalize_factor).
2378
2379        Parameters
2380        ----------
2381        normalize_factor : int, default 100000
2382            Scaling factor used after dividing by column sums.
2383
2384        log_transform : bool, default True
2385            If True, apply log2(x+1) transformation to normalized values.
2386
2387        Update
2388        -------
2389            Sets `self.normalized_data` to normalized values (fills NaNs with 0).
2390
2391        Raises
2392        ------
2393        ValueError
2394            If `self.input_data` is missing (cannot normalize).
2395        """
2396        if self.input_data is None:
2397            raise ValueError("Input data is missing, cannot normalize.")
2398
2399        sum_col = self.input_data.sum()
2400        self.normalized_data = self.input_data.div(sum_col).fillna(0) * normalize_factor
2401
2402        if log_transform:
2403            # log2(x + 1) to avoid -inf for zeros
2404            self.normalized_data = np.log2(self.normalized_data + 1)
2405
2406    def statistic(
2407        self,
2408        cells=None,
2409        sets=None,
2410        min_exp: float = 0.01,
2411        min_pct: float = 0.1,
2412        n_proc: int = 10,
2413    ):
2414        """
2415        Compute per-feature statistics (Mann–Whitney U) comparing target vs rest.
2416
2417        This is a wrapper similar to `calc_DEG` tailored to use `self.normalized_data`
2418        and `self.input_metadata`. It returns per-feature statistics including p-values,
2419        adjusted p-values, means, variances, effect-size measures and fold-changes.
2420
2421        Parameters
2422        ----------
2423        cells : list, 'All', dict, or None
2424            Defines the target cells or groups for comparison (several modes supported).
2425
2426        sets : 'All', dict, or None
2427            Alternative grouping mode (operate on `self.input_metadata['sets']`).
2428
2429        min_exp : float, default 0.01
2430            Minimum expression threshold used when filtering features.
2431
2432        min_pct : float, default 0.1
2433            Minimum proportion of expressing cells in the target group required to test a feature.
2434
2435        n_proc : int, default 10
2436            Number of parallel jobs to use.
2437
2438        Returns
2439        -------
2440        pandas.DataFrame or dict
2441            Results DataFrame (or dict containing valid/control cells + DataFrame),
2442            similar to `calc_DEG` interface.
2443
2444        Raises
2445        ------
2446        ValueError
2447            If neither `cells` nor `sets` is provided, or input metadata mismatch occurs.
2448
2449        Notes
2450        -----
2451        Multiple modes supported: single-list entities, 'All', pairwise dicts, etc.
2452        """
2453
2454        offset = 1e-100
2455
2456        def stat_calc(choose, feature_name):
2457            target_values = choose.loc[choose["DEG"] == "target", feature_name]
2458            rest_values = choose.loc[choose["DEG"] == "rest", feature_name]
2459
2460            pct_valid = (target_values > 0).sum() / len(target_values)
2461            pct_rest = (rest_values > 0).sum() / len(rest_values)
2462
2463            avg_valid = np.mean(target_values)
2464            avg_ctrl = np.mean(rest_values)
2465            sd_valid = np.std(target_values, ddof=1)
2466            sd_ctrl = np.std(rest_values, ddof=1)
2467            esm = (avg_valid - avg_ctrl) / np.sqrt(((sd_valid**2 + sd_ctrl**2) / 2))
2468
2469            if np.sum(target_values) == np.sum(rest_values):
2470                p_val = 1.0
2471            else:
2472                _, p_val = stats.mannwhitneyu(
2473                    target_values, rest_values, alternative="two-sided"
2474                )
2475
2476            return {
2477                "feature": feature_name,
2478                "p_val": p_val,
2479                "pct_valid": pct_valid,
2480                "pct_ctrl": pct_rest,
2481                "avg_valid": avg_valid,
2482                "avg_ctrl": avg_ctrl,
2483                "sd_valid": sd_valid,
2484                "sd_ctrl": sd_ctrl,
2485                "esm": esm,
2486            }
2487
2488        def prepare_and_run_stat(choose, valid_group, min_exp, min_pct, n_proc):
2489
2490            def safe_min_half(series):
2491                filtered = series[(series > ((2**-1074)*2)) & (series.notna())]
2492                return filtered.min() / 2 if not filtered.empty else 0
2493        
2494            tmp_dat = choose[choose["DEG"] == "target"]
2495            tmp_dat = tmp_dat.drop("DEG", axis=1)
2496
2497            counts = (tmp_dat > min_exp).sum(axis=0)
2498
2499            total_count = tmp_dat.shape[0]
2500
2501            info = pd.DataFrame(
2502                {"feature": list(tmp_dat.columns), "pct": list(counts / total_count)}
2503            )
2504
2505            del tmp_dat
2506
2507            drop_col = info["feature"][info["pct"] <= min_pct]
2508
2509            if len(drop_col) + 1 == len(choose.columns):
2510                drop_col = info["feature"][info["pct"] == 0]
2511
2512            del info
2513
2514            choose = choose.drop(list(drop_col), axis=1)
2515
2516            results = Parallel(n_jobs=n_proc)(
2517                delayed(stat_calc)(choose, feature)
2518                for feature in tqdm(choose.columns[choose.columns != "DEG"])
2519            )
2520
2521            df = pd.DataFrame(results)
2522            df = df[(df["avg_valid"] > 0) | (df["avg_ctrl"] > 0)]
2523
2524            df["valid_group"] = valid_group
2525            df.sort_values(by="p_val", inplace=True)
2526
2527            num_tests = len(df)
2528            df["adj_pval"] = np.minimum(
2529                1, (df["p_val"] * num_tests) / np.arange(1, num_tests + 1)
2530            )
2531
2532            valid_factor = safe_min_half(df["avg_valid"])
2533            ctrl_factor = safe_min_half(df["avg_ctrl"])
2534
2535            cv_factor = min(valid_factor, ctrl_factor)
2536
2537            if cv_factor == 0:
2538                cv_factor = max(valid_factor, ctrl_factor)
2539
2540            if not np.isfinite(cv_factor) or cv_factor == 0:
2541                cv_factor += offset
2542
2543            valid = df["avg_valid"].where(
2544                df["avg_valid"] != 0, df["avg_valid"] + cv_factor
2545            )
2546            ctrl = df["avg_ctrl"].where(
2547                df["avg_ctrl"] != 0, df["avg_ctrl"] + cv_factor
2548            )
2549
2550            df["FC"] = valid / ctrl
2551
2552            df["log(FC)"] = np.log2(df["FC"])
2553            df["norm_diff"] = df["avg_valid"] - df["avg_ctrl"]
2554
2555            return df
2556
2557        choose = self.normalized_data.copy().T
2558
2559        final_results = []
2560
2561        if isinstance(cells, list) and sets is None:
2562            print("\nAnalysis started...\nComparing selected cells to the whole set...")
2563            choose.index = (
2564                self.input_metadata["cell_names"] + " # " + self.input_metadata["sets"]
2565            )
2566
2567            if "#" not in cells[0]:
2568                choose.index = self.input_metadata["cell_names"]
2569
2570                print(
2571                    "Not include the set info (name # set) in the 'cells' list. "
2572                    "Only the names will be compared, without considering the set information."
2573                )
2574
2575            labels = ["target" if idx in cells else "rest" for idx in choose.index]
2576            valid = list(
2577                set(choose.index[[i for i, x in enumerate(labels) if x == "target"]])
2578            )
2579
2580            choose["DEG"] = labels
2581            choose = choose[choose["DEG"] != "drop"]
2582
2583            result_df = prepare_and_run_stat(
2584                choose.reset_index(drop=True),
2585                valid_group=valid,
2586                min_exp=min_exp,
2587                min_pct=min_pct,
2588                n_proc=n_proc,
2589            )
2590            return {"valid_cells": valid, "control_cells": "rest", "DEG": result_df}
2591
2592        elif cells == "All" and sets is None:
2593            print("\nAnalysis started...\nComparing each type of cell to others...")
2594            choose.index = (
2595                self.input_metadata["cell_names"] + " # " + self.input_metadata["sets"]
2596            )
2597            unique_labels = set(choose.index)
2598
2599            for label in tqdm(unique_labels):
2600                print(f"\nCalculating statistics for {label}")
2601                labels = ["target" if idx == label else "rest" for idx in choose.index]
2602                choose["DEG"] = labels
2603                choose = choose[choose["DEG"] != "drop"]
2604                result_df = prepare_and_run_stat(
2605                    choose.copy(),
2606                    valid_group=label,
2607                    min_exp=min_exp,
2608                    min_pct=min_pct,
2609                    n_proc=n_proc,
2610                )
2611                final_results.append(result_df)
2612
2613            return pd.concat(final_results, ignore_index=True)
2614
2615        elif cells is None and sets == "All":
2616            print("\nAnalysis started...\nComparing each set/group to others...")
2617            choose.index = self.input_metadata["sets"]
2618            unique_sets = set(choose.index)
2619
2620            for label in tqdm(unique_sets):
2621                print(f"\nCalculating statistics for {label}")
2622                labels = ["target" if idx == label else "rest" for idx in choose.index]
2623
2624                choose["DEG"] = labels
2625                choose = choose[choose["DEG"] != "drop"]
2626                result_df = prepare_and_run_stat(
2627                    choose.copy(),
2628                    valid_group=label,
2629                    min_exp=min_exp,
2630                    min_pct=min_pct,
2631                    n_proc=n_proc,
2632                )
2633                final_results.append(result_df)
2634
2635            return pd.concat(final_results, ignore_index=True)
2636
2637        elif cells is None and isinstance(sets, dict):
2638            print("\nAnalysis started...\nComparing groups...")
2639
2640            choose.index = self.input_metadata["sets"]
2641
2642            group_list = list(sets.keys())
2643            if len(group_list) != 2:
2644                print("Only pairwise group comparison is supported.")
2645                return None
2646
2647            labels = [
2648                (
2649                    "target"
2650                    if idx in sets[group_list[0]]
2651                    else "rest" if idx in sets[group_list[1]] else "drop"
2652                )
2653                for idx in choose.index
2654            ]
2655            choose["DEG"] = labels
2656            choose = choose[choose["DEG"] != "drop"]
2657
2658            result_df = prepare_and_run_stat(
2659                choose.reset_index(drop=True),
2660                valid_group=group_list[0],
2661                min_exp=min_exp,
2662                min_pct=min_pct,
2663                n_proc=n_proc,
2664            )
2665            return result_df
2666
2667        elif isinstance(cells, dict) and sets is None:
2668            print("\nAnalysis started...\nComparing groups...")
2669            choose.index = (
2670                self.input_metadata["cell_names"] + " # " + self.input_metadata["sets"]
2671            )
2672
2673            if "#" not in cells[list(cells.keys())[0]][0]:
2674                choose.index = self.input_metadata["cell_names"]
2675
2676                print(
2677                    "Not include the set info (name # set) in the 'cells' dict. "
2678                    "Only the names will be compared, without considering the set information."
2679                )
2680
2681            group_list = list(cells.keys())
2682            if len(group_list) != 2:
2683                print("Only pairwise group comparison is supported.")
2684                return None
2685
2686            labels = [
2687                (
2688                    "target"
2689                    if idx in cells[group_list[0]]
2690                    else "rest" if idx in cells[group_list[1]] else "drop"
2691                )
2692                for idx in choose.index
2693            ]
2694
2695            choose["DEG"] = labels
2696            choose = choose[choose["DEG"] != "drop"]
2697
2698            result_df = prepare_and_run_stat(
2699                choose.reset_index(drop=True),
2700                valid_group=group_list[0],
2701                min_exp=min_exp,
2702                min_pct=min_pct,
2703                n_proc=n_proc,
2704            )
2705
2706            return result_df.reset_index(drop=True)
2707
2708        else:
2709            raise ValueError(
2710                "You must specify either 'cells' or 'sets' (or both). None were provided, which is not allowed for this analysis."
2711            )
2712
2713    def calculate_difference_markers(
2714        self, min_exp=0, min_pct=0.25, n_proc=10, force: bool = False
2715    ):
2716        """
2717        Compute differential markers (var_data) if not already present.
2718
2719        Parameters
2720        ----------
2721        min_exp : float, default 0
2722            Minimum expression threshold passed to `statistic`.
2723
2724        min_pct : float, default 0.25
2725            Minimum percent expressed in target group.
2726
2727        n_proc : int, default 10
2728            Parallel jobs.
2729
2730        force : bool, default False
2731            If True, recompute even if `self.var_data` is present.
2732
2733        Update
2734        -------
2735        Sets `self.var_data` to the result of `self.statistic(...)`.
2736
2737        Raise
2738        ------
2739        ValueError if already computed and `force` is False.
2740        """
2741
2742        if self.var_data is None or force:
2743
2744            self.var_data = self.statistic(
2745                cells="All", sets=None, min_exp=min_exp, min_pct=min_pct, n_proc=n_proc
2746            )
2747
2748        else:
2749            raise ValueError(
2750                "self.calculate_difference_markers() has already been executed. "
2751                "The results are stored in self.var. "
2752                "If you want to recalculate with different statistics, please rerun the method with force=True."
2753            )
2754
2755    def clustering_features(
2756        self,
2757        features_list: list | None,
2758        name_slot: str = "cell_names",
2759        p_val: float = 0.05,
2760        top_n: int = 25,
2761        adj_mean: bool = True,
2762        beta: float = 0.2,
2763    ):
2764        """
2765        Prepare clustering input by selecting marker features and optionally smoothing cell values
2766        toward group means.
2767
2768        Parameters
2769        ----------
2770        features_list : list or None
2771            If provided, use this list of features. If None, features are selected
2772            from `self.var_data` (adj_pval <= p_val, positive logFC) picking `top_n` per group.
2773
2774        name_slot : str, default 'cell_names'
2775            Metadata column used for naming.
2776
2777        p_val : float, default 0.05
2778            Adjusted p-value cutoff when selecting features automatically.
2779
2780        top_n : int, default 25
2781            Number of top features per valid group to keep if `features_list` is None.
2782
2783        adj_mean : bool, default True
2784            If True, adjust cell values toward group means using `beta`.
2785
2786        beta : float, default 0.2
2787            Adjustment strength toward group mean.
2788
2789        Update
2790        ------
2791        Sets `self.clustering_data` and `self.clustering_metadata` to the selected subset,
2792        ready for PCA/UMAP/clustering.
2793        """
2794
2795        if features_list is None or len(features_list) == 0:
2796
2797            if self.var_data is None:
2798                raise ValueError(
2799                    "Lack of 'self.var_data'. Use self.calculate_difference_markers() method first."
2800                )
2801
2802            df_tmp = self.var_data[self.var_data["adj_pval"] <= p_val]
2803            df_tmp = df_tmp[df_tmp["log(FC)"] > 0]
2804            df_tmp = (
2805                df_tmp.sort_values(
2806                    ["valid_group", "esm", "log(FC)"], ascending=[True, False, False]
2807                )
2808                .groupby("valid_group")
2809                .head(top_n)
2810            )
2811
2812            feaures_list = list(set(df_tmp["feature"]))
2813
2814        data = self.get_partial_data(
2815            names=None, features=feaures_list, name_slot=name_slot
2816        )
2817        data_avg = average(data)
2818
2819        if adj_mean:
2820            data = adjust_cells_to_group_mean(data=data, data_avg=data_avg, beta=beta)
2821
2822        self.clustering_data = data
2823
2824        self.clustering_metadata = self.input_metadata
2825
2826    def average(self):
2827        """
2828        Aggregate normalized data by (cell_name, set) pairs computing the mean per group.
2829
2830        The method constructs new column names as "cell_name # set", averages columns
2831        sharing identical labels, and populates `self.agg_normalized_data` and `self.agg_metadata`.
2832
2833        Update
2834        ------
2835        Sets `self.agg_normalized_data` (features x aggregated samples) and
2836        `self.agg_metadata` (DataFrame with 'cell_names' and 'sets').
2837        """
2838
2839        wide_data = self.normalized_data
2840
2841        wide_metadata = self.input_metadata
2842
2843        new_names = wide_metadata["cell_names"] + " # " + wide_metadata["sets"]
2844
2845        wide_data.columns = list(new_names)
2846
2847        aggregated_df = wide_data.T.groupby(level=0).mean().T
2848
2849        sets = [re.sub(".*# ", "", x) for x in aggregated_df.columns]
2850        names = [re.sub(" #.*", "", x) for x in aggregated_df.columns]
2851
2852        aggregated_df.columns = names
2853        aggregated_metadata = pd.DataFrame({"cell_names": names, "sets": sets})
2854
2855        self.agg_metadata = aggregated_metadata
2856        self.agg_normalized_data = aggregated_df
2857
2858    def estimating_similarity(
2859        self, method="pearson", p_val: float = 0.05, top_n: int = 25
2860    ):
2861        """
2862        Estimate pairwise similarity and Euclidean distance between aggregated samples.
2863
2864        Parameters
2865        ----------
2866        method : str, default 'pearson'
2867            Correlation method to use (passed to pandas.DataFrame.corr()).
2868
2869        p_val : float, default 0.05
2870            Adjusted p-value cutoff used to select marker features from `self.var_data`.
2871
2872        top_n : int, default 25
2873            Number of top features per valid group to include.
2874
2875        Update
2876        -------
2877        Computes a combined table with per-pair correlation and euclidean distance
2878        and stores it in `self.similarity`.
2879        """
2880
2881        if self.var_data is None:
2882            raise ValueError(
2883                "Lack of 'self.var_data'. Use self.calculate_difference_markers() method first."
2884            )
2885
2886        if self.agg_normalized_data is None:
2887            self.average()
2888
2889        metadata = self.agg_metadata
2890        data = self.agg_normalized_data
2891
2892        df_tmp = self.var_data[self.var_data["adj_pval"] <= p_val]
2893        df_tmp = df_tmp[df_tmp["log(FC)"] > 0]
2894        df_tmp = (
2895            df_tmp.sort_values(
2896                ["valid_group", "esm", "log(FC)"], ascending=[True, False, False]
2897            )
2898            .groupby("valid_group")
2899            .head(top_n)
2900        )
2901
2902        data = data.loc[list(set(df_tmp["feature"]))]
2903
2904        if len(set(metadata["sets"])) > 1:
2905            data.columns = data.columns + " # " + [x for x in metadata["sets"]]
2906        else:
2907            data = data.copy()
2908
2909        scaler = StandardScaler()
2910
2911        scaled_data = scaler.fit_transform(data)
2912
2913        scaled_df = pd.DataFrame(scaled_data, columns=data.columns)
2914
2915        cor = scaled_df.corr(method=method)
2916        cor_df = cor.stack().reset_index()
2917        cor_df.columns = ["cell1", "cell2", "correlation"]
2918
2919        distances = pdist(scaled_df.T, metric="euclidean")
2920        dist_mat = pd.DataFrame(
2921            squareform(distances), index=scaled_df.columns, columns=scaled_df.columns
2922        )
2923        dist_df = dist_mat.stack().reset_index()
2924        dist_df.columns = ["cell1", "cell2", "euclidean_dist"]
2925
2926        full = pd.merge(cor_df, dist_df, on=["cell1", "cell2"])
2927
2928        full = full[full["cell1"] != full["cell2"]]
2929        full = full.reset_index(drop=True)
2930
2931        self.similarity = full
2932
2933    def similarity_plot(
2934        self,
2935        split_sets=True,
2936        set_info: bool = True,
2937        cmap="seismic",
2938        width=12,
2939        height=10,
2940    ):
2941        """
2942        Visualize pairwise similarity as a scatter plot.
2943
2944        Parameters
2945        ----------
2946        split_sets : bool, default True
2947            If True and set information is present, split plotting area roughly into two halves to visualize cross-set pairs.
2948
2949        set_info : bool, default True
2950            If True, keep the ' # set' annotation in labels; otherwise strip it.
2951
2952        cmap : str, default 'seismic'
2953            Color map for correlation (hue).
2954
2955        width : int, default 12
2956            Figure width.
2957
2958        height : int, default 10
2959            Figure height.
2960
2961        Returns
2962        -------
2963        matplotlib.figure.Figure
2964
2965        Raises
2966        ------
2967        ValueError
2968            If `self.similarity` is None.
2969
2970        Notes
2971        -----
2972        The function filters pairs by z-scored euclidean distance > 0 to focus on closer pairs.
2973        """
2974
2975        if self.similarity is None:
2976            raise ValueError(
2977                "Similarity data is missing. Please calculate similarity using self.estimating_similarity."
2978            )
2979
2980        similarity_data = self.similarity
2981
2982        if " # " in similarity_data["cell1"][0]:
2983            similarity_data["set1"] = [
2984                re.sub(".*# ", "", x) for x in similarity_data["cell1"]
2985            ]
2986            similarity_data["set2"] = [
2987                re.sub(".*# ", "", x) for x in similarity_data["cell2"]
2988            ]
2989
2990        if split_sets and " # " in similarity_data["cell1"][0]:
2991            sets = list(
2992                set(list(similarity_data["set1"]) + list(similarity_data["set2"]))
2993            )
2994
2995            mm = math.ceil(len(sets) / 2)
2996
2997            x_s = sets[0:mm]
2998            y_s = sets[mm : len(sets)]
2999
3000            similarity_data = similarity_data[similarity_data["set1"].isin(x_s)]
3001            similarity_data = similarity_data[similarity_data["set2"].isin(y_s)]
3002
3003            similarity_data = similarity_data.sort_values(["set1", "set2"])
3004
3005        if set_info is False and " # " in similarity_data["cell1"][0]:
3006            similarity_data["cell1"] = [
3007                re.sub(" #.*", "", x) for x in similarity_data["cell1"]
3008            ]
3009            similarity_data["cell2"] = [
3010                re.sub(" #.*", "", x) for x in similarity_data["cell2"]
3011            ]
3012
3013        similarity_data["-euclidean_zscore"] = -zscore(
3014            similarity_data["euclidean_dist"]
3015        )
3016
3017        similarity_data = similarity_data[similarity_data["-euclidean_zscore"] > 0]
3018
3019        fig = plt.figure(figsize=(width, height))
3020        sns.scatterplot(
3021            data=similarity_data,
3022            x="cell1",
3023            y="cell2",
3024            hue="correlation",
3025            size="-euclidean_zscore",
3026            sizes=(1, 100),
3027            palette=cmap,
3028            alpha=1,
3029            edgecolor="black",
3030        )
3031
3032        plt.xticks(rotation=90)
3033        plt.yticks(rotation=0)
3034        plt.xlabel("Cell 1")
3035        plt.ylabel("Cell 2")
3036        plt.legend(bbox_to_anchor=(1.05, 1), loc="upper left")
3037
3038        plt.grid(True, alpha=0.6)
3039
3040        plt.tight_layout()
3041
3042        return fig
3043
3044    def spatial_similarity(
3045        self,
3046        set_info: bool = True,
3047        bandwidth=1,
3048        n_neighbors=5,
3049        min_dist=0.1,
3050        legend_split=2,
3051        point_size=100,
3052        spread=1.0,
3053        set_op_mix_ratio=1.0,
3054        local_connectivity=1,
3055        repulsion_strength=1.0,
3056        negative_sample_rate=5,
3057        threshold=0.1,
3058        width=12,
3059        height=10,
3060    ):
3061        """
3062        Create a spatial UMAP-like visualization of similarity relationships between samples.
3063
3064        Parameters
3065        ----------
3066        set_info : bool, default True
3067            If True, retain set information in labels.
3068
3069        bandwidth : float, default 1
3070            Bandwidth used by MeanShift for clustering polygons.
3071
3072        point_size : float, default 100
3073            Size of scatter points.
3074
3075        legend_split : int, default 2
3076            Number of columns in legend.
3077
3078        n_neighbors, min_dist, spread, set_op_mix_ratio, local_connectivity, repulsion_strength, negative_sample_rate : parameters passed to UMAP.
3079
3080        threshold : float, default 0.1
3081            Minimum text distance for label adjustment to avoid overlap.
3082
3083        width : int, default 12
3084            Figure width.
3085
3086        height : int, default 10
3087            Figure height.
3088
3089        Returns
3090        -------
3091        matplotlib.figure.Figure
3092
3093        Raises
3094        ------
3095        ValueError
3096            If `self.similarity` is None.
3097
3098        Notes
3099        -----
3100        Builds a precomputed distance matrix combining correlation and euclidean distance,
3101        runs UMAP with metric='precomputed', then overlays cluster hulls (MeanShift + convex hull)
3102        and arrows to indicate nearest neighbors (minimal combined distance).
3103        """
3104
3105        if self.similarity is None:
3106            raise ValueError(
3107                "Similarity data is missing. Please calculate similarity using self.estimating_similarity."
3108            )
3109
3110        similarity_data = self.similarity
3111
3112        sim = similarity_data["correlation"]
3113        sim_scaled = (sim - sim.min()) / (sim.max() - sim.min())
3114        eu_dist = similarity_data["euclidean_dist"]
3115        eu_dist_scaled = (eu_dist - eu_dist.min()) / (eu_dist.max() - eu_dist.min())
3116
3117        similarity_data["combo_dist"] = (1 - sim_scaled) * eu_dist_scaled
3118
3119        # for nn target
3120        arrow_df = similarity_data.copy()
3121        arrow_df = similarity_data.loc[
3122            similarity_data.groupby("cell1")["combo_dist"].idxmin()
3123        ]
3124
3125        cells = sorted(set(similarity_data["cell1"]) | set(similarity_data["cell2"]))
3126        combo_matrix = pd.DataFrame(0, index=cells, columns=cells, dtype=float)
3127
3128        for _, row in similarity_data.iterrows():
3129            combo_matrix.loc[row["cell1"], row["cell2"]] = row["combo_dist"]
3130            combo_matrix.loc[row["cell2"], row["cell1"]] = row["combo_dist"]
3131
3132        umap_model = umap.UMAP(
3133            n_components=2,
3134            metric="precomputed",
3135            n_neighbors=n_neighbors,
3136            min_dist=min_dist,
3137            spread=spread,
3138            set_op_mix_ratio=set_op_mix_ratio,
3139            local_connectivity=set_op_mix_ratio,
3140            repulsion_strength=repulsion_strength,
3141            negative_sample_rate=negative_sample_rate,
3142            transform_seed=42,
3143            init="spectral",
3144            random_state=42,
3145            verbose=True,
3146        )
3147
3148        coords = umap_model.fit_transform(combo_matrix.values)
3149        cell_names = list(combo_matrix.index)
3150        num_cells = len(cell_names)
3151        palette = sns.color_palette("tab20c", num_cells)
3152
3153        if "#" in cell_names[0]:
3154            avsets = set(
3155                [re.sub(".*# ", "", x) for x in similarity_data["cell1"]]
3156                + [re.sub(".*# ", "", x) for x in similarity_data["cell2"]]
3157            )
3158            num_sets = len(avsets)
3159            color_indices = [i * len(palette) // num_sets for i in range(num_sets)]
3160            color_mapping_sets = {
3161                set_name: palette[i] for i, set_name in zip(color_indices, avsets)
3162            }
3163            color_mapping = {
3164                name: color_mapping_sets[re.sub(".*# ", "", name)]
3165                for i, name in enumerate(cell_names)
3166            }
3167        else:
3168            color_mapping = {name: palette[i] for i, name in enumerate(cell_names)}
3169
3170        meanshift = MeanShift(bandwidth=bandwidth)
3171        labels = meanshift.fit_predict(coords)
3172
3173        fig = plt.figure(figsize=(width, height))
3174        ax = plt.gca()
3175
3176        unique_labels = set(labels)
3177        cluster_palette = sns.color_palette("hls", len(unique_labels))
3178
3179        for label in unique_labels:
3180            if label == -1:
3181                continue
3182            cluster_coords = coords[labels == label]
3183            if len(cluster_coords) < 3:
3184                continue
3185
3186            hull = ConvexHull(cluster_coords)
3187            hull_points = cluster_coords[hull.vertices]
3188
3189            centroid = np.mean(hull_points, axis=0)
3190            expanded = hull_points + 0.05 * (hull_points - centroid)
3191
3192            poly = Polygon(
3193                expanded,
3194                closed=True,
3195                facecolor=cluster_palette[label],
3196                edgecolor="none",
3197                alpha=0.2,
3198                zorder=1,
3199            )
3200            ax.add_patch(poly)
3201
3202        texts = []
3203        for i, (x, y) in enumerate(coords):
3204            plt.scatter(
3205                x,
3206                y,
3207                s=point_size,
3208                color=color_mapping[cell_names[i]],
3209                edgecolors="black",
3210                linewidths=0.5,
3211                zorder=2,
3212            )
3213            texts.append(
3214                ax.text(
3215                    x, y, str(i), ha="center", va="center", fontsize=8, color="black"
3216                )
3217            )
3218
3219        def dist(p1, p2):
3220            return np.sqrt((p1[0] - p2[0]) ** 2 + (p1[1] - p2[1]) ** 2)
3221
3222        texts_to_adjust = []
3223        for i, t1 in enumerate(texts):
3224            for j, t2 in enumerate(texts):
3225                if i >= j:
3226                    continue
3227                d = dist(
3228                    (t1.get_position()[0], t1.get_position()[1]),
3229                    (t2.get_position()[0], t2.get_position()[1]),
3230                )
3231                if d < threshold:
3232                    if t1 not in texts_to_adjust:
3233                        texts_to_adjust.append(t1)
3234                    if t2 not in texts_to_adjust:
3235                        texts_to_adjust.append(t2)
3236
3237        adjust_text(
3238            texts_to_adjust,
3239            expand_text=(1.0, 1.0),
3240            force_text=0.9,
3241            arrowprops=dict(arrowstyle="-", color="gray", lw=0.1),
3242            ax=ax,
3243        )
3244
3245        for _, row in arrow_df.iterrows():
3246            try:
3247                idx1 = cell_names.index(row["cell1"])
3248                idx2 = cell_names.index(row["cell2"])
3249            except ValueError:
3250                continue
3251            x1, y1 = coords[idx1]
3252            x2, y2 = coords[idx2]
3253            arrow = FancyArrowPatch(
3254                (x1, y1),
3255                (x2, y2),
3256                arrowstyle="->",
3257                color="gray",
3258                linewidth=1.5,
3259                alpha=0.5,
3260                mutation_scale=12,
3261                zorder=0,
3262            )
3263            ax.add_patch(arrow)
3264
3265        if set_info is False and " # " in cell_names[0]:
3266
3267            legend_elements = [
3268                Patch(
3269                    facecolor=color_mapping[name],
3270                    edgecolor="black",
3271                    label=f"{i}{re.sub(' #.*', '', name)}",
3272                )
3273                for i, name in enumerate(cell_names)
3274            ]
3275
3276        else:
3277
3278            legend_elements = [
3279                Patch(
3280                    facecolor=color_mapping[name],
3281                    edgecolor="black",
3282                    label=f"{i}{name}",
3283                )
3284                for i, name in enumerate(cell_names)
3285            ]
3286
3287        plt.legend(
3288            handles=legend_elements,
3289            title="Cells",
3290            bbox_to_anchor=(1.05, 1),
3291            loc="upper left",
3292            ncol=legend_split,
3293        )
3294
3295        plt.xlabel("UMAP 1")
3296        plt.ylabel("UMAP 2")
3297        plt.grid(False)
3298        plt.show()
3299
3300        return fig
3301
3302    # subclusters part
3303
3304    def subcluster_prepare(self, features: list, cluster: str):
3305        """
3306        Prepare a `Clustering` object for subcluster analysis on a selected parent cluster.
3307
3308        Parameters
3309        ----------
3310        features : list
3311            Features to include for subcluster analysis.
3312
3313        cluster : str
3314            Parent cluster name (used to select matching cells).
3315
3316        Update
3317        ------
3318        Initializes `self.subclusters_` as a new `Clustering` instance containing the
3319        reduced data for the given cluster and stores `current_features` and `current_cluster`.
3320        """
3321
3322        dat = self.normalized_data
3323        dat.columns = list(self.input_metadata["cell_names"])
3324
3325        dat = reduce_data(self.normalized_data, features=features, names=[cluster])
3326
3327        self.subclusters_ = Clustering(data=dat, metadata=None)
3328
3329        self.subclusters_.current_features = features
3330        self.subclusters_.current_cluster = cluster
3331
3332    def define_subclusters(
3333        self,
3334        umap_num: int = 2,
3335        eps: float = 0.5,
3336        min_samples: int = 10,
3337        n_neighbors: int = 5,
3338        min_dist: float = 0.1,
3339        spread: float = 1.0,
3340        set_op_mix_ratio: float = 1.0,
3341        local_connectivity: int = 1,
3342        repulsion_strength: float = 1.0,
3343        negative_sample_rate: int = 5,
3344        width=8,
3345        height=6,
3346    ):
3347        """
3348        Compute UMAP and DBSCAN clustering within a previously prepared subcluster dataset.
3349
3350        Parameters
3351        ----------
3352        umap_num : int, default 2
3353            Number of UMAP dimensions to compute.
3354
3355        eps : float, default 0.5
3356            DBSCAN eps parameter.
3357
3358        min_samples : int, default 10
3359            DBSCAN min_samples parameter.
3360
3361        n_neighbors, min_dist, spread, set_op_mix_ratio, local_connectivity, repulsion_strength, negative_sample_rate, width, height :
3362            Additional parameters passed to UMAP / plotting / MeanShift as appropriate.
3363
3364        Update
3365        -------
3366        Stores cluster labels in `self.subclusters_.subclusters`.
3367
3368        Raises
3369        ------
3370        RuntimeError
3371            If `self.subclusters_` has not been prepared.
3372        """
3373
3374        if self.subclusters_ is None:
3375            raise RuntimeError(
3376                "Nothing to return. 'self.subcluster_prepare' was not conducted!"
3377            )
3378
3379        self.subclusters_.perform_UMAP(
3380            factorize=False,
3381            umap_num=umap_num,
3382            pc_num=0,
3383            harmonized=False,
3384            n_neighbors=n_neighbors,
3385            min_dist=min_dist,
3386            spread=spread,
3387            set_op_mix_ratio=set_op_mix_ratio,
3388            local_connectivity=local_connectivity,
3389            repulsion_strength=repulsion_strength,
3390            negative_sample_rate=negative_sample_rate,
3391            width=width,
3392            height=height,
3393        )
3394
3395        fig = self.subclusters_.find_clusters_UMAP(
3396            umap_n=umap_num,
3397            eps=eps,
3398            min_samples=min_samples,
3399            width=width,
3400            height=height,
3401        )
3402
3403        clusters = self.subclusters_.return_clusters(clusters="umap")
3404
3405        self.subclusters_.subclusters = [str(x) for x in list(clusters)]
3406
3407        return fig
3408
3409    def subcluster_features_scatter(
3410        self,
3411        colors="viridis",
3412        hclust="complete",
3413        scale=False,
3414        img_width=3,
3415        img_high=5,
3416        label_size=6,
3417        size_scale=70,
3418        y_lab="Genes",
3419        legend_lab="normalized",
3420        bbox_to_anchor_scale: int = 25,
3421        bbox_to_anchor_perc: tuple = (0.91, 0.63),
3422    ):
3423        """
3424        Create a features-scatter visualization for the subclusters (averaged and occurrence).
3425
3426        Parameters
3427        ----------
3428        colors : str, default 'viridis'
3429            Colormap name passed to `features_scatter`.
3430
3431        hclust : str or None
3432            Hierarchical clustering linkage to order rows/columns.
3433
3434        scale: bool, default False
3435            If True, expression data will be scaled (0–1) across the rows (features).
3436
3437        img_width, img_high : float
3438            Figure size.
3439
3440        label_size : int
3441            Font size for labels.
3442
3443        size_scale : int
3444            Bubble size scaling.
3445
3446        y_lab : str
3447            X axis label.
3448
3449        legend_lab : str
3450            Colorbar label.
3451
3452        bbox_to_anchor_scale : int, default=25
3453            Vertical scale (percentage) for positioning the colorbar.
3454
3455        bbox_to_anchor_perc : tuple, default=(0.91, 0.63)
3456            Anchor position for the size legend (percent bubble legend).
3457
3458        Returns
3459        -------
3460        matplotlib.figure.Figure
3461
3462        Raises
3463        ------
3464        RuntimeError
3465            If subcluster preparation/definition has not been run.
3466        """
3467
3468        if self.subclusters_ is None or self.subclusters_.subclusters is None:
3469            raise RuntimeError(
3470                "Nothing to return. 'self.subcluster_prepare' -> 'self.define_subclusters' pip was not conducted!"
3471            )
3472
3473        dat = self.normalized_data
3474        dat.columns = list(self.input_metadata["cell_names"])
3475
3476        dat = reduce_data(
3477            self.normalized_data,
3478            features=self.subclusters_.current_features,
3479            names=[self.subclusters_.current_cluster],
3480        )
3481
3482        dat.columns = self.subclusters_.subclusters
3483
3484        avg = average(dat)
3485        occ = occurrence(dat)
3486
3487        scatter = features_scatter(
3488            expression_data=avg,
3489            occurence_data=occ,
3490            features=None,
3491            scale=scale,
3492            metadata_list=None,
3493            colors=colors,
3494            hclust=hclust,
3495            img_width=img_width,
3496            img_high=img_high,
3497            label_size=label_size,
3498            size_scale=size_scale,
3499            y_lab=y_lab,
3500            legend_lab=legend_lab,
3501            bbox_to_anchor_scale=bbox_to_anchor_scale,
3502            bbox_to_anchor_perc=bbox_to_anchor_perc,
3503        )
3504
3505        return scatter
3506
3507    def subcluster_DEG_scatter(
3508        self,
3509        top_n=3,
3510        min_exp=0,
3511        min_pct=0.25,
3512        p_val=0.05,
3513        colors="viridis",
3514        hclust="complete",
3515        scale=False,
3516        img_width=3,
3517        img_high=5,
3518        label_size=6,
3519        size_scale=70,
3520        y_lab="Genes",
3521        legend_lab="normalized",
3522        bbox_to_anchor_scale: int = 25,
3523        bbox_to_anchor_perc: tuple = (0.91, 0.63),
3524        n_proc=10,
3525    ):
3526        """
3527        Plot top differential features (DEGs) for subclusters as a features-scatter.
3528
3529        Parameters
3530        ----------
3531        top_n : int, default 3
3532            Number of top features per subcluster to show.
3533
3534        min_exp : float, default 0
3535            Minimum expression threshold passed to `statistic`.
3536
3537        min_pct : float, default 0.25
3538            Minimum percent expressed in target group.
3539
3540        p_val: float, default 0.05
3541            Maximum p-value for visualizing features.
3542
3543        n_proc : int, default 10
3544            Parallel jobs used for DEG calculation.
3545
3546        scale: bool, default False
3547            If True, expression_data will be scaled (0–1) across the rows (features).
3548
3549        colors : str, default='viridis'
3550            Colormap for expression values.
3551
3552        hclust : str or None, default='complete'
3553            Linkage method for hierarchical clustering. If None, no clustering
3554            is performed.
3555
3556        img_width : int or float, default=8
3557            Width of the plot in inches.
3558
3559        img_high : int or float, default=5
3560            Height of the plot in inches.
3561
3562        label_size : int, default=10
3563            Font size for axis labels and ticks.
3564
3565        size_scale : int or float, default=100
3566            Scaling factor for bubble sizes.
3567
3568        y_lab : str, default='Genes'
3569            Label for the x-axis.
3570
3571        legend_lab : str, default='normalized'
3572            Label for the colorbar legend.
3573
3574        bbox_to_anchor_scale : int, default=25
3575            Vertical scale (percentage) for positioning the colorbar.
3576
3577        bbox_to_anchor_perc : tuple, default=(0.91, 0.63)
3578            Anchor position for the size legend (percent bubble legend).
3579
3580        Returns
3581        -------
3582        matplotlib.figure.Figure
3583
3584        Raises
3585        ------
3586        RuntimeError
3587            If subcluster preparation/definition has not been run.
3588
3589        Notes
3590        -----
3591        Internally calls `calc_DEG` (or equivalent) to obtain statistics, filters
3592        by p-value and effect-size, selects top features per valid group and plots them.
3593        """
3594
3595        if self.subclusters_ is None or self.subclusters_.subclusters is None:
3596            raise RuntimeError(
3597                "Nothing to return. 'self.subcluster_prepare' -> 'self.define_subclusters' pip was not conducted!"
3598            )
3599
3600        dat = self.normalized_data
3601        dat.columns = list(self.input_metadata["cell_names"])
3602
3603        dat = reduce_data(
3604            self.normalized_data, names=[self.subclusters_.current_cluster]
3605        )
3606
3607        dat.columns = self.subclusters_.subclusters
3608
3609        deg_stats = calc_DEG(
3610            dat,
3611            metadata_list=None,
3612            entities="All",
3613            sets=None,
3614            min_exp=min_exp,
3615            min_pct=min_pct,
3616            n_proc=n_proc,
3617        )
3618
3619        deg_stats = deg_stats[deg_stats["p_val"] <= p_val]
3620        deg_stats = deg_stats[deg_stats["log(FC)"] > 0]
3621
3622        deg_stats = (
3623            deg_stats.sort_values(
3624                ["valid_group", "esm", "log(FC)"], ascending=[True, False, False]
3625            )
3626            .groupby("valid_group")
3627            .head(top_n)
3628        )
3629
3630        dat = reduce_data(dat, features=list(set(deg_stats["feature"])))
3631
3632        avg = average(dat)
3633        occ = occurrence(dat)
3634
3635        scatter = features_scatter(
3636            expression_data=avg,
3637            occurence_data=occ,
3638            features=None,
3639            metadata_list=None,
3640            colors=colors,
3641            hclust=hclust,
3642            img_width=img_width,
3643            img_high=img_high,
3644            label_size=label_size,
3645            size_scale=size_scale,
3646            y_lab=y_lab,
3647            legend_lab=legend_lab,
3648            bbox_to_anchor_scale=bbox_to_anchor_scale,
3649            bbox_to_anchor_perc=bbox_to_anchor_perc,
3650        )
3651
3652        return scatter
3653
3654    def accept_subclusters(self):
3655        """
3656        Commit subcluster labels into the main `input_metadata` by renaming cell names.
3657
3658        The method replaces occurrences of the parent cluster name in `self.input_metadata['cell_names']`
3659        with the expanded names that include subcluster suffixes (via `add_subnames`),
3660        then clears `self.subclusters_`.
3661
3662        Update
3663        ------
3664        Modifies `self.input_metadata['cell_names']`.
3665
3666        Resets `self.subclusters_` to None.
3667
3668        Raises
3669        ------
3670        RuntimeError
3671            If `self.subclusters_` is not defined or subclusters were not computed.
3672        """
3673
3674        if self.subclusters_ is None or self.subclusters_.subclusters is None:
3675            raise RuntimeError(
3676                "Nothing to return. 'self.subcluster_prepare' -> 'self.define_subclusters' pip was not conducted!"
3677            )
3678
3679        new_meta = add_subnames(
3680            list(self.input_metadata["cell_names"]),
3681            parent_name=self.subclusters_.current_cluster,
3682            new_clusters=self.subclusters_.subclusters,
3683        )
3684
3685        self.input_metadata["cell_names"] = new_meta
3686
3687        self.subclusters_ = None
3688
3689    def scatter_plot(
3690        self,
3691        names: list | None = None,
3692        features: list | None = None,
3693        name_slot: str = "cell_names",
3694        scale=True,
3695        colors="viridis",
3696        hclust=None,
3697        img_width=15,
3698        img_high=1,
3699        label_size=10,
3700        size_scale=200,
3701        y_lab="Genes",
3702        legend_lab="log(CPM + 1)",
3703        set_box_size: float | int = 5,
3704        set_box_high: float | int = 5,
3705        bbox_to_anchor_scale=25,
3706        bbox_to_anchor_perc=(0.90, -0.24),
3707        bbox_to_anchor_group=(1.01, 0.4),
3708    ):
3709        """
3710        Create a bubble scatter plot of selected features across samples inside project.
3711
3712        Each point represents a feature-sample pair, where the color encodes the
3713        expression value and the size encodes occurrence or relative abundance.
3714        Optionally, hierarchical clustering can be applied to order rows and columns.
3715
3716        Parameters
3717        ----------
3718        names : list, str, or None
3719            Names of samples to include. If None, all samples are considered.
3720
3721        features : list, str, or None
3722            Names of features to include. If None, all features are considered.
3723
3724        name_slot : str
3725            Column in metadata to use as sample names.
3726
3727        scale: bool, default False
3728            If True, expression_data will be scaled (0–1) across the rows (features).
3729
3730        colors : str, default='viridis'
3731            Colormap for expression values.
3732
3733        hclust : str or None, default='complete'
3734            Linkage method for hierarchical clustering. If None, no clustering
3735            is performed.
3736
3737        img_width : int or float, default=8
3738            Width of the plot in inches.
3739
3740        img_high : int or float, default=5
3741            Height of the plot in inches.
3742
3743        label_size : int, default=10
3744            Font size for axis labels and ticks.
3745
3746        size_scale : int or float, default=100
3747            Scaling factor for bubble sizes.
3748
3749        y_lab : str, default='Genes'
3750            Label for the x-axis.
3751
3752        legend_lab : str, default='log(CPM + 1)'
3753            Label for the colorbar legend.
3754
3755        bbox_to_anchor_scale : int, default=25
3756            Vertical scale (percentage) for positioning the colorbar.
3757
3758        bbox_to_anchor_perc : tuple, default=(0.91, 0.63)
3759            Anchor position for the size legend (percent bubble legend).
3760
3761        bbox_to_anchor_group : tuple, default=(1.01, 0.4)
3762            Anchor position for the group legend.
3763
3764        Returns
3765        -------
3766        matplotlib.figure.Figure
3767            The generated scatter plot figure.
3768
3769        Notes
3770        -----
3771        Colors represent expression values normalized to the colormap.
3772        """
3773
3774        prtd, met = self.get_partial_data(
3775            names=names, features=features, name_slot=name_slot, inc_metadata=True
3776        )
3777
3778        prtd.columns = prtd.columns + "#" + met["sets"]
3779
3780        prtd_avg = average(prtd)
3781
3782        meta_sets = [re.sub(".*#", "", x) for x in prtd_avg.columns]
3783
3784        prtd_avg.columns = [re.sub("#.*", "", x) for x in prtd_avg.columns]
3785
3786        prtd_occ = occurrence(prtd)
3787
3788        prtd_occ.columns = [re.sub("#.*", "", x) for x in prtd_occ.columns]
3789
3790        fig_scatter = features_scatter(
3791            expression_data=prtd_avg,
3792            occurence_data=prtd_occ,
3793            scale=scale,
3794            features=None,
3795            metadata_list=meta_sets,
3796            colors=colors,
3797            hclust=hclust,
3798            img_width=img_width,
3799            img_high=img_high,
3800            label_size=label_size,
3801            size_scale=size_scale,
3802            y_lab=y_lab,
3803            legend_lab=legend_lab,
3804            set_box_size=set_box_size,
3805            set_box_high=set_box_high,
3806            bbox_to_anchor_scale=bbox_to_anchor_scale,
3807            bbox_to_anchor_perc=bbox_to_anchor_perc,
3808            bbox_to_anchor_group=bbox_to_anchor_group,
3809        )
3810
3811        return fig_scatter
3812
3813    def data_composition(
3814        self,
3815        features_count: list | None,
3816        name_slot: str = "cell_names",
3817        set_sep: bool = True,
3818    ):
3819        """
3820        Compute composition of cell types in data set.
3821
3822        This function counts the occurrences of specific cells (e.g., cell types, subtypes)
3823        within metadata entries, calculates their relative percentages, and stores
3824        the results in `self.composition_data`.
3825
3826        Parameters
3827        ----------
3828        features_count : list or None
3829            List of features (part or full names) to be counted.
3830            If None, all unique elements from the specified `name_slot` metadata field are used.
3831
3832        name_slot : str, default 'cell_names'
3833            Metadata field containing sample identifiers or labels.
3834
3835        set_sep : bool, default True
3836            If True and multiple sets are present in metadata, compute composition
3837            separately for each set.
3838
3839        Update
3840        -------
3841        Stores results in `self.composition_data` as a pandas DataFrame with:
3842        - 'name': feature name
3843        - 'n': number of occurrences
3844        - 'pct': percentage of occurrences
3845        - 'set' (if applicable): dataset identifier
3846        """
3847
3848        validated_list = list(self.input_metadata[name_slot])
3849        sets = list(self.input_metadata["sets"])
3850
3851        if features_count is None:
3852            features_count = list(set(self.input_metadata[name_slot]))
3853
3854        if set_sep and len(set(sets)) > 1:
3855
3856            final_res = pd.DataFrame()
3857
3858            for s in set(sets):
3859                print(s)
3860
3861                mask = [True if s == x else False for x in sets]
3862
3863                tmp_val_list = np.array(validated_list)
3864
3865                tmp_val_list = list(tmp_val_list[mask])
3866
3867                res_dict = {"name": [], "n": [], "set": []}
3868
3869                for f in tqdm(features_count):
3870                    res_dict["n"].append(
3871                        sum(1 for element in tmp_val_list if f in element)
3872                    )
3873                    res_dict["name"].append(f)
3874                    res_dict["set"].append(s)
3875                    res = pd.DataFrame(res_dict)
3876                    res["pct"] = res["n"] / sum(res["n"]) * 100
3877                    res["pct"] = res["pct"].round(2)
3878
3879                final_res = pd.concat([final_res, res])
3880
3881            res = final_res.sort_values(["set", "pct"], ascending=[True, False])
3882
3883        else:
3884
3885            res_dict = {"name": [], "n": []}
3886
3887            for f in tqdm(features_count):
3888                res_dict["n"].append(
3889                    sum(1 for element in validated_list if f in element)
3890                )
3891                res_dict["name"].append(f)
3892
3893            res = pd.DataFrame(res_dict)
3894            res["pct"] = res["n"] / sum(res["n"]) * 100
3895            res["pct"] = res["pct"].round(2)
3896
3897            res = res.sort_values("pct", ascending=False)
3898
3899        self.composition_data = res
3900
3901    def composition_pie(
3902        self,
3903        width=6,
3904        height=6,
3905        font_size=15,
3906        cmap: str = "tab20",
3907        legend_split_col: int = 1,
3908        offset_labels: float | int = 0.5,
3909        legend_bbox: tuple = (1.15, 0.95),
3910    ):
3911        """
3912        Visualize the composition of cell lineages using pie charts.
3913
3914        Generates pie charts showing the relative proportions of features stored
3915        in `self.composition_data`. If multiple sets are present, a separate
3916        chart is drawn for each set.
3917
3918        Parameters
3919        ----------
3920        width : int, default 6
3921            Width of the figure.
3922
3923        height : int, default 6
3924            Height of the figure (applied per set if multiple sets are plotted).
3925
3926        font_size : int, default 15
3927            Font size for labels and annotations.
3928
3929        cmap : str, default 'tab20'
3930            Colormap used for pie slices.
3931
3932        legend_split_col : int, default 1
3933            Number of columns in the legend.
3934
3935        offset_labels : float or int, default 0.5
3936            Spacing offset for label placement relative to pie slices.
3937
3938        legend_bbox : tuple, default (1.15, 0.95)
3939            Bounding box anchor position for the legend.
3940
3941        Returns
3942        -------
3943        matplotlib.figure.Figure
3944            Pie chart visualization of composition data.
3945        """
3946
3947        df = self.composition_data
3948
3949        if "set" in df.columns and len(set(df["set"])) > 1:
3950
3951            sets = list(set(df["set"]))
3952            fig, axes = plt.subplots(len(sets), 1, figsize=(width, height * len(sets)))
3953
3954            all_wedges = []
3955            cmap = plt.get_cmap(cmap)
3956
3957            set_nam = len(set(df["name"]))
3958
3959            legend_labels = list(set(df["name"]))
3960
3961            colors = [cmap(i / set_nam) for i in range(set_nam)]
3962
3963            cmap_dict = dict(zip(legend_labels, colors))
3964
3965            for idx, s in enumerate(sets):
3966                ax = axes[idx]
3967                tmp_df = df[df["set"] == s].reset_index(drop=True)
3968
3969                labels = [f"{row['pct']:.1f}%" for _, row in tmp_df.iterrows()]
3970
3971                wedges, _ = ax.pie(
3972                    tmp_df["n"],
3973                    startangle=90,
3974                    labeldistance=1.05,
3975                    colors=[cmap_dict[x] for x in tmp_df["name"]],
3976                    wedgeprops={"linewidth": 0.5, "edgecolor": "black"},
3977                )
3978
3979                all_wedges.extend(wedges)
3980
3981                kw = dict(arrowprops=dict(arrowstyle="-"), zorder=0, va="center")
3982                n = 0
3983                for i, p in enumerate(wedges):
3984                    ang = (p.theta2 - p.theta1) / 2.0 + p.theta1
3985                    y = np.sin(np.deg2rad(ang))
3986                    x = np.cos(np.deg2rad(ang))
3987                    horizontalalignment = {-1: "right", 1: "left"}[int(np.sign(x))]
3988                    connectionstyle = f"angle,angleA=0,angleB={ang}"
3989                    kw["arrowprops"].update({"connectionstyle": connectionstyle})
3990                    if len(labels[i]) > 0:
3991                        n += offset_labels
3992                        ax.annotate(
3993                            labels[i],
3994                            xy=(x, y),
3995                            xytext=(1.01 * x + (n * x / 4), 1.01 * y + (n * y / 4)),
3996                            horizontalalignment=horizontalalignment,
3997                            fontsize=font_size,
3998                            weight="bold",
3999                            **kw,
4000                        )
4001
4002                circle2 = plt.Circle((0, 0), 0.6, color="white", ec="black")
4003                ax.add_artist(circle2)
4004
4005                ax.text(
4006                    0,
4007                    0,
4008                    f"{s}",
4009                    ha="center",
4010                    va="center",
4011                    fontsize=font_size,
4012                    weight="bold",
4013                )
4014
4015            legend_handles = [
4016                Patch(facecolor=cmap_dict[label], edgecolor="black", label=label)
4017                for label in legend_labels
4018            ]
4019
4020            fig.legend(
4021                handles=legend_handles,
4022                loc="center right",
4023                bbox_to_anchor=legend_bbox,
4024                ncol=legend_split_col,
4025                title="",
4026            )
4027
4028            plt.tight_layout()
4029            plt.show()
4030
4031        else:
4032
4033            labels = [f"{row['pct']:.1f}%" for _, row in df.iterrows()]
4034
4035            legend_labels = [f"{row['name']}" for _, row in df.iterrows()]
4036
4037            cmap = plt.get_cmap(cmap)
4038            colors = [cmap(i / len(df)) for i in range(len(df))]
4039
4040            fig, ax = plt.subplots(
4041                figsize=(width, height), subplot_kw=dict(aspect="equal")
4042            )
4043
4044            wedges, _ = ax.pie(
4045                df["n"],
4046                startangle=90,
4047                labeldistance=1.05,
4048                colors=colors,
4049                wedgeprops={"linewidth": 0.5, "edgecolor": "black"},
4050            )
4051
4052            kw = dict(arrowprops=dict(arrowstyle="-"), zorder=0, va="center")
4053            n = 0
4054            for i, p in enumerate(wedges):
4055                ang = (p.theta2 - p.theta1) / 2.0 + p.theta1
4056                y = np.sin(np.deg2rad(ang))
4057                x = np.cos(np.deg2rad(ang))
4058                horizontalalignment = {-1: "right", 1: "left"}[int(np.sign(x))]
4059                connectionstyle = "angle,angleA=0,angleB={}".format(ang)
4060                kw["arrowprops"].update({"connectionstyle": connectionstyle})
4061                if len(labels[i]) > 0:
4062                    n += offset_labels
4063
4064                    ax.annotate(
4065                        labels[i],
4066                        xy=(x, y),
4067                        xytext=(1.01 * x + (n * x / 4), y * 1.01 + (n * y / 4)),
4068                        horizontalalignment=horizontalalignment,
4069                        fontsize=font_size,
4070                        weight="bold",
4071                        **kw,
4072                    )
4073
4074            circle2 = plt.Circle((0, 0), 0.6, color="white")
4075            circle2.set_edgecolor("black")
4076
4077            p = plt.gcf()
4078            p.gca().add_artist(circle2)
4079
4080            ax.legend(
4081                wedges,
4082                legend_labels,
4083                title="",
4084                loc="center left",
4085                bbox_to_anchor=legend_bbox,
4086                ncol=legend_split_col,
4087            )
4088
4089            plt.show()
4090
4091        return fig
4092
4093    def bar_composition(
4094        self,
4095        cmap="tab20b",
4096        width=2,
4097        height=6,
4098        font_size=15,
4099        legend_split_col: int = 1,
4100        legend_bbox: tuple = (1.3, 1),
4101    ):
4102        """
4103        Visualize the composition of cell lineages using bar plots.
4104
4105        Produces bar plots showing the distribution of features stored in
4106        `self.composition_data`. If multiple sets are present, a separate
4107        bar is drawn for each set. Percentages are annotated alongside the bars.
4108
4109        Parameters
4110        ----------
4111        cmap : str, default 'tab20b'
4112            Colormap used for stacked bars.
4113
4114        width : int, default 2
4115            Width of each subplot (per set).
4116
4117        height : int, default 6
4118            Height of the figure.
4119
4120        font_size : int, default 15
4121            Font size for labels and annotations.
4122
4123        legend_split_col : int, default 1
4124            Number of columns in the legend.
4125
4126        legend_bbox : tuple, default (1.3, 1)
4127            Bounding box anchor position for the legend.
4128
4129        Returns
4130        -------
4131        matplotlib.figure.Figure
4132            Stacked bar plot visualization of composition data.
4133        """
4134
4135        df = self.composition_data
4136        df["num"] = range(1, len(df) + 1)
4137
4138        if "set" in df.columns and len(set(df["set"])) > 1:
4139
4140            sets = list(set(df["set"]))
4141            fig, axes = plt.subplots(1, len(sets), figsize=(width * len(sets), height))
4142
4143            cmap = plt.get_cmap(cmap)
4144
4145            set_nam = len(set(df["name"]))
4146
4147            legend_labels = list(set(df["name"]))
4148
4149            colors = [cmap(i / set_nam) for i in range(set_nam)]
4150
4151            cmap_dict = dict(zip(legend_labels, colors))
4152
4153            for idx, s in enumerate(sets):
4154                ax = axes[idx]
4155
4156                tmp_df = df[df["set"] == s].reset_index(drop=True)
4157
4158                values = tmp_df["n"].values
4159                total = sum(values)
4160                values = [v / total * 100 for v in values]
4161                values = [round(v, 2) for v in values]
4162
4163                idx_max = np.argmax(values)
4164                correction = 100 - sum(values)
4165                values[idx_max] += correction
4166
4167                names = tmp_df["name"].values
4168                perc = tmp_df["pct"].values
4169                nums = tmp_df["num"].values
4170
4171                bottom = 0
4172                centers = []
4173                for name, num, val, color in zip(names, nums, values, colors):
4174                    ax.bar(s, val, bottom=bottom, color=cmap_dict[name], label=name)
4175                    centers.append(bottom + val / 2)
4176                    bottom += val
4177
4178                y_positions = np.linspace(centers[0], centers[-1], len(centers))
4179                x_text = -0.8
4180
4181                for y_label, y_center, pct, num in zip(
4182                    y_positions, centers, perc, nums
4183                ):
4184                    ax.annotate(
4185                        f"{pct:.1f}%",
4186                        xy=(0, y_center),
4187                        xycoords="data",
4188                        xytext=(x_text, y_label),
4189                        textcoords="data",
4190                        ha="right",
4191                        va="center",
4192                        fontsize=font_size,
4193                        arrowprops=dict(
4194                            arrowstyle="->",
4195                            lw=1,
4196                            color="black",
4197                            connectionstyle="angle3,angleA=0,angleB=90",
4198                        ),
4199                    )
4200
4201                ax.set_ylim(0, 100)
4202                ax.set_xlabel(s, fontsize=font_size)
4203                ax.xaxis.label.set_rotation(30)
4204
4205                ax.set_xticks([])
4206                ax.set_yticks([])
4207                for spine in ax.spines.values():
4208                    spine.set_visible(False)
4209
4210            legend_handles = [
4211                Patch(facecolor=cmap_dict[label], edgecolor="black", label=label)
4212                for label in legend_labels
4213            ]
4214
4215            fig.legend(
4216                handles=legend_handles,
4217                loc="center right",
4218                bbox_to_anchor=legend_bbox,
4219                ncol=legend_split_col,
4220                title="",
4221            )
4222
4223            plt.tight_layout()
4224            plt.show()
4225
4226        else:
4227
4228            cmap = plt.get_cmap(cmap)
4229
4230            colors = [cmap(i / len(df)) for i in range(len(df))]
4231
4232            fig, ax = plt.subplots(figsize=(width, height))
4233
4234            values = df["n"].values
4235            names = df["name"].values
4236            perc = df["pct"].values
4237            nums = df["num"].values
4238
4239            bottom = 0
4240            centers = []
4241            for name, num, val, color in zip(names, nums, values, colors):
4242                ax.bar(0, val, bottom=bottom, color=color, label=f"{num}) {name}")
4243                centers.append(bottom + val / 2)
4244                bottom += val
4245
4246            y_positions = np.linspace(centers[0], centers[-1], len(centers))
4247            x_text = -0.8
4248
4249            for y_label, y_center, pct, num in zip(y_positions, centers, perc, nums):
4250                ax.annotate(
4251                    f"{num}) {pct}",
4252                    xy=(0, y_center),
4253                    xycoords="data",
4254                    xytext=(x_text, y_label),
4255                    textcoords="data",
4256                    ha="right",
4257                    va="center",
4258                    fontsize=9,
4259                    arrowprops=dict(
4260                        arrowstyle="->",
4261                        lw=1,
4262                        color="black",
4263                        connectionstyle="angle3,angleA=0,angleB=90",
4264                    ),
4265                )
4266
4267            ax.set_xticks([])
4268            ax.set_yticks([])
4269            for spine in ax.spines.values():
4270                spine.set_visible(False)
4271
4272            ax.legend(
4273                title="Legend",
4274                bbox_to_anchor=legend_bbox,
4275                loc="upper left",
4276                ncol=legend_split_col,
4277            )
4278
4279            plt.tight_layout()
4280            plt.show()
4281
4282        return fig
4283
4284    def cell_regression(
4285        self,
4286        cell_x: str,
4287        cell_y: str,
4288        set_x: str | None,
4289        set_y: str | None,
4290        threshold=10,
4291        image_width=12,
4292        image_high=7,
4293        color="black",
4294    ):
4295        """
4296        Perform regression analysis between two selected cells and visualize the relationship.
4297
4298        This function computes a linear regression between two specified cells from
4299        aggregated normalized data, plots the regression line with scatter points,
4300        annotates regression statistics, and highlights potential outliers.
4301
4302        Parameters
4303        ----------
4304        cell_x : str
4305            Name of the first cell (X-axis).
4306
4307        cell_y : str
4308            Name of the second cell (Y-axis).
4309
4310        set_x : str or None
4311            Dataset identifier corresponding to `cell_x`. If None, cell is selected only by name.
4312
4313        set_y : str or None
4314            Dataset identifier corresponding to `cell_y`. If None, cell is selected only by name.
4315
4316        threshold : int or float, default 10
4317            Threshold for detecting outliers. Points deviating from the mean or diagonal by more
4318            than this value are annotated.
4319
4320        image_width : int, default 12
4321            Width of the regression plot (in inches).
4322
4323        image_high : int, default 7
4324            Height of the regression plot (in inches).
4325
4326        color : str, default 'black'
4327            Color of the regression scatter points and line.
4328
4329        Returns
4330        -------
4331        matplotlib.figure.Figure
4332            Regression plot figure with annotated regression line, R², p-value, and outliers.
4333
4334        Raises
4335        ------
4336        ValueError
4337            If `cell_x` or `cell_y` are not found in the dataset.
4338            If multiple matches are found for a cell name and `set_x`/`set_y` are not specified.
4339
4340        Notes
4341        -----
4342        - The function automatically calls `jseq_object.average()` if aggregated data is not available.
4343        - Outliers are annotated with their corresponding index labels.
4344        - Regression is computed using `scipy.stats.linregress`.
4345
4346        Examples
4347        --------
4348        >>> obj.cell_regression(cell_x="Purkinje", cell_y="Granule", set_x="Exp1", set_y="Exp2")
4349        >>> obj.cell_regression(cell_x="NeuronA", cell_y="NeuronB", threshold=5, color="blue")
4350        """
4351
4352        if self.agg_normalized_data is None:
4353            self.average()
4354
4355        metadata = self.agg_metadata
4356        data = self.agg_normalized_data
4357
4358        if set_x is not None and set_y is not None:
4359            data.columns = metadata["cell_names"] + " # " + metadata["sets"]
4360            cell_x = cell_x + " # " + set_x
4361            cell_y = cell_y + " # " + set_y
4362
4363        else:
4364            data.columns = metadata["cell_names"]
4365
4366        if not cell_x in data.columns:
4367            raise ValueError("'cell_x' value not in cell names!")
4368
4369        if not cell_y in data.columns:
4370            raise ValueError("'cell_y' value not in cell names!")
4371
4372        if list(data.columns).count(cell_x) > 1:
4373            raise ValueError(
4374                f"'{cell_x}' occurs more than once. If you want to select a specific cell, "
4375                f"please also provide the corresponding 'set_x' and 'set_y' values."
4376            )
4377
4378        if list(data.columns).count(cell_y) > 1:
4379            raise ValueError(
4380                f"'{cell_y}' occurs more than once. If you want to select a specific cell, "
4381                f"please also provide the corresponding 'set_x' and 'set_y' values."
4382            )
4383
4384        fig, ax = plt.subplots(figsize=(image_width, image_high))
4385        ax = sns.regplot(x=cell_x, y=cell_y, data=data, color=color)
4386
4387        slope, intercept, r_value, p_value, _ = stats.linregress(
4388            data[cell_x], data[cell_y]
4389        )
4390        equation = "y = {:.2f}x + {:.2f}".format(slope, intercept)
4391
4392        ax.annotate(
4393            "R-squared = {:.2f}\nP-value = {:.2f}\n{}".format(
4394                r_value**2, p_value, equation
4395            ),
4396            xy=(0.05, 0.90),
4397            xycoords="axes fraction",
4398            fontsize=12,
4399        )
4400
4401        ax.spines["top"].set_visible(False)
4402        ax.spines["right"].set_visible(False)
4403
4404        diff = []
4405        x_mean, y_mean = data[cell_x].mean(), data[cell_y].mean()
4406        for i, (xi, yi) in enumerate(zip(data[cell_x], data[cell_y])):
4407            diff.append(abs(xi - x_mean))
4408            diff.append(abs(yi - y_mean))
4409
4410        def annotate_outliers(x, y, threshold):
4411            texts = []
4412            x_mean, y_mean = x.mean(), y.mean()
4413            for i, (xi, yi) in enumerate(zip(x, y)):
4414                if (
4415                    abs(xi - x_mean) > threshold
4416                    or abs(yi - y_mean) > threshold
4417                    or abs(yi - xi) > threshold
4418                ):
4419                    text = ax.text(xi, yi, data.index[i])
4420                    texts.append(text)
4421
4422            return texts
4423
4424        texts = annotate_outliers(data[cell_x], data[cell_y], threshold)
4425
4426        adjust_text(texts, arrowprops=dict(arrowstyle="-", color="gray", alpha=0.5))
4427
4428        plt.show()
4429
4430        return fig
class Clustering:
  31class Clustering:
  32    """
  33    A class for performing dimensionality reduction, clustering, and visualization
  34    on high-dimensional data (e.g., single-cell gene expression).
  35
  36    The class provides methods for:
  37    - Normalizing and extracting subsets of data
  38    - Principal Component Analysis (PCA) and related clustering
  39    - Uniform Manifold Approximation and Projection (UMAP) and clustering
  40    - Visualization of PCA and UMAP embeddings
  41    - Harmonization of batch effects
  42    - Accessing processed data and cluster labels
  43
  44    Methods
  45    -------
  46    add_data_frame(data, metadata)
  47        Class method to create a Clustering instance from a DataFrame and metadata.
  48
  49    harmonize_sets()
  50        Perform batch effect harmonization on PCA data.
  51
  52    perform_PCA(pc_num=100, width=8, height=6)
  53        Perform PCA on the dataset and visualize the first two PCs.
  54
  55    knee_plot_PCA(width=8, height=6)
  56        Plot the cumulative variance explained by PCs to determine optimal dimensionality.
  57
  58    find_clusters_PCA(pc_num=0, eps=0.5, min_samples=10, width=8, height=6, harmonized=False)
  59        Apply DBSCAN clustering to PCA embeddings and visualize results.
  60
  61    perform_UMAP(factorize=False, umap_num=100, pc_num=0, harmonized=False, ...)
  62        Compute UMAP embeddings with optional parameter tuning.
  63
  64    knee_plot_umap(eps=0.5, min_samples=10)
  65        Determine optimal UMAP dimensionality using silhouette scores.
  66
  67    find_clusters_UMAP(umap_n=5, eps=0.5, min_samples=10, width=8, height=6)
  68        Apply DBSCAN clustering on UMAP embeddings and visualize clusters.
  69
  70    UMAP_vis(names_slot='cell_names', set_sep=True, point_size=0.6, ...)
  71        Visualize UMAP embeddings with labels and optional cluster numbering.
  72
  73    UMAP_feature(feature_name, features_data=None, point_size=0.6, ...)
  74        Plot a single feature over UMAP coordinates with customizable colormap.
  75
  76    get_umap_data()
  77        Return the UMAP embeddings along with cluster labels if available.
  78
  79    get_pca_data()
  80        Return the PCA results along with cluster labels if available.
  81
  82    return_clusters(clusters='umap')
  83        Return the cluster labels for UMAP or PCA embeddings.
  84
  85    Raises
  86    ------
  87    ValueError
  88        For invalid parameters, mismatched dimensions, or missing metadata.
  89    """
  90
  91    def __init__(self, data, metadata):
  92        """
  93        Initialize the clustering class with data and optional metadata.
  94
  95        Parameters
  96        ----------
  97        data : pandas.DataFrame
  98            Input data for clustering. Columns are considered as samples.
  99
 100        metadata : pandas.DataFrame, optional
 101            Metadata for the samples. If None, a default DataFrame with column
 102            names as 'cell_names' is created.
 103
 104        Attributes
 105        ----------
 106        -clustering_data : pandas.DataFrame
 107        -clustering_metadata : pandas.DataFrame
 108        -subclusters : None or dict
 109        -explained_var : None or numpy.ndarray
 110        -cumulative_var : None or numpy.ndarray
 111        -pca : None or pandas.DataFrame
 112        -harmonized_pca : None or pandas.DataFrame
 113        -umap : None or pandas.DataFrame
 114        """
 115
 116        self.clustering_data = data
 117        """The input data used for clustering."""
 118
 119        if metadata is None:
 120            metadata = pd.DataFrame({"cell_names": list(data.columns)})
 121
 122        self.clustering_metadata = metadata
 123        """Metadata associated with the samples."""
 124
 125        self.subclusters = None
 126        """Placeholder for storing subcluster information."""
 127
 128        self.explained_var = None
 129        """Explained variance from PCA, initialized as None."""
 130
 131        self.cumulative_var = None
 132        """Cumulative explained variance from PCA, initialized as None."""
 133
 134        self.pca = None
 135        """PCA-transformed data, initialized as None."""
 136
 137        self.harmonized_pca = None
 138        """PCA data after batch effect harmonization, initialized as None."""
 139
 140        self.umap = None
 141        """UMAP embeddings, initialized as None."""
 142
 143    @classmethod
 144    def add_data_frame(cls, data: pd.DataFrame, metadata: pd.DataFrame | None):
 145        """
 146        Create a Clustering instance from a DataFrame and optional metadata.
 147
 148        Parameters
 149        ----------
 150        data : pandas.DataFrame
 151            Input data with features as rows and samples/cells as columns.
 152
 153        metadata : pandas.DataFrame or None
 154            Optional metadata for the samples.
 155            Each row corresponds to a sample/cell, and column names in this DataFrame
 156            should match the sample/cell names in `data`. Columns can contain additional
 157            information such as cell type, experimental condition, batch, sets, etc.
 158
 159        Returns
 160        -------
 161        Clustering
 162            A new instance of the Clustering class.
 163        """
 164
 165        return cls(data, metadata)
 166
 167    def harmonize_sets(self, batch_col: str = "sets"):
 168        """
 169        Perform batch effect harmonization on PCA embeddings.
 170
 171        Parameters
 172        ----------
 173        batch_col : str, default 'sets'
 174            Name of the column in `metadata` that contains batch information for the samples/cells.
 175
 176        Returns
 177        -------
 178        None
 179            Updates the `harmonized_pca` attribute with harmonized data.
 180        """
 181
 182        data_mat = np.array(self.pca)
 183
 184        metadata = self.clustering_metadata
 185
 186        self.harmonized_pca = pd.DataFrame(
 187            harmonize.run_harmony(data_mat, metadata, vars_use=batch_col).Z_corr
 188        ).T
 189
 190        self.harmonized_pca.columns = self.pca.columns
 191
 192    def perform_PCA(self, pc_num: int = 100, width=8, height=6):
 193        """
 194        Perform Principal Component Analysis (PCA) on the dataset.
 195
 196        This method standardizes the data, applies PCA, stores results as attributes,
 197        and generates a scatter plot of the first two principal components.
 198
 199        Parameters
 200        ----------
 201        pc_num : int, default 100
 202            Number of principal components to compute.
 203            If 0, computes all available components.
 204
 205        width : int or float, default 8
 206            Width of the PCA figure.
 207
 208        height : int or float, default 6
 209            Height of the PCA figure.
 210
 211        Returns
 212        -------
 213        matplotlib.figure.Figure
 214            Scatter plot showing the first two principal components.
 215
 216        Updates
 217        -------
 218        self.pca : pandas.DataFrame
 219            DataFrame with principal component scores for each sample.
 220
 221        self.explained_var : numpy.ndarray
 222            Percentage of variance explained by each principal component.
 223
 224        self.cumulative_var : numpy.ndarray
 225            Cumulative explained variance.
 226        """
 227
 228        scaler = StandardScaler()
 229        data_scaled = scaler.fit_transform(self.clustering_data.T)
 230
 231        if pc_num == 0 or pc_num > data_scaled.shape[0]:
 232            pc_num = data_scaled.shape[0]
 233
 234        pca = PCA(n_components=pc_num, random_state=42)
 235
 236        principal_components = pca.fit_transform(data_scaled)
 237
 238        pca_df = pd.DataFrame(
 239            data=principal_components,
 240            columns=["PC" + str(x + 1) for x in range(pc_num)],
 241        )
 242
 243        self.explained_var = pca.explained_variance_ratio_ * 100
 244        self.cumulative_var = np.cumsum(self.explained_var)
 245
 246        self.pca = pca_df
 247
 248        fig = plt.figure(figsize=(width, height))
 249        plt.scatter(pca_df["PC1"], pca_df["PC2"], alpha=0.7)
 250        plt.xlabel("PC 1")
 251        plt.ylabel("PC 2")
 252        plt.grid(True)
 253        plt.show()
 254
 255        return fig
 256
 257    def knee_plot_PCA(self, width: int = 8, height: int = 6):
 258        """
 259        Plot cumulative explained variance to determine the optimal number of PCs.
 260
 261        Parameters
 262        ----------
 263        width : int, default 8
 264            Width of the figure.
 265
 266        height : int or, default 6
 267            Height of the figure.
 268
 269        Returns
 270        -------
 271        matplotlib.figure.Figure
 272            Line plot showing cumulative variance explained by each PC.
 273        """
 274
 275        fig_knee = plt.figure(figsize=(width, height))
 276        plt.plot(range(1, len(self.explained_var) + 1), self.cumulative_var, marker="o")
 277        plt.xlabel("PC (n components)")
 278        plt.ylabel("Cumulative explained variance (%)")
 279        plt.grid(True)
 280
 281        xticks = [1] + list(range(5, len(self.explained_var) + 1, 5))
 282        plt.xticks(xticks, rotation=60)
 283
 284        plt.show()
 285
 286        return fig_knee
 287
 288    def find_clusters_PCA(
 289        self,
 290        pc_num: int = 2,
 291        eps: float = 0.5,
 292        min_samples: int = 10,
 293        width: int = 8,
 294        height: int = 6,
 295        harmonized: bool = False,
 296    ):
 297        """
 298        Apply DBSCAN clustering to PCA embeddings and visualize the results.
 299
 300        This method performs density-based clustering (DBSCAN) on the PCA-reduced
 301        dataset. Cluster labels are stored in the object's metadata, and a scatter
 302        plot of the first two principal components with cluster annotations is returned.
 303
 304        Parameters
 305        ----------
 306        pc_num : int, default 2
 307            Number of principal components to use for clustering.
 308            If 0, uses all available components.
 309
 310        eps : float, default 0.5
 311            Maximum distance between two points for them to be considered
 312            as neighbors (DBSCAN parameter).
 313
 314        min_samples : int, default 10
 315            Minimum number of samples required to form a cluster (DBSCAN parameter).
 316
 317        width : int, default 8
 318            Width of the output scatter plot.
 319
 320        height : int, default 6
 321            Height of the output scatter plot.
 322
 323        harmonized : bool, default False
 324            If True, use harmonized PCA data (`self.harmonized_pca`).
 325            If False, use standard PCA results (`self.pca`).
 326
 327        Returns
 328        -------
 329        matplotlib.figure.Figure
 330            Scatter plot of the first two principal components colored by
 331            cluster assignments.
 332
 333        Updates
 334        -------
 335        self.clustering_metadata['PCA_clusters'] : list
 336            Cluster labels assigned to each cell/sample.
 337
 338        self.input_metadata['PCA_clusters'] : list, optional
 339            Cluster labels stored in input metadata (if available).
 340        """
 341
 342        dbscan = DBSCAN(eps=eps, min_samples=min_samples)
 343
 344        if pc_num == 0 and harmonized:
 345            PCA = self.harmonized_pca
 346
 347        elif pc_num == 0:
 348            PCA = self.pca
 349
 350        else:
 351            if harmonized:
 352
 353                PCA = self.harmonized_pca.iloc[:, 0:pc_num]
 354            else:
 355
 356                PCA = self.pca.iloc[:, 0:pc_num]
 357
 358        dbscan_labels = dbscan.fit_predict(PCA)
 359
 360        pca_df = pd.DataFrame(PCA)
 361        pca_df["Cluster"] = dbscan_labels
 362
 363        fig = plt.figure(figsize=(width, height))
 364
 365        for cluster_id in sorted(pca_df["Cluster"].unique()):
 366            cluster_data = pca_df[pca_df["Cluster"] == cluster_id]
 367            plt.scatter(
 368                cluster_data["PC1"],
 369                cluster_data["PC2"],
 370                label=f"Cluster {cluster_id}",
 371                alpha=0.7,
 372            )
 373
 374        plt.xlabel("PC 1")
 375        plt.ylabel("PC 2")
 376        plt.legend(title="Clusters", loc="center left", bbox_to_anchor=(1.0, 0.5))
 377
 378        plt.grid(True)
 379        plt.show()
 380
 381        self.clustering_metadata["PCA_clusters"] = [str(x) for x in dbscan_labels]
 382
 383        try:
 384            self.input_metadata["PCA_clusters"] = [str(x) for x in dbscan_labels]
 385        except:
 386            pass
 387
 388        return fig
 389
 390    def perform_UMAP(
 391        self,
 392        factorize: bool = False,
 393        umap_num: int = 100,
 394        pc_num: int = 0,
 395        harmonized: bool = False,
 396        n_neighbors: int = 5,
 397        min_dist: float | int = 0.1,
 398        spread: float | int = 1.0,
 399        set_op_mix_ratio: float | int = 1.0,
 400        local_connectivity: int = 1,
 401        repulsion_strength: float | int = 1.0,
 402        negative_sample_rate: int = 5,
 403        width: int = 8,
 404        height: int = 6,
 405    ):
 406        """
 407        Compute and visualize UMAP embeddings of the dataset.
 408
 409        This method applies Uniform Manifold Approximation and Projection (UMAP)
 410        for dimensionality reduction on either raw, PCA, or harmonized PCA data.
 411        Results are stored as a DataFrame (`self.umap`) and a scatter plot figure
 412        (`self.UMAP_plot`).
 413
 414        Parameters
 415        ----------
 416        factorize : bool, default False
 417            If True, categorical sample labels (from column names) are factorized
 418            and used as supervision in UMAP fitting.
 419
 420        umap_num : int, default 100
 421            Number of UMAP dimensions to compute. If 0, matches the input dimension.
 422
 423        pc_num : int, default 0
 424            Number of principal components to use as UMAP input.
 425            If 0, use all available components or raw data.
 426
 427        harmonized : bool, default False
 428            If True, use harmonized PCA embeddings (`self.harmonized_pca`).
 429            If False, use standard PCA or raw scaled data.
 430
 431        n_neighbors : int, default 5
 432            UMAP parameter controlling the size of the local neighborhood.
 433
 434        min_dist : float, default 0.1
 435            UMAP parameter controlling minimum allowed distance between embedded points.
 436
 437        spread : int | float, default 1.0
 438            Effective scale of embedded space (UMAP parameter).
 439
 440        set_op_mix_ratio : int | float, default 1.0
 441            Interpolation parameter between union and intersection in fuzzy sets.
 442
 443        local_connectivity : int, default 1
 444            Number of nearest neighbors assumed for each point.
 445
 446        repulsion_strength : int | float, default 1.0
 447            Weighting applied to negative samples during optimization.
 448
 449        negative_sample_rate : int, default 5
 450            Number of negative samples per positive sample in optimization.
 451
 452        width : int, default 8
 453            Width of the output scatter plot.
 454
 455        height : int, default 6
 456            Height of the output scatter plot.
 457
 458        Updates
 459        -------
 460        self.umap : pandas.DataFrame
 461            Table of UMAP embeddings with columns `UMAP1 ... UMAPn`.
 462
 463        Notes
 464        -----
 465        For supervised UMAP (`factorize=True`), categorical codes from column
 466        names of the dataset are used as labels.
 467        """
 468
 469        scaler = StandardScaler()
 470
 471        if pc_num == 0 and harmonized:
 472            data_scaled = self.harmonized_pca
 473
 474        elif pc_num == 0:
 475            data_scaled = scaler.fit_transform(self.clustering_data.T)
 476
 477        else:
 478            if harmonized:
 479
 480                data_scaled = self.harmonized_pca.iloc[:, 0:pc_num]
 481            else:
 482
 483                data_scaled = self.pca.iloc[:, 0:pc_num]
 484
 485        if umap_num == 0 or umap_num > data_scaled.shape[1]:
 486
 487            umap_num = data_scaled.shape[1]
 488
 489            reducer = umap.UMAP(
 490                n_components=len(data_scaled.T),
 491                random_state=42,
 492                n_neighbors=n_neighbors,
 493                min_dist=min_dist,
 494                spread=spread,
 495                set_op_mix_ratio=set_op_mix_ratio,
 496                local_connectivity=local_connectivity,
 497                repulsion_strength=repulsion_strength,
 498                negative_sample_rate=negative_sample_rate,
 499                n_jobs=1,
 500            )
 501
 502        else:
 503
 504            reducer = umap.UMAP(
 505                n_components=umap_num,
 506                random_state=42,
 507                n_neighbors=n_neighbors,
 508                min_dist=min_dist,
 509                spread=spread,
 510                set_op_mix_ratio=set_op_mix_ratio,
 511                local_connectivity=local_connectivity,
 512                repulsion_strength=repulsion_strength,
 513                negative_sample_rate=negative_sample_rate,
 514                n_jobs=1,
 515            )
 516
 517        if factorize:
 518            embedding = reducer.fit_transform(
 519                X=data_scaled, y=pd.Categorical(self.clustering_data.columns).codes
 520            )
 521        else:
 522            embedding = reducer.fit_transform(X=data_scaled)
 523
 524        umap_df = pd.DataFrame(
 525            embedding, columns=["UMAP" + str(x + 1) for x in range(umap_num)]
 526        )
 527
 528        plt.figure(figsize=(width, height))
 529        plt.scatter(umap_df["UMAP1"], umap_df["UMAP2"], alpha=0.7)
 530        plt.xlabel("UMAP 1")
 531        plt.ylabel("UMAP 2")
 532        plt.grid(True)
 533
 534        plt.show()
 535
 536        self.umap = umap_df
 537
 538    def knee_plot_umap(self, eps: int | float = 0.5, min_samples: int = 10):
 539        """
 540        Plot silhouette scores for different UMAP dimensions to determine optimal n_components.
 541
 542        Parameters
 543        ----------
 544        eps : float, default 0.5
 545            DBSCAN eps parameter for clustering each UMAP dimension.
 546
 547        min_samples : int, default 10
 548            Minimum number of samples to form a cluster in DBSCAN.
 549
 550        Returns
 551        -------
 552        matplotlib.figure.Figure
 553            Silhouette score plot across UMAP dimensions.
 554        """
 555
 556        umap_range = range(2, len(self.umap.T) + 1)
 557
 558        silhouette_scores = []
 559        component = []
 560        for n in umap_range:
 561
 562            db = DBSCAN(eps=eps, min_samples=min_samples)
 563            labels = db.fit_predict(np.array(self.umap)[:, :n])
 564
 565            mask = labels != -1
 566            if len(set(labels[mask])) > 1:
 567                score = silhouette_score(np.array(self.umap)[:, :n][mask], labels[mask])
 568            else:
 569                score = -1
 570
 571            silhouette_scores.append(score)
 572            component.append(n)
 573
 574        fig = plt.figure(figsize=(10, 5))
 575        plt.plot(component, silhouette_scores, marker="o")
 576        plt.xlabel("UMAP (n_components)")
 577        plt.ylabel("Silhouette Score")
 578        plt.grid(True)
 579        plt.xticks(range(int(min(component)), int(max(component)) + 1, 1))
 580
 581        plt.show()
 582
 583        return fig
 584
 585    def find_clusters_UMAP(
 586        self,
 587        umap_n: int = 5,
 588        eps: float | float = 0.5,
 589        min_samples: int = 10,
 590        width: int = 8,
 591        height: int = 6,
 592    ):
 593        """
 594        Apply DBSCAN clustering on UMAP embeddings and visualize clusters.
 595
 596        This method performs density-based clustering (DBSCAN) on the UMAP-reduced
 597        dataset. Cluster labels are stored in the object's metadata, and a scatter
 598        plot of the first two UMAP components with cluster annotations is returned.
 599
 600        Parameters
 601        ----------
 602        umap_n : int, default 5
 603            Number of UMAP dimensions to use for DBSCAN clustering.
 604            Must be <= number of columns in `self.umap`.
 605
 606        eps : float | int, default 0.5
 607            Maximum neighborhood distance between two samples for them to be considered
 608            as in the same cluster (DBSCAN parameter).
 609
 610        min_samples : int, default 10
 611            Minimum number of samples in a neighborhood to form a cluster (DBSCAN parameter).
 612
 613        width : int, default 8
 614            Figure width.
 615
 616        height : int, default 6
 617            Figure height.
 618
 619        Returns
 620        -------
 621        matplotlib.figure.Figure
 622            Scatter plot of the first two UMAP components colored by
 623            cluster assignments.
 624
 625        Updates
 626        -------
 627        self.clustering_metadata['UMAP_clusters'] : list
 628            Cluster labels assigned to each cell/sample.
 629
 630        self.input_metadata['UMAP_clusters'] : list, optional
 631            Cluster labels stored in input metadata (if available).
 632        """
 633
 634        dbscan = DBSCAN(eps=eps, min_samples=min_samples)
 635        dbscan_labels = dbscan.fit_predict(np.array(self.umap)[:, :umap_n])
 636
 637        umap_df = self.umap
 638        umap_df["Cluster"] = dbscan_labels
 639
 640        fig = plt.figure(figsize=(width, height))
 641
 642        for cluster_id in sorted(umap_df["Cluster"].unique()):
 643            cluster_data = umap_df[umap_df["Cluster"] == cluster_id]
 644            plt.scatter(
 645                cluster_data["UMAP1"],
 646                cluster_data["UMAP2"],
 647                label=f"Cluster {cluster_id}",
 648                alpha=0.7,
 649            )
 650
 651        plt.xlabel("UMAP 1")
 652        plt.ylabel("UMAP 2")
 653        plt.legend(title="Clusters", loc="center left", bbox_to_anchor=(1.0, 0.5))
 654        plt.grid(True)
 655
 656        self.clustering_metadata["UMAP_clusters"] = [str(x) for x in dbscan_labels]
 657
 658        try:
 659            self.input_metadata["UMAP_clusters"] = [str(x) for x in dbscan_labels]
 660        except:
 661            pass
 662
 663        return fig
 664
 665    def UMAP_vis(
 666        self,
 667        names_slot: str = "cell_names",
 668        set_sep: bool = True,
 669        point_size: int | float = 0.6,
 670        font_size: int | float = 6,
 671        legend_split_col: int = 2,
 672        width: int = 8,
 673        height: int = 6,
 674        inc_num: bool = True,
 675    ):
 676        """
 677        Visualize UMAP embeddings with sample labels based on specyfic metadata slot.
 678
 679        Parameters
 680        ----------
 681        names_slot : str, default 'cell_names'
 682            Column in metadata to use as sample labels.
 683
 684        set_sep : bool, default True
 685            If True, separate points by dataset.
 686
 687        point_size : float, default 0.6
 688            Size of scatter points.
 689
 690        font_size : int, default 6
 691            Font size for numbers on points.
 692
 693        legend_split_col : int, default 2
 694            Number of columns in legend.
 695
 696        width : int, default 8
 697            Figure width.
 698
 699        height : int, default 6
 700            Figure height.
 701
 702        inc_num : bool, default True
 703            If True, annotate points with numeric labels.
 704
 705        Returns
 706        -------
 707        matplotlib.figure.Figure
 708            UMAP scatter plot figure.
 709        """
 710
 711        umap_df = self.umap.iloc[:, 0:2].copy()
 712        umap_df["names"] = list(self.clustering_metadata[names_slot])
 713
 714        if set_sep:
 715
 716            if "sets" in list(self.clustering_metadata.columns):
 717                umap_df["dataset"] = list(self.clustering_metadata["sets"])
 718            else:
 719                umap_df["dataset"] = "default"
 720
 721        else:
 722            umap_df["dataset"] = "default"
 723
 724        umap_df["tmp_nam"] = list(umap_df["names"] + umap_df["dataset"])
 725
 726        umap_df["count"] = umap_df["tmp_nam"].map(umap_df["tmp_nam"].value_counts())
 727
 728        numeric_df = (
 729            pd.DataFrame(umap_df[["count", "tmp_nam", "names"]].copy())
 730            .drop_duplicates()
 731            .sort_values("count", ascending=False)
 732        )
 733        numeric_df["numeric_values"] = range(0, numeric_df.shape[0])
 734
 735        umap_df = umap_df.merge(
 736            numeric_df[["tmp_nam", "numeric_values"]], on="tmp_nam", how="left"
 737        )
 738
 739        fig, ax = plt.subplots(figsize=(width, height))
 740
 741        markers = ["o", "s", "^", "D", "P", "*", "X"]
 742        marker_map = {
 743            ds: markers[i % len(markers)]
 744            for i, ds in enumerate(umap_df["dataset"].unique())
 745        }
 746
 747        cord_list = []
 748
 749        for num, nam in zip(numeric_df["numeric_values"], numeric_df["names"]):
 750
 751            cluster_data = umap_df[umap_df["numeric_values"] == num]
 752
 753            ax.scatter(
 754                cluster_data["UMAP1"],
 755                cluster_data["UMAP2"],
 756                label=f"{num} - {nam}",
 757                marker=marker_map[cluster_data["dataset"].iloc[0]],
 758                alpha=0.6,
 759                s=point_size,
 760            )
 761
 762            coords = cluster_data[["UMAP1", "UMAP2"]].values
 763
 764            dists = pairwise_distances(coords)
 765
 766            sum_dists = dists.sum(axis=1)
 767
 768            center_idx = np.argmin(sum_dists)
 769            center_point = coords[center_idx]
 770
 771            cord_list.append(center_point)
 772
 773        if inc_num:
 774            texts = []
 775            for (x, y), num in zip(cord_list, numeric_df["numeric_values"]):
 776                texts.append(
 777                    ax.text(
 778                        x,
 779                        y,
 780                        str(num),
 781                        ha="center",
 782                        va="center",
 783                        fontsize=font_size,
 784                        color="black",
 785                    )
 786                )
 787
 788            adjust_text(texts, arrowprops=dict(arrowstyle="-", color="gray", lw=0.5))
 789
 790        ax.set_xlabel("UMAP 1")
 791        ax.set_ylabel("UMAP 2")
 792
 793        ax.legend(
 794            title="Clusters",
 795            loc="center left",
 796            bbox_to_anchor=(1.05, 0.5),
 797            ncol=legend_split_col,
 798            markerscale=5,
 799        )
 800
 801        ax.grid(True)
 802
 803        plt.tight_layout()
 804
 805        return fig
 806
 807    def UMAP_feature(
 808        self,
 809        feature_name: str,
 810        features_data: pd.DataFrame | None,
 811        point_size: int | float = 0.6,
 812        font_size: int | float = 6,
 813        width: int = 8,
 814        height: int = 6,
 815        palette="light",
 816    ):
 817        """
 818        Visualize UMAP embedding with expression levels of a selected feature.
 819
 820        Each point (cell) in the UMAP plot is colored according to the expression
 821        value of the chosen feature, enabling interpretation of spatial patterns
 822        of gene activity or metadata distribution in low-dimensional space.
 823
 824        Parameters
 825        ----------
 826        feature_name : str
 827           Name of the feature to plot.
 828
 829        features_data : pandas.DataFrame or None, default None
 830            If None, the function uses the DataFrame containing the clustering data.
 831            To plot features not used in clustering, provide a wider DataFrame
 832            containing the original feature values.
 833
 834        point_size : float, default 0.6
 835            Size of scatter points in the plot.
 836
 837        font_size : int, default 6
 838            Font size for axis labels and annotations.
 839
 840        width : int, default 8
 841            Width of the matplotlib figure.
 842
 843        height : int, default 6
 844            Height of the matplotlib figure.
 845
 846        palette : str, default 'light'
 847            Color palette for expression visualization. Options are:
 848            - 'light'
 849            - 'dark'
 850            - 'green'
 851            - 'gray'
 852
 853        Returns
 854        -------
 855        matplotlib.figure.Figure
 856            UMAP scatter plot colored by feature values.
 857        """
 858
 859        umap_df = self.umap.iloc[:, 0:2].copy()
 860
 861        if features_data is None:
 862
 863            features_data = self.clustering_data
 864
 865        if features_data.shape[1] != umap_df.shape[0]:
 866            raise ValueError(
 867                "Imputed 'features_data' shape does not match the number of UMAP cells"
 868            )
 869
 870        blist = [
 871            True if x.upper() == feature_name.upper() else False
 872            for x in features_data.index
 873        ]
 874
 875        if not any(blist):
 876            raise ValueError("Imputed feature_name is not included in the data")
 877
 878        umap_df.loc[:, "feature"] = (
 879            features_data.loc[blist, :]
 880            .apply(lambda row: row.tolist(), axis=1)
 881            .values[0]
 882        )
 883
 884        umap_df = umap_df.sort_values("feature", ascending=True)
 885
 886        import matplotlib.colors as mcolors
 887
 888        if palette == "light":
 889            palette = px.colors.sequential.Sunsetdark
 890
 891        elif palette == "dark":
 892            palette = px.colors.sequential.thermal
 893
 894        elif palette == "green":
 895            palette = px.colors.sequential.Aggrnyl
 896
 897        elif palette == "gray":
 898            palette = px.colors.sequential.gray
 899            palette = palette[::-1]
 900
 901        else:
 902            raise ValueError(
 903                'Palette not found. Use: "light", "dark", "gray", or "green"'
 904            )
 905
 906        converted = []
 907        for c in palette:
 908            rgb_255 = px.colors.unlabel_rgb(c)
 909            rgb_01 = tuple(v / 255.0 for v in rgb_255)
 910            converted.append(rgb_01)
 911
 912        my_cmap = mcolors.ListedColormap(converted, name="custom")
 913
 914        fig, ax = plt.subplots(figsize=(width, height))
 915        sc = ax.scatter(
 916            umap_df["UMAP1"],
 917            umap_df["UMAP2"],
 918            c=umap_df["feature"],
 919            s=point_size,
 920            cmap=my_cmap,
 921            alpha=1.0,
 922            edgecolors="black",
 923            linewidths=0.1,
 924        )
 925
 926        cbar = plt.colorbar(sc, ax=ax)
 927        cbar.set_label(f"{feature_name}")
 928
 929        ax.set_xlabel("UMAP 1")
 930        ax.set_ylabel("UMAP 2")
 931
 932        ax.grid(True)
 933
 934        plt.tight_layout()
 935
 936        return fig
 937
 938    def get_umap_data(self):
 939        """
 940        Retrieve UMAP embedding data with optional cluster labels.
 941
 942        Returns the UMAP coordinates stored in `self.umap`. If clustering
 943        metadata is available (specifically `UMAP_clusters`), the corresponding
 944        cluster assignments are appended as an additional column.
 945
 946        Returns
 947        -------
 948        pandas.DataFrame
 949            DataFrame containing UMAP coordinates (columns: 'UMAP1', 'UMAP2', ...).
 950            If available, includes an extra column 'clusters' with cluster labels.
 951
 952        Notes
 953        -----
 954        - UMAP embeddings must be computed beforehand (e.g., using `perform_UMAP`).
 955        - Cluster labels are added only if present in `self.clustering_metadata`.
 956        """
 957
 958        umap_data = self.umap
 959
 960        try:
 961            umap_data["clusters"] = self.clustering_metadata["UMAP_clusters"]
 962        except:
 963            pass
 964
 965        return umap_data
 966
 967    def get_pca_data(self):
 968        """
 969        Retrieve PCA embedding data with optional cluster labels.
 970
 971        Returns the principal component scores stored in `self.pca`. If clustering
 972        metadata is available (specifically `PCA_clusters`), the corresponding
 973        cluster assignments are appended as an additional column.
 974
 975        Returns
 976        -------
 977        pandas.DataFrame
 978            DataFrame containing PCA coordinates (columns: 'PC1', 'PC2', ...).
 979            If available, includes an extra column 'clusters' with cluster labels.
 980
 981        Notes
 982        -----
 983        - PCA must be computed beforehand (e.g., using `perform_PCA`).
 984        - Cluster labels are added only if present in `self.clustering_metadata`.
 985        """
 986
 987        pca_data = self.pca
 988
 989        try:
 990            pca_data["clusters"] = self.clustering_metadata["PCA_clusters"]
 991        except:
 992            pass
 993
 994        return pca_data
 995
 996    def return_clusters(self, clusters="umap"):
 997        """
 998        Retrieve cluster labels from UMAP or PCA clustering results.
 999
1000        Parameters
1001        ----------
1002        clusters : str, default 'umap'
1003            Source of cluster labels to return. Must be one of:
1004            - 'umap': return cluster labels from UMAP embeddings.
1005            - 'pca' : return cluster labels from PCA embeddings.
1006
1007        Returns
1008        -------
1009        list
1010            Cluster labels corresponding to the selected embedding method.
1011
1012        Raises
1013        ------
1014        ValueError
1015            If `clusters` is not 'umap' or 'pca'.
1016
1017        Notes
1018        -----
1019        Requires that clustering has already been performed
1020        (e.g., using `find_clusters_UMAP` or `find_clusters_PCA`).
1021        """
1022
1023        if clusters.lower() == "umap":
1024            clusters_vector = self.clustering_metadata["UMAP_clusters"]
1025        elif clusters.lower() == "pca":
1026            clusters_vector = self.clustering_metadata["PCA_clusters"]
1027        else:
1028            raise ValueError("Parameter 'clusters' must be either 'umap' or 'pca'.")
1029
1030        return clusters_vector

A class for performing dimensionality reduction, clustering, and visualization on high-dimensional data (e.g., single-cell gene expression).

The class provides methods for:

  • Normalizing and extracting subsets of data
  • Principal Component Analysis (PCA) and related clustering
  • Uniform Manifold Approximation and Projection (UMAP) and clustering
  • Visualization of PCA and UMAP embeddings
  • Harmonization of batch effects
  • Accessing processed data and cluster labels

Methods

add_data_frame(data, metadata) Class method to create a Clustering instance from a DataFrame and metadata.

harmonize_sets() Perform batch effect harmonization on PCA data.

perform_PCA(pc_num=100, width=8, height=6) Perform PCA on the dataset and visualize the first two PCs.

knee_plot_PCA(width=8, height=6) Plot the cumulative variance explained by PCs to determine optimal dimensionality.

find_clusters_PCA(pc_num=0, eps=0.5, min_samples=10, width=8, height=6, harmonized=False) Apply DBSCAN clustering to PCA embeddings and visualize results.

perform_UMAP(factorize=False, umap_num=100, pc_num=0, harmonized=False, ...) Compute UMAP embeddings with optional parameter tuning.

knee_plot_umap(eps=0.5, min_samples=10) Determine optimal UMAP dimensionality using silhouette scores.

find_clusters_UMAP(umap_n=5, eps=0.5, min_samples=10, width=8, height=6) Apply DBSCAN clustering on UMAP embeddings and visualize clusters.

UMAP_vis(names_slot='cell_names', set_sep=True, point_size=0.6, ...) Visualize UMAP embeddings with labels and optional cluster numbering.

UMAP_feature(feature_name, features_data=None, point_size=0.6, ...) Plot a single feature over UMAP coordinates with customizable colormap.

get_umap_data() Return the UMAP embeddings along with cluster labels if available.

get_pca_data() Return the PCA results along with cluster labels if available.

return_clusters(clusters='umap') Return the cluster labels for UMAP or PCA embeddings.

Raises

ValueError For invalid parameters, mismatched dimensions, or missing metadata.

Clustering(data, metadata)
 91    def __init__(self, data, metadata):
 92        """
 93        Initialize the clustering class with data and optional metadata.
 94
 95        Parameters
 96        ----------
 97        data : pandas.DataFrame
 98            Input data for clustering. Columns are considered as samples.
 99
100        metadata : pandas.DataFrame, optional
101            Metadata for the samples. If None, a default DataFrame with column
102            names as 'cell_names' is created.
103
104        Attributes
105        ----------
106        -clustering_data : pandas.DataFrame
107        -clustering_metadata : pandas.DataFrame
108        -subclusters : None or dict
109        -explained_var : None or numpy.ndarray
110        -cumulative_var : None or numpy.ndarray
111        -pca : None or pandas.DataFrame
112        -harmonized_pca : None or pandas.DataFrame
113        -umap : None or pandas.DataFrame
114        """
115
116        self.clustering_data = data
117        """The input data used for clustering."""
118
119        if metadata is None:
120            metadata = pd.DataFrame({"cell_names": list(data.columns)})
121
122        self.clustering_metadata = metadata
123        """Metadata associated with the samples."""
124
125        self.subclusters = None
126        """Placeholder for storing subcluster information."""
127
128        self.explained_var = None
129        """Explained variance from PCA, initialized as None."""
130
131        self.cumulative_var = None
132        """Cumulative explained variance from PCA, initialized as None."""
133
134        self.pca = None
135        """PCA-transformed data, initialized as None."""
136
137        self.harmonized_pca = None
138        """PCA data after batch effect harmonization, initialized as None."""
139
140        self.umap = None
141        """UMAP embeddings, initialized as None."""

Initialize the clustering class with data and optional metadata.

Parameters

data : pandas.DataFrame Input data for clustering. Columns are considered as samples.

metadata : pandas.DataFrame, optional Metadata for the samples. If None, a default DataFrame with column names as 'cell_names' is created.

Attributes

-clustering_data : pandas.DataFrame -clustering_metadata : pandas.DataFrame -subclusters : None or dict -explained_var : None or numpy.ndarray -cumulative_var : None or numpy.ndarray -pca : None or pandas.DataFrame -harmonized_pca : None or pandas.DataFrame -umap : None or pandas.DataFrame

clustering_data

The input data used for clustering.

clustering_metadata

Metadata associated with the samples.

subclusters

Placeholder for storing subcluster information.

explained_var

Explained variance from PCA, initialized as None.

cumulative_var

Cumulative explained variance from PCA, initialized as None.

pca

PCA-transformed data, initialized as None.

harmonized_pca

PCA data after batch effect harmonization, initialized as None.

umap

UMAP embeddings, initialized as None.

@classmethod
def add_data_frame( cls, data: pandas.core.frame.DataFrame, metadata: pandas.core.frame.DataFrame | None):
143    @classmethod
144    def add_data_frame(cls, data: pd.DataFrame, metadata: pd.DataFrame | None):
145        """
146        Create a Clustering instance from a DataFrame and optional metadata.
147
148        Parameters
149        ----------
150        data : pandas.DataFrame
151            Input data with features as rows and samples/cells as columns.
152
153        metadata : pandas.DataFrame or None
154            Optional metadata for the samples.
155            Each row corresponds to a sample/cell, and column names in this DataFrame
156            should match the sample/cell names in `data`. Columns can contain additional
157            information such as cell type, experimental condition, batch, sets, etc.
158
159        Returns
160        -------
161        Clustering
162            A new instance of the Clustering class.
163        """
164
165        return cls(data, metadata)

Create a Clustering instance from a DataFrame and optional metadata.

Parameters

data : pandas.DataFrame Input data with features as rows and samples/cells as columns.

metadata : pandas.DataFrame or None Optional metadata for the samples. Each row corresponds to a sample/cell, and column names in this DataFrame should match the sample/cell names in data. Columns can contain additional information such as cell type, experimental condition, batch, sets, etc.

Returns

Clustering A new instance of the Clustering class.

def harmonize_sets(self, batch_col: str = 'sets'):
167    def harmonize_sets(self, batch_col: str = "sets"):
168        """
169        Perform batch effect harmonization on PCA embeddings.
170
171        Parameters
172        ----------
173        batch_col : str, default 'sets'
174            Name of the column in `metadata` that contains batch information for the samples/cells.
175
176        Returns
177        -------
178        None
179            Updates the `harmonized_pca` attribute with harmonized data.
180        """
181
182        data_mat = np.array(self.pca)
183
184        metadata = self.clustering_metadata
185
186        self.harmonized_pca = pd.DataFrame(
187            harmonize.run_harmony(data_mat, metadata, vars_use=batch_col).Z_corr
188        ).T
189
190        self.harmonized_pca.columns = self.pca.columns

Perform batch effect harmonization on PCA embeddings.

Parameters

batch_col : str, default 'sets' Name of the column in metadata that contains batch information for the samples/cells.

Returns

None Updates the harmonized_pca attribute with harmonized data.

def perform_PCA(self, pc_num: int = 100, width=8, height=6):
192    def perform_PCA(self, pc_num: int = 100, width=8, height=6):
193        """
194        Perform Principal Component Analysis (PCA) on the dataset.
195
196        This method standardizes the data, applies PCA, stores results as attributes,
197        and generates a scatter plot of the first two principal components.
198
199        Parameters
200        ----------
201        pc_num : int, default 100
202            Number of principal components to compute.
203            If 0, computes all available components.
204
205        width : int or float, default 8
206            Width of the PCA figure.
207
208        height : int or float, default 6
209            Height of the PCA figure.
210
211        Returns
212        -------
213        matplotlib.figure.Figure
214            Scatter plot showing the first two principal components.
215
216        Updates
217        -------
218        self.pca : pandas.DataFrame
219            DataFrame with principal component scores for each sample.
220
221        self.explained_var : numpy.ndarray
222            Percentage of variance explained by each principal component.
223
224        self.cumulative_var : numpy.ndarray
225            Cumulative explained variance.
226        """
227
228        scaler = StandardScaler()
229        data_scaled = scaler.fit_transform(self.clustering_data.T)
230
231        if pc_num == 0 or pc_num > data_scaled.shape[0]:
232            pc_num = data_scaled.shape[0]
233
234        pca = PCA(n_components=pc_num, random_state=42)
235
236        principal_components = pca.fit_transform(data_scaled)
237
238        pca_df = pd.DataFrame(
239            data=principal_components,
240            columns=["PC" + str(x + 1) for x in range(pc_num)],
241        )
242
243        self.explained_var = pca.explained_variance_ratio_ * 100
244        self.cumulative_var = np.cumsum(self.explained_var)
245
246        self.pca = pca_df
247
248        fig = plt.figure(figsize=(width, height))
249        plt.scatter(pca_df["PC1"], pca_df["PC2"], alpha=0.7)
250        plt.xlabel("PC 1")
251        plt.ylabel("PC 2")
252        plt.grid(True)
253        plt.show()
254
255        return fig

Perform Principal Component Analysis (PCA) on the dataset.

This method standardizes the data, applies PCA, stores results as attributes, and generates a scatter plot of the first two principal components.

Parameters

pc_num : int, default 100 Number of principal components to compute. If 0, computes all available components.

width : int or float, default 8 Width of the PCA figure.

height : int or float, default 6 Height of the PCA figure.

Returns

matplotlib.figure.Figure Scatter plot showing the first two principal components.

Updates

self.pca : pandas.DataFrame DataFrame with principal component scores for each sample.

self.explained_var : numpy.ndarray Percentage of variance explained by each principal component.

self.cumulative_var : numpy.ndarray Cumulative explained variance.

def knee_plot_PCA(self, width: int = 8, height: int = 6):
257    def knee_plot_PCA(self, width: int = 8, height: int = 6):
258        """
259        Plot cumulative explained variance to determine the optimal number of PCs.
260
261        Parameters
262        ----------
263        width : int, default 8
264            Width of the figure.
265
266        height : int or, default 6
267            Height of the figure.
268
269        Returns
270        -------
271        matplotlib.figure.Figure
272            Line plot showing cumulative variance explained by each PC.
273        """
274
275        fig_knee = plt.figure(figsize=(width, height))
276        plt.plot(range(1, len(self.explained_var) + 1), self.cumulative_var, marker="o")
277        plt.xlabel("PC (n components)")
278        plt.ylabel("Cumulative explained variance (%)")
279        plt.grid(True)
280
281        xticks = [1] + list(range(5, len(self.explained_var) + 1, 5))
282        plt.xticks(xticks, rotation=60)
283
284        plt.show()
285
286        return fig_knee

Plot cumulative explained variance to determine the optimal number of PCs.

Parameters

width : int, default 8 Width of the figure.

height : int or, default 6 Height of the figure.

Returns

matplotlib.figure.Figure Line plot showing cumulative variance explained by each PC.

def find_clusters_PCA( self, pc_num: int = 2, eps: float = 0.5, min_samples: int = 10, width: int = 8, height: int = 6, harmonized: bool = False):
288    def find_clusters_PCA(
289        self,
290        pc_num: int = 2,
291        eps: float = 0.5,
292        min_samples: int = 10,
293        width: int = 8,
294        height: int = 6,
295        harmonized: bool = False,
296    ):
297        """
298        Apply DBSCAN clustering to PCA embeddings and visualize the results.
299
300        This method performs density-based clustering (DBSCAN) on the PCA-reduced
301        dataset. Cluster labels are stored in the object's metadata, and a scatter
302        plot of the first two principal components with cluster annotations is returned.
303
304        Parameters
305        ----------
306        pc_num : int, default 2
307            Number of principal components to use for clustering.
308            If 0, uses all available components.
309
310        eps : float, default 0.5
311            Maximum distance between two points for them to be considered
312            as neighbors (DBSCAN parameter).
313
314        min_samples : int, default 10
315            Minimum number of samples required to form a cluster (DBSCAN parameter).
316
317        width : int, default 8
318            Width of the output scatter plot.
319
320        height : int, default 6
321            Height of the output scatter plot.
322
323        harmonized : bool, default False
324            If True, use harmonized PCA data (`self.harmonized_pca`).
325            If False, use standard PCA results (`self.pca`).
326
327        Returns
328        -------
329        matplotlib.figure.Figure
330            Scatter plot of the first two principal components colored by
331            cluster assignments.
332
333        Updates
334        -------
335        self.clustering_metadata['PCA_clusters'] : list
336            Cluster labels assigned to each cell/sample.
337
338        self.input_metadata['PCA_clusters'] : list, optional
339            Cluster labels stored in input metadata (if available).
340        """
341
342        dbscan = DBSCAN(eps=eps, min_samples=min_samples)
343
344        if pc_num == 0 and harmonized:
345            PCA = self.harmonized_pca
346
347        elif pc_num == 0:
348            PCA = self.pca
349
350        else:
351            if harmonized:
352
353                PCA = self.harmonized_pca.iloc[:, 0:pc_num]
354            else:
355
356                PCA = self.pca.iloc[:, 0:pc_num]
357
358        dbscan_labels = dbscan.fit_predict(PCA)
359
360        pca_df = pd.DataFrame(PCA)
361        pca_df["Cluster"] = dbscan_labels
362
363        fig = plt.figure(figsize=(width, height))
364
365        for cluster_id in sorted(pca_df["Cluster"].unique()):
366            cluster_data = pca_df[pca_df["Cluster"] == cluster_id]
367            plt.scatter(
368                cluster_data["PC1"],
369                cluster_data["PC2"],
370                label=f"Cluster {cluster_id}",
371                alpha=0.7,
372            )
373
374        plt.xlabel("PC 1")
375        plt.ylabel("PC 2")
376        plt.legend(title="Clusters", loc="center left", bbox_to_anchor=(1.0, 0.5))
377
378        plt.grid(True)
379        plt.show()
380
381        self.clustering_metadata["PCA_clusters"] = [str(x) for x in dbscan_labels]
382
383        try:
384            self.input_metadata["PCA_clusters"] = [str(x) for x in dbscan_labels]
385        except:
386            pass
387
388        return fig

Apply DBSCAN clustering to PCA embeddings and visualize the results.

This method performs density-based clustering (DBSCAN) on the PCA-reduced dataset. Cluster labels are stored in the object's metadata, and a scatter plot of the first two principal components with cluster annotations is returned.

Parameters

pc_num : int, default 2 Number of principal components to use for clustering. If 0, uses all available components.

eps : float, default 0.5 Maximum distance between two points for them to be considered as neighbors (DBSCAN parameter).

min_samples : int, default 10 Minimum number of samples required to form a cluster (DBSCAN parameter).

width : int, default 8 Width of the output scatter plot.

height : int, default 6 Height of the output scatter plot.

harmonized : bool, default False If True, use harmonized PCA data (self.harmonized_pca). If False, use standard PCA results (self.pca).

Returns

matplotlib.figure.Figure Scatter plot of the first two principal components colored by cluster assignments.

Updates

self.clustering_metadata['PCA_clusters'] : list Cluster labels assigned to each cell/sample.

self.input_metadata['PCA_clusters'] : list, optional Cluster labels stored in input metadata (if available).

def perform_UMAP( self, factorize: bool = False, umap_num: int = 100, pc_num: int = 0, harmonized: bool = False, n_neighbors: int = 5, min_dist: float | int = 0.1, spread: float | int = 1.0, set_op_mix_ratio: float | int = 1.0, local_connectivity: int = 1, repulsion_strength: float | int = 1.0, negative_sample_rate: int = 5, width: int = 8, height: int = 6):
390    def perform_UMAP(
391        self,
392        factorize: bool = False,
393        umap_num: int = 100,
394        pc_num: int = 0,
395        harmonized: bool = False,
396        n_neighbors: int = 5,
397        min_dist: float | int = 0.1,
398        spread: float | int = 1.0,
399        set_op_mix_ratio: float | int = 1.0,
400        local_connectivity: int = 1,
401        repulsion_strength: float | int = 1.0,
402        negative_sample_rate: int = 5,
403        width: int = 8,
404        height: int = 6,
405    ):
406        """
407        Compute and visualize UMAP embeddings of the dataset.
408
409        This method applies Uniform Manifold Approximation and Projection (UMAP)
410        for dimensionality reduction on either raw, PCA, or harmonized PCA data.
411        Results are stored as a DataFrame (`self.umap`) and a scatter plot figure
412        (`self.UMAP_plot`).
413
414        Parameters
415        ----------
416        factorize : bool, default False
417            If True, categorical sample labels (from column names) are factorized
418            and used as supervision in UMAP fitting.
419
420        umap_num : int, default 100
421            Number of UMAP dimensions to compute. If 0, matches the input dimension.
422
423        pc_num : int, default 0
424            Number of principal components to use as UMAP input.
425            If 0, use all available components or raw data.
426
427        harmonized : bool, default False
428            If True, use harmonized PCA embeddings (`self.harmonized_pca`).
429            If False, use standard PCA or raw scaled data.
430
431        n_neighbors : int, default 5
432            UMAP parameter controlling the size of the local neighborhood.
433
434        min_dist : float, default 0.1
435            UMAP parameter controlling minimum allowed distance between embedded points.
436
437        spread : int | float, default 1.0
438            Effective scale of embedded space (UMAP parameter).
439
440        set_op_mix_ratio : int | float, default 1.0
441            Interpolation parameter between union and intersection in fuzzy sets.
442
443        local_connectivity : int, default 1
444            Number of nearest neighbors assumed for each point.
445
446        repulsion_strength : int | float, default 1.0
447            Weighting applied to negative samples during optimization.
448
449        negative_sample_rate : int, default 5
450            Number of negative samples per positive sample in optimization.
451
452        width : int, default 8
453            Width of the output scatter plot.
454
455        height : int, default 6
456            Height of the output scatter plot.
457
458        Updates
459        -------
460        self.umap : pandas.DataFrame
461            Table of UMAP embeddings with columns `UMAP1 ... UMAPn`.
462
463        Notes
464        -----
465        For supervised UMAP (`factorize=True`), categorical codes from column
466        names of the dataset are used as labels.
467        """
468
469        scaler = StandardScaler()
470
471        if pc_num == 0 and harmonized:
472            data_scaled = self.harmonized_pca
473
474        elif pc_num == 0:
475            data_scaled = scaler.fit_transform(self.clustering_data.T)
476
477        else:
478            if harmonized:
479
480                data_scaled = self.harmonized_pca.iloc[:, 0:pc_num]
481            else:
482
483                data_scaled = self.pca.iloc[:, 0:pc_num]
484
485        if umap_num == 0 or umap_num > data_scaled.shape[1]:
486
487            umap_num = data_scaled.shape[1]
488
489            reducer = umap.UMAP(
490                n_components=len(data_scaled.T),
491                random_state=42,
492                n_neighbors=n_neighbors,
493                min_dist=min_dist,
494                spread=spread,
495                set_op_mix_ratio=set_op_mix_ratio,
496                local_connectivity=local_connectivity,
497                repulsion_strength=repulsion_strength,
498                negative_sample_rate=negative_sample_rate,
499                n_jobs=1,
500            )
501
502        else:
503
504            reducer = umap.UMAP(
505                n_components=umap_num,
506                random_state=42,
507                n_neighbors=n_neighbors,
508                min_dist=min_dist,
509                spread=spread,
510                set_op_mix_ratio=set_op_mix_ratio,
511                local_connectivity=local_connectivity,
512                repulsion_strength=repulsion_strength,
513                negative_sample_rate=negative_sample_rate,
514                n_jobs=1,
515            )
516
517        if factorize:
518            embedding = reducer.fit_transform(
519                X=data_scaled, y=pd.Categorical(self.clustering_data.columns).codes
520            )
521        else:
522            embedding = reducer.fit_transform(X=data_scaled)
523
524        umap_df = pd.DataFrame(
525            embedding, columns=["UMAP" + str(x + 1) for x in range(umap_num)]
526        )
527
528        plt.figure(figsize=(width, height))
529        plt.scatter(umap_df["UMAP1"], umap_df["UMAP2"], alpha=0.7)
530        plt.xlabel("UMAP 1")
531        plt.ylabel("UMAP 2")
532        plt.grid(True)
533
534        plt.show()
535
536        self.umap = umap_df

Compute and visualize UMAP embeddings of the dataset.

This method applies Uniform Manifold Approximation and Projection (UMAP) for dimensionality reduction on either raw, PCA, or harmonized PCA data. Results are stored as a DataFrame (self.umap) and a scatter plot figure (self.UMAP_plot).

Parameters

factorize : bool, default False If True, categorical sample labels (from column names) are factorized and used as supervision in UMAP fitting.

umap_num : int, default 100 Number of UMAP dimensions to compute. If 0, matches the input dimension.

pc_num : int, default 0 Number of principal components to use as UMAP input. If 0, use all available components or raw data.

harmonized : bool, default False If True, use harmonized PCA embeddings (self.harmonized_pca). If False, use standard PCA or raw scaled data.

n_neighbors : int, default 5 UMAP parameter controlling the size of the local neighborhood.

min_dist : float, default 0.1 UMAP parameter controlling minimum allowed distance between embedded points.

spread : int | float, default 1.0 Effective scale of embedded space (UMAP parameter).

set_op_mix_ratio : int | float, default 1.0 Interpolation parameter between union and intersection in fuzzy sets.

local_connectivity : int, default 1 Number of nearest neighbors assumed for each point.

repulsion_strength : int | float, default 1.0 Weighting applied to negative samples during optimization.

negative_sample_rate : int, default 5 Number of negative samples per positive sample in optimization.

width : int, default 8 Width of the output scatter plot.

height : int, default 6 Height of the output scatter plot.

Updates

self.umap : pandas.DataFrame Table of UMAP embeddings with columns UMAP1 ... UMAPn.

Notes

For supervised UMAP (factorize=True), categorical codes from column names of the dataset are used as labels.

def knee_plot_umap(self, eps: int | float = 0.5, min_samples: int = 10):
538    def knee_plot_umap(self, eps: int | float = 0.5, min_samples: int = 10):
539        """
540        Plot silhouette scores for different UMAP dimensions to determine optimal n_components.
541
542        Parameters
543        ----------
544        eps : float, default 0.5
545            DBSCAN eps parameter for clustering each UMAP dimension.
546
547        min_samples : int, default 10
548            Minimum number of samples to form a cluster in DBSCAN.
549
550        Returns
551        -------
552        matplotlib.figure.Figure
553            Silhouette score plot across UMAP dimensions.
554        """
555
556        umap_range = range(2, len(self.umap.T) + 1)
557
558        silhouette_scores = []
559        component = []
560        for n in umap_range:
561
562            db = DBSCAN(eps=eps, min_samples=min_samples)
563            labels = db.fit_predict(np.array(self.umap)[:, :n])
564
565            mask = labels != -1
566            if len(set(labels[mask])) > 1:
567                score = silhouette_score(np.array(self.umap)[:, :n][mask], labels[mask])
568            else:
569                score = -1
570
571            silhouette_scores.append(score)
572            component.append(n)
573
574        fig = plt.figure(figsize=(10, 5))
575        plt.plot(component, silhouette_scores, marker="o")
576        plt.xlabel("UMAP (n_components)")
577        plt.ylabel("Silhouette Score")
578        plt.grid(True)
579        plt.xticks(range(int(min(component)), int(max(component)) + 1, 1))
580
581        plt.show()
582
583        return fig

Plot silhouette scores for different UMAP dimensions to determine optimal n_components.

Parameters

eps : float, default 0.5 DBSCAN eps parameter for clustering each UMAP dimension.

min_samples : int, default 10 Minimum number of samples to form a cluster in DBSCAN.

Returns

matplotlib.figure.Figure Silhouette score plot across UMAP dimensions.

def find_clusters_UMAP( self, umap_n: int = 5, eps: float = 0.5, min_samples: int = 10, width: int = 8, height: int = 6):
585    def find_clusters_UMAP(
586        self,
587        umap_n: int = 5,
588        eps: float | float = 0.5,
589        min_samples: int = 10,
590        width: int = 8,
591        height: int = 6,
592    ):
593        """
594        Apply DBSCAN clustering on UMAP embeddings and visualize clusters.
595
596        This method performs density-based clustering (DBSCAN) on the UMAP-reduced
597        dataset. Cluster labels are stored in the object's metadata, and a scatter
598        plot of the first two UMAP components with cluster annotations is returned.
599
600        Parameters
601        ----------
602        umap_n : int, default 5
603            Number of UMAP dimensions to use for DBSCAN clustering.
604            Must be <= number of columns in `self.umap`.
605
606        eps : float | int, default 0.5
607            Maximum neighborhood distance between two samples for them to be considered
608            as in the same cluster (DBSCAN parameter).
609
610        min_samples : int, default 10
611            Minimum number of samples in a neighborhood to form a cluster (DBSCAN parameter).
612
613        width : int, default 8
614            Figure width.
615
616        height : int, default 6
617            Figure height.
618
619        Returns
620        -------
621        matplotlib.figure.Figure
622            Scatter plot of the first two UMAP components colored by
623            cluster assignments.
624
625        Updates
626        -------
627        self.clustering_metadata['UMAP_clusters'] : list
628            Cluster labels assigned to each cell/sample.
629
630        self.input_metadata['UMAP_clusters'] : list, optional
631            Cluster labels stored in input metadata (if available).
632        """
633
634        dbscan = DBSCAN(eps=eps, min_samples=min_samples)
635        dbscan_labels = dbscan.fit_predict(np.array(self.umap)[:, :umap_n])
636
637        umap_df = self.umap
638        umap_df["Cluster"] = dbscan_labels
639
640        fig = plt.figure(figsize=(width, height))
641
642        for cluster_id in sorted(umap_df["Cluster"].unique()):
643            cluster_data = umap_df[umap_df["Cluster"] == cluster_id]
644            plt.scatter(
645                cluster_data["UMAP1"],
646                cluster_data["UMAP2"],
647                label=f"Cluster {cluster_id}",
648                alpha=0.7,
649            )
650
651        plt.xlabel("UMAP 1")
652        plt.ylabel("UMAP 2")
653        plt.legend(title="Clusters", loc="center left", bbox_to_anchor=(1.0, 0.5))
654        plt.grid(True)
655
656        self.clustering_metadata["UMAP_clusters"] = [str(x) for x in dbscan_labels]
657
658        try:
659            self.input_metadata["UMAP_clusters"] = [str(x) for x in dbscan_labels]
660        except:
661            pass
662
663        return fig

Apply DBSCAN clustering on UMAP embeddings and visualize clusters.

This method performs density-based clustering (DBSCAN) on the UMAP-reduced dataset. Cluster labels are stored in the object's metadata, and a scatter plot of the first two UMAP components with cluster annotations is returned.

Parameters

umap_n : int, default 5 Number of UMAP dimensions to use for DBSCAN clustering. Must be <= number of columns in self.umap.

eps : float | int, default 0.5 Maximum neighborhood distance between two samples for them to be considered as in the same cluster (DBSCAN parameter).

min_samples : int, default 10 Minimum number of samples in a neighborhood to form a cluster (DBSCAN parameter).

width : int, default 8 Figure width.

height : int, default 6 Figure height.

Returns

matplotlib.figure.Figure Scatter plot of the first two UMAP components colored by cluster assignments.

Updates

self.clustering_metadata['UMAP_clusters'] : list Cluster labels assigned to each cell/sample.

self.input_metadata['UMAP_clusters'] : list, optional Cluster labels stored in input metadata (if available).

def UMAP_vis( self, names_slot: str = 'cell_names', set_sep: bool = True, point_size: int | float = 0.6, font_size: int | float = 6, legend_split_col: int = 2, width: int = 8, height: int = 6, inc_num: bool = True):
665    def UMAP_vis(
666        self,
667        names_slot: str = "cell_names",
668        set_sep: bool = True,
669        point_size: int | float = 0.6,
670        font_size: int | float = 6,
671        legend_split_col: int = 2,
672        width: int = 8,
673        height: int = 6,
674        inc_num: bool = True,
675    ):
676        """
677        Visualize UMAP embeddings with sample labels based on specyfic metadata slot.
678
679        Parameters
680        ----------
681        names_slot : str, default 'cell_names'
682            Column in metadata to use as sample labels.
683
684        set_sep : bool, default True
685            If True, separate points by dataset.
686
687        point_size : float, default 0.6
688            Size of scatter points.
689
690        font_size : int, default 6
691            Font size for numbers on points.
692
693        legend_split_col : int, default 2
694            Number of columns in legend.
695
696        width : int, default 8
697            Figure width.
698
699        height : int, default 6
700            Figure height.
701
702        inc_num : bool, default True
703            If True, annotate points with numeric labels.
704
705        Returns
706        -------
707        matplotlib.figure.Figure
708            UMAP scatter plot figure.
709        """
710
711        umap_df = self.umap.iloc[:, 0:2].copy()
712        umap_df["names"] = list(self.clustering_metadata[names_slot])
713
714        if set_sep:
715
716            if "sets" in list(self.clustering_metadata.columns):
717                umap_df["dataset"] = list(self.clustering_metadata["sets"])
718            else:
719                umap_df["dataset"] = "default"
720
721        else:
722            umap_df["dataset"] = "default"
723
724        umap_df["tmp_nam"] = list(umap_df["names"] + umap_df["dataset"])
725
726        umap_df["count"] = umap_df["tmp_nam"].map(umap_df["tmp_nam"].value_counts())
727
728        numeric_df = (
729            pd.DataFrame(umap_df[["count", "tmp_nam", "names"]].copy())
730            .drop_duplicates()
731            .sort_values("count", ascending=False)
732        )
733        numeric_df["numeric_values"] = range(0, numeric_df.shape[0])
734
735        umap_df = umap_df.merge(
736            numeric_df[["tmp_nam", "numeric_values"]], on="tmp_nam", how="left"
737        )
738
739        fig, ax = plt.subplots(figsize=(width, height))
740
741        markers = ["o", "s", "^", "D", "P", "*", "X"]
742        marker_map = {
743            ds: markers[i % len(markers)]
744            for i, ds in enumerate(umap_df["dataset"].unique())
745        }
746
747        cord_list = []
748
749        for num, nam in zip(numeric_df["numeric_values"], numeric_df["names"]):
750
751            cluster_data = umap_df[umap_df["numeric_values"] == num]
752
753            ax.scatter(
754                cluster_data["UMAP1"],
755                cluster_data["UMAP2"],
756                label=f"{num} - {nam}",
757                marker=marker_map[cluster_data["dataset"].iloc[0]],
758                alpha=0.6,
759                s=point_size,
760            )
761
762            coords = cluster_data[["UMAP1", "UMAP2"]].values
763
764            dists = pairwise_distances(coords)
765
766            sum_dists = dists.sum(axis=1)
767
768            center_idx = np.argmin(sum_dists)
769            center_point = coords[center_idx]
770
771            cord_list.append(center_point)
772
773        if inc_num:
774            texts = []
775            for (x, y), num in zip(cord_list, numeric_df["numeric_values"]):
776                texts.append(
777                    ax.text(
778                        x,
779                        y,
780                        str(num),
781                        ha="center",
782                        va="center",
783                        fontsize=font_size,
784                        color="black",
785                    )
786                )
787
788            adjust_text(texts, arrowprops=dict(arrowstyle="-", color="gray", lw=0.5))
789
790        ax.set_xlabel("UMAP 1")
791        ax.set_ylabel("UMAP 2")
792
793        ax.legend(
794            title="Clusters",
795            loc="center left",
796            bbox_to_anchor=(1.05, 0.5),
797            ncol=legend_split_col,
798            markerscale=5,
799        )
800
801        ax.grid(True)
802
803        plt.tight_layout()
804
805        return fig

Visualize UMAP embeddings with sample labels based on specyfic metadata slot.

Parameters

names_slot : str, default 'cell_names' Column in metadata to use as sample labels.

set_sep : bool, default True If True, separate points by dataset.

point_size : float, default 0.6 Size of scatter points.

font_size : int, default 6 Font size for numbers on points.

legend_split_col : int, default 2 Number of columns in legend.

width : int, default 8 Figure width.

height : int, default 6 Figure height.

inc_num : bool, default True If True, annotate points with numeric labels.

Returns

matplotlib.figure.Figure UMAP scatter plot figure.

def UMAP_feature( self, feature_name: str, features_data: pandas.core.frame.DataFrame | None, point_size: int | float = 0.6, font_size: int | float = 6, width: int = 8, height: int = 6, palette='light'):
807    def UMAP_feature(
808        self,
809        feature_name: str,
810        features_data: pd.DataFrame | None,
811        point_size: int | float = 0.6,
812        font_size: int | float = 6,
813        width: int = 8,
814        height: int = 6,
815        palette="light",
816    ):
817        """
818        Visualize UMAP embedding with expression levels of a selected feature.
819
820        Each point (cell) in the UMAP plot is colored according to the expression
821        value of the chosen feature, enabling interpretation of spatial patterns
822        of gene activity or metadata distribution in low-dimensional space.
823
824        Parameters
825        ----------
826        feature_name : str
827           Name of the feature to plot.
828
829        features_data : pandas.DataFrame or None, default None
830            If None, the function uses the DataFrame containing the clustering data.
831            To plot features not used in clustering, provide a wider DataFrame
832            containing the original feature values.
833
834        point_size : float, default 0.6
835            Size of scatter points in the plot.
836
837        font_size : int, default 6
838            Font size for axis labels and annotations.
839
840        width : int, default 8
841            Width of the matplotlib figure.
842
843        height : int, default 6
844            Height of the matplotlib figure.
845
846        palette : str, default 'light'
847            Color palette for expression visualization. Options are:
848            - 'light'
849            - 'dark'
850            - 'green'
851            - 'gray'
852
853        Returns
854        -------
855        matplotlib.figure.Figure
856            UMAP scatter plot colored by feature values.
857        """
858
859        umap_df = self.umap.iloc[:, 0:2].copy()
860
861        if features_data is None:
862
863            features_data = self.clustering_data
864
865        if features_data.shape[1] != umap_df.shape[0]:
866            raise ValueError(
867                "Imputed 'features_data' shape does not match the number of UMAP cells"
868            )
869
870        blist = [
871            True if x.upper() == feature_name.upper() else False
872            for x in features_data.index
873        ]
874
875        if not any(blist):
876            raise ValueError("Imputed feature_name is not included in the data")
877
878        umap_df.loc[:, "feature"] = (
879            features_data.loc[blist, :]
880            .apply(lambda row: row.tolist(), axis=1)
881            .values[0]
882        )
883
884        umap_df = umap_df.sort_values("feature", ascending=True)
885
886        import matplotlib.colors as mcolors
887
888        if palette == "light":
889            palette = px.colors.sequential.Sunsetdark
890
891        elif palette == "dark":
892            palette = px.colors.sequential.thermal
893
894        elif palette == "green":
895            palette = px.colors.sequential.Aggrnyl
896
897        elif palette == "gray":
898            palette = px.colors.sequential.gray
899            palette = palette[::-1]
900
901        else:
902            raise ValueError(
903                'Palette not found. Use: "light", "dark", "gray", or "green"'
904            )
905
906        converted = []
907        for c in palette:
908            rgb_255 = px.colors.unlabel_rgb(c)
909            rgb_01 = tuple(v / 255.0 for v in rgb_255)
910            converted.append(rgb_01)
911
912        my_cmap = mcolors.ListedColormap(converted, name="custom")
913
914        fig, ax = plt.subplots(figsize=(width, height))
915        sc = ax.scatter(
916            umap_df["UMAP1"],
917            umap_df["UMAP2"],
918            c=umap_df["feature"],
919            s=point_size,
920            cmap=my_cmap,
921            alpha=1.0,
922            edgecolors="black",
923            linewidths=0.1,
924        )
925
926        cbar = plt.colorbar(sc, ax=ax)
927        cbar.set_label(f"{feature_name}")
928
929        ax.set_xlabel("UMAP 1")
930        ax.set_ylabel("UMAP 2")
931
932        ax.grid(True)
933
934        plt.tight_layout()
935
936        return fig

Visualize UMAP embedding with expression levels of a selected feature.

Each point (cell) in the UMAP plot is colored according to the expression value of the chosen feature, enabling interpretation of spatial patterns of gene activity or metadata distribution in low-dimensional space.

Parameters

feature_name : str Name of the feature to plot.

features_data : pandas.DataFrame or None, default None If None, the function uses the DataFrame containing the clustering data. To plot features not used in clustering, provide a wider DataFrame containing the original feature values.

point_size : float, default 0.6 Size of scatter points in the plot.

font_size : int, default 6 Font size for axis labels and annotations.

width : int, default 8 Width of the matplotlib figure.

height : int, default 6 Height of the matplotlib figure.

palette : str, default 'light' Color palette for expression visualization. Options are: - 'light' - 'dark' - 'green' - 'gray'

Returns

matplotlib.figure.Figure UMAP scatter plot colored by feature values.

def get_umap_data(self):
938    def get_umap_data(self):
939        """
940        Retrieve UMAP embedding data with optional cluster labels.
941
942        Returns the UMAP coordinates stored in `self.umap`. If clustering
943        metadata is available (specifically `UMAP_clusters`), the corresponding
944        cluster assignments are appended as an additional column.
945
946        Returns
947        -------
948        pandas.DataFrame
949            DataFrame containing UMAP coordinates (columns: 'UMAP1', 'UMAP2', ...).
950            If available, includes an extra column 'clusters' with cluster labels.
951
952        Notes
953        -----
954        - UMAP embeddings must be computed beforehand (e.g., using `perform_UMAP`).
955        - Cluster labels are added only if present in `self.clustering_metadata`.
956        """
957
958        umap_data = self.umap
959
960        try:
961            umap_data["clusters"] = self.clustering_metadata["UMAP_clusters"]
962        except:
963            pass
964
965        return umap_data

Retrieve UMAP embedding data with optional cluster labels.

Returns the UMAP coordinates stored in self.umap. If clustering metadata is available (specifically UMAP_clusters), the corresponding cluster assignments are appended as an additional column.

Returns

pandas.DataFrame DataFrame containing UMAP coordinates (columns: 'UMAP1', 'UMAP2', ...). If available, includes an extra column 'clusters' with cluster labels.

Notes

  • UMAP embeddings must be computed beforehand (e.g., using perform_UMAP).
  • Cluster labels are added only if present in self.clustering_metadata.
def get_pca_data(self):
967    def get_pca_data(self):
968        """
969        Retrieve PCA embedding data with optional cluster labels.
970
971        Returns the principal component scores stored in `self.pca`. If clustering
972        metadata is available (specifically `PCA_clusters`), the corresponding
973        cluster assignments are appended as an additional column.
974
975        Returns
976        -------
977        pandas.DataFrame
978            DataFrame containing PCA coordinates (columns: 'PC1', 'PC2', ...).
979            If available, includes an extra column 'clusters' with cluster labels.
980
981        Notes
982        -----
983        - PCA must be computed beforehand (e.g., using `perform_PCA`).
984        - Cluster labels are added only if present in `self.clustering_metadata`.
985        """
986
987        pca_data = self.pca
988
989        try:
990            pca_data["clusters"] = self.clustering_metadata["PCA_clusters"]
991        except:
992            pass
993
994        return pca_data

Retrieve PCA embedding data with optional cluster labels.

Returns the principal component scores stored in self.pca. If clustering metadata is available (specifically PCA_clusters), the corresponding cluster assignments are appended as an additional column.

Returns

pandas.DataFrame DataFrame containing PCA coordinates (columns: 'PC1', 'PC2', ...). If available, includes an extra column 'clusters' with cluster labels.

Notes

  • PCA must be computed beforehand (e.g., using perform_PCA).
  • Cluster labels are added only if present in self.clustering_metadata.
def return_clusters(self, clusters='umap'):
 996    def return_clusters(self, clusters="umap"):
 997        """
 998        Retrieve cluster labels from UMAP or PCA clustering results.
 999
1000        Parameters
1001        ----------
1002        clusters : str, default 'umap'
1003            Source of cluster labels to return. Must be one of:
1004            - 'umap': return cluster labels from UMAP embeddings.
1005            - 'pca' : return cluster labels from PCA embeddings.
1006
1007        Returns
1008        -------
1009        list
1010            Cluster labels corresponding to the selected embedding method.
1011
1012        Raises
1013        ------
1014        ValueError
1015            If `clusters` is not 'umap' or 'pca'.
1016
1017        Notes
1018        -----
1019        Requires that clustering has already been performed
1020        (e.g., using `find_clusters_UMAP` or `find_clusters_PCA`).
1021        """
1022
1023        if clusters.lower() == "umap":
1024            clusters_vector = self.clustering_metadata["UMAP_clusters"]
1025        elif clusters.lower() == "pca":
1026            clusters_vector = self.clustering_metadata["PCA_clusters"]
1027        else:
1028            raise ValueError("Parameter 'clusters' must be either 'umap' or 'pca'.")
1029
1030        return clusters_vector

Retrieve cluster labels from UMAP or PCA clustering results.

Parameters

clusters : str, default 'umap' Source of cluster labels to return. Must be one of: - 'umap': return cluster labels from UMAP embeddings. - 'pca' : return cluster labels from PCA embeddings.

Returns

list Cluster labels corresponding to the selected embedding method.

Raises

ValueError If clusters is not 'umap' or 'pca'.

Notes

Requires that clustering has already been performed (e.g., using find_clusters_UMAP or find_clusters_PCA).

class COMPsc(Clustering):
1033class COMPsc(Clustering):
1034    """
1035    A class `COMPsc` (Comparison of single-cell data) designed for the integration,
1036    analysis, and visualization of single-cell datasets.
1037    The class supports independent dataset integration, subclustering of existing clusters,
1038    marker detection, and multiple visualization strategies.
1039
1040    The COMPsc class provides methods for:
1041
1042        - Normalizing and filtering single-cell data
1043        - Loading and saving sparse 10x-style datasets
1044        - Computing differential expression and marker genes
1045        - Clustering and subclustering analysis
1046        - Visualizing similarity and spatial relationships
1047        - Aggregating data by cell and set annotations
1048        - Managing metadata and renaming labels
1049        - Plotting gene detection histograms and feature scatters
1050
1051    Methods
1052    -------
1053    project_dir(path_to_directory, project_list)
1054        Scans a directory to create a COMPsc instance mapping project names to their paths.
1055
1056    save_project(name, path=os.getcwd())
1057        Saves the COMPsc object to a pickle file on disk.
1058
1059    load_project(path)
1060        Loads a previously saved COMPsc object from a pickle file.
1061
1062    reduce_cols(reg, inc_set=False)
1063        Removes columns from data tables where column names contain a specified name or partial substring.
1064
1065    reduce_rows(reg, inc_set=False)
1066        Removes rows from data tables where column names contain a specified feature (gene) name.
1067
1068    get_data(set_info=False)
1069        Returns normalized data with optional set annotations in column names.
1070
1071    get_metadata()
1072        Returns the stored input metadata.
1073
1074    get_partial_data(names=None, features=None, name_slot='cell_names')
1075        Return a subset of the data by sample names and/or features.
1076
1077    gene_calculation()
1078        Calculates and stores per-cell gene detection counts as a pandas Series.
1079
1080    gene_histograme(bins=100)
1081        Plots a histogram of genes detected per cell with an overlaid normal distribution.
1082
1083    gene_threshold(min_n=None, max_n=None)
1084        Filters cells based on minimum and/or maximum gene detection thresholds.
1085
1086    load_sparse_from_projects(normalized_data=False)
1087        Loads and concatenates sparse 10x-style datasets from project paths into count or normalized data.
1088
1089    rename_names(mapping, slot='cell_names')
1090        Renames entries in a specified metadata column using a provided mapping dictionary.
1091
1092    rename_subclusters(mapping)
1093        Renames subcluster labels using a provided mapping dictionary.
1094
1095    save_sparse(path_to_save=os.getcwd(), name_slot='cell_names', data_slot='normalized')
1096        Exports data as 10x-compatible sparse files (matrix.mtx, barcodes.tsv, genes.tsv).
1097
1098    normalize_data(normalize=True, normalize_factor=100000)
1099        Normalizes raw counts to counts-per-specified factor (e.g., CPM-like).
1100
1101    statistic(cells=None, sets=None, min_exp=0.01, min_pct=0.1, n_proc=10)
1102        Computes per-feature differential expression statistics (Mann-Whitney U) comparing target vs. rest groups.
1103
1104    calculate_difference_markers(min_exp=0, min_pct=0.25, n_proc=10, force=False)
1105        Computes and caches differential markers using the statistic method.
1106
1107    clustering_features(features_list=None, name_slot='cell_names', p_val=0.05, top_n=25, adj_mean=True, beta=0.4)
1108        Prepares clustering input by selecting marker features and optionally smoothing cell values.
1109
1110    average()
1111        Aggregates normalized data by averaging across (cell_name, set) pairs.
1112
1113    estimating_similarity(method='pearson', p_val=0.05, top_n=25)
1114        Computes pairwise correlation and Euclidean distance between aggregated samples.
1115
1116    similarity_plot(split_sets=True, set_info=True, cmap='seismic', width=12, height=10)
1117        Visualizes pairwise similarity as a scatter plot with correlation as hue and scaled distance as point size.
1118
1119    spatial_similarity(set_info=True, bandwidth=1, n_neighbors=5, min_dist=0.1, legend_split=2, point_size=20, ...)
1120        Creates a UMAP-like visualization of similarity relationships with cluster hulls and nearest-neighbor arrows.
1121
1122    subcluster_prepare(features, cluster)
1123        Initializes a Clustering object for subcluster analysis on a selected parent cluster.
1124
1125    define_subclusters(umap_num=2, eps=0.5, min_samples=10, bandwidth=1, n_neighbors=5, min_dist=0.1, ...)
1126        Performs UMAP and DBSCAN clustering on prepared subcluster data and stores cluster labels.
1127
1128    subcluster_features_scatter(colors='viridis', hclust='complete', img_width=3, img_high=5, label_size=6, ...)
1129        Visualizes averaged expression and occurrence of features for subclusters as a scatter plot.
1130
1131    subcluster_DEG_scatter(top_n=3, min_exp=0, min_pct=0.25, p_val=0.05, colors='viridis', ...)
1132        Plots top differential features for subclusters as a features-scatter visualization.
1133
1134    accept_subclusters()
1135        Commits subcluster labels to main metadata by renaming cell names and clears subcluster data.
1136
1137    Raises
1138    ------
1139    ValueError
1140        For invalid parameters, mismatched dimensions, or missing metadata.
1141
1142    """
1143
1144    def __init__(
1145        self,
1146        objects=None,
1147    ):
1148        """
1149        Initialize the COMPsc class for single-cell data integration and analysis.
1150
1151        Parameters
1152        ----------
1153        objects : list or None, optional
1154            Optional list of data objects to initialize the instance with.
1155
1156        Attributes
1157        ----------
1158        -objects : list or None
1159        -input_data : pandas.DataFrame or None
1160        -input_metadata : pandas.DataFrame or None
1161        -normalized_data : pandas.DataFrame or None
1162        -agg_metadata : pandas.DataFrame or None
1163        -agg_normalized_data : pandas.DataFrame or None
1164        -similarity : pandas.DataFrame or None
1165        -var_data : pandas.DataFrame or None
1166        -subclusters_ : instance of Clustering class or None
1167        -cells_calc : pandas.Series or None
1168        -gene_calc : pandas.Series or None
1169        -composition_data : pandas.DataFrame or None
1170        """
1171
1172        self.objects = objects
1173        """ Stores the input data objects."""
1174
1175        self.input_data = None
1176        """Raw input data for clustering or integration analysis."""
1177
1178        self.input_metadata = None
1179        """Metadata associated with the input data."""
1180
1181        self.normalized_data = None
1182        """Normalized version of the input data."""
1183
1184        self.agg_metadata = None
1185        '''Aggregated metadata for all sets in object related to "agg_normalized_data"'''
1186
1187        self.agg_normalized_data = None
1188        """Aggregated and normalized data across multiple sets."""
1189
1190        self.similarity = None
1191        """Similarity data between cells across all samples. and sets"""
1192
1193        self.var_data = None
1194        """DEG analysis results summarizing variance across all samples in the object."""
1195
1196        self.subclusters_ = None
1197        """Placeholder for information about subclusters analysis; if computed."""
1198
1199        self.cells_calc = None
1200        """Number of cells detected per sample (grouped by lineage, e.g., cluster or name), reflecting data composition."""
1201
1202        self.gene_calc = None
1203        """Number of genes detected per sample (cell), reflecting the sequencing depth."""
1204
1205        self.composition_data = None
1206        """Data describing composition of cells across clusters or sets."""
1207
1208    @classmethod
1209    def project_dir(cls, path_to_directory, project_list):
1210        """
1211        Scan a directory and build a COMPsc instance mapping provided project names
1212        to their paths.
1213
1214        Parameters
1215        ----------
1216        path_to_directory : str
1217            Path containing project subfolders.
1218
1219        project_list : list[str]
1220            List of filenames (folder names) to include in the returned object map.
1221
1222        Returns
1223        -------
1224        COMPsc
1225            New COMPsc instance with `objects` populated.
1226
1227        Raises
1228        ------
1229        Exception
1230            A generic exception is caught and a message printed if scanning fails.
1231
1232        Notes
1233        -----
1234        Function attempts to match entries in `project_list` to directory
1235        names and constructs a simplified object key from the folder name.
1236        """
1237        try:
1238            objects = {}
1239            for filename in tqdm(os.listdir(path_to_directory)):
1240                for c in project_list:
1241                    f = os.path.join(path_to_directory, filename)
1242                    if c == filename and os.path.isdir(f):
1243                        objects[str(c)] = f
1244
1245            return cls(objects)
1246
1247        except:
1248            print("Something went wrong. Check the function input data and try again!")
1249
1250    def save_project(self, name, path: str = os.getcwd()):
1251        """
1252        Save the COMPsc object to disk using pickle.
1253
1254        Parameters
1255        ----------
1256        name : str
1257            Base filename (without extension) to use when saving.
1258
1259        path : str, default os.getcwd()
1260            Directory in which to save the project file.
1261
1262        Returns
1263        -------
1264        None
1265
1266        Side Effects
1267        ------------
1268        - Writes a file `<path>/<name>.jpkl` containing the pickled object.
1269        - Prints a confirmation message with saved path.
1270        """
1271
1272        full = os.path.join(path, f"{name}.jpkl")
1273
1274        with open(full, "wb") as f:
1275            pickle.dump(self, f)
1276
1277        print(f"Project saved as {full}")
1278
1279    @classmethod
1280    def load_project(cls, path):
1281        """
1282        Load a previously saved COMPsc project from a pickle file.
1283
1284        Parameters
1285        ----------
1286        path : str
1287            Full path to the pickled project file.
1288
1289        Returns
1290        -------
1291        COMPsc
1292            The unpickled COMPsc object.
1293
1294        Raises
1295        ------
1296        FileNotFoundError
1297            If the provided path does not exist.
1298        """
1299
1300        if not os.path.exists(path):
1301            raise FileNotFoundError("File does not exist!")
1302        with open(path, "rb") as f:
1303            obj = pickle.load(f)
1304        return obj
1305
1306    def reduce_cols(
1307        self,
1308        reg: str | None = None,
1309        full: str | None = None,
1310        name_slot: str = "cell_names",
1311        inc_set: bool = False,
1312    ):
1313        """
1314        Remove columns (cells) whose names contain a substring `reg` or
1315        full name `full` from available tables.
1316
1317        Parameters
1318        ----------
1319        reg : str | None
1320            Substring to search for in column/cell names; matching columns will be removed.
1321            If not None, `full` must be None.
1322
1323        full : str | None
1324            Full name to search for in column/cell names; matching columns will be removed.
1325            If not None, `reg` must be None.
1326
1327        name_slot : str, default 'cell_names'
1328            Column in metadata to use as sample names.
1329
1330        inc_set : bool, default False
1331            If True, column names are interpreted as 'cell_name # set' when matching.
1332
1333        Update
1334        ------------
1335        Mutates `self.input_data`, `self.normalized_data`, `self.input_metadata`,
1336        `self.agg_normalized_data`, and `self.agg_metadata` (if they exist),
1337        removing columns/rows that match `reg`.
1338
1339        Raises
1340        ------
1341        Raises ValueError if nothing matches the reduction mask.
1342        """
1343
1344        if reg is None and full is None:
1345            raise ValueError(
1346                "Both 'reg' and 'full' arguments not provided. Please provide at least one of them!"
1347            )
1348
1349        if reg is not None and full is not None:
1350            raise ValueError(
1351                "Both 'reg' and 'full' arguments are provided. "
1352                "Please provide only one of them!\n"
1353                "'reg' is used when only part of the name must be detected.\n"
1354                "'full' is used if the full name must be detected."
1355            )
1356
1357        if reg is not None:
1358
1359            if self.input_data is not None:
1360
1361                if inc_set:
1362
1363                    self.input_data.columns = (
1364                        self.input_metadata[name_slot]
1365                        + " # "
1366                        + self.input_metadata["sets"]
1367                    )
1368
1369                else:
1370
1371                    self.input_data.columns = self.input_metadata[name_slot]
1372
1373                mask = [reg.upper() not in x.upper() for x in self.input_data.columns]
1374
1375                if len([y for y in mask if y is False]) == 0:
1376                    raise ValueError("Nothing found to reduce")
1377
1378                self.input_data = self.input_data.loc[:, mask]
1379
1380            if self.normalized_data is not None:
1381
1382                if inc_set:
1383
1384                    self.normalized_data.columns = (
1385                        self.input_metadata[name_slot]
1386                        + " # "
1387                        + self.input_metadata["sets"]
1388                    )
1389
1390                else:
1391
1392                    self.normalized_data.columns = self.input_metadata[name_slot]
1393
1394                mask = [
1395                    reg.upper() not in x.upper() for x in self.normalized_data.columns
1396                ]
1397
1398                if len([y for y in mask if y is False]) == 0:
1399                    raise ValueError("Nothing found to reduce")
1400
1401                self.normalized_data = self.normalized_data.loc[:, mask]
1402
1403            if self.input_metadata is not None:
1404
1405                if inc_set:
1406
1407                    self.input_metadata["drop"] = (
1408                        self.input_metadata[name_slot]
1409                        + " # "
1410                        + self.input_metadata["sets"]
1411                    )
1412
1413                else:
1414
1415                    self.input_metadata["drop"] = self.input_metadata[name_slot]
1416
1417                mask = [
1418                    reg.upper() not in x.upper() for x in self.input_metadata["drop"]
1419                ]
1420
1421                if len([y for y in mask if y is False]) == 0:
1422                    raise ValueError("Nothing found to reduce")
1423
1424                self.input_metadata = self.input_metadata.loc[mask, :].reset_index(
1425                    drop=True
1426                )
1427
1428                self.input_metadata = self.input_metadata.drop(
1429                    columns=["drop"], errors="ignore"
1430                )
1431
1432            if self.agg_normalized_data is not None:
1433
1434                if inc_set:
1435
1436                    self.agg_normalized_data.columns = (
1437                        self.agg_metadata[name_slot] + " # " + self.agg_metadata["sets"]
1438                    )
1439
1440                else:
1441
1442                    self.agg_normalized_data.columns = self.agg_metadata[name_slot]
1443
1444                mask = [
1445                    reg.upper() not in x.upper()
1446                    for x in self.agg_normalized_data.columns
1447                ]
1448
1449                if len([y for y in mask if y is False]) == 0:
1450                    raise ValueError("Nothing found to reduce")
1451
1452                self.agg_normalized_data = self.agg_normalized_data.loc[:, mask]
1453
1454            if self.agg_metadata is not None:
1455
1456                if inc_set:
1457
1458                    self.agg_metadata["drop"] = (
1459                        self.agg_metadata[name_slot] + " # " + self.agg_metadata["sets"]
1460                    )
1461
1462                else:
1463
1464                    self.agg_metadata["drop"] = self.agg_metadata[name_slot]
1465
1466                mask = [reg.upper() not in x.upper() for x in self.agg_metadata["drop"]]
1467
1468                if len([y for y in mask if y is False]) == 0:
1469                    raise ValueError("Nothing found to reduce")
1470
1471                self.agg_metadata = self.agg_metadata.loc[mask, :].reset_index(
1472                    drop=True
1473                )
1474
1475                self.agg_metadata = self.agg_metadata.drop(
1476                    columns=["drop"], errors="ignore"
1477                )
1478
1479        elif full is not None:
1480
1481            if self.input_data is not None:
1482
1483                if inc_set:
1484
1485                    self.input_data.columns = (
1486                        self.input_metadata[name_slot]
1487                        + " # "
1488                        + self.input_metadata["sets"]
1489                    )
1490
1491                    if "#" not in full:
1492
1493                        self.input_data.columns = self.input_metadata[name_slot]
1494
1495                        print(
1496                            "Not include the set info (name # set) in the 'full' argument, where 'inc_set' is True"
1497                            "Only the names will be compared, without considering the set information."
1498                        )
1499
1500                else:
1501
1502                    self.input_data.columns = self.input_metadata[name_slot]
1503
1504                mask = [full.upper() != x.upper() for x in self.input_data.columns]
1505
1506                if len([y for y in mask if y is False]) == 0:
1507                    raise ValueError("Nothing found to reduce")
1508
1509                self.input_data = self.input_data.loc[:, mask]
1510
1511            if self.normalized_data is not None:
1512
1513                if inc_set:
1514
1515                    self.normalized_data.columns = (
1516                        self.input_metadata[name_slot]
1517                        + " # "
1518                        + self.input_metadata["sets"]
1519                    )
1520
1521                    if "#" not in full:
1522
1523                        self.normalized_data.columns = self.input_metadata[name_slot]
1524
1525                        print(
1526                            "Not include the set info (name # set) in the 'full' argument, where 'inc_set' is True"
1527                            "Only the names will be compared, without considering the set information."
1528                        )
1529
1530                else:
1531
1532                    self.normalized_data.columns = self.input_metadata[name_slot]
1533
1534                mask = [full.upper() != x.upper() for x in self.normalized_data.columns]
1535
1536                if len([y for y in mask if y is False]) == 0:
1537                    raise ValueError("Nothing found to reduce")
1538
1539                self.normalized_data = self.normalized_data.loc[:, mask]
1540
1541            if self.input_metadata is not None:
1542
1543                if inc_set:
1544
1545                    self.input_metadata["drop"] = (
1546                        self.input_metadata[name_slot]
1547                        + " # "
1548                        + self.input_metadata["sets"]
1549                    )
1550
1551                    if "#" not in full:
1552
1553                        self.input_metadata["drop"] = self.input_metadata[name_slot]
1554
1555                        print(
1556                            "Not include the set info (name # set) in the 'full' argument, where 'inc_set' is True"
1557                            "Only the names will be compared, without considering the set information."
1558                        )
1559
1560                else:
1561
1562                    self.input_metadata["drop"] = self.input_metadata[name_slot]
1563
1564                mask = [full.upper() != x.upper() for x in self.input_metadata["drop"]]
1565
1566                if len([y for y in mask if y is False]) == 0:
1567                    raise ValueError("Nothing found to reduce")
1568
1569                self.input_metadata = self.input_metadata.loc[mask, :].reset_index(
1570                    drop=True
1571                )
1572
1573                self.input_metadata = self.input_metadata.drop(
1574                    columns=["drop"], errors="ignore"
1575                )
1576
1577            if self.agg_normalized_data is not None:
1578
1579                if inc_set:
1580
1581                    self.agg_normalized_data.columns = (
1582                        self.agg_metadata[name_slot] + " # " + self.agg_metadata["sets"]
1583                    )
1584
1585                    if "#" not in full:
1586
1587                        self.agg_normalized_data.columns = self.agg_metadata[name_slot]
1588
1589                        print(
1590                            "Not include the set info (name # set) in the 'full' argument, where 'inc_set' is True"
1591                            "Only the names will be compared, without considering the set information."
1592                        )
1593                else:
1594
1595                    self.agg_normalized_data.columns = self.agg_metadata[name_slot]
1596
1597                mask = [
1598                    full.upper() != x.upper() for x in self.agg_normalized_data.columns
1599                ]
1600
1601                if len([y for y in mask if y is False]) == 0:
1602                    raise ValueError("Nothing found to reduce")
1603
1604                self.agg_normalized_data = self.agg_normalized_data.loc[:, mask]
1605
1606            if self.agg_metadata is not None:
1607
1608                if inc_set:
1609
1610                    self.agg_metadata["drop"] = (
1611                        self.agg_metadata[name_slot] + " # " + self.agg_metadata["sets"]
1612                    )
1613
1614                    if "#" not in full:
1615
1616                        self.agg_metadata["drop"] = self.agg_metadata[name_slot]
1617
1618                        print(
1619                            "Not include the set info (name # set) in the 'full' argument, where 'inc_set' is True"
1620                            "Only the names will be compared, without considering the set information."
1621                        )
1622                else:
1623
1624                    self.agg_metadata["drop"] = self.agg_metadata[name_slot]
1625
1626                mask = [full.upper() != x.upper() for x in self.agg_metadata["drop"]]
1627
1628                if len([y for y in mask if y is False]) == 0:
1629                    raise ValueError("Nothing found to reduce")
1630
1631                self.agg_metadata = self.agg_metadata.loc[mask, :].reset_index(
1632                    drop=True
1633                )
1634
1635                self.agg_metadata = self.agg_metadata.drop(
1636                    columns=["drop"], errors="ignore"
1637                )
1638
1639        self.gene_calculation()
1640        self.cells_calculation()
1641
1642    def reduce_rows(self, features_list: list):
1643        """
1644        Remove rows (features) whose names are included in features_list.
1645
1646        Parameters
1647        ----------
1648        features_list : list
1649            List of features to search for in index/gene names; matching entries will be removed.
1650
1651        Update
1652        ------------
1653        Mutates `self.input_data`, `self.normalized_data`, `self.input_metadata`,
1654        `self.agg_normalized_data`, and `self.agg_metadata` (if they exist),
1655        removing columns/rows that match `reg`.
1656
1657        Raises
1658        ------
1659        Prints a message listing features that are not found in the data.
1660        """
1661
1662        if self.input_data is not None:
1663
1664            res = find_features(self.input_data, features=features_list)
1665
1666            res_list = [x.upper() for x in res["included"]]
1667
1668            mask = [x.upper() not in res_list for x in self.input_data.index]
1669
1670            if len([y for y in mask if y is False]) == 0:
1671                raise ValueError("Nothing found to reduce")
1672
1673            self.input_data = self.input_data.loc[mask, :]
1674
1675        if self.normalized_data is not None:
1676
1677            res = find_features(self.normalized_data, features=features_list)
1678
1679            res_list = [x.upper() for x in res["included"]]
1680
1681            mask = [x.upper() not in res_list for x in self.normalized_data.index]
1682
1683            if len([y for y in mask if y is False]) == 0:
1684                raise ValueError("Nothing found to reduce")
1685
1686            self.normalized_data = self.normalized_data.loc[mask, :]
1687
1688        if self.agg_normalized_data is not None:
1689
1690            res = find_features(self.agg_normalized_data, features=features_list)
1691
1692            res_list = [x.upper() for x in res["included"]]
1693
1694            mask = [x.upper() not in res_list for x in self.agg_normalized_data.index]
1695
1696            if len([y for y in mask if y is False]) == 0:
1697                raise ValueError("Nothing found to reduce")
1698
1699            self.agg_normalized_data = self.agg_normalized_data.loc[mask, :]
1700
1701        if len(res["not_included"]) > 0:
1702            print("\nFeatures not found:")
1703            for i in res["not_included"]:
1704                print(i)
1705
1706        self.gene_calculation()
1707        self.cells_calculation()
1708
1709    def get_data(self, set_info: bool = False):
1710        """
1711        Return normalized data with optional set annotation appended to column names.
1712
1713        Parameters
1714        ----------
1715        set_info : bool, default False
1716            If True, column names are returned as "cell_name # set"; otherwise
1717            only the `cell_name` is used.
1718
1719        Returns
1720        -------
1721        pandas.DataFrame
1722            The `self.normalized_data` table with columns renamed according to `set_info`.
1723
1724        Raises
1725        ------
1726        AttributeError
1727            If `self.normalized_data` or `self.input_metadata` is missing.
1728        """
1729
1730        to_return = self.normalized_data
1731
1732        if set_info:
1733            to_return.columns = (
1734                self.input_metadata["cell_names"] + " # " + self.input_metadata["sets"]
1735            )
1736        else:
1737            to_return.columns = self.input_metadata["cell_names"]
1738
1739        return to_return
1740
1741    def get_partial_data(
1742        self,
1743        names: list | str | None = None,
1744        features: list | str | None = None,
1745        name_slot: str = "cell_names",
1746        inc_metadata: bool = False,
1747    ):
1748        """
1749        Return a subset of the data filtered by sample names and/or feature names.
1750
1751        Parameters
1752        ----------
1753        names : list, str, or None
1754            Names of samples to include. If None, all samples are considered.
1755
1756        features : list, str, or None
1757            Names of features to include. If None, all features are considered.
1758
1759        name_slot : str
1760            Column in metadata to use as sample names.
1761
1762        inc_metadata : bool
1763            If True return tuple (data, metadata)
1764
1765        Returns
1766        -------
1767        pandas.DataFrame
1768            Subset of the normalized data based on the specified names and features.
1769        """
1770
1771        data = self.normalized_data.copy()
1772        metadata = self.input_metadata
1773
1774        if name_slot in self.input_metadata.columns:
1775            data.columns = self.input_metadata[name_slot]
1776        else:
1777            raise ValueError("'name_slot' not occured in data!'")
1778
1779        if isinstance(features, str):
1780            features = [features]
1781        elif features is None:
1782            features = []
1783
1784        if isinstance(names, str):
1785            names = [names]
1786        elif names is None:
1787            names = []
1788
1789        features = [x.upper() for x in features]
1790        names = [x.upper() for x in names]
1791
1792        columns_names = [x.upper() for x in data.columns]
1793        features_names = [x.upper() for x in data.index]
1794
1795        columns_bool = [True if x in names else False for x in columns_names]
1796        features_bool = [True if x in features else False for x in features_names]
1797
1798        if True not in columns_bool and True not in features_bool:
1799            print("Missing 'names' and/or 'features'. Returning full dataset instead.")
1800            
1801        if True in columns_bool:
1802            data = data.loc[:, columns_bool]
1803            metadata = metadata.loc[columns_bool, :]
1804
1805        if True in features_bool:
1806            data = data.loc[features_bool, :]
1807
1808        not_in_features = [y for y in features if y not in features_names]
1809
1810        if len(not_in_features) > 0:
1811            print('\nThe following features were not found in data:')
1812            print('\n'.join(not_in_features))
1813
1814        not_in_names = [y for y in names if y not in columns_names]
1815
1816        if len(not_in_names) > 0:
1817            print('\nThe following names were not found in data:')
1818            print('\n'.join(not_in_names))
1819
1820        if inc_metadata:
1821            return data, metadata
1822        else:
1823            return data
1824
1825    def get_metadata(self):
1826        """
1827        Return the stored input metadata.
1828
1829        Returns
1830        -------
1831        pandas.DataFrame
1832            `self.input_metadata` (may be None if not set).
1833        """
1834
1835        to_return = self.input_metadata
1836
1837        return to_return
1838
1839    def gene_calculation(self):
1840        """
1841        Calculate and store per-cell counts (e.g., number of detected genes).
1842
1843        The method computes a binary (presence/absence) per cell and sums across
1844        features to produce `self.gene_calc`.
1845
1846        Update
1847        ------
1848        Sets `self.gene_calc` as a pandas.Series.
1849
1850        Side Effects
1851        ------------
1852        Uses `self.input_data` when available, otherwise `self.normalized_data`.
1853        """
1854
1855        if self.input_data is not None:
1856
1857            bin_col = self.input_data.columns.copy()
1858
1859            bin_col = bin_col.where(bin_col <= 0, 1)
1860
1861            sum_data = bin_col.sum(axis=0)
1862
1863            self.gene_calc = sum_data
1864
1865        elif self.normalized_data is not None:
1866
1867            bin_col = self.normalized_data.copy()
1868
1869            bin_col = bin_col.where(bin_col <= 0, 1)
1870
1871            sum_data = bin_col.sum(axis=0)
1872
1873            self.gene_calc = sum_data
1874
1875    def gene_histograme(self, bins=100):
1876        """
1877        Plot a histogram of the number of genes detected per cell.
1878
1879        Parameters
1880        ----------
1881        bins : int, default 100
1882            Number of histogram bins.
1883
1884        Returns
1885        -------
1886        matplotlib.figure.Figure
1887            Figure containing the histogram of gene contents.
1888
1889        Notes
1890        -----
1891        Requires `self.gene_calc` to be computed prior to calling.
1892        """
1893
1894        fig, ax = plt.subplots(figsize=(8, 5))
1895
1896        _, bin_edges, _ = ax.hist(
1897            self.gene_calc, bins=bins, edgecolor="black", alpha=0.6
1898        )
1899
1900        mu, sigma = np.mean(self.gene_calc), np.std(self.gene_calc)
1901
1902        x = np.linspace(min(self.gene_calc), max(self.gene_calc), 1000)
1903        y = norm.pdf(x, mu, sigma)
1904
1905        y_scaled = y * len(self.gene_calc) * (bin_edges[1] - bin_edges[0])
1906
1907        ax.plot(
1908            x, y_scaled, "r-", linewidth=2, label=f"Normal(μ={mu:.2f}, σ={sigma:.2f})"
1909        )
1910
1911        ax.set_xlabel("Value")
1912        ax.set_ylabel("Count")
1913        ax.set_title("Histogram of genes detected per cell")
1914
1915        ax.set_xticks(np.linspace(min(self.gene_calc), max(self.gene_calc), 20))
1916        ax.tick_params(axis="x", rotation=90)
1917
1918        ax.legend()
1919
1920        return fig
1921
1922    def gene_threshold(self, min_n: int | None, max_n: int | None):
1923        """
1924        Filter cells by gene-detection thresholds (min and/or max).
1925
1926        Parameters
1927        ----------
1928        min_n : int or None
1929            Minimum number of detected genes required to keep a cell.
1930
1931        max_n : int or None
1932            Maximum number of detected genes allowed to keep a cell.
1933
1934        Update
1935        -------
1936        Filters `self.input_data`, `self.normalized_data`, `self.input_metadata`
1937        (and calls `average()` if `self.agg_normalized_data` exists).
1938
1939        Side Effects
1940        ------------
1941        Raises ValueError if both bounds are None or if filtering removes all cells.
1942        """
1943
1944        if min_n is not None and max_n is not None:
1945            mask = (self.gene_calc > min_n) & (self.gene_calc < max_n)
1946        elif min_n is None and max_n is not None:
1947            mask = self.gene_calc < max_n
1948        elif min_n is not None and max_n is None:
1949            mask = self.gene_calc > min_n
1950        else:
1951            raise ValueError("Lack of both min_n and max_n values")
1952
1953        if self.input_data is not None:
1954
1955            if len([y for y in mask if y is False]) == 0:
1956                raise ValueError("Nothing to reduce")
1957
1958            self.input_data = self.input_data.loc[:, mask.values]
1959
1960        if self.normalized_data is not None:
1961
1962            if len([y for y in mask if y is False]) == 0:
1963                raise ValueError("Nothing to reduce")
1964
1965            self.normalized_data = self.normalized_data.loc[:, mask.values]
1966
1967        if self.input_metadata is not None:
1968
1969            if len([y for y in mask if y is False]) == 0:
1970                raise ValueError("Nothing to reduce")
1971
1972            self.input_metadata = self.input_metadata.loc[mask.values, :].reset_index(
1973                drop=True
1974            )
1975
1976            self.input_metadata = self.input_metadata.drop(
1977                columns=["drop"], errors="ignore"
1978            )
1979
1980        if self.agg_normalized_data is not None:
1981            self.average()
1982
1983        self.gene_calculation()
1984        self.cells_calculation()
1985
1986    def cells_calculation(self, name_slot="cell_names"):
1987        """
1988        Calculate number of cells per  call name / cluster.
1989
1990        The method computes a binary (presence/absence) per cell name / cluster and sums across
1991        cells.
1992
1993        Parameters
1994        ----------
1995        name_slot : str, default 'cell_names'
1996            Column in metadata to use as sample names.
1997
1998        Update
1999        ------
2000        Sets `self.cells_calc` as a pd.DataFrame.
2001        """
2002
2003        ls = list(self.input_metadata[name_slot])
2004
2005        df = pd.DataFrame(
2006            {
2007                "cluster": pd.Series(ls).value_counts().index,
2008                "n": pd.Series(ls).value_counts().values,
2009            }
2010        )
2011
2012        self.cells_calc = df
2013
2014    def cell_histograme(self, name_slot: str = "cell_names"):
2015        """
2016        Plot a histogram of the number of cells detected per cell name (cluster).
2017
2018        Parameters
2019        ----------
2020        name_slot : str, default 'cell_names'
2021            Column in metadata to use as sample names.
2022
2023        Returns
2024        -------
2025        matplotlib.figure.Figure
2026            Figure containing the histogram of cell contents.
2027
2028        Notes
2029        -----
2030        Requires `self.cells_calc` to be computed prior to calling.
2031        """
2032
2033        if name_slot != "cell_names":
2034            self.cells_calculation(name_slot=name_slot)
2035
2036        fig, ax = plt.subplots(figsize=(8, 5))
2037
2038        _, bin_edges, _ = ax.hist(
2039            list(self.cells_calc["n"]),
2040            bins=len(set(self.cells_calc["cluster"])),
2041            edgecolor="black",
2042            color="orange",
2043            alpha=0.6,
2044        )
2045
2046        mu, sigma = np.mean(list(self.cells_calc["n"])), np.std(
2047            list(self.cells_calc["n"])
2048        )
2049
2050        x = np.linspace(
2051            min(list(self.cells_calc["n"])), max(list(self.cells_calc["n"])), 1000
2052        )
2053        y = norm.pdf(x, mu, sigma)
2054
2055        y_scaled = y * len(list(self.cells_calc["n"])) * (bin_edges[1] - bin_edges[0])
2056
2057        ax.plot(
2058            x, y_scaled, "r-", linewidth=2, label=f"Normal(μ={mu:.2f}, σ={sigma:.2f})"
2059        )
2060
2061        ax.set_xlabel("Value")
2062        ax.set_ylabel("Count")
2063        ax.set_title("Histogram of cells detected per cell name / cluster")
2064
2065        ax.set_xticks(
2066            np.linspace(
2067                min(list(self.cells_calc["n"])), max(list(self.cells_calc["n"])), 20
2068            )
2069        )
2070        ax.tick_params(axis="x", rotation=90)
2071
2072        ax.legend()
2073
2074        return fig
2075
2076    def cluster_threshold(self, min_n: int | None, name_slot: str = "cell_names"):
2077        """
2078        Filter cell names / clusters by cell-detection threshold.
2079
2080        Parameters
2081        ----------
2082        min_n : int or None
2083            Minimum number of detected genes required to keep a cell.
2084
2085        name_slot : str, default 'cell_names'
2086            Column in metadata to use as sample names.
2087
2088
2089        Update
2090        -------
2091        Filters `self.input_data`, `self.normalized_data`, `self.input_metadata`
2092        (and calls `average()` if `self.agg_normalized_data` exists).
2093        """
2094
2095        if name_slot != "cell_names":
2096            self.cells_calculation(name_slot=name_slot)
2097
2098        if min_n is not None:
2099            names = self.cells_calc["cluster"][self.cells_calc["n"] < min_n]
2100        else:
2101            raise ValueError("Lack of min_n value")
2102
2103        if len(names) > 0:
2104
2105            if self.input_data is not None:
2106
2107                self.input_data.columns = self.input_metadata[name_slot]
2108
2109                mask = [not any(r in x for r in names) for x in self.input_data.columns]
2110
2111                if len([y for y in mask if y is False]) > 0:
2112
2113                    self.input_data = self.input_data.loc[:, mask]
2114
2115            if self.normalized_data is not None:
2116
2117                self.normalized_data.columns = self.input_metadata[name_slot]
2118
2119                mask = [
2120                    not any(r in x for r in names) for x in self.normalized_data.columns
2121                ]
2122
2123                if len([y for y in mask if y is False]) > 0:
2124
2125                    self.normalized_data = self.normalized_data.loc[:, mask]
2126
2127            if self.input_metadata is not None:
2128
2129                self.input_metadata["drop"] = self.input_metadata[name_slot]
2130
2131                mask = [
2132                    not any(r in x for r in names) for x in self.input_metadata["drop"]
2133                ]
2134
2135                if len([y for y in mask if y is False]) > 0:
2136
2137                    self.input_metadata = self.input_metadata.loc[mask, :].reset_index(
2138                        drop=True
2139                    )
2140
2141                self.input_metadata = self.input_metadata.drop(
2142                    columns=["drop"], errors="ignore"
2143                )
2144
2145            if self.agg_normalized_data is not None:
2146
2147                self.agg_normalized_data.columns = self.agg_metadata[name_slot]
2148
2149                mask = [
2150                    not any(r in x for r in names)
2151                    for x in self.agg_normalized_data.columns
2152                ]
2153
2154                if len([y for y in mask if y is False]) > 0:
2155
2156                    self.agg_normalized_data = self.agg_normalized_data.loc[:, mask]
2157
2158            if self.agg_metadata is not None:
2159
2160                self.agg_metadata["drop"] = self.agg_metadata[name_slot]
2161
2162                mask = [
2163                    not any(r in x for r in names) for x in self.agg_metadata["drop"]
2164                ]
2165
2166                if len([y for y in mask if y is False]) > 0:
2167
2168                    self.agg_metadata = self.agg_metadata.loc[mask, :].reset_index(
2169                        drop=True
2170                    )
2171
2172                self.agg_metadata = self.agg_metadata.drop(
2173                    columns=["drop"], errors="ignore"
2174                )
2175
2176            self.gene_calculation()
2177            self.cells_calculation()
2178
2179    def load_sparse_from_projects(self, normalized_data: bool = False):
2180        """
2181        Load sparse 10x-style datasets from stored project paths, concatenate them,
2182        and populate `input_data` / `normalized_data` and `input_metadata`.
2183
2184        Parameters
2185        ----------
2186        normalized_data : bool, default False
2187            If True, store concatenated tables in `self.normalized_data`.
2188            If False, store them in `self.input_data` and normalization
2189            is needed using normalize_data() method.
2190
2191        Side Effects
2192        ------------
2193        - Reads each project using `load_sparse(...)` (expects matrix.mtx, genes.tsv, barcodes.tsv).
2194        - Concatenates all projects column-wise and sets `self.input_metadata`.
2195        - Replaces NaNs with zeros and updates `self.gene_calc`.
2196        """
2197
2198        obj = self.objects
2199
2200        full_data = pd.DataFrame()
2201        full_metadata = pd.DataFrame()
2202
2203        for ke in obj.keys():
2204            print(ke)
2205
2206            dt, met = load_sparse(path=obj[ke], name=ke)
2207
2208            full_data = pd.concat([full_data, dt], axis=1)
2209            full_metadata = pd.concat([full_metadata, met], axis=0)
2210
2211        full_data[np.isnan(full_data)] = 0
2212
2213        if normalized_data:
2214            self.normalized_data = full_data
2215            self.input_metadata = full_metadata
2216        else:
2217
2218            self.input_data = full_data
2219            self.input_metadata = full_metadata
2220
2221        self.gene_calculation()
2222        self.cells_calculation()
2223
2224    def rename_names(self, mapping: dict, slot: str = "cell_names"):
2225        """
2226        Rename entries in `self.input_metadata[slot]` according to a provided mapping.
2227
2228        Parameters
2229        ----------
2230        mapping : dict
2231            Dictionary with keys 'old_name' and 'new_name', each mapping to a list
2232            of equal length describing replacements.
2233
2234        slot : str, default 'cell_names'
2235            Metadata column to operate on.
2236
2237        Update
2238        -------
2239        Updates `self.input_metadata[slot]` in-place with renamed values.
2240
2241        Raises
2242        ------
2243        ValueError
2244            If mapping keys are incorrect, lengths differ, or some 'old_name' values
2245            are not present in the metadata column.
2246        """
2247
2248        if set(["old_name", "new_name"]) != set(mapping.keys()):
2249            raise ValueError(
2250                "Mapping dictionary must contain keys 'old_name' and 'new_name', "
2251                "each with a list of names to change."
2252            )
2253
2254        if len(mapping["old_name"]) != len(mapping["new_name"]):
2255            raise ValueError(
2256                "Mapping dictionary lists 'old_name' and 'new_name' "
2257                "must have the same length!"
2258            )
2259
2260        names_vector = list(self.input_metadata[slot])
2261
2262        if not all(elem in names_vector for elem in list(mapping["old_name"])):
2263            raise ValueError(
2264                f"Some entries from 'old_name' do not exist in the names of slot {slot}."
2265            )
2266
2267        replace_dict = dict(zip(mapping["old_name"], mapping["new_name"]))
2268
2269        names_vector_ret = [replace_dict.get(item, item) for item in names_vector]
2270
2271        self.input_metadata[slot] = names_vector_ret
2272
2273    def rename_subclusters(self, mapping):
2274        """
2275        Rename labels stored in `self.subclusters_.subclusters` according to mapping.
2276
2277        Parameters
2278        ----------
2279        mapping : dict
2280            Mapping with keys 'old_name' and 'new_name' (lists of equal length).
2281
2282        Update
2283        -------
2284        Updates `self.subclusters_.subclusters` with renamed labels.
2285
2286        Raises
2287        ------
2288        ValueError
2289            If mapping is invalid or old names are not present.
2290        """
2291
2292        if set(["old_name", "new_name"]) != set(mapping.keys()):
2293            raise ValueError(
2294                "Mapping dictionary must contain keys 'old_name' and 'new_name', "
2295                "each with a list of names to change."
2296            )
2297
2298        if len(mapping["old_name"]) != len(mapping["new_name"]):
2299            raise ValueError(
2300                "Mapping dictionary lists 'old_name' and 'new_name' "
2301                "must have the same length!"
2302            )
2303
2304        names_vector = list(self.subclusters_.subclusters)
2305
2306        if not all(elem in names_vector for elem in list(mapping["old_name"])):
2307            raise ValueError(
2308                "Some entries from 'old_name' do not exist in the subcluster names."
2309            )
2310
2311        replace_dict = dict(zip(mapping["old_name"], mapping["new_name"]))
2312
2313        names_vector_ret = [replace_dict.get(item, item) for item in names_vector]
2314
2315        self.subclusters_.subclusters = names_vector_ret
2316
2317    def save_sparse(
2318        self,
2319        path_to_save: str = os.getcwd(),
2320        name_slot: str = "cell_names",
2321        data_slot: str = "normalized",
2322    ):
2323        """
2324        Export data as 10x-compatible sparse files (matrix.mtx, barcodes.tsv, genes.tsv).
2325
2326        Parameters
2327        ----------
2328        path_to_save : str, default current working directory
2329            Directory where files will be written.
2330
2331        name_slot : str, default 'cell_names'
2332            Metadata column providing cell names for barcodes.tsv.
2333
2334        data_slot : str, default 'normalized'
2335            Either 'normalized' (uses self.normalized_data) or 'count' (uses self.input_data).
2336
2337        Raises
2338        ------
2339        ValueError
2340            If `data_slot` is not 'normalized' or 'count'.
2341        """
2342
2343        names = self.input_metadata[name_slot]
2344
2345        if data_slot.lower() == "normalized":
2346
2347            features = list(self.normalized_data.index)
2348            mtx = sparse.csr_matrix(self.normalized_data)
2349
2350        elif data_slot.lower() == "count":
2351
2352            features = list(self.input_data.index)
2353            mtx = sparse.csr_matrix(self.input_data)
2354
2355        else:
2356            raise ValueError("'data_slot' must be included in 'normalized' or 'count'")
2357
2358        os.makedirs(path_to_save, exist_ok=True)
2359
2360        mmwrite(os.path.join(path_to_save, "matrix.mtx"), mtx)
2361
2362        pd.Series(names).to_csv(
2363            os.path.join(path_to_save, "barcodes.tsv"),
2364            index=False,
2365            header=False,
2366            sep="\t",
2367        )
2368
2369        pd.Series(features).to_csv(
2370            os.path.join(path_to_save, "genes.tsv"), index=False, header=False, sep="\t"
2371        )
2372
2373    def normalize_counts(
2374        self, normalize_factor: int = 100000, log_transform: bool = True
2375    ):
2376        """
2377        Normalize raw counts to counts-per-(normalize_factor)
2378        (e.g., CPM, TPM - depending on normalize_factor).
2379
2380        Parameters
2381        ----------
2382        normalize_factor : int, default 100000
2383            Scaling factor used after dividing by column sums.
2384
2385        log_transform : bool, default True
2386            If True, apply log2(x+1) transformation to normalized values.
2387
2388        Update
2389        -------
2390            Sets `self.normalized_data` to normalized values (fills NaNs with 0).
2391
2392        Raises
2393        ------
2394        ValueError
2395            If `self.input_data` is missing (cannot normalize).
2396        """
2397        if self.input_data is None:
2398            raise ValueError("Input data is missing, cannot normalize.")
2399
2400        sum_col = self.input_data.sum()
2401        self.normalized_data = self.input_data.div(sum_col).fillna(0) * normalize_factor
2402
2403        if log_transform:
2404            # log2(x + 1) to avoid -inf for zeros
2405            self.normalized_data = np.log2(self.normalized_data + 1)
2406
2407    def statistic(
2408        self,
2409        cells=None,
2410        sets=None,
2411        min_exp: float = 0.01,
2412        min_pct: float = 0.1,
2413        n_proc: int = 10,
2414    ):
2415        """
2416        Compute per-feature statistics (Mann–Whitney U) comparing target vs rest.
2417
2418        This is a wrapper similar to `calc_DEG` tailored to use `self.normalized_data`
2419        and `self.input_metadata`. It returns per-feature statistics including p-values,
2420        adjusted p-values, means, variances, effect-size measures and fold-changes.
2421
2422        Parameters
2423        ----------
2424        cells : list, 'All', dict, or None
2425            Defines the target cells or groups for comparison (several modes supported).
2426
2427        sets : 'All', dict, or None
2428            Alternative grouping mode (operate on `self.input_metadata['sets']`).
2429
2430        min_exp : float, default 0.01
2431            Minimum expression threshold used when filtering features.
2432
2433        min_pct : float, default 0.1
2434            Minimum proportion of expressing cells in the target group required to test a feature.
2435
2436        n_proc : int, default 10
2437            Number of parallel jobs to use.
2438
2439        Returns
2440        -------
2441        pandas.DataFrame or dict
2442            Results DataFrame (or dict containing valid/control cells + DataFrame),
2443            similar to `calc_DEG` interface.
2444
2445        Raises
2446        ------
2447        ValueError
2448            If neither `cells` nor `sets` is provided, or input metadata mismatch occurs.
2449
2450        Notes
2451        -----
2452        Multiple modes supported: single-list entities, 'All', pairwise dicts, etc.
2453        """
2454
2455        offset = 1e-100
2456
2457        def stat_calc(choose, feature_name):
2458            target_values = choose.loc[choose["DEG"] == "target", feature_name]
2459            rest_values = choose.loc[choose["DEG"] == "rest", feature_name]
2460
2461            pct_valid = (target_values > 0).sum() / len(target_values)
2462            pct_rest = (rest_values > 0).sum() / len(rest_values)
2463
2464            avg_valid = np.mean(target_values)
2465            avg_ctrl = np.mean(rest_values)
2466            sd_valid = np.std(target_values, ddof=1)
2467            sd_ctrl = np.std(rest_values, ddof=1)
2468            esm = (avg_valid - avg_ctrl) / np.sqrt(((sd_valid**2 + sd_ctrl**2) / 2))
2469
2470            if np.sum(target_values) == np.sum(rest_values):
2471                p_val = 1.0
2472            else:
2473                _, p_val = stats.mannwhitneyu(
2474                    target_values, rest_values, alternative="two-sided"
2475                )
2476
2477            return {
2478                "feature": feature_name,
2479                "p_val": p_val,
2480                "pct_valid": pct_valid,
2481                "pct_ctrl": pct_rest,
2482                "avg_valid": avg_valid,
2483                "avg_ctrl": avg_ctrl,
2484                "sd_valid": sd_valid,
2485                "sd_ctrl": sd_ctrl,
2486                "esm": esm,
2487            }
2488
2489        def prepare_and_run_stat(choose, valid_group, min_exp, min_pct, n_proc):
2490
2491            def safe_min_half(series):
2492                filtered = series[(series > ((2**-1074)*2)) & (series.notna())]
2493                return filtered.min() / 2 if not filtered.empty else 0
2494        
2495            tmp_dat = choose[choose["DEG"] == "target"]
2496            tmp_dat = tmp_dat.drop("DEG", axis=1)
2497
2498            counts = (tmp_dat > min_exp).sum(axis=0)
2499
2500            total_count = tmp_dat.shape[0]
2501
2502            info = pd.DataFrame(
2503                {"feature": list(tmp_dat.columns), "pct": list(counts / total_count)}
2504            )
2505
2506            del tmp_dat
2507
2508            drop_col = info["feature"][info["pct"] <= min_pct]
2509
2510            if len(drop_col) + 1 == len(choose.columns):
2511                drop_col = info["feature"][info["pct"] == 0]
2512
2513            del info
2514
2515            choose = choose.drop(list(drop_col), axis=1)
2516
2517            results = Parallel(n_jobs=n_proc)(
2518                delayed(stat_calc)(choose, feature)
2519                for feature in tqdm(choose.columns[choose.columns != "DEG"])
2520            )
2521
2522            df = pd.DataFrame(results)
2523            df = df[(df["avg_valid"] > 0) | (df["avg_ctrl"] > 0)]
2524
2525            df["valid_group"] = valid_group
2526            df.sort_values(by="p_val", inplace=True)
2527
2528            num_tests = len(df)
2529            df["adj_pval"] = np.minimum(
2530                1, (df["p_val"] * num_tests) / np.arange(1, num_tests + 1)
2531            )
2532
2533            valid_factor = safe_min_half(df["avg_valid"])
2534            ctrl_factor = safe_min_half(df["avg_ctrl"])
2535
2536            cv_factor = min(valid_factor, ctrl_factor)
2537
2538            if cv_factor == 0:
2539                cv_factor = max(valid_factor, ctrl_factor)
2540
2541            if not np.isfinite(cv_factor) or cv_factor == 0:
2542                cv_factor += offset
2543
2544            valid = df["avg_valid"].where(
2545                df["avg_valid"] != 0, df["avg_valid"] + cv_factor
2546            )
2547            ctrl = df["avg_ctrl"].where(
2548                df["avg_ctrl"] != 0, df["avg_ctrl"] + cv_factor
2549            )
2550
2551            df["FC"] = valid / ctrl
2552
2553            df["log(FC)"] = np.log2(df["FC"])
2554            df["norm_diff"] = df["avg_valid"] - df["avg_ctrl"]
2555
2556            return df
2557
2558        choose = self.normalized_data.copy().T
2559
2560        final_results = []
2561
2562        if isinstance(cells, list) and sets is None:
2563            print("\nAnalysis started...\nComparing selected cells to the whole set...")
2564            choose.index = (
2565                self.input_metadata["cell_names"] + " # " + self.input_metadata["sets"]
2566            )
2567
2568            if "#" not in cells[0]:
2569                choose.index = self.input_metadata["cell_names"]
2570
2571                print(
2572                    "Not include the set info (name # set) in the 'cells' list. "
2573                    "Only the names will be compared, without considering the set information."
2574                )
2575
2576            labels = ["target" if idx in cells else "rest" for idx in choose.index]
2577            valid = list(
2578                set(choose.index[[i for i, x in enumerate(labels) if x == "target"]])
2579            )
2580
2581            choose["DEG"] = labels
2582            choose = choose[choose["DEG"] != "drop"]
2583
2584            result_df = prepare_and_run_stat(
2585                choose.reset_index(drop=True),
2586                valid_group=valid,
2587                min_exp=min_exp,
2588                min_pct=min_pct,
2589                n_proc=n_proc,
2590            )
2591            return {"valid_cells": valid, "control_cells": "rest", "DEG": result_df}
2592
2593        elif cells == "All" and sets is None:
2594            print("\nAnalysis started...\nComparing each type of cell to others...")
2595            choose.index = (
2596                self.input_metadata["cell_names"] + " # " + self.input_metadata["sets"]
2597            )
2598            unique_labels = set(choose.index)
2599
2600            for label in tqdm(unique_labels):
2601                print(f"\nCalculating statistics for {label}")
2602                labels = ["target" if idx == label else "rest" for idx in choose.index]
2603                choose["DEG"] = labels
2604                choose = choose[choose["DEG"] != "drop"]
2605                result_df = prepare_and_run_stat(
2606                    choose.copy(),
2607                    valid_group=label,
2608                    min_exp=min_exp,
2609                    min_pct=min_pct,
2610                    n_proc=n_proc,
2611                )
2612                final_results.append(result_df)
2613
2614            return pd.concat(final_results, ignore_index=True)
2615
2616        elif cells is None and sets == "All":
2617            print("\nAnalysis started...\nComparing each set/group to others...")
2618            choose.index = self.input_metadata["sets"]
2619            unique_sets = set(choose.index)
2620
2621            for label in tqdm(unique_sets):
2622                print(f"\nCalculating statistics for {label}")
2623                labels = ["target" if idx == label else "rest" for idx in choose.index]
2624
2625                choose["DEG"] = labels
2626                choose = choose[choose["DEG"] != "drop"]
2627                result_df = prepare_and_run_stat(
2628                    choose.copy(),
2629                    valid_group=label,
2630                    min_exp=min_exp,
2631                    min_pct=min_pct,
2632                    n_proc=n_proc,
2633                )
2634                final_results.append(result_df)
2635
2636            return pd.concat(final_results, ignore_index=True)
2637
2638        elif cells is None and isinstance(sets, dict):
2639            print("\nAnalysis started...\nComparing groups...")
2640
2641            choose.index = self.input_metadata["sets"]
2642
2643            group_list = list(sets.keys())
2644            if len(group_list) != 2:
2645                print("Only pairwise group comparison is supported.")
2646                return None
2647
2648            labels = [
2649                (
2650                    "target"
2651                    if idx in sets[group_list[0]]
2652                    else "rest" if idx in sets[group_list[1]] else "drop"
2653                )
2654                for idx in choose.index
2655            ]
2656            choose["DEG"] = labels
2657            choose = choose[choose["DEG"] != "drop"]
2658
2659            result_df = prepare_and_run_stat(
2660                choose.reset_index(drop=True),
2661                valid_group=group_list[0],
2662                min_exp=min_exp,
2663                min_pct=min_pct,
2664                n_proc=n_proc,
2665            )
2666            return result_df
2667
2668        elif isinstance(cells, dict) and sets is None:
2669            print("\nAnalysis started...\nComparing groups...")
2670            choose.index = (
2671                self.input_metadata["cell_names"] + " # " + self.input_metadata["sets"]
2672            )
2673
2674            if "#" not in cells[list(cells.keys())[0]][0]:
2675                choose.index = self.input_metadata["cell_names"]
2676
2677                print(
2678                    "Not include the set info (name # set) in the 'cells' dict. "
2679                    "Only the names will be compared, without considering the set information."
2680                )
2681
2682            group_list = list(cells.keys())
2683            if len(group_list) != 2:
2684                print("Only pairwise group comparison is supported.")
2685                return None
2686
2687            labels = [
2688                (
2689                    "target"
2690                    if idx in cells[group_list[0]]
2691                    else "rest" if idx in cells[group_list[1]] else "drop"
2692                )
2693                for idx in choose.index
2694            ]
2695
2696            choose["DEG"] = labels
2697            choose = choose[choose["DEG"] != "drop"]
2698
2699            result_df = prepare_and_run_stat(
2700                choose.reset_index(drop=True),
2701                valid_group=group_list[0],
2702                min_exp=min_exp,
2703                min_pct=min_pct,
2704                n_proc=n_proc,
2705            )
2706
2707            return result_df.reset_index(drop=True)
2708
2709        else:
2710            raise ValueError(
2711                "You must specify either 'cells' or 'sets' (or both). None were provided, which is not allowed for this analysis."
2712            )
2713
2714    def calculate_difference_markers(
2715        self, min_exp=0, min_pct=0.25, n_proc=10, force: bool = False
2716    ):
2717        """
2718        Compute differential markers (var_data) if not already present.
2719
2720        Parameters
2721        ----------
2722        min_exp : float, default 0
2723            Minimum expression threshold passed to `statistic`.
2724
2725        min_pct : float, default 0.25
2726            Minimum percent expressed in target group.
2727
2728        n_proc : int, default 10
2729            Parallel jobs.
2730
2731        force : bool, default False
2732            If True, recompute even if `self.var_data` is present.
2733
2734        Update
2735        -------
2736        Sets `self.var_data` to the result of `self.statistic(...)`.
2737
2738        Raise
2739        ------
2740        ValueError if already computed and `force` is False.
2741        """
2742
2743        if self.var_data is None or force:
2744
2745            self.var_data = self.statistic(
2746                cells="All", sets=None, min_exp=min_exp, min_pct=min_pct, n_proc=n_proc
2747            )
2748
2749        else:
2750            raise ValueError(
2751                "self.calculate_difference_markers() has already been executed. "
2752                "The results are stored in self.var. "
2753                "If you want to recalculate with different statistics, please rerun the method with force=True."
2754            )
2755
2756    def clustering_features(
2757        self,
2758        features_list: list | None,
2759        name_slot: str = "cell_names",
2760        p_val: float = 0.05,
2761        top_n: int = 25,
2762        adj_mean: bool = True,
2763        beta: float = 0.2,
2764    ):
2765        """
2766        Prepare clustering input by selecting marker features and optionally smoothing cell values
2767        toward group means.
2768
2769        Parameters
2770        ----------
2771        features_list : list or None
2772            If provided, use this list of features. If None, features are selected
2773            from `self.var_data` (adj_pval <= p_val, positive logFC) picking `top_n` per group.
2774
2775        name_slot : str, default 'cell_names'
2776            Metadata column used for naming.
2777
2778        p_val : float, default 0.05
2779            Adjusted p-value cutoff when selecting features automatically.
2780
2781        top_n : int, default 25
2782            Number of top features per valid group to keep if `features_list` is None.
2783
2784        adj_mean : bool, default True
2785            If True, adjust cell values toward group means using `beta`.
2786
2787        beta : float, default 0.2
2788            Adjustment strength toward group mean.
2789
2790        Update
2791        ------
2792        Sets `self.clustering_data` and `self.clustering_metadata` to the selected subset,
2793        ready for PCA/UMAP/clustering.
2794        """
2795
2796        if features_list is None or len(features_list) == 0:
2797
2798            if self.var_data is None:
2799                raise ValueError(
2800                    "Lack of 'self.var_data'. Use self.calculate_difference_markers() method first."
2801                )
2802
2803            df_tmp = self.var_data[self.var_data["adj_pval"] <= p_val]
2804            df_tmp = df_tmp[df_tmp["log(FC)"] > 0]
2805            df_tmp = (
2806                df_tmp.sort_values(
2807                    ["valid_group", "esm", "log(FC)"], ascending=[True, False, False]
2808                )
2809                .groupby("valid_group")
2810                .head(top_n)
2811            )
2812
2813            feaures_list = list(set(df_tmp["feature"]))
2814
2815        data = self.get_partial_data(
2816            names=None, features=feaures_list, name_slot=name_slot
2817        )
2818        data_avg = average(data)
2819
2820        if adj_mean:
2821            data = adjust_cells_to_group_mean(data=data, data_avg=data_avg, beta=beta)
2822
2823        self.clustering_data = data
2824
2825        self.clustering_metadata = self.input_metadata
2826
2827    def average(self):
2828        """
2829        Aggregate normalized data by (cell_name, set) pairs computing the mean per group.
2830
2831        The method constructs new column names as "cell_name # set", averages columns
2832        sharing identical labels, and populates `self.agg_normalized_data` and `self.agg_metadata`.
2833
2834        Update
2835        ------
2836        Sets `self.agg_normalized_data` (features x aggregated samples) and
2837        `self.agg_metadata` (DataFrame with 'cell_names' and 'sets').
2838        """
2839
2840        wide_data = self.normalized_data
2841
2842        wide_metadata = self.input_metadata
2843
2844        new_names = wide_metadata["cell_names"] + " # " + wide_metadata["sets"]
2845
2846        wide_data.columns = list(new_names)
2847
2848        aggregated_df = wide_data.T.groupby(level=0).mean().T
2849
2850        sets = [re.sub(".*# ", "", x) for x in aggregated_df.columns]
2851        names = [re.sub(" #.*", "", x) for x in aggregated_df.columns]
2852
2853        aggregated_df.columns = names
2854        aggregated_metadata = pd.DataFrame({"cell_names": names, "sets": sets})
2855
2856        self.agg_metadata = aggregated_metadata
2857        self.agg_normalized_data = aggregated_df
2858
2859    def estimating_similarity(
2860        self, method="pearson", p_val: float = 0.05, top_n: int = 25
2861    ):
2862        """
2863        Estimate pairwise similarity and Euclidean distance between aggregated samples.
2864
2865        Parameters
2866        ----------
2867        method : str, default 'pearson'
2868            Correlation method to use (passed to pandas.DataFrame.corr()).
2869
2870        p_val : float, default 0.05
2871            Adjusted p-value cutoff used to select marker features from `self.var_data`.
2872
2873        top_n : int, default 25
2874            Number of top features per valid group to include.
2875
2876        Update
2877        -------
2878        Computes a combined table with per-pair correlation and euclidean distance
2879        and stores it in `self.similarity`.
2880        """
2881
2882        if self.var_data is None:
2883            raise ValueError(
2884                "Lack of 'self.var_data'. Use self.calculate_difference_markers() method first."
2885            )
2886
2887        if self.agg_normalized_data is None:
2888            self.average()
2889
2890        metadata = self.agg_metadata
2891        data = self.agg_normalized_data
2892
2893        df_tmp = self.var_data[self.var_data["adj_pval"] <= p_val]
2894        df_tmp = df_tmp[df_tmp["log(FC)"] > 0]
2895        df_tmp = (
2896            df_tmp.sort_values(
2897                ["valid_group", "esm", "log(FC)"], ascending=[True, False, False]
2898            )
2899            .groupby("valid_group")
2900            .head(top_n)
2901        )
2902
2903        data = data.loc[list(set(df_tmp["feature"]))]
2904
2905        if len(set(metadata["sets"])) > 1:
2906            data.columns = data.columns + " # " + [x for x in metadata["sets"]]
2907        else:
2908            data = data.copy()
2909
2910        scaler = StandardScaler()
2911
2912        scaled_data = scaler.fit_transform(data)
2913
2914        scaled_df = pd.DataFrame(scaled_data, columns=data.columns)
2915
2916        cor = scaled_df.corr(method=method)
2917        cor_df = cor.stack().reset_index()
2918        cor_df.columns = ["cell1", "cell2", "correlation"]
2919
2920        distances = pdist(scaled_df.T, metric="euclidean")
2921        dist_mat = pd.DataFrame(
2922            squareform(distances), index=scaled_df.columns, columns=scaled_df.columns
2923        )
2924        dist_df = dist_mat.stack().reset_index()
2925        dist_df.columns = ["cell1", "cell2", "euclidean_dist"]
2926
2927        full = pd.merge(cor_df, dist_df, on=["cell1", "cell2"])
2928
2929        full = full[full["cell1"] != full["cell2"]]
2930        full = full.reset_index(drop=True)
2931
2932        self.similarity = full
2933
2934    def similarity_plot(
2935        self,
2936        split_sets=True,
2937        set_info: bool = True,
2938        cmap="seismic",
2939        width=12,
2940        height=10,
2941    ):
2942        """
2943        Visualize pairwise similarity as a scatter plot.
2944
2945        Parameters
2946        ----------
2947        split_sets : bool, default True
2948            If True and set information is present, split plotting area roughly into two halves to visualize cross-set pairs.
2949
2950        set_info : bool, default True
2951            If True, keep the ' # set' annotation in labels; otherwise strip it.
2952
2953        cmap : str, default 'seismic'
2954            Color map for correlation (hue).
2955
2956        width : int, default 12
2957            Figure width.
2958
2959        height : int, default 10
2960            Figure height.
2961
2962        Returns
2963        -------
2964        matplotlib.figure.Figure
2965
2966        Raises
2967        ------
2968        ValueError
2969            If `self.similarity` is None.
2970
2971        Notes
2972        -----
2973        The function filters pairs by z-scored euclidean distance > 0 to focus on closer pairs.
2974        """
2975
2976        if self.similarity is None:
2977            raise ValueError(
2978                "Similarity data is missing. Please calculate similarity using self.estimating_similarity."
2979            )
2980
2981        similarity_data = self.similarity
2982
2983        if " # " in similarity_data["cell1"][0]:
2984            similarity_data["set1"] = [
2985                re.sub(".*# ", "", x) for x in similarity_data["cell1"]
2986            ]
2987            similarity_data["set2"] = [
2988                re.sub(".*# ", "", x) for x in similarity_data["cell2"]
2989            ]
2990
2991        if split_sets and " # " in similarity_data["cell1"][0]:
2992            sets = list(
2993                set(list(similarity_data["set1"]) + list(similarity_data["set2"]))
2994            )
2995
2996            mm = math.ceil(len(sets) / 2)
2997
2998            x_s = sets[0:mm]
2999            y_s = sets[mm : len(sets)]
3000
3001            similarity_data = similarity_data[similarity_data["set1"].isin(x_s)]
3002            similarity_data = similarity_data[similarity_data["set2"].isin(y_s)]
3003
3004            similarity_data = similarity_data.sort_values(["set1", "set2"])
3005
3006        if set_info is False and " # " in similarity_data["cell1"][0]:
3007            similarity_data["cell1"] = [
3008                re.sub(" #.*", "", x) for x in similarity_data["cell1"]
3009            ]
3010            similarity_data["cell2"] = [
3011                re.sub(" #.*", "", x) for x in similarity_data["cell2"]
3012            ]
3013
3014        similarity_data["-euclidean_zscore"] = -zscore(
3015            similarity_data["euclidean_dist"]
3016        )
3017
3018        similarity_data = similarity_data[similarity_data["-euclidean_zscore"] > 0]
3019
3020        fig = plt.figure(figsize=(width, height))
3021        sns.scatterplot(
3022            data=similarity_data,
3023            x="cell1",
3024            y="cell2",
3025            hue="correlation",
3026            size="-euclidean_zscore",
3027            sizes=(1, 100),
3028            palette=cmap,
3029            alpha=1,
3030            edgecolor="black",
3031        )
3032
3033        plt.xticks(rotation=90)
3034        plt.yticks(rotation=0)
3035        plt.xlabel("Cell 1")
3036        plt.ylabel("Cell 2")
3037        plt.legend(bbox_to_anchor=(1.05, 1), loc="upper left")
3038
3039        plt.grid(True, alpha=0.6)
3040
3041        plt.tight_layout()
3042
3043        return fig
3044
3045    def spatial_similarity(
3046        self,
3047        set_info: bool = True,
3048        bandwidth=1,
3049        n_neighbors=5,
3050        min_dist=0.1,
3051        legend_split=2,
3052        point_size=100,
3053        spread=1.0,
3054        set_op_mix_ratio=1.0,
3055        local_connectivity=1,
3056        repulsion_strength=1.0,
3057        negative_sample_rate=5,
3058        threshold=0.1,
3059        width=12,
3060        height=10,
3061    ):
3062        """
3063        Create a spatial UMAP-like visualization of similarity relationships between samples.
3064
3065        Parameters
3066        ----------
3067        set_info : bool, default True
3068            If True, retain set information in labels.
3069
3070        bandwidth : float, default 1
3071            Bandwidth used by MeanShift for clustering polygons.
3072
3073        point_size : float, default 100
3074            Size of scatter points.
3075
3076        legend_split : int, default 2
3077            Number of columns in legend.
3078
3079        n_neighbors, min_dist, spread, set_op_mix_ratio, local_connectivity, repulsion_strength, negative_sample_rate : parameters passed to UMAP.
3080
3081        threshold : float, default 0.1
3082            Minimum text distance for label adjustment to avoid overlap.
3083
3084        width : int, default 12
3085            Figure width.
3086
3087        height : int, default 10
3088            Figure height.
3089
3090        Returns
3091        -------
3092        matplotlib.figure.Figure
3093
3094        Raises
3095        ------
3096        ValueError
3097            If `self.similarity` is None.
3098
3099        Notes
3100        -----
3101        Builds a precomputed distance matrix combining correlation and euclidean distance,
3102        runs UMAP with metric='precomputed', then overlays cluster hulls (MeanShift + convex hull)
3103        and arrows to indicate nearest neighbors (minimal combined distance).
3104        """
3105
3106        if self.similarity is None:
3107            raise ValueError(
3108                "Similarity data is missing. Please calculate similarity using self.estimating_similarity."
3109            )
3110
3111        similarity_data = self.similarity
3112
3113        sim = similarity_data["correlation"]
3114        sim_scaled = (sim - sim.min()) / (sim.max() - sim.min())
3115        eu_dist = similarity_data["euclidean_dist"]
3116        eu_dist_scaled = (eu_dist - eu_dist.min()) / (eu_dist.max() - eu_dist.min())
3117
3118        similarity_data["combo_dist"] = (1 - sim_scaled) * eu_dist_scaled
3119
3120        # for nn target
3121        arrow_df = similarity_data.copy()
3122        arrow_df = similarity_data.loc[
3123            similarity_data.groupby("cell1")["combo_dist"].idxmin()
3124        ]
3125
3126        cells = sorted(set(similarity_data["cell1"]) | set(similarity_data["cell2"]))
3127        combo_matrix = pd.DataFrame(0, index=cells, columns=cells, dtype=float)
3128
3129        for _, row in similarity_data.iterrows():
3130            combo_matrix.loc[row["cell1"], row["cell2"]] = row["combo_dist"]
3131            combo_matrix.loc[row["cell2"], row["cell1"]] = row["combo_dist"]
3132
3133        umap_model = umap.UMAP(
3134            n_components=2,
3135            metric="precomputed",
3136            n_neighbors=n_neighbors,
3137            min_dist=min_dist,
3138            spread=spread,
3139            set_op_mix_ratio=set_op_mix_ratio,
3140            local_connectivity=set_op_mix_ratio,
3141            repulsion_strength=repulsion_strength,
3142            negative_sample_rate=negative_sample_rate,
3143            transform_seed=42,
3144            init="spectral",
3145            random_state=42,
3146            verbose=True,
3147        )
3148
3149        coords = umap_model.fit_transform(combo_matrix.values)
3150        cell_names = list(combo_matrix.index)
3151        num_cells = len(cell_names)
3152        palette = sns.color_palette("tab20c", num_cells)
3153
3154        if "#" in cell_names[0]:
3155            avsets = set(
3156                [re.sub(".*# ", "", x) for x in similarity_data["cell1"]]
3157                + [re.sub(".*# ", "", x) for x in similarity_data["cell2"]]
3158            )
3159            num_sets = len(avsets)
3160            color_indices = [i * len(palette) // num_sets for i in range(num_sets)]
3161            color_mapping_sets = {
3162                set_name: palette[i] for i, set_name in zip(color_indices, avsets)
3163            }
3164            color_mapping = {
3165                name: color_mapping_sets[re.sub(".*# ", "", name)]
3166                for i, name in enumerate(cell_names)
3167            }
3168        else:
3169            color_mapping = {name: palette[i] for i, name in enumerate(cell_names)}
3170
3171        meanshift = MeanShift(bandwidth=bandwidth)
3172        labels = meanshift.fit_predict(coords)
3173
3174        fig = plt.figure(figsize=(width, height))
3175        ax = plt.gca()
3176
3177        unique_labels = set(labels)
3178        cluster_palette = sns.color_palette("hls", len(unique_labels))
3179
3180        for label in unique_labels:
3181            if label == -1:
3182                continue
3183            cluster_coords = coords[labels == label]
3184            if len(cluster_coords) < 3:
3185                continue
3186
3187            hull = ConvexHull(cluster_coords)
3188            hull_points = cluster_coords[hull.vertices]
3189
3190            centroid = np.mean(hull_points, axis=0)
3191            expanded = hull_points + 0.05 * (hull_points - centroid)
3192
3193            poly = Polygon(
3194                expanded,
3195                closed=True,
3196                facecolor=cluster_palette[label],
3197                edgecolor="none",
3198                alpha=0.2,
3199                zorder=1,
3200            )
3201            ax.add_patch(poly)
3202
3203        texts = []
3204        for i, (x, y) in enumerate(coords):
3205            plt.scatter(
3206                x,
3207                y,
3208                s=point_size,
3209                color=color_mapping[cell_names[i]],
3210                edgecolors="black",
3211                linewidths=0.5,
3212                zorder=2,
3213            )
3214            texts.append(
3215                ax.text(
3216                    x, y, str(i), ha="center", va="center", fontsize=8, color="black"
3217                )
3218            )
3219
3220        def dist(p1, p2):
3221            return np.sqrt((p1[0] - p2[0]) ** 2 + (p1[1] - p2[1]) ** 2)
3222
3223        texts_to_adjust = []
3224        for i, t1 in enumerate(texts):
3225            for j, t2 in enumerate(texts):
3226                if i >= j:
3227                    continue
3228                d = dist(
3229                    (t1.get_position()[0], t1.get_position()[1]),
3230                    (t2.get_position()[0], t2.get_position()[1]),
3231                )
3232                if d < threshold:
3233                    if t1 not in texts_to_adjust:
3234                        texts_to_adjust.append(t1)
3235                    if t2 not in texts_to_adjust:
3236                        texts_to_adjust.append(t2)
3237
3238        adjust_text(
3239            texts_to_adjust,
3240            expand_text=(1.0, 1.0),
3241            force_text=0.9,
3242            arrowprops=dict(arrowstyle="-", color="gray", lw=0.1),
3243            ax=ax,
3244        )
3245
3246        for _, row in arrow_df.iterrows():
3247            try:
3248                idx1 = cell_names.index(row["cell1"])
3249                idx2 = cell_names.index(row["cell2"])
3250            except ValueError:
3251                continue
3252            x1, y1 = coords[idx1]
3253            x2, y2 = coords[idx2]
3254            arrow = FancyArrowPatch(
3255                (x1, y1),
3256                (x2, y2),
3257                arrowstyle="->",
3258                color="gray",
3259                linewidth=1.5,
3260                alpha=0.5,
3261                mutation_scale=12,
3262                zorder=0,
3263            )
3264            ax.add_patch(arrow)
3265
3266        if set_info is False and " # " in cell_names[0]:
3267
3268            legend_elements = [
3269                Patch(
3270                    facecolor=color_mapping[name],
3271                    edgecolor="black",
3272                    label=f"{i}{re.sub(' #.*', '', name)}",
3273                )
3274                for i, name in enumerate(cell_names)
3275            ]
3276
3277        else:
3278
3279            legend_elements = [
3280                Patch(
3281                    facecolor=color_mapping[name],
3282                    edgecolor="black",
3283                    label=f"{i}{name}",
3284                )
3285                for i, name in enumerate(cell_names)
3286            ]
3287
3288        plt.legend(
3289            handles=legend_elements,
3290            title="Cells",
3291            bbox_to_anchor=(1.05, 1),
3292            loc="upper left",
3293            ncol=legend_split,
3294        )
3295
3296        plt.xlabel("UMAP 1")
3297        plt.ylabel("UMAP 2")
3298        plt.grid(False)
3299        plt.show()
3300
3301        return fig
3302
3303    # subclusters part
3304
3305    def subcluster_prepare(self, features: list, cluster: str):
3306        """
3307        Prepare a `Clustering` object for subcluster analysis on a selected parent cluster.
3308
3309        Parameters
3310        ----------
3311        features : list
3312            Features to include for subcluster analysis.
3313
3314        cluster : str
3315            Parent cluster name (used to select matching cells).
3316
3317        Update
3318        ------
3319        Initializes `self.subclusters_` as a new `Clustering` instance containing the
3320        reduced data for the given cluster and stores `current_features` and `current_cluster`.
3321        """
3322
3323        dat = self.normalized_data
3324        dat.columns = list(self.input_metadata["cell_names"])
3325
3326        dat = reduce_data(self.normalized_data, features=features, names=[cluster])
3327
3328        self.subclusters_ = Clustering(data=dat, metadata=None)
3329
3330        self.subclusters_.current_features = features
3331        self.subclusters_.current_cluster = cluster
3332
3333    def define_subclusters(
3334        self,
3335        umap_num: int = 2,
3336        eps: float = 0.5,
3337        min_samples: int = 10,
3338        n_neighbors: int = 5,
3339        min_dist: float = 0.1,
3340        spread: float = 1.0,
3341        set_op_mix_ratio: float = 1.0,
3342        local_connectivity: int = 1,
3343        repulsion_strength: float = 1.0,
3344        negative_sample_rate: int = 5,
3345        width=8,
3346        height=6,
3347    ):
3348        """
3349        Compute UMAP and DBSCAN clustering within a previously prepared subcluster dataset.
3350
3351        Parameters
3352        ----------
3353        umap_num : int, default 2
3354            Number of UMAP dimensions to compute.
3355
3356        eps : float, default 0.5
3357            DBSCAN eps parameter.
3358
3359        min_samples : int, default 10
3360            DBSCAN min_samples parameter.
3361
3362        n_neighbors, min_dist, spread, set_op_mix_ratio, local_connectivity, repulsion_strength, negative_sample_rate, width, height :
3363            Additional parameters passed to UMAP / plotting / MeanShift as appropriate.
3364
3365        Update
3366        -------
3367        Stores cluster labels in `self.subclusters_.subclusters`.
3368
3369        Raises
3370        ------
3371        RuntimeError
3372            If `self.subclusters_` has not been prepared.
3373        """
3374
3375        if self.subclusters_ is None:
3376            raise RuntimeError(
3377                "Nothing to return. 'self.subcluster_prepare' was not conducted!"
3378            )
3379
3380        self.subclusters_.perform_UMAP(
3381            factorize=False,
3382            umap_num=umap_num,
3383            pc_num=0,
3384            harmonized=False,
3385            n_neighbors=n_neighbors,
3386            min_dist=min_dist,
3387            spread=spread,
3388            set_op_mix_ratio=set_op_mix_ratio,
3389            local_connectivity=local_connectivity,
3390            repulsion_strength=repulsion_strength,
3391            negative_sample_rate=negative_sample_rate,
3392            width=width,
3393            height=height,
3394        )
3395
3396        fig = self.subclusters_.find_clusters_UMAP(
3397            umap_n=umap_num,
3398            eps=eps,
3399            min_samples=min_samples,
3400            width=width,
3401            height=height,
3402        )
3403
3404        clusters = self.subclusters_.return_clusters(clusters="umap")
3405
3406        self.subclusters_.subclusters = [str(x) for x in list(clusters)]
3407
3408        return fig
3409
3410    def subcluster_features_scatter(
3411        self,
3412        colors="viridis",
3413        hclust="complete",
3414        scale=False,
3415        img_width=3,
3416        img_high=5,
3417        label_size=6,
3418        size_scale=70,
3419        y_lab="Genes",
3420        legend_lab="normalized",
3421        bbox_to_anchor_scale: int = 25,
3422        bbox_to_anchor_perc: tuple = (0.91, 0.63),
3423    ):
3424        """
3425        Create a features-scatter visualization for the subclusters (averaged and occurrence).
3426
3427        Parameters
3428        ----------
3429        colors : str, default 'viridis'
3430            Colormap name passed to `features_scatter`.
3431
3432        hclust : str or None
3433            Hierarchical clustering linkage to order rows/columns.
3434
3435        scale: bool, default False
3436            If True, expression data will be scaled (0–1) across the rows (features).
3437
3438        img_width, img_high : float
3439            Figure size.
3440
3441        label_size : int
3442            Font size for labels.
3443
3444        size_scale : int
3445            Bubble size scaling.
3446
3447        y_lab : str
3448            X axis label.
3449
3450        legend_lab : str
3451            Colorbar label.
3452
3453        bbox_to_anchor_scale : int, default=25
3454            Vertical scale (percentage) for positioning the colorbar.
3455
3456        bbox_to_anchor_perc : tuple, default=(0.91, 0.63)
3457            Anchor position for the size legend (percent bubble legend).
3458
3459        Returns
3460        -------
3461        matplotlib.figure.Figure
3462
3463        Raises
3464        ------
3465        RuntimeError
3466            If subcluster preparation/definition has not been run.
3467        """
3468
3469        if self.subclusters_ is None or self.subclusters_.subclusters is None:
3470            raise RuntimeError(
3471                "Nothing to return. 'self.subcluster_prepare' -> 'self.define_subclusters' pip was not conducted!"
3472            )
3473
3474        dat = self.normalized_data
3475        dat.columns = list(self.input_metadata["cell_names"])
3476
3477        dat = reduce_data(
3478            self.normalized_data,
3479            features=self.subclusters_.current_features,
3480            names=[self.subclusters_.current_cluster],
3481        )
3482
3483        dat.columns = self.subclusters_.subclusters
3484
3485        avg = average(dat)
3486        occ = occurrence(dat)
3487
3488        scatter = features_scatter(
3489            expression_data=avg,
3490            occurence_data=occ,
3491            features=None,
3492            scale=scale,
3493            metadata_list=None,
3494            colors=colors,
3495            hclust=hclust,
3496            img_width=img_width,
3497            img_high=img_high,
3498            label_size=label_size,
3499            size_scale=size_scale,
3500            y_lab=y_lab,
3501            legend_lab=legend_lab,
3502            bbox_to_anchor_scale=bbox_to_anchor_scale,
3503            bbox_to_anchor_perc=bbox_to_anchor_perc,
3504        )
3505
3506        return scatter
3507
3508    def subcluster_DEG_scatter(
3509        self,
3510        top_n=3,
3511        min_exp=0,
3512        min_pct=0.25,
3513        p_val=0.05,
3514        colors="viridis",
3515        hclust="complete",
3516        scale=False,
3517        img_width=3,
3518        img_high=5,
3519        label_size=6,
3520        size_scale=70,
3521        y_lab="Genes",
3522        legend_lab="normalized",
3523        bbox_to_anchor_scale: int = 25,
3524        bbox_to_anchor_perc: tuple = (0.91, 0.63),
3525        n_proc=10,
3526    ):
3527        """
3528        Plot top differential features (DEGs) for subclusters as a features-scatter.
3529
3530        Parameters
3531        ----------
3532        top_n : int, default 3
3533            Number of top features per subcluster to show.
3534
3535        min_exp : float, default 0
3536            Minimum expression threshold passed to `statistic`.
3537
3538        min_pct : float, default 0.25
3539            Minimum percent expressed in target group.
3540
3541        p_val: float, default 0.05
3542            Maximum p-value for visualizing features.
3543
3544        n_proc : int, default 10
3545            Parallel jobs used for DEG calculation.
3546
3547        scale: bool, default False
3548            If True, expression_data will be scaled (0–1) across the rows (features).
3549
3550        colors : str, default='viridis'
3551            Colormap for expression values.
3552
3553        hclust : str or None, default='complete'
3554            Linkage method for hierarchical clustering. If None, no clustering
3555            is performed.
3556
3557        img_width : int or float, default=8
3558            Width of the plot in inches.
3559
3560        img_high : int or float, default=5
3561            Height of the plot in inches.
3562
3563        label_size : int, default=10
3564            Font size for axis labels and ticks.
3565
3566        size_scale : int or float, default=100
3567            Scaling factor for bubble sizes.
3568
3569        y_lab : str, default='Genes'
3570            Label for the x-axis.
3571
3572        legend_lab : str, default='normalized'
3573            Label for the colorbar legend.
3574
3575        bbox_to_anchor_scale : int, default=25
3576            Vertical scale (percentage) for positioning the colorbar.
3577
3578        bbox_to_anchor_perc : tuple, default=(0.91, 0.63)
3579            Anchor position for the size legend (percent bubble legend).
3580
3581        Returns
3582        -------
3583        matplotlib.figure.Figure
3584
3585        Raises
3586        ------
3587        RuntimeError
3588            If subcluster preparation/definition has not been run.
3589
3590        Notes
3591        -----
3592        Internally calls `calc_DEG` (or equivalent) to obtain statistics, filters
3593        by p-value and effect-size, selects top features per valid group and plots them.
3594        """
3595
3596        if self.subclusters_ is None or self.subclusters_.subclusters is None:
3597            raise RuntimeError(
3598                "Nothing to return. 'self.subcluster_prepare' -> 'self.define_subclusters' pip was not conducted!"
3599            )
3600
3601        dat = self.normalized_data
3602        dat.columns = list(self.input_metadata["cell_names"])
3603
3604        dat = reduce_data(
3605            self.normalized_data, names=[self.subclusters_.current_cluster]
3606        )
3607
3608        dat.columns = self.subclusters_.subclusters
3609
3610        deg_stats = calc_DEG(
3611            dat,
3612            metadata_list=None,
3613            entities="All",
3614            sets=None,
3615            min_exp=min_exp,
3616            min_pct=min_pct,
3617            n_proc=n_proc,
3618        )
3619
3620        deg_stats = deg_stats[deg_stats["p_val"] <= p_val]
3621        deg_stats = deg_stats[deg_stats["log(FC)"] > 0]
3622
3623        deg_stats = (
3624            deg_stats.sort_values(
3625                ["valid_group", "esm", "log(FC)"], ascending=[True, False, False]
3626            )
3627            .groupby("valid_group")
3628            .head(top_n)
3629        )
3630
3631        dat = reduce_data(dat, features=list(set(deg_stats["feature"])))
3632
3633        avg = average(dat)
3634        occ = occurrence(dat)
3635
3636        scatter = features_scatter(
3637            expression_data=avg,
3638            occurence_data=occ,
3639            features=None,
3640            metadata_list=None,
3641            colors=colors,
3642            hclust=hclust,
3643            img_width=img_width,
3644            img_high=img_high,
3645            label_size=label_size,
3646            size_scale=size_scale,
3647            y_lab=y_lab,
3648            legend_lab=legend_lab,
3649            bbox_to_anchor_scale=bbox_to_anchor_scale,
3650            bbox_to_anchor_perc=bbox_to_anchor_perc,
3651        )
3652
3653        return scatter
3654
3655    def accept_subclusters(self):
3656        """
3657        Commit subcluster labels into the main `input_metadata` by renaming cell names.
3658
3659        The method replaces occurrences of the parent cluster name in `self.input_metadata['cell_names']`
3660        with the expanded names that include subcluster suffixes (via `add_subnames`),
3661        then clears `self.subclusters_`.
3662
3663        Update
3664        ------
3665        Modifies `self.input_metadata['cell_names']`.
3666
3667        Resets `self.subclusters_` to None.
3668
3669        Raises
3670        ------
3671        RuntimeError
3672            If `self.subclusters_` is not defined or subclusters were not computed.
3673        """
3674
3675        if self.subclusters_ is None or self.subclusters_.subclusters is None:
3676            raise RuntimeError(
3677                "Nothing to return. 'self.subcluster_prepare' -> 'self.define_subclusters' pip was not conducted!"
3678            )
3679
3680        new_meta = add_subnames(
3681            list(self.input_metadata["cell_names"]),
3682            parent_name=self.subclusters_.current_cluster,
3683            new_clusters=self.subclusters_.subclusters,
3684        )
3685
3686        self.input_metadata["cell_names"] = new_meta
3687
3688        self.subclusters_ = None
3689
3690    def scatter_plot(
3691        self,
3692        names: list | None = None,
3693        features: list | None = None,
3694        name_slot: str = "cell_names",
3695        scale=True,
3696        colors="viridis",
3697        hclust=None,
3698        img_width=15,
3699        img_high=1,
3700        label_size=10,
3701        size_scale=200,
3702        y_lab="Genes",
3703        legend_lab="log(CPM + 1)",
3704        set_box_size: float | int = 5,
3705        set_box_high: float | int = 5,
3706        bbox_to_anchor_scale=25,
3707        bbox_to_anchor_perc=(0.90, -0.24),
3708        bbox_to_anchor_group=(1.01, 0.4),
3709    ):
3710        """
3711        Create a bubble scatter plot of selected features across samples inside project.
3712
3713        Each point represents a feature-sample pair, where the color encodes the
3714        expression value and the size encodes occurrence or relative abundance.
3715        Optionally, hierarchical clustering can be applied to order rows and columns.
3716
3717        Parameters
3718        ----------
3719        names : list, str, or None
3720            Names of samples to include. If None, all samples are considered.
3721
3722        features : list, str, or None
3723            Names of features to include. If None, all features are considered.
3724
3725        name_slot : str
3726            Column in metadata to use as sample names.
3727
3728        scale: bool, default False
3729            If True, expression_data will be scaled (0–1) across the rows (features).
3730
3731        colors : str, default='viridis'
3732            Colormap for expression values.
3733
3734        hclust : str or None, default='complete'
3735            Linkage method for hierarchical clustering. If None, no clustering
3736            is performed.
3737
3738        img_width : int or float, default=8
3739            Width of the plot in inches.
3740
3741        img_high : int or float, default=5
3742            Height of the plot in inches.
3743
3744        label_size : int, default=10
3745            Font size for axis labels and ticks.
3746
3747        size_scale : int or float, default=100
3748            Scaling factor for bubble sizes.
3749
3750        y_lab : str, default='Genes'
3751            Label for the x-axis.
3752
3753        legend_lab : str, default='log(CPM + 1)'
3754            Label for the colorbar legend.
3755
3756        bbox_to_anchor_scale : int, default=25
3757            Vertical scale (percentage) for positioning the colorbar.
3758
3759        bbox_to_anchor_perc : tuple, default=(0.91, 0.63)
3760            Anchor position for the size legend (percent bubble legend).
3761
3762        bbox_to_anchor_group : tuple, default=(1.01, 0.4)
3763            Anchor position for the group legend.
3764
3765        Returns
3766        -------
3767        matplotlib.figure.Figure
3768            The generated scatter plot figure.
3769
3770        Notes
3771        -----
3772        Colors represent expression values normalized to the colormap.
3773        """
3774
3775        prtd, met = self.get_partial_data(
3776            names=names, features=features, name_slot=name_slot, inc_metadata=True
3777        )
3778
3779        prtd.columns = prtd.columns + "#" + met["sets"]
3780
3781        prtd_avg = average(prtd)
3782
3783        meta_sets = [re.sub(".*#", "", x) for x in prtd_avg.columns]
3784
3785        prtd_avg.columns = [re.sub("#.*", "", x) for x in prtd_avg.columns]
3786
3787        prtd_occ = occurrence(prtd)
3788
3789        prtd_occ.columns = [re.sub("#.*", "", x) for x in prtd_occ.columns]
3790
3791        fig_scatter = features_scatter(
3792            expression_data=prtd_avg,
3793            occurence_data=prtd_occ,
3794            scale=scale,
3795            features=None,
3796            metadata_list=meta_sets,
3797            colors=colors,
3798            hclust=hclust,
3799            img_width=img_width,
3800            img_high=img_high,
3801            label_size=label_size,
3802            size_scale=size_scale,
3803            y_lab=y_lab,
3804            legend_lab=legend_lab,
3805            set_box_size=set_box_size,
3806            set_box_high=set_box_high,
3807            bbox_to_anchor_scale=bbox_to_anchor_scale,
3808            bbox_to_anchor_perc=bbox_to_anchor_perc,
3809            bbox_to_anchor_group=bbox_to_anchor_group,
3810        )
3811
3812        return fig_scatter
3813
3814    def data_composition(
3815        self,
3816        features_count: list | None,
3817        name_slot: str = "cell_names",
3818        set_sep: bool = True,
3819    ):
3820        """
3821        Compute composition of cell types in data set.
3822
3823        This function counts the occurrences of specific cells (e.g., cell types, subtypes)
3824        within metadata entries, calculates their relative percentages, and stores
3825        the results in `self.composition_data`.
3826
3827        Parameters
3828        ----------
3829        features_count : list or None
3830            List of features (part or full names) to be counted.
3831            If None, all unique elements from the specified `name_slot` metadata field are used.
3832
3833        name_slot : str, default 'cell_names'
3834            Metadata field containing sample identifiers or labels.
3835
3836        set_sep : bool, default True
3837            If True and multiple sets are present in metadata, compute composition
3838            separately for each set.
3839
3840        Update
3841        -------
3842        Stores results in `self.composition_data` as a pandas DataFrame with:
3843        - 'name': feature name
3844        - 'n': number of occurrences
3845        - 'pct': percentage of occurrences
3846        - 'set' (if applicable): dataset identifier
3847        """
3848
3849        validated_list = list(self.input_metadata[name_slot])
3850        sets = list(self.input_metadata["sets"])
3851
3852        if features_count is None:
3853            features_count = list(set(self.input_metadata[name_slot]))
3854
3855        if set_sep and len(set(sets)) > 1:
3856
3857            final_res = pd.DataFrame()
3858
3859            for s in set(sets):
3860                print(s)
3861
3862                mask = [True if s == x else False for x in sets]
3863
3864                tmp_val_list = np.array(validated_list)
3865
3866                tmp_val_list = list(tmp_val_list[mask])
3867
3868                res_dict = {"name": [], "n": [], "set": []}
3869
3870                for f in tqdm(features_count):
3871                    res_dict["n"].append(
3872                        sum(1 for element in tmp_val_list if f in element)
3873                    )
3874                    res_dict["name"].append(f)
3875                    res_dict["set"].append(s)
3876                    res = pd.DataFrame(res_dict)
3877                    res["pct"] = res["n"] / sum(res["n"]) * 100
3878                    res["pct"] = res["pct"].round(2)
3879
3880                final_res = pd.concat([final_res, res])
3881
3882            res = final_res.sort_values(["set", "pct"], ascending=[True, False])
3883
3884        else:
3885
3886            res_dict = {"name": [], "n": []}
3887
3888            for f in tqdm(features_count):
3889                res_dict["n"].append(
3890                    sum(1 for element in validated_list if f in element)
3891                )
3892                res_dict["name"].append(f)
3893
3894            res = pd.DataFrame(res_dict)
3895            res["pct"] = res["n"] / sum(res["n"]) * 100
3896            res["pct"] = res["pct"].round(2)
3897
3898            res = res.sort_values("pct", ascending=False)
3899
3900        self.composition_data = res
3901
3902    def composition_pie(
3903        self,
3904        width=6,
3905        height=6,
3906        font_size=15,
3907        cmap: str = "tab20",
3908        legend_split_col: int = 1,
3909        offset_labels: float | int = 0.5,
3910        legend_bbox: tuple = (1.15, 0.95),
3911    ):
3912        """
3913        Visualize the composition of cell lineages using pie charts.
3914
3915        Generates pie charts showing the relative proportions of features stored
3916        in `self.composition_data`. If multiple sets are present, a separate
3917        chart is drawn for each set.
3918
3919        Parameters
3920        ----------
3921        width : int, default 6
3922            Width of the figure.
3923
3924        height : int, default 6
3925            Height of the figure (applied per set if multiple sets are plotted).
3926
3927        font_size : int, default 15
3928            Font size for labels and annotations.
3929
3930        cmap : str, default 'tab20'
3931            Colormap used for pie slices.
3932
3933        legend_split_col : int, default 1
3934            Number of columns in the legend.
3935
3936        offset_labels : float or int, default 0.5
3937            Spacing offset for label placement relative to pie slices.
3938
3939        legend_bbox : tuple, default (1.15, 0.95)
3940            Bounding box anchor position for the legend.
3941
3942        Returns
3943        -------
3944        matplotlib.figure.Figure
3945            Pie chart visualization of composition data.
3946        """
3947
3948        df = self.composition_data
3949
3950        if "set" in df.columns and len(set(df["set"])) > 1:
3951
3952            sets = list(set(df["set"]))
3953            fig, axes = plt.subplots(len(sets), 1, figsize=(width, height * len(sets)))
3954
3955            all_wedges = []
3956            cmap = plt.get_cmap(cmap)
3957
3958            set_nam = len(set(df["name"]))
3959
3960            legend_labels = list(set(df["name"]))
3961
3962            colors = [cmap(i / set_nam) for i in range(set_nam)]
3963
3964            cmap_dict = dict(zip(legend_labels, colors))
3965
3966            for idx, s in enumerate(sets):
3967                ax = axes[idx]
3968                tmp_df = df[df["set"] == s].reset_index(drop=True)
3969
3970                labels = [f"{row['pct']:.1f}%" for _, row in tmp_df.iterrows()]
3971
3972                wedges, _ = ax.pie(
3973                    tmp_df["n"],
3974                    startangle=90,
3975                    labeldistance=1.05,
3976                    colors=[cmap_dict[x] for x in tmp_df["name"]],
3977                    wedgeprops={"linewidth": 0.5, "edgecolor": "black"},
3978                )
3979
3980                all_wedges.extend(wedges)
3981
3982                kw = dict(arrowprops=dict(arrowstyle="-"), zorder=0, va="center")
3983                n = 0
3984                for i, p in enumerate(wedges):
3985                    ang = (p.theta2 - p.theta1) / 2.0 + p.theta1
3986                    y = np.sin(np.deg2rad(ang))
3987                    x = np.cos(np.deg2rad(ang))
3988                    horizontalalignment = {-1: "right", 1: "left"}[int(np.sign(x))]
3989                    connectionstyle = f"angle,angleA=0,angleB={ang}"
3990                    kw["arrowprops"].update({"connectionstyle": connectionstyle})
3991                    if len(labels[i]) > 0:
3992                        n += offset_labels
3993                        ax.annotate(
3994                            labels[i],
3995                            xy=(x, y),
3996                            xytext=(1.01 * x + (n * x / 4), 1.01 * y + (n * y / 4)),
3997                            horizontalalignment=horizontalalignment,
3998                            fontsize=font_size,
3999                            weight="bold",
4000                            **kw,
4001                        )
4002
4003                circle2 = plt.Circle((0, 0), 0.6, color="white", ec="black")
4004                ax.add_artist(circle2)
4005
4006                ax.text(
4007                    0,
4008                    0,
4009                    f"{s}",
4010                    ha="center",
4011                    va="center",
4012                    fontsize=font_size,
4013                    weight="bold",
4014                )
4015
4016            legend_handles = [
4017                Patch(facecolor=cmap_dict[label], edgecolor="black", label=label)
4018                for label in legend_labels
4019            ]
4020
4021            fig.legend(
4022                handles=legend_handles,
4023                loc="center right",
4024                bbox_to_anchor=legend_bbox,
4025                ncol=legend_split_col,
4026                title="",
4027            )
4028
4029            plt.tight_layout()
4030            plt.show()
4031
4032        else:
4033
4034            labels = [f"{row['pct']:.1f}%" for _, row in df.iterrows()]
4035
4036            legend_labels = [f"{row['name']}" for _, row in df.iterrows()]
4037
4038            cmap = plt.get_cmap(cmap)
4039            colors = [cmap(i / len(df)) for i in range(len(df))]
4040
4041            fig, ax = plt.subplots(
4042                figsize=(width, height), subplot_kw=dict(aspect="equal")
4043            )
4044
4045            wedges, _ = ax.pie(
4046                df["n"],
4047                startangle=90,
4048                labeldistance=1.05,
4049                colors=colors,
4050                wedgeprops={"linewidth": 0.5, "edgecolor": "black"},
4051            )
4052
4053            kw = dict(arrowprops=dict(arrowstyle="-"), zorder=0, va="center")
4054            n = 0
4055            for i, p in enumerate(wedges):
4056                ang = (p.theta2 - p.theta1) / 2.0 + p.theta1
4057                y = np.sin(np.deg2rad(ang))
4058                x = np.cos(np.deg2rad(ang))
4059                horizontalalignment = {-1: "right", 1: "left"}[int(np.sign(x))]
4060                connectionstyle = "angle,angleA=0,angleB={}".format(ang)
4061                kw["arrowprops"].update({"connectionstyle": connectionstyle})
4062                if len(labels[i]) > 0:
4063                    n += offset_labels
4064
4065                    ax.annotate(
4066                        labels[i],
4067                        xy=(x, y),
4068                        xytext=(1.01 * x + (n * x / 4), y * 1.01 + (n * y / 4)),
4069                        horizontalalignment=horizontalalignment,
4070                        fontsize=font_size,
4071                        weight="bold",
4072                        **kw,
4073                    )
4074
4075            circle2 = plt.Circle((0, 0), 0.6, color="white")
4076            circle2.set_edgecolor("black")
4077
4078            p = plt.gcf()
4079            p.gca().add_artist(circle2)
4080
4081            ax.legend(
4082                wedges,
4083                legend_labels,
4084                title="",
4085                loc="center left",
4086                bbox_to_anchor=legend_bbox,
4087                ncol=legend_split_col,
4088            )
4089
4090            plt.show()
4091
4092        return fig
4093
4094    def bar_composition(
4095        self,
4096        cmap="tab20b",
4097        width=2,
4098        height=6,
4099        font_size=15,
4100        legend_split_col: int = 1,
4101        legend_bbox: tuple = (1.3, 1),
4102    ):
4103        """
4104        Visualize the composition of cell lineages using bar plots.
4105
4106        Produces bar plots showing the distribution of features stored in
4107        `self.composition_data`. If multiple sets are present, a separate
4108        bar is drawn for each set. Percentages are annotated alongside the bars.
4109
4110        Parameters
4111        ----------
4112        cmap : str, default 'tab20b'
4113            Colormap used for stacked bars.
4114
4115        width : int, default 2
4116            Width of each subplot (per set).
4117
4118        height : int, default 6
4119            Height of the figure.
4120
4121        font_size : int, default 15
4122            Font size for labels and annotations.
4123
4124        legend_split_col : int, default 1
4125            Number of columns in the legend.
4126
4127        legend_bbox : tuple, default (1.3, 1)
4128            Bounding box anchor position for the legend.
4129
4130        Returns
4131        -------
4132        matplotlib.figure.Figure
4133            Stacked bar plot visualization of composition data.
4134        """
4135
4136        df = self.composition_data
4137        df["num"] = range(1, len(df) + 1)
4138
4139        if "set" in df.columns and len(set(df["set"])) > 1:
4140
4141            sets = list(set(df["set"]))
4142            fig, axes = plt.subplots(1, len(sets), figsize=(width * len(sets), height))
4143
4144            cmap = plt.get_cmap(cmap)
4145
4146            set_nam = len(set(df["name"]))
4147
4148            legend_labels = list(set(df["name"]))
4149
4150            colors = [cmap(i / set_nam) for i in range(set_nam)]
4151
4152            cmap_dict = dict(zip(legend_labels, colors))
4153
4154            for idx, s in enumerate(sets):
4155                ax = axes[idx]
4156
4157                tmp_df = df[df["set"] == s].reset_index(drop=True)
4158
4159                values = tmp_df["n"].values
4160                total = sum(values)
4161                values = [v / total * 100 for v in values]
4162                values = [round(v, 2) for v in values]
4163
4164                idx_max = np.argmax(values)
4165                correction = 100 - sum(values)
4166                values[idx_max] += correction
4167
4168                names = tmp_df["name"].values
4169                perc = tmp_df["pct"].values
4170                nums = tmp_df["num"].values
4171
4172                bottom = 0
4173                centers = []
4174                for name, num, val, color in zip(names, nums, values, colors):
4175                    ax.bar(s, val, bottom=bottom, color=cmap_dict[name], label=name)
4176                    centers.append(bottom + val / 2)
4177                    bottom += val
4178
4179                y_positions = np.linspace(centers[0], centers[-1], len(centers))
4180                x_text = -0.8
4181
4182                for y_label, y_center, pct, num in zip(
4183                    y_positions, centers, perc, nums
4184                ):
4185                    ax.annotate(
4186                        f"{pct:.1f}%",
4187                        xy=(0, y_center),
4188                        xycoords="data",
4189                        xytext=(x_text, y_label),
4190                        textcoords="data",
4191                        ha="right",
4192                        va="center",
4193                        fontsize=font_size,
4194                        arrowprops=dict(
4195                            arrowstyle="->",
4196                            lw=1,
4197                            color="black",
4198                            connectionstyle="angle3,angleA=0,angleB=90",
4199                        ),
4200                    )
4201
4202                ax.set_ylim(0, 100)
4203                ax.set_xlabel(s, fontsize=font_size)
4204                ax.xaxis.label.set_rotation(30)
4205
4206                ax.set_xticks([])
4207                ax.set_yticks([])
4208                for spine in ax.spines.values():
4209                    spine.set_visible(False)
4210
4211            legend_handles = [
4212                Patch(facecolor=cmap_dict[label], edgecolor="black", label=label)
4213                for label in legend_labels
4214            ]
4215
4216            fig.legend(
4217                handles=legend_handles,
4218                loc="center right",
4219                bbox_to_anchor=legend_bbox,
4220                ncol=legend_split_col,
4221                title="",
4222            )
4223
4224            plt.tight_layout()
4225            plt.show()
4226
4227        else:
4228
4229            cmap = plt.get_cmap(cmap)
4230
4231            colors = [cmap(i / len(df)) for i in range(len(df))]
4232
4233            fig, ax = plt.subplots(figsize=(width, height))
4234
4235            values = df["n"].values
4236            names = df["name"].values
4237            perc = df["pct"].values
4238            nums = df["num"].values
4239
4240            bottom = 0
4241            centers = []
4242            for name, num, val, color in zip(names, nums, values, colors):
4243                ax.bar(0, val, bottom=bottom, color=color, label=f"{num}) {name}")
4244                centers.append(bottom + val / 2)
4245                bottom += val
4246
4247            y_positions = np.linspace(centers[0], centers[-1], len(centers))
4248            x_text = -0.8
4249
4250            for y_label, y_center, pct, num in zip(y_positions, centers, perc, nums):
4251                ax.annotate(
4252                    f"{num}) {pct}",
4253                    xy=(0, y_center),
4254                    xycoords="data",
4255                    xytext=(x_text, y_label),
4256                    textcoords="data",
4257                    ha="right",
4258                    va="center",
4259                    fontsize=9,
4260                    arrowprops=dict(
4261                        arrowstyle="->",
4262                        lw=1,
4263                        color="black",
4264                        connectionstyle="angle3,angleA=0,angleB=90",
4265                    ),
4266                )
4267
4268            ax.set_xticks([])
4269            ax.set_yticks([])
4270            for spine in ax.spines.values():
4271                spine.set_visible(False)
4272
4273            ax.legend(
4274                title="Legend",
4275                bbox_to_anchor=legend_bbox,
4276                loc="upper left",
4277                ncol=legend_split_col,
4278            )
4279
4280            plt.tight_layout()
4281            plt.show()
4282
4283        return fig
4284
4285    def cell_regression(
4286        self,
4287        cell_x: str,
4288        cell_y: str,
4289        set_x: str | None,
4290        set_y: str | None,
4291        threshold=10,
4292        image_width=12,
4293        image_high=7,
4294        color="black",
4295    ):
4296        """
4297        Perform regression analysis between two selected cells and visualize the relationship.
4298
4299        This function computes a linear regression between two specified cells from
4300        aggregated normalized data, plots the regression line with scatter points,
4301        annotates regression statistics, and highlights potential outliers.
4302
4303        Parameters
4304        ----------
4305        cell_x : str
4306            Name of the first cell (X-axis).
4307
4308        cell_y : str
4309            Name of the second cell (Y-axis).
4310
4311        set_x : str or None
4312            Dataset identifier corresponding to `cell_x`. If None, cell is selected only by name.
4313
4314        set_y : str or None
4315            Dataset identifier corresponding to `cell_y`. If None, cell is selected only by name.
4316
4317        threshold : int or float, default 10
4318            Threshold for detecting outliers. Points deviating from the mean or diagonal by more
4319            than this value are annotated.
4320
4321        image_width : int, default 12
4322            Width of the regression plot (in inches).
4323
4324        image_high : int, default 7
4325            Height of the regression plot (in inches).
4326
4327        color : str, default 'black'
4328            Color of the regression scatter points and line.
4329
4330        Returns
4331        -------
4332        matplotlib.figure.Figure
4333            Regression plot figure with annotated regression line, R², p-value, and outliers.
4334
4335        Raises
4336        ------
4337        ValueError
4338            If `cell_x` or `cell_y` are not found in the dataset.
4339            If multiple matches are found for a cell name and `set_x`/`set_y` are not specified.
4340
4341        Notes
4342        -----
4343        - The function automatically calls `jseq_object.average()` if aggregated data is not available.
4344        - Outliers are annotated with their corresponding index labels.
4345        - Regression is computed using `scipy.stats.linregress`.
4346
4347        Examples
4348        --------
4349        >>> obj.cell_regression(cell_x="Purkinje", cell_y="Granule", set_x="Exp1", set_y="Exp2")
4350        >>> obj.cell_regression(cell_x="NeuronA", cell_y="NeuronB", threshold=5, color="blue")
4351        """
4352
4353        if self.agg_normalized_data is None:
4354            self.average()
4355
4356        metadata = self.agg_metadata
4357        data = self.agg_normalized_data
4358
4359        if set_x is not None and set_y is not None:
4360            data.columns = metadata["cell_names"] + " # " + metadata["sets"]
4361            cell_x = cell_x + " # " + set_x
4362            cell_y = cell_y + " # " + set_y
4363
4364        else:
4365            data.columns = metadata["cell_names"]
4366
4367        if not cell_x in data.columns:
4368            raise ValueError("'cell_x' value not in cell names!")
4369
4370        if not cell_y in data.columns:
4371            raise ValueError("'cell_y' value not in cell names!")
4372
4373        if list(data.columns).count(cell_x) > 1:
4374            raise ValueError(
4375                f"'{cell_x}' occurs more than once. If you want to select a specific cell, "
4376                f"please also provide the corresponding 'set_x' and 'set_y' values."
4377            )
4378
4379        if list(data.columns).count(cell_y) > 1:
4380            raise ValueError(
4381                f"'{cell_y}' occurs more than once. If you want to select a specific cell, "
4382                f"please also provide the corresponding 'set_x' and 'set_y' values."
4383            )
4384
4385        fig, ax = plt.subplots(figsize=(image_width, image_high))
4386        ax = sns.regplot(x=cell_x, y=cell_y, data=data, color=color)
4387
4388        slope, intercept, r_value, p_value, _ = stats.linregress(
4389            data[cell_x], data[cell_y]
4390        )
4391        equation = "y = {:.2f}x + {:.2f}".format(slope, intercept)
4392
4393        ax.annotate(
4394            "R-squared = {:.2f}\nP-value = {:.2f}\n{}".format(
4395                r_value**2, p_value, equation
4396            ),
4397            xy=(0.05, 0.90),
4398            xycoords="axes fraction",
4399            fontsize=12,
4400        )
4401
4402        ax.spines["top"].set_visible(False)
4403        ax.spines["right"].set_visible(False)
4404
4405        diff = []
4406        x_mean, y_mean = data[cell_x].mean(), data[cell_y].mean()
4407        for i, (xi, yi) in enumerate(zip(data[cell_x], data[cell_y])):
4408            diff.append(abs(xi - x_mean))
4409            diff.append(abs(yi - y_mean))
4410
4411        def annotate_outliers(x, y, threshold):
4412            texts = []
4413            x_mean, y_mean = x.mean(), y.mean()
4414            for i, (xi, yi) in enumerate(zip(x, y)):
4415                if (
4416                    abs(xi - x_mean) > threshold
4417                    or abs(yi - y_mean) > threshold
4418                    or abs(yi - xi) > threshold
4419                ):
4420                    text = ax.text(xi, yi, data.index[i])
4421                    texts.append(text)
4422
4423            return texts
4424
4425        texts = annotate_outliers(data[cell_x], data[cell_y], threshold)
4426
4427        adjust_text(texts, arrowprops=dict(arrowstyle="-", color="gray", alpha=0.5))
4428
4429        plt.show()
4430
4431        return fig

A class COMPsc (Comparison of single-cell data) designed for the integration, analysis, and visualization of single-cell datasets. The class supports independent dataset integration, subclustering of existing clusters, marker detection, and multiple visualization strategies.

The COMPsc class provides methods for:

- Normalizing and filtering single-cell data
- Loading and saving sparse 10x-style datasets
- Computing differential expression and marker genes
- Clustering and subclustering analysis
- Visualizing similarity and spatial relationships
- Aggregating data by cell and set annotations
- Managing metadata and renaming labels
- Plotting gene detection histograms and feature scatters

Methods

project_dir(path_to_directory, project_list) Scans a directory to create a COMPsc instance mapping project names to their paths.

save_project(name, path=os.getcwd()) Saves the COMPsc object to a pickle file on disk.

load_project(path) Loads a previously saved COMPsc object from a pickle file.

reduce_cols(reg, inc_set=False) Removes columns from data tables where column names contain a specified name or partial substring.

reduce_rows(reg, inc_set=False) Removes rows from data tables where column names contain a specified feature (gene) name.

get_data(set_info=False) Returns normalized data with optional set annotations in column names.

get_metadata() Returns the stored input metadata.

get_partial_data(names=None, features=None, name_slot='cell_names') Return a subset of the data by sample names and/or features.

gene_calculation() Calculates and stores per-cell gene detection counts as a pandas Series.

gene_histograme(bins=100) Plots a histogram of genes detected per cell with an overlaid normal distribution.

gene_threshold(min_n=None, max_n=None) Filters cells based on minimum and/or maximum gene detection thresholds.

load_sparse_from_projects(normalized_data=False) Loads and concatenates sparse 10x-style datasets from project paths into count or normalized data.

rename_names(mapping, slot='cell_names') Renames entries in a specified metadata column using a provided mapping dictionary.

rename_subclusters(mapping) Renames subcluster labels using a provided mapping dictionary.

save_sparse(path_to_save=os.getcwd(), name_slot='cell_names', data_slot='normalized') Exports data as 10x-compatible sparse files (matrix.mtx, barcodes.tsv, genes.tsv).

normalize_data(normalize=True, normalize_factor=100000) Normalizes raw counts to counts-per-specified factor (e.g., CPM-like).

statistic(cells=None, sets=None, min_exp=0.01, min_pct=0.1, n_proc=10) Computes per-feature differential expression statistics (Mann-Whitney U) comparing target vs. rest groups.

calculate_difference_markers(min_exp=0, min_pct=0.25, n_proc=10, force=False) Computes and caches differential markers using the statistic method.

clustering_features(features_list=None, name_slot='cell_names', p_val=0.05, top_n=25, adj_mean=True, beta=0.4) Prepares clustering input by selecting marker features and optionally smoothing cell values.

average() Aggregates normalized data by averaging across (cell_name, set) pairs.

estimating_similarity(method='pearson', p_val=0.05, top_n=25) Computes pairwise correlation and Euclidean distance between aggregated samples.

similarity_plot(split_sets=True, set_info=True, cmap='seismic', width=12, height=10) Visualizes pairwise similarity as a scatter plot with correlation as hue and scaled distance as point size.

spatial_similarity(set_info=True, bandwidth=1, n_neighbors=5, min_dist=0.1, legend_split=2, point_size=20, ...) Creates a UMAP-like visualization of similarity relationships with cluster hulls and nearest-neighbor arrows.

subcluster_prepare(features, cluster) Initializes a Clustering object for subcluster analysis on a selected parent cluster.

define_subclusters(umap_num=2, eps=0.5, min_samples=10, bandwidth=1, n_neighbors=5, min_dist=0.1, ...) Performs UMAP and DBSCAN clustering on prepared subcluster data and stores cluster labels.

subcluster_features_scatter(colors='viridis', hclust='complete', img_width=3, img_high=5, label_size=6, ...) Visualizes averaged expression and occurrence of features for subclusters as a scatter plot.

subcluster_DEG_scatter(top_n=3, min_exp=0, min_pct=0.25, p_val=0.05, colors='viridis', ...) Plots top differential features for subclusters as a features-scatter visualization.

accept_subclusters() Commits subcluster labels to main metadata by renaming cell names and clears subcluster data.

Raises

ValueError For invalid parameters, mismatched dimensions, or missing metadata.

COMPsc(objects=None)
1144    def __init__(
1145        self,
1146        objects=None,
1147    ):
1148        """
1149        Initialize the COMPsc class for single-cell data integration and analysis.
1150
1151        Parameters
1152        ----------
1153        objects : list or None, optional
1154            Optional list of data objects to initialize the instance with.
1155
1156        Attributes
1157        ----------
1158        -objects : list or None
1159        -input_data : pandas.DataFrame or None
1160        -input_metadata : pandas.DataFrame or None
1161        -normalized_data : pandas.DataFrame or None
1162        -agg_metadata : pandas.DataFrame or None
1163        -agg_normalized_data : pandas.DataFrame or None
1164        -similarity : pandas.DataFrame or None
1165        -var_data : pandas.DataFrame or None
1166        -subclusters_ : instance of Clustering class or None
1167        -cells_calc : pandas.Series or None
1168        -gene_calc : pandas.Series or None
1169        -composition_data : pandas.DataFrame or None
1170        """
1171
1172        self.objects = objects
1173        """ Stores the input data objects."""
1174
1175        self.input_data = None
1176        """Raw input data for clustering or integration analysis."""
1177
1178        self.input_metadata = None
1179        """Metadata associated with the input data."""
1180
1181        self.normalized_data = None
1182        """Normalized version of the input data."""
1183
1184        self.agg_metadata = None
1185        '''Aggregated metadata for all sets in object related to "agg_normalized_data"'''
1186
1187        self.agg_normalized_data = None
1188        """Aggregated and normalized data across multiple sets."""
1189
1190        self.similarity = None
1191        """Similarity data between cells across all samples. and sets"""
1192
1193        self.var_data = None
1194        """DEG analysis results summarizing variance across all samples in the object."""
1195
1196        self.subclusters_ = None
1197        """Placeholder for information about subclusters analysis; if computed."""
1198
1199        self.cells_calc = None
1200        """Number of cells detected per sample (grouped by lineage, e.g., cluster or name), reflecting data composition."""
1201
1202        self.gene_calc = None
1203        """Number of genes detected per sample (cell), reflecting the sequencing depth."""
1204
1205        self.composition_data = None
1206        """Data describing composition of cells across clusters or sets."""

Initialize the COMPsc class for single-cell data integration and analysis.

Parameters

objects : list or None, optional Optional list of data objects to initialize the instance with.

Attributes

-objects : list or None -input_data : pandas.DataFrame or None -input_metadata : pandas.DataFrame or None -normalized_data : pandas.DataFrame or None -agg_metadata : pandas.DataFrame or None -agg_normalized_data : pandas.DataFrame or None -similarity : pandas.DataFrame or None -var_data : pandas.DataFrame or None -subclusters_ : instance of Clustering class or None -cells_calc : pandas.Series or None -gene_calc : pandas.Series or None -composition_data : pandas.DataFrame or None

objects

Stores the input data objects.

input_data

Raw input data for clustering or integration analysis.

input_metadata

Metadata associated with the input data.

normalized_data

Normalized version of the input data.

agg_metadata

Aggregated metadata for all sets in object related to "agg_normalized_data"

agg_normalized_data

Aggregated and normalized data across multiple sets.

similarity

Similarity data between cells across all samples. and sets

var_data

DEG analysis results summarizing variance across all samples in the object.

subclusters_

Placeholder for information about subclusters analysis; if computed.

cells_calc

Number of cells detected per sample (grouped by lineage, e.g., cluster or name), reflecting data composition.

gene_calc

Number of genes detected per sample (cell), reflecting the sequencing depth.

composition_data

Data describing composition of cells across clusters or sets.

@classmethod
def project_dir(cls, path_to_directory, project_list):
1208    @classmethod
1209    def project_dir(cls, path_to_directory, project_list):
1210        """
1211        Scan a directory and build a COMPsc instance mapping provided project names
1212        to their paths.
1213
1214        Parameters
1215        ----------
1216        path_to_directory : str
1217            Path containing project subfolders.
1218
1219        project_list : list[str]
1220            List of filenames (folder names) to include in the returned object map.
1221
1222        Returns
1223        -------
1224        COMPsc
1225            New COMPsc instance with `objects` populated.
1226
1227        Raises
1228        ------
1229        Exception
1230            A generic exception is caught and a message printed if scanning fails.
1231
1232        Notes
1233        -----
1234        Function attempts to match entries in `project_list` to directory
1235        names and constructs a simplified object key from the folder name.
1236        """
1237        try:
1238            objects = {}
1239            for filename in tqdm(os.listdir(path_to_directory)):
1240                for c in project_list:
1241                    f = os.path.join(path_to_directory, filename)
1242                    if c == filename and os.path.isdir(f):
1243                        objects[str(c)] = f
1244
1245            return cls(objects)
1246
1247        except:
1248            print("Something went wrong. Check the function input data and try again!")

Scan a directory and build a COMPsc instance mapping provided project names to their paths.

Parameters

path_to_directory : str Path containing project subfolders.

project_list : list[str] List of filenames (folder names) to include in the returned object map.

Returns

COMPsc New COMPsc instance with objects populated.

Raises

Exception A generic exception is caught and a message printed if scanning fails.

Notes

Function attempts to match entries in project_list to directory names and constructs a simplified object key from the folder name.

def save_project(self, name, path: str = '/mnt/c/Users/merag/Git/JDti'):
1250    def save_project(self, name, path: str = os.getcwd()):
1251        """
1252        Save the COMPsc object to disk using pickle.
1253
1254        Parameters
1255        ----------
1256        name : str
1257            Base filename (without extension) to use when saving.
1258
1259        path : str, default os.getcwd()
1260            Directory in which to save the project file.
1261
1262        Returns
1263        -------
1264        None
1265
1266        Side Effects
1267        ------------
1268        - Writes a file `<path>/<name>.jpkl` containing the pickled object.
1269        - Prints a confirmation message with saved path.
1270        """
1271
1272        full = os.path.join(path, f"{name}.jpkl")
1273
1274        with open(full, "wb") as f:
1275            pickle.dump(self, f)
1276
1277        print(f"Project saved as {full}")

Save the COMPsc object to disk using pickle.

Parameters

name : str Base filename (without extension) to use when saving.

path : str, default os.getcwd() Directory in which to save the project file.

Returns

None

Side Effects

  • Writes a file <path>/<name>.jpkl containing the pickled object.
  • Prints a confirmation message with saved path.
@classmethod
def load_project(cls, path):
1279    @classmethod
1280    def load_project(cls, path):
1281        """
1282        Load a previously saved COMPsc project from a pickle file.
1283
1284        Parameters
1285        ----------
1286        path : str
1287            Full path to the pickled project file.
1288
1289        Returns
1290        -------
1291        COMPsc
1292            The unpickled COMPsc object.
1293
1294        Raises
1295        ------
1296        FileNotFoundError
1297            If the provided path does not exist.
1298        """
1299
1300        if not os.path.exists(path):
1301            raise FileNotFoundError("File does not exist!")
1302        with open(path, "rb") as f:
1303            obj = pickle.load(f)
1304        return obj

Load a previously saved COMPsc project from a pickle file.

Parameters

path : str Full path to the pickled project file.

Returns

COMPsc The unpickled COMPsc object.

Raises

FileNotFoundError If the provided path does not exist.

def reduce_cols( self, reg: str | None = None, full: str | None = None, name_slot: str = 'cell_names', inc_set: bool = False):
1306    def reduce_cols(
1307        self,
1308        reg: str | None = None,
1309        full: str | None = None,
1310        name_slot: str = "cell_names",
1311        inc_set: bool = False,
1312    ):
1313        """
1314        Remove columns (cells) whose names contain a substring `reg` or
1315        full name `full` from available tables.
1316
1317        Parameters
1318        ----------
1319        reg : str | None
1320            Substring to search for in column/cell names; matching columns will be removed.
1321            If not None, `full` must be None.
1322
1323        full : str | None
1324            Full name to search for in column/cell names; matching columns will be removed.
1325            If not None, `reg` must be None.
1326
1327        name_slot : str, default 'cell_names'
1328            Column in metadata to use as sample names.
1329
1330        inc_set : bool, default False
1331            If True, column names are interpreted as 'cell_name # set' when matching.
1332
1333        Update
1334        ------------
1335        Mutates `self.input_data`, `self.normalized_data`, `self.input_metadata`,
1336        `self.agg_normalized_data`, and `self.agg_metadata` (if they exist),
1337        removing columns/rows that match `reg`.
1338
1339        Raises
1340        ------
1341        Raises ValueError if nothing matches the reduction mask.
1342        """
1343
1344        if reg is None and full is None:
1345            raise ValueError(
1346                "Both 'reg' and 'full' arguments not provided. Please provide at least one of them!"
1347            )
1348
1349        if reg is not None and full is not None:
1350            raise ValueError(
1351                "Both 'reg' and 'full' arguments are provided. "
1352                "Please provide only one of them!\n"
1353                "'reg' is used when only part of the name must be detected.\n"
1354                "'full' is used if the full name must be detected."
1355            )
1356
1357        if reg is not None:
1358
1359            if self.input_data is not None:
1360
1361                if inc_set:
1362
1363                    self.input_data.columns = (
1364                        self.input_metadata[name_slot]
1365                        + " # "
1366                        + self.input_metadata["sets"]
1367                    )
1368
1369                else:
1370
1371                    self.input_data.columns = self.input_metadata[name_slot]
1372
1373                mask = [reg.upper() not in x.upper() for x in self.input_data.columns]
1374
1375                if len([y for y in mask if y is False]) == 0:
1376                    raise ValueError("Nothing found to reduce")
1377
1378                self.input_data = self.input_data.loc[:, mask]
1379
1380            if self.normalized_data is not None:
1381
1382                if inc_set:
1383
1384                    self.normalized_data.columns = (
1385                        self.input_metadata[name_slot]
1386                        + " # "
1387                        + self.input_metadata["sets"]
1388                    )
1389
1390                else:
1391
1392                    self.normalized_data.columns = self.input_metadata[name_slot]
1393
1394                mask = [
1395                    reg.upper() not in x.upper() for x in self.normalized_data.columns
1396                ]
1397
1398                if len([y for y in mask if y is False]) == 0:
1399                    raise ValueError("Nothing found to reduce")
1400
1401                self.normalized_data = self.normalized_data.loc[:, mask]
1402
1403            if self.input_metadata is not None:
1404
1405                if inc_set:
1406
1407                    self.input_metadata["drop"] = (
1408                        self.input_metadata[name_slot]
1409                        + " # "
1410                        + self.input_metadata["sets"]
1411                    )
1412
1413                else:
1414
1415                    self.input_metadata["drop"] = self.input_metadata[name_slot]
1416
1417                mask = [
1418                    reg.upper() not in x.upper() for x in self.input_metadata["drop"]
1419                ]
1420
1421                if len([y for y in mask if y is False]) == 0:
1422                    raise ValueError("Nothing found to reduce")
1423
1424                self.input_metadata = self.input_metadata.loc[mask, :].reset_index(
1425                    drop=True
1426                )
1427
1428                self.input_metadata = self.input_metadata.drop(
1429                    columns=["drop"], errors="ignore"
1430                )
1431
1432            if self.agg_normalized_data is not None:
1433
1434                if inc_set:
1435
1436                    self.agg_normalized_data.columns = (
1437                        self.agg_metadata[name_slot] + " # " + self.agg_metadata["sets"]
1438                    )
1439
1440                else:
1441
1442                    self.agg_normalized_data.columns = self.agg_metadata[name_slot]
1443
1444                mask = [
1445                    reg.upper() not in x.upper()
1446                    for x in self.agg_normalized_data.columns
1447                ]
1448
1449                if len([y for y in mask if y is False]) == 0:
1450                    raise ValueError("Nothing found to reduce")
1451
1452                self.agg_normalized_data = self.agg_normalized_data.loc[:, mask]
1453
1454            if self.agg_metadata is not None:
1455
1456                if inc_set:
1457
1458                    self.agg_metadata["drop"] = (
1459                        self.agg_metadata[name_slot] + " # " + self.agg_metadata["sets"]
1460                    )
1461
1462                else:
1463
1464                    self.agg_metadata["drop"] = self.agg_metadata[name_slot]
1465
1466                mask = [reg.upper() not in x.upper() for x in self.agg_metadata["drop"]]
1467
1468                if len([y for y in mask if y is False]) == 0:
1469                    raise ValueError("Nothing found to reduce")
1470
1471                self.agg_metadata = self.agg_metadata.loc[mask, :].reset_index(
1472                    drop=True
1473                )
1474
1475                self.agg_metadata = self.agg_metadata.drop(
1476                    columns=["drop"], errors="ignore"
1477                )
1478
1479        elif full is not None:
1480
1481            if self.input_data is not None:
1482
1483                if inc_set:
1484
1485                    self.input_data.columns = (
1486                        self.input_metadata[name_slot]
1487                        + " # "
1488                        + self.input_metadata["sets"]
1489                    )
1490
1491                    if "#" not in full:
1492
1493                        self.input_data.columns = self.input_metadata[name_slot]
1494
1495                        print(
1496                            "Not include the set info (name # set) in the 'full' argument, where 'inc_set' is True"
1497                            "Only the names will be compared, without considering the set information."
1498                        )
1499
1500                else:
1501
1502                    self.input_data.columns = self.input_metadata[name_slot]
1503
1504                mask = [full.upper() != x.upper() for x in self.input_data.columns]
1505
1506                if len([y for y in mask if y is False]) == 0:
1507                    raise ValueError("Nothing found to reduce")
1508
1509                self.input_data = self.input_data.loc[:, mask]
1510
1511            if self.normalized_data is not None:
1512
1513                if inc_set:
1514
1515                    self.normalized_data.columns = (
1516                        self.input_metadata[name_slot]
1517                        + " # "
1518                        + self.input_metadata["sets"]
1519                    )
1520
1521                    if "#" not in full:
1522
1523                        self.normalized_data.columns = self.input_metadata[name_slot]
1524
1525                        print(
1526                            "Not include the set info (name # set) in the 'full' argument, where 'inc_set' is True"
1527                            "Only the names will be compared, without considering the set information."
1528                        )
1529
1530                else:
1531
1532                    self.normalized_data.columns = self.input_metadata[name_slot]
1533
1534                mask = [full.upper() != x.upper() for x in self.normalized_data.columns]
1535
1536                if len([y for y in mask if y is False]) == 0:
1537                    raise ValueError("Nothing found to reduce")
1538
1539                self.normalized_data = self.normalized_data.loc[:, mask]
1540
1541            if self.input_metadata is not None:
1542
1543                if inc_set:
1544
1545                    self.input_metadata["drop"] = (
1546                        self.input_metadata[name_slot]
1547                        + " # "
1548                        + self.input_metadata["sets"]
1549                    )
1550
1551                    if "#" not in full:
1552
1553                        self.input_metadata["drop"] = self.input_metadata[name_slot]
1554
1555                        print(
1556                            "Not include the set info (name # set) in the 'full' argument, where 'inc_set' is True"
1557                            "Only the names will be compared, without considering the set information."
1558                        )
1559
1560                else:
1561
1562                    self.input_metadata["drop"] = self.input_metadata[name_slot]
1563
1564                mask = [full.upper() != x.upper() for x in self.input_metadata["drop"]]
1565
1566                if len([y for y in mask if y is False]) == 0:
1567                    raise ValueError("Nothing found to reduce")
1568
1569                self.input_metadata = self.input_metadata.loc[mask, :].reset_index(
1570                    drop=True
1571                )
1572
1573                self.input_metadata = self.input_metadata.drop(
1574                    columns=["drop"], errors="ignore"
1575                )
1576
1577            if self.agg_normalized_data is not None:
1578
1579                if inc_set:
1580
1581                    self.agg_normalized_data.columns = (
1582                        self.agg_metadata[name_slot] + " # " + self.agg_metadata["sets"]
1583                    )
1584
1585                    if "#" not in full:
1586
1587                        self.agg_normalized_data.columns = self.agg_metadata[name_slot]
1588
1589                        print(
1590                            "Not include the set info (name # set) in the 'full' argument, where 'inc_set' is True"
1591                            "Only the names will be compared, without considering the set information."
1592                        )
1593                else:
1594
1595                    self.agg_normalized_data.columns = self.agg_metadata[name_slot]
1596
1597                mask = [
1598                    full.upper() != x.upper() for x in self.agg_normalized_data.columns
1599                ]
1600
1601                if len([y for y in mask if y is False]) == 0:
1602                    raise ValueError("Nothing found to reduce")
1603
1604                self.agg_normalized_data = self.agg_normalized_data.loc[:, mask]
1605
1606            if self.agg_metadata is not None:
1607
1608                if inc_set:
1609
1610                    self.agg_metadata["drop"] = (
1611                        self.agg_metadata[name_slot] + " # " + self.agg_metadata["sets"]
1612                    )
1613
1614                    if "#" not in full:
1615
1616                        self.agg_metadata["drop"] = self.agg_metadata[name_slot]
1617
1618                        print(
1619                            "Not include the set info (name # set) in the 'full' argument, where 'inc_set' is True"
1620                            "Only the names will be compared, without considering the set information."
1621                        )
1622                else:
1623
1624                    self.agg_metadata["drop"] = self.agg_metadata[name_slot]
1625
1626                mask = [full.upper() != x.upper() for x in self.agg_metadata["drop"]]
1627
1628                if len([y for y in mask if y is False]) == 0:
1629                    raise ValueError("Nothing found to reduce")
1630
1631                self.agg_metadata = self.agg_metadata.loc[mask, :].reset_index(
1632                    drop=True
1633                )
1634
1635                self.agg_metadata = self.agg_metadata.drop(
1636                    columns=["drop"], errors="ignore"
1637                )
1638
1639        self.gene_calculation()
1640        self.cells_calculation()

Remove columns (cells) whose names contain a substring reg or full name full from available tables.

Parameters

reg : str | None Substring to search for in column/cell names; matching columns will be removed. If not None, full must be None.

full : str | None Full name to search for in column/cell names; matching columns will be removed. If not None, reg must be None.

name_slot : str, default 'cell_names' Column in metadata to use as sample names.

inc_set : bool, default False If True, column names are interpreted as 'cell_name # set' when matching.

Update

Mutates self.input_data, self.normalized_data, self.input_metadata, self.agg_normalized_data, and self.agg_metadata (if they exist), removing columns/rows that match reg.

Raises

Raises ValueError if nothing matches the reduction mask.

def reduce_rows(self, features_list: list):
1642    def reduce_rows(self, features_list: list):
1643        """
1644        Remove rows (features) whose names are included in features_list.
1645
1646        Parameters
1647        ----------
1648        features_list : list
1649            List of features to search for in index/gene names; matching entries will be removed.
1650
1651        Update
1652        ------------
1653        Mutates `self.input_data`, `self.normalized_data`, `self.input_metadata`,
1654        `self.agg_normalized_data`, and `self.agg_metadata` (if they exist),
1655        removing columns/rows that match `reg`.
1656
1657        Raises
1658        ------
1659        Prints a message listing features that are not found in the data.
1660        """
1661
1662        if self.input_data is not None:
1663
1664            res = find_features(self.input_data, features=features_list)
1665
1666            res_list = [x.upper() for x in res["included"]]
1667
1668            mask = [x.upper() not in res_list for x in self.input_data.index]
1669
1670            if len([y for y in mask if y is False]) == 0:
1671                raise ValueError("Nothing found to reduce")
1672
1673            self.input_data = self.input_data.loc[mask, :]
1674
1675        if self.normalized_data is not None:
1676
1677            res = find_features(self.normalized_data, features=features_list)
1678
1679            res_list = [x.upper() for x in res["included"]]
1680
1681            mask = [x.upper() not in res_list for x in self.normalized_data.index]
1682
1683            if len([y for y in mask if y is False]) == 0:
1684                raise ValueError("Nothing found to reduce")
1685
1686            self.normalized_data = self.normalized_data.loc[mask, :]
1687
1688        if self.agg_normalized_data is not None:
1689
1690            res = find_features(self.agg_normalized_data, features=features_list)
1691
1692            res_list = [x.upper() for x in res["included"]]
1693
1694            mask = [x.upper() not in res_list for x in self.agg_normalized_data.index]
1695
1696            if len([y for y in mask if y is False]) == 0:
1697                raise ValueError("Nothing found to reduce")
1698
1699            self.agg_normalized_data = self.agg_normalized_data.loc[mask, :]
1700
1701        if len(res["not_included"]) > 0:
1702            print("\nFeatures not found:")
1703            for i in res["not_included"]:
1704                print(i)
1705
1706        self.gene_calculation()
1707        self.cells_calculation()

Remove rows (features) whose names are included in features_list.

Parameters

features_list : list List of features to search for in index/gene names; matching entries will be removed.

Update

Mutates self.input_data, self.normalized_data, self.input_metadata, self.agg_normalized_data, and self.agg_metadata (if they exist), removing columns/rows that match reg.

Raises

Prints a message listing features that are not found in the data.

def get_data(self, set_info: bool = False):
1709    def get_data(self, set_info: bool = False):
1710        """
1711        Return normalized data with optional set annotation appended to column names.
1712
1713        Parameters
1714        ----------
1715        set_info : bool, default False
1716            If True, column names are returned as "cell_name # set"; otherwise
1717            only the `cell_name` is used.
1718
1719        Returns
1720        -------
1721        pandas.DataFrame
1722            The `self.normalized_data` table with columns renamed according to `set_info`.
1723
1724        Raises
1725        ------
1726        AttributeError
1727            If `self.normalized_data` or `self.input_metadata` is missing.
1728        """
1729
1730        to_return = self.normalized_data
1731
1732        if set_info:
1733            to_return.columns = (
1734                self.input_metadata["cell_names"] + " # " + self.input_metadata["sets"]
1735            )
1736        else:
1737            to_return.columns = self.input_metadata["cell_names"]
1738
1739        return to_return

Return normalized data with optional set annotation appended to column names.

Parameters

set_info : bool, default False If True, column names are returned as "cell_name # set"; otherwise only the cell_name is used.

Returns

pandas.DataFrame The self.normalized_data table with columns renamed according to set_info.

Raises

AttributeError If self.normalized_data or self.input_metadata is missing.

def get_partial_data( self, names: list | str | None = None, features: list | str | None = None, name_slot: str = 'cell_names', inc_metadata: bool = False):
1741    def get_partial_data(
1742        self,
1743        names: list | str | None = None,
1744        features: list | str | None = None,
1745        name_slot: str = "cell_names",
1746        inc_metadata: bool = False,
1747    ):
1748        """
1749        Return a subset of the data filtered by sample names and/or feature names.
1750
1751        Parameters
1752        ----------
1753        names : list, str, or None
1754            Names of samples to include. If None, all samples are considered.
1755
1756        features : list, str, or None
1757            Names of features to include. If None, all features are considered.
1758
1759        name_slot : str
1760            Column in metadata to use as sample names.
1761
1762        inc_metadata : bool
1763            If True return tuple (data, metadata)
1764
1765        Returns
1766        -------
1767        pandas.DataFrame
1768            Subset of the normalized data based on the specified names and features.
1769        """
1770
1771        data = self.normalized_data.copy()
1772        metadata = self.input_metadata
1773
1774        if name_slot in self.input_metadata.columns:
1775            data.columns = self.input_metadata[name_slot]
1776        else:
1777            raise ValueError("'name_slot' not occured in data!'")
1778
1779        if isinstance(features, str):
1780            features = [features]
1781        elif features is None:
1782            features = []
1783
1784        if isinstance(names, str):
1785            names = [names]
1786        elif names is None:
1787            names = []
1788
1789        features = [x.upper() for x in features]
1790        names = [x.upper() for x in names]
1791
1792        columns_names = [x.upper() for x in data.columns]
1793        features_names = [x.upper() for x in data.index]
1794
1795        columns_bool = [True if x in names else False for x in columns_names]
1796        features_bool = [True if x in features else False for x in features_names]
1797
1798        if True not in columns_bool and True not in features_bool:
1799            print("Missing 'names' and/or 'features'. Returning full dataset instead.")
1800            
1801        if True in columns_bool:
1802            data = data.loc[:, columns_bool]
1803            metadata = metadata.loc[columns_bool, :]
1804
1805        if True in features_bool:
1806            data = data.loc[features_bool, :]
1807
1808        not_in_features = [y for y in features if y not in features_names]
1809
1810        if len(not_in_features) > 0:
1811            print('\nThe following features were not found in data:')
1812            print('\n'.join(not_in_features))
1813
1814        not_in_names = [y for y in names if y not in columns_names]
1815
1816        if len(not_in_names) > 0:
1817            print('\nThe following names were not found in data:')
1818            print('\n'.join(not_in_names))
1819
1820        if inc_metadata:
1821            return data, metadata
1822        else:
1823            return data

Return a subset of the data filtered by sample names and/or feature names.

Parameters

names : list, str, or None Names of samples to include. If None, all samples are considered.

features : list, str, or None Names of features to include. If None, all features are considered.

name_slot : str Column in metadata to use as sample names.

inc_metadata : bool If True return tuple (data, metadata)

Returns

pandas.DataFrame Subset of the normalized data based on the specified names and features.

def get_metadata(self):
1825    def get_metadata(self):
1826        """
1827        Return the stored input metadata.
1828
1829        Returns
1830        -------
1831        pandas.DataFrame
1832            `self.input_metadata` (may be None if not set).
1833        """
1834
1835        to_return = self.input_metadata
1836
1837        return to_return

Return the stored input metadata.

Returns

pandas.DataFrame self.input_metadata (may be None if not set).

def gene_calculation(self):
1839    def gene_calculation(self):
1840        """
1841        Calculate and store per-cell counts (e.g., number of detected genes).
1842
1843        The method computes a binary (presence/absence) per cell and sums across
1844        features to produce `self.gene_calc`.
1845
1846        Update
1847        ------
1848        Sets `self.gene_calc` as a pandas.Series.
1849
1850        Side Effects
1851        ------------
1852        Uses `self.input_data` when available, otherwise `self.normalized_data`.
1853        """
1854
1855        if self.input_data is not None:
1856
1857            bin_col = self.input_data.columns.copy()
1858
1859            bin_col = bin_col.where(bin_col <= 0, 1)
1860
1861            sum_data = bin_col.sum(axis=0)
1862
1863            self.gene_calc = sum_data
1864
1865        elif self.normalized_data is not None:
1866
1867            bin_col = self.normalized_data.copy()
1868
1869            bin_col = bin_col.where(bin_col <= 0, 1)
1870
1871            sum_data = bin_col.sum(axis=0)
1872
1873            self.gene_calc = sum_data

Calculate and store per-cell counts (e.g., number of detected genes).

The method computes a binary (presence/absence) per cell and sums across features to produce self.gene_calc.

Update

Sets self.gene_calc as a pandas.Series.

Side Effects

Uses self.input_data when available, otherwise self.normalized_data.

def gene_histograme(self, bins=100):
1875    def gene_histograme(self, bins=100):
1876        """
1877        Plot a histogram of the number of genes detected per cell.
1878
1879        Parameters
1880        ----------
1881        bins : int, default 100
1882            Number of histogram bins.
1883
1884        Returns
1885        -------
1886        matplotlib.figure.Figure
1887            Figure containing the histogram of gene contents.
1888
1889        Notes
1890        -----
1891        Requires `self.gene_calc` to be computed prior to calling.
1892        """
1893
1894        fig, ax = plt.subplots(figsize=(8, 5))
1895
1896        _, bin_edges, _ = ax.hist(
1897            self.gene_calc, bins=bins, edgecolor="black", alpha=0.6
1898        )
1899
1900        mu, sigma = np.mean(self.gene_calc), np.std(self.gene_calc)
1901
1902        x = np.linspace(min(self.gene_calc), max(self.gene_calc), 1000)
1903        y = norm.pdf(x, mu, sigma)
1904
1905        y_scaled = y * len(self.gene_calc) * (bin_edges[1] - bin_edges[0])
1906
1907        ax.plot(
1908            x, y_scaled, "r-", linewidth=2, label=f"Normal(μ={mu:.2f}, σ={sigma:.2f})"
1909        )
1910
1911        ax.set_xlabel("Value")
1912        ax.set_ylabel("Count")
1913        ax.set_title("Histogram of genes detected per cell")
1914
1915        ax.set_xticks(np.linspace(min(self.gene_calc), max(self.gene_calc), 20))
1916        ax.tick_params(axis="x", rotation=90)
1917
1918        ax.legend()
1919
1920        return fig

Plot a histogram of the number of genes detected per cell.

Parameters

bins : int, default 100 Number of histogram bins.

Returns

matplotlib.figure.Figure Figure containing the histogram of gene contents.

Notes

Requires self.gene_calc to be computed prior to calling.

def gene_threshold(self, min_n: int | None, max_n: int | None):
1922    def gene_threshold(self, min_n: int | None, max_n: int | None):
1923        """
1924        Filter cells by gene-detection thresholds (min and/or max).
1925
1926        Parameters
1927        ----------
1928        min_n : int or None
1929            Minimum number of detected genes required to keep a cell.
1930
1931        max_n : int or None
1932            Maximum number of detected genes allowed to keep a cell.
1933
1934        Update
1935        -------
1936        Filters `self.input_data`, `self.normalized_data`, `self.input_metadata`
1937        (and calls `average()` if `self.agg_normalized_data` exists).
1938
1939        Side Effects
1940        ------------
1941        Raises ValueError if both bounds are None or if filtering removes all cells.
1942        """
1943
1944        if min_n is not None and max_n is not None:
1945            mask = (self.gene_calc > min_n) & (self.gene_calc < max_n)
1946        elif min_n is None and max_n is not None:
1947            mask = self.gene_calc < max_n
1948        elif min_n is not None and max_n is None:
1949            mask = self.gene_calc > min_n
1950        else:
1951            raise ValueError("Lack of both min_n and max_n values")
1952
1953        if self.input_data is not None:
1954
1955            if len([y for y in mask if y is False]) == 0:
1956                raise ValueError("Nothing to reduce")
1957
1958            self.input_data = self.input_data.loc[:, mask.values]
1959
1960        if self.normalized_data is not None:
1961
1962            if len([y for y in mask if y is False]) == 0:
1963                raise ValueError("Nothing to reduce")
1964
1965            self.normalized_data = self.normalized_data.loc[:, mask.values]
1966
1967        if self.input_metadata is not None:
1968
1969            if len([y for y in mask if y is False]) == 0:
1970                raise ValueError("Nothing to reduce")
1971
1972            self.input_metadata = self.input_metadata.loc[mask.values, :].reset_index(
1973                drop=True
1974            )
1975
1976            self.input_metadata = self.input_metadata.drop(
1977                columns=["drop"], errors="ignore"
1978            )
1979
1980        if self.agg_normalized_data is not None:
1981            self.average()
1982
1983        self.gene_calculation()
1984        self.cells_calculation()

Filter cells by gene-detection thresholds (min and/or max).

Parameters

min_n : int or None Minimum number of detected genes required to keep a cell.

max_n : int or None Maximum number of detected genes allowed to keep a cell.

Update

Filters self.input_data, self.normalized_data, self.input_metadata (and calls average() if self.agg_normalized_data exists).

Side Effects

Raises ValueError if both bounds are None or if filtering removes all cells.

def cells_calculation(self, name_slot='cell_names'):
1986    def cells_calculation(self, name_slot="cell_names"):
1987        """
1988        Calculate number of cells per  call name / cluster.
1989
1990        The method computes a binary (presence/absence) per cell name / cluster and sums across
1991        cells.
1992
1993        Parameters
1994        ----------
1995        name_slot : str, default 'cell_names'
1996            Column in metadata to use as sample names.
1997
1998        Update
1999        ------
2000        Sets `self.cells_calc` as a pd.DataFrame.
2001        """
2002
2003        ls = list(self.input_metadata[name_slot])
2004
2005        df = pd.DataFrame(
2006            {
2007                "cluster": pd.Series(ls).value_counts().index,
2008                "n": pd.Series(ls).value_counts().values,
2009            }
2010        )
2011
2012        self.cells_calc = df

Calculate number of cells per call name / cluster.

The method computes a binary (presence/absence) per cell name / cluster and sums across cells.

Parameters

name_slot : str, default 'cell_names' Column in metadata to use as sample names.

Update

Sets self.cells_calc as a pd.DataFrame.

def cell_histograme(self, name_slot: str = 'cell_names'):
2014    def cell_histograme(self, name_slot: str = "cell_names"):
2015        """
2016        Plot a histogram of the number of cells detected per cell name (cluster).
2017
2018        Parameters
2019        ----------
2020        name_slot : str, default 'cell_names'
2021            Column in metadata to use as sample names.
2022
2023        Returns
2024        -------
2025        matplotlib.figure.Figure
2026            Figure containing the histogram of cell contents.
2027
2028        Notes
2029        -----
2030        Requires `self.cells_calc` to be computed prior to calling.
2031        """
2032
2033        if name_slot != "cell_names":
2034            self.cells_calculation(name_slot=name_slot)
2035
2036        fig, ax = plt.subplots(figsize=(8, 5))
2037
2038        _, bin_edges, _ = ax.hist(
2039            list(self.cells_calc["n"]),
2040            bins=len(set(self.cells_calc["cluster"])),
2041            edgecolor="black",
2042            color="orange",
2043            alpha=0.6,
2044        )
2045
2046        mu, sigma = np.mean(list(self.cells_calc["n"])), np.std(
2047            list(self.cells_calc["n"])
2048        )
2049
2050        x = np.linspace(
2051            min(list(self.cells_calc["n"])), max(list(self.cells_calc["n"])), 1000
2052        )
2053        y = norm.pdf(x, mu, sigma)
2054
2055        y_scaled = y * len(list(self.cells_calc["n"])) * (bin_edges[1] - bin_edges[0])
2056
2057        ax.plot(
2058            x, y_scaled, "r-", linewidth=2, label=f"Normal(μ={mu:.2f}, σ={sigma:.2f})"
2059        )
2060
2061        ax.set_xlabel("Value")
2062        ax.set_ylabel("Count")
2063        ax.set_title("Histogram of cells detected per cell name / cluster")
2064
2065        ax.set_xticks(
2066            np.linspace(
2067                min(list(self.cells_calc["n"])), max(list(self.cells_calc["n"])), 20
2068            )
2069        )
2070        ax.tick_params(axis="x", rotation=90)
2071
2072        ax.legend()
2073
2074        return fig

Plot a histogram of the number of cells detected per cell name (cluster).

Parameters

name_slot : str, default 'cell_names' Column in metadata to use as sample names.

Returns

matplotlib.figure.Figure Figure containing the histogram of cell contents.

Notes

Requires self.cells_calc to be computed prior to calling.

def cluster_threshold(self, min_n: int | None, name_slot: str = 'cell_names'):
2076    def cluster_threshold(self, min_n: int | None, name_slot: str = "cell_names"):
2077        """
2078        Filter cell names / clusters by cell-detection threshold.
2079
2080        Parameters
2081        ----------
2082        min_n : int or None
2083            Minimum number of detected genes required to keep a cell.
2084
2085        name_slot : str, default 'cell_names'
2086            Column in metadata to use as sample names.
2087
2088
2089        Update
2090        -------
2091        Filters `self.input_data`, `self.normalized_data`, `self.input_metadata`
2092        (and calls `average()` if `self.agg_normalized_data` exists).
2093        """
2094
2095        if name_slot != "cell_names":
2096            self.cells_calculation(name_slot=name_slot)
2097
2098        if min_n is not None:
2099            names = self.cells_calc["cluster"][self.cells_calc["n"] < min_n]
2100        else:
2101            raise ValueError("Lack of min_n value")
2102
2103        if len(names) > 0:
2104
2105            if self.input_data is not None:
2106
2107                self.input_data.columns = self.input_metadata[name_slot]
2108
2109                mask = [not any(r in x for r in names) for x in self.input_data.columns]
2110
2111                if len([y for y in mask if y is False]) > 0:
2112
2113                    self.input_data = self.input_data.loc[:, mask]
2114
2115            if self.normalized_data is not None:
2116
2117                self.normalized_data.columns = self.input_metadata[name_slot]
2118
2119                mask = [
2120                    not any(r in x for r in names) for x in self.normalized_data.columns
2121                ]
2122
2123                if len([y for y in mask if y is False]) > 0:
2124
2125                    self.normalized_data = self.normalized_data.loc[:, mask]
2126
2127            if self.input_metadata is not None:
2128
2129                self.input_metadata["drop"] = self.input_metadata[name_slot]
2130
2131                mask = [
2132                    not any(r in x for r in names) for x in self.input_metadata["drop"]
2133                ]
2134
2135                if len([y for y in mask if y is False]) > 0:
2136
2137                    self.input_metadata = self.input_metadata.loc[mask, :].reset_index(
2138                        drop=True
2139                    )
2140
2141                self.input_metadata = self.input_metadata.drop(
2142                    columns=["drop"], errors="ignore"
2143                )
2144
2145            if self.agg_normalized_data is not None:
2146
2147                self.agg_normalized_data.columns = self.agg_metadata[name_slot]
2148
2149                mask = [
2150                    not any(r in x for r in names)
2151                    for x in self.agg_normalized_data.columns
2152                ]
2153
2154                if len([y for y in mask if y is False]) > 0:
2155
2156                    self.agg_normalized_data = self.agg_normalized_data.loc[:, mask]
2157
2158            if self.agg_metadata is not None:
2159
2160                self.agg_metadata["drop"] = self.agg_metadata[name_slot]
2161
2162                mask = [
2163                    not any(r in x for r in names) for x in self.agg_metadata["drop"]
2164                ]
2165
2166                if len([y for y in mask if y is False]) > 0:
2167
2168                    self.agg_metadata = self.agg_metadata.loc[mask, :].reset_index(
2169                        drop=True
2170                    )
2171
2172                self.agg_metadata = self.agg_metadata.drop(
2173                    columns=["drop"], errors="ignore"
2174                )
2175
2176            self.gene_calculation()
2177            self.cells_calculation()

Filter cell names / clusters by cell-detection threshold.

Parameters

min_n : int or None Minimum number of detected genes required to keep a cell.

name_slot : str, default 'cell_names' Column in metadata to use as sample names.

Update

Filters self.input_data, self.normalized_data, self.input_metadata (and calls average() if self.agg_normalized_data exists).

def load_sparse_from_projects(self, normalized_data: bool = False):
2179    def load_sparse_from_projects(self, normalized_data: bool = False):
2180        """
2181        Load sparse 10x-style datasets from stored project paths, concatenate them,
2182        and populate `input_data` / `normalized_data` and `input_metadata`.
2183
2184        Parameters
2185        ----------
2186        normalized_data : bool, default False
2187            If True, store concatenated tables in `self.normalized_data`.
2188            If False, store them in `self.input_data` and normalization
2189            is needed using normalize_data() method.
2190
2191        Side Effects
2192        ------------
2193        - Reads each project using `load_sparse(...)` (expects matrix.mtx, genes.tsv, barcodes.tsv).
2194        - Concatenates all projects column-wise and sets `self.input_metadata`.
2195        - Replaces NaNs with zeros and updates `self.gene_calc`.
2196        """
2197
2198        obj = self.objects
2199
2200        full_data = pd.DataFrame()
2201        full_metadata = pd.DataFrame()
2202
2203        for ke in obj.keys():
2204            print(ke)
2205
2206            dt, met = load_sparse(path=obj[ke], name=ke)
2207
2208            full_data = pd.concat([full_data, dt], axis=1)
2209            full_metadata = pd.concat([full_metadata, met], axis=0)
2210
2211        full_data[np.isnan(full_data)] = 0
2212
2213        if normalized_data:
2214            self.normalized_data = full_data
2215            self.input_metadata = full_metadata
2216        else:
2217
2218            self.input_data = full_data
2219            self.input_metadata = full_metadata
2220
2221        self.gene_calculation()
2222        self.cells_calculation()

Load sparse 10x-style datasets from stored project paths, concatenate them, and populate input_data / normalized_data and input_metadata.

Parameters

normalized_data : bool, default False If True, store concatenated tables in self.normalized_data. If False, store them in self.input_data and normalization is needed using normalize_data() method.

Side Effects

  • Reads each project using load_sparse(...) (expects matrix.mtx, genes.tsv, barcodes.tsv).
  • Concatenates all projects column-wise and sets self.input_metadata.
  • Replaces NaNs with zeros and updates self.gene_calc.
def rename_names(self, mapping: dict, slot: str = 'cell_names'):
2224    def rename_names(self, mapping: dict, slot: str = "cell_names"):
2225        """
2226        Rename entries in `self.input_metadata[slot]` according to a provided mapping.
2227
2228        Parameters
2229        ----------
2230        mapping : dict
2231            Dictionary with keys 'old_name' and 'new_name', each mapping to a list
2232            of equal length describing replacements.
2233
2234        slot : str, default 'cell_names'
2235            Metadata column to operate on.
2236
2237        Update
2238        -------
2239        Updates `self.input_metadata[slot]` in-place with renamed values.
2240
2241        Raises
2242        ------
2243        ValueError
2244            If mapping keys are incorrect, lengths differ, or some 'old_name' values
2245            are not present in the metadata column.
2246        """
2247
2248        if set(["old_name", "new_name"]) != set(mapping.keys()):
2249            raise ValueError(
2250                "Mapping dictionary must contain keys 'old_name' and 'new_name', "
2251                "each with a list of names to change."
2252            )
2253
2254        if len(mapping["old_name"]) != len(mapping["new_name"]):
2255            raise ValueError(
2256                "Mapping dictionary lists 'old_name' and 'new_name' "
2257                "must have the same length!"
2258            )
2259
2260        names_vector = list(self.input_metadata[slot])
2261
2262        if not all(elem in names_vector for elem in list(mapping["old_name"])):
2263            raise ValueError(
2264                f"Some entries from 'old_name' do not exist in the names of slot {slot}."
2265            )
2266
2267        replace_dict = dict(zip(mapping["old_name"], mapping["new_name"]))
2268
2269        names_vector_ret = [replace_dict.get(item, item) for item in names_vector]
2270
2271        self.input_metadata[slot] = names_vector_ret

Rename entries in self.input_metadata[slot] according to a provided mapping.

Parameters

mapping : dict Dictionary with keys 'old_name' and 'new_name', each mapping to a list of equal length describing replacements.

slot : str, default 'cell_names' Metadata column to operate on.

Update

Updates self.input_metadata[slot] in-place with renamed values.

Raises

ValueError If mapping keys are incorrect, lengths differ, or some 'old_name' values are not present in the metadata column.

def rename_subclusters(self, mapping):
2273    def rename_subclusters(self, mapping):
2274        """
2275        Rename labels stored in `self.subclusters_.subclusters` according to mapping.
2276
2277        Parameters
2278        ----------
2279        mapping : dict
2280            Mapping with keys 'old_name' and 'new_name' (lists of equal length).
2281
2282        Update
2283        -------
2284        Updates `self.subclusters_.subclusters` with renamed labels.
2285
2286        Raises
2287        ------
2288        ValueError
2289            If mapping is invalid or old names are not present.
2290        """
2291
2292        if set(["old_name", "new_name"]) != set(mapping.keys()):
2293            raise ValueError(
2294                "Mapping dictionary must contain keys 'old_name' and 'new_name', "
2295                "each with a list of names to change."
2296            )
2297
2298        if len(mapping["old_name"]) != len(mapping["new_name"]):
2299            raise ValueError(
2300                "Mapping dictionary lists 'old_name' and 'new_name' "
2301                "must have the same length!"
2302            )
2303
2304        names_vector = list(self.subclusters_.subclusters)
2305
2306        if not all(elem in names_vector for elem in list(mapping["old_name"])):
2307            raise ValueError(
2308                "Some entries from 'old_name' do not exist in the subcluster names."
2309            )
2310
2311        replace_dict = dict(zip(mapping["old_name"], mapping["new_name"]))
2312
2313        names_vector_ret = [replace_dict.get(item, item) for item in names_vector]
2314
2315        self.subclusters_.subclusters = names_vector_ret

Rename labels stored in self.subclusters_.subclusters according to mapping.

Parameters

mapping : dict Mapping with keys 'old_name' and 'new_name' (lists of equal length).

Update

Updates self.subclusters_.subclusters with renamed labels.

Raises

ValueError If mapping is invalid or old names are not present.

def save_sparse( self, path_to_save: str = '/mnt/c/Users/merag/Git/JDti', name_slot: str = 'cell_names', data_slot: str = 'normalized'):
2317    def save_sparse(
2318        self,
2319        path_to_save: str = os.getcwd(),
2320        name_slot: str = "cell_names",
2321        data_slot: str = "normalized",
2322    ):
2323        """
2324        Export data as 10x-compatible sparse files (matrix.mtx, barcodes.tsv, genes.tsv).
2325
2326        Parameters
2327        ----------
2328        path_to_save : str, default current working directory
2329            Directory where files will be written.
2330
2331        name_slot : str, default 'cell_names'
2332            Metadata column providing cell names for barcodes.tsv.
2333
2334        data_slot : str, default 'normalized'
2335            Either 'normalized' (uses self.normalized_data) or 'count' (uses self.input_data).
2336
2337        Raises
2338        ------
2339        ValueError
2340            If `data_slot` is not 'normalized' or 'count'.
2341        """
2342
2343        names = self.input_metadata[name_slot]
2344
2345        if data_slot.lower() == "normalized":
2346
2347            features = list(self.normalized_data.index)
2348            mtx = sparse.csr_matrix(self.normalized_data)
2349
2350        elif data_slot.lower() == "count":
2351
2352            features = list(self.input_data.index)
2353            mtx = sparse.csr_matrix(self.input_data)
2354
2355        else:
2356            raise ValueError("'data_slot' must be included in 'normalized' or 'count'")
2357
2358        os.makedirs(path_to_save, exist_ok=True)
2359
2360        mmwrite(os.path.join(path_to_save, "matrix.mtx"), mtx)
2361
2362        pd.Series(names).to_csv(
2363            os.path.join(path_to_save, "barcodes.tsv"),
2364            index=False,
2365            header=False,
2366            sep="\t",
2367        )
2368
2369        pd.Series(features).to_csv(
2370            os.path.join(path_to_save, "genes.tsv"), index=False, header=False, sep="\t"
2371        )

Export data as 10x-compatible sparse files (matrix.mtx, barcodes.tsv, genes.tsv).

Parameters

path_to_save : str, default current working directory Directory where files will be written.

name_slot : str, default 'cell_names' Metadata column providing cell names for barcodes.tsv.

data_slot : str, default 'normalized' Either 'normalized' (uses self.normalized_data) or 'count' (uses self.input_data).

Raises

ValueError If data_slot is not 'normalized' or 'count'.

def normalize_counts(self, normalize_factor: int = 100000, log_transform: bool = True):
2373    def normalize_counts(
2374        self, normalize_factor: int = 100000, log_transform: bool = True
2375    ):
2376        """
2377        Normalize raw counts to counts-per-(normalize_factor)
2378        (e.g., CPM, TPM - depending on normalize_factor).
2379
2380        Parameters
2381        ----------
2382        normalize_factor : int, default 100000
2383            Scaling factor used after dividing by column sums.
2384
2385        log_transform : bool, default True
2386            If True, apply log2(x+1) transformation to normalized values.
2387
2388        Update
2389        -------
2390            Sets `self.normalized_data` to normalized values (fills NaNs with 0).
2391
2392        Raises
2393        ------
2394        ValueError
2395            If `self.input_data` is missing (cannot normalize).
2396        """
2397        if self.input_data is None:
2398            raise ValueError("Input data is missing, cannot normalize.")
2399
2400        sum_col = self.input_data.sum()
2401        self.normalized_data = self.input_data.div(sum_col).fillna(0) * normalize_factor
2402
2403        if log_transform:
2404            # log2(x + 1) to avoid -inf for zeros
2405            self.normalized_data = np.log2(self.normalized_data + 1)

Normalize raw counts to counts-per-(normalize_factor) (e.g., CPM, TPM - depending on normalize_factor).

Parameters

normalize_factor : int, default 100000 Scaling factor used after dividing by column sums.

log_transform : bool, default True If True, apply log2(x+1) transformation to normalized values.

Update

Sets `self.normalized_data` to normalized values (fills NaNs with 0).

Raises

ValueError If self.input_data is missing (cannot normalize).

def statistic( self, cells=None, sets=None, min_exp: float = 0.01, min_pct: float = 0.1, n_proc: int = 10):
2407    def statistic(
2408        self,
2409        cells=None,
2410        sets=None,
2411        min_exp: float = 0.01,
2412        min_pct: float = 0.1,
2413        n_proc: int = 10,
2414    ):
2415        """
2416        Compute per-feature statistics (Mann–Whitney U) comparing target vs rest.
2417
2418        This is a wrapper similar to `calc_DEG` tailored to use `self.normalized_data`
2419        and `self.input_metadata`. It returns per-feature statistics including p-values,
2420        adjusted p-values, means, variances, effect-size measures and fold-changes.
2421
2422        Parameters
2423        ----------
2424        cells : list, 'All', dict, or None
2425            Defines the target cells or groups for comparison (several modes supported).
2426
2427        sets : 'All', dict, or None
2428            Alternative grouping mode (operate on `self.input_metadata['sets']`).
2429
2430        min_exp : float, default 0.01
2431            Minimum expression threshold used when filtering features.
2432
2433        min_pct : float, default 0.1
2434            Minimum proportion of expressing cells in the target group required to test a feature.
2435
2436        n_proc : int, default 10
2437            Number of parallel jobs to use.
2438
2439        Returns
2440        -------
2441        pandas.DataFrame or dict
2442            Results DataFrame (or dict containing valid/control cells + DataFrame),
2443            similar to `calc_DEG` interface.
2444
2445        Raises
2446        ------
2447        ValueError
2448            If neither `cells` nor `sets` is provided, or input metadata mismatch occurs.
2449
2450        Notes
2451        -----
2452        Multiple modes supported: single-list entities, 'All', pairwise dicts, etc.
2453        """
2454
2455        offset = 1e-100
2456
2457        def stat_calc(choose, feature_name):
2458            target_values = choose.loc[choose["DEG"] == "target", feature_name]
2459            rest_values = choose.loc[choose["DEG"] == "rest", feature_name]
2460
2461            pct_valid = (target_values > 0).sum() / len(target_values)
2462            pct_rest = (rest_values > 0).sum() / len(rest_values)
2463
2464            avg_valid = np.mean(target_values)
2465            avg_ctrl = np.mean(rest_values)
2466            sd_valid = np.std(target_values, ddof=1)
2467            sd_ctrl = np.std(rest_values, ddof=1)
2468            esm = (avg_valid - avg_ctrl) / np.sqrt(((sd_valid**2 + sd_ctrl**2) / 2))
2469
2470            if np.sum(target_values) == np.sum(rest_values):
2471                p_val = 1.0
2472            else:
2473                _, p_val = stats.mannwhitneyu(
2474                    target_values, rest_values, alternative="two-sided"
2475                )
2476
2477            return {
2478                "feature": feature_name,
2479                "p_val": p_val,
2480                "pct_valid": pct_valid,
2481                "pct_ctrl": pct_rest,
2482                "avg_valid": avg_valid,
2483                "avg_ctrl": avg_ctrl,
2484                "sd_valid": sd_valid,
2485                "sd_ctrl": sd_ctrl,
2486                "esm": esm,
2487            }
2488
2489        def prepare_and_run_stat(choose, valid_group, min_exp, min_pct, n_proc):
2490
2491            def safe_min_half(series):
2492                filtered = series[(series > ((2**-1074)*2)) & (series.notna())]
2493                return filtered.min() / 2 if not filtered.empty else 0
2494        
2495            tmp_dat = choose[choose["DEG"] == "target"]
2496            tmp_dat = tmp_dat.drop("DEG", axis=1)
2497
2498            counts = (tmp_dat > min_exp).sum(axis=0)
2499
2500            total_count = tmp_dat.shape[0]
2501
2502            info = pd.DataFrame(
2503                {"feature": list(tmp_dat.columns), "pct": list(counts / total_count)}
2504            )
2505
2506            del tmp_dat
2507
2508            drop_col = info["feature"][info["pct"] <= min_pct]
2509
2510            if len(drop_col) + 1 == len(choose.columns):
2511                drop_col = info["feature"][info["pct"] == 0]
2512
2513            del info
2514
2515            choose = choose.drop(list(drop_col), axis=1)
2516
2517            results = Parallel(n_jobs=n_proc)(
2518                delayed(stat_calc)(choose, feature)
2519                for feature in tqdm(choose.columns[choose.columns != "DEG"])
2520            )
2521
2522            df = pd.DataFrame(results)
2523            df = df[(df["avg_valid"] > 0) | (df["avg_ctrl"] > 0)]
2524
2525            df["valid_group"] = valid_group
2526            df.sort_values(by="p_val", inplace=True)
2527
2528            num_tests = len(df)
2529            df["adj_pval"] = np.minimum(
2530                1, (df["p_val"] * num_tests) / np.arange(1, num_tests + 1)
2531            )
2532
2533            valid_factor = safe_min_half(df["avg_valid"])
2534            ctrl_factor = safe_min_half(df["avg_ctrl"])
2535
2536            cv_factor = min(valid_factor, ctrl_factor)
2537
2538            if cv_factor == 0:
2539                cv_factor = max(valid_factor, ctrl_factor)
2540
2541            if not np.isfinite(cv_factor) or cv_factor == 0:
2542                cv_factor += offset
2543
2544            valid = df["avg_valid"].where(
2545                df["avg_valid"] != 0, df["avg_valid"] + cv_factor
2546            )
2547            ctrl = df["avg_ctrl"].where(
2548                df["avg_ctrl"] != 0, df["avg_ctrl"] + cv_factor
2549            )
2550
2551            df["FC"] = valid / ctrl
2552
2553            df["log(FC)"] = np.log2(df["FC"])
2554            df["norm_diff"] = df["avg_valid"] - df["avg_ctrl"]
2555
2556            return df
2557
2558        choose = self.normalized_data.copy().T
2559
2560        final_results = []
2561
2562        if isinstance(cells, list) and sets is None:
2563            print("\nAnalysis started...\nComparing selected cells to the whole set...")
2564            choose.index = (
2565                self.input_metadata["cell_names"] + " # " + self.input_metadata["sets"]
2566            )
2567
2568            if "#" not in cells[0]:
2569                choose.index = self.input_metadata["cell_names"]
2570
2571                print(
2572                    "Not include the set info (name # set) in the 'cells' list. "
2573                    "Only the names will be compared, without considering the set information."
2574                )
2575
2576            labels = ["target" if idx in cells else "rest" for idx in choose.index]
2577            valid = list(
2578                set(choose.index[[i for i, x in enumerate(labels) if x == "target"]])
2579            )
2580
2581            choose["DEG"] = labels
2582            choose = choose[choose["DEG"] != "drop"]
2583
2584            result_df = prepare_and_run_stat(
2585                choose.reset_index(drop=True),
2586                valid_group=valid,
2587                min_exp=min_exp,
2588                min_pct=min_pct,
2589                n_proc=n_proc,
2590            )
2591            return {"valid_cells": valid, "control_cells": "rest", "DEG": result_df}
2592
2593        elif cells == "All" and sets is None:
2594            print("\nAnalysis started...\nComparing each type of cell to others...")
2595            choose.index = (
2596                self.input_metadata["cell_names"] + " # " + self.input_metadata["sets"]
2597            )
2598            unique_labels = set(choose.index)
2599
2600            for label in tqdm(unique_labels):
2601                print(f"\nCalculating statistics for {label}")
2602                labels = ["target" if idx == label else "rest" for idx in choose.index]
2603                choose["DEG"] = labels
2604                choose = choose[choose["DEG"] != "drop"]
2605                result_df = prepare_and_run_stat(
2606                    choose.copy(),
2607                    valid_group=label,
2608                    min_exp=min_exp,
2609                    min_pct=min_pct,
2610                    n_proc=n_proc,
2611                )
2612                final_results.append(result_df)
2613
2614            return pd.concat(final_results, ignore_index=True)
2615
2616        elif cells is None and sets == "All":
2617            print("\nAnalysis started...\nComparing each set/group to others...")
2618            choose.index = self.input_metadata["sets"]
2619            unique_sets = set(choose.index)
2620
2621            for label in tqdm(unique_sets):
2622                print(f"\nCalculating statistics for {label}")
2623                labels = ["target" if idx == label else "rest" for idx in choose.index]
2624
2625                choose["DEG"] = labels
2626                choose = choose[choose["DEG"] != "drop"]
2627                result_df = prepare_and_run_stat(
2628                    choose.copy(),
2629                    valid_group=label,
2630                    min_exp=min_exp,
2631                    min_pct=min_pct,
2632                    n_proc=n_proc,
2633                )
2634                final_results.append(result_df)
2635
2636            return pd.concat(final_results, ignore_index=True)
2637
2638        elif cells is None and isinstance(sets, dict):
2639            print("\nAnalysis started...\nComparing groups...")
2640
2641            choose.index = self.input_metadata["sets"]
2642
2643            group_list = list(sets.keys())
2644            if len(group_list) != 2:
2645                print("Only pairwise group comparison is supported.")
2646                return None
2647
2648            labels = [
2649                (
2650                    "target"
2651                    if idx in sets[group_list[0]]
2652                    else "rest" if idx in sets[group_list[1]] else "drop"
2653                )
2654                for idx in choose.index
2655            ]
2656            choose["DEG"] = labels
2657            choose = choose[choose["DEG"] != "drop"]
2658
2659            result_df = prepare_and_run_stat(
2660                choose.reset_index(drop=True),
2661                valid_group=group_list[0],
2662                min_exp=min_exp,
2663                min_pct=min_pct,
2664                n_proc=n_proc,
2665            )
2666            return result_df
2667
2668        elif isinstance(cells, dict) and sets is None:
2669            print("\nAnalysis started...\nComparing groups...")
2670            choose.index = (
2671                self.input_metadata["cell_names"] + " # " + self.input_metadata["sets"]
2672            )
2673
2674            if "#" not in cells[list(cells.keys())[0]][0]:
2675                choose.index = self.input_metadata["cell_names"]
2676
2677                print(
2678                    "Not include the set info (name # set) in the 'cells' dict. "
2679                    "Only the names will be compared, without considering the set information."
2680                )
2681
2682            group_list = list(cells.keys())
2683            if len(group_list) != 2:
2684                print("Only pairwise group comparison is supported.")
2685                return None
2686
2687            labels = [
2688                (
2689                    "target"
2690                    if idx in cells[group_list[0]]
2691                    else "rest" if idx in cells[group_list[1]] else "drop"
2692                )
2693                for idx in choose.index
2694            ]
2695
2696            choose["DEG"] = labels
2697            choose = choose[choose["DEG"] != "drop"]
2698
2699            result_df = prepare_and_run_stat(
2700                choose.reset_index(drop=True),
2701                valid_group=group_list[0],
2702                min_exp=min_exp,
2703                min_pct=min_pct,
2704                n_proc=n_proc,
2705            )
2706
2707            return result_df.reset_index(drop=True)
2708
2709        else:
2710            raise ValueError(
2711                "You must specify either 'cells' or 'sets' (or both). None were provided, which is not allowed for this analysis."
2712            )

Compute per-feature statistics (Mann–Whitney U) comparing target vs rest.

This is a wrapper similar to calc_DEG tailored to use self.normalized_data and self.input_metadata. It returns per-feature statistics including p-values, adjusted p-values, means, variances, effect-size measures and fold-changes.

Parameters

cells : list, 'All', dict, or None Defines the target cells or groups for comparison (several modes supported).

sets : 'All', dict, or None Alternative grouping mode (operate on self.input_metadata['sets']).

min_exp : float, default 0.01 Minimum expression threshold used when filtering features.

min_pct : float, default 0.1 Minimum proportion of expressing cells in the target group required to test a feature.

n_proc : int, default 10 Number of parallel jobs to use.

Returns

pandas.DataFrame or dict Results DataFrame (or dict containing valid/control cells + DataFrame), similar to calc_DEG interface.

Raises

ValueError If neither cells nor sets is provided, or input metadata mismatch occurs.

Notes

Multiple modes supported: single-list entities, 'All', pairwise dicts, etc.

def calculate_difference_markers(self, min_exp=0, min_pct=0.25, n_proc=10, force: bool = False):
2714    def calculate_difference_markers(
2715        self, min_exp=0, min_pct=0.25, n_proc=10, force: bool = False
2716    ):
2717        """
2718        Compute differential markers (var_data) if not already present.
2719
2720        Parameters
2721        ----------
2722        min_exp : float, default 0
2723            Minimum expression threshold passed to `statistic`.
2724
2725        min_pct : float, default 0.25
2726            Minimum percent expressed in target group.
2727
2728        n_proc : int, default 10
2729            Parallel jobs.
2730
2731        force : bool, default False
2732            If True, recompute even if `self.var_data` is present.
2733
2734        Update
2735        -------
2736        Sets `self.var_data` to the result of `self.statistic(...)`.
2737
2738        Raise
2739        ------
2740        ValueError if already computed and `force` is False.
2741        """
2742
2743        if self.var_data is None or force:
2744
2745            self.var_data = self.statistic(
2746                cells="All", sets=None, min_exp=min_exp, min_pct=min_pct, n_proc=n_proc
2747            )
2748
2749        else:
2750            raise ValueError(
2751                "self.calculate_difference_markers() has already been executed. "
2752                "The results are stored in self.var. "
2753                "If you want to recalculate with different statistics, please rerun the method with force=True."
2754            )

Compute differential markers (var_data) if not already present.

Parameters

min_exp : float, default 0 Minimum expression threshold passed to statistic.

min_pct : float, default 0.25 Minimum percent expressed in target group.

n_proc : int, default 10 Parallel jobs.

force : bool, default False If True, recompute even if self.var_data is present.

Update

Sets self.var_data to the result of self.statistic(...).

Raise

ValueError if already computed and force is False.

def clustering_features( self, features_list: list | None, name_slot: str = 'cell_names', p_val: float = 0.05, top_n: int = 25, adj_mean: bool = True, beta: float = 0.2):
2756    def clustering_features(
2757        self,
2758        features_list: list | None,
2759        name_slot: str = "cell_names",
2760        p_val: float = 0.05,
2761        top_n: int = 25,
2762        adj_mean: bool = True,
2763        beta: float = 0.2,
2764    ):
2765        """
2766        Prepare clustering input by selecting marker features and optionally smoothing cell values
2767        toward group means.
2768
2769        Parameters
2770        ----------
2771        features_list : list or None
2772            If provided, use this list of features. If None, features are selected
2773            from `self.var_data` (adj_pval <= p_val, positive logFC) picking `top_n` per group.
2774
2775        name_slot : str, default 'cell_names'
2776            Metadata column used for naming.
2777
2778        p_val : float, default 0.05
2779            Adjusted p-value cutoff when selecting features automatically.
2780
2781        top_n : int, default 25
2782            Number of top features per valid group to keep if `features_list` is None.
2783
2784        adj_mean : bool, default True
2785            If True, adjust cell values toward group means using `beta`.
2786
2787        beta : float, default 0.2
2788            Adjustment strength toward group mean.
2789
2790        Update
2791        ------
2792        Sets `self.clustering_data` and `self.clustering_metadata` to the selected subset,
2793        ready for PCA/UMAP/clustering.
2794        """
2795
2796        if features_list is None or len(features_list) == 0:
2797
2798            if self.var_data is None:
2799                raise ValueError(
2800                    "Lack of 'self.var_data'. Use self.calculate_difference_markers() method first."
2801                )
2802
2803            df_tmp = self.var_data[self.var_data["adj_pval"] <= p_val]
2804            df_tmp = df_tmp[df_tmp["log(FC)"] > 0]
2805            df_tmp = (
2806                df_tmp.sort_values(
2807                    ["valid_group", "esm", "log(FC)"], ascending=[True, False, False]
2808                )
2809                .groupby("valid_group")
2810                .head(top_n)
2811            )
2812
2813            feaures_list = list(set(df_tmp["feature"]))
2814
2815        data = self.get_partial_data(
2816            names=None, features=feaures_list, name_slot=name_slot
2817        )
2818        data_avg = average(data)
2819
2820        if adj_mean:
2821            data = adjust_cells_to_group_mean(data=data, data_avg=data_avg, beta=beta)
2822
2823        self.clustering_data = data
2824
2825        self.clustering_metadata = self.input_metadata

Prepare clustering input by selecting marker features and optionally smoothing cell values toward group means.

Parameters

features_list : list or None If provided, use this list of features. If None, features are selected from self.var_data (adj_pval <= p_val, positive logFC) picking top_n per group.

name_slot : str, default 'cell_names' Metadata column used for naming.

p_val : float, default 0.05 Adjusted p-value cutoff when selecting features automatically.

top_n : int, default 25 Number of top features per valid group to keep if features_list is None.

adj_mean : bool, default True If True, adjust cell values toward group means using beta.

beta : float, default 0.2 Adjustment strength toward group mean.

Update

Sets self.clustering_data and self.clustering_metadata to the selected subset, ready for PCA/UMAP/clustering.

def average(self):
2827    def average(self):
2828        """
2829        Aggregate normalized data by (cell_name, set) pairs computing the mean per group.
2830
2831        The method constructs new column names as "cell_name # set", averages columns
2832        sharing identical labels, and populates `self.agg_normalized_data` and `self.agg_metadata`.
2833
2834        Update
2835        ------
2836        Sets `self.agg_normalized_data` (features x aggregated samples) and
2837        `self.agg_metadata` (DataFrame with 'cell_names' and 'sets').
2838        """
2839
2840        wide_data = self.normalized_data
2841
2842        wide_metadata = self.input_metadata
2843
2844        new_names = wide_metadata["cell_names"] + " # " + wide_metadata["sets"]
2845
2846        wide_data.columns = list(new_names)
2847
2848        aggregated_df = wide_data.T.groupby(level=0).mean().T
2849
2850        sets = [re.sub(".*# ", "", x) for x in aggregated_df.columns]
2851        names = [re.sub(" #.*", "", x) for x in aggregated_df.columns]
2852
2853        aggregated_df.columns = names
2854        aggregated_metadata = pd.DataFrame({"cell_names": names, "sets": sets})
2855
2856        self.agg_metadata = aggregated_metadata
2857        self.agg_normalized_data = aggregated_df

Aggregate normalized data by (cell_name, set) pairs computing the mean per group.

The method constructs new column names as "cell_name # set", averages columns sharing identical labels, and populates self.agg_normalized_data and self.agg_metadata.

Update

Sets self.agg_normalized_data (features x aggregated samples) and self.agg_metadata (DataFrame with 'cell_names' and 'sets').

def estimating_similarity(self, method='pearson', p_val: float = 0.05, top_n: int = 25):
2859    def estimating_similarity(
2860        self, method="pearson", p_val: float = 0.05, top_n: int = 25
2861    ):
2862        """
2863        Estimate pairwise similarity and Euclidean distance between aggregated samples.
2864
2865        Parameters
2866        ----------
2867        method : str, default 'pearson'
2868            Correlation method to use (passed to pandas.DataFrame.corr()).
2869
2870        p_val : float, default 0.05
2871            Adjusted p-value cutoff used to select marker features from `self.var_data`.
2872
2873        top_n : int, default 25
2874            Number of top features per valid group to include.
2875
2876        Update
2877        -------
2878        Computes a combined table with per-pair correlation and euclidean distance
2879        and stores it in `self.similarity`.
2880        """
2881
2882        if self.var_data is None:
2883            raise ValueError(
2884                "Lack of 'self.var_data'. Use self.calculate_difference_markers() method first."
2885            )
2886
2887        if self.agg_normalized_data is None:
2888            self.average()
2889
2890        metadata = self.agg_metadata
2891        data = self.agg_normalized_data
2892
2893        df_tmp = self.var_data[self.var_data["adj_pval"] <= p_val]
2894        df_tmp = df_tmp[df_tmp["log(FC)"] > 0]
2895        df_tmp = (
2896            df_tmp.sort_values(
2897                ["valid_group", "esm", "log(FC)"], ascending=[True, False, False]
2898            )
2899            .groupby("valid_group")
2900            .head(top_n)
2901        )
2902
2903        data = data.loc[list(set(df_tmp["feature"]))]
2904
2905        if len(set(metadata["sets"])) > 1:
2906            data.columns = data.columns + " # " + [x for x in metadata["sets"]]
2907        else:
2908            data = data.copy()
2909
2910        scaler = StandardScaler()
2911
2912        scaled_data = scaler.fit_transform(data)
2913
2914        scaled_df = pd.DataFrame(scaled_data, columns=data.columns)
2915
2916        cor = scaled_df.corr(method=method)
2917        cor_df = cor.stack().reset_index()
2918        cor_df.columns = ["cell1", "cell2", "correlation"]
2919
2920        distances = pdist(scaled_df.T, metric="euclidean")
2921        dist_mat = pd.DataFrame(
2922            squareform(distances), index=scaled_df.columns, columns=scaled_df.columns
2923        )
2924        dist_df = dist_mat.stack().reset_index()
2925        dist_df.columns = ["cell1", "cell2", "euclidean_dist"]
2926
2927        full = pd.merge(cor_df, dist_df, on=["cell1", "cell2"])
2928
2929        full = full[full["cell1"] != full["cell2"]]
2930        full = full.reset_index(drop=True)
2931
2932        self.similarity = full

Estimate pairwise similarity and Euclidean distance between aggregated samples.

Parameters

method : str, default 'pearson' Correlation method to use (passed to pandas.DataFrame.corr()).

p_val : float, default 0.05 Adjusted p-value cutoff used to select marker features from self.var_data.

top_n : int, default 25 Number of top features per valid group to include.

Update

Computes a combined table with per-pair correlation and euclidean distance and stores it in self.similarity.

def similarity_plot( self, split_sets=True, set_info: bool = True, cmap='seismic', width=12, height=10):
2934    def similarity_plot(
2935        self,
2936        split_sets=True,
2937        set_info: bool = True,
2938        cmap="seismic",
2939        width=12,
2940        height=10,
2941    ):
2942        """
2943        Visualize pairwise similarity as a scatter plot.
2944
2945        Parameters
2946        ----------
2947        split_sets : bool, default True
2948            If True and set information is present, split plotting area roughly into two halves to visualize cross-set pairs.
2949
2950        set_info : bool, default True
2951            If True, keep the ' # set' annotation in labels; otherwise strip it.
2952
2953        cmap : str, default 'seismic'
2954            Color map for correlation (hue).
2955
2956        width : int, default 12
2957            Figure width.
2958
2959        height : int, default 10
2960            Figure height.
2961
2962        Returns
2963        -------
2964        matplotlib.figure.Figure
2965
2966        Raises
2967        ------
2968        ValueError
2969            If `self.similarity` is None.
2970
2971        Notes
2972        -----
2973        The function filters pairs by z-scored euclidean distance > 0 to focus on closer pairs.
2974        """
2975
2976        if self.similarity is None:
2977            raise ValueError(
2978                "Similarity data is missing. Please calculate similarity using self.estimating_similarity."
2979            )
2980
2981        similarity_data = self.similarity
2982
2983        if " # " in similarity_data["cell1"][0]:
2984            similarity_data["set1"] = [
2985                re.sub(".*# ", "", x) for x in similarity_data["cell1"]
2986            ]
2987            similarity_data["set2"] = [
2988                re.sub(".*# ", "", x) for x in similarity_data["cell2"]
2989            ]
2990
2991        if split_sets and " # " in similarity_data["cell1"][0]:
2992            sets = list(
2993                set(list(similarity_data["set1"]) + list(similarity_data["set2"]))
2994            )
2995
2996            mm = math.ceil(len(sets) / 2)
2997
2998            x_s = sets[0:mm]
2999            y_s = sets[mm : len(sets)]
3000
3001            similarity_data = similarity_data[similarity_data["set1"].isin(x_s)]
3002            similarity_data = similarity_data[similarity_data["set2"].isin(y_s)]
3003
3004            similarity_data = similarity_data.sort_values(["set1", "set2"])
3005
3006        if set_info is False and " # " in similarity_data["cell1"][0]:
3007            similarity_data["cell1"] = [
3008                re.sub(" #.*", "", x) for x in similarity_data["cell1"]
3009            ]
3010            similarity_data["cell2"] = [
3011                re.sub(" #.*", "", x) for x in similarity_data["cell2"]
3012            ]
3013
3014        similarity_data["-euclidean_zscore"] = -zscore(
3015            similarity_data["euclidean_dist"]
3016        )
3017
3018        similarity_data = similarity_data[similarity_data["-euclidean_zscore"] > 0]
3019
3020        fig = plt.figure(figsize=(width, height))
3021        sns.scatterplot(
3022            data=similarity_data,
3023            x="cell1",
3024            y="cell2",
3025            hue="correlation",
3026            size="-euclidean_zscore",
3027            sizes=(1, 100),
3028            palette=cmap,
3029            alpha=1,
3030            edgecolor="black",
3031        )
3032
3033        plt.xticks(rotation=90)
3034        plt.yticks(rotation=0)
3035        plt.xlabel("Cell 1")
3036        plt.ylabel("Cell 2")
3037        plt.legend(bbox_to_anchor=(1.05, 1), loc="upper left")
3038
3039        plt.grid(True, alpha=0.6)
3040
3041        plt.tight_layout()
3042
3043        return fig

Visualize pairwise similarity as a scatter plot.

Parameters

split_sets : bool, default True If True and set information is present, split plotting area roughly into two halves to visualize cross-set pairs.

set_info : bool, default True If True, keep the ' # set' annotation in labels; otherwise strip it.

cmap : str, default 'seismic' Color map for correlation (hue).

width : int, default 12 Figure width.

height : int, default 10 Figure height.

Returns

matplotlib.figure.Figure

Raises

ValueError If self.similarity is None.

Notes

The function filters pairs by z-scored euclidean distance > 0 to focus on closer pairs.

def spatial_similarity( self, set_info: bool = True, bandwidth=1, n_neighbors=5, min_dist=0.1, legend_split=2, point_size=100, spread=1.0, set_op_mix_ratio=1.0, local_connectivity=1, repulsion_strength=1.0, negative_sample_rate=5, threshold=0.1, width=12, height=10):
3045    def spatial_similarity(
3046        self,
3047        set_info: bool = True,
3048        bandwidth=1,
3049        n_neighbors=5,
3050        min_dist=0.1,
3051        legend_split=2,
3052        point_size=100,
3053        spread=1.0,
3054        set_op_mix_ratio=1.0,
3055        local_connectivity=1,
3056        repulsion_strength=1.0,
3057        negative_sample_rate=5,
3058        threshold=0.1,
3059        width=12,
3060        height=10,
3061    ):
3062        """
3063        Create a spatial UMAP-like visualization of similarity relationships between samples.
3064
3065        Parameters
3066        ----------
3067        set_info : bool, default True
3068            If True, retain set information in labels.
3069
3070        bandwidth : float, default 1
3071            Bandwidth used by MeanShift for clustering polygons.
3072
3073        point_size : float, default 100
3074            Size of scatter points.
3075
3076        legend_split : int, default 2
3077            Number of columns in legend.
3078
3079        n_neighbors, min_dist, spread, set_op_mix_ratio, local_connectivity, repulsion_strength, negative_sample_rate : parameters passed to UMAP.
3080
3081        threshold : float, default 0.1
3082            Minimum text distance for label adjustment to avoid overlap.
3083
3084        width : int, default 12
3085            Figure width.
3086
3087        height : int, default 10
3088            Figure height.
3089
3090        Returns
3091        -------
3092        matplotlib.figure.Figure
3093
3094        Raises
3095        ------
3096        ValueError
3097            If `self.similarity` is None.
3098
3099        Notes
3100        -----
3101        Builds a precomputed distance matrix combining correlation and euclidean distance,
3102        runs UMAP with metric='precomputed', then overlays cluster hulls (MeanShift + convex hull)
3103        and arrows to indicate nearest neighbors (minimal combined distance).
3104        """
3105
3106        if self.similarity is None:
3107            raise ValueError(
3108                "Similarity data is missing. Please calculate similarity using self.estimating_similarity."
3109            )
3110
3111        similarity_data = self.similarity
3112
3113        sim = similarity_data["correlation"]
3114        sim_scaled = (sim - sim.min()) / (sim.max() - sim.min())
3115        eu_dist = similarity_data["euclidean_dist"]
3116        eu_dist_scaled = (eu_dist - eu_dist.min()) / (eu_dist.max() - eu_dist.min())
3117
3118        similarity_data["combo_dist"] = (1 - sim_scaled) * eu_dist_scaled
3119
3120        # for nn target
3121        arrow_df = similarity_data.copy()
3122        arrow_df = similarity_data.loc[
3123            similarity_data.groupby("cell1")["combo_dist"].idxmin()
3124        ]
3125
3126        cells = sorted(set(similarity_data["cell1"]) | set(similarity_data["cell2"]))
3127        combo_matrix = pd.DataFrame(0, index=cells, columns=cells, dtype=float)
3128
3129        for _, row in similarity_data.iterrows():
3130            combo_matrix.loc[row["cell1"], row["cell2"]] = row["combo_dist"]
3131            combo_matrix.loc[row["cell2"], row["cell1"]] = row["combo_dist"]
3132
3133        umap_model = umap.UMAP(
3134            n_components=2,
3135            metric="precomputed",
3136            n_neighbors=n_neighbors,
3137            min_dist=min_dist,
3138            spread=spread,
3139            set_op_mix_ratio=set_op_mix_ratio,
3140            local_connectivity=set_op_mix_ratio,
3141            repulsion_strength=repulsion_strength,
3142            negative_sample_rate=negative_sample_rate,
3143            transform_seed=42,
3144            init="spectral",
3145            random_state=42,
3146            verbose=True,
3147        )
3148
3149        coords = umap_model.fit_transform(combo_matrix.values)
3150        cell_names = list(combo_matrix.index)
3151        num_cells = len(cell_names)
3152        palette = sns.color_palette("tab20c", num_cells)
3153
3154        if "#" in cell_names[0]:
3155            avsets = set(
3156                [re.sub(".*# ", "", x) for x in similarity_data["cell1"]]
3157                + [re.sub(".*# ", "", x) for x in similarity_data["cell2"]]
3158            )
3159            num_sets = len(avsets)
3160            color_indices = [i * len(palette) // num_sets for i in range(num_sets)]
3161            color_mapping_sets = {
3162                set_name: palette[i] for i, set_name in zip(color_indices, avsets)
3163            }
3164            color_mapping = {
3165                name: color_mapping_sets[re.sub(".*# ", "", name)]
3166                for i, name in enumerate(cell_names)
3167            }
3168        else:
3169            color_mapping = {name: palette[i] for i, name in enumerate(cell_names)}
3170
3171        meanshift = MeanShift(bandwidth=bandwidth)
3172        labels = meanshift.fit_predict(coords)
3173
3174        fig = plt.figure(figsize=(width, height))
3175        ax = plt.gca()
3176
3177        unique_labels = set(labels)
3178        cluster_palette = sns.color_palette("hls", len(unique_labels))
3179
3180        for label in unique_labels:
3181            if label == -1:
3182                continue
3183            cluster_coords = coords[labels == label]
3184            if len(cluster_coords) < 3:
3185                continue
3186
3187            hull = ConvexHull(cluster_coords)
3188            hull_points = cluster_coords[hull.vertices]
3189
3190            centroid = np.mean(hull_points, axis=0)
3191            expanded = hull_points + 0.05 * (hull_points - centroid)
3192
3193            poly = Polygon(
3194                expanded,
3195                closed=True,
3196                facecolor=cluster_palette[label],
3197                edgecolor="none",
3198                alpha=0.2,
3199                zorder=1,
3200            )
3201            ax.add_patch(poly)
3202
3203        texts = []
3204        for i, (x, y) in enumerate(coords):
3205            plt.scatter(
3206                x,
3207                y,
3208                s=point_size,
3209                color=color_mapping[cell_names[i]],
3210                edgecolors="black",
3211                linewidths=0.5,
3212                zorder=2,
3213            )
3214            texts.append(
3215                ax.text(
3216                    x, y, str(i), ha="center", va="center", fontsize=8, color="black"
3217                )
3218            )
3219
3220        def dist(p1, p2):
3221            return np.sqrt((p1[0] - p2[0]) ** 2 + (p1[1] - p2[1]) ** 2)
3222
3223        texts_to_adjust = []
3224        for i, t1 in enumerate(texts):
3225            for j, t2 in enumerate(texts):
3226                if i >= j:
3227                    continue
3228                d = dist(
3229                    (t1.get_position()[0], t1.get_position()[1]),
3230                    (t2.get_position()[0], t2.get_position()[1]),
3231                )
3232                if d < threshold:
3233                    if t1 not in texts_to_adjust:
3234                        texts_to_adjust.append(t1)
3235                    if t2 not in texts_to_adjust:
3236                        texts_to_adjust.append(t2)
3237
3238        adjust_text(
3239            texts_to_adjust,
3240            expand_text=(1.0, 1.0),
3241            force_text=0.9,
3242            arrowprops=dict(arrowstyle="-", color="gray", lw=0.1),
3243            ax=ax,
3244        )
3245
3246        for _, row in arrow_df.iterrows():
3247            try:
3248                idx1 = cell_names.index(row["cell1"])
3249                idx2 = cell_names.index(row["cell2"])
3250            except ValueError:
3251                continue
3252            x1, y1 = coords[idx1]
3253            x2, y2 = coords[idx2]
3254            arrow = FancyArrowPatch(
3255                (x1, y1),
3256                (x2, y2),
3257                arrowstyle="->",
3258                color="gray",
3259                linewidth=1.5,
3260                alpha=0.5,
3261                mutation_scale=12,
3262                zorder=0,
3263            )
3264            ax.add_patch(arrow)
3265
3266        if set_info is False and " # " in cell_names[0]:
3267
3268            legend_elements = [
3269                Patch(
3270                    facecolor=color_mapping[name],
3271                    edgecolor="black",
3272                    label=f"{i}{re.sub(' #.*', '', name)}",
3273                )
3274                for i, name in enumerate(cell_names)
3275            ]
3276
3277        else:
3278
3279            legend_elements = [
3280                Patch(
3281                    facecolor=color_mapping[name],
3282                    edgecolor="black",
3283                    label=f"{i}{name}",
3284                )
3285                for i, name in enumerate(cell_names)
3286            ]
3287
3288        plt.legend(
3289            handles=legend_elements,
3290            title="Cells",
3291            bbox_to_anchor=(1.05, 1),
3292            loc="upper left",
3293            ncol=legend_split,
3294        )
3295
3296        plt.xlabel("UMAP 1")
3297        plt.ylabel("UMAP 2")
3298        plt.grid(False)
3299        plt.show()
3300
3301        return fig

Create a spatial UMAP-like visualization of similarity relationships between samples.

Parameters

set_info : bool, default True If True, retain set information in labels.

bandwidth : float, default 1 Bandwidth used by MeanShift for clustering polygons.

point_size : float, default 100 Size of scatter points.

legend_split : int, default 2 Number of columns in legend.

n_neighbors, min_dist, spread, set_op_mix_ratio, local_connectivity, repulsion_strength, negative_sample_rate : parameters passed to UMAP.

threshold : float, default 0.1 Minimum text distance for label adjustment to avoid overlap.

width : int, default 12 Figure width.

height : int, default 10 Figure height.

Returns

matplotlib.figure.Figure

Raises

ValueError If self.similarity is None.

Notes

Builds a precomputed distance matrix combining correlation and euclidean distance, runs UMAP with metric='precomputed', then overlays cluster hulls (MeanShift + convex hull) and arrows to indicate nearest neighbors (minimal combined distance).

def subcluster_prepare(self, features: list, cluster: str):
3305    def subcluster_prepare(self, features: list, cluster: str):
3306        """
3307        Prepare a `Clustering` object for subcluster analysis on a selected parent cluster.
3308
3309        Parameters
3310        ----------
3311        features : list
3312            Features to include for subcluster analysis.
3313
3314        cluster : str
3315            Parent cluster name (used to select matching cells).
3316
3317        Update
3318        ------
3319        Initializes `self.subclusters_` as a new `Clustering` instance containing the
3320        reduced data for the given cluster and stores `current_features` and `current_cluster`.
3321        """
3322
3323        dat = self.normalized_data
3324        dat.columns = list(self.input_metadata["cell_names"])
3325
3326        dat = reduce_data(self.normalized_data, features=features, names=[cluster])
3327
3328        self.subclusters_ = Clustering(data=dat, metadata=None)
3329
3330        self.subclusters_.current_features = features
3331        self.subclusters_.current_cluster = cluster

Prepare a Clustering object for subcluster analysis on a selected parent cluster.

Parameters

features : list Features to include for subcluster analysis.

cluster : str Parent cluster name (used to select matching cells).

Update

Initializes self.subclusters_ as a new Clustering instance containing the reduced data for the given cluster and stores current_features and current_cluster.

def define_subclusters( self, umap_num: int = 2, eps: float = 0.5, min_samples: int = 10, n_neighbors: int = 5, min_dist: float = 0.1, spread: float = 1.0, set_op_mix_ratio: float = 1.0, local_connectivity: int = 1, repulsion_strength: float = 1.0, negative_sample_rate: int = 5, width=8, height=6):
3333    def define_subclusters(
3334        self,
3335        umap_num: int = 2,
3336        eps: float = 0.5,
3337        min_samples: int = 10,
3338        n_neighbors: int = 5,
3339        min_dist: float = 0.1,
3340        spread: float = 1.0,
3341        set_op_mix_ratio: float = 1.0,
3342        local_connectivity: int = 1,
3343        repulsion_strength: float = 1.0,
3344        negative_sample_rate: int = 5,
3345        width=8,
3346        height=6,
3347    ):
3348        """
3349        Compute UMAP and DBSCAN clustering within a previously prepared subcluster dataset.
3350
3351        Parameters
3352        ----------
3353        umap_num : int, default 2
3354            Number of UMAP dimensions to compute.
3355
3356        eps : float, default 0.5
3357            DBSCAN eps parameter.
3358
3359        min_samples : int, default 10
3360            DBSCAN min_samples parameter.
3361
3362        n_neighbors, min_dist, spread, set_op_mix_ratio, local_connectivity, repulsion_strength, negative_sample_rate, width, height :
3363            Additional parameters passed to UMAP / plotting / MeanShift as appropriate.
3364
3365        Update
3366        -------
3367        Stores cluster labels in `self.subclusters_.subclusters`.
3368
3369        Raises
3370        ------
3371        RuntimeError
3372            If `self.subclusters_` has not been prepared.
3373        """
3374
3375        if self.subclusters_ is None:
3376            raise RuntimeError(
3377                "Nothing to return. 'self.subcluster_prepare' was not conducted!"
3378            )
3379
3380        self.subclusters_.perform_UMAP(
3381            factorize=False,
3382            umap_num=umap_num,
3383            pc_num=0,
3384            harmonized=False,
3385            n_neighbors=n_neighbors,
3386            min_dist=min_dist,
3387            spread=spread,
3388            set_op_mix_ratio=set_op_mix_ratio,
3389            local_connectivity=local_connectivity,
3390            repulsion_strength=repulsion_strength,
3391            negative_sample_rate=negative_sample_rate,
3392            width=width,
3393            height=height,
3394        )
3395
3396        fig = self.subclusters_.find_clusters_UMAP(
3397            umap_n=umap_num,
3398            eps=eps,
3399            min_samples=min_samples,
3400            width=width,
3401            height=height,
3402        )
3403
3404        clusters = self.subclusters_.return_clusters(clusters="umap")
3405
3406        self.subclusters_.subclusters = [str(x) for x in list(clusters)]
3407
3408        return fig

Compute UMAP and DBSCAN clustering within a previously prepared subcluster dataset.

Parameters

umap_num : int, default 2 Number of UMAP dimensions to compute.

eps : float, default 0.5 DBSCAN eps parameter.

min_samples : int, default 10 DBSCAN min_samples parameter.

n_neighbors, min_dist, spread, set_op_mix_ratio, local_connectivity, repulsion_strength, negative_sample_rate, width, height : Additional parameters passed to UMAP / plotting / MeanShift as appropriate.

Update

Stores cluster labels in self.subclusters_.subclusters.

Raises

RuntimeError If self.subclusters_ has not been prepared.

def subcluster_features_scatter( self, colors='viridis', hclust='complete', scale=False, img_width=3, img_high=5, label_size=6, size_scale=70, y_lab='Genes', legend_lab='normalized', bbox_to_anchor_scale: int = 25, bbox_to_anchor_perc: tuple = (0.91, 0.63)):
3410    def subcluster_features_scatter(
3411        self,
3412        colors="viridis",
3413        hclust="complete",
3414        scale=False,
3415        img_width=3,
3416        img_high=5,
3417        label_size=6,
3418        size_scale=70,
3419        y_lab="Genes",
3420        legend_lab="normalized",
3421        bbox_to_anchor_scale: int = 25,
3422        bbox_to_anchor_perc: tuple = (0.91, 0.63),
3423    ):
3424        """
3425        Create a features-scatter visualization for the subclusters (averaged and occurrence).
3426
3427        Parameters
3428        ----------
3429        colors : str, default 'viridis'
3430            Colormap name passed to `features_scatter`.
3431
3432        hclust : str or None
3433            Hierarchical clustering linkage to order rows/columns.
3434
3435        scale: bool, default False
3436            If True, expression data will be scaled (0–1) across the rows (features).
3437
3438        img_width, img_high : float
3439            Figure size.
3440
3441        label_size : int
3442            Font size for labels.
3443
3444        size_scale : int
3445            Bubble size scaling.
3446
3447        y_lab : str
3448            X axis label.
3449
3450        legend_lab : str
3451            Colorbar label.
3452
3453        bbox_to_anchor_scale : int, default=25
3454            Vertical scale (percentage) for positioning the colorbar.
3455
3456        bbox_to_anchor_perc : tuple, default=(0.91, 0.63)
3457            Anchor position for the size legend (percent bubble legend).
3458
3459        Returns
3460        -------
3461        matplotlib.figure.Figure
3462
3463        Raises
3464        ------
3465        RuntimeError
3466            If subcluster preparation/definition has not been run.
3467        """
3468
3469        if self.subclusters_ is None or self.subclusters_.subclusters is None:
3470            raise RuntimeError(
3471                "Nothing to return. 'self.subcluster_prepare' -> 'self.define_subclusters' pip was not conducted!"
3472            )
3473
3474        dat = self.normalized_data
3475        dat.columns = list(self.input_metadata["cell_names"])
3476
3477        dat = reduce_data(
3478            self.normalized_data,
3479            features=self.subclusters_.current_features,
3480            names=[self.subclusters_.current_cluster],
3481        )
3482
3483        dat.columns = self.subclusters_.subclusters
3484
3485        avg = average(dat)
3486        occ = occurrence(dat)
3487
3488        scatter = features_scatter(
3489            expression_data=avg,
3490            occurence_data=occ,
3491            features=None,
3492            scale=scale,
3493            metadata_list=None,
3494            colors=colors,
3495            hclust=hclust,
3496            img_width=img_width,
3497            img_high=img_high,
3498            label_size=label_size,
3499            size_scale=size_scale,
3500            y_lab=y_lab,
3501            legend_lab=legend_lab,
3502            bbox_to_anchor_scale=bbox_to_anchor_scale,
3503            bbox_to_anchor_perc=bbox_to_anchor_perc,
3504        )
3505
3506        return scatter

Create a features-scatter visualization for the subclusters (averaged and occurrence).

Parameters

colors : str, default 'viridis' Colormap name passed to features_scatter.

hclust : str or None Hierarchical clustering linkage to order rows/columns.

scale: bool, default False If True, expression data will be scaled (0–1) across the rows (features).

img_width, img_high : float Figure size.

label_size : int Font size for labels.

size_scale : int Bubble size scaling.

y_lab : str X axis label.

legend_lab : str Colorbar label.

bbox_to_anchor_scale : int, default=25 Vertical scale (percentage) for positioning the colorbar.

bbox_to_anchor_perc : tuple, default=(0.91, 0.63) Anchor position for the size legend (percent bubble legend).

Returns

matplotlib.figure.Figure

Raises

RuntimeError If subcluster preparation/definition has not been run.

def subcluster_DEG_scatter( self, top_n=3, min_exp=0, min_pct=0.25, p_val=0.05, colors='viridis', hclust='complete', scale=False, img_width=3, img_high=5, label_size=6, size_scale=70, y_lab='Genes', legend_lab='normalized', bbox_to_anchor_scale: int = 25, bbox_to_anchor_perc: tuple = (0.91, 0.63), n_proc=10):
3508    def subcluster_DEG_scatter(
3509        self,
3510        top_n=3,
3511        min_exp=0,
3512        min_pct=0.25,
3513        p_val=0.05,
3514        colors="viridis",
3515        hclust="complete",
3516        scale=False,
3517        img_width=3,
3518        img_high=5,
3519        label_size=6,
3520        size_scale=70,
3521        y_lab="Genes",
3522        legend_lab="normalized",
3523        bbox_to_anchor_scale: int = 25,
3524        bbox_to_anchor_perc: tuple = (0.91, 0.63),
3525        n_proc=10,
3526    ):
3527        """
3528        Plot top differential features (DEGs) for subclusters as a features-scatter.
3529
3530        Parameters
3531        ----------
3532        top_n : int, default 3
3533            Number of top features per subcluster to show.
3534
3535        min_exp : float, default 0
3536            Minimum expression threshold passed to `statistic`.
3537
3538        min_pct : float, default 0.25
3539            Minimum percent expressed in target group.
3540
3541        p_val: float, default 0.05
3542            Maximum p-value for visualizing features.
3543
3544        n_proc : int, default 10
3545            Parallel jobs used for DEG calculation.
3546
3547        scale: bool, default False
3548            If True, expression_data will be scaled (0–1) across the rows (features).
3549
3550        colors : str, default='viridis'
3551            Colormap for expression values.
3552
3553        hclust : str or None, default='complete'
3554            Linkage method for hierarchical clustering. If None, no clustering
3555            is performed.
3556
3557        img_width : int or float, default=8
3558            Width of the plot in inches.
3559
3560        img_high : int or float, default=5
3561            Height of the plot in inches.
3562
3563        label_size : int, default=10
3564            Font size for axis labels and ticks.
3565
3566        size_scale : int or float, default=100
3567            Scaling factor for bubble sizes.
3568
3569        y_lab : str, default='Genes'
3570            Label for the x-axis.
3571
3572        legend_lab : str, default='normalized'
3573            Label for the colorbar legend.
3574
3575        bbox_to_anchor_scale : int, default=25
3576            Vertical scale (percentage) for positioning the colorbar.
3577
3578        bbox_to_anchor_perc : tuple, default=(0.91, 0.63)
3579            Anchor position for the size legend (percent bubble legend).
3580
3581        Returns
3582        -------
3583        matplotlib.figure.Figure
3584
3585        Raises
3586        ------
3587        RuntimeError
3588            If subcluster preparation/definition has not been run.
3589
3590        Notes
3591        -----
3592        Internally calls `calc_DEG` (or equivalent) to obtain statistics, filters
3593        by p-value and effect-size, selects top features per valid group and plots them.
3594        """
3595
3596        if self.subclusters_ is None or self.subclusters_.subclusters is None:
3597            raise RuntimeError(
3598                "Nothing to return. 'self.subcluster_prepare' -> 'self.define_subclusters' pip was not conducted!"
3599            )
3600
3601        dat = self.normalized_data
3602        dat.columns = list(self.input_metadata["cell_names"])
3603
3604        dat = reduce_data(
3605            self.normalized_data, names=[self.subclusters_.current_cluster]
3606        )
3607
3608        dat.columns = self.subclusters_.subclusters
3609
3610        deg_stats = calc_DEG(
3611            dat,
3612            metadata_list=None,
3613            entities="All",
3614            sets=None,
3615            min_exp=min_exp,
3616            min_pct=min_pct,
3617            n_proc=n_proc,
3618        )
3619
3620        deg_stats = deg_stats[deg_stats["p_val"] <= p_val]
3621        deg_stats = deg_stats[deg_stats["log(FC)"] > 0]
3622
3623        deg_stats = (
3624            deg_stats.sort_values(
3625                ["valid_group", "esm", "log(FC)"], ascending=[True, False, False]
3626            )
3627            .groupby("valid_group")
3628            .head(top_n)
3629        )
3630
3631        dat = reduce_data(dat, features=list(set(deg_stats["feature"])))
3632
3633        avg = average(dat)
3634        occ = occurrence(dat)
3635
3636        scatter = features_scatter(
3637            expression_data=avg,
3638            occurence_data=occ,
3639            features=None,
3640            metadata_list=None,
3641            colors=colors,
3642            hclust=hclust,
3643            img_width=img_width,
3644            img_high=img_high,
3645            label_size=label_size,
3646            size_scale=size_scale,
3647            y_lab=y_lab,
3648            legend_lab=legend_lab,
3649            bbox_to_anchor_scale=bbox_to_anchor_scale,
3650            bbox_to_anchor_perc=bbox_to_anchor_perc,
3651        )
3652
3653        return scatter

Plot top differential features (DEGs) for subclusters as a features-scatter.

Parameters

top_n : int, default 3 Number of top features per subcluster to show.

min_exp : float, default 0 Minimum expression threshold passed to statistic.

min_pct : float, default 0.25 Minimum percent expressed in target group.

p_val: float, default 0.05 Maximum p-value for visualizing features.

n_proc : int, default 10 Parallel jobs used for DEG calculation.

scale: bool, default False If True, expression_data will be scaled (0–1) across the rows (features).

colors : str, default='viridis' Colormap for expression values.

hclust : str or None, default='complete' Linkage method for hierarchical clustering. If None, no clustering is performed.

img_width : int or float, default=8 Width of the plot in inches.

img_high : int or float, default=5 Height of the plot in inches.

label_size : int, default=10 Font size for axis labels and ticks.

size_scale : int or float, default=100 Scaling factor for bubble sizes.

y_lab : str, default='Genes' Label for the x-axis.

legend_lab : str, default='normalized' Label for the colorbar legend.

bbox_to_anchor_scale : int, default=25 Vertical scale (percentage) for positioning the colorbar.

bbox_to_anchor_perc : tuple, default=(0.91, 0.63) Anchor position for the size legend (percent bubble legend).

Returns

matplotlib.figure.Figure

Raises

RuntimeError If subcluster preparation/definition has not been run.

Notes

Internally calls calc_DEG (or equivalent) to obtain statistics, filters by p-value and effect-size, selects top features per valid group and plots them.

def accept_subclusters(self):
3655    def accept_subclusters(self):
3656        """
3657        Commit subcluster labels into the main `input_metadata` by renaming cell names.
3658
3659        The method replaces occurrences of the parent cluster name in `self.input_metadata['cell_names']`
3660        with the expanded names that include subcluster suffixes (via `add_subnames`),
3661        then clears `self.subclusters_`.
3662
3663        Update
3664        ------
3665        Modifies `self.input_metadata['cell_names']`.
3666
3667        Resets `self.subclusters_` to None.
3668
3669        Raises
3670        ------
3671        RuntimeError
3672            If `self.subclusters_` is not defined or subclusters were not computed.
3673        """
3674
3675        if self.subclusters_ is None or self.subclusters_.subclusters is None:
3676            raise RuntimeError(
3677                "Nothing to return. 'self.subcluster_prepare' -> 'self.define_subclusters' pip was not conducted!"
3678            )
3679
3680        new_meta = add_subnames(
3681            list(self.input_metadata["cell_names"]),
3682            parent_name=self.subclusters_.current_cluster,
3683            new_clusters=self.subclusters_.subclusters,
3684        )
3685
3686        self.input_metadata["cell_names"] = new_meta
3687
3688        self.subclusters_ = None

Commit subcluster labels into the main input_metadata by renaming cell names.

The method replaces occurrences of the parent cluster name in self.input_metadata['cell_names'] with the expanded names that include subcluster suffixes (via add_subnames), then clears self.subclusters_.

Update

Modifies self.input_metadata['cell_names'].

Resets self.subclusters_ to None.

Raises

RuntimeError If self.subclusters_ is not defined or subclusters were not computed.

def scatter_plot( self, names: list | None = None, features: list | None = None, name_slot: str = 'cell_names', scale=True, colors='viridis', hclust=None, img_width=15, img_high=1, label_size=10, size_scale=200, y_lab='Genes', legend_lab='log(CPM + 1)', set_box_size: float | int = 5, set_box_high: float | int = 5, bbox_to_anchor_scale=25, bbox_to_anchor_perc=(0.9, -0.24), bbox_to_anchor_group=(1.01, 0.4)):
3690    def scatter_plot(
3691        self,
3692        names: list | None = None,
3693        features: list | None = None,
3694        name_slot: str = "cell_names",
3695        scale=True,
3696        colors="viridis",
3697        hclust=None,
3698        img_width=15,
3699        img_high=1,
3700        label_size=10,
3701        size_scale=200,
3702        y_lab="Genes",
3703        legend_lab="log(CPM + 1)",
3704        set_box_size: float | int = 5,
3705        set_box_high: float | int = 5,
3706        bbox_to_anchor_scale=25,
3707        bbox_to_anchor_perc=(0.90, -0.24),
3708        bbox_to_anchor_group=(1.01, 0.4),
3709    ):
3710        """
3711        Create a bubble scatter plot of selected features across samples inside project.
3712
3713        Each point represents a feature-sample pair, where the color encodes the
3714        expression value and the size encodes occurrence or relative abundance.
3715        Optionally, hierarchical clustering can be applied to order rows and columns.
3716
3717        Parameters
3718        ----------
3719        names : list, str, or None
3720            Names of samples to include. If None, all samples are considered.
3721
3722        features : list, str, or None
3723            Names of features to include. If None, all features are considered.
3724
3725        name_slot : str
3726            Column in metadata to use as sample names.
3727
3728        scale: bool, default False
3729            If True, expression_data will be scaled (0–1) across the rows (features).
3730
3731        colors : str, default='viridis'
3732            Colormap for expression values.
3733
3734        hclust : str or None, default='complete'
3735            Linkage method for hierarchical clustering. If None, no clustering
3736            is performed.
3737
3738        img_width : int or float, default=8
3739            Width of the plot in inches.
3740
3741        img_high : int or float, default=5
3742            Height of the plot in inches.
3743
3744        label_size : int, default=10
3745            Font size for axis labels and ticks.
3746
3747        size_scale : int or float, default=100
3748            Scaling factor for bubble sizes.
3749
3750        y_lab : str, default='Genes'
3751            Label for the x-axis.
3752
3753        legend_lab : str, default='log(CPM + 1)'
3754            Label for the colorbar legend.
3755
3756        bbox_to_anchor_scale : int, default=25
3757            Vertical scale (percentage) for positioning the colorbar.
3758
3759        bbox_to_anchor_perc : tuple, default=(0.91, 0.63)
3760            Anchor position for the size legend (percent bubble legend).
3761
3762        bbox_to_anchor_group : tuple, default=(1.01, 0.4)
3763            Anchor position for the group legend.
3764
3765        Returns
3766        -------
3767        matplotlib.figure.Figure
3768            The generated scatter plot figure.
3769
3770        Notes
3771        -----
3772        Colors represent expression values normalized to the colormap.
3773        """
3774
3775        prtd, met = self.get_partial_data(
3776            names=names, features=features, name_slot=name_slot, inc_metadata=True
3777        )
3778
3779        prtd.columns = prtd.columns + "#" + met["sets"]
3780
3781        prtd_avg = average(prtd)
3782
3783        meta_sets = [re.sub(".*#", "", x) for x in prtd_avg.columns]
3784
3785        prtd_avg.columns = [re.sub("#.*", "", x) for x in prtd_avg.columns]
3786
3787        prtd_occ = occurrence(prtd)
3788
3789        prtd_occ.columns = [re.sub("#.*", "", x) for x in prtd_occ.columns]
3790
3791        fig_scatter = features_scatter(
3792            expression_data=prtd_avg,
3793            occurence_data=prtd_occ,
3794            scale=scale,
3795            features=None,
3796            metadata_list=meta_sets,
3797            colors=colors,
3798            hclust=hclust,
3799            img_width=img_width,
3800            img_high=img_high,
3801            label_size=label_size,
3802            size_scale=size_scale,
3803            y_lab=y_lab,
3804            legend_lab=legend_lab,
3805            set_box_size=set_box_size,
3806            set_box_high=set_box_high,
3807            bbox_to_anchor_scale=bbox_to_anchor_scale,
3808            bbox_to_anchor_perc=bbox_to_anchor_perc,
3809            bbox_to_anchor_group=bbox_to_anchor_group,
3810        )
3811
3812        return fig_scatter

Create a bubble scatter plot of selected features across samples inside project.

Each point represents a feature-sample pair, where the color encodes the expression value and the size encodes occurrence or relative abundance. Optionally, hierarchical clustering can be applied to order rows and columns.

Parameters

names : list, str, or None Names of samples to include. If None, all samples are considered.

features : list, str, or None Names of features to include. If None, all features are considered.

name_slot : str Column in metadata to use as sample names.

scale: bool, default False If True, expression_data will be scaled (0–1) across the rows (features).

colors : str, default='viridis' Colormap for expression values.

hclust : str or None, default='complete' Linkage method for hierarchical clustering. If None, no clustering is performed.

img_width : int or float, default=8 Width of the plot in inches.

img_high : int or float, default=5 Height of the plot in inches.

label_size : int, default=10 Font size for axis labels and ticks.

size_scale : int or float, default=100 Scaling factor for bubble sizes.

y_lab : str, default='Genes' Label for the x-axis.

legend_lab : str, default='log(CPM + 1)' Label for the colorbar legend.

bbox_to_anchor_scale : int, default=25 Vertical scale (percentage) for positioning the colorbar.

bbox_to_anchor_perc : tuple, default=(0.91, 0.63) Anchor position for the size legend (percent bubble legend).

bbox_to_anchor_group : tuple, default=(1.01, 0.4) Anchor position for the group legend.

Returns

matplotlib.figure.Figure The generated scatter plot figure.

Notes

Colors represent expression values normalized to the colormap.

def data_composition( self, features_count: list | None, name_slot: str = 'cell_names', set_sep: bool = True):
3814    def data_composition(
3815        self,
3816        features_count: list | None,
3817        name_slot: str = "cell_names",
3818        set_sep: bool = True,
3819    ):
3820        """
3821        Compute composition of cell types in data set.
3822
3823        This function counts the occurrences of specific cells (e.g., cell types, subtypes)
3824        within metadata entries, calculates their relative percentages, and stores
3825        the results in `self.composition_data`.
3826
3827        Parameters
3828        ----------
3829        features_count : list or None
3830            List of features (part or full names) to be counted.
3831            If None, all unique elements from the specified `name_slot` metadata field are used.
3832
3833        name_slot : str, default 'cell_names'
3834            Metadata field containing sample identifiers or labels.
3835
3836        set_sep : bool, default True
3837            If True and multiple sets are present in metadata, compute composition
3838            separately for each set.
3839
3840        Update
3841        -------
3842        Stores results in `self.composition_data` as a pandas DataFrame with:
3843        - 'name': feature name
3844        - 'n': number of occurrences
3845        - 'pct': percentage of occurrences
3846        - 'set' (if applicable): dataset identifier
3847        """
3848
3849        validated_list = list(self.input_metadata[name_slot])
3850        sets = list(self.input_metadata["sets"])
3851
3852        if features_count is None:
3853            features_count = list(set(self.input_metadata[name_slot]))
3854
3855        if set_sep and len(set(sets)) > 1:
3856
3857            final_res = pd.DataFrame()
3858
3859            for s in set(sets):
3860                print(s)
3861
3862                mask = [True if s == x else False for x in sets]
3863
3864                tmp_val_list = np.array(validated_list)
3865
3866                tmp_val_list = list(tmp_val_list[mask])
3867
3868                res_dict = {"name": [], "n": [], "set": []}
3869
3870                for f in tqdm(features_count):
3871                    res_dict["n"].append(
3872                        sum(1 for element in tmp_val_list if f in element)
3873                    )
3874                    res_dict["name"].append(f)
3875                    res_dict["set"].append(s)
3876                    res = pd.DataFrame(res_dict)
3877                    res["pct"] = res["n"] / sum(res["n"]) * 100
3878                    res["pct"] = res["pct"].round(2)
3879
3880                final_res = pd.concat([final_res, res])
3881
3882            res = final_res.sort_values(["set", "pct"], ascending=[True, False])
3883
3884        else:
3885
3886            res_dict = {"name": [], "n": []}
3887
3888            for f in tqdm(features_count):
3889                res_dict["n"].append(
3890                    sum(1 for element in validated_list if f in element)
3891                )
3892                res_dict["name"].append(f)
3893
3894            res = pd.DataFrame(res_dict)
3895            res["pct"] = res["n"] / sum(res["n"]) * 100
3896            res["pct"] = res["pct"].round(2)
3897
3898            res = res.sort_values("pct", ascending=False)
3899
3900        self.composition_data = res

Compute composition of cell types in data set.

This function counts the occurrences of specific cells (e.g., cell types, subtypes) within metadata entries, calculates their relative percentages, and stores the results in self.composition_data.

Parameters

features_count : list or None List of features (part or full names) to be counted. If None, all unique elements from the specified name_slot metadata field are used.

name_slot : str, default 'cell_names' Metadata field containing sample identifiers or labels.

set_sep : bool, default True If True and multiple sets are present in metadata, compute composition separately for each set.

Update

Stores results in self.composition_data as a pandas DataFrame with:

  • 'name': feature name
  • 'n': number of occurrences
  • 'pct': percentage of occurrences
  • 'set' (if applicable): dataset identifier
def composition_pie( self, width=6, height=6, font_size=15, cmap: str = 'tab20', legend_split_col: int = 1, offset_labels: float | int = 0.5, legend_bbox: tuple = (1.15, 0.95)):
3902    def composition_pie(
3903        self,
3904        width=6,
3905        height=6,
3906        font_size=15,
3907        cmap: str = "tab20",
3908        legend_split_col: int = 1,
3909        offset_labels: float | int = 0.5,
3910        legend_bbox: tuple = (1.15, 0.95),
3911    ):
3912        """
3913        Visualize the composition of cell lineages using pie charts.
3914
3915        Generates pie charts showing the relative proportions of features stored
3916        in `self.composition_data`. If multiple sets are present, a separate
3917        chart is drawn for each set.
3918
3919        Parameters
3920        ----------
3921        width : int, default 6
3922            Width of the figure.
3923
3924        height : int, default 6
3925            Height of the figure (applied per set if multiple sets are plotted).
3926
3927        font_size : int, default 15
3928            Font size for labels and annotations.
3929
3930        cmap : str, default 'tab20'
3931            Colormap used for pie slices.
3932
3933        legend_split_col : int, default 1
3934            Number of columns in the legend.
3935
3936        offset_labels : float or int, default 0.5
3937            Spacing offset for label placement relative to pie slices.
3938
3939        legend_bbox : tuple, default (1.15, 0.95)
3940            Bounding box anchor position for the legend.
3941
3942        Returns
3943        -------
3944        matplotlib.figure.Figure
3945            Pie chart visualization of composition data.
3946        """
3947
3948        df = self.composition_data
3949
3950        if "set" in df.columns and len(set(df["set"])) > 1:
3951
3952            sets = list(set(df["set"]))
3953            fig, axes = plt.subplots(len(sets), 1, figsize=(width, height * len(sets)))
3954
3955            all_wedges = []
3956            cmap = plt.get_cmap(cmap)
3957
3958            set_nam = len(set(df["name"]))
3959
3960            legend_labels = list(set(df["name"]))
3961
3962            colors = [cmap(i / set_nam) for i in range(set_nam)]
3963
3964            cmap_dict = dict(zip(legend_labels, colors))
3965
3966            for idx, s in enumerate(sets):
3967                ax = axes[idx]
3968                tmp_df = df[df["set"] == s].reset_index(drop=True)
3969
3970                labels = [f"{row['pct']:.1f}%" for _, row in tmp_df.iterrows()]
3971
3972                wedges, _ = ax.pie(
3973                    tmp_df["n"],
3974                    startangle=90,
3975                    labeldistance=1.05,
3976                    colors=[cmap_dict[x] for x in tmp_df["name"]],
3977                    wedgeprops={"linewidth": 0.5, "edgecolor": "black"},
3978                )
3979
3980                all_wedges.extend(wedges)
3981
3982                kw = dict(arrowprops=dict(arrowstyle="-"), zorder=0, va="center")
3983                n = 0
3984                for i, p in enumerate(wedges):
3985                    ang = (p.theta2 - p.theta1) / 2.0 + p.theta1
3986                    y = np.sin(np.deg2rad(ang))
3987                    x = np.cos(np.deg2rad(ang))
3988                    horizontalalignment = {-1: "right", 1: "left"}[int(np.sign(x))]
3989                    connectionstyle = f"angle,angleA=0,angleB={ang}"
3990                    kw["arrowprops"].update({"connectionstyle": connectionstyle})
3991                    if len(labels[i]) > 0:
3992                        n += offset_labels
3993                        ax.annotate(
3994                            labels[i],
3995                            xy=(x, y),
3996                            xytext=(1.01 * x + (n * x / 4), 1.01 * y + (n * y / 4)),
3997                            horizontalalignment=horizontalalignment,
3998                            fontsize=font_size,
3999                            weight="bold",
4000                            **kw,
4001                        )
4002
4003                circle2 = plt.Circle((0, 0), 0.6, color="white", ec="black")
4004                ax.add_artist(circle2)
4005
4006                ax.text(
4007                    0,
4008                    0,
4009                    f"{s}",
4010                    ha="center",
4011                    va="center",
4012                    fontsize=font_size,
4013                    weight="bold",
4014                )
4015
4016            legend_handles = [
4017                Patch(facecolor=cmap_dict[label], edgecolor="black", label=label)
4018                for label in legend_labels
4019            ]
4020
4021            fig.legend(
4022                handles=legend_handles,
4023                loc="center right",
4024                bbox_to_anchor=legend_bbox,
4025                ncol=legend_split_col,
4026                title="",
4027            )
4028
4029            plt.tight_layout()
4030            plt.show()
4031
4032        else:
4033
4034            labels = [f"{row['pct']:.1f}%" for _, row in df.iterrows()]
4035
4036            legend_labels = [f"{row['name']}" for _, row in df.iterrows()]
4037
4038            cmap = plt.get_cmap(cmap)
4039            colors = [cmap(i / len(df)) for i in range(len(df))]
4040
4041            fig, ax = plt.subplots(
4042                figsize=(width, height), subplot_kw=dict(aspect="equal")
4043            )
4044
4045            wedges, _ = ax.pie(
4046                df["n"],
4047                startangle=90,
4048                labeldistance=1.05,
4049                colors=colors,
4050                wedgeprops={"linewidth": 0.5, "edgecolor": "black"},
4051            )
4052
4053            kw = dict(arrowprops=dict(arrowstyle="-"), zorder=0, va="center")
4054            n = 0
4055            for i, p in enumerate(wedges):
4056                ang = (p.theta2 - p.theta1) / 2.0 + p.theta1
4057                y = np.sin(np.deg2rad(ang))
4058                x = np.cos(np.deg2rad(ang))
4059                horizontalalignment = {-1: "right", 1: "left"}[int(np.sign(x))]
4060                connectionstyle = "angle,angleA=0,angleB={}".format(ang)
4061                kw["arrowprops"].update({"connectionstyle": connectionstyle})
4062                if len(labels[i]) > 0:
4063                    n += offset_labels
4064
4065                    ax.annotate(
4066                        labels[i],
4067                        xy=(x, y),
4068                        xytext=(1.01 * x + (n * x / 4), y * 1.01 + (n * y / 4)),
4069                        horizontalalignment=horizontalalignment,
4070                        fontsize=font_size,
4071                        weight="bold",
4072                        **kw,
4073                    )
4074
4075            circle2 = plt.Circle((0, 0), 0.6, color="white")
4076            circle2.set_edgecolor("black")
4077
4078            p = plt.gcf()
4079            p.gca().add_artist(circle2)
4080
4081            ax.legend(
4082                wedges,
4083                legend_labels,
4084                title="",
4085                loc="center left",
4086                bbox_to_anchor=legend_bbox,
4087                ncol=legend_split_col,
4088            )
4089
4090            plt.show()
4091
4092        return fig

Visualize the composition of cell lineages using pie charts.

Generates pie charts showing the relative proportions of features stored in self.composition_data. If multiple sets are present, a separate chart is drawn for each set.

Parameters

width : int, default 6 Width of the figure.

height : int, default 6 Height of the figure (applied per set if multiple sets are plotted).

font_size : int, default 15 Font size for labels and annotations.

cmap : str, default 'tab20' Colormap used for pie slices.

legend_split_col : int, default 1 Number of columns in the legend.

offset_labels : float or int, default 0.5 Spacing offset for label placement relative to pie slices.

legend_bbox : tuple, default (1.15, 0.95) Bounding box anchor position for the legend.

Returns

matplotlib.figure.Figure Pie chart visualization of composition data.

def bar_composition( self, cmap='tab20b', width=2, height=6, font_size=15, legend_split_col: int = 1, legend_bbox: tuple = (1.3, 1)):
4094    def bar_composition(
4095        self,
4096        cmap="tab20b",
4097        width=2,
4098        height=6,
4099        font_size=15,
4100        legend_split_col: int = 1,
4101        legend_bbox: tuple = (1.3, 1),
4102    ):
4103        """
4104        Visualize the composition of cell lineages using bar plots.
4105
4106        Produces bar plots showing the distribution of features stored in
4107        `self.composition_data`. If multiple sets are present, a separate
4108        bar is drawn for each set. Percentages are annotated alongside the bars.
4109
4110        Parameters
4111        ----------
4112        cmap : str, default 'tab20b'
4113            Colormap used for stacked bars.
4114
4115        width : int, default 2
4116            Width of each subplot (per set).
4117
4118        height : int, default 6
4119            Height of the figure.
4120
4121        font_size : int, default 15
4122            Font size for labels and annotations.
4123
4124        legend_split_col : int, default 1
4125            Number of columns in the legend.
4126
4127        legend_bbox : tuple, default (1.3, 1)
4128            Bounding box anchor position for the legend.
4129
4130        Returns
4131        -------
4132        matplotlib.figure.Figure
4133            Stacked bar plot visualization of composition data.
4134        """
4135
4136        df = self.composition_data
4137        df["num"] = range(1, len(df) + 1)
4138
4139        if "set" in df.columns and len(set(df["set"])) > 1:
4140
4141            sets = list(set(df["set"]))
4142            fig, axes = plt.subplots(1, len(sets), figsize=(width * len(sets), height))
4143
4144            cmap = plt.get_cmap(cmap)
4145
4146            set_nam = len(set(df["name"]))
4147
4148            legend_labels = list(set(df["name"]))
4149
4150            colors = [cmap(i / set_nam) for i in range(set_nam)]
4151
4152            cmap_dict = dict(zip(legend_labels, colors))
4153
4154            for idx, s in enumerate(sets):
4155                ax = axes[idx]
4156
4157                tmp_df = df[df["set"] == s].reset_index(drop=True)
4158
4159                values = tmp_df["n"].values
4160                total = sum(values)
4161                values = [v / total * 100 for v in values]
4162                values = [round(v, 2) for v in values]
4163
4164                idx_max = np.argmax(values)
4165                correction = 100 - sum(values)
4166                values[idx_max] += correction
4167
4168                names = tmp_df["name"].values
4169                perc = tmp_df["pct"].values
4170                nums = tmp_df["num"].values
4171
4172                bottom = 0
4173                centers = []
4174                for name, num, val, color in zip(names, nums, values, colors):
4175                    ax.bar(s, val, bottom=bottom, color=cmap_dict[name], label=name)
4176                    centers.append(bottom + val / 2)
4177                    bottom += val
4178
4179                y_positions = np.linspace(centers[0], centers[-1], len(centers))
4180                x_text = -0.8
4181
4182                for y_label, y_center, pct, num in zip(
4183                    y_positions, centers, perc, nums
4184                ):
4185                    ax.annotate(
4186                        f"{pct:.1f}%",
4187                        xy=(0, y_center),
4188                        xycoords="data",
4189                        xytext=(x_text, y_label),
4190                        textcoords="data",
4191                        ha="right",
4192                        va="center",
4193                        fontsize=font_size,
4194                        arrowprops=dict(
4195                            arrowstyle="->",
4196                            lw=1,
4197                            color="black",
4198                            connectionstyle="angle3,angleA=0,angleB=90",
4199                        ),
4200                    )
4201
4202                ax.set_ylim(0, 100)
4203                ax.set_xlabel(s, fontsize=font_size)
4204                ax.xaxis.label.set_rotation(30)
4205
4206                ax.set_xticks([])
4207                ax.set_yticks([])
4208                for spine in ax.spines.values():
4209                    spine.set_visible(False)
4210
4211            legend_handles = [
4212                Patch(facecolor=cmap_dict[label], edgecolor="black", label=label)
4213                for label in legend_labels
4214            ]
4215
4216            fig.legend(
4217                handles=legend_handles,
4218                loc="center right",
4219                bbox_to_anchor=legend_bbox,
4220                ncol=legend_split_col,
4221                title="",
4222            )
4223
4224            plt.tight_layout()
4225            plt.show()
4226
4227        else:
4228
4229            cmap = plt.get_cmap(cmap)
4230
4231            colors = [cmap(i / len(df)) for i in range(len(df))]
4232
4233            fig, ax = plt.subplots(figsize=(width, height))
4234
4235            values = df["n"].values
4236            names = df["name"].values
4237            perc = df["pct"].values
4238            nums = df["num"].values
4239
4240            bottom = 0
4241            centers = []
4242            for name, num, val, color in zip(names, nums, values, colors):
4243                ax.bar(0, val, bottom=bottom, color=color, label=f"{num}) {name}")
4244                centers.append(bottom + val / 2)
4245                bottom += val
4246
4247            y_positions = np.linspace(centers[0], centers[-1], len(centers))
4248            x_text = -0.8
4249
4250            for y_label, y_center, pct, num in zip(y_positions, centers, perc, nums):
4251                ax.annotate(
4252                    f"{num}) {pct}",
4253                    xy=(0, y_center),
4254                    xycoords="data",
4255                    xytext=(x_text, y_label),
4256                    textcoords="data",
4257                    ha="right",
4258                    va="center",
4259                    fontsize=9,
4260                    arrowprops=dict(
4261                        arrowstyle="->",
4262                        lw=1,
4263                        color="black",
4264                        connectionstyle="angle3,angleA=0,angleB=90",
4265                    ),
4266                )
4267
4268            ax.set_xticks([])
4269            ax.set_yticks([])
4270            for spine in ax.spines.values():
4271                spine.set_visible(False)
4272
4273            ax.legend(
4274                title="Legend",
4275                bbox_to_anchor=legend_bbox,
4276                loc="upper left",
4277                ncol=legend_split_col,
4278            )
4279
4280            plt.tight_layout()
4281            plt.show()
4282
4283        return fig

Visualize the composition of cell lineages using bar plots.

Produces bar plots showing the distribution of features stored in self.composition_data. If multiple sets are present, a separate bar is drawn for each set. Percentages are annotated alongside the bars.

Parameters

cmap : str, default 'tab20b' Colormap used for stacked bars.

width : int, default 2 Width of each subplot (per set).

height : int, default 6 Height of the figure.

font_size : int, default 15 Font size for labels and annotations.

legend_split_col : int, default 1 Number of columns in the legend.

legend_bbox : tuple, default (1.3, 1) Bounding box anchor position for the legend.

Returns

matplotlib.figure.Figure Stacked bar plot visualization of composition data.

def cell_regression( self, cell_x: str, cell_y: str, set_x: str | None, set_y: str | None, threshold=10, image_width=12, image_high=7, color='black'):
4285    def cell_regression(
4286        self,
4287        cell_x: str,
4288        cell_y: str,
4289        set_x: str | None,
4290        set_y: str | None,
4291        threshold=10,
4292        image_width=12,
4293        image_high=7,
4294        color="black",
4295    ):
4296        """
4297        Perform regression analysis between two selected cells and visualize the relationship.
4298
4299        This function computes a linear regression between two specified cells from
4300        aggregated normalized data, plots the regression line with scatter points,
4301        annotates regression statistics, and highlights potential outliers.
4302
4303        Parameters
4304        ----------
4305        cell_x : str
4306            Name of the first cell (X-axis).
4307
4308        cell_y : str
4309            Name of the second cell (Y-axis).
4310
4311        set_x : str or None
4312            Dataset identifier corresponding to `cell_x`. If None, cell is selected only by name.
4313
4314        set_y : str or None
4315            Dataset identifier corresponding to `cell_y`. If None, cell is selected only by name.
4316
4317        threshold : int or float, default 10
4318            Threshold for detecting outliers. Points deviating from the mean or diagonal by more
4319            than this value are annotated.
4320
4321        image_width : int, default 12
4322            Width of the regression plot (in inches).
4323
4324        image_high : int, default 7
4325            Height of the regression plot (in inches).
4326
4327        color : str, default 'black'
4328            Color of the regression scatter points and line.
4329
4330        Returns
4331        -------
4332        matplotlib.figure.Figure
4333            Regression plot figure with annotated regression line, R², p-value, and outliers.
4334
4335        Raises
4336        ------
4337        ValueError
4338            If `cell_x` or `cell_y` are not found in the dataset.
4339            If multiple matches are found for a cell name and `set_x`/`set_y` are not specified.
4340
4341        Notes
4342        -----
4343        - The function automatically calls `jseq_object.average()` if aggregated data is not available.
4344        - Outliers are annotated with their corresponding index labels.
4345        - Regression is computed using `scipy.stats.linregress`.
4346
4347        Examples
4348        --------
4349        >>> obj.cell_regression(cell_x="Purkinje", cell_y="Granule", set_x="Exp1", set_y="Exp2")
4350        >>> obj.cell_regression(cell_x="NeuronA", cell_y="NeuronB", threshold=5, color="blue")
4351        """
4352
4353        if self.agg_normalized_data is None:
4354            self.average()
4355
4356        metadata = self.agg_metadata
4357        data = self.agg_normalized_data
4358
4359        if set_x is not None and set_y is not None:
4360            data.columns = metadata["cell_names"] + " # " + metadata["sets"]
4361            cell_x = cell_x + " # " + set_x
4362            cell_y = cell_y + " # " + set_y
4363
4364        else:
4365            data.columns = metadata["cell_names"]
4366
4367        if not cell_x in data.columns:
4368            raise ValueError("'cell_x' value not in cell names!")
4369
4370        if not cell_y in data.columns:
4371            raise ValueError("'cell_y' value not in cell names!")
4372
4373        if list(data.columns).count(cell_x) > 1:
4374            raise ValueError(
4375                f"'{cell_x}' occurs more than once. If you want to select a specific cell, "
4376                f"please also provide the corresponding 'set_x' and 'set_y' values."
4377            )
4378
4379        if list(data.columns).count(cell_y) > 1:
4380            raise ValueError(
4381                f"'{cell_y}' occurs more than once. If you want to select a specific cell, "
4382                f"please also provide the corresponding 'set_x' and 'set_y' values."
4383            )
4384
4385        fig, ax = plt.subplots(figsize=(image_width, image_high))
4386        ax = sns.regplot(x=cell_x, y=cell_y, data=data, color=color)
4387
4388        slope, intercept, r_value, p_value, _ = stats.linregress(
4389            data[cell_x], data[cell_y]
4390        )
4391        equation = "y = {:.2f}x + {:.2f}".format(slope, intercept)
4392
4393        ax.annotate(
4394            "R-squared = {:.2f}\nP-value = {:.2f}\n{}".format(
4395                r_value**2, p_value, equation
4396            ),
4397            xy=(0.05, 0.90),
4398            xycoords="axes fraction",
4399            fontsize=12,
4400        )
4401
4402        ax.spines["top"].set_visible(False)
4403        ax.spines["right"].set_visible(False)
4404
4405        diff = []
4406        x_mean, y_mean = data[cell_x].mean(), data[cell_y].mean()
4407        for i, (xi, yi) in enumerate(zip(data[cell_x], data[cell_y])):
4408            diff.append(abs(xi - x_mean))
4409            diff.append(abs(yi - y_mean))
4410
4411        def annotate_outliers(x, y, threshold):
4412            texts = []
4413            x_mean, y_mean = x.mean(), y.mean()
4414            for i, (xi, yi) in enumerate(zip(x, y)):
4415                if (
4416                    abs(xi - x_mean) > threshold
4417                    or abs(yi - y_mean) > threshold
4418                    or abs(yi - xi) > threshold
4419                ):
4420                    text = ax.text(xi, yi, data.index[i])
4421                    texts.append(text)
4422
4423            return texts
4424
4425        texts = annotate_outliers(data[cell_x], data[cell_y], threshold)
4426
4427        adjust_text(texts, arrowprops=dict(arrowstyle="-", color="gray", alpha=0.5))
4428
4429        plt.show()
4430
4431        return fig

Perform regression analysis between two selected cells and visualize the relationship.

This function computes a linear regression between two specified cells from aggregated normalized data, plots the regression line with scatter points, annotates regression statistics, and highlights potential outliers.

Parameters

cell_x : str Name of the first cell (X-axis).

cell_y : str Name of the second cell (Y-axis).

set_x : str or None Dataset identifier corresponding to cell_x. If None, cell is selected only by name.

set_y : str or None Dataset identifier corresponding to cell_y. If None, cell is selected only by name.

threshold : int or float, default 10 Threshold for detecting outliers. Points deviating from the mean or diagonal by more than this value are annotated.

image_width : int, default 12 Width of the regression plot (in inches).

image_high : int, default 7 Height of the regression plot (in inches).

color : str, default 'black' Color of the regression scatter points and line.

Returns

matplotlib.figure.Figure Regression plot figure with annotated regression line, R², p-value, and outliers.

Raises

ValueError If cell_x or cell_y are not found in the dataset. If multiple matches are found for a cell name and set_x/set_y are not specified.

Notes

  • The function automatically calls jseq_object.average() if aggregated data is not available.
  • Outliers are annotated with their corresponding index labels.
  • Regression is computed using scipy.stats.linregress.

Examples

>>> obj.cell_regression(cell_x="Purkinje", cell_y="Granule", set_x="Exp1", set_y="Exp2")
>>> obj.cell_regression(cell_x="NeuronA", cell_y="NeuronB", threshold=5, color="blue")