jdti.jdti

   1import math
   2import os
   3import pickle
   4import re
   5
   6import harmonypy as harmonize
   7import matplotlib.pyplot as plt
   8import numpy as np
   9import pandas as pd
  10import plotly.express as px
  11import seaborn as sns
  12import umap
  13from adjustText import adjust_text
  14from joblib import Parallel, delayed
  15from matplotlib.patches import FancyArrowPatch, Patch, Polygon
  16from scipy import sparse
  17from scipy.io import mmwrite
  18from scipy.spatial import ConvexHull
  19from scipy.spatial.distance import pdist, squareform
  20from scipy.stats import norm, stats, zscore
  21from sklearn.cluster import DBSCAN, MeanShift
  22from sklearn.decomposition import PCA
  23from sklearn.metrics import pairwise_distances, silhouette_score
  24from sklearn.preprocessing import MinMaxScaler, StandardScaler
  25from tqdm import tqdm
  26
  27from .utils import *
  28
  29
  30class Clustering:
  31    """
  32    A class for performing dimensionality reduction, clustering, and visualization
  33    on high-dimensional data (e.g., single-cell gene expression).
  34
  35    The class provides methods for:
  36    - Normalizing and extracting subsets of data
  37    - Principal Component Analysis (PCA) and related clustering
  38    - Uniform Manifold Approximation and Projection (UMAP) and clustering
  39    - Visualization of PCA and UMAP embeddings
  40    - Harmonization of batch effects
  41    - Accessing processed data and cluster labels
  42
  43    Methods
  44    -------
  45    add_data_frame(data, metadata)
  46        Class method to create a Clustering instance from a DataFrame and metadata.
  47
  48    harmonize_sets()
  49        Perform batch effect harmonization on PCA data.
  50
  51    perform_PCA(pc_num=100, width=8, height=6)
  52        Perform PCA on the dataset and visualize the first two PCs.
  53
  54    knee_plot_PCA(width=8, height=6)
  55        Plot the cumulative variance explained by PCs to determine optimal dimensionality.
  56
  57    find_clusters_PCA(pc_num=0, eps=0.5, min_samples=10, width=8, height=6, harmonized=False)
  58        Apply DBSCAN clustering to PCA embeddings and visualize results.
  59
  60    perform_UMAP(factorize=False, umap_num=100, pc_num=0, harmonized=False, ...)
  61        Compute UMAP embeddings with optional parameter tuning.
  62
  63    knee_plot_umap(eps=0.5, min_samples=10)
  64        Determine optimal UMAP dimensionality using silhouette scores.
  65
  66    find_clusters_UMAP(umap_n=5, eps=0.5, min_samples=10, width=8, height=6)
  67        Apply DBSCAN clustering on UMAP embeddings and visualize clusters.
  68
  69    UMAP_vis(names_slot='cell_names', set_sep=True, point_size=0.6, ...)
  70        Visualize UMAP embeddings with labels and optional cluster numbering.
  71
  72    UMAP_feature(feature_name, features_data=None, point_size=0.6, ...)
  73        Plot a single feature over UMAP coordinates with customizable colormap.
  74
  75    get_umap_data()
  76        Return the UMAP embeddings along with cluster labels if available.
  77
  78    get_pca_data()
  79        Return the PCA results along with cluster labels if available.
  80
  81    return_clusters(clusters='umap')
  82        Return the cluster labels for UMAP or PCA embeddings.
  83
  84    Raises
  85    ------
  86    ValueError
  87        For invalid parameters, mismatched dimensions, or missing metadata.
  88    """
  89
  90    def __init__(self, data, metadata):
  91        """
  92        Initialize the clustering class with data and optional metadata.
  93
  94        Parameters
  95        ----------
  96        data : pandas.DataFrame
  97            Input data for clustering. Columns are considered as samples.
  98
  99        metadata : pandas.DataFrame, optional
 100            Metadata for the samples. If None, a default DataFrame with column
 101            names as 'cell_names' is created.
 102
 103        Attributes
 104        ----------
 105        -clustering_data : pandas.DataFrame
 106        -clustering_metadata : pandas.DataFrame
 107        -subclusters : None or dict
 108        -explained_var : None or numpy.ndarray
 109        -cumulative_var : None or numpy.ndarray
 110        -pca : None or pandas.DataFrame
 111        -harmonized_pca : None or pandas.DataFrame
 112        -umap : None or pandas.DataFrame
 113        """
 114
 115        self.clustering_data = data
 116        """The input data used for clustering."""
 117
 118        if metadata is None:
 119            metadata = pd.DataFrame({"cell_names": list(data.columns)})
 120
 121        self.clustering_metadata = metadata
 122        """Metadata associated with the samples."""
 123
 124        self.subclusters = None
 125        """Placeholder for storing subcluster information."""
 126
 127        self.explained_var = None
 128        """Explained variance from PCA, initialized as None."""
 129
 130        self.cumulative_var = None
 131        """Cumulative explained variance from PCA, initialized as None."""
 132
 133        self.pca = None
 134        """PCA-transformed data, initialized as None."""
 135
 136        self.harmonized_pca = None
 137        """PCA data after batch effect harmonization, initialized as None."""
 138
 139        self.umap = None
 140        """UMAP embeddings, initialized as None."""
 141
 142    @classmethod
 143    def add_data_frame(cls, data: pd.DataFrame, metadata: pd.DataFrame | None):
 144        """
 145        Create a Clustering instance from a DataFrame and optional metadata.
 146
 147        Parameters
 148        ----------
 149        data : pandas.DataFrame
 150            Input data with features as rows and samples/cells as columns.
 151
 152        metadata : pandas.DataFrame or None
 153            Optional metadata for the samples.
 154            Each row corresponds to a sample/cell, and column names in this DataFrame
 155            should match the sample/cell names in `data`. Columns can contain additional
 156            information such as cell type, experimental condition, batch, sets, etc.
 157
 158        Returns
 159        -------
 160        Clustering
 161            A new instance of the Clustering class.
 162        """
 163
 164        return cls(data, metadata)
 165
 166    def harmonize_sets(self, batch_col: str = "sets"):
 167        """
 168        Perform batch effect harmonization on PCA embeddings.
 169
 170        Parameters
 171        ----------
 172        batch_col : str, default 'sets'
 173            Name of the column in `metadata` that contains batch information for the samples/cells.
 174
 175        Returns
 176        -------
 177        None
 178            Updates the `harmonized_pca` attribute with harmonized data.
 179        """
 180
 181        data_mat = np.array(self.pca)
 182
 183        metadata = self.clustering_metadata
 184
 185        self.harmonized_pca = pd.DataFrame(
 186            harmonize.run_harmony(data_mat, metadata, vars_use=batch_col).Z_corr
 187        ).T
 188
 189        self.harmonized_pca.columns = self.pca.columns
 190
 191    def perform_PCA(self, pc_num: int = 100, width=8, height=6):
 192        """
 193        Perform Principal Component Analysis (PCA) on the dataset.
 194
 195        This method standardizes the data, applies PCA, stores results as attributes,
 196        and generates a scatter plot of the first two principal components.
 197
 198        Parameters
 199        ----------
 200        pc_num : int, default 100
 201            Number of principal components to compute.
 202            If 0, computes all available components.
 203
 204        width : int or float, default 8
 205            Width of the PCA figure.
 206
 207        height : int or float, default 6
 208            Height of the PCA figure.
 209
 210        Returns
 211        -------
 212        matplotlib.figure.Figure
 213            Scatter plot showing the first two principal components.
 214
 215        Updates
 216        -------
 217        self.pca : pandas.DataFrame
 218            DataFrame with principal component scores for each sample.
 219
 220        self.explained_var : numpy.ndarray
 221            Percentage of variance explained by each principal component.
 222
 223        self.cumulative_var : numpy.ndarray
 224            Cumulative explained variance.
 225        """
 226
 227        scaler = StandardScaler()
 228        data_scaled = scaler.fit_transform(self.clustering_data.T)
 229
 230        if pc_num == 0 or pc_num > data_scaled.shape[0]:
 231            pc_num = data_scaled.shape[0]
 232
 233        pca = PCA(n_components=pc_num, random_state=42)
 234
 235        principal_components = pca.fit_transform(data_scaled)
 236
 237        pca_df = pd.DataFrame(
 238            data=principal_components,
 239            columns=["PC" + str(x + 1) for x in range(pc_num)],
 240        )
 241
 242        self.explained_var = pca.explained_variance_ratio_ * 100
 243        self.cumulative_var = np.cumsum(self.explained_var)
 244
 245        self.pca = pca_df
 246
 247        fig = plt.figure(figsize=(width, height))
 248        plt.scatter(pca_df["PC1"], pca_df["PC2"], alpha=0.7)
 249        plt.xlabel("PC 1")
 250        plt.ylabel("PC 2")
 251        plt.grid(True)
 252        plt.show()
 253
 254        return fig
 255
 256    def knee_plot_PCA(self, width: int = 8, height: int = 6):
 257        """
 258        Plot cumulative explained variance to determine the optimal number of PCs.
 259
 260        Parameters
 261        ----------
 262        width : int, default 8
 263            Width of the figure.
 264
 265        height : int or, default 6
 266            Height of the figure.
 267
 268        Returns
 269        -------
 270        matplotlib.figure.Figure
 271            Line plot showing cumulative variance explained by each PC.
 272        """
 273
 274        fig_knee = plt.figure(figsize=(width, height))
 275        plt.plot(range(1, len(self.explained_var) + 1), self.cumulative_var, marker="o")
 276        plt.xlabel("PC (n components)")
 277        plt.ylabel("Cumulative explained variance (%)")
 278        plt.grid(True)
 279
 280        xticks = [1] + list(range(5, len(self.explained_var) + 1, 5))
 281        plt.xticks(xticks, rotation=60)
 282
 283        plt.show()
 284
 285        return fig_knee
 286
 287    def find_clusters_PCA(
 288        self,
 289        pc_num: int = 2,
 290        eps: float = 0.5,
 291        min_samples: int = 10,
 292        width: int = 8,
 293        height: int = 6,
 294        harmonized: bool = False,
 295    ):
 296        """
 297        Apply DBSCAN clustering to PCA embeddings and visualize the results.
 298
 299        This method performs density-based clustering (DBSCAN) on the PCA-reduced
 300        dataset. Cluster labels are stored in the object's metadata, and a scatter
 301        plot of the first two principal components with cluster annotations is returned.
 302
 303        Parameters
 304        ----------
 305        pc_num : int, default 2
 306            Number of principal components to use for clustering.
 307            If 0, uses all available components.
 308
 309        eps : float, default 0.5
 310            Maximum distance between two points for them to be considered
 311            as neighbors (DBSCAN parameter).
 312
 313        min_samples : int, default 10
 314            Minimum number of samples required to form a cluster (DBSCAN parameter).
 315
 316        width : int, default 8
 317            Width of the output scatter plot.
 318
 319        height : int, default 6
 320            Height of the output scatter plot.
 321
 322        harmonized : bool, default False
 323            If True, use harmonized PCA data (`self.harmonized_pca`).
 324            If False, use standard PCA results (`self.pca`).
 325
 326        Returns
 327        -------
 328        matplotlib.figure.Figure
 329            Scatter plot of the first two principal components colored by
 330            cluster assignments.
 331
 332        Updates
 333        -------
 334        self.clustering_metadata['PCA_clusters'] : list
 335            Cluster labels assigned to each cell/sample.
 336
 337        self.input_metadata['PCA_clusters'] : list, optional
 338            Cluster labels stored in input metadata (if available).
 339        """
 340
 341        dbscan = DBSCAN(eps=eps, min_samples=min_samples)
 342
 343        if pc_num == 0 and harmonized:
 344            PCA = self.harmonized_pca
 345
 346        elif pc_num == 0:
 347            PCA = self.pca
 348
 349        else:
 350            if harmonized:
 351
 352                PCA = self.harmonized_pca.iloc[:, 0:pc_num]
 353            else:
 354
 355                PCA = self.pca.iloc[:, 0:pc_num]
 356
 357        dbscan_labels = dbscan.fit_predict(PCA)
 358
 359        pca_df = pd.DataFrame(PCA)
 360        pca_df["Cluster"] = dbscan_labels
 361
 362        fig = plt.figure(figsize=(width, height))
 363
 364        for cluster_id in sorted(pca_df["Cluster"].unique()):
 365            cluster_data = pca_df[pca_df["Cluster"] == cluster_id]
 366            plt.scatter(
 367                cluster_data["PC1"],
 368                cluster_data["PC2"],
 369                label=f"Cluster {cluster_id}",
 370                alpha=0.7,
 371            )
 372
 373        plt.xlabel("PC 1")
 374        plt.ylabel("PC 2")
 375        plt.legend(title="Clusters", loc="center left", bbox_to_anchor=(1.0, 0.5))
 376
 377        plt.grid(True)
 378        plt.show()
 379
 380        self.clustering_metadata["PCA_clusters"] = [str(x) for x in dbscan_labels]
 381
 382        try:
 383            self.input_metadata["PCA_clusters"] = [str(x) for x in dbscan_labels]
 384        except:
 385            pass
 386
 387        return fig
 388
 389    def perform_UMAP(
 390        self,
 391        factorize: bool = False,
 392        umap_num: int = 100,
 393        pc_num: int = 0,
 394        harmonized: bool = False,
 395        n_neighbors: int = 5,
 396        min_dist: float | int = 0.1,
 397        spread: float | int = 1.0,
 398        set_op_mix_ratio: float | int = 1.0,
 399        local_connectivity: int = 1,
 400        repulsion_strength: float | int = 1.0,
 401        negative_sample_rate: int = 5,
 402        width: int = 8,
 403        height: int = 6,
 404    ):
 405        """
 406        Compute and visualize UMAP embeddings of the dataset.
 407
 408        This method applies Uniform Manifold Approximation and Projection (UMAP)
 409        for dimensionality reduction on either raw, PCA, or harmonized PCA data.
 410        Results are stored as a DataFrame (`self.umap`) and a scatter plot figure
 411        (`self.UMAP_plot`).
 412
 413        Parameters
 414        ----------
 415        factorize : bool, default False
 416            If True, categorical sample labels (from column names) are factorized
 417            and used as supervision in UMAP fitting.
 418
 419        umap_num : int, default 100
 420            Number of UMAP dimensions to compute. If 0, matches the input dimension.
 421
 422        pc_num : int, default 0
 423            Number of principal components to use as UMAP input.
 424            If 0, use all available components or raw data.
 425
 426        harmonized : bool, default False
 427            If True, use harmonized PCA embeddings (`self.harmonized_pca`).
 428            If False, use standard PCA or raw scaled data.
 429
 430        n_neighbors : int, default 5
 431            UMAP parameter controlling the size of the local neighborhood.
 432
 433        min_dist : float, default 0.1
 434            UMAP parameter controlling minimum allowed distance between embedded points.
 435
 436        spread : int | float, default 1.0
 437            Effective scale of embedded space (UMAP parameter).
 438
 439        set_op_mix_ratio : int | float, default 1.0
 440            Interpolation parameter between union and intersection in fuzzy sets.
 441
 442        local_connectivity : int, default 1
 443            Number of nearest neighbors assumed for each point.
 444
 445        repulsion_strength : int | float, default 1.0
 446            Weighting applied to negative samples during optimization.
 447
 448        negative_sample_rate : int, default 5
 449            Number of negative samples per positive sample in optimization.
 450
 451        width : int, default 8
 452            Width of the output scatter plot.
 453
 454        height : int, default 6
 455            Height of the output scatter plot.
 456
 457        Updates
 458        -------
 459        self.umap : pandas.DataFrame
 460            Table of UMAP embeddings with columns `UMAP1 ... UMAPn`.
 461
 462        Notes
 463        -----
 464        For supervised UMAP (`factorize=True`), categorical codes from column
 465        names of the dataset are used as labels.
 466        """
 467
 468        scaler = StandardScaler()
 469
 470        if pc_num == 0 and harmonized:
 471            data_scaled = self.harmonized_pca
 472
 473        elif pc_num == 0:
 474            data_scaled = scaler.fit_transform(self.clustering_data.T)
 475
 476        else:
 477            if harmonized:
 478
 479                data_scaled = self.harmonized_pca.iloc[:, 0:pc_num]
 480            else:
 481
 482                data_scaled = self.pca.iloc[:, 0:pc_num]
 483
 484        if umap_num == 0 or umap_num > data_scaled.shape[1]:
 485
 486            umap_num = data_scaled.shape[1]
 487
 488            reducer = umap.UMAP(
 489                n_components=len(data_scaled.T),
 490                random_state=42,
 491                n_neighbors=n_neighbors,
 492                min_dist=min_dist,
 493                spread=spread,
 494                set_op_mix_ratio=set_op_mix_ratio,
 495                local_connectivity=local_connectivity,
 496                repulsion_strength=repulsion_strength,
 497                negative_sample_rate=negative_sample_rate,
 498                n_jobs=1,
 499            )
 500
 501        else:
 502
 503            reducer = umap.UMAP(
 504                n_components=umap_num,
 505                random_state=42,
 506                n_neighbors=n_neighbors,
 507                min_dist=min_dist,
 508                spread=spread,
 509                set_op_mix_ratio=set_op_mix_ratio,
 510                local_connectivity=local_connectivity,
 511                repulsion_strength=repulsion_strength,
 512                negative_sample_rate=negative_sample_rate,
 513                n_jobs=1,
 514            )
 515
 516        if factorize:
 517            embedding = reducer.fit_transform(
 518                X=data_scaled, y=pd.Categorical(self.clustering_data.columns).codes
 519            )
 520        else:
 521            embedding = reducer.fit_transform(X=data_scaled)
 522
 523        umap_df = pd.DataFrame(
 524            embedding, columns=["UMAP" + str(x + 1) for x in range(umap_num)]
 525        )
 526
 527        plt.figure(figsize=(width, height))
 528        plt.scatter(umap_df["UMAP1"], umap_df["UMAP2"], alpha=0.7)
 529        plt.xlabel("UMAP 1")
 530        plt.ylabel("UMAP 2")
 531        plt.grid(True)
 532
 533        plt.show()
 534
 535        self.umap = umap_df
 536
 537    def knee_plot_umap(self, eps: int | float = 0.5, min_samples: int = 10):
 538        """
 539        Plot silhouette scores for different UMAP dimensions to determine optimal n_components.
 540
 541        Parameters
 542        ----------
 543        eps : float, default 0.5
 544            DBSCAN eps parameter for clustering each UMAP dimension.
 545
 546        min_samples : int, default 10
 547            Minimum number of samples to form a cluster in DBSCAN.
 548
 549        Returns
 550        -------
 551        matplotlib.figure.Figure
 552            Silhouette score plot across UMAP dimensions.
 553        """
 554
 555        umap_range = range(2, len(self.umap.T) + 1)
 556
 557        silhouette_scores = []
 558        component = []
 559        for n in umap_range:
 560
 561            db = DBSCAN(eps=eps, min_samples=min_samples)
 562            labels = db.fit_predict(np.array(self.umap)[:, :n])
 563
 564            mask = labels != -1
 565            if len(set(labels[mask])) > 1:
 566                score = silhouette_score(np.array(self.umap)[:, :n][mask], labels[mask])
 567            else:
 568                score = -1
 569
 570            silhouette_scores.append(score)
 571            component.append(n)
 572
 573        fig = plt.figure(figsize=(10, 5))
 574        plt.plot(component, silhouette_scores, marker="o")
 575        plt.xlabel("UMAP (n_components)")
 576        plt.ylabel("Silhouette Score")
 577        plt.grid(True)
 578        plt.xticks(range(int(min(component)), int(max(component)) + 1, 1))
 579
 580        plt.show()
 581
 582        return fig
 583
 584    def find_clusters_UMAP(
 585        self,
 586        umap_n: int = 5,
 587        eps: float | float = 0.5,
 588        min_samples: int = 10,
 589        width: int = 8,
 590        height: int = 6,
 591    ):
 592        """
 593        Apply DBSCAN clustering on UMAP embeddings and visualize clusters.
 594
 595        This method performs density-based clustering (DBSCAN) on the UMAP-reduced
 596        dataset. Cluster labels are stored in the object's metadata, and a scatter
 597        plot of the first two UMAP components with cluster annotations is returned.
 598
 599        Parameters
 600        ----------
 601        umap_n : int, default 5
 602            Number of UMAP dimensions to use for DBSCAN clustering.
 603            Must be <= number of columns in `self.umap`.
 604
 605        eps : float | int, default 0.5
 606            Maximum neighborhood distance between two samples for them to be considered
 607            as in the same cluster (DBSCAN parameter).
 608
 609        min_samples : int, default 10
 610            Minimum number of samples in a neighborhood to form a cluster (DBSCAN parameter).
 611
 612        width : int, default 8
 613            Figure width.
 614
 615        height : int, default 6
 616            Figure height.
 617
 618        Returns
 619        -------
 620        matplotlib.figure.Figure
 621            Scatter plot of the first two UMAP components colored by
 622            cluster assignments.
 623
 624        Updates
 625        -------
 626        self.clustering_metadata['UMAP_clusters'] : list
 627            Cluster labels assigned to each cell/sample.
 628
 629        self.input_metadata['UMAP_clusters'] : list, optional
 630            Cluster labels stored in input metadata (if available).
 631        """
 632
 633        dbscan = DBSCAN(eps=eps, min_samples=min_samples)
 634        dbscan_labels = dbscan.fit_predict(np.array(self.umap)[:, :umap_n])
 635
 636        umap_df = self.umap
 637        umap_df["Cluster"] = dbscan_labels
 638
 639        fig = plt.figure(figsize=(width, height))
 640
 641        for cluster_id in sorted(umap_df["Cluster"].unique()):
 642            cluster_data = umap_df[umap_df["Cluster"] == cluster_id]
 643            plt.scatter(
 644                cluster_data["UMAP1"],
 645                cluster_data["UMAP2"],
 646                label=f"Cluster {cluster_id}",
 647                alpha=0.7,
 648            )
 649
 650        plt.xlabel("UMAP 1")
 651        plt.ylabel("UMAP 2")
 652        plt.legend(title="Clusters", loc="center left", bbox_to_anchor=(1.0, 0.5))
 653        plt.grid(True)
 654
 655        self.clustering_metadata["UMAP_clusters"] = [str(x) for x in dbscan_labels]
 656
 657        try:
 658            self.input_metadata["UMAP_clusters"] = [str(x) for x in dbscan_labels]
 659        except:
 660            pass
 661
 662        return fig
 663
 664    def UMAP_vis(
 665        self,
 666        names_slot: str = "cell_names",
 667        set_sep: bool = True,
 668        point_size: int | float = 0.6,
 669        font_size: int | float = 6,
 670        legend_split_col: int = 2,
 671        width: int = 8,
 672        height: int = 6,
 673        inc_num: bool = True,
 674    ):
 675        """
 676        Visualize UMAP embeddings with sample labels based on specyfic metadata slot.
 677
 678        Parameters
 679        ----------
 680        names_slot : str, default 'cell_names'
 681            Column in metadata to use as sample labels.
 682
 683        set_sep : bool, default True
 684            If True, separate points by dataset.
 685
 686        point_size : float, default 0.6
 687            Size of scatter points.
 688
 689        font_size : int, default 6
 690            Font size for numbers on points.
 691
 692        legend_split_col : int, default 2
 693            Number of columns in legend.
 694
 695        width : int, default 8
 696            Figure width.
 697
 698        height : int, default 6
 699            Figure height.
 700
 701        inc_num : bool, default True
 702            If True, annotate points with numeric labels.
 703
 704        Returns
 705        -------
 706        matplotlib.figure.Figure
 707            UMAP scatter plot figure.
 708        """
 709
 710        umap_df = self.umap.iloc[:, 0:2].copy()
 711        umap_df.loc[:, "names"] = self.clustering_metadata[names_slot]
 712
 713        if set_sep:
 714
 715            if "sets" in self.clustering_metadata.columns:
 716                umap_df.loc[:, "dataset"] = self.clustering_metadata["sets"]
 717            else:
 718                umap_df["dataset"] = "default"
 719
 720        else:
 721            umap_df["dataset"] = "default"
 722
 723        umap_df.loc[:, "tmp_nam"] = umap_df["names"] + umap_df["dataset"]
 724
 725        umap_df.loc[:, "count"] = umap_df["tmp_nam"].map(
 726            umap_df["tmp_nam"].value_counts()
 727        )
 728
 729        numeric_df = (
 730            pd.DataFrame(umap_df[["count", "tmp_nam", "names"]].copy())
 731            .drop_duplicates()
 732            .sort_values("count", ascending=False)
 733        )
 734        numeric_df["numeric_values"] = range(0, numeric_df.shape[0])
 735
 736        umap_df = umap_df.merge(
 737            numeric_df[["tmp_nam", "numeric_values"]], on="tmp_nam", how="left"
 738        )
 739
 740        fig, ax = plt.subplots(figsize=(width, height))
 741
 742        markers = ["o", "s", "^", "D", "P", "*", "X"]
 743        marker_map = {
 744            ds: markers[i % len(markers)]
 745            for i, ds in enumerate(umap_df["dataset"].unique())
 746        }
 747
 748        cord_list = []
 749
 750        for num, nam in zip(numeric_df["numeric_values"], numeric_df["names"]):
 751
 752            cluster_data = umap_df[umap_df["numeric_values"] == num]
 753
 754            ax.scatter(
 755                cluster_data["UMAP1"],
 756                cluster_data["UMAP2"],
 757                label=f"{num} - {nam}",
 758                marker=marker_map[cluster_data["dataset"].iloc[0]],
 759                alpha=0.6,
 760                s=point_size,
 761            )
 762
 763            coords = cluster_data[["UMAP1", "UMAP2"]].values
 764
 765            dists = pairwise_distances(coords)
 766
 767            sum_dists = dists.sum(axis=1)
 768
 769            center_idx = np.argmin(sum_dists)
 770            center_point = coords[center_idx]
 771
 772            cord_list.append(center_point)
 773
 774        if inc_num:
 775            texts = []
 776            for (x, y), num in zip(cord_list, numeric_df["numeric_values"]):
 777                texts.append(
 778                    ax.text(
 779                        x,
 780                        y,
 781                        str(num),
 782                        ha="center",
 783                        va="center",
 784                        fontsize=font_size,
 785                        color="black",
 786                    )
 787                )
 788
 789            adjust_text(texts, arrowprops=dict(arrowstyle="-", color="gray", lw=0.5))
 790
 791        ax.set_xlabel("UMAP 1")
 792        ax.set_ylabel("UMAP 2")
 793
 794        ax.legend(
 795            title="Clusters",
 796            loc="center left",
 797            bbox_to_anchor=(1.05, 0.5),
 798            ncol=legend_split_col,
 799            markerscale=5,
 800        )
 801
 802        ax.grid(True)
 803
 804        plt.tight_layout()
 805
 806        return fig
 807
 808    def UMAP_feature(
 809        self,
 810        feature_name: str,
 811        features_data: pd.DataFrame | None,
 812        point_size: int | float = 0.6,
 813        font_size: int | float = 6,
 814        width: int = 8,
 815        height: int = 6,
 816        palette="light",
 817    ):
 818        """
 819        Visualize UMAP embedding with expression levels of a selected feature.
 820
 821        Each point (cell) in the UMAP plot is colored according to the expression
 822        value of the chosen feature, enabling interpretation of spatial patterns
 823        of gene activity or metadata distribution in low-dimensional space.
 824
 825        Parameters
 826        ----------
 827        feature_name : str
 828           Name of the feature to plot.
 829
 830        features_data : pandas.DataFrame or None, default None
 831            If None, the function uses the DataFrame containing the clustering data.
 832            To plot features not used in clustering, provide a wider DataFrame
 833            containing the original feature values.
 834
 835        point_size : float, default 0.6
 836            Size of scatter points in the plot.
 837
 838        font_size : int, default 6
 839            Font size for axis labels and annotations.
 840
 841        width : int, default 8
 842            Width of the matplotlib figure.
 843
 844        height : int, default 6
 845            Height of the matplotlib figure.
 846
 847        palette : str, default 'light'
 848            Color palette for expression visualization. Options are:
 849            - 'light'
 850            - 'dark'
 851            - 'green'
 852            - 'gray'
 853
 854        Returns
 855        -------
 856        matplotlib.figure.Figure
 857            UMAP scatter plot colored by feature values.
 858        """
 859
 860        umap_df = self.umap.iloc[:, 0:2].copy()
 861
 862        if features_data is None:
 863
 864            features_data = self.clustering_data
 865
 866        if features_data.shape[1] != umap_df.shape[0]:
 867            raise ValueError(
 868                "Imputed 'features_data' shape does not match the number of UMAP cells"
 869            )
 870
 871        blist = [
 872            True if x.upper() == feature_name.upper() else False
 873            for x in features_data.index
 874        ]
 875
 876        if not any(blist):
 877            raise ValueError("Imputed feature_name is not included in the data")
 878
 879        umap_df.loc[:, "feature"] = (
 880            features_data.loc[blist, :]
 881            .apply(lambda row: row.tolist(), axis=1)
 882            .values[0]
 883        )
 884
 885        umap_df = umap_df.sort_values("feature", ascending=True)
 886
 887        import matplotlib.colors as mcolors
 888
 889        if palette == "light":
 890            palette = px.colors.sequential.Sunsetdark
 891
 892        elif palette == "dark":
 893            palette = px.colors.sequential.thermal
 894
 895        elif palette == "green":
 896            palette = px.colors.sequential.Aggrnyl
 897
 898        elif palette == "gray":
 899            palette = px.colors.sequential.gray
 900            palette = palette[::-1]
 901
 902        else:
 903            raise ValueError(
 904                'Palette not found. Use: "light", "dark", "gray", or "green"'
 905            )
 906
 907        converted = []
 908        for c in palette:
 909            rgb_255 = px.colors.unlabel_rgb(c)
 910            rgb_01 = tuple(v / 255.0 for v in rgb_255)
 911            converted.append(rgb_01)
 912
 913        my_cmap = mcolors.ListedColormap(converted, name="custom")
 914
 915        fig, ax = plt.subplots(figsize=(width, height))
 916        sc = ax.scatter(
 917            umap_df["UMAP1"],
 918            umap_df["UMAP2"],
 919            c=umap_df["feature"],
 920            s=point_size,
 921            cmap=my_cmap,
 922            alpha=1.0,
 923            edgecolors="black",
 924            linewidths=0.1,
 925        )
 926
 927        cbar = plt.colorbar(sc, ax=ax)
 928        cbar.set_label(f"{feature_name}")
 929
 930        ax.set_xlabel("UMAP 1")
 931        ax.set_ylabel("UMAP 2")
 932
 933        ax.grid(True)
 934
 935        plt.tight_layout()
 936
 937        return fig
 938
 939    def get_umap_data(self):
 940        """
 941        Retrieve UMAP embedding data with optional cluster labels.
 942
 943        Returns the UMAP coordinates stored in `self.umap`. If clustering
 944        metadata is available (specifically `UMAP_clusters`), the corresponding
 945        cluster assignments are appended as an additional column.
 946
 947        Returns
 948        -------
 949        pandas.DataFrame
 950            DataFrame containing UMAP coordinates (columns: 'UMAP1', 'UMAP2', ...).
 951            If available, includes an extra column 'clusters' with cluster labels.
 952
 953        Notes
 954        -----
 955        - UMAP embeddings must be computed beforehand (e.g., using `perform_UMAP`).
 956        - Cluster labels are added only if present in `self.clustering_metadata`.
 957        """
 958
 959        umap_data = self.umap
 960
 961        try:
 962            umap_data["clusters"] = self.clustering_metadata["UMAP_clusters"]
 963        except:
 964            pass
 965
 966        return umap_data
 967
 968    def get_pca_data(self):
 969        """
 970        Retrieve PCA embedding data with optional cluster labels.
 971
 972        Returns the principal component scores stored in `self.pca`. If clustering
 973        metadata is available (specifically `PCA_clusters`), the corresponding
 974        cluster assignments are appended as an additional column.
 975
 976        Returns
 977        -------
 978        pandas.DataFrame
 979            DataFrame containing PCA coordinates (columns: 'PC1', 'PC2', ...).
 980            If available, includes an extra column 'clusters' with cluster labels.
 981
 982        Notes
 983        -----
 984        - PCA must be computed beforehand (e.g., using `perform_PCA`).
 985        - Cluster labels are added only if present in `self.clustering_metadata`.
 986        """
 987
 988        pca_data = self.pca
 989
 990        try:
 991            pca_data["clusters"] = self.clustering_metadata["PCA_clusters"]
 992        except:
 993            pass
 994
 995        return pca_data
 996
 997    def return_clusters(self, clusters="umap"):
 998        """
 999        Retrieve cluster labels from UMAP or PCA clustering results.
1000
1001        Parameters
1002        ----------
1003        clusters : str, default 'umap'
1004            Source of cluster labels to return. Must be one of:
1005            - 'umap': return cluster labels from UMAP embeddings.
1006            - 'pca' : return cluster labels from PCA embeddings.
1007
1008        Returns
1009        -------
1010        list
1011            Cluster labels corresponding to the selected embedding method.
1012
1013        Raises
1014        ------
1015        ValueError
1016            If `clusters` is not 'umap' or 'pca'.
1017
1018        Notes
1019        -----
1020        Requires that clustering has already been performed
1021        (e.g., using `find_clusters_UMAP` or `find_clusters_PCA`).
1022        """
1023
1024        if clusters.lower() == "umap":
1025            clusters_vector = self.clustering_metadata["UMAP_clusters"]
1026        elif clusters.lower() == "pca":
1027            clusters_vector = self.clustering_metadata["PCA_clusters"]
1028        else:
1029            raise ValueError("Parameter 'clusters' must be either 'umap' or 'pca'.")
1030
1031        return clusters_vector
1032
1033
1034class COMPsc(Clustering):
1035    """
1036    A class `COMPsc` (Comparison of single-cell data) designed for the integration,
1037    analysis, and visualization of single-cell datasets.
1038    The class supports independent dataset integration, subclustering of existing clusters,
1039    marker detection, and multiple visualization strategies.
1040
1041    The COMPsc class provides methods for:
1042
1043        - Normalizing and filtering single-cell data
1044        - Loading and saving sparse 10x-style datasets
1045        - Computing differential expression and marker genes
1046        - Clustering and subclustering analysis
1047        - Visualizing similarity and spatial relationships
1048        - Aggregating data by cell and set annotations
1049        - Managing metadata and renaming labels
1050        - Plotting gene detection histograms and feature scatters
1051
1052    Methods
1053    -------
1054    project_dir(path_to_directory, project_list)
1055        Scans a directory to create a COMPsc instance mapping project names to their paths.
1056
1057    save_project(name, path=os.getcwd())
1058        Saves the COMPsc object to a pickle file on disk.
1059
1060    load_project(path)
1061        Loads a previously saved COMPsc object from a pickle file.
1062
1063    reduce_cols(reg, inc_set=False)
1064        Removes columns from data tables where column names contain a specified name or partial substring.
1065
1066    reduce_rows(reg, inc_set=False)
1067        Removes rows from data tables where column names contain a specified feature (gene) name.
1068
1069    get_data(set_info=False)
1070        Returns normalized data with optional set annotations in column names.
1071
1072    get_metadata()
1073        Returns the stored input metadata.
1074
1075    get_partial_data(names=None, features=None, name_slot='cell_names')
1076        Return a subset of the data by sample names and/or features.
1077
1078    gene_calculation()
1079        Calculates and stores per-cell gene detection counts as a pandas Series.
1080
1081    gene_histograme(bins=100)
1082        Plots a histogram of genes detected per cell with an overlaid normal distribution.
1083
1084    gene_threshold(min_n=None, max_n=None)
1085        Filters cells based on minimum and/or maximum gene detection thresholds.
1086
1087    load_sparse_from_projects(normalized_data=False)
1088        Loads and concatenates sparse 10x-style datasets from project paths into count or normalized data.
1089
1090    rename_names(mapping, slot='cell_names')
1091        Renames entries in a specified metadata column using a provided mapping dictionary.
1092
1093    rename_subclusters(mapping)
1094        Renames subcluster labels using a provided mapping dictionary.
1095
1096    save_sparse(path_to_save=os.getcwd(), name_slot='cell_names', data_slot='normalized')
1097        Exports data as 10x-compatible sparse files (matrix.mtx, barcodes.tsv, genes.tsv).
1098
1099    normalize_data(normalize=True, normalize_factor=100000)
1100        Normalizes raw counts to counts-per-specified factor (e.g., CPM-like).
1101
1102    statistic(cells=None, sets=None, min_exp=0.01, min_pct=0.1, n_proc=10)
1103        Computes per-feature differential expression statistics (Mann-Whitney U) comparing target vs. rest groups.
1104
1105    calculate_difference_markers(min_exp=0, min_pct=0.25, n_proc=10, force=False)
1106        Computes and caches differential markers using the statistic method.
1107
1108    clustering_features(features_list=None, name_slot='cell_names', p_val=0.05, top_n=25, adj_mean=True, beta=0.4)
1109        Prepares clustering input by selecting marker features and optionally smoothing cell values.
1110
1111    average()
1112        Aggregates normalized data by averaging across (cell_name, set) pairs.
1113
1114    estimating_similarity(method='pearson', p_val=0.05, top_n=25)
1115        Computes pairwise correlation and Euclidean distance between aggregated samples.
1116
1117    similarity_plot(split_sets=True, set_info=True, cmap='seismic', width=12, height=10)
1118        Visualizes pairwise similarity as a scatter plot with correlation as hue and scaled distance as point size.
1119
1120    spatial_similarity(set_info=True, bandwidth=1, n_neighbors=5, min_dist=0.1, legend_split=2, point_size=20, ...)
1121        Creates a UMAP-like visualization of similarity relationships with cluster hulls and nearest-neighbor arrows.
1122
1123    subcluster_prepare(features, cluster)
1124        Initializes a Clustering object for subcluster analysis on a selected parent cluster.
1125
1126    define_subclusters(umap_num=2, eps=0.5, min_samples=10, bandwidth=1, n_neighbors=5, min_dist=0.1, ...)
1127        Performs UMAP and DBSCAN clustering on prepared subcluster data and stores cluster labels.
1128
1129    subcluster_features_scatter(colors='viridis', hclust='complete', img_width=3, img_high=5, label_size=6, ...)
1130        Visualizes averaged expression and occurrence of features for subclusters as a scatter plot.
1131
1132    subcluster_DEG_scatter(top_n=3, min_exp=0, min_pct=0.25, p_val=0.05, colors='viridis', ...)
1133        Plots top differential features for subclusters as a features-scatter visualization.
1134
1135    accept_subclusters()
1136        Commits subcluster labels to main metadata by renaming cell names and clears subcluster data.
1137
1138    Raises
1139    ------
1140    ValueError
1141        For invalid parameters, mismatched dimensions, or missing metadata.
1142
1143    """
1144
1145    def __init__(
1146        self,
1147        objects=None,
1148    ):
1149        """
1150        Initialize the COMPsc class for single-cell data integration and analysis.
1151
1152        Parameters
1153        ----------
1154        objects : list or None, optional
1155            Optional list of data objects to initialize the instance with.
1156
1157        Attributes
1158        ----------
1159        -objects : list or None
1160        -input_data : pandas.DataFrame or None
1161        -input_metadata : pandas.DataFrame or None
1162        -normalized_data : pandas.DataFrame or None
1163        -agg_metadata : pandas.DataFrame or None
1164        -agg_normalized_data : pandas.DataFrame or None
1165        -similarity : pandas.DataFrame or None
1166        -var_data : pandas.DataFrame or None
1167        -subclusters_ : instance of Clustering class or None
1168        -cells_calc : pandas.Series or None
1169        -gene_calc : pandas.Series or None
1170        -composition_data : pandas.DataFrame or None
1171        """
1172
1173        self.objects = objects
1174        """ Stores the input data objects."""
1175
1176        self.input_data = None
1177        """Raw input data for clustering or integration analysis."""
1178
1179        self.input_metadata = None
1180        """Metadata associated with the input data."""
1181
1182        self.normalized_data = None
1183        """Normalized version of the input data."""
1184
1185        self.agg_metadata = None
1186        '''Aggregated metadata for all sets in object related to "agg_normalized_data"'''
1187
1188        self.agg_normalized_data = None
1189        """Aggregated and normalized data across multiple sets."""
1190
1191        self.similarity = None
1192        """Similarity data between cells across all samples. and sets"""
1193
1194        self.var_data = None
1195        """DEG analysis results summarizing variance across all samples in the object."""
1196
1197        self.subclusters_ = None
1198        """Placeholder for information about subclusters analysis; if computed."""
1199
1200        self.cells_calc = None
1201        """Number of cells detected per sample (grouped by lineage, e.g., cluster or name), reflecting data composition."""
1202
1203        self.gene_calc = None
1204        """Number of genes detected per sample (cell), reflecting the sequencing depth."""
1205
1206        self.composition_data = None
1207        """Data describing composition of cells across clusters or sets."""
1208
1209    @classmethod
1210    def project_dir(cls, path_to_directory, project_list):
1211        """
1212        Scan a directory and build a COMPsc instance mapping provided project names
1213        to their paths.
1214
1215        Parameters
1216        ----------
1217        path_to_directory : str
1218            Path containing project subfolders.
1219
1220        project_list : list[str]
1221            List of filenames (folder names) to include in the returned object map.
1222
1223        Returns
1224        -------
1225        COMPsc
1226            New COMPsc instance with `objects` populated.
1227
1228        Raises
1229        ------
1230        Exception
1231            A generic exception is caught and a message printed if scanning fails.
1232
1233        Notes
1234        -----
1235        Function attempts to match entries in `project_list` to directory
1236        names and constructs a simplified object key from the folder name.
1237        """
1238        try:
1239            objects = {}
1240            for filename in tqdm(os.listdir(path_to_directory)):
1241                for c in project_list:
1242                    f = os.path.join(path_to_directory, filename)
1243                    if c == filename and os.path.isdir(f):
1244                        objects[
1245                            str(
1246                                re.sub(
1247                                    "_count",
1248                                    "",
1249                                    re.sub(
1250                                        "_fq",
1251                                        "",
1252                                        re.sub(
1253                                            re.sub("_.*", "", filename) + "_",
1254                                            "",
1255                                            filename,
1256                                        ),
1257                                    ),
1258                                )
1259                            )
1260                        ] = f
1261
1262            return cls(objects)
1263
1264        except:
1265            print("Something went wrong. Check the function input data and try again!")
1266
1267    def save_project(self, name, path: str = os.getcwd()):
1268        """
1269        Save the COMPsc object to disk using pickle.
1270
1271        Parameters
1272        ----------
1273        name : str
1274            Base filename (without extension) to use when saving.
1275
1276        path : str, default os.getcwd()
1277            Directory in which to save the project file.
1278
1279        Returns
1280        -------
1281        None
1282
1283        Side Effects
1284        ------------
1285        - Writes a file `<path>/<name>.jpkl` containing the pickled object.
1286        - Prints a confirmation message with saved path.
1287        """
1288
1289        full = os.path.join(path, f"{name}.jpkl")
1290
1291        with open(full, "wb") as f:
1292            pickle.dump(self, f)
1293
1294        print(f"Project saved as {full}")
1295
1296    @classmethod
1297    def load_project(cls, path):
1298        """
1299        Load a previously saved COMPsc project from a pickle file.
1300
1301        Parameters
1302        ----------
1303        path : str
1304            Full path to the pickled project file.
1305
1306        Returns
1307        -------
1308        COMPsc
1309            The unpickled COMPsc object.
1310
1311        Raises
1312        ------
1313        FileNotFoundError
1314            If the provided path does not exist.
1315        """
1316
1317        if not os.path.exists(path):
1318            raise FileNotFoundError("File does not exist!")
1319        with open(path, "rb") as f:
1320            obj = pickle.load(f)
1321        return obj
1322
1323    def reduce_cols(
1324        self,
1325        reg: str | None = None,
1326        full: str | None = None,
1327        name_slot: str = "cell_names",
1328        inc_set: bool = False,
1329    ):
1330        """
1331        Remove columns (cells) whose names contain a substring `reg` or
1332        full name `full` from available tables.
1333
1334        Parameters
1335        ----------
1336        reg : str | None
1337            Substring to search for in column/cell names; matching columns will be removed.
1338            If not None, `full` must be None.
1339
1340        full : str | None
1341            Full name to search for in column/cell names; matching columns will be removed.
1342            If not None, `reg` must be None.
1343
1344        name_slot : str, default 'cell_names'
1345            Column in metadata to use as sample names.
1346
1347        inc_set : bool, default False
1348            If True, column names are interpreted as 'cell_name # set' when matching.
1349
1350        Update
1351        ------------
1352        Mutates `self.input_data`, `self.normalized_data`, `self.input_metadata`,
1353        `self.agg_normalized_data`, and `self.agg_metadata` (if they exist),
1354        removing columns/rows that match `reg`.
1355
1356        Raises
1357        ------
1358        Raises ValueError if nothing matches the reduction mask.
1359        """
1360
1361        if reg is None and full is None:
1362            raise ValueError(
1363                "Both 'reg' and 'full' arguments not provided. Please provide at least one of them!"
1364            )
1365
1366        if reg is not None and full is not None:
1367            raise ValueError(
1368                "Both 'reg' and 'full' arguments are provided. "
1369                "Please provide only one of them!\n"
1370                "'reg' is used when only part of the name must be detected.\n"
1371                "'full' is used if the full name must be detected."
1372            )
1373
1374        if reg is not None:
1375
1376            if self.input_data is not None:
1377
1378                if inc_set:
1379
1380                    self.input_data.columns = (
1381                        self.input_metadata[name_slot]
1382                        + " # "
1383                        + self.input_metadata["sets"]
1384                    )
1385
1386                else:
1387
1388                    self.input_data.columns = self.input_metadata[name_slot]
1389
1390                mask = [reg.upper() not in x.upper() for x in self.input_data.columns]
1391
1392                if len([y for y in mask if y is False]) == 0:
1393                    raise ValueError("Nothing found to reduce")
1394
1395                self.input_data = self.input_data.loc[:, mask]
1396
1397            if self.normalized_data is not None:
1398
1399                if inc_set:
1400
1401                    self.normalized_data.columns = (
1402                        self.input_metadata[name_slot]
1403                        + " # "
1404                        + self.input_metadata["sets"]
1405                    )
1406
1407                else:
1408
1409                    self.normalized_data.columns = self.input_metadata[name_slot]
1410
1411                mask = [
1412                    reg.upper() not in x.upper() for x in self.normalized_data.columns
1413                ]
1414
1415                if len([y for y in mask if y is False]) == 0:
1416                    raise ValueError("Nothing found to reduce")
1417
1418                self.normalized_data = self.normalized_data.loc[:, mask]
1419
1420            if self.input_metadata is not None:
1421
1422                if inc_set:
1423
1424                    self.input_metadata["drop"] = (
1425                        self.input_metadata[name_slot]
1426                        + " # "
1427                        + self.input_metadata["sets"]
1428                    )
1429
1430                else:
1431
1432                    self.input_metadata["drop"] = self.input_metadata[name_slot]
1433
1434                mask = [
1435                    reg.upper() not in x.upper() for x in self.input_metadata["drop"]
1436                ]
1437
1438                if len([y for y in mask if y is False]) == 0:
1439                    raise ValueError("Nothing found to reduce")
1440
1441                self.input_metadata = self.input_metadata.loc[mask, :].reset_index(
1442                    drop=True
1443                )
1444
1445                self.input_metadata = self.input_metadata.drop(
1446                    columns=["drop"], errors="ignore"
1447                )
1448
1449            if self.agg_normalized_data is not None:
1450
1451                if inc_set:
1452
1453                    self.agg_normalized_data.columns = (
1454                        self.agg_metadata[name_slot] + " # " + self.agg_metadata["sets"]
1455                    )
1456
1457                else:
1458
1459                    self.agg_normalized_data.columns = self.agg_metadata[name_slot]
1460
1461                mask = [
1462                    reg.upper() not in x.upper()
1463                    for x in self.agg_normalized_data.columns
1464                ]
1465
1466                if len([y for y in mask if y is False]) == 0:
1467                    raise ValueError("Nothing found to reduce")
1468
1469                self.agg_normalized_data = self.agg_normalized_data.loc[:, mask]
1470
1471            if self.agg_metadata is not None:
1472
1473                if inc_set:
1474
1475                    self.agg_metadata["drop"] = (
1476                        self.agg_metadata[name_slot] + " # " + self.agg_metadata["sets"]
1477                    )
1478
1479                else:
1480
1481                    self.agg_metadata["drop"] = self.agg_metadata[name_slot]
1482
1483                mask = [reg.upper() not in x.upper() for x in self.agg_metadata["drop"]]
1484
1485                if len([y for y in mask if y is False]) == 0:
1486                    raise ValueError("Nothing found to reduce")
1487
1488                self.agg_metadata = self.agg_metadata.loc[mask, :].reset_index(
1489                    drop=True
1490                )
1491
1492                self.agg_metadata = self.agg_metadata.drop(
1493                    columns=["drop"], errors="ignore"
1494                )
1495
1496        elif full is not None:
1497
1498            if self.input_data is not None:
1499
1500                if inc_set:
1501
1502                    self.input_data.columns = (
1503                        self.input_metadata[name_slot]
1504                        + " # "
1505                        + self.input_metadata["sets"]
1506                    )
1507
1508                    if "#" not in full:
1509
1510                        self.input_data.columns = self.input_metadata[name_slot]
1511
1512                        print(
1513                            "Not include the set info (name # set) in the 'full' argument, where 'inc_set' is True"
1514                            "Only the names will be compared, without considering the set information."
1515                        )
1516
1517                else:
1518
1519                    self.input_data.columns = self.input_metadata[name_slot]
1520
1521                mask = [full.upper() != x.upper() for x in self.input_data.columns]
1522
1523                if len([y for y in mask if y is False]) == 0:
1524                    raise ValueError("Nothing found to reduce")
1525
1526                self.input_data = self.input_data.loc[:, mask]
1527
1528            if self.normalized_data is not None:
1529
1530                if inc_set:
1531
1532                    self.normalized_data.columns = (
1533                        self.input_metadata[name_slot]
1534                        + " # "
1535                        + self.input_metadata["sets"]
1536                    )
1537
1538                    if "#" not in full:
1539
1540                        self.normalized_data.columns = self.input_metadata[name_slot]
1541
1542                        print(
1543                            "Not include the set info (name # set) in the 'full' argument, where 'inc_set' is True"
1544                            "Only the names will be compared, without considering the set information."
1545                        )
1546
1547                else:
1548
1549                    self.normalized_data.columns = self.input_metadata[name_slot]
1550
1551                mask = [full.upper() != x.upper() for x in self.normalized_data.columns]
1552
1553                if len([y for y in mask if y is False]) == 0:
1554                    raise ValueError("Nothing found to reduce")
1555
1556                self.normalized_data = self.normalized_data.loc[:, mask]
1557
1558            if self.input_metadata is not None:
1559
1560                if inc_set:
1561
1562                    self.input_metadata["drop"] = (
1563                        self.input_metadata[name_slot]
1564                        + " # "
1565                        + self.input_metadata["sets"]
1566                    )
1567
1568                    if "#" not in full:
1569
1570                        self.input_metadata["drop"] = self.input_metadata[name_slot]
1571
1572                        print(
1573                            "Not include the set info (name # set) in the 'full' argument, where 'inc_set' is True"
1574                            "Only the names will be compared, without considering the set information."
1575                        )
1576
1577                else:
1578
1579                    self.input_metadata["drop"] = self.input_metadata[name_slot]
1580
1581                mask = [full.upper() != x.upper() for x in self.input_metadata["drop"]]
1582
1583                if len([y for y in mask if y is False]) == 0:
1584                    raise ValueError("Nothing found to reduce")
1585
1586                self.input_metadata = self.input_metadata.loc[mask, :].reset_index(
1587                    drop=True
1588                )
1589
1590                self.input_metadata = self.input_metadata.drop(
1591                    columns=["drop"], errors="ignore"
1592                )
1593
1594            if self.agg_normalized_data is not None:
1595
1596                if inc_set:
1597
1598                    self.agg_normalized_data.columns = (
1599                        self.agg_metadata[name_slot] + " # " + self.agg_metadata["sets"]
1600                    )
1601
1602                    if "#" not in full:
1603
1604                        self.agg_normalized_data.columns = self.agg_metadata[name_slot]
1605
1606                        print(
1607                            "Not include the set info (name # set) in the 'full' argument, where 'inc_set' is True"
1608                            "Only the names will be compared, without considering the set information."
1609                        )
1610                else:
1611
1612                    self.agg_normalized_data.columns = self.agg_metadata[name_slot]
1613
1614                mask = [
1615                    full.upper() != x.upper() for x in self.agg_normalized_data.columns
1616                ]
1617
1618                if len([y for y in mask if y is False]) == 0:
1619                    raise ValueError("Nothing found to reduce")
1620
1621                self.agg_normalized_data = self.agg_normalized_data.loc[:, mask]
1622
1623            if self.agg_metadata is not None:
1624
1625                if inc_set:
1626
1627                    self.agg_metadata["drop"] = (
1628                        self.agg_metadata[name_slot] + " # " + self.agg_metadata["sets"]
1629                    )
1630
1631                    if "#" not in full:
1632
1633                        self.agg_metadata["drop"] = self.agg_metadata[name_slot]
1634
1635                        print(
1636                            "Not include the set info (name # set) in the 'full' argument, where 'inc_set' is True"
1637                            "Only the names will be compared, without considering the set information."
1638                        )
1639                else:
1640
1641                    self.agg_metadata["drop"] = self.agg_metadata[name_slot]
1642
1643                mask = [full.upper() != x.upper() for x in self.agg_metadata["drop"]]
1644
1645                if len([y for y in mask if y is False]) == 0:
1646                    raise ValueError("Nothing found to reduce")
1647
1648                self.agg_metadata = self.agg_metadata.loc[mask, :].reset_index(
1649                    drop=True
1650                )
1651
1652                self.agg_metadata = self.agg_metadata.drop(
1653                    columns=["drop"], errors="ignore"
1654                )
1655
1656        self.gene_calculation()
1657        self.cells_calculation()
1658
1659    def reduce_rows(self, features_list: list):
1660        """
1661        Remove rows (features) whose names are included in features_list.
1662
1663        Parameters
1664        ----------
1665        features_list : list
1666            List of features to search for in index/gene names; matching entries will be removed.
1667
1668        Update
1669        ------------
1670        Mutates `self.input_data`, `self.normalized_data`, `self.input_metadata`,
1671        `self.agg_normalized_data`, and `self.agg_metadata` (if they exist),
1672        removing columns/rows that match `reg`.
1673
1674        Raises
1675        ------
1676        Prints a message listing features that are not found in the data.
1677        """
1678
1679        if self.input_data is not None:
1680
1681            res = find_features(self.input_data, features=features_list)
1682
1683            res_list = [x.upper() for x in res["included"]]
1684
1685            mask = [x.upper() not in res_list for x in self.input_data.index]
1686
1687            if len([y for y in mask if y is False]) == 0:
1688                raise ValueError("Nothing found to reduce")
1689
1690            self.input_data = self.input_data.loc[mask, :]
1691
1692        if self.normalized_data is not None:
1693
1694            res = find_features(self.normalized_data, features=features_list)
1695
1696            res_list = [x.upper() for x in res["included"]]
1697
1698            mask = [x.upper() not in res_list for x in self.normalized_data.index]
1699
1700            if len([y for y in mask if y is False]) == 0:
1701                raise ValueError("Nothing found to reduce")
1702
1703            self.normalized_data = self.normalized_data.loc[mask, :]
1704
1705        if self.agg_normalized_data is not None:
1706
1707            res = find_features(self.agg_normalized_data, features=features_list)
1708
1709            res_list = [x.upper() for x in res["included"]]
1710
1711            mask = [x.upper() not in res_list for x in self.agg_normalized_data.index]
1712
1713            if len([y for y in mask if y is False]) == 0:
1714                raise ValueError("Nothing found to reduce")
1715
1716            self.agg_normalized_data = self.agg_normalized_data.loc[mask, :]
1717
1718        if len(res["not_included"]) > 0:
1719            print("\nFeatures not found:")
1720            for i in res["not_included"]:
1721                print(i)
1722
1723        self.gene_calculation()
1724        self.cells_calculation()
1725
1726    def get_data(self, set_info: bool = False):
1727        """
1728        Return normalized data with optional set annotation appended to column names.
1729
1730        Parameters
1731        ----------
1732        set_info : bool, default False
1733            If True, column names are returned as "cell_name # set"; otherwise
1734            only the `cell_name` is used.
1735
1736        Returns
1737        -------
1738        pandas.DataFrame
1739            The `self.normalized_data` table with columns renamed according to `set_info`.
1740
1741        Raises
1742        ------
1743        AttributeError
1744            If `self.normalized_data` or `self.input_metadata` is missing.
1745        """
1746
1747        to_return = self.normalized_data
1748
1749        if set_info:
1750            to_return.columns = (
1751                self.input_metadata["cell_names"] + " # " + self.input_metadata["sets"]
1752            )
1753        else:
1754            to_return.columns = self.input_metadata["cell_names"]
1755
1756        return to_return
1757
1758    def get_partial_data(
1759        self,
1760        names: list | str | None = None,
1761        features: list | str | None = None,
1762        name_slot: str = "cell_names",
1763        inc_metadata: bool = False,
1764    ):
1765        """
1766        Return a subset of the data filtered by sample names and/or feature names.
1767
1768        Parameters
1769        ----------
1770        names : list, str, or None
1771            Names of samples to include. If None, all samples are considered.
1772
1773        features : list, str, or None
1774            Names of features to include. If None, all features are considered.
1775
1776        name_slot : str
1777            Column in metadata to use as sample names.
1778
1779        inc_metadata : bool
1780            If True return tuple (data, metadata)
1781
1782        Returns
1783        -------
1784        pandas.DataFrame
1785            Subset of the normalized data based on the specified names and features.
1786        """
1787
1788        data = self.normalized_data.copy()
1789        metadata = self.input_metadata
1790
1791        if name_slot in self.input_metadata.columns:
1792            data.columns = self.input_metadata[name_slot]
1793        else:
1794            raise ValueError("'name_slot' not occured in data!'")
1795
1796        if isinstance(features, str):
1797            features = [features]
1798        elif features is None:
1799            features = []
1800
1801        if isinstance(names, str):
1802            names = [names]
1803        elif names is None:
1804            names = []
1805
1806        features = [x.upper() for x in features]
1807        names = [x.upper() for x in names]
1808
1809        columns_names = [x.upper() for x in data.columns]
1810        features_names = [x.upper() for x in data.index]
1811
1812        columns_bool = [True if x in names else False for x in columns_names]
1813        features_bool = [True if x in features else False for x in features_names]
1814
1815        if True not in columns_bool and True not in features_bool:
1816            print("Missing 'names' and/or 'features'. Returning full dataset instead.")
1817            if inc_metadata:
1818                return data, metadata
1819            else:
1820                return data
1821
1822        if True in columns_bool:
1823            data = data.loc[:, columns_bool]
1824            metadata = metadata.loc[columns_bool, :]
1825
1826            if inc_metadata:
1827                return data, metadata
1828            else:
1829                return data
1830
1831        if True in features_bool:
1832            data = data.loc[features_bool, :]
1833
1834            if inc_metadata:
1835                return data, metadata
1836            else:
1837                return data
1838
1839    def get_metadata(self):
1840        """
1841        Return the stored input metadata.
1842
1843        Returns
1844        -------
1845        pandas.DataFrame
1846            `self.input_metadata` (may be None if not set).
1847        """
1848
1849        to_return = self.input_metadata
1850
1851        return to_return
1852
1853    def gene_calculation(self):
1854        """
1855        Calculate and store per-cell counts (e.g., number of detected genes).
1856
1857        The method computes a binary (presence/absence) per cell and sums across
1858        features to produce `self.gene_calc`.
1859
1860        Update
1861        ------
1862        Sets `self.gene_calc` as a pandas.Series.
1863
1864        Side Effects
1865        ------------
1866        Uses `self.input_data` when available, otherwise `self.normalized_data`.
1867        """
1868
1869        if self.input_data is not None:
1870
1871            bin_col = self.input_data.columns.copy()
1872
1873            bin_col = bin_col.where(bin_col <= 0, 1)
1874
1875            sum_data = bin_col.sum(axis=0)
1876
1877            self.gene_calc = sum_data
1878
1879        elif self.normalized_data is not None:
1880
1881            bin_col = self.normalized_data.copy()
1882
1883            bin_col = bin_col.where(bin_col <= 0, 1)
1884
1885            sum_data = bin_col.sum(axis=0)
1886
1887            self.gene_calc = sum_data
1888
1889    def gene_histograme(self, bins=100):
1890        """
1891        Plot a histogram of the number of genes detected per cell.
1892
1893        Parameters
1894        ----------
1895        bins : int, default 100
1896            Number of histogram bins.
1897
1898        Returns
1899        -------
1900        matplotlib.figure.Figure
1901            Figure containing the histogram of gene contents.
1902
1903        Notes
1904        -----
1905        Requires `self.gene_calc` to be computed prior to calling.
1906        """
1907
1908        fig, ax = plt.subplots(figsize=(8, 5))
1909
1910        _, bin_edges, _ = ax.hist(
1911            self.gene_calc, bins=bins, edgecolor="black", alpha=0.6
1912        )
1913
1914        mu, sigma = np.mean(self.gene_calc), np.std(self.gene_calc)
1915
1916        x = np.linspace(min(self.gene_calc), max(self.gene_calc), 1000)
1917        y = norm.pdf(x, mu, sigma)
1918
1919        y_scaled = y * len(self.gene_calc) * (bin_edges[1] - bin_edges[0])
1920
1921        ax.plot(
1922            x, y_scaled, "r-", linewidth=2, label=f"Normal(μ={mu:.2f}, σ={sigma:.2f})"
1923        )
1924
1925        ax.set_xlabel("Value")
1926        ax.set_ylabel("Count")
1927        ax.set_title("Histogram of genes detected per cell")
1928
1929        ax.set_xticks(np.linspace(min(self.gene_calc), max(self.gene_calc), 20))
1930        ax.tick_params(axis="x", rotation=90)
1931
1932        ax.legend()
1933
1934        return fig
1935
1936    def gene_threshold(self, min_n: int | None, max_n: int | None):
1937        """
1938        Filter cells by gene-detection thresholds (min and/or max).
1939
1940        Parameters
1941        ----------
1942        min_n : int or None
1943            Minimum number of detected genes required to keep a cell.
1944
1945        max_n : int or None
1946            Maximum number of detected genes allowed to keep a cell.
1947
1948        Update
1949        -------
1950        Filters `self.input_data`, `self.normalized_data`, `self.input_metadata`
1951        (and calls `average()` if `self.agg_normalized_data` exists).
1952
1953        Side Effects
1954        ------------
1955        Raises ValueError if both bounds are None or if filtering removes all cells.
1956        """
1957
1958        if min_n is not None and max_n is not None:
1959            mask = (self.gene_calc > min_n) & (self.gene_calc < max_n)
1960        elif min_n is None and max_n is not None:
1961            mask = self.gene_calc < max_n
1962        elif min_n is not None and max_n is None:
1963            mask = self.gene_calc > min_n
1964        else:
1965            raise ValueError("Lack of both min_n and max_n values")
1966
1967        if self.input_data is not None:
1968
1969            if len([y for y in mask if y is False]) == 0:
1970                raise ValueError("Nothing to reduce")
1971
1972            self.input_data = self.input_data.loc[:, mask.values]
1973
1974        if self.normalized_data is not None:
1975
1976            if len([y for y in mask if y is False]) == 0:
1977                raise ValueError("Nothing to reduce")
1978
1979            self.normalized_data = self.normalized_data.loc[:, mask.values]
1980
1981        if self.input_metadata is not None:
1982
1983            if len([y for y in mask if y is False]) == 0:
1984                raise ValueError("Nothing to reduce")
1985
1986            self.input_metadata = self.input_metadata.loc[mask.values, :].reset_index(
1987                drop=True
1988            )
1989
1990            self.input_metadata = self.input_metadata.drop(
1991                columns=["drop"], errors="ignore"
1992            )
1993
1994        if self.agg_normalized_data is not None:
1995            self.average()
1996
1997        self.gene_calculation()
1998        self.cells_calculation()
1999
2000    def cells_calculation(self, name_slot="cell_names"):
2001        """
2002        Calculate number of cells per  call name / cluster.
2003
2004        The method computes a binary (presence/absence) per cell name / cluster and sums across
2005        cells.
2006
2007        Parameters
2008        ----------
2009        name_slot : str, default 'cell_names'
2010            Column in metadata to use as sample names.
2011
2012        Update
2013        ------
2014        Sets `self.cells_calc` as a pd.DataFrame.
2015        """
2016
2017        ls = list(self.input_metadata[name_slot])
2018
2019        df = pd.DataFrame(
2020            {
2021                "cluster": pd.Series(ls).value_counts().index,
2022                "n": pd.Series(ls).value_counts().values,
2023            }
2024        )
2025
2026        self.cells_calc = df
2027
2028    def cell_histograme(self, name_slot: str = "cell_names"):
2029        """
2030        Plot a histogram of the number of cells detected per cell name (cluster).
2031
2032        Parameters
2033        ----------
2034        name_slot : str, default 'cell_names'
2035            Column in metadata to use as sample names.
2036
2037        Returns
2038        -------
2039        matplotlib.figure.Figure
2040            Figure containing the histogram of cell contents.
2041
2042        Notes
2043        -----
2044        Requires `self.cells_calc` to be computed prior to calling.
2045        """
2046
2047        if name_slot != "cell_names":
2048            self.cells_calculation(name_slot=name_slot)
2049
2050        fig, ax = plt.subplots(figsize=(8, 5))
2051
2052        _, bin_edges, _ = ax.hist(
2053            list(self.cells_calc["n"]),
2054            bins=len(set(self.cells_calc["cluster"])),
2055            edgecolor="black",
2056            color="orange",
2057            alpha=0.6,
2058        )
2059
2060        mu, sigma = np.mean(list(self.cells_calc["n"])), np.std(
2061            list(self.cells_calc["n"])
2062        )
2063
2064        x = np.linspace(
2065            min(list(self.cells_calc["n"])), max(list(self.cells_calc["n"])), 1000
2066        )
2067        y = norm.pdf(x, mu, sigma)
2068
2069        y_scaled = y * len(list(self.cells_calc["n"])) * (bin_edges[1] - bin_edges[0])
2070
2071        ax.plot(
2072            x, y_scaled, "r-", linewidth=2, label=f"Normal(μ={mu:.2f}, σ={sigma:.2f})"
2073        )
2074
2075        ax.set_xlabel("Value")
2076        ax.set_ylabel("Count")
2077        ax.set_title("Histogram of cells detected per cell name / cluster")
2078
2079        ax.set_xticks(
2080            np.linspace(
2081                min(list(self.cells_calc["n"])), max(list(self.cells_calc["n"])), 20
2082            )
2083        )
2084        ax.tick_params(axis="x", rotation=90)
2085
2086        ax.legend()
2087
2088        return fig
2089
2090    def cluster_threshold(self, min_n: int | None, name_slot: str = "cell_names"):
2091        """
2092        Filter cell names / clusters by cell-detection threshold.
2093
2094        Parameters
2095        ----------
2096        min_n : int or None
2097            Minimum number of detected genes required to keep a cell.
2098
2099        name_slot : str, default 'cell_names'
2100            Column in metadata to use as sample names.
2101
2102
2103        Update
2104        -------
2105        Filters `self.input_data`, `self.normalized_data`, `self.input_metadata`
2106        (and calls `average()` if `self.agg_normalized_data` exists).
2107        """
2108
2109        if name_slot != "cell_names":
2110            self.cells_calculation(name_slot=name_slot)
2111
2112        if min_n is not None:
2113            names = self.cells_calc["cluster"][self.cells_calc["n"] < min_n]
2114        else:
2115            raise ValueError("Lack of min_n value")
2116
2117        if len(names) > 0:
2118
2119            if self.input_data is not None:
2120
2121                self.input_data.columns = self.input_metadata[name_slot]
2122
2123                mask = [not any(r in x for r in names) for x in self.input_data.columns]
2124
2125                if len([y for y in mask if y is False]) > 0:
2126
2127                    self.input_data = self.input_data.loc[:, mask]
2128
2129            if self.normalized_data is not None:
2130
2131                self.normalized_data.columns = self.input_metadata[name_slot]
2132
2133                mask = [
2134                    not any(r in x for r in names) for x in self.normalized_data.columns
2135                ]
2136
2137                if len([y for y in mask if y is False]) > 0:
2138
2139                    self.normalized_data = self.normalized_data.loc[:, mask]
2140
2141            if self.input_metadata is not None:
2142
2143                self.input_metadata["drop"] = self.input_metadata[name_slot]
2144
2145                mask = [
2146                    not any(r in x for r in names) for x in self.input_metadata["drop"]
2147                ]
2148
2149                if len([y for y in mask if y is False]) > 0:
2150
2151                    self.input_metadata = self.input_metadata.loc[mask, :].reset_index(
2152                        drop=True
2153                    )
2154
2155                self.input_metadata = self.input_metadata.drop(
2156                    columns=["drop"], errors="ignore"
2157                )
2158
2159            if self.agg_normalized_data is not None:
2160
2161                self.agg_normalized_data.columns = self.agg_metadata[name_slot]
2162
2163                mask = [
2164                    not any(r in x for r in names)
2165                    for x in self.agg_normalized_data.columns
2166                ]
2167
2168                if len([y for y in mask if y is False]) > 0:
2169
2170                    self.agg_normalized_data = self.agg_normalized_data.loc[:, mask]
2171
2172            if self.agg_metadata is not None:
2173
2174                self.agg_metadata["drop"] = self.agg_metadata[name_slot]
2175
2176                mask = [
2177                    not any(r in x for r in names) for x in self.agg_metadata["drop"]
2178                ]
2179
2180                if len([y for y in mask if y is False]) > 0:
2181
2182                    self.agg_metadata = self.agg_metadata.loc[mask, :].reset_index(
2183                        drop=True
2184                    )
2185
2186                self.agg_metadata = self.agg_metadata.drop(
2187                    columns=["drop"], errors="ignore"
2188                )
2189
2190            self.gene_calculation()
2191            self.cells_calculation()
2192
2193    def load_sparse_from_projects(self, normalized_data: bool = False):
2194        """
2195        Load sparse 10x-style datasets from stored project paths, concatenate them,
2196        and populate `input_data` / `normalized_data` and `input_metadata`.
2197
2198        Parameters
2199        ----------
2200        normalized_data : bool, default False
2201            If True, store concatenated tables in `self.normalized_data`.
2202            If False, store them in `self.input_data` and normalization
2203            is needed using normalize_data() method.
2204
2205        Side Effects
2206        ------------
2207        - Reads each project using `load_sparse(...)` (expects matrix.mtx, genes.tsv, barcodes.tsv).
2208        - Concatenates all projects column-wise and sets `self.input_metadata`.
2209        - Replaces NaNs with zeros and updates `self.gene_calc`.
2210        """
2211
2212        obj = self.objects
2213
2214        full_data = pd.DataFrame()
2215        full_metadata = pd.DataFrame()
2216
2217        for ke in obj.keys():
2218            print(ke)
2219
2220            dt, met = load_sparse(path=obj[ke], name=ke)
2221
2222            full_data = pd.concat([full_data, dt], axis=1)
2223            full_metadata = pd.concat([full_metadata, met], axis=0)
2224
2225        full_data[np.isnan(full_data)] = 0
2226
2227        if normalized_data:
2228            self.normalized_data = full_data
2229            self.input_metadata = full_metadata
2230        else:
2231
2232            self.input_data = full_data
2233            self.input_metadata = full_metadata
2234
2235        self.gene_calculation()
2236        self.cells_calculation()
2237
2238    def rename_names(self, mapping: dict, slot: str = "cell_names"):
2239        """
2240        Rename entries in `self.input_metadata[slot]` according to a provided mapping.
2241
2242        Parameters
2243        ----------
2244        mapping : dict
2245            Dictionary with keys 'old_name' and 'new_name', each mapping to a list
2246            of equal length describing replacements.
2247
2248        slot : str, default 'cell_names'
2249            Metadata column to operate on.
2250
2251        Update
2252        -------
2253        Updates `self.input_metadata[slot]` in-place with renamed values.
2254
2255        Raises
2256        ------
2257        ValueError
2258            If mapping keys are incorrect, lengths differ, or some 'old_name' values
2259            are not present in the metadata column.
2260        """
2261
2262        if set(["old_name", "new_name"]) != set(mapping.keys()):
2263            raise ValueError(
2264                "Mapping dictionary must contain keys 'old_name' and 'new_name', "
2265                "each with a list of names to change."
2266            )
2267
2268        if len(mapping["old_name"]) != len(mapping["new_name"]):
2269            raise ValueError(
2270                "Mapping dictionary lists 'old_name' and 'new_name' "
2271                "must have the same length!"
2272            )
2273
2274        names_vector = list(self.input_metadata[slot])
2275
2276        if not all(elem in names_vector for elem in list(mapping["old_name"])):
2277            raise ValueError(
2278                f"Some entries from 'old_name' do not exist in the names of slot {slot}."
2279            )
2280
2281        replace_dict = dict(zip(mapping["old_name"], mapping["new_name"]))
2282
2283        names_vector_ret = [replace_dict.get(item, item) for item in names_vector]
2284
2285        self.input_metadata[slot] = names_vector_ret
2286
2287    def rename_subclusters(self, mapping):
2288        """
2289        Rename labels stored in `self.subclusters_.subclusters` according to mapping.
2290
2291        Parameters
2292        ----------
2293        mapping : dict
2294            Mapping with keys 'old_name' and 'new_name' (lists of equal length).
2295
2296        Update
2297        -------
2298        Updates `self.subclusters_.subclusters` with renamed labels.
2299
2300        Raises
2301        ------
2302        ValueError
2303            If mapping is invalid or old names are not present.
2304        """
2305
2306        if set(["old_name", "new_name"]) != set(mapping.keys()):
2307            raise ValueError(
2308                "Mapping dictionary must contain keys 'old_name' and 'new_name', "
2309                "each with a list of names to change."
2310            )
2311
2312        if len(mapping["old_name"]) != len(mapping["new_name"]):
2313            raise ValueError(
2314                "Mapping dictionary lists 'old_name' and 'new_name' "
2315                "must have the same length!"
2316            )
2317
2318        names_vector = list(self.subclusters_.subclusters)
2319
2320        if not all(elem in names_vector for elem in list(mapping["old_name"])):
2321            raise ValueError(
2322                "Some entries from 'old_name' do not exist in the subcluster names."
2323            )
2324
2325        replace_dict = dict(zip(mapping["old_name"], mapping["new_name"]))
2326
2327        names_vector_ret = [replace_dict.get(item, item) for item in names_vector]
2328
2329        self.subclusters_.subclusters = names_vector_ret
2330
2331    def save_sparse(
2332        self,
2333        path_to_save: str = os.getcwd(),
2334        name_slot: str = "cell_names",
2335        data_slot: str = "normalized",
2336    ):
2337        """
2338        Export data as 10x-compatible sparse files (matrix.mtx, barcodes.tsv, genes.tsv).
2339
2340        Parameters
2341        ----------
2342        path_to_save : str, default current working directory
2343            Directory where files will be written.
2344
2345        name_slot : str, default 'cell_names'
2346            Metadata column providing cell names for barcodes.tsv.
2347
2348        data_slot : str, default 'normalized'
2349            Either 'normalized' (uses self.normalized_data) or 'count' (uses self.input_data).
2350
2351        Raises
2352        ------
2353        ValueError
2354            If `data_slot` is not 'normalized' or 'count'.
2355        """
2356
2357        names = self.input_metadata[name_slot]
2358
2359        if data_slot.lower() == "normalized":
2360
2361            features = list(self.normalized_data.index)
2362            mtx = sparse.csr_matrix(self.normalized_data)
2363
2364        elif data_slot.lower() == "count":
2365
2366            features = list(self.input_data.index)
2367            mtx = sparse.csr_matrix(self.input_data)
2368
2369        else:
2370            raise ValueError("'data_slot' must be included in 'normalized' or 'count'")
2371
2372        os.makedirs(path_to_save, exist_ok=True)
2373
2374        mmwrite(os.path.join(path_to_save, "matrix.mtx"), mtx)
2375
2376        pd.Series(names).to_csv(
2377            os.path.join(path_to_save, "barcodes.tsv"),
2378            index=False,
2379            header=False,
2380            sep="\t",
2381        )
2382
2383        pd.Series(features).to_csv(
2384            os.path.join(path_to_save, "genes.tsv"), index=False, header=False, sep="\t"
2385        )
2386
2387    def normalize_counts(
2388        self, normalize_factor: int = 100000, log_transform: bool = True
2389    ):
2390        """
2391        Normalize raw counts to counts-per-(normalize_factor)
2392        (e.g., CPM, TPM - depending on normalize_factor).
2393
2394        Parameters
2395        ----------
2396        normalize_factor : int, default 100000
2397            Scaling factor used after dividing by column sums.
2398
2399        log_transform : bool, default True
2400            If True, apply log2(x+1) transformation to normalized values.
2401
2402        Update
2403        -------
2404            Sets `self.normalized_data` to normalized values (fills NaNs with 0).
2405
2406        Raises
2407        ------
2408        ValueError
2409            If `self.input_data` is missing (cannot normalize).
2410        """
2411        if self.input_data is None:
2412            raise ValueError("Input data is missing, cannot normalize.")
2413
2414        sum_col = self.input_data.sum()
2415        self.normalized_data = self.input_data.div(sum_col).fillna(0) * normalize_factor
2416
2417        if log_transform:
2418            # log2(x + 1) to avoid -inf for zeros
2419            self.normalized_data = np.log2(self.normalized_data + 1)
2420
2421    def statistic(
2422        self,
2423        cells=None,
2424        sets=None,
2425        min_exp: float = 0.01,
2426        min_pct: float = 0.1,
2427        n_proc: int = 10,
2428    ):
2429        """
2430        Compute per-feature statistics (Mann–Whitney U) comparing target vs rest.
2431
2432        This is a wrapper similar to `calc_DEG` tailored to use `self.normalized_data`
2433        and `self.input_metadata`. It returns per-feature statistics including p-values,
2434        adjusted p-values, means, variances, effect-size measures and fold-changes.
2435
2436        Parameters
2437        ----------
2438        cells : list, 'All', dict, or None
2439            Defines the target cells or groups for comparison (several modes supported).
2440
2441        sets : 'All', dict, or None
2442            Alternative grouping mode (operate on `self.input_metadata['sets']`).
2443
2444        min_exp : float, default 0.01
2445            Minimum expression threshold used when filtering features.
2446
2447        min_pct : float, default 0.1
2448            Minimum proportion of expressing cells in the target group required to test a feature.
2449
2450        n_proc : int, default 10
2451            Number of parallel jobs to use.
2452
2453        Returns
2454        -------
2455        pandas.DataFrame or dict
2456            Results DataFrame (or dict containing valid/control cells + DataFrame),
2457            similar to `calc_DEG` interface.
2458
2459        Raises
2460        ------
2461        ValueError
2462            If neither `cells` nor `sets` is provided, or input metadata mismatch occurs.
2463
2464        Notes
2465        -----
2466        Multiple modes supported: single-list entities, 'All', pairwise dicts, etc.
2467        """
2468
2469        def stat_calc(choose, feature_name):
2470            target_values = choose.loc[choose["DEG"] == "target", feature_name]
2471            rest_values = choose.loc[choose["DEG"] == "rest", feature_name]
2472
2473            pct_valid = (target_values > 0).sum() / len(target_values)
2474            pct_rest = (rest_values > 0).sum() / len(rest_values)
2475
2476            avg_valid = np.mean(target_values)
2477            avg_ctrl = np.mean(rest_values)
2478            sd_valid = np.std(target_values, ddof=1)
2479            sd_ctrl = np.std(rest_values, ddof=1)
2480            esm = (avg_valid - avg_ctrl) / np.sqrt(((sd_valid**2 + sd_ctrl**2) / 2))
2481
2482            if np.sum(target_values) == np.sum(rest_values):
2483                p_val = 1.0
2484            else:
2485                _, p_val = stats.mannwhitneyu(
2486                    target_values, rest_values, alternative="two-sided"
2487                )
2488
2489            return {
2490                "feature": feature_name,
2491                "p_val": p_val,
2492                "pct_valid": pct_valid,
2493                "pct_ctrl": pct_rest,
2494                "avg_valid": avg_valid,
2495                "avg_ctrl": avg_ctrl,
2496                "sd_valid": sd_valid,
2497                "sd_ctrl": sd_ctrl,
2498                "esm": esm,
2499            }
2500
2501        def prepare_and_run_stat(
2502            choose, valid_group, min_exp, min_pct, n_proc, factors
2503        ):
2504
2505            tmp_dat = choose[choose["DEG"] == "target"]
2506            tmp_dat = tmp_dat.drop("DEG", axis=1)
2507
2508            counts = (tmp_dat > min_exp).sum(axis=0)
2509
2510            total_count = tmp_dat.shape[0]
2511
2512            info = pd.DataFrame(
2513                {"feature": list(tmp_dat.columns), "pct": list(counts / total_count)}
2514            )
2515
2516            del tmp_dat
2517
2518            drop_col = info["feature"][info["pct"] <= min_pct]
2519
2520            if len(drop_col) + 1 == len(choose.columns):
2521                drop_col = info["feature"][info["pct"] == 0]
2522
2523            del info
2524
2525            choose = choose.drop(list(drop_col), axis=1)
2526
2527            results = Parallel(n_jobs=n_proc)(
2528                delayed(stat_calc)(choose, feature)
2529                for feature in tqdm(choose.columns[choose.columns != "DEG"])
2530            )
2531
2532            df = pd.DataFrame(results)
2533            df["valid_group"] = valid_group
2534            df.sort_values(by="p_val", inplace=True)
2535
2536            num_tests = len(df)
2537            df["adj_pval"] = np.minimum(
2538                1, (df["p_val"] * num_tests) / np.arange(1, num_tests + 1)
2539            )
2540
2541            factors_df = factors.to_frame(name="factor")
2542
2543            df = df.merge(factors_df, left_on="feature", right_index=True, how="left")
2544
2545            df["FC"] = (df["avg_valid"] + df["factor"]) / (
2546                df["avg_ctrl"] + df["factor"]
2547            )
2548            df["log(FC)"] = np.log2(df["FC"])
2549            df["norm_diff"] = df["avg_valid"] - df["avg_ctrl"]
2550            df = df.drop(columns=["factor"])
2551
2552            return df
2553
2554        choose = self.normalized_data.copy().T
2555        factors = self.normalized_data.copy().replace(0, np.nan).min(axis=1)
2556        factors[factors == 0] = min(factors[factors != 0])
2557
2558        final_results = []
2559
2560        if isinstance(cells, list) and sets is None:
2561            print("\nAnalysis started...\nComparing selected cells to the whole set...")
2562            choose.index = (
2563                self.input_metadata["cell_names"] + " # " + self.input_metadata["sets"]
2564            )
2565
2566            if "#" not in cells[0]:
2567                choose.index = self.input_metadata["cell_names"]
2568
2569                print(
2570                    "Not include the set info (name # set) in the 'cells' list. "
2571                    "Only the names will be compared, without considering the set information."
2572                )
2573
2574            labels = ["target" if idx in cells else "rest" for idx in choose.index]
2575            valid = list(
2576                set(choose.index[[i for i, x in enumerate(labels) if x == "target"]])
2577            )
2578
2579            choose["DEG"] = labels
2580            choose = choose[choose["DEG"] != "drop"]
2581
2582            result_df = prepare_and_run_stat(
2583                choose.reset_index(drop=True),
2584                valid_group=valid,
2585                min_exp=min_exp,
2586                min_pct=min_pct,
2587                n_proc=n_proc,
2588                factors=factors,
2589            )
2590            return {"valid_cells": valid, "control_cells": "rest", "DEG": result_df}
2591
2592        elif cells == "All" and sets is None:
2593            print("\nAnalysis started...\nComparing each type of cell to others...")
2594            choose.index = (
2595                self.input_metadata["cell_names"] + " # " + self.input_metadata["sets"]
2596            )
2597            unique_labels = set(choose.index)
2598
2599            for label in tqdm(unique_labels):
2600                print(f"\nCalculating statistics for {label}")
2601                labels = ["target" if idx == label else "rest" for idx in choose.index]
2602                choose["DEG"] = labels
2603                choose = choose[choose["DEG"] != "drop"]
2604                result_df = prepare_and_run_stat(
2605                    choose.copy(),
2606                    valid_group=label,
2607                    min_exp=min_exp,
2608                    min_pct=min_pct,
2609                    n_proc=n_proc,
2610                    factors=factors,
2611                )
2612                final_results.append(result_df)
2613
2614            return pd.concat(final_results, ignore_index=True)
2615
2616        elif cells is None and sets == "All":
2617            print("\nAnalysis started...\nComparing each set/group to others...")
2618            choose.index = self.input_metadata["sets"]
2619            unique_sets = set(choose.index)
2620
2621            for label in tqdm(unique_sets):
2622                print(f"\nCalculating statistics for {label}")
2623                labels = ["target" if idx == label else "rest" for idx in choose.index]
2624
2625                choose["DEG"] = labels
2626                choose = choose[choose["DEG"] != "drop"]
2627                result_df = prepare_and_run_stat(
2628                    choose.copy(),
2629                    valid_group=label,
2630                    min_exp=min_exp,
2631                    min_pct=min_pct,
2632                    n_proc=n_proc,
2633                    factors=factors,
2634                )
2635                final_results.append(result_df)
2636
2637            return pd.concat(final_results, ignore_index=True)
2638
2639        elif cells is None and isinstance(sets, dict):
2640            print("\nAnalysis started...\nComparing groups...")
2641
2642            choose.index = self.input_metadata["sets"]
2643
2644            group_list = list(sets.keys())
2645            if len(group_list) != 2:
2646                print("Only pairwise group comparison is supported.")
2647                return None
2648
2649            labels = [
2650                (
2651                    "target"
2652                    if idx in sets[group_list[0]]
2653                    else "rest" if idx in sets[group_list[1]] else "drop"
2654                )
2655                for idx in choose.index
2656            ]
2657            choose["DEG"] = labels
2658            choose = choose[choose["DEG"] != "drop"]
2659
2660            result_df = prepare_and_run_stat(
2661                choose.reset_index(drop=True),
2662                valid_group=group_list[0],
2663                min_exp=min_exp,
2664                min_pct=min_pct,
2665                n_proc=n_proc,
2666                factors=factors,
2667            )
2668            return result_df
2669
2670        elif isinstance(cells, dict) and sets is None:
2671            print("\nAnalysis started...\nComparing groups...")
2672            choose.index = (
2673                self.input_metadata["cell_names"] + " # " + self.input_metadata["sets"]
2674            )
2675
2676            if "#" not in cells[list(cells.keys())[0]][0]:
2677                choose.index = self.input_metadata["cell_names"]
2678
2679                print(
2680                    "Not include the set info (name # set) in the 'cells' dict. "
2681                    "Only the names will be compared, without considering the set information."
2682                )
2683
2684            group_list = list(cells.keys())
2685            if len(group_list) != 2:
2686                print("Only pairwise group comparison is supported.")
2687                return None
2688
2689            labels = [
2690                (
2691                    "target"
2692                    if idx in cells[group_list[0]]
2693                    else "rest" if idx in cells[group_list[1]] else "drop"
2694                )
2695                for idx in choose.index
2696            ]
2697
2698            choose["DEG"] = labels
2699            choose = choose[choose["DEG"] != "drop"]
2700
2701            result_df = prepare_and_run_stat(
2702                choose.reset_index(drop=True),
2703                valid_group=group_list[0],
2704                min_exp=min_exp,
2705                min_pct=min_pct,
2706                n_proc=n_proc,
2707                factors=factors,
2708            )
2709
2710            return result_df.reset_index(drop=True)
2711
2712        else:
2713            raise ValueError(
2714                "You must specify either 'cells' or 'sets' (or both). None were provided, which is not allowed for this analysis."
2715            )
2716
2717    def calculate_difference_markers(
2718        self, min_exp=0, min_pct=0.25, n_proc=10, force: bool = False
2719    ):
2720        """
2721        Compute differential markers (var_data) if not already present.
2722
2723        Parameters
2724        ----------
2725        min_exp : float, default 0
2726            Minimum expression threshold passed to `statistic`.
2727
2728        min_pct : float, default 0.25
2729            Minimum percent expressed in target group.
2730
2731        n_proc : int, default 10
2732            Parallel jobs.
2733
2734        force : bool, default False
2735            If True, recompute even if `self.var_data` is present.
2736
2737        Update
2738        -------
2739        Sets `self.var_data` to the result of `self.statistic(...)`.
2740
2741        Raise
2742        ------
2743        ValueError if already computed and `force` is False.
2744        """
2745
2746        if self.var_data is None or force:
2747
2748            self.var_data = self.statistic(
2749                cells="All", sets=None, min_exp=min_exp, min_pct=min_pct, n_proc=n_proc
2750            )
2751
2752        else:
2753            raise ValueError(
2754                "self.calculate_difference_markers() has already been executed. "
2755                "The results are stored in self.var. "
2756                "If you want to recalculate with different statistics, please rerun the method with force=True."
2757            )
2758
2759    def clustering_features(
2760        self,
2761        features_list: list | None,
2762        name_slot: str = "cell_names",
2763        p_val: float = 0.05,
2764        top_n: int = 25,
2765        adj_mean: bool = True,
2766        beta: float = 0.2,
2767    ):
2768        """
2769        Prepare clustering input by selecting marker features and optionally smoothing cell values
2770        toward group means.
2771
2772        Parameters
2773        ----------
2774        features_list : list or None
2775            If provided, use this list of features. If None, features are selected
2776            from `self.var_data` (adj_pval <= p_val, positive logFC) picking `top_n` per group.
2777
2778        name_slot : str, default 'cell_names'
2779            Metadata column used for naming.
2780
2781        p_val : float, default 0.05
2782            Adjusted p-value cutoff when selecting features automatically.
2783
2784        top_n : int, default 25
2785            Number of top features per valid group to keep if `features_list` is None.
2786
2787        adj_mean : bool, default True
2788            If True, adjust cell values toward group means using `beta`.
2789
2790        beta : float, default 0.2
2791            Adjustment strength toward group mean.
2792
2793        Update
2794        ------
2795        Sets `self.clustering_data` and `self.clustering_metadata` to the selected subset,
2796        ready for PCA/UMAP/clustering.
2797        """
2798
2799        if features_list is None or len(features_list) == 0:
2800
2801            if self.var_data is None:
2802                raise ValueError(
2803                    "Lack of 'self.var_data'. Use self.calculate_difference_markers() method first."
2804                )
2805
2806            df_tmp = self.var_data[self.var_data["adj_pval"] <= p_val]
2807            df_tmp = df_tmp[df_tmp["log(FC)"] > 0]
2808            df_tmp = (
2809                df_tmp.sort_values(
2810                    ["valid_group", "esm", "log(FC)"], ascending=[True, False, False]
2811                )
2812                .groupby("valid_group")
2813                .head(top_n)
2814            )
2815
2816            feaures_list = list(set(df_tmp["feature"]))
2817
2818        data = self.get_partial_data(
2819            names=None, features=feaures_list, name_slot=name_slot
2820        )
2821        data_avg = average(data)
2822
2823        if adj_mean:
2824            data = adjust_cells_to_group_mean(data=data, data_avg=data_avg, beta=beta)
2825
2826        self.clustering_data = data
2827
2828        self.clustering_metadata = self.input_metadata
2829
2830    def average(self):
2831        """
2832        Aggregate normalized data by (cell_name, set) pairs computing the mean per group.
2833
2834        The method constructs new column names as "cell_name # set", averages columns
2835        sharing identical labels, and populates `self.agg_normalized_data` and `self.agg_metadata`.
2836
2837        Update
2838        ------
2839        Sets `self.agg_normalized_data` (features x aggregated samples) and
2840        `self.agg_metadata` (DataFrame with 'cell_names' and 'sets').
2841        """
2842
2843        wide_data = self.normalized_data
2844
2845        wide_metadata = self.input_metadata
2846
2847        new_names = wide_metadata["cell_names"] + " # " + wide_metadata["sets"]
2848
2849        wide_data.columns = list(new_names)
2850
2851        aggregated_df = wide_data.T.groupby(wide_data.columns, axis=0).mean().T
2852
2853        sets = [re.sub(".*# ", "", x) for x in aggregated_df.columns]
2854        names = [re.sub(" #.*", "", x) for x in aggregated_df.columns]
2855
2856        aggregated_df.columns = names
2857        aggregated_metadata = pd.DataFrame({"cell_names": names, "sets": sets})
2858
2859        self.agg_metadata = aggregated_metadata
2860        self.agg_normalized_data = aggregated_df
2861
2862    def estimating_similarity(
2863        self, method="pearson", p_val: float = 0.05, top_n: int = 25
2864    ):
2865        """
2866        Estimate pairwise similarity and Euclidean distance between aggregated samples.
2867
2868        Parameters
2869        ----------
2870        method : str, default 'pearson'
2871            Correlation method to use (passed to pandas.DataFrame.corr()).
2872
2873        p_val : float, default 0.05
2874            Adjusted p-value cutoff used to select marker features from `self.var_data`.
2875
2876        top_n : int, default 25
2877            Number of top features per valid group to include.
2878
2879        Update
2880        -------
2881        Computes a combined table with per-pair correlation and euclidean distance
2882        and stores it in `self.similarity`.
2883        """
2884
2885        if self.var_data is None:
2886            raise ValueError(
2887                "Lack of 'self.var_data'. Use self.calculate_difference_markers() method first."
2888            )
2889
2890        if self.agg_normalized_data is None:
2891            self.average()
2892
2893        metadata = self.agg_metadata
2894        data = self.agg_normalized_data
2895
2896        df_tmp = self.var_data[self.var_data["adj_pval"] <= p_val]
2897        df_tmp = df_tmp[df_tmp["log(FC)"] > 0]
2898        df_tmp = (
2899            df_tmp.sort_values(
2900                ["valid_group", "esm", "log(FC)"], ascending=[True, False, False]
2901            )
2902            .groupby("valid_group")
2903            .head(top_n)
2904        )
2905
2906        data = data.loc[list(set(df_tmp["feature"]))]
2907
2908        if len(set(metadata["sets"])) > 1:
2909            data.columns = data.columns + " # " + [x for x in metadata["sets"]]
2910        else:
2911            data = data.copy()
2912
2913        scaler = StandardScaler()
2914
2915        scaled_data = scaler.fit_transform(data)
2916
2917        scaled_df = pd.DataFrame(scaled_data, columns=data.columns)
2918
2919        cor = scaled_df.corr(method=method)
2920        cor_df = cor.stack().reset_index()
2921        cor_df.columns = ["cell1", "cell2", "correlation"]
2922
2923        distances = pdist(scaled_df.T, metric="euclidean")
2924        dist_mat = pd.DataFrame(
2925            squareform(distances), index=scaled_df.columns, columns=scaled_df.columns
2926        )
2927        dist_df = dist_mat.stack().reset_index()
2928        dist_df.columns = ["cell1", "cell2", "euclidean_dist"]
2929
2930        full = pd.merge(cor_df, dist_df, on=["cell1", "cell2"])
2931
2932        full = full[full["cell1"] != full["cell2"]]
2933        full = full.reset_index(drop=True)
2934
2935        self.similarity = full
2936
2937    def similarity_plot(
2938        self,
2939        split_sets=True,
2940        set_info: bool = True,
2941        cmap="seismic",
2942        width=12,
2943        height=10,
2944    ):
2945        """
2946        Visualize pairwise similarity as a scatter plot.
2947
2948        Parameters
2949        ----------
2950        split_sets : bool, default True
2951            If True and set information is present, split plotting area roughly into two halves to visualize cross-set pairs.
2952
2953        set_info : bool, default True
2954            If True, keep the ' # set' annotation in labels; otherwise strip it.
2955
2956        cmap : str, default 'seismic'
2957            Color map for correlation (hue).
2958
2959        width : int, default 12
2960            Figure width.
2961
2962        height : int, default 10
2963            Figure height.
2964
2965        Returns
2966        -------
2967        matplotlib.figure.Figure
2968
2969        Raises
2970        ------
2971        ValueError
2972            If `self.similarity` is None.
2973
2974        Notes
2975        -----
2976        The function filters pairs by z-scored euclidean distance > 0 to focus on closer pairs.
2977        """
2978
2979        if self.similarity is None:
2980            raise ValueError(
2981                "Similarity data is missing. Please calculate similarity using self.estimating_similarity."
2982            )
2983
2984        similarity_data = self.similarity
2985
2986        if " # " in similarity_data["cell1"][0]:
2987            similarity_data["set1"] = [
2988                re.sub(".*# ", "", x) for x in similarity_data["cell1"]
2989            ]
2990            similarity_data["set2"] = [
2991                re.sub(".*# ", "", x) for x in similarity_data["cell2"]
2992            ]
2993
2994        if split_sets and " # " in similarity_data["cell1"][0]:
2995            sets = list(
2996                set(list(similarity_data["set1"]) + list(similarity_data["set2"]))
2997            )
2998
2999            mm = math.ceil(len(sets) / 2)
3000
3001            x_s = sets[0:mm]
3002            y_s = sets[mm : len(sets)]
3003
3004            similarity_data = similarity_data[similarity_data["set1"].isin(x_s)]
3005            similarity_data = similarity_data[similarity_data["set2"].isin(y_s)]
3006
3007            similarity_data = similarity_data.sort_values(["set1", "set2"])
3008
3009        if set_info is False and " # " in similarity_data["cell1"][0]:
3010            similarity_data["cell1"] = [
3011                re.sub(" #.*", "", x) for x in similarity_data["cell1"]
3012            ]
3013            similarity_data["cell2"] = [
3014                re.sub(" #.*", "", x) for x in similarity_data["cell2"]
3015            ]
3016
3017        similarity_data["-euclidean_zscore"] = -zscore(
3018            similarity_data["euclidean_dist"]
3019        )
3020
3021        similarity_data = similarity_data[similarity_data["-euclidean_zscore"] > 0]
3022
3023        fig = plt.figure(figsize=(width, height))
3024        sns.scatterplot(
3025            data=similarity_data,
3026            x="cell1",
3027            y="cell2",
3028            hue="correlation",
3029            size="-euclidean_zscore",
3030            sizes=(1, 100),
3031            palette=cmap,
3032            alpha=1,
3033            edgecolor="black",
3034        )
3035
3036        plt.xticks(rotation=90)
3037        plt.yticks(rotation=0)
3038        plt.xlabel("Cell 1")
3039        plt.ylabel("Cell 2")
3040        plt.legend(bbox_to_anchor=(1.05, 1), loc="upper left")
3041
3042        plt.grid(True, alpha=0.6)
3043
3044        plt.tight_layout()
3045
3046        return fig
3047
3048    def spatial_similarity(
3049        self,
3050        set_info: bool = True,
3051        bandwidth=1,
3052        n_neighbors=5,
3053        min_dist=0.1,
3054        legend_split=2,
3055        point_size=100,
3056        spread=1.0,
3057        set_op_mix_ratio=1.0,
3058        local_connectivity=1,
3059        repulsion_strength=1.0,
3060        negative_sample_rate=5,
3061        threshold=0.1,
3062        width=12,
3063        height=10,
3064    ):
3065        """
3066        Create a spatial UMAP-like visualization of similarity relationships between samples.
3067
3068        Parameters
3069        ----------
3070        set_info : bool, default True
3071            If True, retain set information in labels.
3072
3073        bandwidth : float, default 1
3074            Bandwidth used by MeanShift for clustering polygons.
3075
3076        point_size : float, default 100
3077            Size of scatter points.
3078
3079        legend_split : int, default 2
3080            Number of columns in legend.
3081
3082        n_neighbors, min_dist, spread, set_op_mix_ratio, local_connectivity, repulsion_strength, negative_sample_rate : parameters passed to UMAP.
3083
3084        threshold : float, default 0.1
3085            Minimum text distance for label adjustment to avoid overlap.
3086
3087        width : int, default 12
3088            Figure width.
3089
3090        height : int, default 10
3091            Figure height.
3092
3093        Returns
3094        -------
3095        matplotlib.figure.Figure
3096
3097        Raises
3098        ------
3099        ValueError
3100            If `self.similarity` is None.
3101
3102        Notes
3103        -----
3104        Builds a precomputed distance matrix combining correlation and euclidean distance,
3105        runs UMAP with metric='precomputed', then overlays cluster hulls (MeanShift + convex hull)
3106        and arrows to indicate nearest neighbors (minimal combined distance).
3107        """
3108
3109        if self.similarity is None:
3110            raise ValueError(
3111                "Similarity data is missing. Please calculate similarity using self.estimating_similarity."
3112            )
3113
3114        similarity_data = self.similarity
3115
3116        sim = similarity_data["correlation"]
3117        sim_scaled = (sim - sim.min()) / (sim.max() - sim.min())
3118        eu_dist = similarity_data["euclidean_dist"]
3119        eu_dist_scaled = (eu_dist - eu_dist.min()) / (eu_dist.max() - eu_dist.min())
3120
3121        similarity_data["combo_dist"] = (1 - sim_scaled) * eu_dist_scaled
3122
3123        # for nn target
3124        arrow_df = similarity_data.copy()
3125        arrow_df = similarity_data.loc[
3126            similarity_data.groupby("cell1")["combo_dist"].idxmin()
3127        ]
3128
3129        cells = sorted(set(similarity_data["cell1"]) | set(similarity_data["cell2"]))
3130        combo_matrix = pd.DataFrame(0, index=cells, columns=cells, dtype=float)
3131
3132        for _, row in similarity_data.iterrows():
3133            combo_matrix.loc[row["cell1"], row["cell2"]] = row["combo_dist"]
3134            combo_matrix.loc[row["cell2"], row["cell1"]] = row["combo_dist"]
3135
3136        umap_model = umap.UMAP(
3137            n_components=2,
3138            metric="precomputed",
3139            n_neighbors=n_neighbors,
3140            min_dist=min_dist,
3141            spread=spread,
3142            set_op_mix_ratio=set_op_mix_ratio,
3143            local_connectivity=set_op_mix_ratio,
3144            repulsion_strength=repulsion_strength,
3145            negative_sample_rate=negative_sample_rate,
3146            transform_seed=42,
3147            init="spectral",
3148            random_state=42,
3149            verbose=True,
3150        )
3151
3152        coords = umap_model.fit_transform(combo_matrix.values)
3153        cell_names = list(combo_matrix.index)
3154        num_cells = len(cell_names)
3155        palette = sns.color_palette("tab20c", num_cells)
3156
3157        if "#" in cell_names[0]:
3158            avsets = set(
3159                [re.sub(".*# ", "", x) for x in similarity_data["cell1"]]
3160                + [re.sub(".*# ", "", x) for x in similarity_data["cell2"]]
3161            )
3162            num_sets = len(avsets)
3163            color_indices = [i * len(palette) // num_sets for i in range(num_sets)]
3164            color_mapping_sets = {
3165                set_name: palette[i] for i, set_name in zip(color_indices, avsets)
3166            }
3167            color_mapping = {
3168                name: color_mapping_sets[re.sub(".*# ", "", name)]
3169                for i, name in enumerate(cell_names)
3170            }
3171        else:
3172            color_mapping = {name: palette[i] for i, name in enumerate(cell_names)}
3173
3174        meanshift = MeanShift(bandwidth=bandwidth)
3175        labels = meanshift.fit_predict(coords)
3176
3177        fig = plt.figure(figsize=(width, height))
3178        ax = plt.gca()
3179
3180        unique_labels = set(labels)
3181        cluster_palette = sns.color_palette("hls", len(unique_labels))
3182
3183        for label in unique_labels:
3184            if label == -1:
3185                continue
3186            cluster_coords = coords[labels == label]
3187            if len(cluster_coords) < 3:
3188                continue
3189
3190            hull = ConvexHull(cluster_coords)
3191            hull_points = cluster_coords[hull.vertices]
3192
3193            centroid = np.mean(hull_points, axis=0)
3194            expanded = hull_points + 0.05 * (hull_points - centroid)
3195
3196            poly = Polygon(
3197                expanded,
3198                closed=True,
3199                facecolor=cluster_palette[label],
3200                edgecolor="none",
3201                alpha=0.2,
3202                zorder=1,
3203            )
3204            ax.add_patch(poly)
3205
3206        texts = []
3207        for i, (x, y) in enumerate(coords):
3208            plt.scatter(
3209                x,
3210                y,
3211                s=point_size,
3212                color=color_mapping[cell_names[i]],
3213                edgecolors="black",
3214                linewidths=0.5,
3215                zorder=2,
3216            )
3217            texts.append(
3218                ax.text(
3219                    x, y, str(i), ha="center", va="center", fontsize=8, color="black"
3220                )
3221            )
3222
3223        def dist(p1, p2):
3224            return np.sqrt((p1[0] - p2[0]) ** 2 + (p1[1] - p2[1]) ** 2)
3225
3226        texts_to_adjust = []
3227        for i, t1 in enumerate(texts):
3228            for j, t2 in enumerate(texts):
3229                if i >= j:
3230                    continue
3231                d = dist(
3232                    (t1.get_position()[0], t1.get_position()[1]),
3233                    (t2.get_position()[0], t2.get_position()[1]),
3234                )
3235                if d < threshold:
3236                    if t1 not in texts_to_adjust:
3237                        texts_to_adjust.append(t1)
3238                    if t2 not in texts_to_adjust:
3239                        texts_to_adjust.append(t2)
3240
3241        adjust_text(
3242            texts_to_adjust,
3243            expand_text=(1.0, 1.0),
3244            force_text=0.9,
3245            arrowprops=dict(arrowstyle="-", color="gray", lw=0.1),
3246            ax=ax,
3247        )
3248
3249        for _, row in arrow_df.iterrows():
3250            try:
3251                idx1 = cell_names.index(row["cell1"])
3252                idx2 = cell_names.index(row["cell2"])
3253            except ValueError:
3254                continue
3255            x1, y1 = coords[idx1]
3256            x2, y2 = coords[idx2]
3257            arrow = FancyArrowPatch(
3258                (x1, y1),
3259                (x2, y2),
3260                arrowstyle="->",
3261                color="gray",
3262                linewidth=1.5,
3263                alpha=0.5,
3264                mutation_scale=12,
3265                zorder=0,
3266            )
3267            ax.add_patch(arrow)
3268
3269        if set_info is False and " # " in cell_names[0]:
3270
3271            legend_elements = [
3272                Patch(
3273                    facecolor=color_mapping[name],
3274                    edgecolor="black",
3275                    label=f"{i}{re.sub(' #.*', '', name)}",
3276                )
3277                for i, name in enumerate(cell_names)
3278            ]
3279
3280        else:
3281
3282            legend_elements = [
3283                Patch(
3284                    facecolor=color_mapping[name],
3285                    edgecolor="black",
3286                    label=f"{i}{name}",
3287                )
3288                for i, name in enumerate(cell_names)
3289            ]
3290
3291        plt.legend(
3292            handles=legend_elements,
3293            title="Cells",
3294            bbox_to_anchor=(1.05, 1),
3295            loc="upper left",
3296            ncol=legend_split,
3297        )
3298
3299        plt.xlabel("UMAP 1")
3300        plt.ylabel("UMAP 2")
3301        plt.grid(False)
3302        plt.show()
3303
3304        return fig
3305
3306    # subclusters part
3307
3308    def subcluster_prepare(self, features: list, cluster: str):
3309        """
3310        Prepare a `Clustering` object for subcluster analysis on a selected parent cluster.
3311
3312        Parameters
3313        ----------
3314        features : list
3315            Features to include for subcluster analysis.
3316
3317        cluster : str
3318            Parent cluster name (used to select matching cells).
3319
3320        Update
3321        ------
3322        Initializes `self.subclusters_` as a new `Clustering` instance containing the
3323        reduced data for the given cluster and stores `current_features` and `current_cluster`.
3324        """
3325
3326        dat = self.normalized_data
3327        dat.columns = list(self.input_metadata["cell_names"])
3328
3329        dat = reduce_data(self.normalized_data, features=features, names=[cluster])
3330
3331        self.subclusters_ = Clustering(data=dat, metadata=None)
3332
3333        self.subclusters_.current_features = features
3334        self.subclusters_.current_cluster = cluster
3335
3336    def define_subclusters(
3337        self,
3338        umap_num: int = 2,
3339        eps: float = 0.5,
3340        min_samples: int = 10,
3341        n_neighbors: int = 5,
3342        min_dist: float = 0.1,
3343        spread: float = 1.0,
3344        set_op_mix_ratio: float = 1.0,
3345        local_connectivity: int = 1,
3346        repulsion_strength: float = 1.0,
3347        negative_sample_rate: int = 5,
3348        width=8,
3349        height=6,
3350    ):
3351        """
3352        Compute UMAP and DBSCAN clustering within a previously prepared subcluster dataset.
3353
3354        Parameters
3355        ----------
3356        umap_num : int, default 2
3357            Number of UMAP dimensions to compute.
3358
3359        eps : float, default 0.5
3360            DBSCAN eps parameter.
3361
3362        min_samples : int, default 10
3363            DBSCAN min_samples parameter.
3364
3365        n_neighbors, min_dist, spread, set_op_mix_ratio, local_connectivity, repulsion_strength, negative_sample_rate, width, height :
3366            Additional parameters passed to UMAP / plotting / MeanShift as appropriate.
3367
3368        Update
3369        -------
3370        Stores cluster labels in `self.subclusters_.subclusters`.
3371
3372        Raises
3373        ------
3374        RuntimeError
3375            If `self.subclusters_` has not been prepared.
3376        """
3377
3378        if self.subclusters_ is None:
3379            raise RuntimeError(
3380                "Nothing to return. 'self.subcluster_prepare' was not conducted!"
3381            )
3382
3383        self.subclusters_.perform_UMAP(
3384            factorize=False,
3385            umap_num=umap_num,
3386            pc_num=0,
3387            harmonized=False,
3388            n_neighbors=n_neighbors,
3389            min_dist=min_dist,
3390            spread=spread,
3391            set_op_mix_ratio=set_op_mix_ratio,
3392            local_connectivity=local_connectivity,
3393            repulsion_strength=repulsion_strength,
3394            negative_sample_rate=negative_sample_rate,
3395            width=width,
3396            height=height,
3397        )
3398
3399        fig = self.subclusters_.find_clusters_UMAP(
3400            umap_n=umap_num,
3401            eps=eps,
3402            min_samples=min_samples,
3403            width=width,
3404            height=height,
3405        )
3406
3407        clusters = self.subclusters_.return_clusters(clusters="umap")
3408
3409        self.subclusters_.subclusters = [str(x) for x in list(clusters)]
3410
3411        return fig
3412
3413    def subcluster_features_scatter(
3414        self,
3415        colors="viridis",
3416        hclust="complete",
3417        scale=False,
3418        img_width=3,
3419        img_high=5,
3420        label_size=6,
3421        size_scale=70,
3422        y_lab="Genes",
3423        legend_lab="normalized",
3424        bbox_to_anchor_scale: int = 25,
3425        bbox_to_anchor_perc: tuple = (0.91, 0.63),
3426    ):
3427        """
3428        Create a features-scatter visualization for the subclusters (averaged and occurrence).
3429
3430        Parameters
3431        ----------
3432        colors : str, default 'viridis'
3433            Colormap name passed to `features_scatter`.
3434
3435        hclust : str or None
3436            Hierarchical clustering linkage to order rows/columns.
3437
3438        scale: bool, default False
3439            If True, expression data will be scaled (0–1) across the rows (features).
3440
3441        img_width, img_high : float
3442            Figure size.
3443
3444        label_size : int
3445            Font size for labels.
3446
3447        size_scale : int
3448            Bubble size scaling.
3449
3450        y_lab : str
3451            X axis label.
3452
3453        legend_lab : str
3454            Colorbar label.
3455
3456        bbox_to_anchor_scale : int, default=25
3457            Vertical scale (percentage) for positioning the colorbar.
3458
3459        bbox_to_anchor_perc : tuple, default=(0.91, 0.63)
3460            Anchor position for the size legend (percent bubble legend).
3461
3462        Returns
3463        -------
3464        matplotlib.figure.Figure
3465
3466        Raises
3467        ------
3468        RuntimeError
3469            If subcluster preparation/definition has not been run.
3470        """
3471
3472        if self.subclusters_ is None or self.subclusters_.subclusters is None:
3473            raise RuntimeError(
3474                "Nothing to return. 'self.subcluster_prepare' -> 'self.define_subclusters' pip was not conducted!"
3475            )
3476
3477        dat = self.normalized_data
3478        dat.columns = list(self.input_metadata["cell_names"])
3479
3480        dat = reduce_data(
3481            self.normalized_data,
3482            features=self.subclusters_.current_features,
3483            names=[self.subclusters_.current_cluster],
3484        )
3485
3486        dat.columns = self.subclusters_.subclusters
3487
3488        avg = average(dat)
3489        occ = occurrence(dat)
3490
3491        scatter = features_scatter(
3492            expression_data=avg,
3493            occurence_data=occ,
3494            features=None,
3495            scale=scale,
3496            metadata_list=None,
3497            colors=colors,
3498            hclust=hclust,
3499            img_width=img_width,
3500            img_high=img_high,
3501            label_size=label_size,
3502            size_scale=size_scale,
3503            y_lab=y_lab,
3504            legend_lab=legend_lab,
3505            bbox_to_anchor_scale=bbox_to_anchor_scale,
3506            bbox_to_anchor_perc=bbox_to_anchor_perc,
3507        )
3508
3509        return scatter
3510
3511    def subcluster_DEG_scatter(
3512        self,
3513        top_n=3,
3514        min_exp=0,
3515        min_pct=0.25,
3516        p_val=0.05,
3517        colors="viridis",
3518        hclust="complete",
3519        scale=False,
3520        img_width=3,
3521        img_high=5,
3522        label_size=6,
3523        size_scale=70,
3524        y_lab="Genes",
3525        legend_lab="normalized",
3526        bbox_to_anchor_scale: int = 25,
3527        bbox_to_anchor_perc: tuple = (0.91, 0.63),
3528        n_proc=10,
3529    ):
3530        """
3531        Plot top differential features (DEGs) for subclusters as a features-scatter.
3532
3533        Parameters
3534        ----------
3535        top_n : int, default 3
3536            Number of top features per subcluster to show.
3537
3538        min_exp : float, default 0
3539            Minimum expression threshold passed to `statistic`.
3540
3541        min_pct : float, default 0.25
3542            Minimum percent expressed in target group.
3543
3544        p_val: float, default 0.05
3545            Maximum p-value for visualizing features.
3546
3547        n_proc : int, default 10
3548            Parallel jobs used for DEG calculation.
3549
3550        scale: bool, default False
3551            If True, expression_data will be scaled (0–1) across the rows (features).
3552
3553        colors : str, default='viridis'
3554            Colormap for expression values.
3555
3556        hclust : str or None, default='complete'
3557            Linkage method for hierarchical clustering. If None, no clustering
3558            is performed.
3559
3560        img_width : int or float, default=8
3561            Width of the plot in inches.
3562
3563        img_high : int or float, default=5
3564            Height of the plot in inches.
3565
3566        label_size : int, default=10
3567            Font size for axis labels and ticks.
3568
3569        size_scale : int or float, default=100
3570            Scaling factor for bubble sizes.
3571
3572        y_lab : str, default='Genes'
3573            Label for the x-axis.
3574
3575        legend_lab : str, default='normalized'
3576            Label for the colorbar legend.
3577
3578        bbox_to_anchor_scale : int, default=25
3579            Vertical scale (percentage) for positioning the colorbar.
3580
3581        bbox_to_anchor_perc : tuple, default=(0.91, 0.63)
3582            Anchor position for the size legend (percent bubble legend).
3583
3584        Returns
3585        -------
3586        matplotlib.figure.Figure
3587
3588        Raises
3589        ------
3590        RuntimeError
3591            If subcluster preparation/definition has not been run.
3592
3593        Notes
3594        -----
3595        Internally calls `calc_DEG` (or equivalent) to obtain statistics, filters
3596        by p-value and effect-size, selects top features per valid group and plots them.
3597        """
3598
3599        if self.subclusters_ is None or self.subclusters_.subclusters is None:
3600            raise RuntimeError(
3601                "Nothing to return. 'self.subcluster_prepare' -> 'self.define_subclusters' pip was not conducted!"
3602            )
3603
3604        dat = self.normalized_data
3605        dat.columns = list(self.input_metadata["cell_names"])
3606
3607        dat = reduce_data(
3608            self.normalized_data, names=[self.subclusters_.current_cluster]
3609        )
3610
3611        dat.columns = self.subclusters_.subclusters
3612
3613        deg_stats = calc_DEG(
3614            dat,
3615            metadata_list=None,
3616            entities="All",
3617            sets=None,
3618            min_exp=min_exp,
3619            min_pct=min_pct,
3620            n_proc=n_proc,
3621        )
3622
3623        deg_stats = deg_stats[deg_stats["p_val"] <= p_val]
3624        deg_stats = deg_stats[deg_stats["log(FC)"] > 0]
3625
3626        deg_stats = (
3627            deg_stats.sort_values(
3628                ["valid_group", "esm", "log(FC)"], ascending=[True, False, False]
3629            )
3630            .groupby("valid_group")
3631            .head(top_n)
3632        )
3633
3634        dat = reduce_data(dat, features=list(set(deg_stats["feature"])))
3635
3636        avg = average(dat)
3637        occ = occurrence(dat)
3638
3639        scatter = features_scatter(
3640            expression_data=avg,
3641            occurence_data=occ,
3642            features=None,
3643            metadata_list=None,
3644            colors=colors,
3645            hclust=hclust,
3646            img_width=img_width,
3647            img_high=img_high,
3648            label_size=label_size,
3649            size_scale=size_scale,
3650            y_lab=y_lab,
3651            legend_lab=legend_lab,
3652            bbox_to_anchor_scale=bbox_to_anchor_scale,
3653            bbox_to_anchor_perc=bbox_to_anchor_perc,
3654        )
3655
3656        return scatter
3657
3658    def accept_subclusters(self):
3659        """
3660        Commit subcluster labels into the main `input_metadata` by renaming cell names.
3661
3662        The method replaces occurrences of the parent cluster name in `self.input_metadata['cell_names']`
3663        with the expanded names that include subcluster suffixes (via `add_subnames`),
3664        then clears `self.subclusters_`.
3665
3666        Update
3667        ------
3668        Modifies `self.input_metadata['cell_names']`.
3669
3670        Resets `self.subclusters_` to None.
3671
3672        Raises
3673        ------
3674        RuntimeError
3675            If `self.subclusters_` is not defined or subclusters were not computed.
3676        """
3677
3678        if self.subclusters_ is None or self.subclusters_.subclusters is None:
3679            raise RuntimeError(
3680                "Nothing to return. 'self.subcluster_prepare' -> 'self.define_subclusters' pip was not conducted!"
3681            )
3682
3683        new_meta = add_subnames(
3684            list(self.input_metadata["cell_names"]),
3685            parent_name=self.subclusters_.current_cluster,
3686            new_clusters=self.subclusters_.subclusters,
3687        )
3688
3689        self.input_metadata["cell_names"] = new_meta
3690
3691        self.subclusters_ = None
3692
3693    def scatter_plot(
3694        self,
3695        names: list | None = None,
3696        features: list | None = None,
3697        name_slot: str = "cell_names",
3698        scale=True,
3699        colors="viridis",
3700        hclust=None,
3701        img_width=15,
3702        img_high=1,
3703        label_size=10,
3704        size_scale=200,
3705        y_lab="Genes",
3706        legend_lab="log(CPM + 1)",
3707        set_box_size: float | int = 5,
3708        set_box_high: float | int = 5,
3709        bbox_to_anchor_scale=25,
3710        bbox_to_anchor_perc=(0.90, -0.24),
3711        bbox_to_anchor_group=(1.01, 0.4),
3712    ):
3713        """
3714        Create a bubble scatter plot of selected features across samples inside project.
3715
3716        Each point represents a feature-sample pair, where the color encodes the
3717        expression value and the size encodes occurrence or relative abundance.
3718        Optionally, hierarchical clustering can be applied to order rows and columns.
3719
3720        Parameters
3721        ----------
3722        names : list, str, or None
3723            Names of samples to include. If None, all samples are considered.
3724
3725        features : list, str, or None
3726            Names of features to include. If None, all features are considered.
3727
3728        name_slot : str
3729            Column in metadata to use as sample names.
3730
3731        scale: bool, default False
3732            If True, expression_data will be scaled (0–1) across the rows (features).
3733
3734        colors : str, default='viridis'
3735            Colormap for expression values.
3736
3737        hclust : str or None, default='complete'
3738            Linkage method for hierarchical clustering. If None, no clustering
3739            is performed.
3740
3741        img_width : int or float, default=8
3742            Width of the plot in inches.
3743
3744        img_high : int or float, default=5
3745            Height of the plot in inches.
3746
3747        label_size : int, default=10
3748            Font size for axis labels and ticks.
3749
3750        size_scale : int or float, default=100
3751            Scaling factor for bubble sizes.
3752
3753        y_lab : str, default='Genes'
3754            Label for the x-axis.
3755
3756        legend_lab : str, default='log(CPM + 1)'
3757            Label for the colorbar legend.
3758
3759        bbox_to_anchor_scale : int, default=25
3760            Vertical scale (percentage) for positioning the colorbar.
3761
3762        bbox_to_anchor_perc : tuple, default=(0.91, 0.63)
3763            Anchor position for the size legend (percent bubble legend).
3764
3765        bbox_to_anchor_group : tuple, default=(1.01, 0.4)
3766            Anchor position for the group legend.
3767
3768        Returns
3769        -------
3770        matplotlib.figure.Figure
3771            The generated scatter plot figure.
3772
3773        Notes
3774        -----
3775        Colors represent expression values normalized to the colormap.
3776        """
3777
3778        prtd, met = self.get_partial_data(
3779            names=names, features=features, name_slot=name_slot, inc_metadata=True
3780        )
3781
3782        if scale:
3783
3784            legend_lab = "Scaled\n" + legend_lab
3785
3786            scaler = MinMaxScaler(feature_range=(0, 1))
3787            prtd = pd.DataFrame(
3788                scaler.fit_transform(prtd.T).T,
3789                index=prtd.index,
3790                columns=prtd.columns,
3791            )
3792
3793        prtd.columns = prtd.columns + "#" + met["sets"]
3794
3795        prtd_avg = average(prtd)
3796
3797        meta_sets = [re.sub(".*#", "", x) for x in prtd_avg.columns]
3798
3799        prtd_avg.columns = [re.sub("#.*", "", x) for x in prtd_avg.columns]
3800
3801        prtd_occ = occurrence(prtd)
3802
3803        prtd_occ.columns = [re.sub("#.*", "", x) for x in prtd_occ.columns]
3804
3805        fig_scatter = features_scatter(
3806            expression_data=prtd_avg,
3807            occurence_data=prtd_occ,
3808            scale=scale,
3809            features=None,
3810            metadata_list=meta_sets,
3811            colors=colors,
3812            hclust=hclust,
3813            img_width=img_width,
3814            img_high=img_high,
3815            label_size=label_size,
3816            size_scale=size_scale,
3817            y_lab=y_lab,
3818            legend_lab=legend_lab,
3819            set_box_size=set_box_size,
3820            set_box_high=set_box_high,
3821            bbox_to_anchor_scale=bbox_to_anchor_scale,
3822            bbox_to_anchor_perc=bbox_to_anchor_perc,
3823            bbox_to_anchor_group=bbox_to_anchor_group,
3824        )
3825
3826        return fig_scatter
3827
3828    def data_composition(
3829        self,
3830        features_count: list | None,
3831        name_slot: str = "cell_names",
3832        set_sep: bool = True,
3833    ):
3834        """
3835        Compute composition of cell types in data set.
3836
3837        This function counts the occurrences of specific cells (e.g., cell types, subtypes)
3838        within metadata entries, calculates their relative percentages, and stores
3839        the results in `self.composition_data`.
3840
3841        Parameters
3842        ----------
3843        features_count : list or None
3844            List of features (part or full names) to be counted.
3845            If None, all unique elements from the specified `name_slot` metadata field are used.
3846
3847        name_slot : str, default 'cell_names'
3848            Metadata field containing sample identifiers or labels.
3849
3850        set_sep : bool, default True
3851            If True and multiple sets are present in metadata, compute composition
3852            separately for each set.
3853
3854        Update
3855        -------
3856        Stores results in `self.composition_data` as a pandas DataFrame with:
3857        - 'name': feature name
3858        - 'n': number of occurrences
3859        - 'pct': percentage of occurrences
3860        - 'set' (if applicable): dataset identifier
3861        """
3862
3863        validated_list = list(self.input_metadata[name_slot])
3864        sets = list(self.input_metadata["sets"])
3865
3866        if features_count is None:
3867            features_count = list(set(self.input_metadata[name_slot]))
3868
3869        if set_sep and len(set(sets)) > 1:
3870
3871            final_res = pd.DataFrame()
3872
3873            for s in set(sets):
3874                print(s)
3875
3876                mask = [True if s == x else False for x in sets]
3877
3878                tmp_val_list = np.array(validated_list)
3879
3880                tmp_val_list = list(tmp_val_list[mask])
3881
3882                res_dict = {"name": [], "n": [], "set": []}
3883
3884                for f in tqdm(features_count):
3885                    res_dict["n"].append(
3886                        sum(1 for element in tmp_val_list if f in element)
3887                    )
3888                    res_dict["name"].append(f)
3889                    res_dict["set"].append(s)
3890                    res = pd.DataFrame(res_dict)
3891                    res["pct"] = res["n"] / sum(res["n"]) * 100
3892                    res["pct"] = res["pct"].round(2)
3893
3894                final_res = pd.concat([final_res, res])
3895
3896            res = final_res.sort_values(["set", "pct"], ascending=[True, False])
3897
3898        else:
3899
3900            res_dict = {"name": [], "n": []}
3901
3902            for f in tqdm(features_count):
3903                res_dict["n"].append(
3904                    sum(1 for element in validated_list if f in element)
3905                )
3906                res_dict["name"].append(f)
3907
3908            res = pd.DataFrame(res_dict)
3909            res["pct"] = res["n"] / sum(res["n"]) * 100
3910            res["pct"] = res["pct"].round(2)
3911
3912            res = res.sort_values("pct", ascending=False)
3913
3914        self.composition_data = res
3915
3916    def composition_pie(
3917        self,
3918        width=6,
3919        height=6,
3920        font_size=15,
3921        cmap: str = "tab20",
3922        legend_split_col: int = 1,
3923        offset_labels: float | int = 0.5,
3924        legend_bbox: tuple = (1.15, 0.95),
3925    ):
3926        """
3927        Visualize the composition of cell lineages using pie charts.
3928
3929        Generates pie charts showing the relative proportions of features stored
3930        in `self.composition_data`. If multiple sets are present, a separate
3931        chart is drawn for each set.
3932
3933        Parameters
3934        ----------
3935        width : int, default 6
3936            Width of the figure.
3937
3938        height : int, default 6
3939            Height of the figure (applied per set if multiple sets are plotted).
3940
3941        font_size : int, default 15
3942            Font size for labels and annotations.
3943
3944        cmap : str, default 'tab20'
3945            Colormap used for pie slices.
3946
3947        legend_split_col : int, default 1
3948            Number of columns in the legend.
3949
3950        offset_labels : float or int, default 0.5
3951            Spacing offset for label placement relative to pie slices.
3952
3953        legend_bbox : tuple, default (1.15, 0.95)
3954            Bounding box anchor position for the legend.
3955
3956        Returns
3957        -------
3958        matplotlib.figure.Figure
3959            Pie chart visualization of composition data.
3960        """
3961
3962        df = self.composition_data
3963
3964        if "set" in df.columns and len(set(df["set"])) > 1:
3965
3966            sets = list(set(df["set"]))
3967            fig, axes = plt.subplots(len(sets), 1, figsize=(width, height * len(sets)))
3968
3969            all_wedges = []
3970            cmap = plt.get_cmap(cmap)
3971
3972            set_nam = len(set(df["name"]))
3973
3974            legend_labels = list(set(df["name"]))
3975
3976            colors = [cmap(i / set_nam) for i in range(set_nam)]
3977
3978            cmap_dict = dict(zip(legend_labels, colors))
3979
3980            for idx, s in enumerate(sets):
3981                ax = axes[idx]
3982                tmp_df = df[df["set"] == s].reset_index(drop=True)
3983
3984                labels = [f"{row['pct']:.1f}%" for _, row in tmp_df.iterrows()]
3985
3986                wedges, _ = ax.pie(
3987                    tmp_df["n"],
3988                    startangle=90,
3989                    labeldistance=1.05,
3990                    colors=[cmap_dict[x] for x in tmp_df["name"]],
3991                    wedgeprops={"linewidth": 0.5, "edgecolor": "black"},
3992                )
3993
3994                all_wedges.extend(wedges)
3995
3996                kw = dict(arrowprops=dict(arrowstyle="-"), zorder=0, va="center")
3997                n = 0
3998                for i, p in enumerate(wedges):
3999                    ang = (p.theta2 - p.theta1) / 2.0 + p.theta1
4000                    y = np.sin(np.deg2rad(ang))
4001                    x = np.cos(np.deg2rad(ang))
4002                    horizontalalignment = {-1: "right", 1: "left"}[int(np.sign(x))]
4003                    connectionstyle = f"angle,angleA=0,angleB={ang}"
4004                    kw["arrowprops"].update({"connectionstyle": connectionstyle})
4005                    if len(labels[i]) > 0:
4006                        n += offset_labels
4007                        ax.annotate(
4008                            labels[i],
4009                            xy=(x, y),
4010                            xytext=(1.01 * x + (n * x / 4), 1.01 * y + (n * y / 4)),
4011                            horizontalalignment=horizontalalignment,
4012                            fontsize=font_size,
4013                            weight="bold",
4014                            **kw,
4015                        )
4016
4017                circle2 = plt.Circle((0, 0), 0.6, color="white", ec="black")
4018                ax.add_artist(circle2)
4019
4020                ax.text(
4021                    0,
4022                    0,
4023                    f"{s}",
4024                    ha="center",
4025                    va="center",
4026                    fontsize=font_size,
4027                    weight="bold",
4028                )
4029
4030            legend_handles = [
4031                Patch(facecolor=cmap_dict[label], edgecolor="black", label=label)
4032                for label in legend_labels
4033            ]
4034
4035            fig.legend(
4036                handles=legend_handles,
4037                loc="center right",
4038                bbox_to_anchor=legend_bbox,
4039                ncol=legend_split_col,
4040                title="",
4041            )
4042
4043            plt.tight_layout()
4044            plt.show()
4045
4046        else:
4047
4048            labels = [f"{row['pct']:.1f}%" for _, row in df.iterrows()]
4049
4050            legend_labels = [f"{row['name']}" for _, row in df.iterrows()]
4051
4052            cmap = plt.get_cmap(cmap)
4053            colors = [cmap(i / len(df)) for i in range(len(df))]
4054
4055            fig, ax = plt.subplots(
4056                figsize=(width, height), subplot_kw=dict(aspect="equal")
4057            )
4058
4059            wedges, _ = ax.pie(
4060                df["n"],
4061                startangle=90,
4062                labeldistance=1.05,
4063                colors=colors,
4064                wedgeprops={"linewidth": 0.5, "edgecolor": "black"},
4065            )
4066
4067            kw = dict(arrowprops=dict(arrowstyle="-"), zorder=0, va="center")
4068            n = 0
4069            for i, p in enumerate(wedges):
4070                ang = (p.theta2 - p.theta1) / 2.0 + p.theta1
4071                y = np.sin(np.deg2rad(ang))
4072                x = np.cos(np.deg2rad(ang))
4073                horizontalalignment = {-1: "right", 1: "left"}[int(np.sign(x))]
4074                connectionstyle = "angle,angleA=0,angleB={}".format(ang)
4075                kw["arrowprops"].update({"connectionstyle": connectionstyle})
4076                if len(labels[i]) > 0:
4077                    n += offset_labels
4078
4079                    ax.annotate(
4080                        labels[i],
4081                        xy=(x, y),
4082                        xytext=(1.01 * x + (n * x / 4), y * 1.01 + (n * y / 4)),
4083                        horizontalalignment=horizontalalignment,
4084                        fontsize=font_size,
4085                        weight="bold",
4086                        **kw,
4087                    )
4088
4089            circle2 = plt.Circle((0, 0), 0.6, color="white")
4090            circle2.set_edgecolor("black")
4091
4092            p = plt.gcf()
4093            p.gca().add_artist(circle2)
4094
4095            ax.legend(
4096                wedges,
4097                legend_labels,
4098                title="",
4099                loc="center left",
4100                bbox_to_anchor=legend_bbox,
4101                ncol=legend_split_col,
4102            )
4103
4104            plt.show()
4105
4106        return fig
4107
4108    def bar_composition(
4109        self,
4110        cmap="tab20b",
4111        width=2,
4112        height=6,
4113        font_size=15,
4114        legend_split_col: int = 1,
4115        legend_bbox: tuple = (1.3, 1),
4116    ):
4117        """
4118        Visualize the composition of cell lineages using bar plots.
4119
4120        Produces bar plots showing the distribution of features stored in
4121        `self.composition_data`. If multiple sets are present, a separate
4122        bar is drawn for each set. Percentages are annotated alongside the bars.
4123
4124        Parameters
4125        ----------
4126        cmap : str, default 'tab20b'
4127            Colormap used for stacked bars.
4128
4129        width : int, default 2
4130            Width of each subplot (per set).
4131
4132        height : int, default 6
4133            Height of the figure.
4134
4135        font_size : int, default 15
4136            Font size for labels and annotations.
4137
4138        legend_split_col : int, default 1
4139            Number of columns in the legend.
4140
4141        legend_bbox : tuple, default (1.3, 1)
4142            Bounding box anchor position for the legend.
4143
4144        Returns
4145        -------
4146        matplotlib.figure.Figure
4147            Stacked bar plot visualization of composition data.
4148        """
4149
4150        df = self.composition_data
4151        df["num"] = range(1, len(df) + 1)
4152
4153        if "set" in df.columns and len(set(df["set"])) > 1:
4154
4155            sets = list(set(df["set"]))
4156            fig, axes = plt.subplots(1, len(sets), figsize=(width * len(sets), height))
4157
4158            cmap = plt.get_cmap(cmap)
4159
4160            set_nam = len(set(df["name"]))
4161
4162            legend_labels = list(set(df["name"]))
4163
4164            colors = [cmap(i / set_nam) for i in range(set_nam)]
4165
4166            cmap_dict = dict(zip(legend_labels, colors))
4167
4168            for idx, s in enumerate(sets):
4169                ax = axes[idx]
4170
4171                tmp_df = df[df["set"] == s].reset_index(drop=True)
4172
4173                values = tmp_df["n"].values
4174                total = sum(values)
4175                values = [v / total * 100 for v in values]
4176                values = [round(v, 2) for v in values]
4177
4178                idx_max = np.argmax(values)
4179                correction = 100 - sum(values)
4180                values[idx_max] += correction
4181
4182                names = tmp_df["name"].values
4183                perc = tmp_df["pct"].values
4184                nums = tmp_df["num"].values
4185
4186                bottom = 0
4187                centers = []
4188                for name, num, val, color in zip(names, nums, values, colors):
4189                    ax.bar(s, val, bottom=bottom, color=cmap_dict[name], label=name)
4190                    centers.append(bottom + val / 2)
4191                    bottom += val
4192
4193                y_positions = np.linspace(centers[0], centers[-1], len(centers))
4194                x_text = -0.8
4195
4196                for y_label, y_center, pct, num in zip(
4197                    y_positions, centers, perc, nums
4198                ):
4199                    ax.annotate(
4200                        f"{pct:.1f}%",
4201                        xy=(0, y_center),
4202                        xycoords="data",
4203                        xytext=(x_text, y_label),
4204                        textcoords="data",
4205                        ha="right",
4206                        va="center",
4207                        fontsize=font_size,
4208                        arrowprops=dict(
4209                            arrowstyle="->",
4210                            lw=1,
4211                            color="black",
4212                            connectionstyle="angle3,angleA=0,angleB=90",
4213                        ),
4214                    )
4215
4216                ax.set_ylim(0, 100)
4217                ax.set_xlabel(s, fontsize=font_size)
4218                ax.xaxis.label.set_rotation(30)
4219
4220                ax.set_xticks([])
4221                ax.set_yticks([])
4222                for spine in ax.spines.values():
4223                    spine.set_visible(False)
4224
4225            legend_handles = [
4226                Patch(facecolor=cmap_dict[label], edgecolor="black", label=label)
4227                for label in legend_labels
4228            ]
4229
4230            fig.legend(
4231                handles=legend_handles,
4232                loc="center right",
4233                bbox_to_anchor=legend_bbox,
4234                ncol=legend_split_col,
4235                title="",
4236            )
4237
4238            plt.tight_layout()
4239            plt.show()
4240
4241        else:
4242
4243            cmap = plt.get_cmap(cmap)
4244
4245            colors = [cmap(i / len(df)) for i in range(len(df))]
4246
4247            fig, ax = plt.subplots(figsize=(width, height))
4248
4249            values = df["n"].values
4250            names = df["name"].values
4251            perc = df["pct"].values
4252            nums = df["num"].values
4253
4254            bottom = 0
4255            centers = []
4256            for name, num, val, color in zip(names, nums, values, colors):
4257                ax.bar(0, val, bottom=bottom, color=color, label=f"{num}) {name}")
4258                centers.append(bottom + val / 2)
4259                bottom += val
4260
4261            y_positions = np.linspace(centers[0], centers[-1], len(centers))
4262            x_text = -0.8
4263
4264            for y_label, y_center, pct, num in zip(y_positions, centers, perc, nums):
4265                ax.annotate(
4266                    f"{num}) {pct}",
4267                    xy=(0, y_center),
4268                    xycoords="data",
4269                    xytext=(x_text, y_label),
4270                    textcoords="data",
4271                    ha="right",
4272                    va="center",
4273                    fontsize=9,
4274                    arrowprops=dict(
4275                        arrowstyle="->",
4276                        lw=1,
4277                        color="black",
4278                        connectionstyle="angle3,angleA=0,angleB=90",
4279                    ),
4280                )
4281
4282            ax.set_xticks([])
4283            ax.set_yticks([])
4284            for spine in ax.spines.values():
4285                spine.set_visible(False)
4286
4287            ax.legend(
4288                title="Legend",
4289                bbox_to_anchor=legend_bbox,
4290                loc="upper left",
4291                ncol=legend_split_col,
4292            )
4293
4294            plt.tight_layout()
4295            plt.show()
4296
4297        return fig
4298
4299    def cell_regression(
4300        self,
4301        cell_x: str,
4302        cell_y: str,
4303        set_x: str | None,
4304        set_y: str | None,
4305        threshold=10,
4306        image_width=12,
4307        image_high=7,
4308        color="black",
4309    ):
4310        """
4311        Perform regression analysis between two selected cells and visualize the relationship.
4312
4313        This function computes a linear regression between two specified cells from
4314        aggregated normalized data, plots the regression line with scatter points,
4315        annotates regression statistics, and highlights potential outliers.
4316
4317        Parameters
4318        ----------
4319        cell_x : str
4320            Name of the first cell (X-axis).
4321
4322        cell_y : str
4323            Name of the second cell (Y-axis).
4324
4325        set_x : str or None
4326            Dataset identifier corresponding to `cell_x`. If None, cell is selected only by name.
4327
4328        set_y : str or None
4329            Dataset identifier corresponding to `cell_y`. If None, cell is selected only by name.
4330
4331        threshold : int or float, default 10
4332            Threshold for detecting outliers. Points deviating from the mean or diagonal by more
4333            than this value are annotated.
4334
4335        image_width : int, default 12
4336            Width of the regression plot (in inches).
4337
4338        image_high : int, default 7
4339            Height of the regression plot (in inches).
4340
4341        color : str, default 'black'
4342            Color of the regression scatter points and line.
4343
4344        Returns
4345        -------
4346        matplotlib.figure.Figure
4347            Regression plot figure with annotated regression line, R², p-value, and outliers.
4348
4349        Raises
4350        ------
4351        ValueError
4352            If `cell_x` or `cell_y` are not found in the dataset.
4353            If multiple matches are found for a cell name and `set_x`/`set_y` are not specified.
4354
4355        Notes
4356        -----
4357        - The function automatically calls `jseq_object.average()` if aggregated data is not available.
4358        - Outliers are annotated with their corresponding index labels.
4359        - Regression is computed using `scipy.stats.linregress`.
4360
4361        Examples
4362        --------
4363        >>> obj.cell_regression(cell_x="Purkinje", cell_y="Granule", set_x="Exp1", set_y="Exp2")
4364        >>> obj.cell_regression(cell_x="NeuronA", cell_y="NeuronB", threshold=5, color="blue")
4365        """
4366
4367        if self.agg_normalized_data is None:
4368            self.average()
4369
4370        metadata = self.agg_metadata
4371        data = self.agg_normalized_data
4372
4373        if set_x is not None and set_y is not None:
4374            data.columns = metadata["cell_names"] + " # " + metadata["sets"]
4375            cell_x = cell_x + " # " + set_x
4376            cell_y = cell_y + " # " + set_y
4377
4378        else:
4379            data.columns = metadata["cell_names"]
4380
4381        if not cell_x in data.columns:
4382            raise ValueError("'cell_x' value not in cell names!")
4383
4384        if not cell_y in data.columns:
4385            raise ValueError("'cell_y' value not in cell names!")
4386
4387        if list(data.columns).count(cell_x) > 1:
4388            raise ValueError(
4389                f"'{cell_x}' occurs more than once. If you want to select a specific cell, "
4390                f"please also provide the corresponding 'set_x' and 'set_y' values."
4391            )
4392
4393        if list(data.columns).count(cell_y) > 1:
4394            raise ValueError(
4395                f"'{cell_y}' occurs more than once. If you want to select a specific cell, "
4396                f"please also provide the corresponding 'set_x' and 'set_y' values."
4397            )
4398
4399        fig, ax = plt.subplots(figsize=(image_width, image_high))
4400        ax = sns.regplot(x=cell_x, y=cell_y, data=data, color=color)
4401
4402        slope, intercept, r_value, p_value, _ = stats.linregress(
4403            data[cell_x], data[cell_y]
4404        )
4405        equation = "y = {:.2f}x + {:.2f}".format(slope, intercept)
4406
4407        ax.annotate(
4408            "R-squared = {:.2f}\nP-value = {:.2f}\n{}".format(
4409                r_value**2, p_value, equation
4410            ),
4411            xy=(0.05, 0.90),
4412            xycoords="axes fraction",
4413            fontsize=12,
4414        )
4415
4416        ax.spines["top"].set_visible(False)
4417        ax.spines["right"].set_visible(False)
4418
4419        diff = []
4420        x_mean, y_mean = data[cell_x].mean(), data[cell_y].mean()
4421        for i, (xi, yi) in enumerate(zip(data[cell_x], data[cell_y])):
4422            diff.append(abs(xi - x_mean))
4423            diff.append(abs(yi - y_mean))
4424
4425        def annotate_outliers(x, y, threshold):
4426            texts = []
4427            x_mean, y_mean = x.mean(), y.mean()
4428            for i, (xi, yi) in enumerate(zip(x, y)):
4429                if (
4430                    abs(xi - x_mean) > threshold
4431                    or abs(yi - y_mean) > threshold
4432                    or abs(yi - xi) > threshold
4433                ):
4434                    text = ax.text(xi, yi, data.index[i])
4435                    texts.append(text)
4436
4437            return texts
4438
4439        texts = annotate_outliers(data[cell_x], data[cell_y], threshold)
4440
4441        adjust_text(texts, arrowprops=dict(arrowstyle="-", color="gray", alpha=0.5))
4442
4443        plt.show()
4444
4445        return fig
class Clustering:
  31class Clustering:
  32    """
  33    A class for performing dimensionality reduction, clustering, and visualization
  34    on high-dimensional data (e.g., single-cell gene expression).
  35
  36    The class provides methods for:
  37    - Normalizing and extracting subsets of data
  38    - Principal Component Analysis (PCA) and related clustering
  39    - Uniform Manifold Approximation and Projection (UMAP) and clustering
  40    - Visualization of PCA and UMAP embeddings
  41    - Harmonization of batch effects
  42    - Accessing processed data and cluster labels
  43
  44    Methods
  45    -------
  46    add_data_frame(data, metadata)
  47        Class method to create a Clustering instance from a DataFrame and metadata.
  48
  49    harmonize_sets()
  50        Perform batch effect harmonization on PCA data.
  51
  52    perform_PCA(pc_num=100, width=8, height=6)
  53        Perform PCA on the dataset and visualize the first two PCs.
  54
  55    knee_plot_PCA(width=8, height=6)
  56        Plot the cumulative variance explained by PCs to determine optimal dimensionality.
  57
  58    find_clusters_PCA(pc_num=0, eps=0.5, min_samples=10, width=8, height=6, harmonized=False)
  59        Apply DBSCAN clustering to PCA embeddings and visualize results.
  60
  61    perform_UMAP(factorize=False, umap_num=100, pc_num=0, harmonized=False, ...)
  62        Compute UMAP embeddings with optional parameter tuning.
  63
  64    knee_plot_umap(eps=0.5, min_samples=10)
  65        Determine optimal UMAP dimensionality using silhouette scores.
  66
  67    find_clusters_UMAP(umap_n=5, eps=0.5, min_samples=10, width=8, height=6)
  68        Apply DBSCAN clustering on UMAP embeddings and visualize clusters.
  69
  70    UMAP_vis(names_slot='cell_names', set_sep=True, point_size=0.6, ...)
  71        Visualize UMAP embeddings with labels and optional cluster numbering.
  72
  73    UMAP_feature(feature_name, features_data=None, point_size=0.6, ...)
  74        Plot a single feature over UMAP coordinates with customizable colormap.
  75
  76    get_umap_data()
  77        Return the UMAP embeddings along with cluster labels if available.
  78
  79    get_pca_data()
  80        Return the PCA results along with cluster labels if available.
  81
  82    return_clusters(clusters='umap')
  83        Return the cluster labels for UMAP or PCA embeddings.
  84
  85    Raises
  86    ------
  87    ValueError
  88        For invalid parameters, mismatched dimensions, or missing metadata.
  89    """
  90
  91    def __init__(self, data, metadata):
  92        """
  93        Initialize the clustering class with data and optional metadata.
  94
  95        Parameters
  96        ----------
  97        data : pandas.DataFrame
  98            Input data for clustering. Columns are considered as samples.
  99
 100        metadata : pandas.DataFrame, optional
 101            Metadata for the samples. If None, a default DataFrame with column
 102            names as 'cell_names' is created.
 103
 104        Attributes
 105        ----------
 106        -clustering_data : pandas.DataFrame
 107        -clustering_metadata : pandas.DataFrame
 108        -subclusters : None or dict
 109        -explained_var : None or numpy.ndarray
 110        -cumulative_var : None or numpy.ndarray
 111        -pca : None or pandas.DataFrame
 112        -harmonized_pca : None or pandas.DataFrame
 113        -umap : None or pandas.DataFrame
 114        """
 115
 116        self.clustering_data = data
 117        """The input data used for clustering."""
 118
 119        if metadata is None:
 120            metadata = pd.DataFrame({"cell_names": list(data.columns)})
 121
 122        self.clustering_metadata = metadata
 123        """Metadata associated with the samples."""
 124
 125        self.subclusters = None
 126        """Placeholder for storing subcluster information."""
 127
 128        self.explained_var = None
 129        """Explained variance from PCA, initialized as None."""
 130
 131        self.cumulative_var = None
 132        """Cumulative explained variance from PCA, initialized as None."""
 133
 134        self.pca = None
 135        """PCA-transformed data, initialized as None."""
 136
 137        self.harmonized_pca = None
 138        """PCA data after batch effect harmonization, initialized as None."""
 139
 140        self.umap = None
 141        """UMAP embeddings, initialized as None."""
 142
 143    @classmethod
 144    def add_data_frame(cls, data: pd.DataFrame, metadata: pd.DataFrame | None):
 145        """
 146        Create a Clustering instance from a DataFrame and optional metadata.
 147
 148        Parameters
 149        ----------
 150        data : pandas.DataFrame
 151            Input data with features as rows and samples/cells as columns.
 152
 153        metadata : pandas.DataFrame or None
 154            Optional metadata for the samples.
 155            Each row corresponds to a sample/cell, and column names in this DataFrame
 156            should match the sample/cell names in `data`. Columns can contain additional
 157            information such as cell type, experimental condition, batch, sets, etc.
 158
 159        Returns
 160        -------
 161        Clustering
 162            A new instance of the Clustering class.
 163        """
 164
 165        return cls(data, metadata)
 166
 167    def harmonize_sets(self, batch_col: str = "sets"):
 168        """
 169        Perform batch effect harmonization on PCA embeddings.
 170
 171        Parameters
 172        ----------
 173        batch_col : str, default 'sets'
 174            Name of the column in `metadata` that contains batch information for the samples/cells.
 175
 176        Returns
 177        -------
 178        None
 179            Updates the `harmonized_pca` attribute with harmonized data.
 180        """
 181
 182        data_mat = np.array(self.pca)
 183
 184        metadata = self.clustering_metadata
 185
 186        self.harmonized_pca = pd.DataFrame(
 187            harmonize.run_harmony(data_mat, metadata, vars_use=batch_col).Z_corr
 188        ).T
 189
 190        self.harmonized_pca.columns = self.pca.columns
 191
 192    def perform_PCA(self, pc_num: int = 100, width=8, height=6):
 193        """
 194        Perform Principal Component Analysis (PCA) on the dataset.
 195
 196        This method standardizes the data, applies PCA, stores results as attributes,
 197        and generates a scatter plot of the first two principal components.
 198
 199        Parameters
 200        ----------
 201        pc_num : int, default 100
 202            Number of principal components to compute.
 203            If 0, computes all available components.
 204
 205        width : int or float, default 8
 206            Width of the PCA figure.
 207
 208        height : int or float, default 6
 209            Height of the PCA figure.
 210
 211        Returns
 212        -------
 213        matplotlib.figure.Figure
 214            Scatter plot showing the first two principal components.
 215
 216        Updates
 217        -------
 218        self.pca : pandas.DataFrame
 219            DataFrame with principal component scores for each sample.
 220
 221        self.explained_var : numpy.ndarray
 222            Percentage of variance explained by each principal component.
 223
 224        self.cumulative_var : numpy.ndarray
 225            Cumulative explained variance.
 226        """
 227
 228        scaler = StandardScaler()
 229        data_scaled = scaler.fit_transform(self.clustering_data.T)
 230
 231        if pc_num == 0 or pc_num > data_scaled.shape[0]:
 232            pc_num = data_scaled.shape[0]
 233
 234        pca = PCA(n_components=pc_num, random_state=42)
 235
 236        principal_components = pca.fit_transform(data_scaled)
 237
 238        pca_df = pd.DataFrame(
 239            data=principal_components,
 240            columns=["PC" + str(x + 1) for x in range(pc_num)],
 241        )
 242
 243        self.explained_var = pca.explained_variance_ratio_ * 100
 244        self.cumulative_var = np.cumsum(self.explained_var)
 245
 246        self.pca = pca_df
 247
 248        fig = plt.figure(figsize=(width, height))
 249        plt.scatter(pca_df["PC1"], pca_df["PC2"], alpha=0.7)
 250        plt.xlabel("PC 1")
 251        plt.ylabel("PC 2")
 252        plt.grid(True)
 253        plt.show()
 254
 255        return fig
 256
 257    def knee_plot_PCA(self, width: int = 8, height: int = 6):
 258        """
 259        Plot cumulative explained variance to determine the optimal number of PCs.
 260
 261        Parameters
 262        ----------
 263        width : int, default 8
 264            Width of the figure.
 265
 266        height : int or, default 6
 267            Height of the figure.
 268
 269        Returns
 270        -------
 271        matplotlib.figure.Figure
 272            Line plot showing cumulative variance explained by each PC.
 273        """
 274
 275        fig_knee = plt.figure(figsize=(width, height))
 276        plt.plot(range(1, len(self.explained_var) + 1), self.cumulative_var, marker="o")
 277        plt.xlabel("PC (n components)")
 278        plt.ylabel("Cumulative explained variance (%)")
 279        plt.grid(True)
 280
 281        xticks = [1] + list(range(5, len(self.explained_var) + 1, 5))
 282        plt.xticks(xticks, rotation=60)
 283
 284        plt.show()
 285
 286        return fig_knee
 287
 288    def find_clusters_PCA(
 289        self,
 290        pc_num: int = 2,
 291        eps: float = 0.5,
 292        min_samples: int = 10,
 293        width: int = 8,
 294        height: int = 6,
 295        harmonized: bool = False,
 296    ):
 297        """
 298        Apply DBSCAN clustering to PCA embeddings and visualize the results.
 299
 300        This method performs density-based clustering (DBSCAN) on the PCA-reduced
 301        dataset. Cluster labels are stored in the object's metadata, and a scatter
 302        plot of the first two principal components with cluster annotations is returned.
 303
 304        Parameters
 305        ----------
 306        pc_num : int, default 2
 307            Number of principal components to use for clustering.
 308            If 0, uses all available components.
 309
 310        eps : float, default 0.5
 311            Maximum distance between two points for them to be considered
 312            as neighbors (DBSCAN parameter).
 313
 314        min_samples : int, default 10
 315            Minimum number of samples required to form a cluster (DBSCAN parameter).
 316
 317        width : int, default 8
 318            Width of the output scatter plot.
 319
 320        height : int, default 6
 321            Height of the output scatter plot.
 322
 323        harmonized : bool, default False
 324            If True, use harmonized PCA data (`self.harmonized_pca`).
 325            If False, use standard PCA results (`self.pca`).
 326
 327        Returns
 328        -------
 329        matplotlib.figure.Figure
 330            Scatter plot of the first two principal components colored by
 331            cluster assignments.
 332
 333        Updates
 334        -------
 335        self.clustering_metadata['PCA_clusters'] : list
 336            Cluster labels assigned to each cell/sample.
 337
 338        self.input_metadata['PCA_clusters'] : list, optional
 339            Cluster labels stored in input metadata (if available).
 340        """
 341
 342        dbscan = DBSCAN(eps=eps, min_samples=min_samples)
 343
 344        if pc_num == 0 and harmonized:
 345            PCA = self.harmonized_pca
 346
 347        elif pc_num == 0:
 348            PCA = self.pca
 349
 350        else:
 351            if harmonized:
 352
 353                PCA = self.harmonized_pca.iloc[:, 0:pc_num]
 354            else:
 355
 356                PCA = self.pca.iloc[:, 0:pc_num]
 357
 358        dbscan_labels = dbscan.fit_predict(PCA)
 359
 360        pca_df = pd.DataFrame(PCA)
 361        pca_df["Cluster"] = dbscan_labels
 362
 363        fig = plt.figure(figsize=(width, height))
 364
 365        for cluster_id in sorted(pca_df["Cluster"].unique()):
 366            cluster_data = pca_df[pca_df["Cluster"] == cluster_id]
 367            plt.scatter(
 368                cluster_data["PC1"],
 369                cluster_data["PC2"],
 370                label=f"Cluster {cluster_id}",
 371                alpha=0.7,
 372            )
 373
 374        plt.xlabel("PC 1")
 375        plt.ylabel("PC 2")
 376        plt.legend(title="Clusters", loc="center left", bbox_to_anchor=(1.0, 0.5))
 377
 378        plt.grid(True)
 379        plt.show()
 380
 381        self.clustering_metadata["PCA_clusters"] = [str(x) for x in dbscan_labels]
 382
 383        try:
 384            self.input_metadata["PCA_clusters"] = [str(x) for x in dbscan_labels]
 385        except:
 386            pass
 387
 388        return fig
 389
 390    def perform_UMAP(
 391        self,
 392        factorize: bool = False,
 393        umap_num: int = 100,
 394        pc_num: int = 0,
 395        harmonized: bool = False,
 396        n_neighbors: int = 5,
 397        min_dist: float | int = 0.1,
 398        spread: float | int = 1.0,
 399        set_op_mix_ratio: float | int = 1.0,
 400        local_connectivity: int = 1,
 401        repulsion_strength: float | int = 1.0,
 402        negative_sample_rate: int = 5,
 403        width: int = 8,
 404        height: int = 6,
 405    ):
 406        """
 407        Compute and visualize UMAP embeddings of the dataset.
 408
 409        This method applies Uniform Manifold Approximation and Projection (UMAP)
 410        for dimensionality reduction on either raw, PCA, or harmonized PCA data.
 411        Results are stored as a DataFrame (`self.umap`) and a scatter plot figure
 412        (`self.UMAP_plot`).
 413
 414        Parameters
 415        ----------
 416        factorize : bool, default False
 417            If True, categorical sample labels (from column names) are factorized
 418            and used as supervision in UMAP fitting.
 419
 420        umap_num : int, default 100
 421            Number of UMAP dimensions to compute. If 0, matches the input dimension.
 422
 423        pc_num : int, default 0
 424            Number of principal components to use as UMAP input.
 425            If 0, use all available components or raw data.
 426
 427        harmonized : bool, default False
 428            If True, use harmonized PCA embeddings (`self.harmonized_pca`).
 429            If False, use standard PCA or raw scaled data.
 430
 431        n_neighbors : int, default 5
 432            UMAP parameter controlling the size of the local neighborhood.
 433
 434        min_dist : float, default 0.1
 435            UMAP parameter controlling minimum allowed distance between embedded points.
 436
 437        spread : int | float, default 1.0
 438            Effective scale of embedded space (UMAP parameter).
 439
 440        set_op_mix_ratio : int | float, default 1.0
 441            Interpolation parameter between union and intersection in fuzzy sets.
 442
 443        local_connectivity : int, default 1
 444            Number of nearest neighbors assumed for each point.
 445
 446        repulsion_strength : int | float, default 1.0
 447            Weighting applied to negative samples during optimization.
 448
 449        negative_sample_rate : int, default 5
 450            Number of negative samples per positive sample in optimization.
 451
 452        width : int, default 8
 453            Width of the output scatter plot.
 454
 455        height : int, default 6
 456            Height of the output scatter plot.
 457
 458        Updates
 459        -------
 460        self.umap : pandas.DataFrame
 461            Table of UMAP embeddings with columns `UMAP1 ... UMAPn`.
 462
 463        Notes
 464        -----
 465        For supervised UMAP (`factorize=True`), categorical codes from column
 466        names of the dataset are used as labels.
 467        """
 468
 469        scaler = StandardScaler()
 470
 471        if pc_num == 0 and harmonized:
 472            data_scaled = self.harmonized_pca
 473
 474        elif pc_num == 0:
 475            data_scaled = scaler.fit_transform(self.clustering_data.T)
 476
 477        else:
 478            if harmonized:
 479
 480                data_scaled = self.harmonized_pca.iloc[:, 0:pc_num]
 481            else:
 482
 483                data_scaled = self.pca.iloc[:, 0:pc_num]
 484
 485        if umap_num == 0 or umap_num > data_scaled.shape[1]:
 486
 487            umap_num = data_scaled.shape[1]
 488
 489            reducer = umap.UMAP(
 490                n_components=len(data_scaled.T),
 491                random_state=42,
 492                n_neighbors=n_neighbors,
 493                min_dist=min_dist,
 494                spread=spread,
 495                set_op_mix_ratio=set_op_mix_ratio,
 496                local_connectivity=local_connectivity,
 497                repulsion_strength=repulsion_strength,
 498                negative_sample_rate=negative_sample_rate,
 499                n_jobs=1,
 500            )
 501
 502        else:
 503
 504            reducer = umap.UMAP(
 505                n_components=umap_num,
 506                random_state=42,
 507                n_neighbors=n_neighbors,
 508                min_dist=min_dist,
 509                spread=spread,
 510                set_op_mix_ratio=set_op_mix_ratio,
 511                local_connectivity=local_connectivity,
 512                repulsion_strength=repulsion_strength,
 513                negative_sample_rate=negative_sample_rate,
 514                n_jobs=1,
 515            )
 516
 517        if factorize:
 518            embedding = reducer.fit_transform(
 519                X=data_scaled, y=pd.Categorical(self.clustering_data.columns).codes
 520            )
 521        else:
 522            embedding = reducer.fit_transform(X=data_scaled)
 523
 524        umap_df = pd.DataFrame(
 525            embedding, columns=["UMAP" + str(x + 1) for x in range(umap_num)]
 526        )
 527
 528        plt.figure(figsize=(width, height))
 529        plt.scatter(umap_df["UMAP1"], umap_df["UMAP2"], alpha=0.7)
 530        plt.xlabel("UMAP 1")
 531        plt.ylabel("UMAP 2")
 532        plt.grid(True)
 533
 534        plt.show()
 535
 536        self.umap = umap_df
 537
 538    def knee_plot_umap(self, eps: int | float = 0.5, min_samples: int = 10):
 539        """
 540        Plot silhouette scores for different UMAP dimensions to determine optimal n_components.
 541
 542        Parameters
 543        ----------
 544        eps : float, default 0.5
 545            DBSCAN eps parameter for clustering each UMAP dimension.
 546
 547        min_samples : int, default 10
 548            Minimum number of samples to form a cluster in DBSCAN.
 549
 550        Returns
 551        -------
 552        matplotlib.figure.Figure
 553            Silhouette score plot across UMAP dimensions.
 554        """
 555
 556        umap_range = range(2, len(self.umap.T) + 1)
 557
 558        silhouette_scores = []
 559        component = []
 560        for n in umap_range:
 561
 562            db = DBSCAN(eps=eps, min_samples=min_samples)
 563            labels = db.fit_predict(np.array(self.umap)[:, :n])
 564
 565            mask = labels != -1
 566            if len(set(labels[mask])) > 1:
 567                score = silhouette_score(np.array(self.umap)[:, :n][mask], labels[mask])
 568            else:
 569                score = -1
 570
 571            silhouette_scores.append(score)
 572            component.append(n)
 573
 574        fig = plt.figure(figsize=(10, 5))
 575        plt.plot(component, silhouette_scores, marker="o")
 576        plt.xlabel("UMAP (n_components)")
 577        plt.ylabel("Silhouette Score")
 578        plt.grid(True)
 579        plt.xticks(range(int(min(component)), int(max(component)) + 1, 1))
 580
 581        plt.show()
 582
 583        return fig
 584
 585    def find_clusters_UMAP(
 586        self,
 587        umap_n: int = 5,
 588        eps: float | float = 0.5,
 589        min_samples: int = 10,
 590        width: int = 8,
 591        height: int = 6,
 592    ):
 593        """
 594        Apply DBSCAN clustering on UMAP embeddings and visualize clusters.
 595
 596        This method performs density-based clustering (DBSCAN) on the UMAP-reduced
 597        dataset. Cluster labels are stored in the object's metadata, and a scatter
 598        plot of the first two UMAP components with cluster annotations is returned.
 599
 600        Parameters
 601        ----------
 602        umap_n : int, default 5
 603            Number of UMAP dimensions to use for DBSCAN clustering.
 604            Must be <= number of columns in `self.umap`.
 605
 606        eps : float | int, default 0.5
 607            Maximum neighborhood distance between two samples for them to be considered
 608            as in the same cluster (DBSCAN parameter).
 609
 610        min_samples : int, default 10
 611            Minimum number of samples in a neighborhood to form a cluster (DBSCAN parameter).
 612
 613        width : int, default 8
 614            Figure width.
 615
 616        height : int, default 6
 617            Figure height.
 618
 619        Returns
 620        -------
 621        matplotlib.figure.Figure
 622            Scatter plot of the first two UMAP components colored by
 623            cluster assignments.
 624
 625        Updates
 626        -------
 627        self.clustering_metadata['UMAP_clusters'] : list
 628            Cluster labels assigned to each cell/sample.
 629
 630        self.input_metadata['UMAP_clusters'] : list, optional
 631            Cluster labels stored in input metadata (if available).
 632        """
 633
 634        dbscan = DBSCAN(eps=eps, min_samples=min_samples)
 635        dbscan_labels = dbscan.fit_predict(np.array(self.umap)[:, :umap_n])
 636
 637        umap_df = self.umap
 638        umap_df["Cluster"] = dbscan_labels
 639
 640        fig = plt.figure(figsize=(width, height))
 641
 642        for cluster_id in sorted(umap_df["Cluster"].unique()):
 643            cluster_data = umap_df[umap_df["Cluster"] == cluster_id]
 644            plt.scatter(
 645                cluster_data["UMAP1"],
 646                cluster_data["UMAP2"],
 647                label=f"Cluster {cluster_id}",
 648                alpha=0.7,
 649            )
 650
 651        plt.xlabel("UMAP 1")
 652        plt.ylabel("UMAP 2")
 653        plt.legend(title="Clusters", loc="center left", bbox_to_anchor=(1.0, 0.5))
 654        plt.grid(True)
 655
 656        self.clustering_metadata["UMAP_clusters"] = [str(x) for x in dbscan_labels]
 657
 658        try:
 659            self.input_metadata["UMAP_clusters"] = [str(x) for x in dbscan_labels]
 660        except:
 661            pass
 662
 663        return fig
 664
 665    def UMAP_vis(
 666        self,
 667        names_slot: str = "cell_names",
 668        set_sep: bool = True,
 669        point_size: int | float = 0.6,
 670        font_size: int | float = 6,
 671        legend_split_col: int = 2,
 672        width: int = 8,
 673        height: int = 6,
 674        inc_num: bool = True,
 675    ):
 676        """
 677        Visualize UMAP embeddings with sample labels based on specyfic metadata slot.
 678
 679        Parameters
 680        ----------
 681        names_slot : str, default 'cell_names'
 682            Column in metadata to use as sample labels.
 683
 684        set_sep : bool, default True
 685            If True, separate points by dataset.
 686
 687        point_size : float, default 0.6
 688            Size of scatter points.
 689
 690        font_size : int, default 6
 691            Font size for numbers on points.
 692
 693        legend_split_col : int, default 2
 694            Number of columns in legend.
 695
 696        width : int, default 8
 697            Figure width.
 698
 699        height : int, default 6
 700            Figure height.
 701
 702        inc_num : bool, default True
 703            If True, annotate points with numeric labels.
 704
 705        Returns
 706        -------
 707        matplotlib.figure.Figure
 708            UMAP scatter plot figure.
 709        """
 710
 711        umap_df = self.umap.iloc[:, 0:2].copy()
 712        umap_df.loc[:, "names"] = self.clustering_metadata[names_slot]
 713
 714        if set_sep:
 715
 716            if "sets" in self.clustering_metadata.columns:
 717                umap_df.loc[:, "dataset"] = self.clustering_metadata["sets"]
 718            else:
 719                umap_df["dataset"] = "default"
 720
 721        else:
 722            umap_df["dataset"] = "default"
 723
 724        umap_df.loc[:, "tmp_nam"] = umap_df["names"] + umap_df["dataset"]
 725
 726        umap_df.loc[:, "count"] = umap_df["tmp_nam"].map(
 727            umap_df["tmp_nam"].value_counts()
 728        )
 729
 730        numeric_df = (
 731            pd.DataFrame(umap_df[["count", "tmp_nam", "names"]].copy())
 732            .drop_duplicates()
 733            .sort_values("count", ascending=False)
 734        )
 735        numeric_df["numeric_values"] = range(0, numeric_df.shape[0])
 736
 737        umap_df = umap_df.merge(
 738            numeric_df[["tmp_nam", "numeric_values"]], on="tmp_nam", how="left"
 739        )
 740
 741        fig, ax = plt.subplots(figsize=(width, height))
 742
 743        markers = ["o", "s", "^", "D", "P", "*", "X"]
 744        marker_map = {
 745            ds: markers[i % len(markers)]
 746            for i, ds in enumerate(umap_df["dataset"].unique())
 747        }
 748
 749        cord_list = []
 750
 751        for num, nam in zip(numeric_df["numeric_values"], numeric_df["names"]):
 752
 753            cluster_data = umap_df[umap_df["numeric_values"] == num]
 754
 755            ax.scatter(
 756                cluster_data["UMAP1"],
 757                cluster_data["UMAP2"],
 758                label=f"{num} - {nam}",
 759                marker=marker_map[cluster_data["dataset"].iloc[0]],
 760                alpha=0.6,
 761                s=point_size,
 762            )
 763
 764            coords = cluster_data[["UMAP1", "UMAP2"]].values
 765
 766            dists = pairwise_distances(coords)
 767
 768            sum_dists = dists.sum(axis=1)
 769
 770            center_idx = np.argmin(sum_dists)
 771            center_point = coords[center_idx]
 772
 773            cord_list.append(center_point)
 774
 775        if inc_num:
 776            texts = []
 777            for (x, y), num in zip(cord_list, numeric_df["numeric_values"]):
 778                texts.append(
 779                    ax.text(
 780                        x,
 781                        y,
 782                        str(num),
 783                        ha="center",
 784                        va="center",
 785                        fontsize=font_size,
 786                        color="black",
 787                    )
 788                )
 789
 790            adjust_text(texts, arrowprops=dict(arrowstyle="-", color="gray", lw=0.5))
 791
 792        ax.set_xlabel("UMAP 1")
 793        ax.set_ylabel("UMAP 2")
 794
 795        ax.legend(
 796            title="Clusters",
 797            loc="center left",
 798            bbox_to_anchor=(1.05, 0.5),
 799            ncol=legend_split_col,
 800            markerscale=5,
 801        )
 802
 803        ax.grid(True)
 804
 805        plt.tight_layout()
 806
 807        return fig
 808
 809    def UMAP_feature(
 810        self,
 811        feature_name: str,
 812        features_data: pd.DataFrame | None,
 813        point_size: int | float = 0.6,
 814        font_size: int | float = 6,
 815        width: int = 8,
 816        height: int = 6,
 817        palette="light",
 818    ):
 819        """
 820        Visualize UMAP embedding with expression levels of a selected feature.
 821
 822        Each point (cell) in the UMAP plot is colored according to the expression
 823        value of the chosen feature, enabling interpretation of spatial patterns
 824        of gene activity or metadata distribution in low-dimensional space.
 825
 826        Parameters
 827        ----------
 828        feature_name : str
 829           Name of the feature to plot.
 830
 831        features_data : pandas.DataFrame or None, default None
 832            If None, the function uses the DataFrame containing the clustering data.
 833            To plot features not used in clustering, provide a wider DataFrame
 834            containing the original feature values.
 835
 836        point_size : float, default 0.6
 837            Size of scatter points in the plot.
 838
 839        font_size : int, default 6
 840            Font size for axis labels and annotations.
 841
 842        width : int, default 8
 843            Width of the matplotlib figure.
 844
 845        height : int, default 6
 846            Height of the matplotlib figure.
 847
 848        palette : str, default 'light'
 849            Color palette for expression visualization. Options are:
 850            - 'light'
 851            - 'dark'
 852            - 'green'
 853            - 'gray'
 854
 855        Returns
 856        -------
 857        matplotlib.figure.Figure
 858            UMAP scatter plot colored by feature values.
 859        """
 860
 861        umap_df = self.umap.iloc[:, 0:2].copy()
 862
 863        if features_data is None:
 864
 865            features_data = self.clustering_data
 866
 867        if features_data.shape[1] != umap_df.shape[0]:
 868            raise ValueError(
 869                "Imputed 'features_data' shape does not match the number of UMAP cells"
 870            )
 871
 872        blist = [
 873            True if x.upper() == feature_name.upper() else False
 874            for x in features_data.index
 875        ]
 876
 877        if not any(blist):
 878            raise ValueError("Imputed feature_name is not included in the data")
 879
 880        umap_df.loc[:, "feature"] = (
 881            features_data.loc[blist, :]
 882            .apply(lambda row: row.tolist(), axis=1)
 883            .values[0]
 884        )
 885
 886        umap_df = umap_df.sort_values("feature", ascending=True)
 887
 888        import matplotlib.colors as mcolors
 889
 890        if palette == "light":
 891            palette = px.colors.sequential.Sunsetdark
 892
 893        elif palette == "dark":
 894            palette = px.colors.sequential.thermal
 895
 896        elif palette == "green":
 897            palette = px.colors.sequential.Aggrnyl
 898
 899        elif palette == "gray":
 900            palette = px.colors.sequential.gray
 901            palette = palette[::-1]
 902
 903        else:
 904            raise ValueError(
 905                'Palette not found. Use: "light", "dark", "gray", or "green"'
 906            )
 907
 908        converted = []
 909        for c in palette:
 910            rgb_255 = px.colors.unlabel_rgb(c)
 911            rgb_01 = tuple(v / 255.0 for v in rgb_255)
 912            converted.append(rgb_01)
 913
 914        my_cmap = mcolors.ListedColormap(converted, name="custom")
 915
 916        fig, ax = plt.subplots(figsize=(width, height))
 917        sc = ax.scatter(
 918            umap_df["UMAP1"],
 919            umap_df["UMAP2"],
 920            c=umap_df["feature"],
 921            s=point_size,
 922            cmap=my_cmap,
 923            alpha=1.0,
 924            edgecolors="black",
 925            linewidths=0.1,
 926        )
 927
 928        cbar = plt.colorbar(sc, ax=ax)
 929        cbar.set_label(f"{feature_name}")
 930
 931        ax.set_xlabel("UMAP 1")
 932        ax.set_ylabel("UMAP 2")
 933
 934        ax.grid(True)
 935
 936        plt.tight_layout()
 937
 938        return fig
 939
 940    def get_umap_data(self):
 941        """
 942        Retrieve UMAP embedding data with optional cluster labels.
 943
 944        Returns the UMAP coordinates stored in `self.umap`. If clustering
 945        metadata is available (specifically `UMAP_clusters`), the corresponding
 946        cluster assignments are appended as an additional column.
 947
 948        Returns
 949        -------
 950        pandas.DataFrame
 951            DataFrame containing UMAP coordinates (columns: 'UMAP1', 'UMAP2', ...).
 952            If available, includes an extra column 'clusters' with cluster labels.
 953
 954        Notes
 955        -----
 956        - UMAP embeddings must be computed beforehand (e.g., using `perform_UMAP`).
 957        - Cluster labels are added only if present in `self.clustering_metadata`.
 958        """
 959
 960        umap_data = self.umap
 961
 962        try:
 963            umap_data["clusters"] = self.clustering_metadata["UMAP_clusters"]
 964        except:
 965            pass
 966
 967        return umap_data
 968
 969    def get_pca_data(self):
 970        """
 971        Retrieve PCA embedding data with optional cluster labels.
 972
 973        Returns the principal component scores stored in `self.pca`. If clustering
 974        metadata is available (specifically `PCA_clusters`), the corresponding
 975        cluster assignments are appended as an additional column.
 976
 977        Returns
 978        -------
 979        pandas.DataFrame
 980            DataFrame containing PCA coordinates (columns: 'PC1', 'PC2', ...).
 981            If available, includes an extra column 'clusters' with cluster labels.
 982
 983        Notes
 984        -----
 985        - PCA must be computed beforehand (e.g., using `perform_PCA`).
 986        - Cluster labels are added only if present in `self.clustering_metadata`.
 987        """
 988
 989        pca_data = self.pca
 990
 991        try:
 992            pca_data["clusters"] = self.clustering_metadata["PCA_clusters"]
 993        except:
 994            pass
 995
 996        return pca_data
 997
 998    def return_clusters(self, clusters="umap"):
 999        """
1000        Retrieve cluster labels from UMAP or PCA clustering results.
1001
1002        Parameters
1003        ----------
1004        clusters : str, default 'umap'
1005            Source of cluster labels to return. Must be one of:
1006            - 'umap': return cluster labels from UMAP embeddings.
1007            - 'pca' : return cluster labels from PCA embeddings.
1008
1009        Returns
1010        -------
1011        list
1012            Cluster labels corresponding to the selected embedding method.
1013
1014        Raises
1015        ------
1016        ValueError
1017            If `clusters` is not 'umap' or 'pca'.
1018
1019        Notes
1020        -----
1021        Requires that clustering has already been performed
1022        (e.g., using `find_clusters_UMAP` or `find_clusters_PCA`).
1023        """
1024
1025        if clusters.lower() == "umap":
1026            clusters_vector = self.clustering_metadata["UMAP_clusters"]
1027        elif clusters.lower() == "pca":
1028            clusters_vector = self.clustering_metadata["PCA_clusters"]
1029        else:
1030            raise ValueError("Parameter 'clusters' must be either 'umap' or 'pca'.")
1031
1032        return clusters_vector

A class for performing dimensionality reduction, clustering, and visualization on high-dimensional data (e.g., single-cell gene expression).

The class provides methods for:

  • Normalizing and extracting subsets of data
  • Principal Component Analysis (PCA) and related clustering
  • Uniform Manifold Approximation and Projection (UMAP) and clustering
  • Visualization of PCA and UMAP embeddings
  • Harmonization of batch effects
  • Accessing processed data and cluster labels

Methods

add_data_frame(data, metadata) Class method to create a Clustering instance from a DataFrame and metadata.

harmonize_sets() Perform batch effect harmonization on PCA data.

perform_PCA(pc_num=100, width=8, height=6) Perform PCA on the dataset and visualize the first two PCs.

knee_plot_PCA(width=8, height=6) Plot the cumulative variance explained by PCs to determine optimal dimensionality.

find_clusters_PCA(pc_num=0, eps=0.5, min_samples=10, width=8, height=6, harmonized=False) Apply DBSCAN clustering to PCA embeddings and visualize results.

perform_UMAP(factorize=False, umap_num=100, pc_num=0, harmonized=False, ...) Compute UMAP embeddings with optional parameter tuning.

knee_plot_umap(eps=0.5, min_samples=10) Determine optimal UMAP dimensionality using silhouette scores.

find_clusters_UMAP(umap_n=5, eps=0.5, min_samples=10, width=8, height=6) Apply DBSCAN clustering on UMAP embeddings and visualize clusters.

UMAP_vis(names_slot='cell_names', set_sep=True, point_size=0.6, ...) Visualize UMAP embeddings with labels and optional cluster numbering.

UMAP_feature(feature_name, features_data=None, point_size=0.6, ...) Plot a single feature over UMAP coordinates with customizable colormap.

get_umap_data() Return the UMAP embeddings along with cluster labels if available.

get_pca_data() Return the PCA results along with cluster labels if available.

return_clusters(clusters='umap') Return the cluster labels for UMAP or PCA embeddings.

Raises

ValueError For invalid parameters, mismatched dimensions, or missing metadata.

Clustering(data, metadata)
 91    def __init__(self, data, metadata):
 92        """
 93        Initialize the clustering class with data and optional metadata.
 94
 95        Parameters
 96        ----------
 97        data : pandas.DataFrame
 98            Input data for clustering. Columns are considered as samples.
 99
100        metadata : pandas.DataFrame, optional
101            Metadata for the samples. If None, a default DataFrame with column
102            names as 'cell_names' is created.
103
104        Attributes
105        ----------
106        -clustering_data : pandas.DataFrame
107        -clustering_metadata : pandas.DataFrame
108        -subclusters : None or dict
109        -explained_var : None or numpy.ndarray
110        -cumulative_var : None or numpy.ndarray
111        -pca : None or pandas.DataFrame
112        -harmonized_pca : None or pandas.DataFrame
113        -umap : None or pandas.DataFrame
114        """
115
116        self.clustering_data = data
117        """The input data used for clustering."""
118
119        if metadata is None:
120            metadata = pd.DataFrame({"cell_names": list(data.columns)})
121
122        self.clustering_metadata = metadata
123        """Metadata associated with the samples."""
124
125        self.subclusters = None
126        """Placeholder for storing subcluster information."""
127
128        self.explained_var = None
129        """Explained variance from PCA, initialized as None."""
130
131        self.cumulative_var = None
132        """Cumulative explained variance from PCA, initialized as None."""
133
134        self.pca = None
135        """PCA-transformed data, initialized as None."""
136
137        self.harmonized_pca = None
138        """PCA data after batch effect harmonization, initialized as None."""
139
140        self.umap = None
141        """UMAP embeddings, initialized as None."""

Initialize the clustering class with data and optional metadata.

Parameters

data : pandas.DataFrame Input data for clustering. Columns are considered as samples.

metadata : pandas.DataFrame, optional Metadata for the samples. If None, a default DataFrame with column names as 'cell_names' is created.

Attributes

-clustering_data : pandas.DataFrame -clustering_metadata : pandas.DataFrame -subclusters : None or dict -explained_var : None or numpy.ndarray -cumulative_var : None or numpy.ndarray -pca : None or pandas.DataFrame -harmonized_pca : None or pandas.DataFrame -umap : None or pandas.DataFrame

clustering_data

The input data used for clustering.

clustering_metadata

Metadata associated with the samples.

subclusters

Placeholder for storing subcluster information.

explained_var

Explained variance from PCA, initialized as None.

cumulative_var

Cumulative explained variance from PCA, initialized as None.

pca

PCA-transformed data, initialized as None.

harmonized_pca

PCA data after batch effect harmonization, initialized as None.

umap

UMAP embeddings, initialized as None.

@classmethod
def add_data_frame( cls, data: pandas.core.frame.DataFrame, metadata: pandas.core.frame.DataFrame | None):
143    @classmethod
144    def add_data_frame(cls, data: pd.DataFrame, metadata: pd.DataFrame | None):
145        """
146        Create a Clustering instance from a DataFrame and optional metadata.
147
148        Parameters
149        ----------
150        data : pandas.DataFrame
151            Input data with features as rows and samples/cells as columns.
152
153        metadata : pandas.DataFrame or None
154            Optional metadata for the samples.
155            Each row corresponds to a sample/cell, and column names in this DataFrame
156            should match the sample/cell names in `data`. Columns can contain additional
157            information such as cell type, experimental condition, batch, sets, etc.
158
159        Returns
160        -------
161        Clustering
162            A new instance of the Clustering class.
163        """
164
165        return cls(data, metadata)

Create a Clustering instance from a DataFrame and optional metadata.

Parameters

data : pandas.DataFrame Input data with features as rows and samples/cells as columns.

metadata : pandas.DataFrame or None Optional metadata for the samples. Each row corresponds to a sample/cell, and column names in this DataFrame should match the sample/cell names in data. Columns can contain additional information such as cell type, experimental condition, batch, sets, etc.

Returns

Clustering A new instance of the Clustering class.

def harmonize_sets(self, batch_col: str = 'sets'):
167    def harmonize_sets(self, batch_col: str = "sets"):
168        """
169        Perform batch effect harmonization on PCA embeddings.
170
171        Parameters
172        ----------
173        batch_col : str, default 'sets'
174            Name of the column in `metadata` that contains batch information for the samples/cells.
175
176        Returns
177        -------
178        None
179            Updates the `harmonized_pca` attribute with harmonized data.
180        """
181
182        data_mat = np.array(self.pca)
183
184        metadata = self.clustering_metadata
185
186        self.harmonized_pca = pd.DataFrame(
187            harmonize.run_harmony(data_mat, metadata, vars_use=batch_col).Z_corr
188        ).T
189
190        self.harmonized_pca.columns = self.pca.columns

Perform batch effect harmonization on PCA embeddings.

Parameters

batch_col : str, default 'sets' Name of the column in metadata that contains batch information for the samples/cells.

Returns

None Updates the harmonized_pca attribute with harmonized data.

def perform_PCA(self, pc_num: int = 100, width=8, height=6):
192    def perform_PCA(self, pc_num: int = 100, width=8, height=6):
193        """
194        Perform Principal Component Analysis (PCA) on the dataset.
195
196        This method standardizes the data, applies PCA, stores results as attributes,
197        and generates a scatter plot of the first two principal components.
198
199        Parameters
200        ----------
201        pc_num : int, default 100
202            Number of principal components to compute.
203            If 0, computes all available components.
204
205        width : int or float, default 8
206            Width of the PCA figure.
207
208        height : int or float, default 6
209            Height of the PCA figure.
210
211        Returns
212        -------
213        matplotlib.figure.Figure
214            Scatter plot showing the first two principal components.
215
216        Updates
217        -------
218        self.pca : pandas.DataFrame
219            DataFrame with principal component scores for each sample.
220
221        self.explained_var : numpy.ndarray
222            Percentage of variance explained by each principal component.
223
224        self.cumulative_var : numpy.ndarray
225            Cumulative explained variance.
226        """
227
228        scaler = StandardScaler()
229        data_scaled = scaler.fit_transform(self.clustering_data.T)
230
231        if pc_num == 0 or pc_num > data_scaled.shape[0]:
232            pc_num = data_scaled.shape[0]
233
234        pca = PCA(n_components=pc_num, random_state=42)
235
236        principal_components = pca.fit_transform(data_scaled)
237
238        pca_df = pd.DataFrame(
239            data=principal_components,
240            columns=["PC" + str(x + 1) for x in range(pc_num)],
241        )
242
243        self.explained_var = pca.explained_variance_ratio_ * 100
244        self.cumulative_var = np.cumsum(self.explained_var)
245
246        self.pca = pca_df
247
248        fig = plt.figure(figsize=(width, height))
249        plt.scatter(pca_df["PC1"], pca_df["PC2"], alpha=0.7)
250        plt.xlabel("PC 1")
251        plt.ylabel("PC 2")
252        plt.grid(True)
253        plt.show()
254
255        return fig

Perform Principal Component Analysis (PCA) on the dataset.

This method standardizes the data, applies PCA, stores results as attributes, and generates a scatter plot of the first two principal components.

Parameters

pc_num : int, default 100 Number of principal components to compute. If 0, computes all available components.

width : int or float, default 8 Width of the PCA figure.

height : int or float, default 6 Height of the PCA figure.

Returns

matplotlib.figure.Figure Scatter plot showing the first two principal components.

Updates

self.pca : pandas.DataFrame DataFrame with principal component scores for each sample.

self.explained_var : numpy.ndarray Percentage of variance explained by each principal component.

self.cumulative_var : numpy.ndarray Cumulative explained variance.

def knee_plot_PCA(self, width: int = 8, height: int = 6):
257    def knee_plot_PCA(self, width: int = 8, height: int = 6):
258        """
259        Plot cumulative explained variance to determine the optimal number of PCs.
260
261        Parameters
262        ----------
263        width : int, default 8
264            Width of the figure.
265
266        height : int or, default 6
267            Height of the figure.
268
269        Returns
270        -------
271        matplotlib.figure.Figure
272            Line plot showing cumulative variance explained by each PC.
273        """
274
275        fig_knee = plt.figure(figsize=(width, height))
276        plt.plot(range(1, len(self.explained_var) + 1), self.cumulative_var, marker="o")
277        plt.xlabel("PC (n components)")
278        plt.ylabel("Cumulative explained variance (%)")
279        plt.grid(True)
280
281        xticks = [1] + list(range(5, len(self.explained_var) + 1, 5))
282        plt.xticks(xticks, rotation=60)
283
284        plt.show()
285
286        return fig_knee

Plot cumulative explained variance to determine the optimal number of PCs.

Parameters

width : int, default 8 Width of the figure.

height : int or, default 6 Height of the figure.

Returns

matplotlib.figure.Figure Line plot showing cumulative variance explained by each PC.

def find_clusters_PCA( self, pc_num: int = 2, eps: float = 0.5, min_samples: int = 10, width: int = 8, height: int = 6, harmonized: bool = False):
288    def find_clusters_PCA(
289        self,
290        pc_num: int = 2,
291        eps: float = 0.5,
292        min_samples: int = 10,
293        width: int = 8,
294        height: int = 6,
295        harmonized: bool = False,
296    ):
297        """
298        Apply DBSCAN clustering to PCA embeddings and visualize the results.
299
300        This method performs density-based clustering (DBSCAN) on the PCA-reduced
301        dataset. Cluster labels are stored in the object's metadata, and a scatter
302        plot of the first two principal components with cluster annotations is returned.
303
304        Parameters
305        ----------
306        pc_num : int, default 2
307            Number of principal components to use for clustering.
308            If 0, uses all available components.
309
310        eps : float, default 0.5
311            Maximum distance between two points for them to be considered
312            as neighbors (DBSCAN parameter).
313
314        min_samples : int, default 10
315            Minimum number of samples required to form a cluster (DBSCAN parameter).
316
317        width : int, default 8
318            Width of the output scatter plot.
319
320        height : int, default 6
321            Height of the output scatter plot.
322
323        harmonized : bool, default False
324            If True, use harmonized PCA data (`self.harmonized_pca`).
325            If False, use standard PCA results (`self.pca`).
326
327        Returns
328        -------
329        matplotlib.figure.Figure
330            Scatter plot of the first two principal components colored by
331            cluster assignments.
332
333        Updates
334        -------
335        self.clustering_metadata['PCA_clusters'] : list
336            Cluster labels assigned to each cell/sample.
337
338        self.input_metadata['PCA_clusters'] : list, optional
339            Cluster labels stored in input metadata (if available).
340        """
341
342        dbscan = DBSCAN(eps=eps, min_samples=min_samples)
343
344        if pc_num == 0 and harmonized:
345            PCA = self.harmonized_pca
346
347        elif pc_num == 0:
348            PCA = self.pca
349
350        else:
351            if harmonized:
352
353                PCA = self.harmonized_pca.iloc[:, 0:pc_num]
354            else:
355
356                PCA = self.pca.iloc[:, 0:pc_num]
357
358        dbscan_labels = dbscan.fit_predict(PCA)
359
360        pca_df = pd.DataFrame(PCA)
361        pca_df["Cluster"] = dbscan_labels
362
363        fig = plt.figure(figsize=(width, height))
364
365        for cluster_id in sorted(pca_df["Cluster"].unique()):
366            cluster_data = pca_df[pca_df["Cluster"] == cluster_id]
367            plt.scatter(
368                cluster_data["PC1"],
369                cluster_data["PC2"],
370                label=f"Cluster {cluster_id}",
371                alpha=0.7,
372            )
373
374        plt.xlabel("PC 1")
375        plt.ylabel("PC 2")
376        plt.legend(title="Clusters", loc="center left", bbox_to_anchor=(1.0, 0.5))
377
378        plt.grid(True)
379        plt.show()
380
381        self.clustering_metadata["PCA_clusters"] = [str(x) for x in dbscan_labels]
382
383        try:
384            self.input_metadata["PCA_clusters"] = [str(x) for x in dbscan_labels]
385        except:
386            pass
387
388        return fig

Apply DBSCAN clustering to PCA embeddings and visualize the results.

This method performs density-based clustering (DBSCAN) on the PCA-reduced dataset. Cluster labels are stored in the object's metadata, and a scatter plot of the first two principal components with cluster annotations is returned.

Parameters

pc_num : int, default 2 Number of principal components to use for clustering. If 0, uses all available components.

eps : float, default 0.5 Maximum distance between two points for them to be considered as neighbors (DBSCAN parameter).

min_samples : int, default 10 Minimum number of samples required to form a cluster (DBSCAN parameter).

width : int, default 8 Width of the output scatter plot.

height : int, default 6 Height of the output scatter plot.

harmonized : bool, default False If True, use harmonized PCA data (self.harmonized_pca). If False, use standard PCA results (self.pca).

Returns

matplotlib.figure.Figure Scatter plot of the first two principal components colored by cluster assignments.

Updates

self.clustering_metadata['PCA_clusters'] : list Cluster labels assigned to each cell/sample.

self.input_metadata['PCA_clusters'] : list, optional Cluster labels stored in input metadata (if available).

def perform_UMAP( self, factorize: bool = False, umap_num: int = 100, pc_num: int = 0, harmonized: bool = False, n_neighbors: int = 5, min_dist: float | int = 0.1, spread: float | int = 1.0, set_op_mix_ratio: float | int = 1.0, local_connectivity: int = 1, repulsion_strength: float | int = 1.0, negative_sample_rate: int = 5, width: int = 8, height: int = 6):
390    def perform_UMAP(
391        self,
392        factorize: bool = False,
393        umap_num: int = 100,
394        pc_num: int = 0,
395        harmonized: bool = False,
396        n_neighbors: int = 5,
397        min_dist: float | int = 0.1,
398        spread: float | int = 1.0,
399        set_op_mix_ratio: float | int = 1.0,
400        local_connectivity: int = 1,
401        repulsion_strength: float | int = 1.0,
402        negative_sample_rate: int = 5,
403        width: int = 8,
404        height: int = 6,
405    ):
406        """
407        Compute and visualize UMAP embeddings of the dataset.
408
409        This method applies Uniform Manifold Approximation and Projection (UMAP)
410        for dimensionality reduction on either raw, PCA, or harmonized PCA data.
411        Results are stored as a DataFrame (`self.umap`) and a scatter plot figure
412        (`self.UMAP_plot`).
413
414        Parameters
415        ----------
416        factorize : bool, default False
417            If True, categorical sample labels (from column names) are factorized
418            and used as supervision in UMAP fitting.
419
420        umap_num : int, default 100
421            Number of UMAP dimensions to compute. If 0, matches the input dimension.
422
423        pc_num : int, default 0
424            Number of principal components to use as UMAP input.
425            If 0, use all available components or raw data.
426
427        harmonized : bool, default False
428            If True, use harmonized PCA embeddings (`self.harmonized_pca`).
429            If False, use standard PCA or raw scaled data.
430
431        n_neighbors : int, default 5
432            UMAP parameter controlling the size of the local neighborhood.
433
434        min_dist : float, default 0.1
435            UMAP parameter controlling minimum allowed distance between embedded points.
436
437        spread : int | float, default 1.0
438            Effective scale of embedded space (UMAP parameter).
439
440        set_op_mix_ratio : int | float, default 1.0
441            Interpolation parameter between union and intersection in fuzzy sets.
442
443        local_connectivity : int, default 1
444            Number of nearest neighbors assumed for each point.
445
446        repulsion_strength : int | float, default 1.0
447            Weighting applied to negative samples during optimization.
448
449        negative_sample_rate : int, default 5
450            Number of negative samples per positive sample in optimization.
451
452        width : int, default 8
453            Width of the output scatter plot.
454
455        height : int, default 6
456            Height of the output scatter plot.
457
458        Updates
459        -------
460        self.umap : pandas.DataFrame
461            Table of UMAP embeddings with columns `UMAP1 ... UMAPn`.
462
463        Notes
464        -----
465        For supervised UMAP (`factorize=True`), categorical codes from column
466        names of the dataset are used as labels.
467        """
468
469        scaler = StandardScaler()
470
471        if pc_num == 0 and harmonized:
472            data_scaled = self.harmonized_pca
473
474        elif pc_num == 0:
475            data_scaled = scaler.fit_transform(self.clustering_data.T)
476
477        else:
478            if harmonized:
479
480                data_scaled = self.harmonized_pca.iloc[:, 0:pc_num]
481            else:
482
483                data_scaled = self.pca.iloc[:, 0:pc_num]
484
485        if umap_num == 0 or umap_num > data_scaled.shape[1]:
486
487            umap_num = data_scaled.shape[1]
488
489            reducer = umap.UMAP(
490                n_components=len(data_scaled.T),
491                random_state=42,
492                n_neighbors=n_neighbors,
493                min_dist=min_dist,
494                spread=spread,
495                set_op_mix_ratio=set_op_mix_ratio,
496                local_connectivity=local_connectivity,
497                repulsion_strength=repulsion_strength,
498                negative_sample_rate=negative_sample_rate,
499                n_jobs=1,
500            )
501
502        else:
503
504            reducer = umap.UMAP(
505                n_components=umap_num,
506                random_state=42,
507                n_neighbors=n_neighbors,
508                min_dist=min_dist,
509                spread=spread,
510                set_op_mix_ratio=set_op_mix_ratio,
511                local_connectivity=local_connectivity,
512                repulsion_strength=repulsion_strength,
513                negative_sample_rate=negative_sample_rate,
514                n_jobs=1,
515            )
516
517        if factorize:
518            embedding = reducer.fit_transform(
519                X=data_scaled, y=pd.Categorical(self.clustering_data.columns).codes
520            )
521        else:
522            embedding = reducer.fit_transform(X=data_scaled)
523
524        umap_df = pd.DataFrame(
525            embedding, columns=["UMAP" + str(x + 1) for x in range(umap_num)]
526        )
527
528        plt.figure(figsize=(width, height))
529        plt.scatter(umap_df["UMAP1"], umap_df["UMAP2"], alpha=0.7)
530        plt.xlabel("UMAP 1")
531        plt.ylabel("UMAP 2")
532        plt.grid(True)
533
534        plt.show()
535
536        self.umap = umap_df

Compute and visualize UMAP embeddings of the dataset.

This method applies Uniform Manifold Approximation and Projection (UMAP) for dimensionality reduction on either raw, PCA, or harmonized PCA data. Results are stored as a DataFrame (self.umap) and a scatter plot figure (self.UMAP_plot).

Parameters

factorize : bool, default False If True, categorical sample labels (from column names) are factorized and used as supervision in UMAP fitting.

umap_num : int, default 100 Number of UMAP dimensions to compute. If 0, matches the input dimension.

pc_num : int, default 0 Number of principal components to use as UMAP input. If 0, use all available components or raw data.

harmonized : bool, default False If True, use harmonized PCA embeddings (self.harmonized_pca). If False, use standard PCA or raw scaled data.

n_neighbors : int, default 5 UMAP parameter controlling the size of the local neighborhood.

min_dist : float, default 0.1 UMAP parameter controlling minimum allowed distance between embedded points.

spread : int | float, default 1.0 Effective scale of embedded space (UMAP parameter).

set_op_mix_ratio : int | float, default 1.0 Interpolation parameter between union and intersection in fuzzy sets.

local_connectivity : int, default 1 Number of nearest neighbors assumed for each point.

repulsion_strength : int | float, default 1.0 Weighting applied to negative samples during optimization.

negative_sample_rate : int, default 5 Number of negative samples per positive sample in optimization.

width : int, default 8 Width of the output scatter plot.

height : int, default 6 Height of the output scatter plot.

Updates

self.umap : pandas.DataFrame Table of UMAP embeddings with columns UMAP1 ... UMAPn.

Notes

For supervised UMAP (factorize=True), categorical codes from column names of the dataset are used as labels.

def knee_plot_umap(self, eps: int | float = 0.5, min_samples: int = 10):
538    def knee_plot_umap(self, eps: int | float = 0.5, min_samples: int = 10):
539        """
540        Plot silhouette scores for different UMAP dimensions to determine optimal n_components.
541
542        Parameters
543        ----------
544        eps : float, default 0.5
545            DBSCAN eps parameter for clustering each UMAP dimension.
546
547        min_samples : int, default 10
548            Minimum number of samples to form a cluster in DBSCAN.
549
550        Returns
551        -------
552        matplotlib.figure.Figure
553            Silhouette score plot across UMAP dimensions.
554        """
555
556        umap_range = range(2, len(self.umap.T) + 1)
557
558        silhouette_scores = []
559        component = []
560        for n in umap_range:
561
562            db = DBSCAN(eps=eps, min_samples=min_samples)
563            labels = db.fit_predict(np.array(self.umap)[:, :n])
564
565            mask = labels != -1
566            if len(set(labels[mask])) > 1:
567                score = silhouette_score(np.array(self.umap)[:, :n][mask], labels[mask])
568            else:
569                score = -1
570
571            silhouette_scores.append(score)
572            component.append(n)
573
574        fig = plt.figure(figsize=(10, 5))
575        plt.plot(component, silhouette_scores, marker="o")
576        plt.xlabel("UMAP (n_components)")
577        plt.ylabel("Silhouette Score")
578        plt.grid(True)
579        plt.xticks(range(int(min(component)), int(max(component)) + 1, 1))
580
581        plt.show()
582
583        return fig

Plot silhouette scores for different UMAP dimensions to determine optimal n_components.

Parameters

eps : float, default 0.5 DBSCAN eps parameter for clustering each UMAP dimension.

min_samples : int, default 10 Minimum number of samples to form a cluster in DBSCAN.

Returns

matplotlib.figure.Figure Silhouette score plot across UMAP dimensions.

def find_clusters_UMAP( self, umap_n: int = 5, eps: float = 0.5, min_samples: int = 10, width: int = 8, height: int = 6):
585    def find_clusters_UMAP(
586        self,
587        umap_n: int = 5,
588        eps: float | float = 0.5,
589        min_samples: int = 10,
590        width: int = 8,
591        height: int = 6,
592    ):
593        """
594        Apply DBSCAN clustering on UMAP embeddings and visualize clusters.
595
596        This method performs density-based clustering (DBSCAN) on the UMAP-reduced
597        dataset. Cluster labels are stored in the object's metadata, and a scatter
598        plot of the first two UMAP components with cluster annotations is returned.
599
600        Parameters
601        ----------
602        umap_n : int, default 5
603            Number of UMAP dimensions to use for DBSCAN clustering.
604            Must be <= number of columns in `self.umap`.
605
606        eps : float | int, default 0.5
607            Maximum neighborhood distance between two samples for them to be considered
608            as in the same cluster (DBSCAN parameter).
609
610        min_samples : int, default 10
611            Minimum number of samples in a neighborhood to form a cluster (DBSCAN parameter).
612
613        width : int, default 8
614            Figure width.
615
616        height : int, default 6
617            Figure height.
618
619        Returns
620        -------
621        matplotlib.figure.Figure
622            Scatter plot of the first two UMAP components colored by
623            cluster assignments.
624
625        Updates
626        -------
627        self.clustering_metadata['UMAP_clusters'] : list
628            Cluster labels assigned to each cell/sample.
629
630        self.input_metadata['UMAP_clusters'] : list, optional
631            Cluster labels stored in input metadata (if available).
632        """
633
634        dbscan = DBSCAN(eps=eps, min_samples=min_samples)
635        dbscan_labels = dbscan.fit_predict(np.array(self.umap)[:, :umap_n])
636
637        umap_df = self.umap
638        umap_df["Cluster"] = dbscan_labels
639
640        fig = plt.figure(figsize=(width, height))
641
642        for cluster_id in sorted(umap_df["Cluster"].unique()):
643            cluster_data = umap_df[umap_df["Cluster"] == cluster_id]
644            plt.scatter(
645                cluster_data["UMAP1"],
646                cluster_data["UMAP2"],
647                label=f"Cluster {cluster_id}",
648                alpha=0.7,
649            )
650
651        plt.xlabel("UMAP 1")
652        plt.ylabel("UMAP 2")
653        plt.legend(title="Clusters", loc="center left", bbox_to_anchor=(1.0, 0.5))
654        plt.grid(True)
655
656        self.clustering_metadata["UMAP_clusters"] = [str(x) for x in dbscan_labels]
657
658        try:
659            self.input_metadata["UMAP_clusters"] = [str(x) for x in dbscan_labels]
660        except:
661            pass
662
663        return fig

Apply DBSCAN clustering on UMAP embeddings and visualize clusters.

This method performs density-based clustering (DBSCAN) on the UMAP-reduced dataset. Cluster labels are stored in the object's metadata, and a scatter plot of the first two UMAP components with cluster annotations is returned.

Parameters

umap_n : int, default 5 Number of UMAP dimensions to use for DBSCAN clustering. Must be <= number of columns in self.umap.

eps : float | int, default 0.5 Maximum neighborhood distance between two samples for them to be considered as in the same cluster (DBSCAN parameter).

min_samples : int, default 10 Minimum number of samples in a neighborhood to form a cluster (DBSCAN parameter).

width : int, default 8 Figure width.

height : int, default 6 Figure height.

Returns

matplotlib.figure.Figure Scatter plot of the first two UMAP components colored by cluster assignments.

Updates

self.clustering_metadata['UMAP_clusters'] : list Cluster labels assigned to each cell/sample.

self.input_metadata['UMAP_clusters'] : list, optional Cluster labels stored in input metadata (if available).

def UMAP_vis( self, names_slot: str = 'cell_names', set_sep: bool = True, point_size: int | float = 0.6, font_size: int | float = 6, legend_split_col: int = 2, width: int = 8, height: int = 6, inc_num: bool = True):
665    def UMAP_vis(
666        self,
667        names_slot: str = "cell_names",
668        set_sep: bool = True,
669        point_size: int | float = 0.6,
670        font_size: int | float = 6,
671        legend_split_col: int = 2,
672        width: int = 8,
673        height: int = 6,
674        inc_num: bool = True,
675    ):
676        """
677        Visualize UMAP embeddings with sample labels based on specyfic metadata slot.
678
679        Parameters
680        ----------
681        names_slot : str, default 'cell_names'
682            Column in metadata to use as sample labels.
683
684        set_sep : bool, default True
685            If True, separate points by dataset.
686
687        point_size : float, default 0.6
688            Size of scatter points.
689
690        font_size : int, default 6
691            Font size for numbers on points.
692
693        legend_split_col : int, default 2
694            Number of columns in legend.
695
696        width : int, default 8
697            Figure width.
698
699        height : int, default 6
700            Figure height.
701
702        inc_num : bool, default True
703            If True, annotate points with numeric labels.
704
705        Returns
706        -------
707        matplotlib.figure.Figure
708            UMAP scatter plot figure.
709        """
710
711        umap_df = self.umap.iloc[:, 0:2].copy()
712        umap_df.loc[:, "names"] = self.clustering_metadata[names_slot]
713
714        if set_sep:
715
716            if "sets" in self.clustering_metadata.columns:
717                umap_df.loc[:, "dataset"] = self.clustering_metadata["sets"]
718            else:
719                umap_df["dataset"] = "default"
720
721        else:
722            umap_df["dataset"] = "default"
723
724        umap_df.loc[:, "tmp_nam"] = umap_df["names"] + umap_df["dataset"]
725
726        umap_df.loc[:, "count"] = umap_df["tmp_nam"].map(
727            umap_df["tmp_nam"].value_counts()
728        )
729
730        numeric_df = (
731            pd.DataFrame(umap_df[["count", "tmp_nam", "names"]].copy())
732            .drop_duplicates()
733            .sort_values("count", ascending=False)
734        )
735        numeric_df["numeric_values"] = range(0, numeric_df.shape[0])
736
737        umap_df = umap_df.merge(
738            numeric_df[["tmp_nam", "numeric_values"]], on="tmp_nam", how="left"
739        )
740
741        fig, ax = plt.subplots(figsize=(width, height))
742
743        markers = ["o", "s", "^", "D", "P", "*", "X"]
744        marker_map = {
745            ds: markers[i % len(markers)]
746            for i, ds in enumerate(umap_df["dataset"].unique())
747        }
748
749        cord_list = []
750
751        for num, nam in zip(numeric_df["numeric_values"], numeric_df["names"]):
752
753            cluster_data = umap_df[umap_df["numeric_values"] == num]
754
755            ax.scatter(
756                cluster_data["UMAP1"],
757                cluster_data["UMAP2"],
758                label=f"{num} - {nam}",
759                marker=marker_map[cluster_data["dataset"].iloc[0]],
760                alpha=0.6,
761                s=point_size,
762            )
763
764            coords = cluster_data[["UMAP1", "UMAP2"]].values
765
766            dists = pairwise_distances(coords)
767
768            sum_dists = dists.sum(axis=1)
769
770            center_idx = np.argmin(sum_dists)
771            center_point = coords[center_idx]
772
773            cord_list.append(center_point)
774
775        if inc_num:
776            texts = []
777            for (x, y), num in zip(cord_list, numeric_df["numeric_values"]):
778                texts.append(
779                    ax.text(
780                        x,
781                        y,
782                        str(num),
783                        ha="center",
784                        va="center",
785                        fontsize=font_size,
786                        color="black",
787                    )
788                )
789
790            adjust_text(texts, arrowprops=dict(arrowstyle="-", color="gray", lw=0.5))
791
792        ax.set_xlabel("UMAP 1")
793        ax.set_ylabel("UMAP 2")
794
795        ax.legend(
796            title="Clusters",
797            loc="center left",
798            bbox_to_anchor=(1.05, 0.5),
799            ncol=legend_split_col,
800            markerscale=5,
801        )
802
803        ax.grid(True)
804
805        plt.tight_layout()
806
807        return fig

Visualize UMAP embeddings with sample labels based on specyfic metadata slot.

Parameters

names_slot : str, default 'cell_names' Column in metadata to use as sample labels.

set_sep : bool, default True If True, separate points by dataset.

point_size : float, default 0.6 Size of scatter points.

font_size : int, default 6 Font size for numbers on points.

legend_split_col : int, default 2 Number of columns in legend.

width : int, default 8 Figure width.

height : int, default 6 Figure height.

inc_num : bool, default True If True, annotate points with numeric labels.

Returns

matplotlib.figure.Figure UMAP scatter plot figure.

def UMAP_feature( self, feature_name: str, features_data: pandas.core.frame.DataFrame | None, point_size: int | float = 0.6, font_size: int | float = 6, width: int = 8, height: int = 6, palette='light'):
809    def UMAP_feature(
810        self,
811        feature_name: str,
812        features_data: pd.DataFrame | None,
813        point_size: int | float = 0.6,
814        font_size: int | float = 6,
815        width: int = 8,
816        height: int = 6,
817        palette="light",
818    ):
819        """
820        Visualize UMAP embedding with expression levels of a selected feature.
821
822        Each point (cell) in the UMAP plot is colored according to the expression
823        value of the chosen feature, enabling interpretation of spatial patterns
824        of gene activity or metadata distribution in low-dimensional space.
825
826        Parameters
827        ----------
828        feature_name : str
829           Name of the feature to plot.
830
831        features_data : pandas.DataFrame or None, default None
832            If None, the function uses the DataFrame containing the clustering data.
833            To plot features not used in clustering, provide a wider DataFrame
834            containing the original feature values.
835
836        point_size : float, default 0.6
837            Size of scatter points in the plot.
838
839        font_size : int, default 6
840            Font size for axis labels and annotations.
841
842        width : int, default 8
843            Width of the matplotlib figure.
844
845        height : int, default 6
846            Height of the matplotlib figure.
847
848        palette : str, default 'light'
849            Color palette for expression visualization. Options are:
850            - 'light'
851            - 'dark'
852            - 'green'
853            - 'gray'
854
855        Returns
856        -------
857        matplotlib.figure.Figure
858            UMAP scatter plot colored by feature values.
859        """
860
861        umap_df = self.umap.iloc[:, 0:2].copy()
862
863        if features_data is None:
864
865            features_data = self.clustering_data
866
867        if features_data.shape[1] != umap_df.shape[0]:
868            raise ValueError(
869                "Imputed 'features_data' shape does not match the number of UMAP cells"
870            )
871
872        blist = [
873            True if x.upper() == feature_name.upper() else False
874            for x in features_data.index
875        ]
876
877        if not any(blist):
878            raise ValueError("Imputed feature_name is not included in the data")
879
880        umap_df.loc[:, "feature"] = (
881            features_data.loc[blist, :]
882            .apply(lambda row: row.tolist(), axis=1)
883            .values[0]
884        )
885
886        umap_df = umap_df.sort_values("feature", ascending=True)
887
888        import matplotlib.colors as mcolors
889
890        if palette == "light":
891            palette = px.colors.sequential.Sunsetdark
892
893        elif palette == "dark":
894            palette = px.colors.sequential.thermal
895
896        elif palette == "green":
897            palette = px.colors.sequential.Aggrnyl
898
899        elif palette == "gray":
900            palette = px.colors.sequential.gray
901            palette = palette[::-1]
902
903        else:
904            raise ValueError(
905                'Palette not found. Use: "light", "dark", "gray", or "green"'
906            )
907
908        converted = []
909        for c in palette:
910            rgb_255 = px.colors.unlabel_rgb(c)
911            rgb_01 = tuple(v / 255.0 for v in rgb_255)
912            converted.append(rgb_01)
913
914        my_cmap = mcolors.ListedColormap(converted, name="custom")
915
916        fig, ax = plt.subplots(figsize=(width, height))
917        sc = ax.scatter(
918            umap_df["UMAP1"],
919            umap_df["UMAP2"],
920            c=umap_df["feature"],
921            s=point_size,
922            cmap=my_cmap,
923            alpha=1.0,
924            edgecolors="black",
925            linewidths=0.1,
926        )
927
928        cbar = plt.colorbar(sc, ax=ax)
929        cbar.set_label(f"{feature_name}")
930
931        ax.set_xlabel("UMAP 1")
932        ax.set_ylabel("UMAP 2")
933
934        ax.grid(True)
935
936        plt.tight_layout()
937
938        return fig

Visualize UMAP embedding with expression levels of a selected feature.

Each point (cell) in the UMAP plot is colored according to the expression value of the chosen feature, enabling interpretation of spatial patterns of gene activity or metadata distribution in low-dimensional space.

Parameters

feature_name : str Name of the feature to plot.

features_data : pandas.DataFrame or None, default None If None, the function uses the DataFrame containing the clustering data. To plot features not used in clustering, provide a wider DataFrame containing the original feature values.

point_size : float, default 0.6 Size of scatter points in the plot.

font_size : int, default 6 Font size for axis labels and annotations.

width : int, default 8 Width of the matplotlib figure.

height : int, default 6 Height of the matplotlib figure.

palette : str, default 'light' Color palette for expression visualization. Options are: - 'light' - 'dark' - 'green' - 'gray'

Returns

matplotlib.figure.Figure UMAP scatter plot colored by feature values.

def get_umap_data(self):
940    def get_umap_data(self):
941        """
942        Retrieve UMAP embedding data with optional cluster labels.
943
944        Returns the UMAP coordinates stored in `self.umap`. If clustering
945        metadata is available (specifically `UMAP_clusters`), the corresponding
946        cluster assignments are appended as an additional column.
947
948        Returns
949        -------
950        pandas.DataFrame
951            DataFrame containing UMAP coordinates (columns: 'UMAP1', 'UMAP2', ...).
952            If available, includes an extra column 'clusters' with cluster labels.
953
954        Notes
955        -----
956        - UMAP embeddings must be computed beforehand (e.g., using `perform_UMAP`).
957        - Cluster labels are added only if present in `self.clustering_metadata`.
958        """
959
960        umap_data = self.umap
961
962        try:
963            umap_data["clusters"] = self.clustering_metadata["UMAP_clusters"]
964        except:
965            pass
966
967        return umap_data

Retrieve UMAP embedding data with optional cluster labels.

Returns the UMAP coordinates stored in self.umap. If clustering metadata is available (specifically UMAP_clusters), the corresponding cluster assignments are appended as an additional column.

Returns

pandas.DataFrame DataFrame containing UMAP coordinates (columns: 'UMAP1', 'UMAP2', ...). If available, includes an extra column 'clusters' with cluster labels.

Notes

  • UMAP embeddings must be computed beforehand (e.g., using perform_UMAP).
  • Cluster labels are added only if present in self.clustering_metadata.
def get_pca_data(self):
969    def get_pca_data(self):
970        """
971        Retrieve PCA embedding data with optional cluster labels.
972
973        Returns the principal component scores stored in `self.pca`. If clustering
974        metadata is available (specifically `PCA_clusters`), the corresponding
975        cluster assignments are appended as an additional column.
976
977        Returns
978        -------
979        pandas.DataFrame
980            DataFrame containing PCA coordinates (columns: 'PC1', 'PC2', ...).
981            If available, includes an extra column 'clusters' with cluster labels.
982
983        Notes
984        -----
985        - PCA must be computed beforehand (e.g., using `perform_PCA`).
986        - Cluster labels are added only if present in `self.clustering_metadata`.
987        """
988
989        pca_data = self.pca
990
991        try:
992            pca_data["clusters"] = self.clustering_metadata["PCA_clusters"]
993        except:
994            pass
995
996        return pca_data

Retrieve PCA embedding data with optional cluster labels.

Returns the principal component scores stored in self.pca. If clustering metadata is available (specifically PCA_clusters), the corresponding cluster assignments are appended as an additional column.

Returns

pandas.DataFrame DataFrame containing PCA coordinates (columns: 'PC1', 'PC2', ...). If available, includes an extra column 'clusters' with cluster labels.

Notes

  • PCA must be computed beforehand (e.g., using perform_PCA).
  • Cluster labels are added only if present in self.clustering_metadata.
def return_clusters(self, clusters='umap'):
 998    def return_clusters(self, clusters="umap"):
 999        """
1000        Retrieve cluster labels from UMAP or PCA clustering results.
1001
1002        Parameters
1003        ----------
1004        clusters : str, default 'umap'
1005            Source of cluster labels to return. Must be one of:
1006            - 'umap': return cluster labels from UMAP embeddings.
1007            - 'pca' : return cluster labels from PCA embeddings.
1008
1009        Returns
1010        -------
1011        list
1012            Cluster labels corresponding to the selected embedding method.
1013
1014        Raises
1015        ------
1016        ValueError
1017            If `clusters` is not 'umap' or 'pca'.
1018
1019        Notes
1020        -----
1021        Requires that clustering has already been performed
1022        (e.g., using `find_clusters_UMAP` or `find_clusters_PCA`).
1023        """
1024
1025        if clusters.lower() == "umap":
1026            clusters_vector = self.clustering_metadata["UMAP_clusters"]
1027        elif clusters.lower() == "pca":
1028            clusters_vector = self.clustering_metadata["PCA_clusters"]
1029        else:
1030            raise ValueError("Parameter 'clusters' must be either 'umap' or 'pca'.")
1031
1032        return clusters_vector

Retrieve cluster labels from UMAP or PCA clustering results.

Parameters

clusters : str, default 'umap' Source of cluster labels to return. Must be one of: - 'umap': return cluster labels from UMAP embeddings. - 'pca' : return cluster labels from PCA embeddings.

Returns

list Cluster labels corresponding to the selected embedding method.

Raises

ValueError If clusters is not 'umap' or 'pca'.

Notes

Requires that clustering has already been performed (e.g., using find_clusters_UMAP or find_clusters_PCA).

class COMPsc(Clustering):
1035class COMPsc(Clustering):
1036    """
1037    A class `COMPsc` (Comparison of single-cell data) designed for the integration,
1038    analysis, and visualization of single-cell datasets.
1039    The class supports independent dataset integration, subclustering of existing clusters,
1040    marker detection, and multiple visualization strategies.
1041
1042    The COMPsc class provides methods for:
1043
1044        - Normalizing and filtering single-cell data
1045        - Loading and saving sparse 10x-style datasets
1046        - Computing differential expression and marker genes
1047        - Clustering and subclustering analysis
1048        - Visualizing similarity and spatial relationships
1049        - Aggregating data by cell and set annotations
1050        - Managing metadata and renaming labels
1051        - Plotting gene detection histograms and feature scatters
1052
1053    Methods
1054    -------
1055    project_dir(path_to_directory, project_list)
1056        Scans a directory to create a COMPsc instance mapping project names to their paths.
1057
1058    save_project(name, path=os.getcwd())
1059        Saves the COMPsc object to a pickle file on disk.
1060
1061    load_project(path)
1062        Loads a previously saved COMPsc object from a pickle file.
1063
1064    reduce_cols(reg, inc_set=False)
1065        Removes columns from data tables where column names contain a specified name or partial substring.
1066
1067    reduce_rows(reg, inc_set=False)
1068        Removes rows from data tables where column names contain a specified feature (gene) name.
1069
1070    get_data(set_info=False)
1071        Returns normalized data with optional set annotations in column names.
1072
1073    get_metadata()
1074        Returns the stored input metadata.
1075
1076    get_partial_data(names=None, features=None, name_slot='cell_names')
1077        Return a subset of the data by sample names and/or features.
1078
1079    gene_calculation()
1080        Calculates and stores per-cell gene detection counts as a pandas Series.
1081
1082    gene_histograme(bins=100)
1083        Plots a histogram of genes detected per cell with an overlaid normal distribution.
1084
1085    gene_threshold(min_n=None, max_n=None)
1086        Filters cells based on minimum and/or maximum gene detection thresholds.
1087
1088    load_sparse_from_projects(normalized_data=False)
1089        Loads and concatenates sparse 10x-style datasets from project paths into count or normalized data.
1090
1091    rename_names(mapping, slot='cell_names')
1092        Renames entries in a specified metadata column using a provided mapping dictionary.
1093
1094    rename_subclusters(mapping)
1095        Renames subcluster labels using a provided mapping dictionary.
1096
1097    save_sparse(path_to_save=os.getcwd(), name_slot='cell_names', data_slot='normalized')
1098        Exports data as 10x-compatible sparse files (matrix.mtx, barcodes.tsv, genes.tsv).
1099
1100    normalize_data(normalize=True, normalize_factor=100000)
1101        Normalizes raw counts to counts-per-specified factor (e.g., CPM-like).
1102
1103    statistic(cells=None, sets=None, min_exp=0.01, min_pct=0.1, n_proc=10)
1104        Computes per-feature differential expression statistics (Mann-Whitney U) comparing target vs. rest groups.
1105
1106    calculate_difference_markers(min_exp=0, min_pct=0.25, n_proc=10, force=False)
1107        Computes and caches differential markers using the statistic method.
1108
1109    clustering_features(features_list=None, name_slot='cell_names', p_val=0.05, top_n=25, adj_mean=True, beta=0.4)
1110        Prepares clustering input by selecting marker features and optionally smoothing cell values.
1111
1112    average()
1113        Aggregates normalized data by averaging across (cell_name, set) pairs.
1114
1115    estimating_similarity(method='pearson', p_val=0.05, top_n=25)
1116        Computes pairwise correlation and Euclidean distance between aggregated samples.
1117
1118    similarity_plot(split_sets=True, set_info=True, cmap='seismic', width=12, height=10)
1119        Visualizes pairwise similarity as a scatter plot with correlation as hue and scaled distance as point size.
1120
1121    spatial_similarity(set_info=True, bandwidth=1, n_neighbors=5, min_dist=0.1, legend_split=2, point_size=20, ...)
1122        Creates a UMAP-like visualization of similarity relationships with cluster hulls and nearest-neighbor arrows.
1123
1124    subcluster_prepare(features, cluster)
1125        Initializes a Clustering object for subcluster analysis on a selected parent cluster.
1126
1127    define_subclusters(umap_num=2, eps=0.5, min_samples=10, bandwidth=1, n_neighbors=5, min_dist=0.1, ...)
1128        Performs UMAP and DBSCAN clustering on prepared subcluster data and stores cluster labels.
1129
1130    subcluster_features_scatter(colors='viridis', hclust='complete', img_width=3, img_high=5, label_size=6, ...)
1131        Visualizes averaged expression and occurrence of features for subclusters as a scatter plot.
1132
1133    subcluster_DEG_scatter(top_n=3, min_exp=0, min_pct=0.25, p_val=0.05, colors='viridis', ...)
1134        Plots top differential features for subclusters as a features-scatter visualization.
1135
1136    accept_subclusters()
1137        Commits subcluster labels to main metadata by renaming cell names and clears subcluster data.
1138
1139    Raises
1140    ------
1141    ValueError
1142        For invalid parameters, mismatched dimensions, or missing metadata.
1143
1144    """
1145
1146    def __init__(
1147        self,
1148        objects=None,
1149    ):
1150        """
1151        Initialize the COMPsc class for single-cell data integration and analysis.
1152
1153        Parameters
1154        ----------
1155        objects : list or None, optional
1156            Optional list of data objects to initialize the instance with.
1157
1158        Attributes
1159        ----------
1160        -objects : list or None
1161        -input_data : pandas.DataFrame or None
1162        -input_metadata : pandas.DataFrame or None
1163        -normalized_data : pandas.DataFrame or None
1164        -agg_metadata : pandas.DataFrame or None
1165        -agg_normalized_data : pandas.DataFrame or None
1166        -similarity : pandas.DataFrame or None
1167        -var_data : pandas.DataFrame or None
1168        -subclusters_ : instance of Clustering class or None
1169        -cells_calc : pandas.Series or None
1170        -gene_calc : pandas.Series or None
1171        -composition_data : pandas.DataFrame or None
1172        """
1173
1174        self.objects = objects
1175        """ Stores the input data objects."""
1176
1177        self.input_data = None
1178        """Raw input data for clustering or integration analysis."""
1179
1180        self.input_metadata = None
1181        """Metadata associated with the input data."""
1182
1183        self.normalized_data = None
1184        """Normalized version of the input data."""
1185
1186        self.agg_metadata = None
1187        '''Aggregated metadata for all sets in object related to "agg_normalized_data"'''
1188
1189        self.agg_normalized_data = None
1190        """Aggregated and normalized data across multiple sets."""
1191
1192        self.similarity = None
1193        """Similarity data between cells across all samples. and sets"""
1194
1195        self.var_data = None
1196        """DEG analysis results summarizing variance across all samples in the object."""
1197
1198        self.subclusters_ = None
1199        """Placeholder for information about subclusters analysis; if computed."""
1200
1201        self.cells_calc = None
1202        """Number of cells detected per sample (grouped by lineage, e.g., cluster or name), reflecting data composition."""
1203
1204        self.gene_calc = None
1205        """Number of genes detected per sample (cell), reflecting the sequencing depth."""
1206
1207        self.composition_data = None
1208        """Data describing composition of cells across clusters or sets."""
1209
1210    @classmethod
1211    def project_dir(cls, path_to_directory, project_list):
1212        """
1213        Scan a directory and build a COMPsc instance mapping provided project names
1214        to their paths.
1215
1216        Parameters
1217        ----------
1218        path_to_directory : str
1219            Path containing project subfolders.
1220
1221        project_list : list[str]
1222            List of filenames (folder names) to include in the returned object map.
1223
1224        Returns
1225        -------
1226        COMPsc
1227            New COMPsc instance with `objects` populated.
1228
1229        Raises
1230        ------
1231        Exception
1232            A generic exception is caught and a message printed if scanning fails.
1233
1234        Notes
1235        -----
1236        Function attempts to match entries in `project_list` to directory
1237        names and constructs a simplified object key from the folder name.
1238        """
1239        try:
1240            objects = {}
1241            for filename in tqdm(os.listdir(path_to_directory)):
1242                for c in project_list:
1243                    f = os.path.join(path_to_directory, filename)
1244                    if c == filename and os.path.isdir(f):
1245                        objects[
1246                            str(
1247                                re.sub(
1248                                    "_count",
1249                                    "",
1250                                    re.sub(
1251                                        "_fq",
1252                                        "",
1253                                        re.sub(
1254                                            re.sub("_.*", "", filename) + "_",
1255                                            "",
1256                                            filename,
1257                                        ),
1258                                    ),
1259                                )
1260                            )
1261                        ] = f
1262
1263            return cls(objects)
1264
1265        except:
1266            print("Something went wrong. Check the function input data and try again!")
1267
1268    def save_project(self, name, path: str = os.getcwd()):
1269        """
1270        Save the COMPsc object to disk using pickle.
1271
1272        Parameters
1273        ----------
1274        name : str
1275            Base filename (without extension) to use when saving.
1276
1277        path : str, default os.getcwd()
1278            Directory in which to save the project file.
1279
1280        Returns
1281        -------
1282        None
1283
1284        Side Effects
1285        ------------
1286        - Writes a file `<path>/<name>.jpkl` containing the pickled object.
1287        - Prints a confirmation message with saved path.
1288        """
1289
1290        full = os.path.join(path, f"{name}.jpkl")
1291
1292        with open(full, "wb") as f:
1293            pickle.dump(self, f)
1294
1295        print(f"Project saved as {full}")
1296
1297    @classmethod
1298    def load_project(cls, path):
1299        """
1300        Load a previously saved COMPsc project from a pickle file.
1301
1302        Parameters
1303        ----------
1304        path : str
1305            Full path to the pickled project file.
1306
1307        Returns
1308        -------
1309        COMPsc
1310            The unpickled COMPsc object.
1311
1312        Raises
1313        ------
1314        FileNotFoundError
1315            If the provided path does not exist.
1316        """
1317
1318        if not os.path.exists(path):
1319            raise FileNotFoundError("File does not exist!")
1320        with open(path, "rb") as f:
1321            obj = pickle.load(f)
1322        return obj
1323
1324    def reduce_cols(
1325        self,
1326        reg: str | None = None,
1327        full: str | None = None,
1328        name_slot: str = "cell_names",
1329        inc_set: bool = False,
1330    ):
1331        """
1332        Remove columns (cells) whose names contain a substring `reg` or
1333        full name `full` from available tables.
1334
1335        Parameters
1336        ----------
1337        reg : str | None
1338            Substring to search for in column/cell names; matching columns will be removed.
1339            If not None, `full` must be None.
1340
1341        full : str | None
1342            Full name to search for in column/cell names; matching columns will be removed.
1343            If not None, `reg` must be None.
1344
1345        name_slot : str, default 'cell_names'
1346            Column in metadata to use as sample names.
1347
1348        inc_set : bool, default False
1349            If True, column names are interpreted as 'cell_name # set' when matching.
1350
1351        Update
1352        ------------
1353        Mutates `self.input_data`, `self.normalized_data`, `self.input_metadata`,
1354        `self.agg_normalized_data`, and `self.agg_metadata` (if they exist),
1355        removing columns/rows that match `reg`.
1356
1357        Raises
1358        ------
1359        Raises ValueError if nothing matches the reduction mask.
1360        """
1361
1362        if reg is None and full is None:
1363            raise ValueError(
1364                "Both 'reg' and 'full' arguments not provided. Please provide at least one of them!"
1365            )
1366
1367        if reg is not None and full is not None:
1368            raise ValueError(
1369                "Both 'reg' and 'full' arguments are provided. "
1370                "Please provide only one of them!\n"
1371                "'reg' is used when only part of the name must be detected.\n"
1372                "'full' is used if the full name must be detected."
1373            )
1374
1375        if reg is not None:
1376
1377            if self.input_data is not None:
1378
1379                if inc_set:
1380
1381                    self.input_data.columns = (
1382                        self.input_metadata[name_slot]
1383                        + " # "
1384                        + self.input_metadata["sets"]
1385                    )
1386
1387                else:
1388
1389                    self.input_data.columns = self.input_metadata[name_slot]
1390
1391                mask = [reg.upper() not in x.upper() for x in self.input_data.columns]
1392
1393                if len([y for y in mask if y is False]) == 0:
1394                    raise ValueError("Nothing found to reduce")
1395
1396                self.input_data = self.input_data.loc[:, mask]
1397
1398            if self.normalized_data is not None:
1399
1400                if inc_set:
1401
1402                    self.normalized_data.columns = (
1403                        self.input_metadata[name_slot]
1404                        + " # "
1405                        + self.input_metadata["sets"]
1406                    )
1407
1408                else:
1409
1410                    self.normalized_data.columns = self.input_metadata[name_slot]
1411
1412                mask = [
1413                    reg.upper() not in x.upper() for x in self.normalized_data.columns
1414                ]
1415
1416                if len([y for y in mask if y is False]) == 0:
1417                    raise ValueError("Nothing found to reduce")
1418
1419                self.normalized_data = self.normalized_data.loc[:, mask]
1420
1421            if self.input_metadata is not None:
1422
1423                if inc_set:
1424
1425                    self.input_metadata["drop"] = (
1426                        self.input_metadata[name_slot]
1427                        + " # "
1428                        + self.input_metadata["sets"]
1429                    )
1430
1431                else:
1432
1433                    self.input_metadata["drop"] = self.input_metadata[name_slot]
1434
1435                mask = [
1436                    reg.upper() not in x.upper() for x in self.input_metadata["drop"]
1437                ]
1438
1439                if len([y for y in mask if y is False]) == 0:
1440                    raise ValueError("Nothing found to reduce")
1441
1442                self.input_metadata = self.input_metadata.loc[mask, :].reset_index(
1443                    drop=True
1444                )
1445
1446                self.input_metadata = self.input_metadata.drop(
1447                    columns=["drop"], errors="ignore"
1448                )
1449
1450            if self.agg_normalized_data is not None:
1451
1452                if inc_set:
1453
1454                    self.agg_normalized_data.columns = (
1455                        self.agg_metadata[name_slot] + " # " + self.agg_metadata["sets"]
1456                    )
1457
1458                else:
1459
1460                    self.agg_normalized_data.columns = self.agg_metadata[name_slot]
1461
1462                mask = [
1463                    reg.upper() not in x.upper()
1464                    for x in self.agg_normalized_data.columns
1465                ]
1466
1467                if len([y for y in mask if y is False]) == 0:
1468                    raise ValueError("Nothing found to reduce")
1469
1470                self.agg_normalized_data = self.agg_normalized_data.loc[:, mask]
1471
1472            if self.agg_metadata is not None:
1473
1474                if inc_set:
1475
1476                    self.agg_metadata["drop"] = (
1477                        self.agg_metadata[name_slot] + " # " + self.agg_metadata["sets"]
1478                    )
1479
1480                else:
1481
1482                    self.agg_metadata["drop"] = self.agg_metadata[name_slot]
1483
1484                mask = [reg.upper() not in x.upper() for x in self.agg_metadata["drop"]]
1485
1486                if len([y for y in mask if y is False]) == 0:
1487                    raise ValueError("Nothing found to reduce")
1488
1489                self.agg_metadata = self.agg_metadata.loc[mask, :].reset_index(
1490                    drop=True
1491                )
1492
1493                self.agg_metadata = self.agg_metadata.drop(
1494                    columns=["drop"], errors="ignore"
1495                )
1496
1497        elif full is not None:
1498
1499            if self.input_data is not None:
1500
1501                if inc_set:
1502
1503                    self.input_data.columns = (
1504                        self.input_metadata[name_slot]
1505                        + " # "
1506                        + self.input_metadata["sets"]
1507                    )
1508
1509                    if "#" not in full:
1510
1511                        self.input_data.columns = self.input_metadata[name_slot]
1512
1513                        print(
1514                            "Not include the set info (name # set) in the 'full' argument, where 'inc_set' is True"
1515                            "Only the names will be compared, without considering the set information."
1516                        )
1517
1518                else:
1519
1520                    self.input_data.columns = self.input_metadata[name_slot]
1521
1522                mask = [full.upper() != x.upper() for x in self.input_data.columns]
1523
1524                if len([y for y in mask if y is False]) == 0:
1525                    raise ValueError("Nothing found to reduce")
1526
1527                self.input_data = self.input_data.loc[:, mask]
1528
1529            if self.normalized_data is not None:
1530
1531                if inc_set:
1532
1533                    self.normalized_data.columns = (
1534                        self.input_metadata[name_slot]
1535                        + " # "
1536                        + self.input_metadata["sets"]
1537                    )
1538
1539                    if "#" not in full:
1540
1541                        self.normalized_data.columns = self.input_metadata[name_slot]
1542
1543                        print(
1544                            "Not include the set info (name # set) in the 'full' argument, where 'inc_set' is True"
1545                            "Only the names will be compared, without considering the set information."
1546                        )
1547
1548                else:
1549
1550                    self.normalized_data.columns = self.input_metadata[name_slot]
1551
1552                mask = [full.upper() != x.upper() for x in self.normalized_data.columns]
1553
1554                if len([y for y in mask if y is False]) == 0:
1555                    raise ValueError("Nothing found to reduce")
1556
1557                self.normalized_data = self.normalized_data.loc[:, mask]
1558
1559            if self.input_metadata is not None:
1560
1561                if inc_set:
1562
1563                    self.input_metadata["drop"] = (
1564                        self.input_metadata[name_slot]
1565                        + " # "
1566                        + self.input_metadata["sets"]
1567                    )
1568
1569                    if "#" not in full:
1570
1571                        self.input_metadata["drop"] = self.input_metadata[name_slot]
1572
1573                        print(
1574                            "Not include the set info (name # set) in the 'full' argument, where 'inc_set' is True"
1575                            "Only the names will be compared, without considering the set information."
1576                        )
1577
1578                else:
1579
1580                    self.input_metadata["drop"] = self.input_metadata[name_slot]
1581
1582                mask = [full.upper() != x.upper() for x in self.input_metadata["drop"]]
1583
1584                if len([y for y in mask if y is False]) == 0:
1585                    raise ValueError("Nothing found to reduce")
1586
1587                self.input_metadata = self.input_metadata.loc[mask, :].reset_index(
1588                    drop=True
1589                )
1590
1591                self.input_metadata = self.input_metadata.drop(
1592                    columns=["drop"], errors="ignore"
1593                )
1594
1595            if self.agg_normalized_data is not None:
1596
1597                if inc_set:
1598
1599                    self.agg_normalized_data.columns = (
1600                        self.agg_metadata[name_slot] + " # " + self.agg_metadata["sets"]
1601                    )
1602
1603                    if "#" not in full:
1604
1605                        self.agg_normalized_data.columns = self.agg_metadata[name_slot]
1606
1607                        print(
1608                            "Not include the set info (name # set) in the 'full' argument, where 'inc_set' is True"
1609                            "Only the names will be compared, without considering the set information."
1610                        )
1611                else:
1612
1613                    self.agg_normalized_data.columns = self.agg_metadata[name_slot]
1614
1615                mask = [
1616                    full.upper() != x.upper() for x in self.agg_normalized_data.columns
1617                ]
1618
1619                if len([y for y in mask if y is False]) == 0:
1620                    raise ValueError("Nothing found to reduce")
1621
1622                self.agg_normalized_data = self.agg_normalized_data.loc[:, mask]
1623
1624            if self.agg_metadata is not None:
1625
1626                if inc_set:
1627
1628                    self.agg_metadata["drop"] = (
1629                        self.agg_metadata[name_slot] + " # " + self.agg_metadata["sets"]
1630                    )
1631
1632                    if "#" not in full:
1633
1634                        self.agg_metadata["drop"] = self.agg_metadata[name_slot]
1635
1636                        print(
1637                            "Not include the set info (name # set) in the 'full' argument, where 'inc_set' is True"
1638                            "Only the names will be compared, without considering the set information."
1639                        )
1640                else:
1641
1642                    self.agg_metadata["drop"] = self.agg_metadata[name_slot]
1643
1644                mask = [full.upper() != x.upper() for x in self.agg_metadata["drop"]]
1645
1646                if len([y for y in mask if y is False]) == 0:
1647                    raise ValueError("Nothing found to reduce")
1648
1649                self.agg_metadata = self.agg_metadata.loc[mask, :].reset_index(
1650                    drop=True
1651                )
1652
1653                self.agg_metadata = self.agg_metadata.drop(
1654                    columns=["drop"], errors="ignore"
1655                )
1656
1657        self.gene_calculation()
1658        self.cells_calculation()
1659
1660    def reduce_rows(self, features_list: list):
1661        """
1662        Remove rows (features) whose names are included in features_list.
1663
1664        Parameters
1665        ----------
1666        features_list : list
1667            List of features to search for in index/gene names; matching entries will be removed.
1668
1669        Update
1670        ------------
1671        Mutates `self.input_data`, `self.normalized_data`, `self.input_metadata`,
1672        `self.agg_normalized_data`, and `self.agg_metadata` (if they exist),
1673        removing columns/rows that match `reg`.
1674
1675        Raises
1676        ------
1677        Prints a message listing features that are not found in the data.
1678        """
1679
1680        if self.input_data is not None:
1681
1682            res = find_features(self.input_data, features=features_list)
1683
1684            res_list = [x.upper() for x in res["included"]]
1685
1686            mask = [x.upper() not in res_list for x in self.input_data.index]
1687
1688            if len([y for y in mask if y is False]) == 0:
1689                raise ValueError("Nothing found to reduce")
1690
1691            self.input_data = self.input_data.loc[mask, :]
1692
1693        if self.normalized_data is not None:
1694
1695            res = find_features(self.normalized_data, features=features_list)
1696
1697            res_list = [x.upper() for x in res["included"]]
1698
1699            mask = [x.upper() not in res_list for x in self.normalized_data.index]
1700
1701            if len([y for y in mask if y is False]) == 0:
1702                raise ValueError("Nothing found to reduce")
1703
1704            self.normalized_data = self.normalized_data.loc[mask, :]
1705
1706        if self.agg_normalized_data is not None:
1707
1708            res = find_features(self.agg_normalized_data, features=features_list)
1709
1710            res_list = [x.upper() for x in res["included"]]
1711
1712            mask = [x.upper() not in res_list for x in self.agg_normalized_data.index]
1713
1714            if len([y for y in mask if y is False]) == 0:
1715                raise ValueError("Nothing found to reduce")
1716
1717            self.agg_normalized_data = self.agg_normalized_data.loc[mask, :]
1718
1719        if len(res["not_included"]) > 0:
1720            print("\nFeatures not found:")
1721            for i in res["not_included"]:
1722                print(i)
1723
1724        self.gene_calculation()
1725        self.cells_calculation()
1726
1727    def get_data(self, set_info: bool = False):
1728        """
1729        Return normalized data with optional set annotation appended to column names.
1730
1731        Parameters
1732        ----------
1733        set_info : bool, default False
1734            If True, column names are returned as "cell_name # set"; otherwise
1735            only the `cell_name` is used.
1736
1737        Returns
1738        -------
1739        pandas.DataFrame
1740            The `self.normalized_data` table with columns renamed according to `set_info`.
1741
1742        Raises
1743        ------
1744        AttributeError
1745            If `self.normalized_data` or `self.input_metadata` is missing.
1746        """
1747
1748        to_return = self.normalized_data
1749
1750        if set_info:
1751            to_return.columns = (
1752                self.input_metadata["cell_names"] + " # " + self.input_metadata["sets"]
1753            )
1754        else:
1755            to_return.columns = self.input_metadata["cell_names"]
1756
1757        return to_return
1758
1759    def get_partial_data(
1760        self,
1761        names: list | str | None = None,
1762        features: list | str | None = None,
1763        name_slot: str = "cell_names",
1764        inc_metadata: bool = False,
1765    ):
1766        """
1767        Return a subset of the data filtered by sample names and/or feature names.
1768
1769        Parameters
1770        ----------
1771        names : list, str, or None
1772            Names of samples to include. If None, all samples are considered.
1773
1774        features : list, str, or None
1775            Names of features to include. If None, all features are considered.
1776
1777        name_slot : str
1778            Column in metadata to use as sample names.
1779
1780        inc_metadata : bool
1781            If True return tuple (data, metadata)
1782
1783        Returns
1784        -------
1785        pandas.DataFrame
1786            Subset of the normalized data based on the specified names and features.
1787        """
1788
1789        data = self.normalized_data.copy()
1790        metadata = self.input_metadata
1791
1792        if name_slot in self.input_metadata.columns:
1793            data.columns = self.input_metadata[name_slot]
1794        else:
1795            raise ValueError("'name_slot' not occured in data!'")
1796
1797        if isinstance(features, str):
1798            features = [features]
1799        elif features is None:
1800            features = []
1801
1802        if isinstance(names, str):
1803            names = [names]
1804        elif names is None:
1805            names = []
1806
1807        features = [x.upper() for x in features]
1808        names = [x.upper() for x in names]
1809
1810        columns_names = [x.upper() for x in data.columns]
1811        features_names = [x.upper() for x in data.index]
1812
1813        columns_bool = [True if x in names else False for x in columns_names]
1814        features_bool = [True if x in features else False for x in features_names]
1815
1816        if True not in columns_bool and True not in features_bool:
1817            print("Missing 'names' and/or 'features'. Returning full dataset instead.")
1818            if inc_metadata:
1819                return data, metadata
1820            else:
1821                return data
1822
1823        if True in columns_bool:
1824            data = data.loc[:, columns_bool]
1825            metadata = metadata.loc[columns_bool, :]
1826
1827            if inc_metadata:
1828                return data, metadata
1829            else:
1830                return data
1831
1832        if True in features_bool:
1833            data = data.loc[features_bool, :]
1834
1835            if inc_metadata:
1836                return data, metadata
1837            else:
1838                return data
1839
1840    def get_metadata(self):
1841        """
1842        Return the stored input metadata.
1843
1844        Returns
1845        -------
1846        pandas.DataFrame
1847            `self.input_metadata` (may be None if not set).
1848        """
1849
1850        to_return = self.input_metadata
1851
1852        return to_return
1853
1854    def gene_calculation(self):
1855        """
1856        Calculate and store per-cell counts (e.g., number of detected genes).
1857
1858        The method computes a binary (presence/absence) per cell and sums across
1859        features to produce `self.gene_calc`.
1860
1861        Update
1862        ------
1863        Sets `self.gene_calc` as a pandas.Series.
1864
1865        Side Effects
1866        ------------
1867        Uses `self.input_data` when available, otherwise `self.normalized_data`.
1868        """
1869
1870        if self.input_data is not None:
1871
1872            bin_col = self.input_data.columns.copy()
1873
1874            bin_col = bin_col.where(bin_col <= 0, 1)
1875
1876            sum_data = bin_col.sum(axis=0)
1877
1878            self.gene_calc = sum_data
1879
1880        elif self.normalized_data is not None:
1881
1882            bin_col = self.normalized_data.copy()
1883
1884            bin_col = bin_col.where(bin_col <= 0, 1)
1885
1886            sum_data = bin_col.sum(axis=0)
1887
1888            self.gene_calc = sum_data
1889
1890    def gene_histograme(self, bins=100):
1891        """
1892        Plot a histogram of the number of genes detected per cell.
1893
1894        Parameters
1895        ----------
1896        bins : int, default 100
1897            Number of histogram bins.
1898
1899        Returns
1900        -------
1901        matplotlib.figure.Figure
1902            Figure containing the histogram of gene contents.
1903
1904        Notes
1905        -----
1906        Requires `self.gene_calc` to be computed prior to calling.
1907        """
1908
1909        fig, ax = plt.subplots(figsize=(8, 5))
1910
1911        _, bin_edges, _ = ax.hist(
1912            self.gene_calc, bins=bins, edgecolor="black", alpha=0.6
1913        )
1914
1915        mu, sigma = np.mean(self.gene_calc), np.std(self.gene_calc)
1916
1917        x = np.linspace(min(self.gene_calc), max(self.gene_calc), 1000)
1918        y = norm.pdf(x, mu, sigma)
1919
1920        y_scaled = y * len(self.gene_calc) * (bin_edges[1] - bin_edges[0])
1921
1922        ax.plot(
1923            x, y_scaled, "r-", linewidth=2, label=f"Normal(μ={mu:.2f}, σ={sigma:.2f})"
1924        )
1925
1926        ax.set_xlabel("Value")
1927        ax.set_ylabel("Count")
1928        ax.set_title("Histogram of genes detected per cell")
1929
1930        ax.set_xticks(np.linspace(min(self.gene_calc), max(self.gene_calc), 20))
1931        ax.tick_params(axis="x", rotation=90)
1932
1933        ax.legend()
1934
1935        return fig
1936
1937    def gene_threshold(self, min_n: int | None, max_n: int | None):
1938        """
1939        Filter cells by gene-detection thresholds (min and/or max).
1940
1941        Parameters
1942        ----------
1943        min_n : int or None
1944            Minimum number of detected genes required to keep a cell.
1945
1946        max_n : int or None
1947            Maximum number of detected genes allowed to keep a cell.
1948
1949        Update
1950        -------
1951        Filters `self.input_data`, `self.normalized_data`, `self.input_metadata`
1952        (and calls `average()` if `self.agg_normalized_data` exists).
1953
1954        Side Effects
1955        ------------
1956        Raises ValueError if both bounds are None or if filtering removes all cells.
1957        """
1958
1959        if min_n is not None and max_n is not None:
1960            mask = (self.gene_calc > min_n) & (self.gene_calc < max_n)
1961        elif min_n is None and max_n is not None:
1962            mask = self.gene_calc < max_n
1963        elif min_n is not None and max_n is None:
1964            mask = self.gene_calc > min_n
1965        else:
1966            raise ValueError("Lack of both min_n and max_n values")
1967
1968        if self.input_data is not None:
1969
1970            if len([y for y in mask if y is False]) == 0:
1971                raise ValueError("Nothing to reduce")
1972
1973            self.input_data = self.input_data.loc[:, mask.values]
1974
1975        if self.normalized_data is not None:
1976
1977            if len([y for y in mask if y is False]) == 0:
1978                raise ValueError("Nothing to reduce")
1979
1980            self.normalized_data = self.normalized_data.loc[:, mask.values]
1981
1982        if self.input_metadata is not None:
1983
1984            if len([y for y in mask if y is False]) == 0:
1985                raise ValueError("Nothing to reduce")
1986
1987            self.input_metadata = self.input_metadata.loc[mask.values, :].reset_index(
1988                drop=True
1989            )
1990
1991            self.input_metadata = self.input_metadata.drop(
1992                columns=["drop"], errors="ignore"
1993            )
1994
1995        if self.agg_normalized_data is not None:
1996            self.average()
1997
1998        self.gene_calculation()
1999        self.cells_calculation()
2000
2001    def cells_calculation(self, name_slot="cell_names"):
2002        """
2003        Calculate number of cells per  call name / cluster.
2004
2005        The method computes a binary (presence/absence) per cell name / cluster and sums across
2006        cells.
2007
2008        Parameters
2009        ----------
2010        name_slot : str, default 'cell_names'
2011            Column in metadata to use as sample names.
2012
2013        Update
2014        ------
2015        Sets `self.cells_calc` as a pd.DataFrame.
2016        """
2017
2018        ls = list(self.input_metadata[name_slot])
2019
2020        df = pd.DataFrame(
2021            {
2022                "cluster": pd.Series(ls).value_counts().index,
2023                "n": pd.Series(ls).value_counts().values,
2024            }
2025        )
2026
2027        self.cells_calc = df
2028
2029    def cell_histograme(self, name_slot: str = "cell_names"):
2030        """
2031        Plot a histogram of the number of cells detected per cell name (cluster).
2032
2033        Parameters
2034        ----------
2035        name_slot : str, default 'cell_names'
2036            Column in metadata to use as sample names.
2037
2038        Returns
2039        -------
2040        matplotlib.figure.Figure
2041            Figure containing the histogram of cell contents.
2042
2043        Notes
2044        -----
2045        Requires `self.cells_calc` to be computed prior to calling.
2046        """
2047
2048        if name_slot != "cell_names":
2049            self.cells_calculation(name_slot=name_slot)
2050
2051        fig, ax = plt.subplots(figsize=(8, 5))
2052
2053        _, bin_edges, _ = ax.hist(
2054            list(self.cells_calc["n"]),
2055            bins=len(set(self.cells_calc["cluster"])),
2056            edgecolor="black",
2057            color="orange",
2058            alpha=0.6,
2059        )
2060
2061        mu, sigma = np.mean(list(self.cells_calc["n"])), np.std(
2062            list(self.cells_calc["n"])
2063        )
2064
2065        x = np.linspace(
2066            min(list(self.cells_calc["n"])), max(list(self.cells_calc["n"])), 1000
2067        )
2068        y = norm.pdf(x, mu, sigma)
2069
2070        y_scaled = y * len(list(self.cells_calc["n"])) * (bin_edges[1] - bin_edges[0])
2071
2072        ax.plot(
2073            x, y_scaled, "r-", linewidth=2, label=f"Normal(μ={mu:.2f}, σ={sigma:.2f})"
2074        )
2075
2076        ax.set_xlabel("Value")
2077        ax.set_ylabel("Count")
2078        ax.set_title("Histogram of cells detected per cell name / cluster")
2079
2080        ax.set_xticks(
2081            np.linspace(
2082                min(list(self.cells_calc["n"])), max(list(self.cells_calc["n"])), 20
2083            )
2084        )
2085        ax.tick_params(axis="x", rotation=90)
2086
2087        ax.legend()
2088
2089        return fig
2090
2091    def cluster_threshold(self, min_n: int | None, name_slot: str = "cell_names"):
2092        """
2093        Filter cell names / clusters by cell-detection threshold.
2094
2095        Parameters
2096        ----------
2097        min_n : int or None
2098            Minimum number of detected genes required to keep a cell.
2099
2100        name_slot : str, default 'cell_names'
2101            Column in metadata to use as sample names.
2102
2103
2104        Update
2105        -------
2106        Filters `self.input_data`, `self.normalized_data`, `self.input_metadata`
2107        (and calls `average()` if `self.agg_normalized_data` exists).
2108        """
2109
2110        if name_slot != "cell_names":
2111            self.cells_calculation(name_slot=name_slot)
2112
2113        if min_n is not None:
2114            names = self.cells_calc["cluster"][self.cells_calc["n"] < min_n]
2115        else:
2116            raise ValueError("Lack of min_n value")
2117
2118        if len(names) > 0:
2119
2120            if self.input_data is not None:
2121
2122                self.input_data.columns = self.input_metadata[name_slot]
2123
2124                mask = [not any(r in x for r in names) for x in self.input_data.columns]
2125
2126                if len([y for y in mask if y is False]) > 0:
2127
2128                    self.input_data = self.input_data.loc[:, mask]
2129
2130            if self.normalized_data is not None:
2131
2132                self.normalized_data.columns = self.input_metadata[name_slot]
2133
2134                mask = [
2135                    not any(r in x for r in names) for x in self.normalized_data.columns
2136                ]
2137
2138                if len([y for y in mask if y is False]) > 0:
2139
2140                    self.normalized_data = self.normalized_data.loc[:, mask]
2141
2142            if self.input_metadata is not None:
2143
2144                self.input_metadata["drop"] = self.input_metadata[name_slot]
2145
2146                mask = [
2147                    not any(r in x for r in names) for x in self.input_metadata["drop"]
2148                ]
2149
2150                if len([y for y in mask if y is False]) > 0:
2151
2152                    self.input_metadata = self.input_metadata.loc[mask, :].reset_index(
2153                        drop=True
2154                    )
2155
2156                self.input_metadata = self.input_metadata.drop(
2157                    columns=["drop"], errors="ignore"
2158                )
2159
2160            if self.agg_normalized_data is not None:
2161
2162                self.agg_normalized_data.columns = self.agg_metadata[name_slot]
2163
2164                mask = [
2165                    not any(r in x for r in names)
2166                    for x in self.agg_normalized_data.columns
2167                ]
2168
2169                if len([y for y in mask if y is False]) > 0:
2170
2171                    self.agg_normalized_data = self.agg_normalized_data.loc[:, mask]
2172
2173            if self.agg_metadata is not None:
2174
2175                self.agg_metadata["drop"] = self.agg_metadata[name_slot]
2176
2177                mask = [
2178                    not any(r in x for r in names) for x in self.agg_metadata["drop"]
2179                ]
2180
2181                if len([y for y in mask if y is False]) > 0:
2182
2183                    self.agg_metadata = self.agg_metadata.loc[mask, :].reset_index(
2184                        drop=True
2185                    )
2186
2187                self.agg_metadata = self.agg_metadata.drop(
2188                    columns=["drop"], errors="ignore"
2189                )
2190
2191            self.gene_calculation()
2192            self.cells_calculation()
2193
2194    def load_sparse_from_projects(self, normalized_data: bool = False):
2195        """
2196        Load sparse 10x-style datasets from stored project paths, concatenate them,
2197        and populate `input_data` / `normalized_data` and `input_metadata`.
2198
2199        Parameters
2200        ----------
2201        normalized_data : bool, default False
2202            If True, store concatenated tables in `self.normalized_data`.
2203            If False, store them in `self.input_data` and normalization
2204            is needed using normalize_data() method.
2205
2206        Side Effects
2207        ------------
2208        - Reads each project using `load_sparse(...)` (expects matrix.mtx, genes.tsv, barcodes.tsv).
2209        - Concatenates all projects column-wise and sets `self.input_metadata`.
2210        - Replaces NaNs with zeros and updates `self.gene_calc`.
2211        """
2212
2213        obj = self.objects
2214
2215        full_data = pd.DataFrame()
2216        full_metadata = pd.DataFrame()
2217
2218        for ke in obj.keys():
2219            print(ke)
2220
2221            dt, met = load_sparse(path=obj[ke], name=ke)
2222
2223            full_data = pd.concat([full_data, dt], axis=1)
2224            full_metadata = pd.concat([full_metadata, met], axis=0)
2225
2226        full_data[np.isnan(full_data)] = 0
2227
2228        if normalized_data:
2229            self.normalized_data = full_data
2230            self.input_metadata = full_metadata
2231        else:
2232
2233            self.input_data = full_data
2234            self.input_metadata = full_metadata
2235
2236        self.gene_calculation()
2237        self.cells_calculation()
2238
2239    def rename_names(self, mapping: dict, slot: str = "cell_names"):
2240        """
2241        Rename entries in `self.input_metadata[slot]` according to a provided mapping.
2242
2243        Parameters
2244        ----------
2245        mapping : dict
2246            Dictionary with keys 'old_name' and 'new_name', each mapping to a list
2247            of equal length describing replacements.
2248
2249        slot : str, default 'cell_names'
2250            Metadata column to operate on.
2251
2252        Update
2253        -------
2254        Updates `self.input_metadata[slot]` in-place with renamed values.
2255
2256        Raises
2257        ------
2258        ValueError
2259            If mapping keys are incorrect, lengths differ, or some 'old_name' values
2260            are not present in the metadata column.
2261        """
2262
2263        if set(["old_name", "new_name"]) != set(mapping.keys()):
2264            raise ValueError(
2265                "Mapping dictionary must contain keys 'old_name' and 'new_name', "
2266                "each with a list of names to change."
2267            )
2268
2269        if len(mapping["old_name"]) != len(mapping["new_name"]):
2270            raise ValueError(
2271                "Mapping dictionary lists 'old_name' and 'new_name' "
2272                "must have the same length!"
2273            )
2274
2275        names_vector = list(self.input_metadata[slot])
2276
2277        if not all(elem in names_vector for elem in list(mapping["old_name"])):
2278            raise ValueError(
2279                f"Some entries from 'old_name' do not exist in the names of slot {slot}."
2280            )
2281
2282        replace_dict = dict(zip(mapping["old_name"], mapping["new_name"]))
2283
2284        names_vector_ret = [replace_dict.get(item, item) for item in names_vector]
2285
2286        self.input_metadata[slot] = names_vector_ret
2287
2288    def rename_subclusters(self, mapping):
2289        """
2290        Rename labels stored in `self.subclusters_.subclusters` according to mapping.
2291
2292        Parameters
2293        ----------
2294        mapping : dict
2295            Mapping with keys 'old_name' and 'new_name' (lists of equal length).
2296
2297        Update
2298        -------
2299        Updates `self.subclusters_.subclusters` with renamed labels.
2300
2301        Raises
2302        ------
2303        ValueError
2304            If mapping is invalid or old names are not present.
2305        """
2306
2307        if set(["old_name", "new_name"]) != set(mapping.keys()):
2308            raise ValueError(
2309                "Mapping dictionary must contain keys 'old_name' and 'new_name', "
2310                "each with a list of names to change."
2311            )
2312
2313        if len(mapping["old_name"]) != len(mapping["new_name"]):
2314            raise ValueError(
2315                "Mapping dictionary lists 'old_name' and 'new_name' "
2316                "must have the same length!"
2317            )
2318
2319        names_vector = list(self.subclusters_.subclusters)
2320
2321        if not all(elem in names_vector for elem in list(mapping["old_name"])):
2322            raise ValueError(
2323                "Some entries from 'old_name' do not exist in the subcluster names."
2324            )
2325
2326        replace_dict = dict(zip(mapping["old_name"], mapping["new_name"]))
2327
2328        names_vector_ret = [replace_dict.get(item, item) for item in names_vector]
2329
2330        self.subclusters_.subclusters = names_vector_ret
2331
2332    def save_sparse(
2333        self,
2334        path_to_save: str = os.getcwd(),
2335        name_slot: str = "cell_names",
2336        data_slot: str = "normalized",
2337    ):
2338        """
2339        Export data as 10x-compatible sparse files (matrix.mtx, barcodes.tsv, genes.tsv).
2340
2341        Parameters
2342        ----------
2343        path_to_save : str, default current working directory
2344            Directory where files will be written.
2345
2346        name_slot : str, default 'cell_names'
2347            Metadata column providing cell names for barcodes.tsv.
2348
2349        data_slot : str, default 'normalized'
2350            Either 'normalized' (uses self.normalized_data) or 'count' (uses self.input_data).
2351
2352        Raises
2353        ------
2354        ValueError
2355            If `data_slot` is not 'normalized' or 'count'.
2356        """
2357
2358        names = self.input_metadata[name_slot]
2359
2360        if data_slot.lower() == "normalized":
2361
2362            features = list(self.normalized_data.index)
2363            mtx = sparse.csr_matrix(self.normalized_data)
2364
2365        elif data_slot.lower() == "count":
2366
2367            features = list(self.input_data.index)
2368            mtx = sparse.csr_matrix(self.input_data)
2369
2370        else:
2371            raise ValueError("'data_slot' must be included in 'normalized' or 'count'")
2372
2373        os.makedirs(path_to_save, exist_ok=True)
2374
2375        mmwrite(os.path.join(path_to_save, "matrix.mtx"), mtx)
2376
2377        pd.Series(names).to_csv(
2378            os.path.join(path_to_save, "barcodes.tsv"),
2379            index=False,
2380            header=False,
2381            sep="\t",
2382        )
2383
2384        pd.Series(features).to_csv(
2385            os.path.join(path_to_save, "genes.tsv"), index=False, header=False, sep="\t"
2386        )
2387
2388    def normalize_counts(
2389        self, normalize_factor: int = 100000, log_transform: bool = True
2390    ):
2391        """
2392        Normalize raw counts to counts-per-(normalize_factor)
2393        (e.g., CPM, TPM - depending on normalize_factor).
2394
2395        Parameters
2396        ----------
2397        normalize_factor : int, default 100000
2398            Scaling factor used after dividing by column sums.
2399
2400        log_transform : bool, default True
2401            If True, apply log2(x+1) transformation to normalized values.
2402
2403        Update
2404        -------
2405            Sets `self.normalized_data` to normalized values (fills NaNs with 0).
2406
2407        Raises
2408        ------
2409        ValueError
2410            If `self.input_data` is missing (cannot normalize).
2411        """
2412        if self.input_data is None:
2413            raise ValueError("Input data is missing, cannot normalize.")
2414
2415        sum_col = self.input_data.sum()
2416        self.normalized_data = self.input_data.div(sum_col).fillna(0) * normalize_factor
2417
2418        if log_transform:
2419            # log2(x + 1) to avoid -inf for zeros
2420            self.normalized_data = np.log2(self.normalized_data + 1)
2421
2422    def statistic(
2423        self,
2424        cells=None,
2425        sets=None,
2426        min_exp: float = 0.01,
2427        min_pct: float = 0.1,
2428        n_proc: int = 10,
2429    ):
2430        """
2431        Compute per-feature statistics (Mann–Whitney U) comparing target vs rest.
2432
2433        This is a wrapper similar to `calc_DEG` tailored to use `self.normalized_data`
2434        and `self.input_metadata`. It returns per-feature statistics including p-values,
2435        adjusted p-values, means, variances, effect-size measures and fold-changes.
2436
2437        Parameters
2438        ----------
2439        cells : list, 'All', dict, or None
2440            Defines the target cells or groups for comparison (several modes supported).
2441
2442        sets : 'All', dict, or None
2443            Alternative grouping mode (operate on `self.input_metadata['sets']`).
2444
2445        min_exp : float, default 0.01
2446            Minimum expression threshold used when filtering features.
2447
2448        min_pct : float, default 0.1
2449            Minimum proportion of expressing cells in the target group required to test a feature.
2450
2451        n_proc : int, default 10
2452            Number of parallel jobs to use.
2453
2454        Returns
2455        -------
2456        pandas.DataFrame or dict
2457            Results DataFrame (or dict containing valid/control cells + DataFrame),
2458            similar to `calc_DEG` interface.
2459
2460        Raises
2461        ------
2462        ValueError
2463            If neither `cells` nor `sets` is provided, or input metadata mismatch occurs.
2464
2465        Notes
2466        -----
2467        Multiple modes supported: single-list entities, 'All', pairwise dicts, etc.
2468        """
2469
2470        def stat_calc(choose, feature_name):
2471            target_values = choose.loc[choose["DEG"] == "target", feature_name]
2472            rest_values = choose.loc[choose["DEG"] == "rest", feature_name]
2473
2474            pct_valid = (target_values > 0).sum() / len(target_values)
2475            pct_rest = (rest_values > 0).sum() / len(rest_values)
2476
2477            avg_valid = np.mean(target_values)
2478            avg_ctrl = np.mean(rest_values)
2479            sd_valid = np.std(target_values, ddof=1)
2480            sd_ctrl = np.std(rest_values, ddof=1)
2481            esm = (avg_valid - avg_ctrl) / np.sqrt(((sd_valid**2 + sd_ctrl**2) / 2))
2482
2483            if np.sum(target_values) == np.sum(rest_values):
2484                p_val = 1.0
2485            else:
2486                _, p_val = stats.mannwhitneyu(
2487                    target_values, rest_values, alternative="two-sided"
2488                )
2489
2490            return {
2491                "feature": feature_name,
2492                "p_val": p_val,
2493                "pct_valid": pct_valid,
2494                "pct_ctrl": pct_rest,
2495                "avg_valid": avg_valid,
2496                "avg_ctrl": avg_ctrl,
2497                "sd_valid": sd_valid,
2498                "sd_ctrl": sd_ctrl,
2499                "esm": esm,
2500            }
2501
2502        def prepare_and_run_stat(
2503            choose, valid_group, min_exp, min_pct, n_proc, factors
2504        ):
2505
2506            tmp_dat = choose[choose["DEG"] == "target"]
2507            tmp_dat = tmp_dat.drop("DEG", axis=1)
2508
2509            counts = (tmp_dat > min_exp).sum(axis=0)
2510
2511            total_count = tmp_dat.shape[0]
2512
2513            info = pd.DataFrame(
2514                {"feature": list(tmp_dat.columns), "pct": list(counts / total_count)}
2515            )
2516
2517            del tmp_dat
2518
2519            drop_col = info["feature"][info["pct"] <= min_pct]
2520
2521            if len(drop_col) + 1 == len(choose.columns):
2522                drop_col = info["feature"][info["pct"] == 0]
2523
2524            del info
2525
2526            choose = choose.drop(list(drop_col), axis=1)
2527
2528            results = Parallel(n_jobs=n_proc)(
2529                delayed(stat_calc)(choose, feature)
2530                for feature in tqdm(choose.columns[choose.columns != "DEG"])
2531            )
2532
2533            df = pd.DataFrame(results)
2534            df["valid_group"] = valid_group
2535            df.sort_values(by="p_val", inplace=True)
2536
2537            num_tests = len(df)
2538            df["adj_pval"] = np.minimum(
2539                1, (df["p_val"] * num_tests) / np.arange(1, num_tests + 1)
2540            )
2541
2542            factors_df = factors.to_frame(name="factor")
2543
2544            df = df.merge(factors_df, left_on="feature", right_index=True, how="left")
2545
2546            df["FC"] = (df["avg_valid"] + df["factor"]) / (
2547                df["avg_ctrl"] + df["factor"]
2548            )
2549            df["log(FC)"] = np.log2(df["FC"])
2550            df["norm_diff"] = df["avg_valid"] - df["avg_ctrl"]
2551            df = df.drop(columns=["factor"])
2552
2553            return df
2554
2555        choose = self.normalized_data.copy().T
2556        factors = self.normalized_data.copy().replace(0, np.nan).min(axis=1)
2557        factors[factors == 0] = min(factors[factors != 0])
2558
2559        final_results = []
2560
2561        if isinstance(cells, list) and sets is None:
2562            print("\nAnalysis started...\nComparing selected cells to the whole set...")
2563            choose.index = (
2564                self.input_metadata["cell_names"] + " # " + self.input_metadata["sets"]
2565            )
2566
2567            if "#" not in cells[0]:
2568                choose.index = self.input_metadata["cell_names"]
2569
2570                print(
2571                    "Not include the set info (name # set) in the 'cells' list. "
2572                    "Only the names will be compared, without considering the set information."
2573                )
2574
2575            labels = ["target" if idx in cells else "rest" for idx in choose.index]
2576            valid = list(
2577                set(choose.index[[i for i, x in enumerate(labels) if x == "target"]])
2578            )
2579
2580            choose["DEG"] = labels
2581            choose = choose[choose["DEG"] != "drop"]
2582
2583            result_df = prepare_and_run_stat(
2584                choose.reset_index(drop=True),
2585                valid_group=valid,
2586                min_exp=min_exp,
2587                min_pct=min_pct,
2588                n_proc=n_proc,
2589                factors=factors,
2590            )
2591            return {"valid_cells": valid, "control_cells": "rest", "DEG": result_df}
2592
2593        elif cells == "All" and sets is None:
2594            print("\nAnalysis started...\nComparing each type of cell to others...")
2595            choose.index = (
2596                self.input_metadata["cell_names"] + " # " + self.input_metadata["sets"]
2597            )
2598            unique_labels = set(choose.index)
2599
2600            for label in tqdm(unique_labels):
2601                print(f"\nCalculating statistics for {label}")
2602                labels = ["target" if idx == label else "rest" for idx in choose.index]
2603                choose["DEG"] = labels
2604                choose = choose[choose["DEG"] != "drop"]
2605                result_df = prepare_and_run_stat(
2606                    choose.copy(),
2607                    valid_group=label,
2608                    min_exp=min_exp,
2609                    min_pct=min_pct,
2610                    n_proc=n_proc,
2611                    factors=factors,
2612                )
2613                final_results.append(result_df)
2614
2615            return pd.concat(final_results, ignore_index=True)
2616
2617        elif cells is None and sets == "All":
2618            print("\nAnalysis started...\nComparing each set/group to others...")
2619            choose.index = self.input_metadata["sets"]
2620            unique_sets = set(choose.index)
2621
2622            for label in tqdm(unique_sets):
2623                print(f"\nCalculating statistics for {label}")
2624                labels = ["target" if idx == label else "rest" for idx in choose.index]
2625
2626                choose["DEG"] = labels
2627                choose = choose[choose["DEG"] != "drop"]
2628                result_df = prepare_and_run_stat(
2629                    choose.copy(),
2630                    valid_group=label,
2631                    min_exp=min_exp,
2632                    min_pct=min_pct,
2633                    n_proc=n_proc,
2634                    factors=factors,
2635                )
2636                final_results.append(result_df)
2637
2638            return pd.concat(final_results, ignore_index=True)
2639
2640        elif cells is None and isinstance(sets, dict):
2641            print("\nAnalysis started...\nComparing groups...")
2642
2643            choose.index = self.input_metadata["sets"]
2644
2645            group_list = list(sets.keys())
2646            if len(group_list) != 2:
2647                print("Only pairwise group comparison is supported.")
2648                return None
2649
2650            labels = [
2651                (
2652                    "target"
2653                    if idx in sets[group_list[0]]
2654                    else "rest" if idx in sets[group_list[1]] else "drop"
2655                )
2656                for idx in choose.index
2657            ]
2658            choose["DEG"] = labels
2659            choose = choose[choose["DEG"] != "drop"]
2660
2661            result_df = prepare_and_run_stat(
2662                choose.reset_index(drop=True),
2663                valid_group=group_list[0],
2664                min_exp=min_exp,
2665                min_pct=min_pct,
2666                n_proc=n_proc,
2667                factors=factors,
2668            )
2669            return result_df
2670
2671        elif isinstance(cells, dict) and sets is None:
2672            print("\nAnalysis started...\nComparing groups...")
2673            choose.index = (
2674                self.input_metadata["cell_names"] + " # " + self.input_metadata["sets"]
2675            )
2676
2677            if "#" not in cells[list(cells.keys())[0]][0]:
2678                choose.index = self.input_metadata["cell_names"]
2679
2680                print(
2681                    "Not include the set info (name # set) in the 'cells' dict. "
2682                    "Only the names will be compared, without considering the set information."
2683                )
2684
2685            group_list = list(cells.keys())
2686            if len(group_list) != 2:
2687                print("Only pairwise group comparison is supported.")
2688                return None
2689
2690            labels = [
2691                (
2692                    "target"
2693                    if idx in cells[group_list[0]]
2694                    else "rest" if idx in cells[group_list[1]] else "drop"
2695                )
2696                for idx in choose.index
2697            ]
2698
2699            choose["DEG"] = labels
2700            choose = choose[choose["DEG"] != "drop"]
2701
2702            result_df = prepare_and_run_stat(
2703                choose.reset_index(drop=True),
2704                valid_group=group_list[0],
2705                min_exp=min_exp,
2706                min_pct=min_pct,
2707                n_proc=n_proc,
2708                factors=factors,
2709            )
2710
2711            return result_df.reset_index(drop=True)
2712
2713        else:
2714            raise ValueError(
2715                "You must specify either 'cells' or 'sets' (or both). None were provided, which is not allowed for this analysis."
2716            )
2717
2718    def calculate_difference_markers(
2719        self, min_exp=0, min_pct=0.25, n_proc=10, force: bool = False
2720    ):
2721        """
2722        Compute differential markers (var_data) if not already present.
2723
2724        Parameters
2725        ----------
2726        min_exp : float, default 0
2727            Minimum expression threshold passed to `statistic`.
2728
2729        min_pct : float, default 0.25
2730            Minimum percent expressed in target group.
2731
2732        n_proc : int, default 10
2733            Parallel jobs.
2734
2735        force : bool, default False
2736            If True, recompute even if `self.var_data` is present.
2737
2738        Update
2739        -------
2740        Sets `self.var_data` to the result of `self.statistic(...)`.
2741
2742        Raise
2743        ------
2744        ValueError if already computed and `force` is False.
2745        """
2746
2747        if self.var_data is None or force:
2748
2749            self.var_data = self.statistic(
2750                cells="All", sets=None, min_exp=min_exp, min_pct=min_pct, n_proc=n_proc
2751            )
2752
2753        else:
2754            raise ValueError(
2755                "self.calculate_difference_markers() has already been executed. "
2756                "The results are stored in self.var. "
2757                "If you want to recalculate with different statistics, please rerun the method with force=True."
2758            )
2759
2760    def clustering_features(
2761        self,
2762        features_list: list | None,
2763        name_slot: str = "cell_names",
2764        p_val: float = 0.05,
2765        top_n: int = 25,
2766        adj_mean: bool = True,
2767        beta: float = 0.2,
2768    ):
2769        """
2770        Prepare clustering input by selecting marker features and optionally smoothing cell values
2771        toward group means.
2772
2773        Parameters
2774        ----------
2775        features_list : list or None
2776            If provided, use this list of features. If None, features are selected
2777            from `self.var_data` (adj_pval <= p_val, positive logFC) picking `top_n` per group.
2778
2779        name_slot : str, default 'cell_names'
2780            Metadata column used for naming.
2781
2782        p_val : float, default 0.05
2783            Adjusted p-value cutoff when selecting features automatically.
2784
2785        top_n : int, default 25
2786            Number of top features per valid group to keep if `features_list` is None.
2787
2788        adj_mean : bool, default True
2789            If True, adjust cell values toward group means using `beta`.
2790
2791        beta : float, default 0.2
2792            Adjustment strength toward group mean.
2793
2794        Update
2795        ------
2796        Sets `self.clustering_data` and `self.clustering_metadata` to the selected subset,
2797        ready for PCA/UMAP/clustering.
2798        """
2799
2800        if features_list is None or len(features_list) == 0:
2801
2802            if self.var_data is None:
2803                raise ValueError(
2804                    "Lack of 'self.var_data'. Use self.calculate_difference_markers() method first."
2805                )
2806
2807            df_tmp = self.var_data[self.var_data["adj_pval"] <= p_val]
2808            df_tmp = df_tmp[df_tmp["log(FC)"] > 0]
2809            df_tmp = (
2810                df_tmp.sort_values(
2811                    ["valid_group", "esm", "log(FC)"], ascending=[True, False, False]
2812                )
2813                .groupby("valid_group")
2814                .head(top_n)
2815            )
2816
2817            feaures_list = list(set(df_tmp["feature"]))
2818
2819        data = self.get_partial_data(
2820            names=None, features=feaures_list, name_slot=name_slot
2821        )
2822        data_avg = average(data)
2823
2824        if adj_mean:
2825            data = adjust_cells_to_group_mean(data=data, data_avg=data_avg, beta=beta)
2826
2827        self.clustering_data = data
2828
2829        self.clustering_metadata = self.input_metadata
2830
2831    def average(self):
2832        """
2833        Aggregate normalized data by (cell_name, set) pairs computing the mean per group.
2834
2835        The method constructs new column names as "cell_name # set", averages columns
2836        sharing identical labels, and populates `self.agg_normalized_data` and `self.agg_metadata`.
2837
2838        Update
2839        ------
2840        Sets `self.agg_normalized_data` (features x aggregated samples) and
2841        `self.agg_metadata` (DataFrame with 'cell_names' and 'sets').
2842        """
2843
2844        wide_data = self.normalized_data
2845
2846        wide_metadata = self.input_metadata
2847
2848        new_names = wide_metadata["cell_names"] + " # " + wide_metadata["sets"]
2849
2850        wide_data.columns = list(new_names)
2851
2852        aggregated_df = wide_data.T.groupby(wide_data.columns, axis=0).mean().T
2853
2854        sets = [re.sub(".*# ", "", x) for x in aggregated_df.columns]
2855        names = [re.sub(" #.*", "", x) for x in aggregated_df.columns]
2856
2857        aggregated_df.columns = names
2858        aggregated_metadata = pd.DataFrame({"cell_names": names, "sets": sets})
2859
2860        self.agg_metadata = aggregated_metadata
2861        self.agg_normalized_data = aggregated_df
2862
2863    def estimating_similarity(
2864        self, method="pearson", p_val: float = 0.05, top_n: int = 25
2865    ):
2866        """
2867        Estimate pairwise similarity and Euclidean distance between aggregated samples.
2868
2869        Parameters
2870        ----------
2871        method : str, default 'pearson'
2872            Correlation method to use (passed to pandas.DataFrame.corr()).
2873
2874        p_val : float, default 0.05
2875            Adjusted p-value cutoff used to select marker features from `self.var_data`.
2876
2877        top_n : int, default 25
2878            Number of top features per valid group to include.
2879
2880        Update
2881        -------
2882        Computes a combined table with per-pair correlation and euclidean distance
2883        and stores it in `self.similarity`.
2884        """
2885
2886        if self.var_data is None:
2887            raise ValueError(
2888                "Lack of 'self.var_data'. Use self.calculate_difference_markers() method first."
2889            )
2890
2891        if self.agg_normalized_data is None:
2892            self.average()
2893
2894        metadata = self.agg_metadata
2895        data = self.agg_normalized_data
2896
2897        df_tmp = self.var_data[self.var_data["adj_pval"] <= p_val]
2898        df_tmp = df_tmp[df_tmp["log(FC)"] > 0]
2899        df_tmp = (
2900            df_tmp.sort_values(
2901                ["valid_group", "esm", "log(FC)"], ascending=[True, False, False]
2902            )
2903            .groupby("valid_group")
2904            .head(top_n)
2905        )
2906
2907        data = data.loc[list(set(df_tmp["feature"]))]
2908
2909        if len(set(metadata["sets"])) > 1:
2910            data.columns = data.columns + " # " + [x for x in metadata["sets"]]
2911        else:
2912            data = data.copy()
2913
2914        scaler = StandardScaler()
2915
2916        scaled_data = scaler.fit_transform(data)
2917
2918        scaled_df = pd.DataFrame(scaled_data, columns=data.columns)
2919
2920        cor = scaled_df.corr(method=method)
2921        cor_df = cor.stack().reset_index()
2922        cor_df.columns = ["cell1", "cell2", "correlation"]
2923
2924        distances = pdist(scaled_df.T, metric="euclidean")
2925        dist_mat = pd.DataFrame(
2926            squareform(distances), index=scaled_df.columns, columns=scaled_df.columns
2927        )
2928        dist_df = dist_mat.stack().reset_index()
2929        dist_df.columns = ["cell1", "cell2", "euclidean_dist"]
2930
2931        full = pd.merge(cor_df, dist_df, on=["cell1", "cell2"])
2932
2933        full = full[full["cell1"] != full["cell2"]]
2934        full = full.reset_index(drop=True)
2935
2936        self.similarity = full
2937
2938    def similarity_plot(
2939        self,
2940        split_sets=True,
2941        set_info: bool = True,
2942        cmap="seismic",
2943        width=12,
2944        height=10,
2945    ):
2946        """
2947        Visualize pairwise similarity as a scatter plot.
2948
2949        Parameters
2950        ----------
2951        split_sets : bool, default True
2952            If True and set information is present, split plotting area roughly into two halves to visualize cross-set pairs.
2953
2954        set_info : bool, default True
2955            If True, keep the ' # set' annotation in labels; otherwise strip it.
2956
2957        cmap : str, default 'seismic'
2958            Color map for correlation (hue).
2959
2960        width : int, default 12
2961            Figure width.
2962
2963        height : int, default 10
2964            Figure height.
2965
2966        Returns
2967        -------
2968        matplotlib.figure.Figure
2969
2970        Raises
2971        ------
2972        ValueError
2973            If `self.similarity` is None.
2974
2975        Notes
2976        -----
2977        The function filters pairs by z-scored euclidean distance > 0 to focus on closer pairs.
2978        """
2979
2980        if self.similarity is None:
2981            raise ValueError(
2982                "Similarity data is missing. Please calculate similarity using self.estimating_similarity."
2983            )
2984
2985        similarity_data = self.similarity
2986
2987        if " # " in similarity_data["cell1"][0]:
2988            similarity_data["set1"] = [
2989                re.sub(".*# ", "", x) for x in similarity_data["cell1"]
2990            ]
2991            similarity_data["set2"] = [
2992                re.sub(".*# ", "", x) for x in similarity_data["cell2"]
2993            ]
2994
2995        if split_sets and " # " in similarity_data["cell1"][0]:
2996            sets = list(
2997                set(list(similarity_data["set1"]) + list(similarity_data["set2"]))
2998            )
2999
3000            mm = math.ceil(len(sets) / 2)
3001
3002            x_s = sets[0:mm]
3003            y_s = sets[mm : len(sets)]
3004
3005            similarity_data = similarity_data[similarity_data["set1"].isin(x_s)]
3006            similarity_data = similarity_data[similarity_data["set2"].isin(y_s)]
3007
3008            similarity_data = similarity_data.sort_values(["set1", "set2"])
3009
3010        if set_info is False and " # " in similarity_data["cell1"][0]:
3011            similarity_data["cell1"] = [
3012                re.sub(" #.*", "", x) for x in similarity_data["cell1"]
3013            ]
3014            similarity_data["cell2"] = [
3015                re.sub(" #.*", "", x) for x in similarity_data["cell2"]
3016            ]
3017
3018        similarity_data["-euclidean_zscore"] = -zscore(
3019            similarity_data["euclidean_dist"]
3020        )
3021
3022        similarity_data = similarity_data[similarity_data["-euclidean_zscore"] > 0]
3023
3024        fig = plt.figure(figsize=(width, height))
3025        sns.scatterplot(
3026            data=similarity_data,
3027            x="cell1",
3028            y="cell2",
3029            hue="correlation",
3030            size="-euclidean_zscore",
3031            sizes=(1, 100),
3032            palette=cmap,
3033            alpha=1,
3034            edgecolor="black",
3035        )
3036
3037        plt.xticks(rotation=90)
3038        plt.yticks(rotation=0)
3039        plt.xlabel("Cell 1")
3040        plt.ylabel("Cell 2")
3041        plt.legend(bbox_to_anchor=(1.05, 1), loc="upper left")
3042
3043        plt.grid(True, alpha=0.6)
3044
3045        plt.tight_layout()
3046
3047        return fig
3048
3049    def spatial_similarity(
3050        self,
3051        set_info: bool = True,
3052        bandwidth=1,
3053        n_neighbors=5,
3054        min_dist=0.1,
3055        legend_split=2,
3056        point_size=100,
3057        spread=1.0,
3058        set_op_mix_ratio=1.0,
3059        local_connectivity=1,
3060        repulsion_strength=1.0,
3061        negative_sample_rate=5,
3062        threshold=0.1,
3063        width=12,
3064        height=10,
3065    ):
3066        """
3067        Create a spatial UMAP-like visualization of similarity relationships between samples.
3068
3069        Parameters
3070        ----------
3071        set_info : bool, default True
3072            If True, retain set information in labels.
3073
3074        bandwidth : float, default 1
3075            Bandwidth used by MeanShift for clustering polygons.
3076
3077        point_size : float, default 100
3078            Size of scatter points.
3079
3080        legend_split : int, default 2
3081            Number of columns in legend.
3082
3083        n_neighbors, min_dist, spread, set_op_mix_ratio, local_connectivity, repulsion_strength, negative_sample_rate : parameters passed to UMAP.
3084
3085        threshold : float, default 0.1
3086            Minimum text distance for label adjustment to avoid overlap.
3087
3088        width : int, default 12
3089            Figure width.
3090
3091        height : int, default 10
3092            Figure height.
3093
3094        Returns
3095        -------
3096        matplotlib.figure.Figure
3097
3098        Raises
3099        ------
3100        ValueError
3101            If `self.similarity` is None.
3102
3103        Notes
3104        -----
3105        Builds a precomputed distance matrix combining correlation and euclidean distance,
3106        runs UMAP with metric='precomputed', then overlays cluster hulls (MeanShift + convex hull)
3107        and arrows to indicate nearest neighbors (minimal combined distance).
3108        """
3109
3110        if self.similarity is None:
3111            raise ValueError(
3112                "Similarity data is missing. Please calculate similarity using self.estimating_similarity."
3113            )
3114
3115        similarity_data = self.similarity
3116
3117        sim = similarity_data["correlation"]
3118        sim_scaled = (sim - sim.min()) / (sim.max() - sim.min())
3119        eu_dist = similarity_data["euclidean_dist"]
3120        eu_dist_scaled = (eu_dist - eu_dist.min()) / (eu_dist.max() - eu_dist.min())
3121
3122        similarity_data["combo_dist"] = (1 - sim_scaled) * eu_dist_scaled
3123
3124        # for nn target
3125        arrow_df = similarity_data.copy()
3126        arrow_df = similarity_data.loc[
3127            similarity_data.groupby("cell1")["combo_dist"].idxmin()
3128        ]
3129
3130        cells = sorted(set(similarity_data["cell1"]) | set(similarity_data["cell2"]))
3131        combo_matrix = pd.DataFrame(0, index=cells, columns=cells, dtype=float)
3132
3133        for _, row in similarity_data.iterrows():
3134            combo_matrix.loc[row["cell1"], row["cell2"]] = row["combo_dist"]
3135            combo_matrix.loc[row["cell2"], row["cell1"]] = row["combo_dist"]
3136
3137        umap_model = umap.UMAP(
3138            n_components=2,
3139            metric="precomputed",
3140            n_neighbors=n_neighbors,
3141            min_dist=min_dist,
3142            spread=spread,
3143            set_op_mix_ratio=set_op_mix_ratio,
3144            local_connectivity=set_op_mix_ratio,
3145            repulsion_strength=repulsion_strength,
3146            negative_sample_rate=negative_sample_rate,
3147            transform_seed=42,
3148            init="spectral",
3149            random_state=42,
3150            verbose=True,
3151        )
3152
3153        coords = umap_model.fit_transform(combo_matrix.values)
3154        cell_names = list(combo_matrix.index)
3155        num_cells = len(cell_names)
3156        palette = sns.color_palette("tab20c", num_cells)
3157
3158        if "#" in cell_names[0]:
3159            avsets = set(
3160                [re.sub(".*# ", "", x) for x in similarity_data["cell1"]]
3161                + [re.sub(".*# ", "", x) for x in similarity_data["cell2"]]
3162            )
3163            num_sets = len(avsets)
3164            color_indices = [i * len(palette) // num_sets for i in range(num_sets)]
3165            color_mapping_sets = {
3166                set_name: palette[i] for i, set_name in zip(color_indices, avsets)
3167            }
3168            color_mapping = {
3169                name: color_mapping_sets[re.sub(".*# ", "", name)]
3170                for i, name in enumerate(cell_names)
3171            }
3172        else:
3173            color_mapping = {name: palette[i] for i, name in enumerate(cell_names)}
3174
3175        meanshift = MeanShift(bandwidth=bandwidth)
3176        labels = meanshift.fit_predict(coords)
3177
3178        fig = plt.figure(figsize=(width, height))
3179        ax = plt.gca()
3180
3181        unique_labels = set(labels)
3182        cluster_palette = sns.color_palette("hls", len(unique_labels))
3183
3184        for label in unique_labels:
3185            if label == -1:
3186                continue
3187            cluster_coords = coords[labels == label]
3188            if len(cluster_coords) < 3:
3189                continue
3190
3191            hull = ConvexHull(cluster_coords)
3192            hull_points = cluster_coords[hull.vertices]
3193
3194            centroid = np.mean(hull_points, axis=0)
3195            expanded = hull_points + 0.05 * (hull_points - centroid)
3196
3197            poly = Polygon(
3198                expanded,
3199                closed=True,
3200                facecolor=cluster_palette[label],
3201                edgecolor="none",
3202                alpha=0.2,
3203                zorder=1,
3204            )
3205            ax.add_patch(poly)
3206
3207        texts = []
3208        for i, (x, y) in enumerate(coords):
3209            plt.scatter(
3210                x,
3211                y,
3212                s=point_size,
3213                color=color_mapping[cell_names[i]],
3214                edgecolors="black",
3215                linewidths=0.5,
3216                zorder=2,
3217            )
3218            texts.append(
3219                ax.text(
3220                    x, y, str(i), ha="center", va="center", fontsize=8, color="black"
3221                )
3222            )
3223
3224        def dist(p1, p2):
3225            return np.sqrt((p1[0] - p2[0]) ** 2 + (p1[1] - p2[1]) ** 2)
3226
3227        texts_to_adjust = []
3228        for i, t1 in enumerate(texts):
3229            for j, t2 in enumerate(texts):
3230                if i >= j:
3231                    continue
3232                d = dist(
3233                    (t1.get_position()[0], t1.get_position()[1]),
3234                    (t2.get_position()[0], t2.get_position()[1]),
3235                )
3236                if d < threshold:
3237                    if t1 not in texts_to_adjust:
3238                        texts_to_adjust.append(t1)
3239                    if t2 not in texts_to_adjust:
3240                        texts_to_adjust.append(t2)
3241
3242        adjust_text(
3243            texts_to_adjust,
3244            expand_text=(1.0, 1.0),
3245            force_text=0.9,
3246            arrowprops=dict(arrowstyle="-", color="gray", lw=0.1),
3247            ax=ax,
3248        )
3249
3250        for _, row in arrow_df.iterrows():
3251            try:
3252                idx1 = cell_names.index(row["cell1"])
3253                idx2 = cell_names.index(row["cell2"])
3254            except ValueError:
3255                continue
3256            x1, y1 = coords[idx1]
3257            x2, y2 = coords[idx2]
3258            arrow = FancyArrowPatch(
3259                (x1, y1),
3260                (x2, y2),
3261                arrowstyle="->",
3262                color="gray",
3263                linewidth=1.5,
3264                alpha=0.5,
3265                mutation_scale=12,
3266                zorder=0,
3267            )
3268            ax.add_patch(arrow)
3269
3270        if set_info is False and " # " in cell_names[0]:
3271
3272            legend_elements = [
3273                Patch(
3274                    facecolor=color_mapping[name],
3275                    edgecolor="black",
3276                    label=f"{i}{re.sub(' #.*', '', name)}",
3277                )
3278                for i, name in enumerate(cell_names)
3279            ]
3280
3281        else:
3282
3283            legend_elements = [
3284                Patch(
3285                    facecolor=color_mapping[name],
3286                    edgecolor="black",
3287                    label=f"{i}{name}",
3288                )
3289                for i, name in enumerate(cell_names)
3290            ]
3291
3292        plt.legend(
3293            handles=legend_elements,
3294            title="Cells",
3295            bbox_to_anchor=(1.05, 1),
3296            loc="upper left",
3297            ncol=legend_split,
3298        )
3299
3300        plt.xlabel("UMAP 1")
3301        plt.ylabel("UMAP 2")
3302        plt.grid(False)
3303        plt.show()
3304
3305        return fig
3306
3307    # subclusters part
3308
3309    def subcluster_prepare(self, features: list, cluster: str):
3310        """
3311        Prepare a `Clustering` object for subcluster analysis on a selected parent cluster.
3312
3313        Parameters
3314        ----------
3315        features : list
3316            Features to include for subcluster analysis.
3317
3318        cluster : str
3319            Parent cluster name (used to select matching cells).
3320
3321        Update
3322        ------
3323        Initializes `self.subclusters_` as a new `Clustering` instance containing the
3324        reduced data for the given cluster and stores `current_features` and `current_cluster`.
3325        """
3326
3327        dat = self.normalized_data
3328        dat.columns = list(self.input_metadata["cell_names"])
3329
3330        dat = reduce_data(self.normalized_data, features=features, names=[cluster])
3331
3332        self.subclusters_ = Clustering(data=dat, metadata=None)
3333
3334        self.subclusters_.current_features = features
3335        self.subclusters_.current_cluster = cluster
3336
3337    def define_subclusters(
3338        self,
3339        umap_num: int = 2,
3340        eps: float = 0.5,
3341        min_samples: int = 10,
3342        n_neighbors: int = 5,
3343        min_dist: float = 0.1,
3344        spread: float = 1.0,
3345        set_op_mix_ratio: float = 1.0,
3346        local_connectivity: int = 1,
3347        repulsion_strength: float = 1.0,
3348        negative_sample_rate: int = 5,
3349        width=8,
3350        height=6,
3351    ):
3352        """
3353        Compute UMAP and DBSCAN clustering within a previously prepared subcluster dataset.
3354
3355        Parameters
3356        ----------
3357        umap_num : int, default 2
3358            Number of UMAP dimensions to compute.
3359
3360        eps : float, default 0.5
3361            DBSCAN eps parameter.
3362
3363        min_samples : int, default 10
3364            DBSCAN min_samples parameter.
3365
3366        n_neighbors, min_dist, spread, set_op_mix_ratio, local_connectivity, repulsion_strength, negative_sample_rate, width, height :
3367            Additional parameters passed to UMAP / plotting / MeanShift as appropriate.
3368
3369        Update
3370        -------
3371        Stores cluster labels in `self.subclusters_.subclusters`.
3372
3373        Raises
3374        ------
3375        RuntimeError
3376            If `self.subclusters_` has not been prepared.
3377        """
3378
3379        if self.subclusters_ is None:
3380            raise RuntimeError(
3381                "Nothing to return. 'self.subcluster_prepare' was not conducted!"
3382            )
3383
3384        self.subclusters_.perform_UMAP(
3385            factorize=False,
3386            umap_num=umap_num,
3387            pc_num=0,
3388            harmonized=False,
3389            n_neighbors=n_neighbors,
3390            min_dist=min_dist,
3391            spread=spread,
3392            set_op_mix_ratio=set_op_mix_ratio,
3393            local_connectivity=local_connectivity,
3394            repulsion_strength=repulsion_strength,
3395            negative_sample_rate=negative_sample_rate,
3396            width=width,
3397            height=height,
3398        )
3399
3400        fig = self.subclusters_.find_clusters_UMAP(
3401            umap_n=umap_num,
3402            eps=eps,
3403            min_samples=min_samples,
3404            width=width,
3405            height=height,
3406        )
3407
3408        clusters = self.subclusters_.return_clusters(clusters="umap")
3409
3410        self.subclusters_.subclusters = [str(x) for x in list(clusters)]
3411
3412        return fig
3413
3414    def subcluster_features_scatter(
3415        self,
3416        colors="viridis",
3417        hclust="complete",
3418        scale=False,
3419        img_width=3,
3420        img_high=5,
3421        label_size=6,
3422        size_scale=70,
3423        y_lab="Genes",
3424        legend_lab="normalized",
3425        bbox_to_anchor_scale: int = 25,
3426        bbox_to_anchor_perc: tuple = (0.91, 0.63),
3427    ):
3428        """
3429        Create a features-scatter visualization for the subclusters (averaged and occurrence).
3430
3431        Parameters
3432        ----------
3433        colors : str, default 'viridis'
3434            Colormap name passed to `features_scatter`.
3435
3436        hclust : str or None
3437            Hierarchical clustering linkage to order rows/columns.
3438
3439        scale: bool, default False
3440            If True, expression data will be scaled (0–1) across the rows (features).
3441
3442        img_width, img_high : float
3443            Figure size.
3444
3445        label_size : int
3446            Font size for labels.
3447
3448        size_scale : int
3449            Bubble size scaling.
3450
3451        y_lab : str
3452            X axis label.
3453
3454        legend_lab : str
3455            Colorbar label.
3456
3457        bbox_to_anchor_scale : int, default=25
3458            Vertical scale (percentage) for positioning the colorbar.
3459
3460        bbox_to_anchor_perc : tuple, default=(0.91, 0.63)
3461            Anchor position for the size legend (percent bubble legend).
3462
3463        Returns
3464        -------
3465        matplotlib.figure.Figure
3466
3467        Raises
3468        ------
3469        RuntimeError
3470            If subcluster preparation/definition has not been run.
3471        """
3472
3473        if self.subclusters_ is None or self.subclusters_.subclusters is None:
3474            raise RuntimeError(
3475                "Nothing to return. 'self.subcluster_prepare' -> 'self.define_subclusters' pip was not conducted!"
3476            )
3477
3478        dat = self.normalized_data
3479        dat.columns = list(self.input_metadata["cell_names"])
3480
3481        dat = reduce_data(
3482            self.normalized_data,
3483            features=self.subclusters_.current_features,
3484            names=[self.subclusters_.current_cluster],
3485        )
3486
3487        dat.columns = self.subclusters_.subclusters
3488
3489        avg = average(dat)
3490        occ = occurrence(dat)
3491
3492        scatter = features_scatter(
3493            expression_data=avg,
3494            occurence_data=occ,
3495            features=None,
3496            scale=scale,
3497            metadata_list=None,
3498            colors=colors,
3499            hclust=hclust,
3500            img_width=img_width,
3501            img_high=img_high,
3502            label_size=label_size,
3503            size_scale=size_scale,
3504            y_lab=y_lab,
3505            legend_lab=legend_lab,
3506            bbox_to_anchor_scale=bbox_to_anchor_scale,
3507            bbox_to_anchor_perc=bbox_to_anchor_perc,
3508        )
3509
3510        return scatter
3511
3512    def subcluster_DEG_scatter(
3513        self,
3514        top_n=3,
3515        min_exp=0,
3516        min_pct=0.25,
3517        p_val=0.05,
3518        colors="viridis",
3519        hclust="complete",
3520        scale=False,
3521        img_width=3,
3522        img_high=5,
3523        label_size=6,
3524        size_scale=70,
3525        y_lab="Genes",
3526        legend_lab="normalized",
3527        bbox_to_anchor_scale: int = 25,
3528        bbox_to_anchor_perc: tuple = (0.91, 0.63),
3529        n_proc=10,
3530    ):
3531        """
3532        Plot top differential features (DEGs) for subclusters as a features-scatter.
3533
3534        Parameters
3535        ----------
3536        top_n : int, default 3
3537            Number of top features per subcluster to show.
3538
3539        min_exp : float, default 0
3540            Minimum expression threshold passed to `statistic`.
3541
3542        min_pct : float, default 0.25
3543            Minimum percent expressed in target group.
3544
3545        p_val: float, default 0.05
3546            Maximum p-value for visualizing features.
3547
3548        n_proc : int, default 10
3549            Parallel jobs used for DEG calculation.
3550
3551        scale: bool, default False
3552            If True, expression_data will be scaled (0–1) across the rows (features).
3553
3554        colors : str, default='viridis'
3555            Colormap for expression values.
3556
3557        hclust : str or None, default='complete'
3558            Linkage method for hierarchical clustering. If None, no clustering
3559            is performed.
3560
3561        img_width : int or float, default=8
3562            Width of the plot in inches.
3563
3564        img_high : int or float, default=5
3565            Height of the plot in inches.
3566
3567        label_size : int, default=10
3568            Font size for axis labels and ticks.
3569
3570        size_scale : int or float, default=100
3571            Scaling factor for bubble sizes.
3572
3573        y_lab : str, default='Genes'
3574            Label for the x-axis.
3575
3576        legend_lab : str, default='normalized'
3577            Label for the colorbar legend.
3578
3579        bbox_to_anchor_scale : int, default=25
3580            Vertical scale (percentage) for positioning the colorbar.
3581
3582        bbox_to_anchor_perc : tuple, default=(0.91, 0.63)
3583            Anchor position for the size legend (percent bubble legend).
3584
3585        Returns
3586        -------
3587        matplotlib.figure.Figure
3588
3589        Raises
3590        ------
3591        RuntimeError
3592            If subcluster preparation/definition has not been run.
3593
3594        Notes
3595        -----
3596        Internally calls `calc_DEG` (or equivalent) to obtain statistics, filters
3597        by p-value and effect-size, selects top features per valid group and plots them.
3598        """
3599
3600        if self.subclusters_ is None or self.subclusters_.subclusters is None:
3601            raise RuntimeError(
3602                "Nothing to return. 'self.subcluster_prepare' -> 'self.define_subclusters' pip was not conducted!"
3603            )
3604
3605        dat = self.normalized_data
3606        dat.columns = list(self.input_metadata["cell_names"])
3607
3608        dat = reduce_data(
3609            self.normalized_data, names=[self.subclusters_.current_cluster]
3610        )
3611
3612        dat.columns = self.subclusters_.subclusters
3613
3614        deg_stats = calc_DEG(
3615            dat,
3616            metadata_list=None,
3617            entities="All",
3618            sets=None,
3619            min_exp=min_exp,
3620            min_pct=min_pct,
3621            n_proc=n_proc,
3622        )
3623
3624        deg_stats = deg_stats[deg_stats["p_val"] <= p_val]
3625        deg_stats = deg_stats[deg_stats["log(FC)"] > 0]
3626
3627        deg_stats = (
3628            deg_stats.sort_values(
3629                ["valid_group", "esm", "log(FC)"], ascending=[True, False, False]
3630            )
3631            .groupby("valid_group")
3632            .head(top_n)
3633        )
3634
3635        dat = reduce_data(dat, features=list(set(deg_stats["feature"])))
3636
3637        avg = average(dat)
3638        occ = occurrence(dat)
3639
3640        scatter = features_scatter(
3641            expression_data=avg,
3642            occurence_data=occ,
3643            features=None,
3644            metadata_list=None,
3645            colors=colors,
3646            hclust=hclust,
3647            img_width=img_width,
3648            img_high=img_high,
3649            label_size=label_size,
3650            size_scale=size_scale,
3651            y_lab=y_lab,
3652            legend_lab=legend_lab,
3653            bbox_to_anchor_scale=bbox_to_anchor_scale,
3654            bbox_to_anchor_perc=bbox_to_anchor_perc,
3655        )
3656
3657        return scatter
3658
3659    def accept_subclusters(self):
3660        """
3661        Commit subcluster labels into the main `input_metadata` by renaming cell names.
3662
3663        The method replaces occurrences of the parent cluster name in `self.input_metadata['cell_names']`
3664        with the expanded names that include subcluster suffixes (via `add_subnames`),
3665        then clears `self.subclusters_`.
3666
3667        Update
3668        ------
3669        Modifies `self.input_metadata['cell_names']`.
3670
3671        Resets `self.subclusters_` to None.
3672
3673        Raises
3674        ------
3675        RuntimeError
3676            If `self.subclusters_` is not defined or subclusters were not computed.
3677        """
3678
3679        if self.subclusters_ is None or self.subclusters_.subclusters is None:
3680            raise RuntimeError(
3681                "Nothing to return. 'self.subcluster_prepare' -> 'self.define_subclusters' pip was not conducted!"
3682            )
3683
3684        new_meta = add_subnames(
3685            list(self.input_metadata["cell_names"]),
3686            parent_name=self.subclusters_.current_cluster,
3687            new_clusters=self.subclusters_.subclusters,
3688        )
3689
3690        self.input_metadata["cell_names"] = new_meta
3691
3692        self.subclusters_ = None
3693
3694    def scatter_plot(
3695        self,
3696        names: list | None = None,
3697        features: list | None = None,
3698        name_slot: str = "cell_names",
3699        scale=True,
3700        colors="viridis",
3701        hclust=None,
3702        img_width=15,
3703        img_high=1,
3704        label_size=10,
3705        size_scale=200,
3706        y_lab="Genes",
3707        legend_lab="log(CPM + 1)",
3708        set_box_size: float | int = 5,
3709        set_box_high: float | int = 5,
3710        bbox_to_anchor_scale=25,
3711        bbox_to_anchor_perc=(0.90, -0.24),
3712        bbox_to_anchor_group=(1.01, 0.4),
3713    ):
3714        """
3715        Create a bubble scatter plot of selected features across samples inside project.
3716
3717        Each point represents a feature-sample pair, where the color encodes the
3718        expression value and the size encodes occurrence or relative abundance.
3719        Optionally, hierarchical clustering can be applied to order rows and columns.
3720
3721        Parameters
3722        ----------
3723        names : list, str, or None
3724            Names of samples to include. If None, all samples are considered.
3725
3726        features : list, str, or None
3727            Names of features to include. If None, all features are considered.
3728
3729        name_slot : str
3730            Column in metadata to use as sample names.
3731
3732        scale: bool, default False
3733            If True, expression_data will be scaled (0–1) across the rows (features).
3734
3735        colors : str, default='viridis'
3736            Colormap for expression values.
3737
3738        hclust : str or None, default='complete'
3739            Linkage method for hierarchical clustering. If None, no clustering
3740            is performed.
3741
3742        img_width : int or float, default=8
3743            Width of the plot in inches.
3744
3745        img_high : int or float, default=5
3746            Height of the plot in inches.
3747
3748        label_size : int, default=10
3749            Font size for axis labels and ticks.
3750
3751        size_scale : int or float, default=100
3752            Scaling factor for bubble sizes.
3753
3754        y_lab : str, default='Genes'
3755            Label for the x-axis.
3756
3757        legend_lab : str, default='log(CPM + 1)'
3758            Label for the colorbar legend.
3759
3760        bbox_to_anchor_scale : int, default=25
3761            Vertical scale (percentage) for positioning the colorbar.
3762
3763        bbox_to_anchor_perc : tuple, default=(0.91, 0.63)
3764            Anchor position for the size legend (percent bubble legend).
3765
3766        bbox_to_anchor_group : tuple, default=(1.01, 0.4)
3767            Anchor position for the group legend.
3768
3769        Returns
3770        -------
3771        matplotlib.figure.Figure
3772            The generated scatter plot figure.
3773
3774        Notes
3775        -----
3776        Colors represent expression values normalized to the colormap.
3777        """
3778
3779        prtd, met = self.get_partial_data(
3780            names=names, features=features, name_slot=name_slot, inc_metadata=True
3781        )
3782
3783        if scale:
3784
3785            legend_lab = "Scaled\n" + legend_lab
3786
3787            scaler = MinMaxScaler(feature_range=(0, 1))
3788            prtd = pd.DataFrame(
3789                scaler.fit_transform(prtd.T).T,
3790                index=prtd.index,
3791                columns=prtd.columns,
3792            )
3793
3794        prtd.columns = prtd.columns + "#" + met["sets"]
3795
3796        prtd_avg = average(prtd)
3797
3798        meta_sets = [re.sub(".*#", "", x) for x in prtd_avg.columns]
3799
3800        prtd_avg.columns = [re.sub("#.*", "", x) for x in prtd_avg.columns]
3801
3802        prtd_occ = occurrence(prtd)
3803
3804        prtd_occ.columns = [re.sub("#.*", "", x) for x in prtd_occ.columns]
3805
3806        fig_scatter = features_scatter(
3807            expression_data=prtd_avg,
3808            occurence_data=prtd_occ,
3809            scale=scale,
3810            features=None,
3811            metadata_list=meta_sets,
3812            colors=colors,
3813            hclust=hclust,
3814            img_width=img_width,
3815            img_high=img_high,
3816            label_size=label_size,
3817            size_scale=size_scale,
3818            y_lab=y_lab,
3819            legend_lab=legend_lab,
3820            set_box_size=set_box_size,
3821            set_box_high=set_box_high,
3822            bbox_to_anchor_scale=bbox_to_anchor_scale,
3823            bbox_to_anchor_perc=bbox_to_anchor_perc,
3824            bbox_to_anchor_group=bbox_to_anchor_group,
3825        )
3826
3827        return fig_scatter
3828
3829    def data_composition(
3830        self,
3831        features_count: list | None,
3832        name_slot: str = "cell_names",
3833        set_sep: bool = True,
3834    ):
3835        """
3836        Compute composition of cell types in data set.
3837
3838        This function counts the occurrences of specific cells (e.g., cell types, subtypes)
3839        within metadata entries, calculates their relative percentages, and stores
3840        the results in `self.composition_data`.
3841
3842        Parameters
3843        ----------
3844        features_count : list or None
3845            List of features (part or full names) to be counted.
3846            If None, all unique elements from the specified `name_slot` metadata field are used.
3847
3848        name_slot : str, default 'cell_names'
3849            Metadata field containing sample identifiers or labels.
3850
3851        set_sep : bool, default True
3852            If True and multiple sets are present in metadata, compute composition
3853            separately for each set.
3854
3855        Update
3856        -------
3857        Stores results in `self.composition_data` as a pandas DataFrame with:
3858        - 'name': feature name
3859        - 'n': number of occurrences
3860        - 'pct': percentage of occurrences
3861        - 'set' (if applicable): dataset identifier
3862        """
3863
3864        validated_list = list(self.input_metadata[name_slot])
3865        sets = list(self.input_metadata["sets"])
3866
3867        if features_count is None:
3868            features_count = list(set(self.input_metadata[name_slot]))
3869
3870        if set_sep and len(set(sets)) > 1:
3871
3872            final_res = pd.DataFrame()
3873
3874            for s in set(sets):
3875                print(s)
3876
3877                mask = [True if s == x else False for x in sets]
3878
3879                tmp_val_list = np.array(validated_list)
3880
3881                tmp_val_list = list(tmp_val_list[mask])
3882
3883                res_dict = {"name": [], "n": [], "set": []}
3884
3885                for f in tqdm(features_count):
3886                    res_dict["n"].append(
3887                        sum(1 for element in tmp_val_list if f in element)
3888                    )
3889                    res_dict["name"].append(f)
3890                    res_dict["set"].append(s)
3891                    res = pd.DataFrame(res_dict)
3892                    res["pct"] = res["n"] / sum(res["n"]) * 100
3893                    res["pct"] = res["pct"].round(2)
3894
3895                final_res = pd.concat([final_res, res])
3896
3897            res = final_res.sort_values(["set", "pct"], ascending=[True, False])
3898
3899        else:
3900
3901            res_dict = {"name": [], "n": []}
3902
3903            for f in tqdm(features_count):
3904                res_dict["n"].append(
3905                    sum(1 for element in validated_list if f in element)
3906                )
3907                res_dict["name"].append(f)
3908
3909            res = pd.DataFrame(res_dict)
3910            res["pct"] = res["n"] / sum(res["n"]) * 100
3911            res["pct"] = res["pct"].round(2)
3912
3913            res = res.sort_values("pct", ascending=False)
3914
3915        self.composition_data = res
3916
3917    def composition_pie(
3918        self,
3919        width=6,
3920        height=6,
3921        font_size=15,
3922        cmap: str = "tab20",
3923        legend_split_col: int = 1,
3924        offset_labels: float | int = 0.5,
3925        legend_bbox: tuple = (1.15, 0.95),
3926    ):
3927        """
3928        Visualize the composition of cell lineages using pie charts.
3929
3930        Generates pie charts showing the relative proportions of features stored
3931        in `self.composition_data`. If multiple sets are present, a separate
3932        chart is drawn for each set.
3933
3934        Parameters
3935        ----------
3936        width : int, default 6
3937            Width of the figure.
3938
3939        height : int, default 6
3940            Height of the figure (applied per set if multiple sets are plotted).
3941
3942        font_size : int, default 15
3943            Font size for labels and annotations.
3944
3945        cmap : str, default 'tab20'
3946            Colormap used for pie slices.
3947
3948        legend_split_col : int, default 1
3949            Number of columns in the legend.
3950
3951        offset_labels : float or int, default 0.5
3952            Spacing offset for label placement relative to pie slices.
3953
3954        legend_bbox : tuple, default (1.15, 0.95)
3955            Bounding box anchor position for the legend.
3956
3957        Returns
3958        -------
3959        matplotlib.figure.Figure
3960            Pie chart visualization of composition data.
3961        """
3962
3963        df = self.composition_data
3964
3965        if "set" in df.columns and len(set(df["set"])) > 1:
3966
3967            sets = list(set(df["set"]))
3968            fig, axes = plt.subplots(len(sets), 1, figsize=(width, height * len(sets)))
3969
3970            all_wedges = []
3971            cmap = plt.get_cmap(cmap)
3972
3973            set_nam = len(set(df["name"]))
3974
3975            legend_labels = list(set(df["name"]))
3976
3977            colors = [cmap(i / set_nam) for i in range(set_nam)]
3978
3979            cmap_dict = dict(zip(legend_labels, colors))
3980
3981            for idx, s in enumerate(sets):
3982                ax = axes[idx]
3983                tmp_df = df[df["set"] == s].reset_index(drop=True)
3984
3985                labels = [f"{row['pct']:.1f}%" for _, row in tmp_df.iterrows()]
3986
3987                wedges, _ = ax.pie(
3988                    tmp_df["n"],
3989                    startangle=90,
3990                    labeldistance=1.05,
3991                    colors=[cmap_dict[x] for x in tmp_df["name"]],
3992                    wedgeprops={"linewidth": 0.5, "edgecolor": "black"},
3993                )
3994
3995                all_wedges.extend(wedges)
3996
3997                kw = dict(arrowprops=dict(arrowstyle="-"), zorder=0, va="center")
3998                n = 0
3999                for i, p in enumerate(wedges):
4000                    ang = (p.theta2 - p.theta1) / 2.0 + p.theta1
4001                    y = np.sin(np.deg2rad(ang))
4002                    x = np.cos(np.deg2rad(ang))
4003                    horizontalalignment = {-1: "right", 1: "left"}[int(np.sign(x))]
4004                    connectionstyle = f"angle,angleA=0,angleB={ang}"
4005                    kw["arrowprops"].update({"connectionstyle": connectionstyle})
4006                    if len(labels[i]) > 0:
4007                        n += offset_labels
4008                        ax.annotate(
4009                            labels[i],
4010                            xy=(x, y),
4011                            xytext=(1.01 * x + (n * x / 4), 1.01 * y + (n * y / 4)),
4012                            horizontalalignment=horizontalalignment,
4013                            fontsize=font_size,
4014                            weight="bold",
4015                            **kw,
4016                        )
4017
4018                circle2 = plt.Circle((0, 0), 0.6, color="white", ec="black")
4019                ax.add_artist(circle2)
4020
4021                ax.text(
4022                    0,
4023                    0,
4024                    f"{s}",
4025                    ha="center",
4026                    va="center",
4027                    fontsize=font_size,
4028                    weight="bold",
4029                )
4030
4031            legend_handles = [
4032                Patch(facecolor=cmap_dict[label], edgecolor="black", label=label)
4033                for label in legend_labels
4034            ]
4035
4036            fig.legend(
4037                handles=legend_handles,
4038                loc="center right",
4039                bbox_to_anchor=legend_bbox,
4040                ncol=legend_split_col,
4041                title="",
4042            )
4043
4044            plt.tight_layout()
4045            plt.show()
4046
4047        else:
4048
4049            labels = [f"{row['pct']:.1f}%" for _, row in df.iterrows()]
4050
4051            legend_labels = [f"{row['name']}" for _, row in df.iterrows()]
4052
4053            cmap = plt.get_cmap(cmap)
4054            colors = [cmap(i / len(df)) for i in range(len(df))]
4055
4056            fig, ax = plt.subplots(
4057                figsize=(width, height), subplot_kw=dict(aspect="equal")
4058            )
4059
4060            wedges, _ = ax.pie(
4061                df["n"],
4062                startangle=90,
4063                labeldistance=1.05,
4064                colors=colors,
4065                wedgeprops={"linewidth": 0.5, "edgecolor": "black"},
4066            )
4067
4068            kw = dict(arrowprops=dict(arrowstyle="-"), zorder=0, va="center")
4069            n = 0
4070            for i, p in enumerate(wedges):
4071                ang = (p.theta2 - p.theta1) / 2.0 + p.theta1
4072                y = np.sin(np.deg2rad(ang))
4073                x = np.cos(np.deg2rad(ang))
4074                horizontalalignment = {-1: "right", 1: "left"}[int(np.sign(x))]
4075                connectionstyle = "angle,angleA=0,angleB={}".format(ang)
4076                kw["arrowprops"].update({"connectionstyle": connectionstyle})
4077                if len(labels[i]) > 0:
4078                    n += offset_labels
4079
4080                    ax.annotate(
4081                        labels[i],
4082                        xy=(x, y),
4083                        xytext=(1.01 * x + (n * x / 4), y * 1.01 + (n * y / 4)),
4084                        horizontalalignment=horizontalalignment,
4085                        fontsize=font_size,
4086                        weight="bold",
4087                        **kw,
4088                    )
4089
4090            circle2 = plt.Circle((0, 0), 0.6, color="white")
4091            circle2.set_edgecolor("black")
4092
4093            p = plt.gcf()
4094            p.gca().add_artist(circle2)
4095
4096            ax.legend(
4097                wedges,
4098                legend_labels,
4099                title="",
4100                loc="center left",
4101                bbox_to_anchor=legend_bbox,
4102                ncol=legend_split_col,
4103            )
4104
4105            plt.show()
4106
4107        return fig
4108
4109    def bar_composition(
4110        self,
4111        cmap="tab20b",
4112        width=2,
4113        height=6,
4114        font_size=15,
4115        legend_split_col: int = 1,
4116        legend_bbox: tuple = (1.3, 1),
4117    ):
4118        """
4119        Visualize the composition of cell lineages using bar plots.
4120
4121        Produces bar plots showing the distribution of features stored in
4122        `self.composition_data`. If multiple sets are present, a separate
4123        bar is drawn for each set. Percentages are annotated alongside the bars.
4124
4125        Parameters
4126        ----------
4127        cmap : str, default 'tab20b'
4128            Colormap used for stacked bars.
4129
4130        width : int, default 2
4131            Width of each subplot (per set).
4132
4133        height : int, default 6
4134            Height of the figure.
4135
4136        font_size : int, default 15
4137            Font size for labels and annotations.
4138
4139        legend_split_col : int, default 1
4140            Number of columns in the legend.
4141
4142        legend_bbox : tuple, default (1.3, 1)
4143            Bounding box anchor position for the legend.
4144
4145        Returns
4146        -------
4147        matplotlib.figure.Figure
4148            Stacked bar plot visualization of composition data.
4149        """
4150
4151        df = self.composition_data
4152        df["num"] = range(1, len(df) + 1)
4153
4154        if "set" in df.columns and len(set(df["set"])) > 1:
4155
4156            sets = list(set(df["set"]))
4157            fig, axes = plt.subplots(1, len(sets), figsize=(width * len(sets), height))
4158
4159            cmap = plt.get_cmap(cmap)
4160
4161            set_nam = len(set(df["name"]))
4162
4163            legend_labels = list(set(df["name"]))
4164
4165            colors = [cmap(i / set_nam) for i in range(set_nam)]
4166
4167            cmap_dict = dict(zip(legend_labels, colors))
4168
4169            for idx, s in enumerate(sets):
4170                ax = axes[idx]
4171
4172                tmp_df = df[df["set"] == s].reset_index(drop=True)
4173
4174                values = tmp_df["n"].values
4175                total = sum(values)
4176                values = [v / total * 100 for v in values]
4177                values = [round(v, 2) for v in values]
4178
4179                idx_max = np.argmax(values)
4180                correction = 100 - sum(values)
4181                values[idx_max] += correction
4182
4183                names = tmp_df["name"].values
4184                perc = tmp_df["pct"].values
4185                nums = tmp_df["num"].values
4186
4187                bottom = 0
4188                centers = []
4189                for name, num, val, color in zip(names, nums, values, colors):
4190                    ax.bar(s, val, bottom=bottom, color=cmap_dict[name], label=name)
4191                    centers.append(bottom + val / 2)
4192                    bottom += val
4193
4194                y_positions = np.linspace(centers[0], centers[-1], len(centers))
4195                x_text = -0.8
4196
4197                for y_label, y_center, pct, num in zip(
4198                    y_positions, centers, perc, nums
4199                ):
4200                    ax.annotate(
4201                        f"{pct:.1f}%",
4202                        xy=(0, y_center),
4203                        xycoords="data",
4204                        xytext=(x_text, y_label),
4205                        textcoords="data",
4206                        ha="right",
4207                        va="center",
4208                        fontsize=font_size,
4209                        arrowprops=dict(
4210                            arrowstyle="->",
4211                            lw=1,
4212                            color="black",
4213                            connectionstyle="angle3,angleA=0,angleB=90",
4214                        ),
4215                    )
4216
4217                ax.set_ylim(0, 100)
4218                ax.set_xlabel(s, fontsize=font_size)
4219                ax.xaxis.label.set_rotation(30)
4220
4221                ax.set_xticks([])
4222                ax.set_yticks([])
4223                for spine in ax.spines.values():
4224                    spine.set_visible(False)
4225
4226            legend_handles = [
4227                Patch(facecolor=cmap_dict[label], edgecolor="black", label=label)
4228                for label in legend_labels
4229            ]
4230
4231            fig.legend(
4232                handles=legend_handles,
4233                loc="center right",
4234                bbox_to_anchor=legend_bbox,
4235                ncol=legend_split_col,
4236                title="",
4237            )
4238
4239            plt.tight_layout()
4240            plt.show()
4241
4242        else:
4243
4244            cmap = plt.get_cmap(cmap)
4245
4246            colors = [cmap(i / len(df)) for i in range(len(df))]
4247
4248            fig, ax = plt.subplots(figsize=(width, height))
4249
4250            values = df["n"].values
4251            names = df["name"].values
4252            perc = df["pct"].values
4253            nums = df["num"].values
4254
4255            bottom = 0
4256            centers = []
4257            for name, num, val, color in zip(names, nums, values, colors):
4258                ax.bar(0, val, bottom=bottom, color=color, label=f"{num}) {name}")
4259                centers.append(bottom + val / 2)
4260                bottom += val
4261
4262            y_positions = np.linspace(centers[0], centers[-1], len(centers))
4263            x_text = -0.8
4264
4265            for y_label, y_center, pct, num in zip(y_positions, centers, perc, nums):
4266                ax.annotate(
4267                    f"{num}) {pct}",
4268                    xy=(0, y_center),
4269                    xycoords="data",
4270                    xytext=(x_text, y_label),
4271                    textcoords="data",
4272                    ha="right",
4273                    va="center",
4274                    fontsize=9,
4275                    arrowprops=dict(
4276                        arrowstyle="->",
4277                        lw=1,
4278                        color="black",
4279                        connectionstyle="angle3,angleA=0,angleB=90",
4280                    ),
4281                )
4282
4283            ax.set_xticks([])
4284            ax.set_yticks([])
4285            for spine in ax.spines.values():
4286                spine.set_visible(False)
4287
4288            ax.legend(
4289                title="Legend",
4290                bbox_to_anchor=legend_bbox,
4291                loc="upper left",
4292                ncol=legend_split_col,
4293            )
4294
4295            plt.tight_layout()
4296            plt.show()
4297
4298        return fig
4299
4300    def cell_regression(
4301        self,
4302        cell_x: str,
4303        cell_y: str,
4304        set_x: str | None,
4305        set_y: str | None,
4306        threshold=10,
4307        image_width=12,
4308        image_high=7,
4309        color="black",
4310    ):
4311        """
4312        Perform regression analysis between two selected cells and visualize the relationship.
4313
4314        This function computes a linear regression between two specified cells from
4315        aggregated normalized data, plots the regression line with scatter points,
4316        annotates regression statistics, and highlights potential outliers.
4317
4318        Parameters
4319        ----------
4320        cell_x : str
4321            Name of the first cell (X-axis).
4322
4323        cell_y : str
4324            Name of the second cell (Y-axis).
4325
4326        set_x : str or None
4327            Dataset identifier corresponding to `cell_x`. If None, cell is selected only by name.
4328
4329        set_y : str or None
4330            Dataset identifier corresponding to `cell_y`. If None, cell is selected only by name.
4331
4332        threshold : int or float, default 10
4333            Threshold for detecting outliers. Points deviating from the mean or diagonal by more
4334            than this value are annotated.
4335
4336        image_width : int, default 12
4337            Width of the regression plot (in inches).
4338
4339        image_high : int, default 7
4340            Height of the regression plot (in inches).
4341
4342        color : str, default 'black'
4343            Color of the regression scatter points and line.
4344
4345        Returns
4346        -------
4347        matplotlib.figure.Figure
4348            Regression plot figure with annotated regression line, R², p-value, and outliers.
4349
4350        Raises
4351        ------
4352        ValueError
4353            If `cell_x` or `cell_y` are not found in the dataset.
4354            If multiple matches are found for a cell name and `set_x`/`set_y` are not specified.
4355
4356        Notes
4357        -----
4358        - The function automatically calls `jseq_object.average()` if aggregated data is not available.
4359        - Outliers are annotated with their corresponding index labels.
4360        - Regression is computed using `scipy.stats.linregress`.
4361
4362        Examples
4363        --------
4364        >>> obj.cell_regression(cell_x="Purkinje", cell_y="Granule", set_x="Exp1", set_y="Exp2")
4365        >>> obj.cell_regression(cell_x="NeuronA", cell_y="NeuronB", threshold=5, color="blue")
4366        """
4367
4368        if self.agg_normalized_data is None:
4369            self.average()
4370
4371        metadata = self.agg_metadata
4372        data = self.agg_normalized_data
4373
4374        if set_x is not None and set_y is not None:
4375            data.columns = metadata["cell_names"] + " # " + metadata["sets"]
4376            cell_x = cell_x + " # " + set_x
4377            cell_y = cell_y + " # " + set_y
4378
4379        else:
4380            data.columns = metadata["cell_names"]
4381
4382        if not cell_x in data.columns:
4383            raise ValueError("'cell_x' value not in cell names!")
4384
4385        if not cell_y in data.columns:
4386            raise ValueError("'cell_y' value not in cell names!")
4387
4388        if list(data.columns).count(cell_x) > 1:
4389            raise ValueError(
4390                f"'{cell_x}' occurs more than once. If you want to select a specific cell, "
4391                f"please also provide the corresponding 'set_x' and 'set_y' values."
4392            )
4393
4394        if list(data.columns).count(cell_y) > 1:
4395            raise ValueError(
4396                f"'{cell_y}' occurs more than once. If you want to select a specific cell, "
4397                f"please also provide the corresponding 'set_x' and 'set_y' values."
4398            )
4399
4400        fig, ax = plt.subplots(figsize=(image_width, image_high))
4401        ax = sns.regplot(x=cell_x, y=cell_y, data=data, color=color)
4402
4403        slope, intercept, r_value, p_value, _ = stats.linregress(
4404            data[cell_x], data[cell_y]
4405        )
4406        equation = "y = {:.2f}x + {:.2f}".format(slope, intercept)
4407
4408        ax.annotate(
4409            "R-squared = {:.2f}\nP-value = {:.2f}\n{}".format(
4410                r_value**2, p_value, equation
4411            ),
4412            xy=(0.05, 0.90),
4413            xycoords="axes fraction",
4414            fontsize=12,
4415        )
4416
4417        ax.spines["top"].set_visible(False)
4418        ax.spines["right"].set_visible(False)
4419
4420        diff = []
4421        x_mean, y_mean = data[cell_x].mean(), data[cell_y].mean()
4422        for i, (xi, yi) in enumerate(zip(data[cell_x], data[cell_y])):
4423            diff.append(abs(xi - x_mean))
4424            diff.append(abs(yi - y_mean))
4425
4426        def annotate_outliers(x, y, threshold):
4427            texts = []
4428            x_mean, y_mean = x.mean(), y.mean()
4429            for i, (xi, yi) in enumerate(zip(x, y)):
4430                if (
4431                    abs(xi - x_mean) > threshold
4432                    or abs(yi - y_mean) > threshold
4433                    or abs(yi - xi) > threshold
4434                ):
4435                    text = ax.text(xi, yi, data.index[i])
4436                    texts.append(text)
4437
4438            return texts
4439
4440        texts = annotate_outliers(data[cell_x], data[cell_y], threshold)
4441
4442        adjust_text(texts, arrowprops=dict(arrowstyle="-", color="gray", alpha=0.5))
4443
4444        plt.show()
4445
4446        return fig

A class COMPsc (Comparison of single-cell data) designed for the integration, analysis, and visualization of single-cell datasets. The class supports independent dataset integration, subclustering of existing clusters, marker detection, and multiple visualization strategies.

The COMPsc class provides methods for:

- Normalizing and filtering single-cell data
- Loading and saving sparse 10x-style datasets
- Computing differential expression and marker genes
- Clustering and subclustering analysis
- Visualizing similarity and spatial relationships
- Aggregating data by cell and set annotations
- Managing metadata and renaming labels
- Plotting gene detection histograms and feature scatters

Methods

project_dir(path_to_directory, project_list) Scans a directory to create a COMPsc instance mapping project names to their paths.

save_project(name, path=os.getcwd()) Saves the COMPsc object to a pickle file on disk.

load_project(path) Loads a previously saved COMPsc object from a pickle file.

reduce_cols(reg, inc_set=False) Removes columns from data tables where column names contain a specified name or partial substring.

reduce_rows(reg, inc_set=False) Removes rows from data tables where column names contain a specified feature (gene) name.

get_data(set_info=False) Returns normalized data with optional set annotations in column names.

get_metadata() Returns the stored input metadata.

get_partial_data(names=None, features=None, name_slot='cell_names') Return a subset of the data by sample names and/or features.

gene_calculation() Calculates and stores per-cell gene detection counts as a pandas Series.

gene_histograme(bins=100) Plots a histogram of genes detected per cell with an overlaid normal distribution.

gene_threshold(min_n=None, max_n=None) Filters cells based on minimum and/or maximum gene detection thresholds.

load_sparse_from_projects(normalized_data=False) Loads and concatenates sparse 10x-style datasets from project paths into count or normalized data.

rename_names(mapping, slot='cell_names') Renames entries in a specified metadata column using a provided mapping dictionary.

rename_subclusters(mapping) Renames subcluster labels using a provided mapping dictionary.

save_sparse(path_to_save=os.getcwd(), name_slot='cell_names', data_slot='normalized') Exports data as 10x-compatible sparse files (matrix.mtx, barcodes.tsv, genes.tsv).

normalize_data(normalize=True, normalize_factor=100000) Normalizes raw counts to counts-per-specified factor (e.g., CPM-like).

statistic(cells=None, sets=None, min_exp=0.01, min_pct=0.1, n_proc=10) Computes per-feature differential expression statistics (Mann-Whitney U) comparing target vs. rest groups.

calculate_difference_markers(min_exp=0, min_pct=0.25, n_proc=10, force=False) Computes and caches differential markers using the statistic method.

clustering_features(features_list=None, name_slot='cell_names', p_val=0.05, top_n=25, adj_mean=True, beta=0.4) Prepares clustering input by selecting marker features and optionally smoothing cell values.

average() Aggregates normalized data by averaging across (cell_name, set) pairs.

estimating_similarity(method='pearson', p_val=0.05, top_n=25) Computes pairwise correlation and Euclidean distance between aggregated samples.

similarity_plot(split_sets=True, set_info=True, cmap='seismic', width=12, height=10) Visualizes pairwise similarity as a scatter plot with correlation as hue and scaled distance as point size.

spatial_similarity(set_info=True, bandwidth=1, n_neighbors=5, min_dist=0.1, legend_split=2, point_size=20, ...) Creates a UMAP-like visualization of similarity relationships with cluster hulls and nearest-neighbor arrows.

subcluster_prepare(features, cluster) Initializes a Clustering object for subcluster analysis on a selected parent cluster.

define_subclusters(umap_num=2, eps=0.5, min_samples=10, bandwidth=1, n_neighbors=5, min_dist=0.1, ...) Performs UMAP and DBSCAN clustering on prepared subcluster data and stores cluster labels.

subcluster_features_scatter(colors='viridis', hclust='complete', img_width=3, img_high=5, label_size=6, ...) Visualizes averaged expression and occurrence of features for subclusters as a scatter plot.

subcluster_DEG_scatter(top_n=3, min_exp=0, min_pct=0.25, p_val=0.05, colors='viridis', ...) Plots top differential features for subclusters as a features-scatter visualization.

accept_subclusters() Commits subcluster labels to main metadata by renaming cell names and clears subcluster data.

Raises

ValueError For invalid parameters, mismatched dimensions, or missing metadata.

COMPsc(objects=None)
1146    def __init__(
1147        self,
1148        objects=None,
1149    ):
1150        """
1151        Initialize the COMPsc class for single-cell data integration and analysis.
1152
1153        Parameters
1154        ----------
1155        objects : list or None, optional
1156            Optional list of data objects to initialize the instance with.
1157
1158        Attributes
1159        ----------
1160        -objects : list or None
1161        -input_data : pandas.DataFrame or None
1162        -input_metadata : pandas.DataFrame or None
1163        -normalized_data : pandas.DataFrame or None
1164        -agg_metadata : pandas.DataFrame or None
1165        -agg_normalized_data : pandas.DataFrame or None
1166        -similarity : pandas.DataFrame or None
1167        -var_data : pandas.DataFrame or None
1168        -subclusters_ : instance of Clustering class or None
1169        -cells_calc : pandas.Series or None
1170        -gene_calc : pandas.Series or None
1171        -composition_data : pandas.DataFrame or None
1172        """
1173
1174        self.objects = objects
1175        """ Stores the input data objects."""
1176
1177        self.input_data = None
1178        """Raw input data for clustering or integration analysis."""
1179
1180        self.input_metadata = None
1181        """Metadata associated with the input data."""
1182
1183        self.normalized_data = None
1184        """Normalized version of the input data."""
1185
1186        self.agg_metadata = None
1187        '''Aggregated metadata for all sets in object related to "agg_normalized_data"'''
1188
1189        self.agg_normalized_data = None
1190        """Aggregated and normalized data across multiple sets."""
1191
1192        self.similarity = None
1193        """Similarity data between cells across all samples. and sets"""
1194
1195        self.var_data = None
1196        """DEG analysis results summarizing variance across all samples in the object."""
1197
1198        self.subclusters_ = None
1199        """Placeholder for information about subclusters analysis; if computed."""
1200
1201        self.cells_calc = None
1202        """Number of cells detected per sample (grouped by lineage, e.g., cluster or name), reflecting data composition."""
1203
1204        self.gene_calc = None
1205        """Number of genes detected per sample (cell), reflecting the sequencing depth."""
1206
1207        self.composition_data = None
1208        """Data describing composition of cells across clusters or sets."""

Initialize the COMPsc class for single-cell data integration and analysis.

Parameters

objects : list or None, optional Optional list of data objects to initialize the instance with.

Attributes

-objects : list or None -input_data : pandas.DataFrame or None -input_metadata : pandas.DataFrame or None -normalized_data : pandas.DataFrame or None -agg_metadata : pandas.DataFrame or None -agg_normalized_data : pandas.DataFrame or None -similarity : pandas.DataFrame or None -var_data : pandas.DataFrame or None -subclusters_ : instance of Clustering class or None -cells_calc : pandas.Series or None -gene_calc : pandas.Series or None -composition_data : pandas.DataFrame or None

objects

Stores the input data objects.

input_data

Raw input data for clustering or integration analysis.

input_metadata

Metadata associated with the input data.

normalized_data

Normalized version of the input data.

agg_metadata

Aggregated metadata for all sets in object related to "agg_normalized_data"

agg_normalized_data

Aggregated and normalized data across multiple sets.

similarity

Similarity data between cells across all samples. and sets

var_data

DEG analysis results summarizing variance across all samples in the object.

subclusters_

Placeholder for information about subclusters analysis; if computed.

cells_calc

Number of cells detected per sample (grouped by lineage, e.g., cluster or name), reflecting data composition.

gene_calc

Number of genes detected per sample (cell), reflecting the sequencing depth.

composition_data

Data describing composition of cells across clusters or sets.

@classmethod
def project_dir(cls, path_to_directory, project_list):
1210    @classmethod
1211    def project_dir(cls, path_to_directory, project_list):
1212        """
1213        Scan a directory and build a COMPsc instance mapping provided project names
1214        to their paths.
1215
1216        Parameters
1217        ----------
1218        path_to_directory : str
1219            Path containing project subfolders.
1220
1221        project_list : list[str]
1222            List of filenames (folder names) to include in the returned object map.
1223
1224        Returns
1225        -------
1226        COMPsc
1227            New COMPsc instance with `objects` populated.
1228
1229        Raises
1230        ------
1231        Exception
1232            A generic exception is caught and a message printed if scanning fails.
1233
1234        Notes
1235        -----
1236        Function attempts to match entries in `project_list` to directory
1237        names and constructs a simplified object key from the folder name.
1238        """
1239        try:
1240            objects = {}
1241            for filename in tqdm(os.listdir(path_to_directory)):
1242                for c in project_list:
1243                    f = os.path.join(path_to_directory, filename)
1244                    if c == filename and os.path.isdir(f):
1245                        objects[
1246                            str(
1247                                re.sub(
1248                                    "_count",
1249                                    "",
1250                                    re.sub(
1251                                        "_fq",
1252                                        "",
1253                                        re.sub(
1254                                            re.sub("_.*", "", filename) + "_",
1255                                            "",
1256                                            filename,
1257                                        ),
1258                                    ),
1259                                )
1260                            )
1261                        ] = f
1262
1263            return cls(objects)
1264
1265        except:
1266            print("Something went wrong. Check the function input data and try again!")

Scan a directory and build a COMPsc instance mapping provided project names to their paths.

Parameters

path_to_directory : str Path containing project subfolders.

project_list : list[str] List of filenames (folder names) to include in the returned object map.

Returns

COMPsc New COMPsc instance with objects populated.

Raises

Exception A generic exception is caught and a message printed if scanning fails.

Notes

Function attempts to match entries in project_list to directory names and constructs a simplified object key from the folder name.

def save_project(self, name, path: str = '/mnt/c/Users/merag/Git/JDti'):
1268    def save_project(self, name, path: str = os.getcwd()):
1269        """
1270        Save the COMPsc object to disk using pickle.
1271
1272        Parameters
1273        ----------
1274        name : str
1275            Base filename (without extension) to use when saving.
1276
1277        path : str, default os.getcwd()
1278            Directory in which to save the project file.
1279
1280        Returns
1281        -------
1282        None
1283
1284        Side Effects
1285        ------------
1286        - Writes a file `<path>/<name>.jpkl` containing the pickled object.
1287        - Prints a confirmation message with saved path.
1288        """
1289
1290        full = os.path.join(path, f"{name}.jpkl")
1291
1292        with open(full, "wb") as f:
1293            pickle.dump(self, f)
1294
1295        print(f"Project saved as {full}")

Save the COMPsc object to disk using pickle.

Parameters

name : str Base filename (without extension) to use when saving.

path : str, default os.getcwd() Directory in which to save the project file.

Returns

None

Side Effects

  • Writes a file <path>/<name>.jpkl containing the pickled object.
  • Prints a confirmation message with saved path.
@classmethod
def load_project(cls, path):
1297    @classmethod
1298    def load_project(cls, path):
1299        """
1300        Load a previously saved COMPsc project from a pickle file.
1301
1302        Parameters
1303        ----------
1304        path : str
1305            Full path to the pickled project file.
1306
1307        Returns
1308        -------
1309        COMPsc
1310            The unpickled COMPsc object.
1311
1312        Raises
1313        ------
1314        FileNotFoundError
1315            If the provided path does not exist.
1316        """
1317
1318        if not os.path.exists(path):
1319            raise FileNotFoundError("File does not exist!")
1320        with open(path, "rb") as f:
1321            obj = pickle.load(f)
1322        return obj

Load a previously saved COMPsc project from a pickle file.

Parameters

path : str Full path to the pickled project file.

Returns

COMPsc The unpickled COMPsc object.

Raises

FileNotFoundError If the provided path does not exist.

def reduce_cols( self, reg: str | None = None, full: str | None = None, name_slot: str = 'cell_names', inc_set: bool = False):
1324    def reduce_cols(
1325        self,
1326        reg: str | None = None,
1327        full: str | None = None,
1328        name_slot: str = "cell_names",
1329        inc_set: bool = False,
1330    ):
1331        """
1332        Remove columns (cells) whose names contain a substring `reg` or
1333        full name `full` from available tables.
1334
1335        Parameters
1336        ----------
1337        reg : str | None
1338            Substring to search for in column/cell names; matching columns will be removed.
1339            If not None, `full` must be None.
1340
1341        full : str | None
1342            Full name to search for in column/cell names; matching columns will be removed.
1343            If not None, `reg` must be None.
1344
1345        name_slot : str, default 'cell_names'
1346            Column in metadata to use as sample names.
1347
1348        inc_set : bool, default False
1349            If True, column names are interpreted as 'cell_name # set' when matching.
1350
1351        Update
1352        ------------
1353        Mutates `self.input_data`, `self.normalized_data`, `self.input_metadata`,
1354        `self.agg_normalized_data`, and `self.agg_metadata` (if they exist),
1355        removing columns/rows that match `reg`.
1356
1357        Raises
1358        ------
1359        Raises ValueError if nothing matches the reduction mask.
1360        """
1361
1362        if reg is None and full is None:
1363            raise ValueError(
1364                "Both 'reg' and 'full' arguments not provided. Please provide at least one of them!"
1365            )
1366
1367        if reg is not None and full is not None:
1368            raise ValueError(
1369                "Both 'reg' and 'full' arguments are provided. "
1370                "Please provide only one of them!\n"
1371                "'reg' is used when only part of the name must be detected.\n"
1372                "'full' is used if the full name must be detected."
1373            )
1374
1375        if reg is not None:
1376
1377            if self.input_data is not None:
1378
1379                if inc_set:
1380
1381                    self.input_data.columns = (
1382                        self.input_metadata[name_slot]
1383                        + " # "
1384                        + self.input_metadata["sets"]
1385                    )
1386
1387                else:
1388
1389                    self.input_data.columns = self.input_metadata[name_slot]
1390
1391                mask = [reg.upper() not in x.upper() for x in self.input_data.columns]
1392
1393                if len([y for y in mask if y is False]) == 0:
1394                    raise ValueError("Nothing found to reduce")
1395
1396                self.input_data = self.input_data.loc[:, mask]
1397
1398            if self.normalized_data is not None:
1399
1400                if inc_set:
1401
1402                    self.normalized_data.columns = (
1403                        self.input_metadata[name_slot]
1404                        + " # "
1405                        + self.input_metadata["sets"]
1406                    )
1407
1408                else:
1409
1410                    self.normalized_data.columns = self.input_metadata[name_slot]
1411
1412                mask = [
1413                    reg.upper() not in x.upper() for x in self.normalized_data.columns
1414                ]
1415
1416                if len([y for y in mask if y is False]) == 0:
1417                    raise ValueError("Nothing found to reduce")
1418
1419                self.normalized_data = self.normalized_data.loc[:, mask]
1420
1421            if self.input_metadata is not None:
1422
1423                if inc_set:
1424
1425                    self.input_metadata["drop"] = (
1426                        self.input_metadata[name_slot]
1427                        + " # "
1428                        + self.input_metadata["sets"]
1429                    )
1430
1431                else:
1432
1433                    self.input_metadata["drop"] = self.input_metadata[name_slot]
1434
1435                mask = [
1436                    reg.upper() not in x.upper() for x in self.input_metadata["drop"]
1437                ]
1438
1439                if len([y for y in mask if y is False]) == 0:
1440                    raise ValueError("Nothing found to reduce")
1441
1442                self.input_metadata = self.input_metadata.loc[mask, :].reset_index(
1443                    drop=True
1444                )
1445
1446                self.input_metadata = self.input_metadata.drop(
1447                    columns=["drop"], errors="ignore"
1448                )
1449
1450            if self.agg_normalized_data is not None:
1451
1452                if inc_set:
1453
1454                    self.agg_normalized_data.columns = (
1455                        self.agg_metadata[name_slot] + " # " + self.agg_metadata["sets"]
1456                    )
1457
1458                else:
1459
1460                    self.agg_normalized_data.columns = self.agg_metadata[name_slot]
1461
1462                mask = [
1463                    reg.upper() not in x.upper()
1464                    for x in self.agg_normalized_data.columns
1465                ]
1466
1467                if len([y for y in mask if y is False]) == 0:
1468                    raise ValueError("Nothing found to reduce")
1469
1470                self.agg_normalized_data = self.agg_normalized_data.loc[:, mask]
1471
1472            if self.agg_metadata is not None:
1473
1474                if inc_set:
1475
1476                    self.agg_metadata["drop"] = (
1477                        self.agg_metadata[name_slot] + " # " + self.agg_metadata["sets"]
1478                    )
1479
1480                else:
1481
1482                    self.agg_metadata["drop"] = self.agg_metadata[name_slot]
1483
1484                mask = [reg.upper() not in x.upper() for x in self.agg_metadata["drop"]]
1485
1486                if len([y for y in mask if y is False]) == 0:
1487                    raise ValueError("Nothing found to reduce")
1488
1489                self.agg_metadata = self.agg_metadata.loc[mask, :].reset_index(
1490                    drop=True
1491                )
1492
1493                self.agg_metadata = self.agg_metadata.drop(
1494                    columns=["drop"], errors="ignore"
1495                )
1496
1497        elif full is not None:
1498
1499            if self.input_data is not None:
1500
1501                if inc_set:
1502
1503                    self.input_data.columns = (
1504                        self.input_metadata[name_slot]
1505                        + " # "
1506                        + self.input_metadata["sets"]
1507                    )
1508
1509                    if "#" not in full:
1510
1511                        self.input_data.columns = self.input_metadata[name_slot]
1512
1513                        print(
1514                            "Not include the set info (name # set) in the 'full' argument, where 'inc_set' is True"
1515                            "Only the names will be compared, without considering the set information."
1516                        )
1517
1518                else:
1519
1520                    self.input_data.columns = self.input_metadata[name_slot]
1521
1522                mask = [full.upper() != x.upper() for x in self.input_data.columns]
1523
1524                if len([y for y in mask if y is False]) == 0:
1525                    raise ValueError("Nothing found to reduce")
1526
1527                self.input_data = self.input_data.loc[:, mask]
1528
1529            if self.normalized_data is not None:
1530
1531                if inc_set:
1532
1533                    self.normalized_data.columns = (
1534                        self.input_metadata[name_slot]
1535                        + " # "
1536                        + self.input_metadata["sets"]
1537                    )
1538
1539                    if "#" not in full:
1540
1541                        self.normalized_data.columns = self.input_metadata[name_slot]
1542
1543                        print(
1544                            "Not include the set info (name # set) in the 'full' argument, where 'inc_set' is True"
1545                            "Only the names will be compared, without considering the set information."
1546                        )
1547
1548                else:
1549
1550                    self.normalized_data.columns = self.input_metadata[name_slot]
1551
1552                mask = [full.upper() != x.upper() for x in self.normalized_data.columns]
1553
1554                if len([y for y in mask if y is False]) == 0:
1555                    raise ValueError("Nothing found to reduce")
1556
1557                self.normalized_data = self.normalized_data.loc[:, mask]
1558
1559            if self.input_metadata is not None:
1560
1561                if inc_set:
1562
1563                    self.input_metadata["drop"] = (
1564                        self.input_metadata[name_slot]
1565                        + " # "
1566                        + self.input_metadata["sets"]
1567                    )
1568
1569                    if "#" not in full:
1570
1571                        self.input_metadata["drop"] = self.input_metadata[name_slot]
1572
1573                        print(
1574                            "Not include the set info (name # set) in the 'full' argument, where 'inc_set' is True"
1575                            "Only the names will be compared, without considering the set information."
1576                        )
1577
1578                else:
1579
1580                    self.input_metadata["drop"] = self.input_metadata[name_slot]
1581
1582                mask = [full.upper() != x.upper() for x in self.input_metadata["drop"]]
1583
1584                if len([y for y in mask if y is False]) == 0:
1585                    raise ValueError("Nothing found to reduce")
1586
1587                self.input_metadata = self.input_metadata.loc[mask, :].reset_index(
1588                    drop=True
1589                )
1590
1591                self.input_metadata = self.input_metadata.drop(
1592                    columns=["drop"], errors="ignore"
1593                )
1594
1595            if self.agg_normalized_data is not None:
1596
1597                if inc_set:
1598
1599                    self.agg_normalized_data.columns = (
1600                        self.agg_metadata[name_slot] + " # " + self.agg_metadata["sets"]
1601                    )
1602
1603                    if "#" not in full:
1604
1605                        self.agg_normalized_data.columns = self.agg_metadata[name_slot]
1606
1607                        print(
1608                            "Not include the set info (name # set) in the 'full' argument, where 'inc_set' is True"
1609                            "Only the names will be compared, without considering the set information."
1610                        )
1611                else:
1612
1613                    self.agg_normalized_data.columns = self.agg_metadata[name_slot]
1614
1615                mask = [
1616                    full.upper() != x.upper() for x in self.agg_normalized_data.columns
1617                ]
1618
1619                if len([y for y in mask if y is False]) == 0:
1620                    raise ValueError("Nothing found to reduce")
1621
1622                self.agg_normalized_data = self.agg_normalized_data.loc[:, mask]
1623
1624            if self.agg_metadata is not None:
1625
1626                if inc_set:
1627
1628                    self.agg_metadata["drop"] = (
1629                        self.agg_metadata[name_slot] + " # " + self.agg_metadata["sets"]
1630                    )
1631
1632                    if "#" not in full:
1633
1634                        self.agg_metadata["drop"] = self.agg_metadata[name_slot]
1635
1636                        print(
1637                            "Not include the set info (name # set) in the 'full' argument, where 'inc_set' is True"
1638                            "Only the names will be compared, without considering the set information."
1639                        )
1640                else:
1641
1642                    self.agg_metadata["drop"] = self.agg_metadata[name_slot]
1643
1644                mask = [full.upper() != x.upper() for x in self.agg_metadata["drop"]]
1645
1646                if len([y for y in mask if y is False]) == 0:
1647                    raise ValueError("Nothing found to reduce")
1648
1649                self.agg_metadata = self.agg_metadata.loc[mask, :].reset_index(
1650                    drop=True
1651                )
1652
1653                self.agg_metadata = self.agg_metadata.drop(
1654                    columns=["drop"], errors="ignore"
1655                )
1656
1657        self.gene_calculation()
1658        self.cells_calculation()

Remove columns (cells) whose names contain a substring reg or full name full from available tables.

Parameters

reg : str | None Substring to search for in column/cell names; matching columns will be removed. If not None, full must be None.

full : str | None Full name to search for in column/cell names; matching columns will be removed. If not None, reg must be None.

name_slot : str, default 'cell_names' Column in metadata to use as sample names.

inc_set : bool, default False If True, column names are interpreted as 'cell_name # set' when matching.

Update

Mutates self.input_data, self.normalized_data, self.input_metadata, self.agg_normalized_data, and self.agg_metadata (if they exist), removing columns/rows that match reg.

Raises

Raises ValueError if nothing matches the reduction mask.

def reduce_rows(self, features_list: list):
1660    def reduce_rows(self, features_list: list):
1661        """
1662        Remove rows (features) whose names are included in features_list.
1663
1664        Parameters
1665        ----------
1666        features_list : list
1667            List of features to search for in index/gene names; matching entries will be removed.
1668
1669        Update
1670        ------------
1671        Mutates `self.input_data`, `self.normalized_data`, `self.input_metadata`,
1672        `self.agg_normalized_data`, and `self.agg_metadata` (if they exist),
1673        removing columns/rows that match `reg`.
1674
1675        Raises
1676        ------
1677        Prints a message listing features that are not found in the data.
1678        """
1679
1680        if self.input_data is not None:
1681
1682            res = find_features(self.input_data, features=features_list)
1683
1684            res_list = [x.upper() for x in res["included"]]
1685
1686            mask = [x.upper() not in res_list for x in self.input_data.index]
1687
1688            if len([y for y in mask if y is False]) == 0:
1689                raise ValueError("Nothing found to reduce")
1690
1691            self.input_data = self.input_data.loc[mask, :]
1692
1693        if self.normalized_data is not None:
1694
1695            res = find_features(self.normalized_data, features=features_list)
1696
1697            res_list = [x.upper() for x in res["included"]]
1698
1699            mask = [x.upper() not in res_list for x in self.normalized_data.index]
1700
1701            if len([y for y in mask if y is False]) == 0:
1702                raise ValueError("Nothing found to reduce")
1703
1704            self.normalized_data = self.normalized_data.loc[mask, :]
1705
1706        if self.agg_normalized_data is not None:
1707
1708            res = find_features(self.agg_normalized_data, features=features_list)
1709
1710            res_list = [x.upper() for x in res["included"]]
1711
1712            mask = [x.upper() not in res_list for x in self.agg_normalized_data.index]
1713
1714            if len([y for y in mask if y is False]) == 0:
1715                raise ValueError("Nothing found to reduce")
1716
1717            self.agg_normalized_data = self.agg_normalized_data.loc[mask, :]
1718
1719        if len(res["not_included"]) > 0:
1720            print("\nFeatures not found:")
1721            for i in res["not_included"]:
1722                print(i)
1723
1724        self.gene_calculation()
1725        self.cells_calculation()

Remove rows (features) whose names are included in features_list.

Parameters

features_list : list List of features to search for in index/gene names; matching entries will be removed.

Update

Mutates self.input_data, self.normalized_data, self.input_metadata, self.agg_normalized_data, and self.agg_metadata (if they exist), removing columns/rows that match reg.

Raises

Prints a message listing features that are not found in the data.

def get_data(self, set_info: bool = False):
1727    def get_data(self, set_info: bool = False):
1728        """
1729        Return normalized data with optional set annotation appended to column names.
1730
1731        Parameters
1732        ----------
1733        set_info : bool, default False
1734            If True, column names are returned as "cell_name # set"; otherwise
1735            only the `cell_name` is used.
1736
1737        Returns
1738        -------
1739        pandas.DataFrame
1740            The `self.normalized_data` table with columns renamed according to `set_info`.
1741
1742        Raises
1743        ------
1744        AttributeError
1745            If `self.normalized_data` or `self.input_metadata` is missing.
1746        """
1747
1748        to_return = self.normalized_data
1749
1750        if set_info:
1751            to_return.columns = (
1752                self.input_metadata["cell_names"] + " # " + self.input_metadata["sets"]
1753            )
1754        else:
1755            to_return.columns = self.input_metadata["cell_names"]
1756
1757        return to_return

Return normalized data with optional set annotation appended to column names.

Parameters

set_info : bool, default False If True, column names are returned as "cell_name # set"; otherwise only the cell_name is used.

Returns

pandas.DataFrame The self.normalized_data table with columns renamed according to set_info.

Raises

AttributeError If self.normalized_data or self.input_metadata is missing.

def get_partial_data( self, names: list | str | None = None, features: list | str | None = None, name_slot: str = 'cell_names', inc_metadata: bool = False):
1759    def get_partial_data(
1760        self,
1761        names: list | str | None = None,
1762        features: list | str | None = None,
1763        name_slot: str = "cell_names",
1764        inc_metadata: bool = False,
1765    ):
1766        """
1767        Return a subset of the data filtered by sample names and/or feature names.
1768
1769        Parameters
1770        ----------
1771        names : list, str, or None
1772            Names of samples to include. If None, all samples are considered.
1773
1774        features : list, str, or None
1775            Names of features to include. If None, all features are considered.
1776
1777        name_slot : str
1778            Column in metadata to use as sample names.
1779
1780        inc_metadata : bool
1781            If True return tuple (data, metadata)
1782
1783        Returns
1784        -------
1785        pandas.DataFrame
1786            Subset of the normalized data based on the specified names and features.
1787        """
1788
1789        data = self.normalized_data.copy()
1790        metadata = self.input_metadata
1791
1792        if name_slot in self.input_metadata.columns:
1793            data.columns = self.input_metadata[name_slot]
1794        else:
1795            raise ValueError("'name_slot' not occured in data!'")
1796
1797        if isinstance(features, str):
1798            features = [features]
1799        elif features is None:
1800            features = []
1801
1802        if isinstance(names, str):
1803            names = [names]
1804        elif names is None:
1805            names = []
1806
1807        features = [x.upper() for x in features]
1808        names = [x.upper() for x in names]
1809
1810        columns_names = [x.upper() for x in data.columns]
1811        features_names = [x.upper() for x in data.index]
1812
1813        columns_bool = [True if x in names else False for x in columns_names]
1814        features_bool = [True if x in features else False for x in features_names]
1815
1816        if True not in columns_bool and True not in features_bool:
1817            print("Missing 'names' and/or 'features'. Returning full dataset instead.")
1818            if inc_metadata:
1819                return data, metadata
1820            else:
1821                return data
1822
1823        if True in columns_bool:
1824            data = data.loc[:, columns_bool]
1825            metadata = metadata.loc[columns_bool, :]
1826
1827            if inc_metadata:
1828                return data, metadata
1829            else:
1830                return data
1831
1832        if True in features_bool:
1833            data = data.loc[features_bool, :]
1834
1835            if inc_metadata:
1836                return data, metadata
1837            else:
1838                return data

Return a subset of the data filtered by sample names and/or feature names.

Parameters

names : list, str, or None Names of samples to include. If None, all samples are considered.

features : list, str, or None Names of features to include. If None, all features are considered.

name_slot : str Column in metadata to use as sample names.

inc_metadata : bool If True return tuple (data, metadata)

Returns

pandas.DataFrame Subset of the normalized data based on the specified names and features.

def get_metadata(self):
1840    def get_metadata(self):
1841        """
1842        Return the stored input metadata.
1843
1844        Returns
1845        -------
1846        pandas.DataFrame
1847            `self.input_metadata` (may be None if not set).
1848        """
1849
1850        to_return = self.input_metadata
1851
1852        return to_return

Return the stored input metadata.

Returns

pandas.DataFrame self.input_metadata (may be None if not set).

def gene_calculation(self):
1854    def gene_calculation(self):
1855        """
1856        Calculate and store per-cell counts (e.g., number of detected genes).
1857
1858        The method computes a binary (presence/absence) per cell and sums across
1859        features to produce `self.gene_calc`.
1860
1861        Update
1862        ------
1863        Sets `self.gene_calc` as a pandas.Series.
1864
1865        Side Effects
1866        ------------
1867        Uses `self.input_data` when available, otherwise `self.normalized_data`.
1868        """
1869
1870        if self.input_data is not None:
1871
1872            bin_col = self.input_data.columns.copy()
1873
1874            bin_col = bin_col.where(bin_col <= 0, 1)
1875
1876            sum_data = bin_col.sum(axis=0)
1877
1878            self.gene_calc = sum_data
1879
1880        elif self.normalized_data is not None:
1881
1882            bin_col = self.normalized_data.copy()
1883
1884            bin_col = bin_col.where(bin_col <= 0, 1)
1885
1886            sum_data = bin_col.sum(axis=0)
1887
1888            self.gene_calc = sum_data

Calculate and store per-cell counts (e.g., number of detected genes).

The method computes a binary (presence/absence) per cell and sums across features to produce self.gene_calc.

Update

Sets self.gene_calc as a pandas.Series.

Side Effects

Uses self.input_data when available, otherwise self.normalized_data.

def gene_histograme(self, bins=100):
1890    def gene_histograme(self, bins=100):
1891        """
1892        Plot a histogram of the number of genes detected per cell.
1893
1894        Parameters
1895        ----------
1896        bins : int, default 100
1897            Number of histogram bins.
1898
1899        Returns
1900        -------
1901        matplotlib.figure.Figure
1902            Figure containing the histogram of gene contents.
1903
1904        Notes
1905        -----
1906        Requires `self.gene_calc` to be computed prior to calling.
1907        """
1908
1909        fig, ax = plt.subplots(figsize=(8, 5))
1910
1911        _, bin_edges, _ = ax.hist(
1912            self.gene_calc, bins=bins, edgecolor="black", alpha=0.6
1913        )
1914
1915        mu, sigma = np.mean(self.gene_calc), np.std(self.gene_calc)
1916
1917        x = np.linspace(min(self.gene_calc), max(self.gene_calc), 1000)
1918        y = norm.pdf(x, mu, sigma)
1919
1920        y_scaled = y * len(self.gene_calc) * (bin_edges[1] - bin_edges[0])
1921
1922        ax.plot(
1923            x, y_scaled, "r-", linewidth=2, label=f"Normal(μ={mu:.2f}, σ={sigma:.2f})"
1924        )
1925
1926        ax.set_xlabel("Value")
1927        ax.set_ylabel("Count")
1928        ax.set_title("Histogram of genes detected per cell")
1929
1930        ax.set_xticks(np.linspace(min(self.gene_calc), max(self.gene_calc), 20))
1931        ax.tick_params(axis="x", rotation=90)
1932
1933        ax.legend()
1934
1935        return fig

Plot a histogram of the number of genes detected per cell.

Parameters

bins : int, default 100 Number of histogram bins.

Returns

matplotlib.figure.Figure Figure containing the histogram of gene contents.

Notes

Requires self.gene_calc to be computed prior to calling.

def gene_threshold(self, min_n: int | None, max_n: int | None):
1937    def gene_threshold(self, min_n: int | None, max_n: int | None):
1938        """
1939        Filter cells by gene-detection thresholds (min and/or max).
1940
1941        Parameters
1942        ----------
1943        min_n : int or None
1944            Minimum number of detected genes required to keep a cell.
1945
1946        max_n : int or None
1947            Maximum number of detected genes allowed to keep a cell.
1948
1949        Update
1950        -------
1951        Filters `self.input_data`, `self.normalized_data`, `self.input_metadata`
1952        (and calls `average()` if `self.agg_normalized_data` exists).
1953
1954        Side Effects
1955        ------------
1956        Raises ValueError if both bounds are None or if filtering removes all cells.
1957        """
1958
1959        if min_n is not None and max_n is not None:
1960            mask = (self.gene_calc > min_n) & (self.gene_calc < max_n)
1961        elif min_n is None and max_n is not None:
1962            mask = self.gene_calc < max_n
1963        elif min_n is not None and max_n is None:
1964            mask = self.gene_calc > min_n
1965        else:
1966            raise ValueError("Lack of both min_n and max_n values")
1967
1968        if self.input_data is not None:
1969
1970            if len([y for y in mask if y is False]) == 0:
1971                raise ValueError("Nothing to reduce")
1972
1973            self.input_data = self.input_data.loc[:, mask.values]
1974
1975        if self.normalized_data is not None:
1976
1977            if len([y for y in mask if y is False]) == 0:
1978                raise ValueError("Nothing to reduce")
1979
1980            self.normalized_data = self.normalized_data.loc[:, mask.values]
1981
1982        if self.input_metadata is not None:
1983
1984            if len([y for y in mask if y is False]) == 0:
1985                raise ValueError("Nothing to reduce")
1986
1987            self.input_metadata = self.input_metadata.loc[mask.values, :].reset_index(
1988                drop=True
1989            )
1990
1991            self.input_metadata = self.input_metadata.drop(
1992                columns=["drop"], errors="ignore"
1993            )
1994
1995        if self.agg_normalized_data is not None:
1996            self.average()
1997
1998        self.gene_calculation()
1999        self.cells_calculation()

Filter cells by gene-detection thresholds (min and/or max).

Parameters

min_n : int or None Minimum number of detected genes required to keep a cell.

max_n : int or None Maximum number of detected genes allowed to keep a cell.

Update

Filters self.input_data, self.normalized_data, self.input_metadata (and calls average() if self.agg_normalized_data exists).

Side Effects

Raises ValueError if both bounds are None or if filtering removes all cells.

def cells_calculation(self, name_slot='cell_names'):
2001    def cells_calculation(self, name_slot="cell_names"):
2002        """
2003        Calculate number of cells per  call name / cluster.
2004
2005        The method computes a binary (presence/absence) per cell name / cluster and sums across
2006        cells.
2007
2008        Parameters
2009        ----------
2010        name_slot : str, default 'cell_names'
2011            Column in metadata to use as sample names.
2012
2013        Update
2014        ------
2015        Sets `self.cells_calc` as a pd.DataFrame.
2016        """
2017
2018        ls = list(self.input_metadata[name_slot])
2019
2020        df = pd.DataFrame(
2021            {
2022                "cluster": pd.Series(ls).value_counts().index,
2023                "n": pd.Series(ls).value_counts().values,
2024            }
2025        )
2026
2027        self.cells_calc = df

Calculate number of cells per call name / cluster.

The method computes a binary (presence/absence) per cell name / cluster and sums across cells.

Parameters

name_slot : str, default 'cell_names' Column in metadata to use as sample names.

Update

Sets self.cells_calc as a pd.DataFrame.

def cell_histograme(self, name_slot: str = 'cell_names'):
2029    def cell_histograme(self, name_slot: str = "cell_names"):
2030        """
2031        Plot a histogram of the number of cells detected per cell name (cluster).
2032
2033        Parameters
2034        ----------
2035        name_slot : str, default 'cell_names'
2036            Column in metadata to use as sample names.
2037
2038        Returns
2039        -------
2040        matplotlib.figure.Figure
2041            Figure containing the histogram of cell contents.
2042
2043        Notes
2044        -----
2045        Requires `self.cells_calc` to be computed prior to calling.
2046        """
2047
2048        if name_slot != "cell_names":
2049            self.cells_calculation(name_slot=name_slot)
2050
2051        fig, ax = plt.subplots(figsize=(8, 5))
2052
2053        _, bin_edges, _ = ax.hist(
2054            list(self.cells_calc["n"]),
2055            bins=len(set(self.cells_calc["cluster"])),
2056            edgecolor="black",
2057            color="orange",
2058            alpha=0.6,
2059        )
2060
2061        mu, sigma = np.mean(list(self.cells_calc["n"])), np.std(
2062            list(self.cells_calc["n"])
2063        )
2064
2065        x = np.linspace(
2066            min(list(self.cells_calc["n"])), max(list(self.cells_calc["n"])), 1000
2067        )
2068        y = norm.pdf(x, mu, sigma)
2069
2070        y_scaled = y * len(list(self.cells_calc["n"])) * (bin_edges[1] - bin_edges[0])
2071
2072        ax.plot(
2073            x, y_scaled, "r-", linewidth=2, label=f"Normal(μ={mu:.2f}, σ={sigma:.2f})"
2074        )
2075
2076        ax.set_xlabel("Value")
2077        ax.set_ylabel("Count")
2078        ax.set_title("Histogram of cells detected per cell name / cluster")
2079
2080        ax.set_xticks(
2081            np.linspace(
2082                min(list(self.cells_calc["n"])), max(list(self.cells_calc["n"])), 20
2083            )
2084        )
2085        ax.tick_params(axis="x", rotation=90)
2086
2087        ax.legend()
2088
2089        return fig

Plot a histogram of the number of cells detected per cell name (cluster).

Parameters

name_slot : str, default 'cell_names' Column in metadata to use as sample names.

Returns

matplotlib.figure.Figure Figure containing the histogram of cell contents.

Notes

Requires self.cells_calc to be computed prior to calling.

def cluster_threshold(self, min_n: int | None, name_slot: str = 'cell_names'):
2091    def cluster_threshold(self, min_n: int | None, name_slot: str = "cell_names"):
2092        """
2093        Filter cell names / clusters by cell-detection threshold.
2094
2095        Parameters
2096        ----------
2097        min_n : int or None
2098            Minimum number of detected genes required to keep a cell.
2099
2100        name_slot : str, default 'cell_names'
2101            Column in metadata to use as sample names.
2102
2103
2104        Update
2105        -------
2106        Filters `self.input_data`, `self.normalized_data`, `self.input_metadata`
2107        (and calls `average()` if `self.agg_normalized_data` exists).
2108        """
2109
2110        if name_slot != "cell_names":
2111            self.cells_calculation(name_slot=name_slot)
2112
2113        if min_n is not None:
2114            names = self.cells_calc["cluster"][self.cells_calc["n"] < min_n]
2115        else:
2116            raise ValueError("Lack of min_n value")
2117
2118        if len(names) > 0:
2119
2120            if self.input_data is not None:
2121
2122                self.input_data.columns = self.input_metadata[name_slot]
2123
2124                mask = [not any(r in x for r in names) for x in self.input_data.columns]
2125
2126                if len([y for y in mask if y is False]) > 0:
2127
2128                    self.input_data = self.input_data.loc[:, mask]
2129
2130            if self.normalized_data is not None:
2131
2132                self.normalized_data.columns = self.input_metadata[name_slot]
2133
2134                mask = [
2135                    not any(r in x for r in names) for x in self.normalized_data.columns
2136                ]
2137
2138                if len([y for y in mask if y is False]) > 0:
2139
2140                    self.normalized_data = self.normalized_data.loc[:, mask]
2141
2142            if self.input_metadata is not None:
2143
2144                self.input_metadata["drop"] = self.input_metadata[name_slot]
2145
2146                mask = [
2147                    not any(r in x for r in names) for x in self.input_metadata["drop"]
2148                ]
2149
2150                if len([y for y in mask if y is False]) > 0:
2151
2152                    self.input_metadata = self.input_metadata.loc[mask, :].reset_index(
2153                        drop=True
2154                    )
2155
2156                self.input_metadata = self.input_metadata.drop(
2157                    columns=["drop"], errors="ignore"
2158                )
2159
2160            if self.agg_normalized_data is not None:
2161
2162                self.agg_normalized_data.columns = self.agg_metadata[name_slot]
2163
2164                mask = [
2165                    not any(r in x for r in names)
2166                    for x in self.agg_normalized_data.columns
2167                ]
2168
2169                if len([y for y in mask if y is False]) > 0:
2170
2171                    self.agg_normalized_data = self.agg_normalized_data.loc[:, mask]
2172
2173            if self.agg_metadata is not None:
2174
2175                self.agg_metadata["drop"] = self.agg_metadata[name_slot]
2176
2177                mask = [
2178                    not any(r in x for r in names) for x in self.agg_metadata["drop"]
2179                ]
2180
2181                if len([y for y in mask if y is False]) > 0:
2182
2183                    self.agg_metadata = self.agg_metadata.loc[mask, :].reset_index(
2184                        drop=True
2185                    )
2186
2187                self.agg_metadata = self.agg_metadata.drop(
2188                    columns=["drop"], errors="ignore"
2189                )
2190
2191            self.gene_calculation()
2192            self.cells_calculation()

Filter cell names / clusters by cell-detection threshold.

Parameters

min_n : int or None Minimum number of detected genes required to keep a cell.

name_slot : str, default 'cell_names' Column in metadata to use as sample names.

Update

Filters self.input_data, self.normalized_data, self.input_metadata (and calls average() if self.agg_normalized_data exists).

def load_sparse_from_projects(self, normalized_data: bool = False):
2194    def load_sparse_from_projects(self, normalized_data: bool = False):
2195        """
2196        Load sparse 10x-style datasets from stored project paths, concatenate them,
2197        and populate `input_data` / `normalized_data` and `input_metadata`.
2198
2199        Parameters
2200        ----------
2201        normalized_data : bool, default False
2202            If True, store concatenated tables in `self.normalized_data`.
2203            If False, store them in `self.input_data` and normalization
2204            is needed using normalize_data() method.
2205
2206        Side Effects
2207        ------------
2208        - Reads each project using `load_sparse(...)` (expects matrix.mtx, genes.tsv, barcodes.tsv).
2209        - Concatenates all projects column-wise and sets `self.input_metadata`.
2210        - Replaces NaNs with zeros and updates `self.gene_calc`.
2211        """
2212
2213        obj = self.objects
2214
2215        full_data = pd.DataFrame()
2216        full_metadata = pd.DataFrame()
2217
2218        for ke in obj.keys():
2219            print(ke)
2220
2221            dt, met = load_sparse(path=obj[ke], name=ke)
2222
2223            full_data = pd.concat([full_data, dt], axis=1)
2224            full_metadata = pd.concat([full_metadata, met], axis=0)
2225
2226        full_data[np.isnan(full_data)] = 0
2227
2228        if normalized_data:
2229            self.normalized_data = full_data
2230            self.input_metadata = full_metadata
2231        else:
2232
2233            self.input_data = full_data
2234            self.input_metadata = full_metadata
2235
2236        self.gene_calculation()
2237        self.cells_calculation()

Load sparse 10x-style datasets from stored project paths, concatenate them, and populate input_data / normalized_data and input_metadata.

Parameters

normalized_data : bool, default False If True, store concatenated tables in self.normalized_data. If False, store them in self.input_data and normalization is needed using normalize_data() method.

Side Effects

  • Reads each project using load_sparse(...) (expects matrix.mtx, genes.tsv, barcodes.tsv).
  • Concatenates all projects column-wise and sets self.input_metadata.
  • Replaces NaNs with zeros and updates self.gene_calc.
def rename_names(self, mapping: dict, slot: str = 'cell_names'):
2239    def rename_names(self, mapping: dict, slot: str = "cell_names"):
2240        """
2241        Rename entries in `self.input_metadata[slot]` according to a provided mapping.
2242
2243        Parameters
2244        ----------
2245        mapping : dict
2246            Dictionary with keys 'old_name' and 'new_name', each mapping to a list
2247            of equal length describing replacements.
2248
2249        slot : str, default 'cell_names'
2250            Metadata column to operate on.
2251
2252        Update
2253        -------
2254        Updates `self.input_metadata[slot]` in-place with renamed values.
2255
2256        Raises
2257        ------
2258        ValueError
2259            If mapping keys are incorrect, lengths differ, or some 'old_name' values
2260            are not present in the metadata column.
2261        """
2262
2263        if set(["old_name", "new_name"]) != set(mapping.keys()):
2264            raise ValueError(
2265                "Mapping dictionary must contain keys 'old_name' and 'new_name', "
2266                "each with a list of names to change."
2267            )
2268
2269        if len(mapping["old_name"]) != len(mapping["new_name"]):
2270            raise ValueError(
2271                "Mapping dictionary lists 'old_name' and 'new_name' "
2272                "must have the same length!"
2273            )
2274
2275        names_vector = list(self.input_metadata[slot])
2276
2277        if not all(elem in names_vector for elem in list(mapping["old_name"])):
2278            raise ValueError(
2279                f"Some entries from 'old_name' do not exist in the names of slot {slot}."
2280            )
2281
2282        replace_dict = dict(zip(mapping["old_name"], mapping["new_name"]))
2283
2284        names_vector_ret = [replace_dict.get(item, item) for item in names_vector]
2285
2286        self.input_metadata[slot] = names_vector_ret

Rename entries in self.input_metadata[slot] according to a provided mapping.

Parameters

mapping : dict Dictionary with keys 'old_name' and 'new_name', each mapping to a list of equal length describing replacements.

slot : str, default 'cell_names' Metadata column to operate on.

Update

Updates self.input_metadata[slot] in-place with renamed values.

Raises

ValueError If mapping keys are incorrect, lengths differ, or some 'old_name' values are not present in the metadata column.

def rename_subclusters(self, mapping):
2288    def rename_subclusters(self, mapping):
2289        """
2290        Rename labels stored in `self.subclusters_.subclusters` according to mapping.
2291
2292        Parameters
2293        ----------
2294        mapping : dict
2295            Mapping with keys 'old_name' and 'new_name' (lists of equal length).
2296
2297        Update
2298        -------
2299        Updates `self.subclusters_.subclusters` with renamed labels.
2300
2301        Raises
2302        ------
2303        ValueError
2304            If mapping is invalid or old names are not present.
2305        """
2306
2307        if set(["old_name", "new_name"]) != set(mapping.keys()):
2308            raise ValueError(
2309                "Mapping dictionary must contain keys 'old_name' and 'new_name', "
2310                "each with a list of names to change."
2311            )
2312
2313        if len(mapping["old_name"]) != len(mapping["new_name"]):
2314            raise ValueError(
2315                "Mapping dictionary lists 'old_name' and 'new_name' "
2316                "must have the same length!"
2317            )
2318
2319        names_vector = list(self.subclusters_.subclusters)
2320
2321        if not all(elem in names_vector for elem in list(mapping["old_name"])):
2322            raise ValueError(
2323                "Some entries from 'old_name' do not exist in the subcluster names."
2324            )
2325
2326        replace_dict = dict(zip(mapping["old_name"], mapping["new_name"]))
2327
2328        names_vector_ret = [replace_dict.get(item, item) for item in names_vector]
2329
2330        self.subclusters_.subclusters = names_vector_ret

Rename labels stored in self.subclusters_.subclusters according to mapping.

Parameters

mapping : dict Mapping with keys 'old_name' and 'new_name' (lists of equal length).

Update

Updates self.subclusters_.subclusters with renamed labels.

Raises

ValueError If mapping is invalid or old names are not present.

def save_sparse( self, path_to_save: str = '/mnt/c/Users/merag/Git/JDti', name_slot: str = 'cell_names', data_slot: str = 'normalized'):
2332    def save_sparse(
2333        self,
2334        path_to_save: str = os.getcwd(),
2335        name_slot: str = "cell_names",
2336        data_slot: str = "normalized",
2337    ):
2338        """
2339        Export data as 10x-compatible sparse files (matrix.mtx, barcodes.tsv, genes.tsv).
2340
2341        Parameters
2342        ----------
2343        path_to_save : str, default current working directory
2344            Directory where files will be written.
2345
2346        name_slot : str, default 'cell_names'
2347            Metadata column providing cell names for barcodes.tsv.
2348
2349        data_slot : str, default 'normalized'
2350            Either 'normalized' (uses self.normalized_data) or 'count' (uses self.input_data).
2351
2352        Raises
2353        ------
2354        ValueError
2355            If `data_slot` is not 'normalized' or 'count'.
2356        """
2357
2358        names = self.input_metadata[name_slot]
2359
2360        if data_slot.lower() == "normalized":
2361
2362            features = list(self.normalized_data.index)
2363            mtx = sparse.csr_matrix(self.normalized_data)
2364
2365        elif data_slot.lower() == "count":
2366
2367            features = list(self.input_data.index)
2368            mtx = sparse.csr_matrix(self.input_data)
2369
2370        else:
2371            raise ValueError("'data_slot' must be included in 'normalized' or 'count'")
2372
2373        os.makedirs(path_to_save, exist_ok=True)
2374
2375        mmwrite(os.path.join(path_to_save, "matrix.mtx"), mtx)
2376
2377        pd.Series(names).to_csv(
2378            os.path.join(path_to_save, "barcodes.tsv"),
2379            index=False,
2380            header=False,
2381            sep="\t",
2382        )
2383
2384        pd.Series(features).to_csv(
2385            os.path.join(path_to_save, "genes.tsv"), index=False, header=False, sep="\t"
2386        )

Export data as 10x-compatible sparse files (matrix.mtx, barcodes.tsv, genes.tsv).

Parameters

path_to_save : str, default current working directory Directory where files will be written.

name_slot : str, default 'cell_names' Metadata column providing cell names for barcodes.tsv.

data_slot : str, default 'normalized' Either 'normalized' (uses self.normalized_data) or 'count' (uses self.input_data).

Raises

ValueError If data_slot is not 'normalized' or 'count'.

def normalize_counts(self, normalize_factor: int = 100000, log_transform: bool = True):
2388    def normalize_counts(
2389        self, normalize_factor: int = 100000, log_transform: bool = True
2390    ):
2391        """
2392        Normalize raw counts to counts-per-(normalize_factor)
2393        (e.g., CPM, TPM - depending on normalize_factor).
2394
2395        Parameters
2396        ----------
2397        normalize_factor : int, default 100000
2398            Scaling factor used after dividing by column sums.
2399
2400        log_transform : bool, default True
2401            If True, apply log2(x+1) transformation to normalized values.
2402
2403        Update
2404        -------
2405            Sets `self.normalized_data` to normalized values (fills NaNs with 0).
2406
2407        Raises
2408        ------
2409        ValueError
2410            If `self.input_data` is missing (cannot normalize).
2411        """
2412        if self.input_data is None:
2413            raise ValueError("Input data is missing, cannot normalize.")
2414
2415        sum_col = self.input_data.sum()
2416        self.normalized_data = self.input_data.div(sum_col).fillna(0) * normalize_factor
2417
2418        if log_transform:
2419            # log2(x + 1) to avoid -inf for zeros
2420            self.normalized_data = np.log2(self.normalized_data + 1)

Normalize raw counts to counts-per-(normalize_factor) (e.g., CPM, TPM - depending on normalize_factor).

Parameters

normalize_factor : int, default 100000 Scaling factor used after dividing by column sums.

log_transform : bool, default True If True, apply log2(x+1) transformation to normalized values.

Update

Sets `self.normalized_data` to normalized values (fills NaNs with 0).

Raises

ValueError If self.input_data is missing (cannot normalize).

def statistic( self, cells=None, sets=None, min_exp: float = 0.01, min_pct: float = 0.1, n_proc: int = 10):
2422    def statistic(
2423        self,
2424        cells=None,
2425        sets=None,
2426        min_exp: float = 0.01,
2427        min_pct: float = 0.1,
2428        n_proc: int = 10,
2429    ):
2430        """
2431        Compute per-feature statistics (Mann–Whitney U) comparing target vs rest.
2432
2433        This is a wrapper similar to `calc_DEG` tailored to use `self.normalized_data`
2434        and `self.input_metadata`. It returns per-feature statistics including p-values,
2435        adjusted p-values, means, variances, effect-size measures and fold-changes.
2436
2437        Parameters
2438        ----------
2439        cells : list, 'All', dict, or None
2440            Defines the target cells or groups for comparison (several modes supported).
2441
2442        sets : 'All', dict, or None
2443            Alternative grouping mode (operate on `self.input_metadata['sets']`).
2444
2445        min_exp : float, default 0.01
2446            Minimum expression threshold used when filtering features.
2447
2448        min_pct : float, default 0.1
2449            Minimum proportion of expressing cells in the target group required to test a feature.
2450
2451        n_proc : int, default 10
2452            Number of parallel jobs to use.
2453
2454        Returns
2455        -------
2456        pandas.DataFrame or dict
2457            Results DataFrame (or dict containing valid/control cells + DataFrame),
2458            similar to `calc_DEG` interface.
2459
2460        Raises
2461        ------
2462        ValueError
2463            If neither `cells` nor `sets` is provided, or input metadata mismatch occurs.
2464
2465        Notes
2466        -----
2467        Multiple modes supported: single-list entities, 'All', pairwise dicts, etc.
2468        """
2469
2470        def stat_calc(choose, feature_name):
2471            target_values = choose.loc[choose["DEG"] == "target", feature_name]
2472            rest_values = choose.loc[choose["DEG"] == "rest", feature_name]
2473
2474            pct_valid = (target_values > 0).sum() / len(target_values)
2475            pct_rest = (rest_values > 0).sum() / len(rest_values)
2476
2477            avg_valid = np.mean(target_values)
2478            avg_ctrl = np.mean(rest_values)
2479            sd_valid = np.std(target_values, ddof=1)
2480            sd_ctrl = np.std(rest_values, ddof=1)
2481            esm = (avg_valid - avg_ctrl) / np.sqrt(((sd_valid**2 + sd_ctrl**2) / 2))
2482
2483            if np.sum(target_values) == np.sum(rest_values):
2484                p_val = 1.0
2485            else:
2486                _, p_val = stats.mannwhitneyu(
2487                    target_values, rest_values, alternative="two-sided"
2488                )
2489
2490            return {
2491                "feature": feature_name,
2492                "p_val": p_val,
2493                "pct_valid": pct_valid,
2494                "pct_ctrl": pct_rest,
2495                "avg_valid": avg_valid,
2496                "avg_ctrl": avg_ctrl,
2497                "sd_valid": sd_valid,
2498                "sd_ctrl": sd_ctrl,
2499                "esm": esm,
2500            }
2501
2502        def prepare_and_run_stat(
2503            choose, valid_group, min_exp, min_pct, n_proc, factors
2504        ):
2505
2506            tmp_dat = choose[choose["DEG"] == "target"]
2507            tmp_dat = tmp_dat.drop("DEG", axis=1)
2508
2509            counts = (tmp_dat > min_exp).sum(axis=0)
2510
2511            total_count = tmp_dat.shape[0]
2512
2513            info = pd.DataFrame(
2514                {"feature": list(tmp_dat.columns), "pct": list(counts / total_count)}
2515            )
2516
2517            del tmp_dat
2518
2519            drop_col = info["feature"][info["pct"] <= min_pct]
2520
2521            if len(drop_col) + 1 == len(choose.columns):
2522                drop_col = info["feature"][info["pct"] == 0]
2523
2524            del info
2525
2526            choose = choose.drop(list(drop_col), axis=1)
2527
2528            results = Parallel(n_jobs=n_proc)(
2529                delayed(stat_calc)(choose, feature)
2530                for feature in tqdm(choose.columns[choose.columns != "DEG"])
2531            )
2532
2533            df = pd.DataFrame(results)
2534            df["valid_group"] = valid_group
2535            df.sort_values(by="p_val", inplace=True)
2536
2537            num_tests = len(df)
2538            df["adj_pval"] = np.minimum(
2539                1, (df["p_val"] * num_tests) / np.arange(1, num_tests + 1)
2540            )
2541
2542            factors_df = factors.to_frame(name="factor")
2543
2544            df = df.merge(factors_df, left_on="feature", right_index=True, how="left")
2545
2546            df["FC"] = (df["avg_valid"] + df["factor"]) / (
2547                df["avg_ctrl"] + df["factor"]
2548            )
2549            df["log(FC)"] = np.log2(df["FC"])
2550            df["norm_diff"] = df["avg_valid"] - df["avg_ctrl"]
2551            df = df.drop(columns=["factor"])
2552
2553            return df
2554
2555        choose = self.normalized_data.copy().T
2556        factors = self.normalized_data.copy().replace(0, np.nan).min(axis=1)
2557        factors[factors == 0] = min(factors[factors != 0])
2558
2559        final_results = []
2560
2561        if isinstance(cells, list) and sets is None:
2562            print("\nAnalysis started...\nComparing selected cells to the whole set...")
2563            choose.index = (
2564                self.input_metadata["cell_names"] + " # " + self.input_metadata["sets"]
2565            )
2566
2567            if "#" not in cells[0]:
2568                choose.index = self.input_metadata["cell_names"]
2569
2570                print(
2571                    "Not include the set info (name # set) in the 'cells' list. "
2572                    "Only the names will be compared, without considering the set information."
2573                )
2574
2575            labels = ["target" if idx in cells else "rest" for idx in choose.index]
2576            valid = list(
2577                set(choose.index[[i for i, x in enumerate(labels) if x == "target"]])
2578            )
2579
2580            choose["DEG"] = labels
2581            choose = choose[choose["DEG"] != "drop"]
2582
2583            result_df = prepare_and_run_stat(
2584                choose.reset_index(drop=True),
2585                valid_group=valid,
2586                min_exp=min_exp,
2587                min_pct=min_pct,
2588                n_proc=n_proc,
2589                factors=factors,
2590            )
2591            return {"valid_cells": valid, "control_cells": "rest", "DEG": result_df}
2592
2593        elif cells == "All" and sets is None:
2594            print("\nAnalysis started...\nComparing each type of cell to others...")
2595            choose.index = (
2596                self.input_metadata["cell_names"] + " # " + self.input_metadata["sets"]
2597            )
2598            unique_labels = set(choose.index)
2599
2600            for label in tqdm(unique_labels):
2601                print(f"\nCalculating statistics for {label}")
2602                labels = ["target" if idx == label else "rest" for idx in choose.index]
2603                choose["DEG"] = labels
2604                choose = choose[choose["DEG"] != "drop"]
2605                result_df = prepare_and_run_stat(
2606                    choose.copy(),
2607                    valid_group=label,
2608                    min_exp=min_exp,
2609                    min_pct=min_pct,
2610                    n_proc=n_proc,
2611                    factors=factors,
2612                )
2613                final_results.append(result_df)
2614
2615            return pd.concat(final_results, ignore_index=True)
2616
2617        elif cells is None and sets == "All":
2618            print("\nAnalysis started...\nComparing each set/group to others...")
2619            choose.index = self.input_metadata["sets"]
2620            unique_sets = set(choose.index)
2621
2622            for label in tqdm(unique_sets):
2623                print(f"\nCalculating statistics for {label}")
2624                labels = ["target" if idx == label else "rest" for idx in choose.index]
2625
2626                choose["DEG"] = labels
2627                choose = choose[choose["DEG"] != "drop"]
2628                result_df = prepare_and_run_stat(
2629                    choose.copy(),
2630                    valid_group=label,
2631                    min_exp=min_exp,
2632                    min_pct=min_pct,
2633                    n_proc=n_proc,
2634                    factors=factors,
2635                )
2636                final_results.append(result_df)
2637
2638            return pd.concat(final_results, ignore_index=True)
2639
2640        elif cells is None and isinstance(sets, dict):
2641            print("\nAnalysis started...\nComparing groups...")
2642
2643            choose.index = self.input_metadata["sets"]
2644
2645            group_list = list(sets.keys())
2646            if len(group_list) != 2:
2647                print("Only pairwise group comparison is supported.")
2648                return None
2649
2650            labels = [
2651                (
2652                    "target"
2653                    if idx in sets[group_list[0]]
2654                    else "rest" if idx in sets[group_list[1]] else "drop"
2655                )
2656                for idx in choose.index
2657            ]
2658            choose["DEG"] = labels
2659            choose = choose[choose["DEG"] != "drop"]
2660
2661            result_df = prepare_and_run_stat(
2662                choose.reset_index(drop=True),
2663                valid_group=group_list[0],
2664                min_exp=min_exp,
2665                min_pct=min_pct,
2666                n_proc=n_proc,
2667                factors=factors,
2668            )
2669            return result_df
2670
2671        elif isinstance(cells, dict) and sets is None:
2672            print("\nAnalysis started...\nComparing groups...")
2673            choose.index = (
2674                self.input_metadata["cell_names"] + " # " + self.input_metadata["sets"]
2675            )
2676
2677            if "#" not in cells[list(cells.keys())[0]][0]:
2678                choose.index = self.input_metadata["cell_names"]
2679
2680                print(
2681                    "Not include the set info (name # set) in the 'cells' dict. "
2682                    "Only the names will be compared, without considering the set information."
2683                )
2684
2685            group_list = list(cells.keys())
2686            if len(group_list) != 2:
2687                print("Only pairwise group comparison is supported.")
2688                return None
2689
2690            labels = [
2691                (
2692                    "target"
2693                    if idx in cells[group_list[0]]
2694                    else "rest" if idx in cells[group_list[1]] else "drop"
2695                )
2696                for idx in choose.index
2697            ]
2698
2699            choose["DEG"] = labels
2700            choose = choose[choose["DEG"] != "drop"]
2701
2702            result_df = prepare_and_run_stat(
2703                choose.reset_index(drop=True),
2704                valid_group=group_list[0],
2705                min_exp=min_exp,
2706                min_pct=min_pct,
2707                n_proc=n_proc,
2708                factors=factors,
2709            )
2710
2711            return result_df.reset_index(drop=True)
2712
2713        else:
2714            raise ValueError(
2715                "You must specify either 'cells' or 'sets' (or both). None were provided, which is not allowed for this analysis."
2716            )

Compute per-feature statistics (Mann–Whitney U) comparing target vs rest.

This is a wrapper similar to calc_DEG tailored to use self.normalized_data and self.input_metadata. It returns per-feature statistics including p-values, adjusted p-values, means, variances, effect-size measures and fold-changes.

Parameters

cells : list, 'All', dict, or None Defines the target cells or groups for comparison (several modes supported).

sets : 'All', dict, or None Alternative grouping mode (operate on self.input_metadata['sets']).

min_exp : float, default 0.01 Minimum expression threshold used when filtering features.

min_pct : float, default 0.1 Minimum proportion of expressing cells in the target group required to test a feature.

n_proc : int, default 10 Number of parallel jobs to use.

Returns

pandas.DataFrame or dict Results DataFrame (or dict containing valid/control cells + DataFrame), similar to calc_DEG interface.

Raises

ValueError If neither cells nor sets is provided, or input metadata mismatch occurs.

Notes

Multiple modes supported: single-list entities, 'All', pairwise dicts, etc.

def calculate_difference_markers(self, min_exp=0, min_pct=0.25, n_proc=10, force: bool = False):
2718    def calculate_difference_markers(
2719        self, min_exp=0, min_pct=0.25, n_proc=10, force: bool = False
2720    ):
2721        """
2722        Compute differential markers (var_data) if not already present.
2723
2724        Parameters
2725        ----------
2726        min_exp : float, default 0
2727            Minimum expression threshold passed to `statistic`.
2728
2729        min_pct : float, default 0.25
2730            Minimum percent expressed in target group.
2731
2732        n_proc : int, default 10
2733            Parallel jobs.
2734
2735        force : bool, default False
2736            If True, recompute even if `self.var_data` is present.
2737
2738        Update
2739        -------
2740        Sets `self.var_data` to the result of `self.statistic(...)`.
2741
2742        Raise
2743        ------
2744        ValueError if already computed and `force` is False.
2745        """
2746
2747        if self.var_data is None or force:
2748
2749            self.var_data = self.statistic(
2750                cells="All", sets=None, min_exp=min_exp, min_pct=min_pct, n_proc=n_proc
2751            )
2752
2753        else:
2754            raise ValueError(
2755                "self.calculate_difference_markers() has already been executed. "
2756                "The results are stored in self.var. "
2757                "If you want to recalculate with different statistics, please rerun the method with force=True."
2758            )

Compute differential markers (var_data) if not already present.

Parameters

min_exp : float, default 0 Minimum expression threshold passed to statistic.

min_pct : float, default 0.25 Minimum percent expressed in target group.

n_proc : int, default 10 Parallel jobs.

force : bool, default False If True, recompute even if self.var_data is present.

Update

Sets self.var_data to the result of self.statistic(...).

Raise

ValueError if already computed and force is False.

def clustering_features( self, features_list: list | None, name_slot: str = 'cell_names', p_val: float = 0.05, top_n: int = 25, adj_mean: bool = True, beta: float = 0.2):
2760    def clustering_features(
2761        self,
2762        features_list: list | None,
2763        name_slot: str = "cell_names",
2764        p_val: float = 0.05,
2765        top_n: int = 25,
2766        adj_mean: bool = True,
2767        beta: float = 0.2,
2768    ):
2769        """
2770        Prepare clustering input by selecting marker features and optionally smoothing cell values
2771        toward group means.
2772
2773        Parameters
2774        ----------
2775        features_list : list or None
2776            If provided, use this list of features. If None, features are selected
2777            from `self.var_data` (adj_pval <= p_val, positive logFC) picking `top_n` per group.
2778
2779        name_slot : str, default 'cell_names'
2780            Metadata column used for naming.
2781
2782        p_val : float, default 0.05
2783            Adjusted p-value cutoff when selecting features automatically.
2784
2785        top_n : int, default 25
2786            Number of top features per valid group to keep if `features_list` is None.
2787
2788        adj_mean : bool, default True
2789            If True, adjust cell values toward group means using `beta`.
2790
2791        beta : float, default 0.2
2792            Adjustment strength toward group mean.
2793
2794        Update
2795        ------
2796        Sets `self.clustering_data` and `self.clustering_metadata` to the selected subset,
2797        ready for PCA/UMAP/clustering.
2798        """
2799
2800        if features_list is None or len(features_list) == 0:
2801
2802            if self.var_data is None:
2803                raise ValueError(
2804                    "Lack of 'self.var_data'. Use self.calculate_difference_markers() method first."
2805                )
2806
2807            df_tmp = self.var_data[self.var_data["adj_pval"] <= p_val]
2808            df_tmp = df_tmp[df_tmp["log(FC)"] > 0]
2809            df_tmp = (
2810                df_tmp.sort_values(
2811                    ["valid_group", "esm", "log(FC)"], ascending=[True, False, False]
2812                )
2813                .groupby("valid_group")
2814                .head(top_n)
2815            )
2816
2817            feaures_list = list(set(df_tmp["feature"]))
2818
2819        data = self.get_partial_data(
2820            names=None, features=feaures_list, name_slot=name_slot
2821        )
2822        data_avg = average(data)
2823
2824        if adj_mean:
2825            data = adjust_cells_to_group_mean(data=data, data_avg=data_avg, beta=beta)
2826
2827        self.clustering_data = data
2828
2829        self.clustering_metadata = self.input_metadata

Prepare clustering input by selecting marker features and optionally smoothing cell values toward group means.

Parameters

features_list : list or None If provided, use this list of features. If None, features are selected from self.var_data (adj_pval <= p_val, positive logFC) picking top_n per group.

name_slot : str, default 'cell_names' Metadata column used for naming.

p_val : float, default 0.05 Adjusted p-value cutoff when selecting features automatically.

top_n : int, default 25 Number of top features per valid group to keep if features_list is None.

adj_mean : bool, default True If True, adjust cell values toward group means using beta.

beta : float, default 0.2 Adjustment strength toward group mean.

Update

Sets self.clustering_data and self.clustering_metadata to the selected subset, ready for PCA/UMAP/clustering.

def average(self):
2831    def average(self):
2832        """
2833        Aggregate normalized data by (cell_name, set) pairs computing the mean per group.
2834
2835        The method constructs new column names as "cell_name # set", averages columns
2836        sharing identical labels, and populates `self.agg_normalized_data` and `self.agg_metadata`.
2837
2838        Update
2839        ------
2840        Sets `self.agg_normalized_data` (features x aggregated samples) and
2841        `self.agg_metadata` (DataFrame with 'cell_names' and 'sets').
2842        """
2843
2844        wide_data = self.normalized_data
2845
2846        wide_metadata = self.input_metadata
2847
2848        new_names = wide_metadata["cell_names"] + " # " + wide_metadata["sets"]
2849
2850        wide_data.columns = list(new_names)
2851
2852        aggregated_df = wide_data.T.groupby(wide_data.columns, axis=0).mean().T
2853
2854        sets = [re.sub(".*# ", "", x) for x in aggregated_df.columns]
2855        names = [re.sub(" #.*", "", x) for x in aggregated_df.columns]
2856
2857        aggregated_df.columns = names
2858        aggregated_metadata = pd.DataFrame({"cell_names": names, "sets": sets})
2859
2860        self.agg_metadata = aggregated_metadata
2861        self.agg_normalized_data = aggregated_df

Aggregate normalized data by (cell_name, set) pairs computing the mean per group.

The method constructs new column names as "cell_name # set", averages columns sharing identical labels, and populates self.agg_normalized_data and self.agg_metadata.

Update

Sets self.agg_normalized_data (features x aggregated samples) and self.agg_metadata (DataFrame with 'cell_names' and 'sets').

def estimating_similarity(self, method='pearson', p_val: float = 0.05, top_n: int = 25):
2863    def estimating_similarity(
2864        self, method="pearson", p_val: float = 0.05, top_n: int = 25
2865    ):
2866        """
2867        Estimate pairwise similarity and Euclidean distance between aggregated samples.
2868
2869        Parameters
2870        ----------
2871        method : str, default 'pearson'
2872            Correlation method to use (passed to pandas.DataFrame.corr()).
2873
2874        p_val : float, default 0.05
2875            Adjusted p-value cutoff used to select marker features from `self.var_data`.
2876
2877        top_n : int, default 25
2878            Number of top features per valid group to include.
2879
2880        Update
2881        -------
2882        Computes a combined table with per-pair correlation and euclidean distance
2883        and stores it in `self.similarity`.
2884        """
2885
2886        if self.var_data is None:
2887            raise ValueError(
2888                "Lack of 'self.var_data'. Use self.calculate_difference_markers() method first."
2889            )
2890
2891        if self.agg_normalized_data is None:
2892            self.average()
2893
2894        metadata = self.agg_metadata
2895        data = self.agg_normalized_data
2896
2897        df_tmp = self.var_data[self.var_data["adj_pval"] <= p_val]
2898        df_tmp = df_tmp[df_tmp["log(FC)"] > 0]
2899        df_tmp = (
2900            df_tmp.sort_values(
2901                ["valid_group", "esm", "log(FC)"], ascending=[True, False, False]
2902            )
2903            .groupby("valid_group")
2904            .head(top_n)
2905        )
2906
2907        data = data.loc[list(set(df_tmp["feature"]))]
2908
2909        if len(set(metadata["sets"])) > 1:
2910            data.columns = data.columns + " # " + [x for x in metadata["sets"]]
2911        else:
2912            data = data.copy()
2913
2914        scaler = StandardScaler()
2915
2916        scaled_data = scaler.fit_transform(data)
2917
2918        scaled_df = pd.DataFrame(scaled_data, columns=data.columns)
2919
2920        cor = scaled_df.corr(method=method)
2921        cor_df = cor.stack().reset_index()
2922        cor_df.columns = ["cell1", "cell2", "correlation"]
2923
2924        distances = pdist(scaled_df.T, metric="euclidean")
2925        dist_mat = pd.DataFrame(
2926            squareform(distances), index=scaled_df.columns, columns=scaled_df.columns
2927        )
2928        dist_df = dist_mat.stack().reset_index()
2929        dist_df.columns = ["cell1", "cell2", "euclidean_dist"]
2930
2931        full = pd.merge(cor_df, dist_df, on=["cell1", "cell2"])
2932
2933        full = full[full["cell1"] != full["cell2"]]
2934        full = full.reset_index(drop=True)
2935
2936        self.similarity = full

Estimate pairwise similarity and Euclidean distance between aggregated samples.

Parameters

method : str, default 'pearson' Correlation method to use (passed to pandas.DataFrame.corr()).

p_val : float, default 0.05 Adjusted p-value cutoff used to select marker features from self.var_data.

top_n : int, default 25 Number of top features per valid group to include.

Update

Computes a combined table with per-pair correlation and euclidean distance and stores it in self.similarity.

def similarity_plot( self, split_sets=True, set_info: bool = True, cmap='seismic', width=12, height=10):
2938    def similarity_plot(
2939        self,
2940        split_sets=True,
2941        set_info: bool = True,
2942        cmap="seismic",
2943        width=12,
2944        height=10,
2945    ):
2946        """
2947        Visualize pairwise similarity as a scatter plot.
2948
2949        Parameters
2950        ----------
2951        split_sets : bool, default True
2952            If True and set information is present, split plotting area roughly into two halves to visualize cross-set pairs.
2953
2954        set_info : bool, default True
2955            If True, keep the ' # set' annotation in labels; otherwise strip it.
2956
2957        cmap : str, default 'seismic'
2958            Color map for correlation (hue).
2959
2960        width : int, default 12
2961            Figure width.
2962
2963        height : int, default 10
2964            Figure height.
2965
2966        Returns
2967        -------
2968        matplotlib.figure.Figure
2969
2970        Raises
2971        ------
2972        ValueError
2973            If `self.similarity` is None.
2974
2975        Notes
2976        -----
2977        The function filters pairs by z-scored euclidean distance > 0 to focus on closer pairs.
2978        """
2979
2980        if self.similarity is None:
2981            raise ValueError(
2982                "Similarity data is missing. Please calculate similarity using self.estimating_similarity."
2983            )
2984
2985        similarity_data = self.similarity
2986
2987        if " # " in similarity_data["cell1"][0]:
2988            similarity_data["set1"] = [
2989                re.sub(".*# ", "", x) for x in similarity_data["cell1"]
2990            ]
2991            similarity_data["set2"] = [
2992                re.sub(".*# ", "", x) for x in similarity_data["cell2"]
2993            ]
2994
2995        if split_sets and " # " in similarity_data["cell1"][0]:
2996            sets = list(
2997                set(list(similarity_data["set1"]) + list(similarity_data["set2"]))
2998            )
2999
3000            mm = math.ceil(len(sets) / 2)
3001
3002            x_s = sets[0:mm]
3003            y_s = sets[mm : len(sets)]
3004
3005            similarity_data = similarity_data[similarity_data["set1"].isin(x_s)]
3006            similarity_data = similarity_data[similarity_data["set2"].isin(y_s)]
3007
3008            similarity_data = similarity_data.sort_values(["set1", "set2"])
3009
3010        if set_info is False and " # " in similarity_data["cell1"][0]:
3011            similarity_data["cell1"] = [
3012                re.sub(" #.*", "", x) for x in similarity_data["cell1"]
3013            ]
3014            similarity_data["cell2"] = [
3015                re.sub(" #.*", "", x) for x in similarity_data["cell2"]
3016            ]
3017
3018        similarity_data["-euclidean_zscore"] = -zscore(
3019            similarity_data["euclidean_dist"]
3020        )
3021
3022        similarity_data = similarity_data[similarity_data["-euclidean_zscore"] > 0]
3023
3024        fig = plt.figure(figsize=(width, height))
3025        sns.scatterplot(
3026            data=similarity_data,
3027            x="cell1",
3028            y="cell2",
3029            hue="correlation",
3030            size="-euclidean_zscore",
3031            sizes=(1, 100),
3032            palette=cmap,
3033            alpha=1,
3034            edgecolor="black",
3035        )
3036
3037        plt.xticks(rotation=90)
3038        plt.yticks(rotation=0)
3039        plt.xlabel("Cell 1")
3040        plt.ylabel("Cell 2")
3041        plt.legend(bbox_to_anchor=(1.05, 1), loc="upper left")
3042
3043        plt.grid(True, alpha=0.6)
3044
3045        plt.tight_layout()
3046
3047        return fig

Visualize pairwise similarity as a scatter plot.

Parameters

split_sets : bool, default True If True and set information is present, split plotting area roughly into two halves to visualize cross-set pairs.

set_info : bool, default True If True, keep the ' # set' annotation in labels; otherwise strip it.

cmap : str, default 'seismic' Color map for correlation (hue).

width : int, default 12 Figure width.

height : int, default 10 Figure height.

Returns

matplotlib.figure.Figure

Raises

ValueError If self.similarity is None.

Notes

The function filters pairs by z-scored euclidean distance > 0 to focus on closer pairs.

def spatial_similarity( self, set_info: bool = True, bandwidth=1, n_neighbors=5, min_dist=0.1, legend_split=2, point_size=100, spread=1.0, set_op_mix_ratio=1.0, local_connectivity=1, repulsion_strength=1.0, negative_sample_rate=5, threshold=0.1, width=12, height=10):
3049    def spatial_similarity(
3050        self,
3051        set_info: bool = True,
3052        bandwidth=1,
3053        n_neighbors=5,
3054        min_dist=0.1,
3055        legend_split=2,
3056        point_size=100,
3057        spread=1.0,
3058        set_op_mix_ratio=1.0,
3059        local_connectivity=1,
3060        repulsion_strength=1.0,
3061        negative_sample_rate=5,
3062        threshold=0.1,
3063        width=12,
3064        height=10,
3065    ):
3066        """
3067        Create a spatial UMAP-like visualization of similarity relationships between samples.
3068
3069        Parameters
3070        ----------
3071        set_info : bool, default True
3072            If True, retain set information in labels.
3073
3074        bandwidth : float, default 1
3075            Bandwidth used by MeanShift for clustering polygons.
3076
3077        point_size : float, default 100
3078            Size of scatter points.
3079
3080        legend_split : int, default 2
3081            Number of columns in legend.
3082
3083        n_neighbors, min_dist, spread, set_op_mix_ratio, local_connectivity, repulsion_strength, negative_sample_rate : parameters passed to UMAP.
3084
3085        threshold : float, default 0.1
3086            Minimum text distance for label adjustment to avoid overlap.
3087
3088        width : int, default 12
3089            Figure width.
3090
3091        height : int, default 10
3092            Figure height.
3093
3094        Returns
3095        -------
3096        matplotlib.figure.Figure
3097
3098        Raises
3099        ------
3100        ValueError
3101            If `self.similarity` is None.
3102
3103        Notes
3104        -----
3105        Builds a precomputed distance matrix combining correlation and euclidean distance,
3106        runs UMAP with metric='precomputed', then overlays cluster hulls (MeanShift + convex hull)
3107        and arrows to indicate nearest neighbors (minimal combined distance).
3108        """
3109
3110        if self.similarity is None:
3111            raise ValueError(
3112                "Similarity data is missing. Please calculate similarity using self.estimating_similarity."
3113            )
3114
3115        similarity_data = self.similarity
3116
3117        sim = similarity_data["correlation"]
3118        sim_scaled = (sim - sim.min()) / (sim.max() - sim.min())
3119        eu_dist = similarity_data["euclidean_dist"]
3120        eu_dist_scaled = (eu_dist - eu_dist.min()) / (eu_dist.max() - eu_dist.min())
3121
3122        similarity_data["combo_dist"] = (1 - sim_scaled) * eu_dist_scaled
3123
3124        # for nn target
3125        arrow_df = similarity_data.copy()
3126        arrow_df = similarity_data.loc[
3127            similarity_data.groupby("cell1")["combo_dist"].idxmin()
3128        ]
3129
3130        cells = sorted(set(similarity_data["cell1"]) | set(similarity_data["cell2"]))
3131        combo_matrix = pd.DataFrame(0, index=cells, columns=cells, dtype=float)
3132
3133        for _, row in similarity_data.iterrows():
3134            combo_matrix.loc[row["cell1"], row["cell2"]] = row["combo_dist"]
3135            combo_matrix.loc[row["cell2"], row["cell1"]] = row["combo_dist"]
3136
3137        umap_model = umap.UMAP(
3138            n_components=2,
3139            metric="precomputed",
3140            n_neighbors=n_neighbors,
3141            min_dist=min_dist,
3142            spread=spread,
3143            set_op_mix_ratio=set_op_mix_ratio,
3144            local_connectivity=set_op_mix_ratio,
3145            repulsion_strength=repulsion_strength,
3146            negative_sample_rate=negative_sample_rate,
3147            transform_seed=42,
3148            init="spectral",
3149            random_state=42,
3150            verbose=True,
3151        )
3152
3153        coords = umap_model.fit_transform(combo_matrix.values)
3154        cell_names = list(combo_matrix.index)
3155        num_cells = len(cell_names)
3156        palette = sns.color_palette("tab20c", num_cells)
3157
3158        if "#" in cell_names[0]:
3159            avsets = set(
3160                [re.sub(".*# ", "", x) for x in similarity_data["cell1"]]
3161                + [re.sub(".*# ", "", x) for x in similarity_data["cell2"]]
3162            )
3163            num_sets = len(avsets)
3164            color_indices = [i * len(palette) // num_sets for i in range(num_sets)]
3165            color_mapping_sets = {
3166                set_name: palette[i] for i, set_name in zip(color_indices, avsets)
3167            }
3168            color_mapping = {
3169                name: color_mapping_sets[re.sub(".*# ", "", name)]
3170                for i, name in enumerate(cell_names)
3171            }
3172        else:
3173            color_mapping = {name: palette[i] for i, name in enumerate(cell_names)}
3174
3175        meanshift = MeanShift(bandwidth=bandwidth)
3176        labels = meanshift.fit_predict(coords)
3177
3178        fig = plt.figure(figsize=(width, height))
3179        ax = plt.gca()
3180
3181        unique_labels = set(labels)
3182        cluster_palette = sns.color_palette("hls", len(unique_labels))
3183
3184        for label in unique_labels:
3185            if label == -1:
3186                continue
3187            cluster_coords = coords[labels == label]
3188            if len(cluster_coords) < 3:
3189                continue
3190
3191            hull = ConvexHull(cluster_coords)
3192            hull_points = cluster_coords[hull.vertices]
3193
3194            centroid = np.mean(hull_points, axis=0)
3195            expanded = hull_points + 0.05 * (hull_points - centroid)
3196
3197            poly = Polygon(
3198                expanded,
3199                closed=True,
3200                facecolor=cluster_palette[label],
3201                edgecolor="none",
3202                alpha=0.2,
3203                zorder=1,
3204            )
3205            ax.add_patch(poly)
3206
3207        texts = []
3208        for i, (x, y) in enumerate(coords):
3209            plt.scatter(
3210                x,
3211                y,
3212                s=point_size,
3213                color=color_mapping[cell_names[i]],
3214                edgecolors="black",
3215                linewidths=0.5,
3216                zorder=2,
3217            )
3218            texts.append(
3219                ax.text(
3220                    x, y, str(i), ha="center", va="center", fontsize=8, color="black"
3221                )
3222            )
3223
3224        def dist(p1, p2):
3225            return np.sqrt((p1[0] - p2[0]) ** 2 + (p1[1] - p2[1]) ** 2)
3226
3227        texts_to_adjust = []
3228        for i, t1 in enumerate(texts):
3229            for j, t2 in enumerate(texts):
3230                if i >= j:
3231                    continue
3232                d = dist(
3233                    (t1.get_position()[0], t1.get_position()[1]),
3234                    (t2.get_position()[0], t2.get_position()[1]),
3235                )
3236                if d < threshold:
3237                    if t1 not in texts_to_adjust:
3238                        texts_to_adjust.append(t1)
3239                    if t2 not in texts_to_adjust:
3240                        texts_to_adjust.append(t2)
3241
3242        adjust_text(
3243            texts_to_adjust,
3244            expand_text=(1.0, 1.0),
3245            force_text=0.9,
3246            arrowprops=dict(arrowstyle="-", color="gray", lw=0.1),
3247            ax=ax,
3248        )
3249
3250        for _, row in arrow_df.iterrows():
3251            try:
3252                idx1 = cell_names.index(row["cell1"])
3253                idx2 = cell_names.index(row["cell2"])
3254            except ValueError:
3255                continue
3256            x1, y1 = coords[idx1]
3257            x2, y2 = coords[idx2]
3258            arrow = FancyArrowPatch(
3259                (x1, y1),
3260                (x2, y2),
3261                arrowstyle="->",
3262                color="gray",
3263                linewidth=1.5,
3264                alpha=0.5,
3265                mutation_scale=12,
3266                zorder=0,
3267            )
3268            ax.add_patch(arrow)
3269
3270        if set_info is False and " # " in cell_names[0]:
3271
3272            legend_elements = [
3273                Patch(
3274                    facecolor=color_mapping[name],
3275                    edgecolor="black",
3276                    label=f"{i}{re.sub(' #.*', '', name)}",
3277                )
3278                for i, name in enumerate(cell_names)
3279            ]
3280
3281        else:
3282
3283            legend_elements = [
3284                Patch(
3285                    facecolor=color_mapping[name],
3286                    edgecolor="black",
3287                    label=f"{i}{name}",
3288                )
3289                for i, name in enumerate(cell_names)
3290            ]
3291
3292        plt.legend(
3293            handles=legend_elements,
3294            title="Cells",
3295            bbox_to_anchor=(1.05, 1),
3296            loc="upper left",
3297            ncol=legend_split,
3298        )
3299
3300        plt.xlabel("UMAP 1")
3301        plt.ylabel("UMAP 2")
3302        plt.grid(False)
3303        plt.show()
3304
3305        return fig

Create a spatial UMAP-like visualization of similarity relationships between samples.

Parameters

set_info : bool, default True If True, retain set information in labels.

bandwidth : float, default 1 Bandwidth used by MeanShift for clustering polygons.

point_size : float, default 100 Size of scatter points.

legend_split : int, default 2 Number of columns in legend.

n_neighbors, min_dist, spread, set_op_mix_ratio, local_connectivity, repulsion_strength, negative_sample_rate : parameters passed to UMAP.

threshold : float, default 0.1 Minimum text distance for label adjustment to avoid overlap.

width : int, default 12 Figure width.

height : int, default 10 Figure height.

Returns

matplotlib.figure.Figure

Raises

ValueError If self.similarity is None.

Notes

Builds a precomputed distance matrix combining correlation and euclidean distance, runs UMAP with metric='precomputed', then overlays cluster hulls (MeanShift + convex hull) and arrows to indicate nearest neighbors (minimal combined distance).

def subcluster_prepare(self, features: list, cluster: str):
3309    def subcluster_prepare(self, features: list, cluster: str):
3310        """
3311        Prepare a `Clustering` object for subcluster analysis on a selected parent cluster.
3312
3313        Parameters
3314        ----------
3315        features : list
3316            Features to include for subcluster analysis.
3317
3318        cluster : str
3319            Parent cluster name (used to select matching cells).
3320
3321        Update
3322        ------
3323        Initializes `self.subclusters_` as a new `Clustering` instance containing the
3324        reduced data for the given cluster and stores `current_features` and `current_cluster`.
3325        """
3326
3327        dat = self.normalized_data
3328        dat.columns = list(self.input_metadata["cell_names"])
3329
3330        dat = reduce_data(self.normalized_data, features=features, names=[cluster])
3331
3332        self.subclusters_ = Clustering(data=dat, metadata=None)
3333
3334        self.subclusters_.current_features = features
3335        self.subclusters_.current_cluster = cluster

Prepare a Clustering object for subcluster analysis on a selected parent cluster.

Parameters

features : list Features to include for subcluster analysis.

cluster : str Parent cluster name (used to select matching cells).

Update

Initializes self.subclusters_ as a new Clustering instance containing the reduced data for the given cluster and stores current_features and current_cluster.

def define_subclusters( self, umap_num: int = 2, eps: float = 0.5, min_samples: int = 10, n_neighbors: int = 5, min_dist: float = 0.1, spread: float = 1.0, set_op_mix_ratio: float = 1.0, local_connectivity: int = 1, repulsion_strength: float = 1.0, negative_sample_rate: int = 5, width=8, height=6):
3337    def define_subclusters(
3338        self,
3339        umap_num: int = 2,
3340        eps: float = 0.5,
3341        min_samples: int = 10,
3342        n_neighbors: int = 5,
3343        min_dist: float = 0.1,
3344        spread: float = 1.0,
3345        set_op_mix_ratio: float = 1.0,
3346        local_connectivity: int = 1,
3347        repulsion_strength: float = 1.0,
3348        negative_sample_rate: int = 5,
3349        width=8,
3350        height=6,
3351    ):
3352        """
3353        Compute UMAP and DBSCAN clustering within a previously prepared subcluster dataset.
3354
3355        Parameters
3356        ----------
3357        umap_num : int, default 2
3358            Number of UMAP dimensions to compute.
3359
3360        eps : float, default 0.5
3361            DBSCAN eps parameter.
3362
3363        min_samples : int, default 10
3364            DBSCAN min_samples parameter.
3365
3366        n_neighbors, min_dist, spread, set_op_mix_ratio, local_connectivity, repulsion_strength, negative_sample_rate, width, height :
3367            Additional parameters passed to UMAP / plotting / MeanShift as appropriate.
3368
3369        Update
3370        -------
3371        Stores cluster labels in `self.subclusters_.subclusters`.
3372
3373        Raises
3374        ------
3375        RuntimeError
3376            If `self.subclusters_` has not been prepared.
3377        """
3378
3379        if self.subclusters_ is None:
3380            raise RuntimeError(
3381                "Nothing to return. 'self.subcluster_prepare' was not conducted!"
3382            )
3383
3384        self.subclusters_.perform_UMAP(
3385            factorize=False,
3386            umap_num=umap_num,
3387            pc_num=0,
3388            harmonized=False,
3389            n_neighbors=n_neighbors,
3390            min_dist=min_dist,
3391            spread=spread,
3392            set_op_mix_ratio=set_op_mix_ratio,
3393            local_connectivity=local_connectivity,
3394            repulsion_strength=repulsion_strength,
3395            negative_sample_rate=negative_sample_rate,
3396            width=width,
3397            height=height,
3398        )
3399
3400        fig = self.subclusters_.find_clusters_UMAP(
3401            umap_n=umap_num,
3402            eps=eps,
3403            min_samples=min_samples,
3404            width=width,
3405            height=height,
3406        )
3407
3408        clusters = self.subclusters_.return_clusters(clusters="umap")
3409
3410        self.subclusters_.subclusters = [str(x) for x in list(clusters)]
3411
3412        return fig

Compute UMAP and DBSCAN clustering within a previously prepared subcluster dataset.

Parameters

umap_num : int, default 2 Number of UMAP dimensions to compute.

eps : float, default 0.5 DBSCAN eps parameter.

min_samples : int, default 10 DBSCAN min_samples parameter.

n_neighbors, min_dist, spread, set_op_mix_ratio, local_connectivity, repulsion_strength, negative_sample_rate, width, height : Additional parameters passed to UMAP / plotting / MeanShift as appropriate.

Update

Stores cluster labels in self.subclusters_.subclusters.

Raises

RuntimeError If self.subclusters_ has not been prepared.

def subcluster_features_scatter( self, colors='viridis', hclust='complete', scale=False, img_width=3, img_high=5, label_size=6, size_scale=70, y_lab='Genes', legend_lab='normalized', bbox_to_anchor_scale: int = 25, bbox_to_anchor_perc: tuple = (0.91, 0.63)):
3414    def subcluster_features_scatter(
3415        self,
3416        colors="viridis",
3417        hclust="complete",
3418        scale=False,
3419        img_width=3,
3420        img_high=5,
3421        label_size=6,
3422        size_scale=70,
3423        y_lab="Genes",
3424        legend_lab="normalized",
3425        bbox_to_anchor_scale: int = 25,
3426        bbox_to_anchor_perc: tuple = (0.91, 0.63),
3427    ):
3428        """
3429        Create a features-scatter visualization for the subclusters (averaged and occurrence).
3430
3431        Parameters
3432        ----------
3433        colors : str, default 'viridis'
3434            Colormap name passed to `features_scatter`.
3435
3436        hclust : str or None
3437            Hierarchical clustering linkage to order rows/columns.
3438
3439        scale: bool, default False
3440            If True, expression data will be scaled (0–1) across the rows (features).
3441
3442        img_width, img_high : float
3443            Figure size.
3444
3445        label_size : int
3446            Font size for labels.
3447
3448        size_scale : int
3449            Bubble size scaling.
3450
3451        y_lab : str
3452            X axis label.
3453
3454        legend_lab : str
3455            Colorbar label.
3456
3457        bbox_to_anchor_scale : int, default=25
3458            Vertical scale (percentage) for positioning the colorbar.
3459
3460        bbox_to_anchor_perc : tuple, default=(0.91, 0.63)
3461            Anchor position for the size legend (percent bubble legend).
3462
3463        Returns
3464        -------
3465        matplotlib.figure.Figure
3466
3467        Raises
3468        ------
3469        RuntimeError
3470            If subcluster preparation/definition has not been run.
3471        """
3472
3473        if self.subclusters_ is None or self.subclusters_.subclusters is None:
3474            raise RuntimeError(
3475                "Nothing to return. 'self.subcluster_prepare' -> 'self.define_subclusters' pip was not conducted!"
3476            )
3477
3478        dat = self.normalized_data
3479        dat.columns = list(self.input_metadata["cell_names"])
3480
3481        dat = reduce_data(
3482            self.normalized_data,
3483            features=self.subclusters_.current_features,
3484            names=[self.subclusters_.current_cluster],
3485        )
3486
3487        dat.columns = self.subclusters_.subclusters
3488
3489        avg = average(dat)
3490        occ = occurrence(dat)
3491
3492        scatter = features_scatter(
3493            expression_data=avg,
3494            occurence_data=occ,
3495            features=None,
3496            scale=scale,
3497            metadata_list=None,
3498            colors=colors,
3499            hclust=hclust,
3500            img_width=img_width,
3501            img_high=img_high,
3502            label_size=label_size,
3503            size_scale=size_scale,
3504            y_lab=y_lab,
3505            legend_lab=legend_lab,
3506            bbox_to_anchor_scale=bbox_to_anchor_scale,
3507            bbox_to_anchor_perc=bbox_to_anchor_perc,
3508        )
3509
3510        return scatter

Create a features-scatter visualization for the subclusters (averaged and occurrence).

Parameters

colors : str, default 'viridis' Colormap name passed to features_scatter.

hclust : str or None Hierarchical clustering linkage to order rows/columns.

scale: bool, default False If True, expression data will be scaled (0–1) across the rows (features).

img_width, img_high : float Figure size.

label_size : int Font size for labels.

size_scale : int Bubble size scaling.

y_lab : str X axis label.

legend_lab : str Colorbar label.

bbox_to_anchor_scale : int, default=25 Vertical scale (percentage) for positioning the colorbar.

bbox_to_anchor_perc : tuple, default=(0.91, 0.63) Anchor position for the size legend (percent bubble legend).

Returns

matplotlib.figure.Figure

Raises

RuntimeError If subcluster preparation/definition has not been run.

def subcluster_DEG_scatter( self, top_n=3, min_exp=0, min_pct=0.25, p_val=0.05, colors='viridis', hclust='complete', scale=False, img_width=3, img_high=5, label_size=6, size_scale=70, y_lab='Genes', legend_lab='normalized', bbox_to_anchor_scale: int = 25, bbox_to_anchor_perc: tuple = (0.91, 0.63), n_proc=10):
3512    def subcluster_DEG_scatter(
3513        self,
3514        top_n=3,
3515        min_exp=0,
3516        min_pct=0.25,
3517        p_val=0.05,
3518        colors="viridis",
3519        hclust="complete",
3520        scale=False,
3521        img_width=3,
3522        img_high=5,
3523        label_size=6,
3524        size_scale=70,
3525        y_lab="Genes",
3526        legend_lab="normalized",
3527        bbox_to_anchor_scale: int = 25,
3528        bbox_to_anchor_perc: tuple = (0.91, 0.63),
3529        n_proc=10,
3530    ):
3531        """
3532        Plot top differential features (DEGs) for subclusters as a features-scatter.
3533
3534        Parameters
3535        ----------
3536        top_n : int, default 3
3537            Number of top features per subcluster to show.
3538
3539        min_exp : float, default 0
3540            Minimum expression threshold passed to `statistic`.
3541
3542        min_pct : float, default 0.25
3543            Minimum percent expressed in target group.
3544
3545        p_val: float, default 0.05
3546            Maximum p-value for visualizing features.
3547
3548        n_proc : int, default 10
3549            Parallel jobs used for DEG calculation.
3550
3551        scale: bool, default False
3552            If True, expression_data will be scaled (0–1) across the rows (features).
3553
3554        colors : str, default='viridis'
3555            Colormap for expression values.
3556
3557        hclust : str or None, default='complete'
3558            Linkage method for hierarchical clustering. If None, no clustering
3559            is performed.
3560
3561        img_width : int or float, default=8
3562            Width of the plot in inches.
3563
3564        img_high : int or float, default=5
3565            Height of the plot in inches.
3566
3567        label_size : int, default=10
3568            Font size for axis labels and ticks.
3569
3570        size_scale : int or float, default=100
3571            Scaling factor for bubble sizes.
3572
3573        y_lab : str, default='Genes'
3574            Label for the x-axis.
3575
3576        legend_lab : str, default='normalized'
3577            Label for the colorbar legend.
3578
3579        bbox_to_anchor_scale : int, default=25
3580            Vertical scale (percentage) for positioning the colorbar.
3581
3582        bbox_to_anchor_perc : tuple, default=(0.91, 0.63)
3583            Anchor position for the size legend (percent bubble legend).
3584
3585        Returns
3586        -------
3587        matplotlib.figure.Figure
3588
3589        Raises
3590        ------
3591        RuntimeError
3592            If subcluster preparation/definition has not been run.
3593
3594        Notes
3595        -----
3596        Internally calls `calc_DEG` (or equivalent) to obtain statistics, filters
3597        by p-value and effect-size, selects top features per valid group and plots them.
3598        """
3599
3600        if self.subclusters_ is None or self.subclusters_.subclusters is None:
3601            raise RuntimeError(
3602                "Nothing to return. 'self.subcluster_prepare' -> 'self.define_subclusters' pip was not conducted!"
3603            )
3604
3605        dat = self.normalized_data
3606        dat.columns = list(self.input_metadata["cell_names"])
3607
3608        dat = reduce_data(
3609            self.normalized_data, names=[self.subclusters_.current_cluster]
3610        )
3611
3612        dat.columns = self.subclusters_.subclusters
3613
3614        deg_stats = calc_DEG(
3615            dat,
3616            metadata_list=None,
3617            entities="All",
3618            sets=None,
3619            min_exp=min_exp,
3620            min_pct=min_pct,
3621            n_proc=n_proc,
3622        )
3623
3624        deg_stats = deg_stats[deg_stats["p_val"] <= p_val]
3625        deg_stats = deg_stats[deg_stats["log(FC)"] > 0]
3626
3627        deg_stats = (
3628            deg_stats.sort_values(
3629                ["valid_group", "esm", "log(FC)"], ascending=[True, False, False]
3630            )
3631            .groupby("valid_group")
3632            .head(top_n)
3633        )
3634
3635        dat = reduce_data(dat, features=list(set(deg_stats["feature"])))
3636
3637        avg = average(dat)
3638        occ = occurrence(dat)
3639
3640        scatter = features_scatter(
3641            expression_data=avg,
3642            occurence_data=occ,
3643            features=None,
3644            metadata_list=None,
3645            colors=colors,
3646            hclust=hclust,
3647            img_width=img_width,
3648            img_high=img_high,
3649            label_size=label_size,
3650            size_scale=size_scale,
3651            y_lab=y_lab,
3652            legend_lab=legend_lab,
3653            bbox_to_anchor_scale=bbox_to_anchor_scale,
3654            bbox_to_anchor_perc=bbox_to_anchor_perc,
3655        )
3656
3657        return scatter

Plot top differential features (DEGs) for subclusters as a features-scatter.

Parameters

top_n : int, default 3 Number of top features per subcluster to show.

min_exp : float, default 0 Minimum expression threshold passed to statistic.

min_pct : float, default 0.25 Minimum percent expressed in target group.

p_val: float, default 0.05 Maximum p-value for visualizing features.

n_proc : int, default 10 Parallel jobs used for DEG calculation.

scale: bool, default False If True, expression_data will be scaled (0–1) across the rows (features).

colors : str, default='viridis' Colormap for expression values.

hclust : str or None, default='complete' Linkage method for hierarchical clustering. If None, no clustering is performed.

img_width : int or float, default=8 Width of the plot in inches.

img_high : int or float, default=5 Height of the plot in inches.

label_size : int, default=10 Font size for axis labels and ticks.

size_scale : int or float, default=100 Scaling factor for bubble sizes.

y_lab : str, default='Genes' Label for the x-axis.

legend_lab : str, default='normalized' Label for the colorbar legend.

bbox_to_anchor_scale : int, default=25 Vertical scale (percentage) for positioning the colorbar.

bbox_to_anchor_perc : tuple, default=(0.91, 0.63) Anchor position for the size legend (percent bubble legend).

Returns

matplotlib.figure.Figure

Raises

RuntimeError If subcluster preparation/definition has not been run.

Notes

Internally calls calc_DEG (or equivalent) to obtain statistics, filters by p-value and effect-size, selects top features per valid group and plots them.

def accept_subclusters(self):
3659    def accept_subclusters(self):
3660        """
3661        Commit subcluster labels into the main `input_metadata` by renaming cell names.
3662
3663        The method replaces occurrences of the parent cluster name in `self.input_metadata['cell_names']`
3664        with the expanded names that include subcluster suffixes (via `add_subnames`),
3665        then clears `self.subclusters_`.
3666
3667        Update
3668        ------
3669        Modifies `self.input_metadata['cell_names']`.
3670
3671        Resets `self.subclusters_` to None.
3672
3673        Raises
3674        ------
3675        RuntimeError
3676            If `self.subclusters_` is not defined or subclusters were not computed.
3677        """
3678
3679        if self.subclusters_ is None or self.subclusters_.subclusters is None:
3680            raise RuntimeError(
3681                "Nothing to return. 'self.subcluster_prepare' -> 'self.define_subclusters' pip was not conducted!"
3682            )
3683
3684        new_meta = add_subnames(
3685            list(self.input_metadata["cell_names"]),
3686            parent_name=self.subclusters_.current_cluster,
3687            new_clusters=self.subclusters_.subclusters,
3688        )
3689
3690        self.input_metadata["cell_names"] = new_meta
3691
3692        self.subclusters_ = None

Commit subcluster labels into the main input_metadata by renaming cell names.

The method replaces occurrences of the parent cluster name in self.input_metadata['cell_names'] with the expanded names that include subcluster suffixes (via add_subnames), then clears self.subclusters_.

Update

Modifies self.input_metadata['cell_names'].

Resets self.subclusters_ to None.

Raises

RuntimeError If self.subclusters_ is not defined or subclusters were not computed.

def scatter_plot( self, names: list | None = None, features: list | None = None, name_slot: str = 'cell_names', scale=True, colors='viridis', hclust=None, img_width=15, img_high=1, label_size=10, size_scale=200, y_lab='Genes', legend_lab='log(CPM + 1)', set_box_size: float | int = 5, set_box_high: float | int = 5, bbox_to_anchor_scale=25, bbox_to_anchor_perc=(0.9, -0.24), bbox_to_anchor_group=(1.01, 0.4)):
3694    def scatter_plot(
3695        self,
3696        names: list | None = None,
3697        features: list | None = None,
3698        name_slot: str = "cell_names",
3699        scale=True,
3700        colors="viridis",
3701        hclust=None,
3702        img_width=15,
3703        img_high=1,
3704        label_size=10,
3705        size_scale=200,
3706        y_lab="Genes",
3707        legend_lab="log(CPM + 1)",
3708        set_box_size: float | int = 5,
3709        set_box_high: float | int = 5,
3710        bbox_to_anchor_scale=25,
3711        bbox_to_anchor_perc=(0.90, -0.24),
3712        bbox_to_anchor_group=(1.01, 0.4),
3713    ):
3714        """
3715        Create a bubble scatter plot of selected features across samples inside project.
3716
3717        Each point represents a feature-sample pair, where the color encodes the
3718        expression value and the size encodes occurrence or relative abundance.
3719        Optionally, hierarchical clustering can be applied to order rows and columns.
3720
3721        Parameters
3722        ----------
3723        names : list, str, or None
3724            Names of samples to include. If None, all samples are considered.
3725
3726        features : list, str, or None
3727            Names of features to include. If None, all features are considered.
3728
3729        name_slot : str
3730            Column in metadata to use as sample names.
3731
3732        scale: bool, default False
3733            If True, expression_data will be scaled (0–1) across the rows (features).
3734
3735        colors : str, default='viridis'
3736            Colormap for expression values.
3737
3738        hclust : str or None, default='complete'
3739            Linkage method for hierarchical clustering. If None, no clustering
3740            is performed.
3741
3742        img_width : int or float, default=8
3743            Width of the plot in inches.
3744
3745        img_high : int or float, default=5
3746            Height of the plot in inches.
3747
3748        label_size : int, default=10
3749            Font size for axis labels and ticks.
3750
3751        size_scale : int or float, default=100
3752            Scaling factor for bubble sizes.
3753
3754        y_lab : str, default='Genes'
3755            Label for the x-axis.
3756
3757        legend_lab : str, default='log(CPM + 1)'
3758            Label for the colorbar legend.
3759
3760        bbox_to_anchor_scale : int, default=25
3761            Vertical scale (percentage) for positioning the colorbar.
3762
3763        bbox_to_anchor_perc : tuple, default=(0.91, 0.63)
3764            Anchor position for the size legend (percent bubble legend).
3765
3766        bbox_to_anchor_group : tuple, default=(1.01, 0.4)
3767            Anchor position for the group legend.
3768
3769        Returns
3770        -------
3771        matplotlib.figure.Figure
3772            The generated scatter plot figure.
3773
3774        Notes
3775        -----
3776        Colors represent expression values normalized to the colormap.
3777        """
3778
3779        prtd, met = self.get_partial_data(
3780            names=names, features=features, name_slot=name_slot, inc_metadata=True
3781        )
3782
3783        if scale:
3784
3785            legend_lab = "Scaled\n" + legend_lab
3786
3787            scaler = MinMaxScaler(feature_range=(0, 1))
3788            prtd = pd.DataFrame(
3789                scaler.fit_transform(prtd.T).T,
3790                index=prtd.index,
3791                columns=prtd.columns,
3792            )
3793
3794        prtd.columns = prtd.columns + "#" + met["sets"]
3795
3796        prtd_avg = average(prtd)
3797
3798        meta_sets = [re.sub(".*#", "", x) for x in prtd_avg.columns]
3799
3800        prtd_avg.columns = [re.sub("#.*", "", x) for x in prtd_avg.columns]
3801
3802        prtd_occ = occurrence(prtd)
3803
3804        prtd_occ.columns = [re.sub("#.*", "", x) for x in prtd_occ.columns]
3805
3806        fig_scatter = features_scatter(
3807            expression_data=prtd_avg,
3808            occurence_data=prtd_occ,
3809            scale=scale,
3810            features=None,
3811            metadata_list=meta_sets,
3812            colors=colors,
3813            hclust=hclust,
3814            img_width=img_width,
3815            img_high=img_high,
3816            label_size=label_size,
3817            size_scale=size_scale,
3818            y_lab=y_lab,
3819            legend_lab=legend_lab,
3820            set_box_size=set_box_size,
3821            set_box_high=set_box_high,
3822            bbox_to_anchor_scale=bbox_to_anchor_scale,
3823            bbox_to_anchor_perc=bbox_to_anchor_perc,
3824            bbox_to_anchor_group=bbox_to_anchor_group,
3825        )
3826
3827        return fig_scatter

Create a bubble scatter plot of selected features across samples inside project.

Each point represents a feature-sample pair, where the color encodes the expression value and the size encodes occurrence or relative abundance. Optionally, hierarchical clustering can be applied to order rows and columns.

Parameters

names : list, str, or None Names of samples to include. If None, all samples are considered.

features : list, str, or None Names of features to include. If None, all features are considered.

name_slot : str Column in metadata to use as sample names.

scale: bool, default False If True, expression_data will be scaled (0–1) across the rows (features).

colors : str, default='viridis' Colormap for expression values.

hclust : str or None, default='complete' Linkage method for hierarchical clustering. If None, no clustering is performed.

img_width : int or float, default=8 Width of the plot in inches.

img_high : int or float, default=5 Height of the plot in inches.

label_size : int, default=10 Font size for axis labels and ticks.

size_scale : int or float, default=100 Scaling factor for bubble sizes.

y_lab : str, default='Genes' Label for the x-axis.

legend_lab : str, default='log(CPM + 1)' Label for the colorbar legend.

bbox_to_anchor_scale : int, default=25 Vertical scale (percentage) for positioning the colorbar.

bbox_to_anchor_perc : tuple, default=(0.91, 0.63) Anchor position for the size legend (percent bubble legend).

bbox_to_anchor_group : tuple, default=(1.01, 0.4) Anchor position for the group legend.

Returns

matplotlib.figure.Figure The generated scatter plot figure.

Notes

Colors represent expression values normalized to the colormap.

def data_composition( self, features_count: list | None, name_slot: str = 'cell_names', set_sep: bool = True):
3829    def data_composition(
3830        self,
3831        features_count: list | None,
3832        name_slot: str = "cell_names",
3833        set_sep: bool = True,
3834    ):
3835        """
3836        Compute composition of cell types in data set.
3837
3838        This function counts the occurrences of specific cells (e.g., cell types, subtypes)
3839        within metadata entries, calculates their relative percentages, and stores
3840        the results in `self.composition_data`.
3841
3842        Parameters
3843        ----------
3844        features_count : list or None
3845            List of features (part or full names) to be counted.
3846            If None, all unique elements from the specified `name_slot` metadata field are used.
3847
3848        name_slot : str, default 'cell_names'
3849            Metadata field containing sample identifiers or labels.
3850
3851        set_sep : bool, default True
3852            If True and multiple sets are present in metadata, compute composition
3853            separately for each set.
3854
3855        Update
3856        -------
3857        Stores results in `self.composition_data` as a pandas DataFrame with:
3858        - 'name': feature name
3859        - 'n': number of occurrences
3860        - 'pct': percentage of occurrences
3861        - 'set' (if applicable): dataset identifier
3862        """
3863
3864        validated_list = list(self.input_metadata[name_slot])
3865        sets = list(self.input_metadata["sets"])
3866
3867        if features_count is None:
3868            features_count = list(set(self.input_metadata[name_slot]))
3869
3870        if set_sep and len(set(sets)) > 1:
3871
3872            final_res = pd.DataFrame()
3873
3874            for s in set(sets):
3875                print(s)
3876
3877                mask = [True if s == x else False for x in sets]
3878
3879                tmp_val_list = np.array(validated_list)
3880
3881                tmp_val_list = list(tmp_val_list[mask])
3882
3883                res_dict = {"name": [], "n": [], "set": []}
3884
3885                for f in tqdm(features_count):
3886                    res_dict["n"].append(
3887                        sum(1 for element in tmp_val_list if f in element)
3888                    )
3889                    res_dict["name"].append(f)
3890                    res_dict["set"].append(s)
3891                    res = pd.DataFrame(res_dict)
3892                    res["pct"] = res["n"] / sum(res["n"]) * 100
3893                    res["pct"] = res["pct"].round(2)
3894
3895                final_res = pd.concat([final_res, res])
3896
3897            res = final_res.sort_values(["set", "pct"], ascending=[True, False])
3898
3899        else:
3900
3901            res_dict = {"name": [], "n": []}
3902
3903            for f in tqdm(features_count):
3904                res_dict["n"].append(
3905                    sum(1 for element in validated_list if f in element)
3906                )
3907                res_dict["name"].append(f)
3908
3909            res = pd.DataFrame(res_dict)
3910            res["pct"] = res["n"] / sum(res["n"]) * 100
3911            res["pct"] = res["pct"].round(2)
3912
3913            res = res.sort_values("pct", ascending=False)
3914
3915        self.composition_data = res

Compute composition of cell types in data set.

This function counts the occurrences of specific cells (e.g., cell types, subtypes) within metadata entries, calculates their relative percentages, and stores the results in self.composition_data.

Parameters

features_count : list or None List of features (part or full names) to be counted. If None, all unique elements from the specified name_slot metadata field are used.

name_slot : str, default 'cell_names' Metadata field containing sample identifiers or labels.

set_sep : bool, default True If True and multiple sets are present in metadata, compute composition separately for each set.

Update

Stores results in self.composition_data as a pandas DataFrame with:

  • 'name': feature name
  • 'n': number of occurrences
  • 'pct': percentage of occurrences
  • 'set' (if applicable): dataset identifier
def composition_pie( self, width=6, height=6, font_size=15, cmap: str = 'tab20', legend_split_col: int = 1, offset_labels: float | int = 0.5, legend_bbox: tuple = (1.15, 0.95)):
3917    def composition_pie(
3918        self,
3919        width=6,
3920        height=6,
3921        font_size=15,
3922        cmap: str = "tab20",
3923        legend_split_col: int = 1,
3924        offset_labels: float | int = 0.5,
3925        legend_bbox: tuple = (1.15, 0.95),
3926    ):
3927        """
3928        Visualize the composition of cell lineages using pie charts.
3929
3930        Generates pie charts showing the relative proportions of features stored
3931        in `self.composition_data`. If multiple sets are present, a separate
3932        chart is drawn for each set.
3933
3934        Parameters
3935        ----------
3936        width : int, default 6
3937            Width of the figure.
3938
3939        height : int, default 6
3940            Height of the figure (applied per set if multiple sets are plotted).
3941
3942        font_size : int, default 15
3943            Font size for labels and annotations.
3944
3945        cmap : str, default 'tab20'
3946            Colormap used for pie slices.
3947
3948        legend_split_col : int, default 1
3949            Number of columns in the legend.
3950
3951        offset_labels : float or int, default 0.5
3952            Spacing offset for label placement relative to pie slices.
3953
3954        legend_bbox : tuple, default (1.15, 0.95)
3955            Bounding box anchor position for the legend.
3956
3957        Returns
3958        -------
3959        matplotlib.figure.Figure
3960            Pie chart visualization of composition data.
3961        """
3962
3963        df = self.composition_data
3964
3965        if "set" in df.columns and len(set(df["set"])) > 1:
3966
3967            sets = list(set(df["set"]))
3968            fig, axes = plt.subplots(len(sets), 1, figsize=(width, height * len(sets)))
3969
3970            all_wedges = []
3971            cmap = plt.get_cmap(cmap)
3972
3973            set_nam = len(set(df["name"]))
3974
3975            legend_labels = list(set(df["name"]))
3976
3977            colors = [cmap(i / set_nam) for i in range(set_nam)]
3978
3979            cmap_dict = dict(zip(legend_labels, colors))
3980
3981            for idx, s in enumerate(sets):
3982                ax = axes[idx]
3983                tmp_df = df[df["set"] == s].reset_index(drop=True)
3984
3985                labels = [f"{row['pct']:.1f}%" for _, row in tmp_df.iterrows()]
3986
3987                wedges, _ = ax.pie(
3988                    tmp_df["n"],
3989                    startangle=90,
3990                    labeldistance=1.05,
3991                    colors=[cmap_dict[x] for x in tmp_df["name"]],
3992                    wedgeprops={"linewidth": 0.5, "edgecolor": "black"},
3993                )
3994
3995                all_wedges.extend(wedges)
3996
3997                kw = dict(arrowprops=dict(arrowstyle="-"), zorder=0, va="center")
3998                n = 0
3999                for i, p in enumerate(wedges):
4000                    ang = (p.theta2 - p.theta1) / 2.0 + p.theta1
4001                    y = np.sin(np.deg2rad(ang))
4002                    x = np.cos(np.deg2rad(ang))
4003                    horizontalalignment = {-1: "right", 1: "left"}[int(np.sign(x))]
4004                    connectionstyle = f"angle,angleA=0,angleB={ang}"
4005                    kw["arrowprops"].update({"connectionstyle": connectionstyle})
4006                    if len(labels[i]) > 0:
4007                        n += offset_labels
4008                        ax.annotate(
4009                            labels[i],
4010                            xy=(x, y),
4011                            xytext=(1.01 * x + (n * x / 4), 1.01 * y + (n * y / 4)),
4012                            horizontalalignment=horizontalalignment,
4013                            fontsize=font_size,
4014                            weight="bold",
4015                            **kw,
4016                        )
4017
4018                circle2 = plt.Circle((0, 0), 0.6, color="white", ec="black")
4019                ax.add_artist(circle2)
4020
4021                ax.text(
4022                    0,
4023                    0,
4024                    f"{s}",
4025                    ha="center",
4026                    va="center",
4027                    fontsize=font_size,
4028                    weight="bold",
4029                )
4030
4031            legend_handles = [
4032                Patch(facecolor=cmap_dict[label], edgecolor="black", label=label)
4033                for label in legend_labels
4034            ]
4035
4036            fig.legend(
4037                handles=legend_handles,
4038                loc="center right",
4039                bbox_to_anchor=legend_bbox,
4040                ncol=legend_split_col,
4041                title="",
4042            )
4043
4044            plt.tight_layout()
4045            plt.show()
4046
4047        else:
4048
4049            labels = [f"{row['pct']:.1f}%" for _, row in df.iterrows()]
4050
4051            legend_labels = [f"{row['name']}" for _, row in df.iterrows()]
4052
4053            cmap = plt.get_cmap(cmap)
4054            colors = [cmap(i / len(df)) for i in range(len(df))]
4055
4056            fig, ax = plt.subplots(
4057                figsize=(width, height), subplot_kw=dict(aspect="equal")
4058            )
4059
4060            wedges, _ = ax.pie(
4061                df["n"],
4062                startangle=90,
4063                labeldistance=1.05,
4064                colors=colors,
4065                wedgeprops={"linewidth": 0.5, "edgecolor": "black"},
4066            )
4067
4068            kw = dict(arrowprops=dict(arrowstyle="-"), zorder=0, va="center")
4069            n = 0
4070            for i, p in enumerate(wedges):
4071                ang = (p.theta2 - p.theta1) / 2.0 + p.theta1
4072                y = np.sin(np.deg2rad(ang))
4073                x = np.cos(np.deg2rad(ang))
4074                horizontalalignment = {-1: "right", 1: "left"}[int(np.sign(x))]
4075                connectionstyle = "angle,angleA=0,angleB={}".format(ang)
4076                kw["arrowprops"].update({"connectionstyle": connectionstyle})
4077                if len(labels[i]) > 0:
4078                    n += offset_labels
4079
4080                    ax.annotate(
4081                        labels[i],
4082                        xy=(x, y),
4083                        xytext=(1.01 * x + (n * x / 4), y * 1.01 + (n * y / 4)),
4084                        horizontalalignment=horizontalalignment,
4085                        fontsize=font_size,
4086                        weight="bold",
4087                        **kw,
4088                    )
4089
4090            circle2 = plt.Circle((0, 0), 0.6, color="white")
4091            circle2.set_edgecolor("black")
4092
4093            p = plt.gcf()
4094            p.gca().add_artist(circle2)
4095
4096            ax.legend(
4097                wedges,
4098                legend_labels,
4099                title="",
4100                loc="center left",
4101                bbox_to_anchor=legend_bbox,
4102                ncol=legend_split_col,
4103            )
4104
4105            plt.show()
4106
4107        return fig

Visualize the composition of cell lineages using pie charts.

Generates pie charts showing the relative proportions of features stored in self.composition_data. If multiple sets are present, a separate chart is drawn for each set.

Parameters

width : int, default 6 Width of the figure.

height : int, default 6 Height of the figure (applied per set if multiple sets are plotted).

font_size : int, default 15 Font size for labels and annotations.

cmap : str, default 'tab20' Colormap used for pie slices.

legend_split_col : int, default 1 Number of columns in the legend.

offset_labels : float or int, default 0.5 Spacing offset for label placement relative to pie slices.

legend_bbox : tuple, default (1.15, 0.95) Bounding box anchor position for the legend.

Returns

matplotlib.figure.Figure Pie chart visualization of composition data.

def bar_composition( self, cmap='tab20b', width=2, height=6, font_size=15, legend_split_col: int = 1, legend_bbox: tuple = (1.3, 1)):
4109    def bar_composition(
4110        self,
4111        cmap="tab20b",
4112        width=2,
4113        height=6,
4114        font_size=15,
4115        legend_split_col: int = 1,
4116        legend_bbox: tuple = (1.3, 1),
4117    ):
4118        """
4119        Visualize the composition of cell lineages using bar plots.
4120
4121        Produces bar plots showing the distribution of features stored in
4122        `self.composition_data`. If multiple sets are present, a separate
4123        bar is drawn for each set. Percentages are annotated alongside the bars.
4124
4125        Parameters
4126        ----------
4127        cmap : str, default 'tab20b'
4128            Colormap used for stacked bars.
4129
4130        width : int, default 2
4131            Width of each subplot (per set).
4132
4133        height : int, default 6
4134            Height of the figure.
4135
4136        font_size : int, default 15
4137            Font size for labels and annotations.
4138
4139        legend_split_col : int, default 1
4140            Number of columns in the legend.
4141
4142        legend_bbox : tuple, default (1.3, 1)
4143            Bounding box anchor position for the legend.
4144
4145        Returns
4146        -------
4147        matplotlib.figure.Figure
4148            Stacked bar plot visualization of composition data.
4149        """
4150
4151        df = self.composition_data
4152        df["num"] = range(1, len(df) + 1)
4153
4154        if "set" in df.columns and len(set(df["set"])) > 1:
4155
4156            sets = list(set(df["set"]))
4157            fig, axes = plt.subplots(1, len(sets), figsize=(width * len(sets), height))
4158
4159            cmap = plt.get_cmap(cmap)
4160
4161            set_nam = len(set(df["name"]))
4162
4163            legend_labels = list(set(df["name"]))
4164
4165            colors = [cmap(i / set_nam) for i in range(set_nam)]
4166
4167            cmap_dict = dict(zip(legend_labels, colors))
4168
4169            for idx, s in enumerate(sets):
4170                ax = axes[idx]
4171
4172                tmp_df = df[df["set"] == s].reset_index(drop=True)
4173
4174                values = tmp_df["n"].values
4175                total = sum(values)
4176                values = [v / total * 100 for v in values]
4177                values = [round(v, 2) for v in values]
4178
4179                idx_max = np.argmax(values)
4180                correction = 100 - sum(values)
4181                values[idx_max] += correction
4182
4183                names = tmp_df["name"].values
4184                perc = tmp_df["pct"].values
4185                nums = tmp_df["num"].values
4186
4187                bottom = 0
4188                centers = []
4189                for name, num, val, color in zip(names, nums, values, colors):
4190                    ax.bar(s, val, bottom=bottom, color=cmap_dict[name], label=name)
4191                    centers.append(bottom + val / 2)
4192                    bottom += val
4193
4194                y_positions = np.linspace(centers[0], centers[-1], len(centers))
4195                x_text = -0.8
4196
4197                for y_label, y_center, pct, num in zip(
4198                    y_positions, centers, perc, nums
4199                ):
4200                    ax.annotate(
4201                        f"{pct:.1f}%",
4202                        xy=(0, y_center),
4203                        xycoords="data",
4204                        xytext=(x_text, y_label),
4205                        textcoords="data",
4206                        ha="right",
4207                        va="center",
4208                        fontsize=font_size,
4209                        arrowprops=dict(
4210                            arrowstyle="->",
4211                            lw=1,
4212                            color="black",
4213                            connectionstyle="angle3,angleA=0,angleB=90",
4214                        ),
4215                    )
4216
4217                ax.set_ylim(0, 100)
4218                ax.set_xlabel(s, fontsize=font_size)
4219                ax.xaxis.label.set_rotation(30)
4220
4221                ax.set_xticks([])
4222                ax.set_yticks([])
4223                for spine in ax.spines.values():
4224                    spine.set_visible(False)
4225
4226            legend_handles = [
4227                Patch(facecolor=cmap_dict[label], edgecolor="black", label=label)
4228                for label in legend_labels
4229            ]
4230
4231            fig.legend(
4232                handles=legend_handles,
4233                loc="center right",
4234                bbox_to_anchor=legend_bbox,
4235                ncol=legend_split_col,
4236                title="",
4237            )
4238
4239            plt.tight_layout()
4240            plt.show()
4241
4242        else:
4243
4244            cmap = plt.get_cmap(cmap)
4245
4246            colors = [cmap(i / len(df)) for i in range(len(df))]
4247
4248            fig, ax = plt.subplots(figsize=(width, height))
4249
4250            values = df["n"].values
4251            names = df["name"].values
4252            perc = df["pct"].values
4253            nums = df["num"].values
4254
4255            bottom = 0
4256            centers = []
4257            for name, num, val, color in zip(names, nums, values, colors):
4258                ax.bar(0, val, bottom=bottom, color=color, label=f"{num}) {name}")
4259                centers.append(bottom + val / 2)
4260                bottom += val
4261
4262            y_positions = np.linspace(centers[0], centers[-1], len(centers))
4263            x_text = -0.8
4264
4265            for y_label, y_center, pct, num in zip(y_positions, centers, perc, nums):
4266                ax.annotate(
4267                    f"{num}) {pct}",
4268                    xy=(0, y_center),
4269                    xycoords="data",
4270                    xytext=(x_text, y_label),
4271                    textcoords="data",
4272                    ha="right",
4273                    va="center",
4274                    fontsize=9,
4275                    arrowprops=dict(
4276                        arrowstyle="->",
4277                        lw=1,
4278                        color="black",
4279                        connectionstyle="angle3,angleA=0,angleB=90",
4280                    ),
4281                )
4282
4283            ax.set_xticks([])
4284            ax.set_yticks([])
4285            for spine in ax.spines.values():
4286                spine.set_visible(False)
4287
4288            ax.legend(
4289                title="Legend",
4290                bbox_to_anchor=legend_bbox,
4291                loc="upper left",
4292                ncol=legend_split_col,
4293            )
4294
4295            plt.tight_layout()
4296            plt.show()
4297
4298        return fig

Visualize the composition of cell lineages using bar plots.

Produces bar plots showing the distribution of features stored in self.composition_data. If multiple sets are present, a separate bar is drawn for each set. Percentages are annotated alongside the bars.

Parameters

cmap : str, default 'tab20b' Colormap used for stacked bars.

width : int, default 2 Width of each subplot (per set).

height : int, default 6 Height of the figure.

font_size : int, default 15 Font size for labels and annotations.

legend_split_col : int, default 1 Number of columns in the legend.

legend_bbox : tuple, default (1.3, 1) Bounding box anchor position for the legend.

Returns

matplotlib.figure.Figure Stacked bar plot visualization of composition data.

def cell_regression( self, cell_x: str, cell_y: str, set_x: str | None, set_y: str | None, threshold=10, image_width=12, image_high=7, color='black'):
4300    def cell_regression(
4301        self,
4302        cell_x: str,
4303        cell_y: str,
4304        set_x: str | None,
4305        set_y: str | None,
4306        threshold=10,
4307        image_width=12,
4308        image_high=7,
4309        color="black",
4310    ):
4311        """
4312        Perform regression analysis between two selected cells and visualize the relationship.
4313
4314        This function computes a linear regression between two specified cells from
4315        aggregated normalized data, plots the regression line with scatter points,
4316        annotates regression statistics, and highlights potential outliers.
4317
4318        Parameters
4319        ----------
4320        cell_x : str
4321            Name of the first cell (X-axis).
4322
4323        cell_y : str
4324            Name of the second cell (Y-axis).
4325
4326        set_x : str or None
4327            Dataset identifier corresponding to `cell_x`. If None, cell is selected only by name.
4328
4329        set_y : str or None
4330            Dataset identifier corresponding to `cell_y`. If None, cell is selected only by name.
4331
4332        threshold : int or float, default 10
4333            Threshold for detecting outliers. Points deviating from the mean or diagonal by more
4334            than this value are annotated.
4335
4336        image_width : int, default 12
4337            Width of the regression plot (in inches).
4338
4339        image_high : int, default 7
4340            Height of the regression plot (in inches).
4341
4342        color : str, default 'black'
4343            Color of the regression scatter points and line.
4344
4345        Returns
4346        -------
4347        matplotlib.figure.Figure
4348            Regression plot figure with annotated regression line, R², p-value, and outliers.
4349
4350        Raises
4351        ------
4352        ValueError
4353            If `cell_x` or `cell_y` are not found in the dataset.
4354            If multiple matches are found for a cell name and `set_x`/`set_y` are not specified.
4355
4356        Notes
4357        -----
4358        - The function automatically calls `jseq_object.average()` if aggregated data is not available.
4359        - Outliers are annotated with their corresponding index labels.
4360        - Regression is computed using `scipy.stats.linregress`.
4361
4362        Examples
4363        --------
4364        >>> obj.cell_regression(cell_x="Purkinje", cell_y="Granule", set_x="Exp1", set_y="Exp2")
4365        >>> obj.cell_regression(cell_x="NeuronA", cell_y="NeuronB", threshold=5, color="blue")
4366        """
4367
4368        if self.agg_normalized_data is None:
4369            self.average()
4370
4371        metadata = self.agg_metadata
4372        data = self.agg_normalized_data
4373
4374        if set_x is not None and set_y is not None:
4375            data.columns = metadata["cell_names"] + " # " + metadata["sets"]
4376            cell_x = cell_x + " # " + set_x
4377            cell_y = cell_y + " # " + set_y
4378
4379        else:
4380            data.columns = metadata["cell_names"]
4381
4382        if not cell_x in data.columns:
4383            raise ValueError("'cell_x' value not in cell names!")
4384
4385        if not cell_y in data.columns:
4386            raise ValueError("'cell_y' value not in cell names!")
4387
4388        if list(data.columns).count(cell_x) > 1:
4389            raise ValueError(
4390                f"'{cell_x}' occurs more than once. If you want to select a specific cell, "
4391                f"please also provide the corresponding 'set_x' and 'set_y' values."
4392            )
4393
4394        if list(data.columns).count(cell_y) > 1:
4395            raise ValueError(
4396                f"'{cell_y}' occurs more than once. If you want to select a specific cell, "
4397                f"please also provide the corresponding 'set_x' and 'set_y' values."
4398            )
4399
4400        fig, ax = plt.subplots(figsize=(image_width, image_high))
4401        ax = sns.regplot(x=cell_x, y=cell_y, data=data, color=color)
4402
4403        slope, intercept, r_value, p_value, _ = stats.linregress(
4404            data[cell_x], data[cell_y]
4405        )
4406        equation = "y = {:.2f}x + {:.2f}".format(slope, intercept)
4407
4408        ax.annotate(
4409            "R-squared = {:.2f}\nP-value = {:.2f}\n{}".format(
4410                r_value**2, p_value, equation
4411            ),
4412            xy=(0.05, 0.90),
4413            xycoords="axes fraction",
4414            fontsize=12,
4415        )
4416
4417        ax.spines["top"].set_visible(False)
4418        ax.spines["right"].set_visible(False)
4419
4420        diff = []
4421        x_mean, y_mean = data[cell_x].mean(), data[cell_y].mean()
4422        for i, (xi, yi) in enumerate(zip(data[cell_x], data[cell_y])):
4423            diff.append(abs(xi - x_mean))
4424            diff.append(abs(yi - y_mean))
4425
4426        def annotate_outliers(x, y, threshold):
4427            texts = []
4428            x_mean, y_mean = x.mean(), y.mean()
4429            for i, (xi, yi) in enumerate(zip(x, y)):
4430                if (
4431                    abs(xi - x_mean) > threshold
4432                    or abs(yi - y_mean) > threshold
4433                    or abs(yi - xi) > threshold
4434                ):
4435                    text = ax.text(xi, yi, data.index[i])
4436                    texts.append(text)
4437
4438            return texts
4439
4440        texts = annotate_outliers(data[cell_x], data[cell_y], threshold)
4441
4442        adjust_text(texts, arrowprops=dict(arrowstyle="-", color="gray", alpha=0.5))
4443
4444        plt.show()
4445
4446        return fig

Perform regression analysis between two selected cells and visualize the relationship.

This function computes a linear regression between two specified cells from aggregated normalized data, plots the regression line with scatter points, annotates regression statistics, and highlights potential outliers.

Parameters

cell_x : str Name of the first cell (X-axis).

cell_y : str Name of the second cell (Y-axis).

set_x : str or None Dataset identifier corresponding to cell_x. If None, cell is selected only by name.

set_y : str or None Dataset identifier corresponding to cell_y. If None, cell is selected only by name.

threshold : int or float, default 10 Threshold for detecting outliers. Points deviating from the mean or diagonal by more than this value are annotated.

image_width : int, default 12 Width of the regression plot (in inches).

image_high : int, default 7 Height of the regression plot (in inches).

color : str, default 'black' Color of the regression scatter points and line.

Returns

matplotlib.figure.Figure Regression plot figure with annotated regression line, R², p-value, and outliers.

Raises

ValueError If cell_x or cell_y are not found in the dataset. If multiple matches are found for a cell name and set_x/set_y are not specified.

Notes

  • The function automatically calls jseq_object.average() if aggregated data is not available.
  • Outliers are annotated with their corresponding index labels.
  • Regression is computed using scipy.stats.linregress.

Examples

>>> obj.cell_regression(cell_x="Purkinje", cell_y="Granule", set_x="Exp1", set_y="Exp2")
>>> obj.cell_regression(cell_x="NeuronA", cell_y="NeuronB", threshold=5, color="blue")