jdti.jdti
1import math 2import os 3import pickle 4import re 5 6import harmonypy as harmonize 7import matplotlib.pyplot as plt 8import numpy as np 9import pandas as pd 10import plotly.express as px 11import seaborn as sns 12import umap 13from adjustText import adjust_text 14from joblib import Parallel, delayed 15from matplotlib.patches import FancyArrowPatch, Patch, Polygon 16from scipy import sparse 17from scipy.io import mmwrite 18from scipy.spatial import ConvexHull 19from scipy.spatial.distance import pdist, squareform 20from scipy.stats import norm, stats, zscore 21from sklearn.cluster import DBSCAN, MeanShift 22from sklearn.decomposition import PCA 23from sklearn.metrics import pairwise_distances, silhouette_score 24from sklearn.preprocessing import StandardScaler 25from tqdm import tqdm 26 27from .utils import * 28 29 30class Clustering: 31 """ 32 A class for performing dimensionality reduction, clustering, and visualization 33 on high-dimensional data (e.g., single-cell gene expression). 34 35 The class provides methods for: 36 - Normalizing and extracting subsets of data 37 - Principal Component Analysis (PCA) and related clustering 38 - Uniform Manifold Approximation and Projection (UMAP) and clustering 39 - Visualization of PCA and UMAP embeddings 40 - Harmonization of batch effects 41 - Accessing processed data and cluster labels 42 43 Methods 44 ------- 45 add_data_frame(data, metadata) 46 Class method to create a Clustering instance from a DataFrame and metadata. 47 48 harmonize_sets() 49 Perform batch effect harmonization on PCA data. 50 51 perform_PCA(pc_num=100, width=8, height=6) 52 Perform PCA on the dataset and visualize the first two PCs. 53 54 knee_plot_PCA(width=8, height=6) 55 Plot the cumulative variance explained by PCs to determine optimal dimensionality. 56 57 find_clusters_PCA(pc_num=0, eps=0.5, min_samples=10, width=8, height=6, harmonized=False) 58 Apply DBSCAN clustering to PCA embeddings and visualize results. 59 60 perform_UMAP(factorize=False, umap_num=100, pc_num=0, harmonized=False, ...) 61 Compute UMAP embeddings with optional parameter tuning. 62 63 knee_plot_umap(eps=0.5, min_samples=10) 64 Determine optimal UMAP dimensionality using silhouette scores. 65 66 find_clusters_UMAP(umap_n=5, eps=0.5, min_samples=10, width=8, height=6) 67 Apply DBSCAN clustering on UMAP embeddings and visualize clusters. 68 69 UMAP_vis(names_slot='cell_names', set_sep=True, point_size=0.6, ...) 70 Visualize UMAP embeddings with labels and optional cluster numbering. 71 72 UMAP_feature(feature_name, features_data=None, point_size=0.6, ...) 73 Plot a single feature over UMAP coordinates with customizable colormap. 74 75 get_umap_data() 76 Return the UMAP embeddings along with cluster labels if available. 77 78 get_pca_data() 79 Return the PCA results along with cluster labels if available. 80 81 return_clusters(clusters='umap') 82 Return the cluster labels for UMAP or PCA embeddings. 83 84 Raises 85 ------ 86 ValueError 87 For invalid parameters, mismatched dimensions, or missing metadata. 88 """ 89 90 def __init__(self, data, metadata): 91 """ 92 Initialize the clustering class with data and optional metadata. 93 94 Parameters 95 ---------- 96 data : pandas.DataFrame 97 Input data for clustering. Columns are considered as samples. 98 99 metadata : pandas.DataFrame, optional 100 Metadata for the samples. If None, a default DataFrame with column 101 names as 'cell_names' is created. 102 103 Attributes 104 ---------- 105 -clustering_data : pandas.DataFrame 106 -clustering_metadata : pandas.DataFrame 107 -subclusters : None or dict 108 -explained_var : None or numpy.ndarray 109 -cumulative_var : None or numpy.ndarray 110 -pca : None or pandas.DataFrame 111 -harmonized_pca : None or pandas.DataFrame 112 -umap : None or pandas.DataFrame 113 """ 114 115 self.clustering_data = data 116 """The input data used for clustering.""" 117 118 if metadata is None: 119 metadata = pd.DataFrame({"cell_names": list(data.columns)}) 120 121 self.clustering_metadata = metadata 122 """Metadata associated with the samples.""" 123 124 self.subclusters = None 125 """Placeholder for storing subcluster information.""" 126 127 self.explained_var = None 128 """Explained variance from PCA, initialized as None.""" 129 130 self.cumulative_var = None 131 """Cumulative explained variance from PCA, initialized as None.""" 132 133 self.pca = None 134 """PCA-transformed data, initialized as None.""" 135 136 self.harmonized_pca = None 137 """PCA data after batch effect harmonization, initialized as None.""" 138 139 self.umap = None 140 """UMAP embeddings, initialized as None.""" 141 142 @classmethod 143 def add_data_frame(cls, data: pd.DataFrame, metadata: pd.DataFrame | None): 144 """ 145 Create a Clustering instance from a DataFrame and optional metadata. 146 147 Parameters 148 ---------- 149 data : pandas.DataFrame 150 Input data with features as rows and samples/cells as columns. 151 152 metadata : pandas.DataFrame or None 153 Optional metadata for the samples. 154 Each row corresponds to a sample/cell, and column names in this DataFrame 155 should match the sample/cell names in `data`. Columns can contain additional 156 information such as cell type, experimental condition, batch, sets, etc. 157 158 Returns 159 ------- 160 Clustering 161 A new instance of the Clustering class. 162 """ 163 164 return cls(data, metadata) 165 166 def harmonize_sets(self, batch_col: str = "sets"): 167 """ 168 Perform batch effect harmonization on PCA embeddings. 169 170 Parameters 171 ---------- 172 batch_col : str, default 'sets' 173 Name of the column in `metadata` that contains batch information for the samples/cells. 174 175 Returns 176 ------- 177 None 178 Updates the `harmonized_pca` attribute with harmonized data. 179 """ 180 181 data_mat = np.array(self.pca) 182 183 metadata = self.clustering_metadata 184 185 self.harmonized_pca = pd.DataFrame( 186 harmonize.run_harmony(data_mat, metadata, vars_use=batch_col).Z_corr 187 ).T 188 189 self.harmonized_pca.columns = self.pca.columns 190 191 def perform_PCA(self, pc_num: int = 100, width=8, height=6): 192 """ 193 Perform Principal Component Analysis (PCA) on the dataset. 194 195 This method standardizes the data, applies PCA, stores results as attributes, 196 and generates a scatter plot of the first two principal components. 197 198 Parameters 199 ---------- 200 pc_num : int, default 100 201 Number of principal components to compute. 202 If 0, computes all available components. 203 204 width : int or float, default 8 205 Width of the PCA figure. 206 207 height : int or float, default 6 208 Height of the PCA figure. 209 210 Returns 211 ------- 212 matplotlib.figure.Figure 213 Scatter plot showing the first two principal components. 214 215 Updates 216 ------- 217 self.pca : pandas.DataFrame 218 DataFrame with principal component scores for each sample. 219 220 self.explained_var : numpy.ndarray 221 Percentage of variance explained by each principal component. 222 223 self.cumulative_var : numpy.ndarray 224 Cumulative explained variance. 225 """ 226 227 scaler = StandardScaler() 228 data_scaled = scaler.fit_transform(self.clustering_data.T) 229 230 if pc_num == 0 or pc_num > data_scaled.shape[0]: 231 pc_num = data_scaled.shape[0] 232 233 pca = PCA(n_components=pc_num, random_state=42) 234 235 principal_components = pca.fit_transform(data_scaled) 236 237 pca_df = pd.DataFrame( 238 data=principal_components, 239 columns=["PC" + str(x + 1) for x in range(pc_num)], 240 ) 241 242 self.explained_var = pca.explained_variance_ratio_ * 100 243 self.cumulative_var = np.cumsum(self.explained_var) 244 245 self.pca = pca_df 246 247 fig = plt.figure(figsize=(width, height)) 248 plt.scatter(pca_df["PC1"], pca_df["PC2"], alpha=0.7) 249 plt.xlabel("PC 1") 250 plt.ylabel("PC 2") 251 plt.grid(True) 252 plt.show() 253 254 return fig 255 256 def knee_plot_PCA(self, width: int = 8, height: int = 6): 257 """ 258 Plot cumulative explained variance to determine the optimal number of PCs. 259 260 Parameters 261 ---------- 262 width : int, default 8 263 Width of the figure. 264 265 height : int or, default 6 266 Height of the figure. 267 268 Returns 269 ------- 270 matplotlib.figure.Figure 271 Line plot showing cumulative variance explained by each PC. 272 """ 273 274 fig_knee = plt.figure(figsize=(width, height)) 275 plt.plot(range(1, len(self.explained_var) + 1), self.cumulative_var, marker="o") 276 plt.xlabel("PC (n components)") 277 plt.ylabel("Cumulative explained variance (%)") 278 plt.grid(True) 279 280 xticks = [1] + list(range(5, len(self.explained_var) + 1, 5)) 281 plt.xticks(xticks, rotation=60) 282 283 plt.show() 284 285 return fig_knee 286 287 def find_clusters_PCA( 288 self, 289 pc_num: int = 2, 290 eps: float = 0.5, 291 min_samples: int = 10, 292 width: int = 8, 293 height: int = 6, 294 harmonized: bool = False, 295 ): 296 """ 297 Apply DBSCAN clustering to PCA embeddings and visualize the results. 298 299 This method performs density-based clustering (DBSCAN) on the PCA-reduced 300 dataset. Cluster labels are stored in the object's metadata, and a scatter 301 plot of the first two principal components with cluster annotations is returned. 302 303 Parameters 304 ---------- 305 pc_num : int, default 2 306 Number of principal components to use for clustering. 307 If 0, uses all available components. 308 309 eps : float, default 0.5 310 Maximum distance between two points for them to be considered 311 as neighbors (DBSCAN parameter). 312 313 min_samples : int, default 10 314 Minimum number of samples required to form a cluster (DBSCAN parameter). 315 316 width : int, default 8 317 Width of the output scatter plot. 318 319 height : int, default 6 320 Height of the output scatter plot. 321 322 harmonized : bool, default False 323 If True, use harmonized PCA data (`self.harmonized_pca`). 324 If False, use standard PCA results (`self.pca`). 325 326 Returns 327 ------- 328 matplotlib.figure.Figure 329 Scatter plot of the first two principal components colored by 330 cluster assignments. 331 332 Updates 333 ------- 334 self.clustering_metadata['PCA_clusters'] : list 335 Cluster labels assigned to each cell/sample. 336 337 self.input_metadata['PCA_clusters'] : list, optional 338 Cluster labels stored in input metadata (if available). 339 """ 340 341 dbscan = DBSCAN(eps=eps, min_samples=min_samples) 342 343 if pc_num == 0 and harmonized: 344 PCA = self.harmonized_pca 345 346 elif pc_num == 0: 347 PCA = self.pca 348 349 else: 350 if harmonized: 351 352 PCA = self.harmonized_pca.iloc[:, 0:pc_num] 353 else: 354 355 PCA = self.pca.iloc[:, 0:pc_num] 356 357 dbscan_labels = dbscan.fit_predict(PCA) 358 359 pca_df = pd.DataFrame(PCA) 360 pca_df["Cluster"] = dbscan_labels 361 362 fig = plt.figure(figsize=(width, height)) 363 364 for cluster_id in sorted(pca_df["Cluster"].unique()): 365 cluster_data = pca_df[pca_df["Cluster"] == cluster_id] 366 plt.scatter( 367 cluster_data["PC1"], 368 cluster_data["PC2"], 369 label=f"Cluster {cluster_id}", 370 alpha=0.7, 371 ) 372 373 plt.xlabel("PC 1") 374 plt.ylabel("PC 2") 375 plt.legend(title="Clusters", loc="center left", bbox_to_anchor=(1.0, 0.5)) 376 377 plt.grid(True) 378 plt.show() 379 380 self.clustering_metadata["PCA_clusters"] = [str(x) for x in dbscan_labels] 381 382 try: 383 self.input_metadata["PCA_clusters"] = [str(x) for x in dbscan_labels] 384 except: 385 pass 386 387 return fig 388 389 def perform_UMAP( 390 self, 391 factorize: bool = False, 392 umap_num: int = 100, 393 pc_num: int = 0, 394 harmonized: bool = False, 395 n_neighbors: int = 5, 396 min_dist: float | int = 0.1, 397 spread: float | int = 1.0, 398 set_op_mix_ratio: float | int = 1.0, 399 local_connectivity: int = 1, 400 repulsion_strength: float | int = 1.0, 401 negative_sample_rate: int = 5, 402 width: int = 8, 403 height: int = 6, 404 ): 405 """ 406 Compute and visualize UMAP embeddings of the dataset. 407 408 This method applies Uniform Manifold Approximation and Projection (UMAP) 409 for dimensionality reduction on either raw, PCA, or harmonized PCA data. 410 Results are stored as a DataFrame (`self.umap`) and a scatter plot figure 411 (`self.UMAP_plot`). 412 413 Parameters 414 ---------- 415 factorize : bool, default False 416 If True, categorical sample labels (from column names) are factorized 417 and used as supervision in UMAP fitting. 418 419 umap_num : int, default 100 420 Number of UMAP dimensions to compute. If 0, matches the input dimension. 421 422 pc_num : int, default 0 423 Number of principal components to use as UMAP input. 424 If 0, use all available components or raw data. 425 426 harmonized : bool, default False 427 If True, use harmonized PCA embeddings (`self.harmonized_pca`). 428 If False, use standard PCA or raw scaled data. 429 430 n_neighbors : int, default 5 431 UMAP parameter controlling the size of the local neighborhood. 432 433 min_dist : float, default 0.1 434 UMAP parameter controlling minimum allowed distance between embedded points. 435 436 spread : int | float, default 1.0 437 Effective scale of embedded space (UMAP parameter). 438 439 set_op_mix_ratio : int | float, default 1.0 440 Interpolation parameter between union and intersection in fuzzy sets. 441 442 local_connectivity : int, default 1 443 Number of nearest neighbors assumed for each point. 444 445 repulsion_strength : int | float, default 1.0 446 Weighting applied to negative samples during optimization. 447 448 negative_sample_rate : int, default 5 449 Number of negative samples per positive sample in optimization. 450 451 width : int, default 8 452 Width of the output scatter plot. 453 454 height : int, default 6 455 Height of the output scatter plot. 456 457 Updates 458 ------- 459 self.umap : pandas.DataFrame 460 Table of UMAP embeddings with columns `UMAP1 ... UMAPn`. 461 462 Notes 463 ----- 464 For supervised UMAP (`factorize=True`), categorical codes from column 465 names of the dataset are used as labels. 466 """ 467 468 scaler = StandardScaler() 469 470 if pc_num == 0 and harmonized: 471 data_scaled = self.harmonized_pca 472 473 elif pc_num == 0: 474 data_scaled = scaler.fit_transform(self.clustering_data.T) 475 476 else: 477 if harmonized: 478 479 data_scaled = self.harmonized_pca.iloc[:, 0:pc_num] 480 else: 481 482 data_scaled = self.pca.iloc[:, 0:pc_num] 483 484 if umap_num == 0 or umap_num > data_scaled.shape[1]: 485 486 umap_num = data_scaled.shape[1] 487 488 reducer = umap.UMAP( 489 n_components=len(data_scaled.T), 490 random_state=42, 491 n_neighbors=n_neighbors, 492 min_dist=min_dist, 493 spread=spread, 494 set_op_mix_ratio=set_op_mix_ratio, 495 local_connectivity=local_connectivity, 496 repulsion_strength=repulsion_strength, 497 negative_sample_rate=negative_sample_rate, 498 n_jobs=1, 499 ) 500 501 else: 502 503 reducer = umap.UMAP( 504 n_components=umap_num, 505 random_state=42, 506 n_neighbors=n_neighbors, 507 min_dist=min_dist, 508 spread=spread, 509 set_op_mix_ratio=set_op_mix_ratio, 510 local_connectivity=local_connectivity, 511 repulsion_strength=repulsion_strength, 512 negative_sample_rate=negative_sample_rate, 513 n_jobs=1, 514 ) 515 516 if factorize: 517 embedding = reducer.fit_transform( 518 X=data_scaled, y=pd.Categorical(self.clustering_data.columns).codes 519 ) 520 else: 521 embedding = reducer.fit_transform(X=data_scaled) 522 523 umap_df = pd.DataFrame( 524 embedding, columns=["UMAP" + str(x + 1) for x in range(umap_num)] 525 ) 526 527 plt.figure(figsize=(width, height)) 528 plt.scatter(umap_df["UMAP1"], umap_df["UMAP2"], alpha=0.7) 529 plt.xlabel("UMAP 1") 530 plt.ylabel("UMAP 2") 531 plt.grid(True) 532 533 plt.show() 534 535 self.umap = umap_df 536 537 def knee_plot_umap(self, eps: int | float = 0.5, min_samples: int = 10): 538 """ 539 Plot silhouette scores for different UMAP dimensions to determine optimal n_components. 540 541 Parameters 542 ---------- 543 eps : float, default 0.5 544 DBSCAN eps parameter for clustering each UMAP dimension. 545 546 min_samples : int, default 10 547 Minimum number of samples to form a cluster in DBSCAN. 548 549 Returns 550 ------- 551 matplotlib.figure.Figure 552 Silhouette score plot across UMAP dimensions. 553 """ 554 555 umap_range = range(2, len(self.umap.T) + 1) 556 557 silhouette_scores = [] 558 component = [] 559 for n in umap_range: 560 561 db = DBSCAN(eps=eps, min_samples=min_samples) 562 labels = db.fit_predict(np.array(self.umap)[:, :n]) 563 564 mask = labels != -1 565 if len(set(labels[mask])) > 1: 566 score = silhouette_score(np.array(self.umap)[:, :n][mask], labels[mask]) 567 else: 568 score = -1 569 570 silhouette_scores.append(score) 571 component.append(n) 572 573 fig = plt.figure(figsize=(10, 5)) 574 plt.plot(component, silhouette_scores, marker="o") 575 plt.xlabel("UMAP (n_components)") 576 plt.ylabel("Silhouette Score") 577 plt.grid(True) 578 plt.xticks(range(int(min(component)), int(max(component)) + 1, 1)) 579 580 plt.show() 581 582 return fig 583 584 def find_clusters_UMAP( 585 self, 586 umap_n: int = 5, 587 eps: float | float = 0.5, 588 min_samples: int = 10, 589 width: int = 8, 590 height: int = 6, 591 ): 592 """ 593 Apply DBSCAN clustering on UMAP embeddings and visualize clusters. 594 595 This method performs density-based clustering (DBSCAN) on the UMAP-reduced 596 dataset. Cluster labels are stored in the object's metadata, and a scatter 597 plot of the first two UMAP components with cluster annotations is returned. 598 599 Parameters 600 ---------- 601 umap_n : int, default 5 602 Number of UMAP dimensions to use for DBSCAN clustering. 603 Must be <= number of columns in `self.umap`. 604 605 eps : float | int, default 0.5 606 Maximum neighborhood distance between two samples for them to be considered 607 as in the same cluster (DBSCAN parameter). 608 609 min_samples : int, default 10 610 Minimum number of samples in a neighborhood to form a cluster (DBSCAN parameter). 611 612 width : int, default 8 613 Figure width. 614 615 height : int, default 6 616 Figure height. 617 618 Returns 619 ------- 620 matplotlib.figure.Figure 621 Scatter plot of the first two UMAP components colored by 622 cluster assignments. 623 624 Updates 625 ------- 626 self.clustering_metadata['UMAP_clusters'] : list 627 Cluster labels assigned to each cell/sample. 628 629 self.input_metadata['UMAP_clusters'] : list, optional 630 Cluster labels stored in input metadata (if available). 631 """ 632 633 dbscan = DBSCAN(eps=eps, min_samples=min_samples) 634 dbscan_labels = dbscan.fit_predict(np.array(self.umap)[:, :umap_n]) 635 636 umap_df = self.umap 637 umap_df["Cluster"] = dbscan_labels 638 639 fig = plt.figure(figsize=(width, height)) 640 641 for cluster_id in sorted(umap_df["Cluster"].unique()): 642 cluster_data = umap_df[umap_df["Cluster"] == cluster_id] 643 plt.scatter( 644 cluster_data["UMAP1"], 645 cluster_data["UMAP2"], 646 label=f"Cluster {cluster_id}", 647 alpha=0.7, 648 ) 649 650 plt.xlabel("UMAP 1") 651 plt.ylabel("UMAP 2") 652 plt.legend(title="Clusters", loc="center left", bbox_to_anchor=(1.0, 0.5)) 653 plt.grid(True) 654 655 self.clustering_metadata["UMAP_clusters"] = [str(x) for x in dbscan_labels] 656 657 try: 658 self.input_metadata["UMAP_clusters"] = [str(x) for x in dbscan_labels] 659 except: 660 pass 661 662 return fig 663 664 def UMAP_vis( 665 self, 666 names_slot: str = "cell_names", 667 set_sep: bool = True, 668 point_size: int | float = 0.6, 669 font_size: int | float = 6, 670 legend_split_col: int = 2, 671 width: int = 8, 672 height: int = 6, 673 inc_num: bool = True, 674 ): 675 """ 676 Visualize UMAP embeddings with sample labels based on specyfic metadata slot. 677 678 Parameters 679 ---------- 680 names_slot : str, default 'cell_names' 681 Column in metadata to use as sample labels. 682 683 set_sep : bool, default True 684 If True, separate points by dataset. 685 686 point_size : float, default 0.6 687 Size of scatter points. 688 689 font_size : int, default 6 690 Font size for numbers on points. 691 692 legend_split_col : int, default 2 693 Number of columns in legend. 694 695 width : int, default 8 696 Figure width. 697 698 height : int, default 6 699 Figure height. 700 701 inc_num : bool, default True 702 If True, annotate points with numeric labels. 703 704 Returns 705 ------- 706 matplotlib.figure.Figure 707 UMAP scatter plot figure. 708 """ 709 710 umap_df = self.umap.iloc[:, 0:2].copy() 711 umap_df["names"] = list(self.clustering_metadata[names_slot]) 712 713 if set_sep: 714 715 if "sets" in list(self.clustering_metadata.columns): 716 umap_df["dataset"] = list(self.clustering_metadata["sets"]) 717 else: 718 umap_df["dataset"] = "default" 719 720 else: 721 umap_df["dataset"] = "default" 722 723 umap_df["tmp_nam"] = list(umap_df["names"] + umap_df["dataset"]) 724 725 umap_df["count"] = umap_df["tmp_nam"].map(umap_df["tmp_nam"].value_counts()) 726 727 numeric_df = ( 728 pd.DataFrame(umap_df[["count", "tmp_nam", "names"]].copy()) 729 .drop_duplicates() 730 .sort_values("count", ascending=False) 731 ) 732 numeric_df["numeric_values"] = range(0, numeric_df.shape[0]) 733 734 umap_df = umap_df.merge( 735 numeric_df[["tmp_nam", "numeric_values"]], on="tmp_nam", how="left" 736 ) 737 738 fig, ax = plt.subplots(figsize=(width, height)) 739 740 markers = ["o", "s", "^", "D", "P", "*", "X"] 741 marker_map = { 742 ds: markers[i % len(markers)] 743 for i, ds in enumerate(umap_df["dataset"].unique()) 744 } 745 746 cord_list = [] 747 748 for num, nam in zip(numeric_df["numeric_values"], numeric_df["names"]): 749 750 cluster_data = umap_df[umap_df["numeric_values"] == num] 751 752 ax.scatter( 753 cluster_data["UMAP1"], 754 cluster_data["UMAP2"], 755 label=f"{num} - {nam}", 756 marker=marker_map[cluster_data["dataset"].iloc[0]], 757 alpha=0.6, 758 s=point_size, 759 ) 760 761 coords = cluster_data[["UMAP1", "UMAP2"]].values 762 763 dists = pairwise_distances(coords) 764 765 sum_dists = dists.sum(axis=1) 766 767 center_idx = np.argmin(sum_dists) 768 center_point = coords[center_idx] 769 770 cord_list.append(center_point) 771 772 if inc_num: 773 texts = [] 774 for (x, y), num in zip(cord_list, numeric_df["numeric_values"]): 775 texts.append( 776 ax.text( 777 x, 778 y, 779 str(num), 780 ha="center", 781 va="center", 782 fontsize=font_size, 783 color="black", 784 ) 785 ) 786 787 adjust_text(texts, arrowprops=dict(arrowstyle="-", color="gray", lw=0.5)) 788 789 ax.set_xlabel("UMAP 1") 790 ax.set_ylabel("UMAP 2") 791 792 ax.legend( 793 title="Clusters", 794 loc="center left", 795 bbox_to_anchor=(1.05, 0.5), 796 ncol=legend_split_col, 797 markerscale=5, 798 ) 799 800 ax.grid(True) 801 802 plt.tight_layout() 803 804 return fig 805 806 def UMAP_feature( 807 self, 808 feature_name: str, 809 features_data: pd.DataFrame | None, 810 point_size: int | float = 0.6, 811 font_size: int | float = 6, 812 width: int = 8, 813 height: int = 6, 814 palette="light", 815 ): 816 """ 817 Visualize UMAP embedding with expression levels of a selected feature. 818 819 Each point (cell) in the UMAP plot is colored according to the expression 820 value of the chosen feature, enabling interpretation of spatial patterns 821 of gene activity or metadata distribution in low-dimensional space. 822 823 Parameters 824 ---------- 825 feature_name : str 826 Name of the feature to plot. 827 828 features_data : pandas.DataFrame or None, default None 829 If None, the function uses the DataFrame containing the clustering data. 830 To plot features not used in clustering, provide a wider DataFrame 831 containing the original feature values. 832 833 point_size : float, default 0.6 834 Size of scatter points in the plot. 835 836 font_size : int, default 6 837 Font size for axis labels and annotations. 838 839 width : int, default 8 840 Width of the matplotlib figure. 841 842 height : int, default 6 843 Height of the matplotlib figure. 844 845 palette : str, default 'light' 846 Color palette for expression visualization. Options are: 847 - 'light' 848 - 'dark' 849 - 'green' 850 - 'gray' 851 852 Returns 853 ------- 854 matplotlib.figure.Figure 855 UMAP scatter plot colored by feature values. 856 """ 857 858 umap_df = self.umap.iloc[:, 0:2].copy() 859 860 if features_data is None: 861 862 features_data = self.clustering_data 863 864 if features_data.shape[1] != umap_df.shape[0]: 865 raise ValueError( 866 "Imputed 'features_data' shape does not match the number of UMAP cells" 867 ) 868 869 blist = [ 870 True if x.upper() == feature_name.upper() else False 871 for x in features_data.index 872 ] 873 874 if not any(blist): 875 raise ValueError("Imputed feature_name is not included in the data") 876 877 umap_df.loc[:, "feature"] = ( 878 features_data.loc[blist, :] 879 .apply(lambda row: row.tolist(), axis=1) 880 .values[0] 881 ) 882 883 umap_df = umap_df.sort_values("feature", ascending=True) 884 885 import matplotlib.colors as mcolors 886 887 if palette == "light": 888 palette = px.colors.sequential.Sunsetdark 889 890 elif palette == "dark": 891 palette = px.colors.sequential.thermal 892 893 elif palette == "green": 894 palette = px.colors.sequential.Aggrnyl 895 896 elif palette == "gray": 897 palette = px.colors.sequential.gray 898 palette = palette[::-1] 899 900 else: 901 raise ValueError( 902 'Palette not found. Use: "light", "dark", "gray", or "green"' 903 ) 904 905 converted = [] 906 for c in palette: 907 rgb_255 = px.colors.unlabel_rgb(c) 908 rgb_01 = tuple(v / 255.0 for v in rgb_255) 909 converted.append(rgb_01) 910 911 my_cmap = mcolors.ListedColormap(converted, name="custom") 912 913 fig, ax = plt.subplots(figsize=(width, height)) 914 sc = ax.scatter( 915 umap_df["UMAP1"], 916 umap_df["UMAP2"], 917 c=umap_df["feature"], 918 s=point_size, 919 cmap=my_cmap, 920 alpha=1.0, 921 edgecolors="black", 922 linewidths=0.1, 923 ) 924 925 cbar = plt.colorbar(sc, ax=ax) 926 cbar.set_label(f"{feature_name}") 927 928 ax.set_xlabel("UMAP 1") 929 ax.set_ylabel("UMAP 2") 930 931 ax.grid(True) 932 933 plt.tight_layout() 934 935 return fig 936 937 def get_umap_data(self): 938 """ 939 Retrieve UMAP embedding data with optional cluster labels. 940 941 Returns the UMAP coordinates stored in `self.umap`. If clustering 942 metadata is available (specifically `UMAP_clusters`), the corresponding 943 cluster assignments are appended as an additional column. 944 945 Returns 946 ------- 947 pandas.DataFrame 948 DataFrame containing UMAP coordinates (columns: 'UMAP1', 'UMAP2', ...). 949 If available, includes an extra column 'clusters' with cluster labels. 950 951 Notes 952 ----- 953 - UMAP embeddings must be computed beforehand (e.g., using `perform_UMAP`). 954 - Cluster labels are added only if present in `self.clustering_metadata`. 955 """ 956 957 umap_data = self.umap 958 959 try: 960 umap_data["clusters"] = self.clustering_metadata["UMAP_clusters"] 961 except: 962 pass 963 964 return umap_data 965 966 def get_pca_data(self): 967 """ 968 Retrieve PCA embedding data with optional cluster labels. 969 970 Returns the principal component scores stored in `self.pca`. If clustering 971 metadata is available (specifically `PCA_clusters`), the corresponding 972 cluster assignments are appended as an additional column. 973 974 Returns 975 ------- 976 pandas.DataFrame 977 DataFrame containing PCA coordinates (columns: 'PC1', 'PC2', ...). 978 If available, includes an extra column 'clusters' with cluster labels. 979 980 Notes 981 ----- 982 - PCA must be computed beforehand (e.g., using `perform_PCA`). 983 - Cluster labels are added only if present in `self.clustering_metadata`. 984 """ 985 986 pca_data = self.pca 987 988 try: 989 pca_data["clusters"] = self.clustering_metadata["PCA_clusters"] 990 except: 991 pass 992 993 return pca_data 994 995 def return_clusters(self, clusters="umap"): 996 """ 997 Retrieve cluster labels from UMAP or PCA clustering results. 998 999 Parameters 1000 ---------- 1001 clusters : str, default 'umap' 1002 Source of cluster labels to return. Must be one of: 1003 - 'umap': return cluster labels from UMAP embeddings. 1004 - 'pca' : return cluster labels from PCA embeddings. 1005 1006 Returns 1007 ------- 1008 list 1009 Cluster labels corresponding to the selected embedding method. 1010 1011 Raises 1012 ------ 1013 ValueError 1014 If `clusters` is not 'umap' or 'pca'. 1015 1016 Notes 1017 ----- 1018 Requires that clustering has already been performed 1019 (e.g., using `find_clusters_UMAP` or `find_clusters_PCA`). 1020 """ 1021 1022 if clusters.lower() == "umap": 1023 clusters_vector = self.clustering_metadata["UMAP_clusters"] 1024 elif clusters.lower() == "pca": 1025 clusters_vector = self.clustering_metadata["PCA_clusters"] 1026 else: 1027 raise ValueError("Parameter 'clusters' must be either 'umap' or 'pca'.") 1028 1029 return clusters_vector 1030 1031 1032class COMPsc(Clustering): 1033 """ 1034 A class `COMPsc` (Comparison of single-cell data) designed for the integration, 1035 analysis, and visualization of single-cell datasets. 1036 The class supports independent dataset integration, subclustering of existing clusters, 1037 marker detection, and multiple visualization strategies. 1038 1039 The COMPsc class provides methods for: 1040 1041 - Normalizing and filtering single-cell data 1042 - Loading and saving sparse 10x-style datasets 1043 - Computing differential expression and marker genes 1044 - Clustering and subclustering analysis 1045 - Visualizing similarity and spatial relationships 1046 - Aggregating data by cell and set annotations 1047 - Managing metadata and renaming labels 1048 - Plotting gene detection histograms and feature scatters 1049 1050 Methods 1051 ------- 1052 project_dir(path_to_directory, project_list) 1053 Scans a directory to create a COMPsc instance mapping project names to their paths. 1054 1055 save_project(name, path=os.getcwd()) 1056 Saves the COMPsc object to a pickle file on disk. 1057 1058 load_project(path) 1059 Loads a previously saved COMPsc object from a pickle file. 1060 1061 reduce_cols(reg, inc_set=False) 1062 Removes columns from data tables where column names contain a specified name or partial substring. 1063 1064 reduce_rows(reg, inc_set=False) 1065 Removes rows from data tables where column names contain a specified feature (gene) name. 1066 1067 get_data(set_info=False) 1068 Returns normalized data with optional set annotations in column names. 1069 1070 get_metadata() 1071 Returns the stored input metadata. 1072 1073 get_partial_data(names=None, features=None, name_slot='cell_names') 1074 Return a subset of the data by sample names and/or features. 1075 1076 gene_calculation() 1077 Calculates and stores per-cell gene detection counts as a pandas Series. 1078 1079 gene_histograme(bins=100) 1080 Plots a histogram of genes detected per cell with an overlaid normal distribution. 1081 1082 gene_threshold(min_n=None, max_n=None) 1083 Filters cells based on minimum and/or maximum gene detection thresholds. 1084 1085 load_sparse_from_projects(normalized_data=False) 1086 Loads and concatenates sparse 10x-style datasets from project paths into count or normalized data. 1087 1088 rename_names(mapping, slot='cell_names') 1089 Renames entries in a specified metadata column using a provided mapping dictionary. 1090 1091 rename_subclusters(mapping) 1092 Renames subcluster labels using a provided mapping dictionary. 1093 1094 save_sparse(path_to_save=os.getcwd(), name_slot='cell_names', data_slot='normalized') 1095 Exports data as 10x-compatible sparse files (matrix.mtx, barcodes.tsv, genes.tsv). 1096 1097 normalize_data(normalize=True, normalize_factor=100000) 1098 Normalizes raw counts to counts-per-specified factor (e.g., CPM-like). 1099 1100 statistic(cells=None, sets=None, min_exp=0.01, min_pct=0.1, n_proc=10) 1101 Computes per-feature differential expression statistics (Mann-Whitney U) comparing target vs. rest groups. 1102 1103 calculate_difference_markers(min_exp=0, min_pct=0.25, n_proc=10, force=False) 1104 Computes and caches differential markers using the statistic method. 1105 1106 clustering_features(features_list=None, name_slot='cell_names', p_val=0.05, top_n=25, adj_mean=True, beta=0.4) 1107 Prepares clustering input by selecting marker features and optionally smoothing cell values. 1108 1109 average() 1110 Aggregates normalized data by averaging across (cell_name, set) pairs. 1111 1112 estimating_similarity(method='pearson', p_val=0.05, top_n=25) 1113 Computes pairwise correlation and Euclidean distance between aggregated samples. 1114 1115 similarity_plot(split_sets=True, set_info=True, cmap='seismic', width=12, height=10) 1116 Visualizes pairwise similarity as a scatter plot with correlation as hue and scaled distance as point size. 1117 1118 spatial_similarity(set_info=True, bandwidth=1, n_neighbors=5, min_dist=0.1, legend_split=2, point_size=20, ...) 1119 Creates a UMAP-like visualization of similarity relationships with cluster hulls and nearest-neighbor arrows. 1120 1121 subcluster_prepare(features, cluster) 1122 Initializes a Clustering object for subcluster analysis on a selected parent cluster. 1123 1124 define_subclusters(umap_num=2, eps=0.5, min_samples=10, bandwidth=1, n_neighbors=5, min_dist=0.1, ...) 1125 Performs UMAP and DBSCAN clustering on prepared subcluster data and stores cluster labels. 1126 1127 subcluster_features_scatter(colors='viridis', hclust='complete', img_width=3, img_high=5, label_size=6, ...) 1128 Visualizes averaged expression and occurrence of features for subclusters as a scatter plot. 1129 1130 subcluster_DEG_scatter(top_n=3, min_exp=0, min_pct=0.25, p_val=0.05, colors='viridis', ...) 1131 Plots top differential features for subclusters as a features-scatter visualization. 1132 1133 accept_subclusters() 1134 Commits subcluster labels to main metadata by renaming cell names and clears subcluster data. 1135 1136 Raises 1137 ------ 1138 ValueError 1139 For invalid parameters, mismatched dimensions, or missing metadata. 1140 1141 """ 1142 1143 def __init__( 1144 self, 1145 objects=None, 1146 ): 1147 """ 1148 Initialize the COMPsc class for single-cell data integration and analysis. 1149 1150 Parameters 1151 ---------- 1152 objects : list or None, optional 1153 Optional list of data objects to initialize the instance with. 1154 1155 Attributes 1156 ---------- 1157 -objects : list or None 1158 -input_data : pandas.DataFrame or None 1159 -input_metadata : pandas.DataFrame or None 1160 -normalized_data : pandas.DataFrame or None 1161 -agg_metadata : pandas.DataFrame or None 1162 -agg_normalized_data : pandas.DataFrame or None 1163 -similarity : pandas.DataFrame or None 1164 -var_data : pandas.DataFrame or None 1165 -subclusters_ : instance of Clustering class or None 1166 -cells_calc : pandas.Series or None 1167 -gene_calc : pandas.Series or None 1168 -composition_data : pandas.DataFrame or None 1169 """ 1170 1171 self.objects = objects 1172 """ Stores the input data objects.""" 1173 1174 self.input_data = None 1175 """Raw input data for clustering or integration analysis.""" 1176 1177 self.input_metadata = None 1178 """Metadata associated with the input data.""" 1179 1180 self.normalized_data = None 1181 """Normalized version of the input data.""" 1182 1183 self.agg_metadata = None 1184 '''Aggregated metadata for all sets in object related to "agg_normalized_data"''' 1185 1186 self.agg_normalized_data = None 1187 """Aggregated and normalized data across multiple sets.""" 1188 1189 self.similarity = None 1190 """Similarity data between cells across all samples. and sets""" 1191 1192 self.var_data = None 1193 """DEG analysis results summarizing variance across all samples in the object.""" 1194 1195 self.subclusters_ = None 1196 """Placeholder for information about subclusters analysis; if computed.""" 1197 1198 self.cells_calc = None 1199 """Number of cells detected per sample (grouped by lineage, e.g., cluster or name), reflecting data composition.""" 1200 1201 self.gene_calc = None 1202 """Number of genes detected per sample (cell), reflecting the sequencing depth.""" 1203 1204 self.composition_data = None 1205 """Data describing composition of cells across clusters or sets.""" 1206 1207 @classmethod 1208 def project_dir(cls, path_to_directory, project_list): 1209 """ 1210 Scan a directory and build a COMPsc instance mapping provided project names 1211 to their paths. 1212 1213 Parameters 1214 ---------- 1215 path_to_directory : str 1216 Path containing project subfolders. 1217 1218 project_list : list[str] 1219 List of filenames (folder names) to include in the returned object map. 1220 1221 Returns 1222 ------- 1223 COMPsc 1224 New COMPsc instance with `objects` populated. 1225 1226 Raises 1227 ------ 1228 Exception 1229 A generic exception is caught and a message printed if scanning fails. 1230 1231 Notes 1232 ----- 1233 Function attempts to match entries in `project_list` to directory 1234 names and constructs a simplified object key from the folder name. 1235 """ 1236 try: 1237 objects = {} 1238 for filename in tqdm(os.listdir(path_to_directory)): 1239 for c in project_list: 1240 f = os.path.join(path_to_directory, filename) 1241 if c == filename and os.path.isdir(f): 1242 objects[str(c)] = f 1243 1244 return cls(objects) 1245 1246 except: 1247 print("Something went wrong. Check the function input data and try again!") 1248 1249 def save_project(self, name, path: str = os.getcwd()): 1250 """ 1251 Save the COMPsc object to disk using pickle. 1252 1253 Parameters 1254 ---------- 1255 name : str 1256 Base filename (without extension) to use when saving. 1257 1258 path : str, default os.getcwd() 1259 Directory in which to save the project file. 1260 1261 Returns 1262 ------- 1263 None 1264 1265 Side Effects 1266 ------------ 1267 - Writes a file `<path>/<name>.jpkl` containing the pickled object. 1268 - Prints a confirmation message with saved path. 1269 """ 1270 1271 full = os.path.join(path, f"{name}.jpkl") 1272 1273 with open(full, "wb") as f: 1274 pickle.dump(self, f) 1275 1276 print(f"Project saved as {full}") 1277 1278 @classmethod 1279 def load_project(cls, path): 1280 """ 1281 Load a previously saved COMPsc project from a pickle file. 1282 1283 Parameters 1284 ---------- 1285 path : str 1286 Full path to the pickled project file. 1287 1288 Returns 1289 ------- 1290 COMPsc 1291 The unpickled COMPsc object. 1292 1293 Raises 1294 ------ 1295 FileNotFoundError 1296 If the provided path does not exist. 1297 """ 1298 1299 if not os.path.exists(path): 1300 raise FileNotFoundError("File does not exist!") 1301 with open(path, "rb") as f: 1302 obj = pickle.load(f) 1303 return obj 1304 1305 def reduce_cols( 1306 self, 1307 reg: str | None = None, 1308 full: str | None = None, 1309 name_slot: str = "cell_names", 1310 inc_set: bool = False, 1311 ): 1312 """ 1313 Remove columns (cells) whose names contain a substring `reg` or 1314 full name `full` from available tables. 1315 1316 Parameters 1317 ---------- 1318 reg : str | None 1319 Substring to search for in column/cell names; matching columns will be removed. 1320 If not None, `full` must be None. 1321 1322 full : str | None 1323 Full name to search for in column/cell names; matching columns will be removed. 1324 If not None, `reg` must be None. 1325 1326 name_slot : str, default 'cell_names' 1327 Column in metadata to use as sample names. 1328 1329 inc_set : bool, default False 1330 If True, column names are interpreted as 'cell_name # set' when matching. 1331 1332 Update 1333 ------------ 1334 Mutates `self.input_data`, `self.normalized_data`, `self.input_metadata`, 1335 `self.agg_normalized_data`, and `self.agg_metadata` (if they exist), 1336 removing columns/rows that match `reg`. 1337 1338 Raises 1339 ------ 1340 Raises ValueError if nothing matches the reduction mask. 1341 """ 1342 1343 if reg is None and full is None: 1344 raise ValueError( 1345 "Both 'reg' and 'full' arguments not provided. Please provide at least one of them!" 1346 ) 1347 1348 if reg is not None and full is not None: 1349 raise ValueError( 1350 "Both 'reg' and 'full' arguments are provided. " 1351 "Please provide only one of them!\n" 1352 "'reg' is used when only part of the name must be detected.\n" 1353 "'full' is used if the full name must be detected." 1354 ) 1355 1356 if reg is not None: 1357 1358 if self.input_data is not None: 1359 1360 if inc_set: 1361 1362 self.input_data.columns = ( 1363 self.input_metadata[name_slot] 1364 + " # " 1365 + self.input_metadata["sets"] 1366 ) 1367 1368 else: 1369 1370 self.input_data.columns = self.input_metadata[name_slot] 1371 1372 mask = [reg.upper() not in x.upper() for x in self.input_data.columns] 1373 1374 if len([y for y in mask if y is False]) == 0: 1375 raise ValueError("Nothing found to reduce") 1376 1377 self.input_data = self.input_data.loc[:, mask] 1378 1379 if self.normalized_data is not None: 1380 1381 if inc_set: 1382 1383 self.normalized_data.columns = ( 1384 self.input_metadata[name_slot] 1385 + " # " 1386 + self.input_metadata["sets"] 1387 ) 1388 1389 else: 1390 1391 self.normalized_data.columns = self.input_metadata[name_slot] 1392 1393 mask = [ 1394 reg.upper() not in x.upper() for x in self.normalized_data.columns 1395 ] 1396 1397 if len([y for y in mask if y is False]) == 0: 1398 raise ValueError("Nothing found to reduce") 1399 1400 self.normalized_data = self.normalized_data.loc[:, mask] 1401 1402 if self.input_metadata is not None: 1403 1404 if inc_set: 1405 1406 self.input_metadata["drop"] = ( 1407 self.input_metadata[name_slot] 1408 + " # " 1409 + self.input_metadata["sets"] 1410 ) 1411 1412 else: 1413 1414 self.input_metadata["drop"] = self.input_metadata[name_slot] 1415 1416 mask = [ 1417 reg.upper() not in x.upper() for x in self.input_metadata["drop"] 1418 ] 1419 1420 if len([y for y in mask if y is False]) == 0: 1421 raise ValueError("Nothing found to reduce") 1422 1423 self.input_metadata = self.input_metadata.loc[mask, :].reset_index( 1424 drop=True 1425 ) 1426 1427 self.input_metadata = self.input_metadata.drop( 1428 columns=["drop"], errors="ignore" 1429 ) 1430 1431 if self.agg_normalized_data is not None: 1432 1433 if inc_set: 1434 1435 self.agg_normalized_data.columns = ( 1436 self.agg_metadata[name_slot] + " # " + self.agg_metadata["sets"] 1437 ) 1438 1439 else: 1440 1441 self.agg_normalized_data.columns = self.agg_metadata[name_slot] 1442 1443 mask = [ 1444 reg.upper() not in x.upper() 1445 for x in self.agg_normalized_data.columns 1446 ] 1447 1448 if len([y for y in mask if y is False]) == 0: 1449 raise ValueError("Nothing found to reduce") 1450 1451 self.agg_normalized_data = self.agg_normalized_data.loc[:, mask] 1452 1453 if self.agg_metadata is not None: 1454 1455 if inc_set: 1456 1457 self.agg_metadata["drop"] = ( 1458 self.agg_metadata[name_slot] + " # " + self.agg_metadata["sets"] 1459 ) 1460 1461 else: 1462 1463 self.agg_metadata["drop"] = self.agg_metadata[name_slot] 1464 1465 mask = [reg.upper() not in x.upper() for x in self.agg_metadata["drop"]] 1466 1467 if len([y for y in mask if y is False]) == 0: 1468 raise ValueError("Nothing found to reduce") 1469 1470 self.agg_metadata = self.agg_metadata.loc[mask, :].reset_index( 1471 drop=True 1472 ) 1473 1474 self.agg_metadata = self.agg_metadata.drop( 1475 columns=["drop"], errors="ignore" 1476 ) 1477 1478 elif full is not None: 1479 1480 if self.input_data is not None: 1481 1482 if inc_set: 1483 1484 self.input_data.columns = ( 1485 self.input_metadata[name_slot] 1486 + " # " 1487 + self.input_metadata["sets"] 1488 ) 1489 1490 if "#" not in full: 1491 1492 self.input_data.columns = self.input_metadata[name_slot] 1493 1494 print( 1495 "Not include the set info (name # set) in the 'full' argument, where 'inc_set' is True" 1496 "Only the names will be compared, without considering the set information." 1497 ) 1498 1499 else: 1500 1501 self.input_data.columns = self.input_metadata[name_slot] 1502 1503 mask = [full.upper() != x.upper() for x in self.input_data.columns] 1504 1505 if len([y for y in mask if y is False]) == 0: 1506 raise ValueError("Nothing found to reduce") 1507 1508 self.input_data = self.input_data.loc[:, mask] 1509 1510 if self.normalized_data is not None: 1511 1512 if inc_set: 1513 1514 self.normalized_data.columns = ( 1515 self.input_metadata[name_slot] 1516 + " # " 1517 + self.input_metadata["sets"] 1518 ) 1519 1520 if "#" not in full: 1521 1522 self.normalized_data.columns = self.input_metadata[name_slot] 1523 1524 print( 1525 "Not include the set info (name # set) in the 'full' argument, where 'inc_set' is True" 1526 "Only the names will be compared, without considering the set information." 1527 ) 1528 1529 else: 1530 1531 self.normalized_data.columns = self.input_metadata[name_slot] 1532 1533 mask = [full.upper() != x.upper() for x in self.normalized_data.columns] 1534 1535 if len([y for y in mask if y is False]) == 0: 1536 raise ValueError("Nothing found to reduce") 1537 1538 self.normalized_data = self.normalized_data.loc[:, mask] 1539 1540 if self.input_metadata is not None: 1541 1542 if inc_set: 1543 1544 self.input_metadata["drop"] = ( 1545 self.input_metadata[name_slot] 1546 + " # " 1547 + self.input_metadata["sets"] 1548 ) 1549 1550 if "#" not in full: 1551 1552 self.input_metadata["drop"] = self.input_metadata[name_slot] 1553 1554 print( 1555 "Not include the set info (name # set) in the 'full' argument, where 'inc_set' is True" 1556 "Only the names will be compared, without considering the set information." 1557 ) 1558 1559 else: 1560 1561 self.input_metadata["drop"] = self.input_metadata[name_slot] 1562 1563 mask = [full.upper() != x.upper() for x in self.input_metadata["drop"]] 1564 1565 if len([y for y in mask if y is False]) == 0: 1566 raise ValueError("Nothing found to reduce") 1567 1568 self.input_metadata = self.input_metadata.loc[mask, :].reset_index( 1569 drop=True 1570 ) 1571 1572 self.input_metadata = self.input_metadata.drop( 1573 columns=["drop"], errors="ignore" 1574 ) 1575 1576 if self.agg_normalized_data is not None: 1577 1578 if inc_set: 1579 1580 self.agg_normalized_data.columns = ( 1581 self.agg_metadata[name_slot] + " # " + self.agg_metadata["sets"] 1582 ) 1583 1584 if "#" not in full: 1585 1586 self.agg_normalized_data.columns = self.agg_metadata[name_slot] 1587 1588 print( 1589 "Not include the set info (name # set) in the 'full' argument, where 'inc_set' is True" 1590 "Only the names will be compared, without considering the set information." 1591 ) 1592 else: 1593 1594 self.agg_normalized_data.columns = self.agg_metadata[name_slot] 1595 1596 mask = [ 1597 full.upper() != x.upper() for x in self.agg_normalized_data.columns 1598 ] 1599 1600 if len([y for y in mask if y is False]) == 0: 1601 raise ValueError("Nothing found to reduce") 1602 1603 self.agg_normalized_data = self.agg_normalized_data.loc[:, mask] 1604 1605 if self.agg_metadata is not None: 1606 1607 if inc_set: 1608 1609 self.agg_metadata["drop"] = ( 1610 self.agg_metadata[name_slot] + " # " + self.agg_metadata["sets"] 1611 ) 1612 1613 if "#" not in full: 1614 1615 self.agg_metadata["drop"] = self.agg_metadata[name_slot] 1616 1617 print( 1618 "Not include the set info (name # set) in the 'full' argument, where 'inc_set' is True" 1619 "Only the names will be compared, without considering the set information." 1620 ) 1621 else: 1622 1623 self.agg_metadata["drop"] = self.agg_metadata[name_slot] 1624 1625 mask = [full.upper() != x.upper() for x in self.agg_metadata["drop"]] 1626 1627 if len([y for y in mask if y is False]) == 0: 1628 raise ValueError("Nothing found to reduce") 1629 1630 self.agg_metadata = self.agg_metadata.loc[mask, :].reset_index( 1631 drop=True 1632 ) 1633 1634 self.agg_metadata = self.agg_metadata.drop( 1635 columns=["drop"], errors="ignore" 1636 ) 1637 1638 self.gene_calculation() 1639 self.cells_calculation() 1640 1641 def reduce_rows(self, features_list: list): 1642 """ 1643 Remove rows (features) whose names are included in features_list. 1644 1645 Parameters 1646 ---------- 1647 features_list : list 1648 List of features to search for in index/gene names; matching entries will be removed. 1649 1650 Update 1651 ------------ 1652 Mutates `self.input_data`, `self.normalized_data`, `self.input_metadata`, 1653 `self.agg_normalized_data`, and `self.agg_metadata` (if they exist), 1654 removing columns/rows that match `reg`. 1655 1656 Raises 1657 ------ 1658 Prints a message listing features that are not found in the data. 1659 """ 1660 1661 if self.input_data is not None: 1662 1663 res = find_features(self.input_data, features=features_list) 1664 1665 res_list = [x.upper() for x in res["included"]] 1666 1667 mask = [x.upper() not in res_list for x in self.input_data.index] 1668 1669 if len([y for y in mask if y is False]) == 0: 1670 raise ValueError("Nothing found to reduce") 1671 1672 self.input_data = self.input_data.loc[mask, :] 1673 1674 if self.normalized_data is not None: 1675 1676 res = find_features(self.normalized_data, features=features_list) 1677 1678 res_list = [x.upper() for x in res["included"]] 1679 1680 mask = [x.upper() not in res_list for x in self.normalized_data.index] 1681 1682 if len([y for y in mask if y is False]) == 0: 1683 raise ValueError("Nothing found to reduce") 1684 1685 self.normalized_data = self.normalized_data.loc[mask, :] 1686 1687 if self.agg_normalized_data is not None: 1688 1689 res = find_features(self.agg_normalized_data, features=features_list) 1690 1691 res_list = [x.upper() for x in res["included"]] 1692 1693 mask = [x.upper() not in res_list for x in self.agg_normalized_data.index] 1694 1695 if len([y for y in mask if y is False]) == 0: 1696 raise ValueError("Nothing found to reduce") 1697 1698 self.agg_normalized_data = self.agg_normalized_data.loc[mask, :] 1699 1700 if len(res["not_included"]) > 0: 1701 print("\nFeatures not found:") 1702 for i in res["not_included"]: 1703 print(i) 1704 1705 self.gene_calculation() 1706 self.cells_calculation() 1707 1708 def get_data(self, set_info: bool = False): 1709 """ 1710 Return normalized data with optional set annotation appended to column names. 1711 1712 Parameters 1713 ---------- 1714 set_info : bool, default False 1715 If True, column names are returned as "cell_name # set"; otherwise 1716 only the `cell_name` is used. 1717 1718 Returns 1719 ------- 1720 pandas.DataFrame 1721 The `self.normalized_data` table with columns renamed according to `set_info`. 1722 1723 Raises 1724 ------ 1725 AttributeError 1726 If `self.normalized_data` or `self.input_metadata` is missing. 1727 """ 1728 1729 to_return = self.normalized_data 1730 1731 if set_info: 1732 to_return.columns = ( 1733 self.input_metadata["cell_names"] + " # " + self.input_metadata["sets"] 1734 ) 1735 else: 1736 to_return.columns = self.input_metadata["cell_names"] 1737 1738 return to_return 1739 1740 def get_partial_data( 1741 self, 1742 names: list | str | None = None, 1743 features: list | str | None = None, 1744 name_slot: str = "cell_names", 1745 inc_metadata: bool = False, 1746 ): 1747 """ 1748 Return a subset of the data filtered by sample names and/or feature names. 1749 1750 Parameters 1751 ---------- 1752 names : list, str, or None 1753 Names of samples to include. If None, all samples are considered. 1754 1755 features : list, str, or None 1756 Names of features to include. If None, all features are considered. 1757 1758 name_slot : str 1759 Column in metadata to use as sample names. 1760 1761 inc_metadata : bool 1762 If True return tuple (data, metadata) 1763 1764 Returns 1765 ------- 1766 pandas.DataFrame 1767 Subset of the normalized data based on the specified names and features. 1768 """ 1769 1770 data = self.normalized_data.copy() 1771 metadata = self.input_metadata 1772 1773 if name_slot in self.input_metadata.columns: 1774 data.columns = self.input_metadata[name_slot] 1775 else: 1776 raise ValueError("'name_slot' not occured in data!'") 1777 1778 if isinstance(features, str): 1779 features = [features] 1780 elif features is None: 1781 features = [] 1782 1783 if isinstance(names, str): 1784 names = [names] 1785 elif names is None: 1786 names = [] 1787 1788 features = [x.upper() for x in features] 1789 names = [x.upper() for x in names] 1790 1791 columns_names = [x.upper() for x in data.columns] 1792 features_names = [x.upper() for x in data.index] 1793 1794 columns_bool = [True if x in names else False for x in columns_names] 1795 features_bool = [True if x in features else False for x in features_names] 1796 1797 if True not in columns_bool and True not in features_bool: 1798 print("Missing 'names' and/or 'features'. Returning full dataset instead.") 1799 1800 if True in columns_bool: 1801 data = data.loc[:, columns_bool] 1802 metadata = metadata.loc[columns_bool, :] 1803 1804 if True in features_bool: 1805 data = data.loc[features_bool, :] 1806 1807 not_in_features = [y for y in features if y not in features_names] 1808 1809 if len(not_in_features) > 0: 1810 print('\nThe following features were not found in data:') 1811 print('\n'.join(not_in_features)) 1812 1813 not_in_names = [y for y in names if y not in columns_names] 1814 1815 if len(not_in_names) > 0: 1816 print('\nThe following names were not found in data:') 1817 print('\n'.join(not_in_names)) 1818 1819 if inc_metadata: 1820 return data, metadata 1821 else: 1822 return data 1823 1824 def get_metadata(self): 1825 """ 1826 Return the stored input metadata. 1827 1828 Returns 1829 ------- 1830 pandas.DataFrame 1831 `self.input_metadata` (may be None if not set). 1832 """ 1833 1834 to_return = self.input_metadata 1835 1836 return to_return 1837 1838 def gene_calculation(self): 1839 """ 1840 Calculate and store per-cell counts (e.g., number of detected genes). 1841 1842 The method computes a binary (presence/absence) per cell and sums across 1843 features to produce `self.gene_calc`. 1844 1845 Update 1846 ------ 1847 Sets `self.gene_calc` as a pandas.Series. 1848 1849 Side Effects 1850 ------------ 1851 Uses `self.input_data` when available, otherwise `self.normalized_data`. 1852 """ 1853 1854 if self.input_data is not None: 1855 1856 bin_col = self.input_data.columns.copy() 1857 1858 bin_col = bin_col.where(bin_col <= 0, 1) 1859 1860 sum_data = bin_col.sum(axis=0) 1861 1862 self.gene_calc = sum_data 1863 1864 elif self.normalized_data is not None: 1865 1866 bin_col = self.normalized_data.copy() 1867 1868 bin_col = bin_col.where(bin_col <= 0, 1) 1869 1870 sum_data = bin_col.sum(axis=0) 1871 1872 self.gene_calc = sum_data 1873 1874 def gene_histograme(self, bins=100): 1875 """ 1876 Plot a histogram of the number of genes detected per cell. 1877 1878 Parameters 1879 ---------- 1880 bins : int, default 100 1881 Number of histogram bins. 1882 1883 Returns 1884 ------- 1885 matplotlib.figure.Figure 1886 Figure containing the histogram of gene contents. 1887 1888 Notes 1889 ----- 1890 Requires `self.gene_calc` to be computed prior to calling. 1891 """ 1892 1893 fig, ax = plt.subplots(figsize=(8, 5)) 1894 1895 _, bin_edges, _ = ax.hist( 1896 self.gene_calc, bins=bins, edgecolor="black", alpha=0.6 1897 ) 1898 1899 mu, sigma = np.mean(self.gene_calc), np.std(self.gene_calc) 1900 1901 x = np.linspace(min(self.gene_calc), max(self.gene_calc), 1000) 1902 y = norm.pdf(x, mu, sigma) 1903 1904 y_scaled = y * len(self.gene_calc) * (bin_edges[1] - bin_edges[0]) 1905 1906 ax.plot( 1907 x, y_scaled, "r-", linewidth=2, label=f"Normal(μ={mu:.2f}, σ={sigma:.2f})" 1908 ) 1909 1910 ax.set_xlabel("Value") 1911 ax.set_ylabel("Count") 1912 ax.set_title("Histogram of genes detected per cell") 1913 1914 ax.set_xticks(np.linspace(min(self.gene_calc), max(self.gene_calc), 20)) 1915 ax.tick_params(axis="x", rotation=90) 1916 1917 ax.legend() 1918 1919 return fig 1920 1921 def gene_threshold(self, min_n: int | None, max_n: int | None): 1922 """ 1923 Filter cells by gene-detection thresholds (min and/or max). 1924 1925 Parameters 1926 ---------- 1927 min_n : int or None 1928 Minimum number of detected genes required to keep a cell. 1929 1930 max_n : int or None 1931 Maximum number of detected genes allowed to keep a cell. 1932 1933 Update 1934 ------- 1935 Filters `self.input_data`, `self.normalized_data`, `self.input_metadata` 1936 (and calls `average()` if `self.agg_normalized_data` exists). 1937 1938 Side Effects 1939 ------------ 1940 Raises ValueError if both bounds are None or if filtering removes all cells. 1941 """ 1942 1943 if min_n is not None and max_n is not None: 1944 mask = (self.gene_calc > min_n) & (self.gene_calc < max_n) 1945 elif min_n is None and max_n is not None: 1946 mask = self.gene_calc < max_n 1947 elif min_n is not None and max_n is None: 1948 mask = self.gene_calc > min_n 1949 else: 1950 raise ValueError("Lack of both min_n and max_n values") 1951 1952 if self.input_data is not None: 1953 1954 if len([y for y in mask if y is False]) == 0: 1955 raise ValueError("Nothing to reduce") 1956 1957 self.input_data = self.input_data.loc[:, mask.values] 1958 1959 if self.normalized_data is not None: 1960 1961 if len([y for y in mask if y is False]) == 0: 1962 raise ValueError("Nothing to reduce") 1963 1964 self.normalized_data = self.normalized_data.loc[:, mask.values] 1965 1966 if self.input_metadata is not None: 1967 1968 if len([y for y in mask if y is False]) == 0: 1969 raise ValueError("Nothing to reduce") 1970 1971 self.input_metadata = self.input_metadata.loc[mask.values, :].reset_index( 1972 drop=True 1973 ) 1974 1975 self.input_metadata = self.input_metadata.drop( 1976 columns=["drop"], errors="ignore" 1977 ) 1978 1979 if self.agg_normalized_data is not None: 1980 self.average() 1981 1982 self.gene_calculation() 1983 self.cells_calculation() 1984 1985 def cells_calculation(self, name_slot="cell_names"): 1986 """ 1987 Calculate number of cells per call name / cluster. 1988 1989 The method computes a binary (presence/absence) per cell name / cluster and sums across 1990 cells. 1991 1992 Parameters 1993 ---------- 1994 name_slot : str, default 'cell_names' 1995 Column in metadata to use as sample names. 1996 1997 Update 1998 ------ 1999 Sets `self.cells_calc` as a pd.DataFrame. 2000 """ 2001 2002 ls = list(self.input_metadata[name_slot]) 2003 2004 df = pd.DataFrame( 2005 { 2006 "cluster": pd.Series(ls).value_counts().index, 2007 "n": pd.Series(ls).value_counts().values, 2008 } 2009 ) 2010 2011 self.cells_calc = df 2012 2013 def cell_histograme(self, name_slot: str = "cell_names"): 2014 """ 2015 Plot a histogram of the number of cells detected per cell name (cluster). 2016 2017 Parameters 2018 ---------- 2019 name_slot : str, default 'cell_names' 2020 Column in metadata to use as sample names. 2021 2022 Returns 2023 ------- 2024 matplotlib.figure.Figure 2025 Figure containing the histogram of cell contents. 2026 2027 Notes 2028 ----- 2029 Requires `self.cells_calc` to be computed prior to calling. 2030 """ 2031 2032 if name_slot != "cell_names": 2033 self.cells_calculation(name_slot=name_slot) 2034 2035 fig, ax = plt.subplots(figsize=(8, 5)) 2036 2037 _, bin_edges, _ = ax.hist( 2038 list(self.cells_calc["n"]), 2039 bins=len(set(self.cells_calc["cluster"])), 2040 edgecolor="black", 2041 color="orange", 2042 alpha=0.6, 2043 ) 2044 2045 mu, sigma = np.mean(list(self.cells_calc["n"])), np.std( 2046 list(self.cells_calc["n"]) 2047 ) 2048 2049 x = np.linspace( 2050 min(list(self.cells_calc["n"])), max(list(self.cells_calc["n"])), 1000 2051 ) 2052 y = norm.pdf(x, mu, sigma) 2053 2054 y_scaled = y * len(list(self.cells_calc["n"])) * (bin_edges[1] - bin_edges[0]) 2055 2056 ax.plot( 2057 x, y_scaled, "r-", linewidth=2, label=f"Normal(μ={mu:.2f}, σ={sigma:.2f})" 2058 ) 2059 2060 ax.set_xlabel("Value") 2061 ax.set_ylabel("Count") 2062 ax.set_title("Histogram of cells detected per cell name / cluster") 2063 2064 ax.set_xticks( 2065 np.linspace( 2066 min(list(self.cells_calc["n"])), max(list(self.cells_calc["n"])), 20 2067 ) 2068 ) 2069 ax.tick_params(axis="x", rotation=90) 2070 2071 ax.legend() 2072 2073 return fig 2074 2075 def cluster_threshold(self, min_n: int | None, name_slot: str = "cell_names"): 2076 """ 2077 Filter cell names / clusters by cell-detection threshold. 2078 2079 Parameters 2080 ---------- 2081 min_n : int or None 2082 Minimum number of detected genes required to keep a cell. 2083 2084 name_slot : str, default 'cell_names' 2085 Column in metadata to use as sample names. 2086 2087 2088 Update 2089 ------- 2090 Filters `self.input_data`, `self.normalized_data`, `self.input_metadata` 2091 (and calls `average()` if `self.agg_normalized_data` exists). 2092 """ 2093 2094 if name_slot != "cell_names": 2095 self.cells_calculation(name_slot=name_slot) 2096 2097 if min_n is not None: 2098 names = self.cells_calc["cluster"][self.cells_calc["n"] < min_n] 2099 else: 2100 raise ValueError("Lack of min_n value") 2101 2102 if len(names) > 0: 2103 2104 if self.input_data is not None: 2105 2106 self.input_data.columns = self.input_metadata[name_slot] 2107 2108 mask = [not any(r in x for r in names) for x in self.input_data.columns] 2109 2110 if len([y for y in mask if y is False]) > 0: 2111 2112 self.input_data = self.input_data.loc[:, mask] 2113 2114 if self.normalized_data is not None: 2115 2116 self.normalized_data.columns = self.input_metadata[name_slot] 2117 2118 mask = [ 2119 not any(r in x for r in names) for x in self.normalized_data.columns 2120 ] 2121 2122 if len([y for y in mask if y is False]) > 0: 2123 2124 self.normalized_data = self.normalized_data.loc[:, mask] 2125 2126 if self.input_metadata is not None: 2127 2128 self.input_metadata["drop"] = self.input_metadata[name_slot] 2129 2130 mask = [ 2131 not any(r in x for r in names) for x in self.input_metadata["drop"] 2132 ] 2133 2134 if len([y for y in mask if y is False]) > 0: 2135 2136 self.input_metadata = self.input_metadata.loc[mask, :].reset_index( 2137 drop=True 2138 ) 2139 2140 self.input_metadata = self.input_metadata.drop( 2141 columns=["drop"], errors="ignore" 2142 ) 2143 2144 if self.agg_normalized_data is not None: 2145 2146 self.agg_normalized_data.columns = self.agg_metadata[name_slot] 2147 2148 mask = [ 2149 not any(r in x for r in names) 2150 for x in self.agg_normalized_data.columns 2151 ] 2152 2153 if len([y for y in mask if y is False]) > 0: 2154 2155 self.agg_normalized_data = self.agg_normalized_data.loc[:, mask] 2156 2157 if self.agg_metadata is not None: 2158 2159 self.agg_metadata["drop"] = self.agg_metadata[name_slot] 2160 2161 mask = [ 2162 not any(r in x for r in names) for x in self.agg_metadata["drop"] 2163 ] 2164 2165 if len([y for y in mask if y is False]) > 0: 2166 2167 self.agg_metadata = self.agg_metadata.loc[mask, :].reset_index( 2168 drop=True 2169 ) 2170 2171 self.agg_metadata = self.agg_metadata.drop( 2172 columns=["drop"], errors="ignore" 2173 ) 2174 2175 self.gene_calculation() 2176 self.cells_calculation() 2177 2178 def load_sparse_from_projects(self, normalized_data: bool = False): 2179 """ 2180 Load sparse 10x-style datasets from stored project paths, concatenate them, 2181 and populate `input_data` / `normalized_data` and `input_metadata`. 2182 2183 Parameters 2184 ---------- 2185 normalized_data : bool, default False 2186 If True, store concatenated tables in `self.normalized_data`. 2187 If False, store them in `self.input_data` and normalization 2188 is needed using normalize_data() method. 2189 2190 Side Effects 2191 ------------ 2192 - Reads each project using `load_sparse(...)` (expects matrix.mtx, genes.tsv, barcodes.tsv). 2193 - Concatenates all projects column-wise and sets `self.input_metadata`. 2194 - Replaces NaNs with zeros and updates `self.gene_calc`. 2195 """ 2196 2197 obj = self.objects 2198 2199 full_data = pd.DataFrame() 2200 full_metadata = pd.DataFrame() 2201 2202 for ke in obj.keys(): 2203 print(ke) 2204 2205 dt, met = load_sparse(path=obj[ke], name=ke) 2206 2207 full_data = pd.concat([full_data, dt], axis=1) 2208 full_metadata = pd.concat([full_metadata, met], axis=0) 2209 2210 full_data[np.isnan(full_data)] = 0 2211 2212 if normalized_data: 2213 self.normalized_data = full_data 2214 self.input_metadata = full_metadata 2215 else: 2216 2217 self.input_data = full_data 2218 self.input_metadata = full_metadata 2219 2220 self.gene_calculation() 2221 self.cells_calculation() 2222 2223 def rename_names(self, mapping: dict, slot: str = "cell_names"): 2224 """ 2225 Rename entries in `self.input_metadata[slot]` according to a provided mapping. 2226 2227 Parameters 2228 ---------- 2229 mapping : dict 2230 Dictionary with keys 'old_name' and 'new_name', each mapping to a list 2231 of equal length describing replacements. 2232 2233 slot : str, default 'cell_names' 2234 Metadata column to operate on. 2235 2236 Update 2237 ------- 2238 Updates `self.input_metadata[slot]` in-place with renamed values. 2239 2240 Raises 2241 ------ 2242 ValueError 2243 If mapping keys are incorrect, lengths differ, or some 'old_name' values 2244 are not present in the metadata column. 2245 """ 2246 2247 if set(["old_name", "new_name"]) != set(mapping.keys()): 2248 raise ValueError( 2249 "Mapping dictionary must contain keys 'old_name' and 'new_name', " 2250 "each with a list of names to change." 2251 ) 2252 2253 if len(mapping["old_name"]) != len(mapping["new_name"]): 2254 raise ValueError( 2255 "Mapping dictionary lists 'old_name' and 'new_name' " 2256 "must have the same length!" 2257 ) 2258 2259 names_vector = list(self.input_metadata[slot]) 2260 2261 if not all(elem in names_vector for elem in list(mapping["old_name"])): 2262 raise ValueError( 2263 f"Some entries from 'old_name' do not exist in the names of slot {slot}." 2264 ) 2265 2266 replace_dict = dict(zip(mapping["old_name"], mapping["new_name"])) 2267 2268 names_vector_ret = [replace_dict.get(item, item) for item in names_vector] 2269 2270 self.input_metadata[slot] = names_vector_ret 2271 2272 def rename_subclusters(self, mapping): 2273 """ 2274 Rename labels stored in `self.subclusters_.subclusters` according to mapping. 2275 2276 Parameters 2277 ---------- 2278 mapping : dict 2279 Mapping with keys 'old_name' and 'new_name' (lists of equal length). 2280 2281 Update 2282 ------- 2283 Updates `self.subclusters_.subclusters` with renamed labels. 2284 2285 Raises 2286 ------ 2287 ValueError 2288 If mapping is invalid or old names are not present. 2289 """ 2290 2291 if set(["old_name", "new_name"]) != set(mapping.keys()): 2292 raise ValueError( 2293 "Mapping dictionary must contain keys 'old_name' and 'new_name', " 2294 "each with a list of names to change." 2295 ) 2296 2297 if len(mapping["old_name"]) != len(mapping["new_name"]): 2298 raise ValueError( 2299 "Mapping dictionary lists 'old_name' and 'new_name' " 2300 "must have the same length!" 2301 ) 2302 2303 names_vector = list(self.subclusters_.subclusters) 2304 2305 if not all(elem in names_vector for elem in list(mapping["old_name"])): 2306 raise ValueError( 2307 "Some entries from 'old_name' do not exist in the subcluster names." 2308 ) 2309 2310 replace_dict = dict(zip(mapping["old_name"], mapping["new_name"])) 2311 2312 names_vector_ret = [replace_dict.get(item, item) for item in names_vector] 2313 2314 self.subclusters_.subclusters = names_vector_ret 2315 2316 def save_sparse( 2317 self, 2318 path_to_save: str = os.getcwd(), 2319 name_slot: str = "cell_names", 2320 data_slot: str = "normalized", 2321 ): 2322 """ 2323 Export data as 10x-compatible sparse files (matrix.mtx, barcodes.tsv, genes.tsv). 2324 2325 Parameters 2326 ---------- 2327 path_to_save : str, default current working directory 2328 Directory where files will be written. 2329 2330 name_slot : str, default 'cell_names' 2331 Metadata column providing cell names for barcodes.tsv. 2332 2333 data_slot : str, default 'normalized' 2334 Either 'normalized' (uses self.normalized_data) or 'count' (uses self.input_data). 2335 2336 Raises 2337 ------ 2338 ValueError 2339 If `data_slot` is not 'normalized' or 'count'. 2340 """ 2341 2342 names = self.input_metadata[name_slot] 2343 2344 if data_slot.lower() == "normalized": 2345 2346 features = list(self.normalized_data.index) 2347 mtx = sparse.csr_matrix(self.normalized_data) 2348 2349 elif data_slot.lower() == "count": 2350 2351 features = list(self.input_data.index) 2352 mtx = sparse.csr_matrix(self.input_data) 2353 2354 else: 2355 raise ValueError("'data_slot' must be included in 'normalized' or 'count'") 2356 2357 os.makedirs(path_to_save, exist_ok=True) 2358 2359 mmwrite(os.path.join(path_to_save, "matrix.mtx"), mtx) 2360 2361 pd.Series(names).to_csv( 2362 os.path.join(path_to_save, "barcodes.tsv"), 2363 index=False, 2364 header=False, 2365 sep="\t", 2366 ) 2367 2368 pd.Series(features).to_csv( 2369 os.path.join(path_to_save, "genes.tsv"), index=False, header=False, sep="\t" 2370 ) 2371 2372 def normalize_counts( 2373 self, normalize_factor: int = 100000, log_transform: bool = True 2374 ): 2375 """ 2376 Normalize raw counts to counts-per-(normalize_factor) 2377 (e.g., CPM, TPM - depending on normalize_factor). 2378 2379 Parameters 2380 ---------- 2381 normalize_factor : int, default 100000 2382 Scaling factor used after dividing by column sums. 2383 2384 log_transform : bool, default True 2385 If True, apply log2(x+1) transformation to normalized values. 2386 2387 Update 2388 ------- 2389 Sets `self.normalized_data` to normalized values (fills NaNs with 0). 2390 2391 Raises 2392 ------ 2393 ValueError 2394 If `self.input_data` is missing (cannot normalize). 2395 """ 2396 if self.input_data is None: 2397 raise ValueError("Input data is missing, cannot normalize.") 2398 2399 sum_col = self.input_data.sum() 2400 self.normalized_data = self.input_data.div(sum_col).fillna(0) * normalize_factor 2401 2402 if log_transform: 2403 # log2(x + 1) to avoid -inf for zeros 2404 self.normalized_data = np.log2(self.normalized_data + 1) 2405 2406 def statistic( 2407 self, 2408 cells=None, 2409 sets=None, 2410 min_exp: float = 0.01, 2411 min_pct: float = 0.1, 2412 n_proc: int = 10, 2413 ): 2414 """ 2415 Compute per-feature statistics (Mann–Whitney U) comparing target vs rest. 2416 2417 This is a wrapper similar to `calc_DEG` tailored to use `self.normalized_data` 2418 and `self.input_metadata`. It returns per-feature statistics including p-values, 2419 adjusted p-values, means, variances, effect-size measures and fold-changes. 2420 2421 Parameters 2422 ---------- 2423 cells : list, 'All', dict, or None 2424 Defines the target cells or groups for comparison (several modes supported). 2425 2426 sets : 'All', dict, or None 2427 Alternative grouping mode (operate on `self.input_metadata['sets']`). 2428 2429 min_exp : float, default 0.01 2430 Minimum expression threshold used when filtering features. 2431 2432 min_pct : float, default 0.1 2433 Minimum proportion of expressing cells in the target group required to test a feature. 2434 2435 n_proc : int, default 10 2436 Number of parallel jobs to use. 2437 2438 Returns 2439 ------- 2440 pandas.DataFrame or dict 2441 Results DataFrame (or dict containing valid/control cells + DataFrame), 2442 similar to `calc_DEG` interface. 2443 2444 Raises 2445 ------ 2446 ValueError 2447 If neither `cells` nor `sets` is provided, or input metadata mismatch occurs. 2448 2449 Notes 2450 ----- 2451 Multiple modes supported: single-list entities, 'All', pairwise dicts, etc. 2452 """ 2453 2454 offset = 1e-100 2455 2456 def stat_calc(choose, feature_name): 2457 target_values = choose.loc[choose["DEG"] == "target", feature_name] 2458 rest_values = choose.loc[choose["DEG"] == "rest", feature_name] 2459 2460 pct_valid = (target_values > 0).sum() / len(target_values) 2461 pct_rest = (rest_values > 0).sum() / len(rest_values) 2462 2463 avg_valid = np.mean(target_values) 2464 avg_ctrl = np.mean(rest_values) 2465 sd_valid = np.std(target_values, ddof=1) 2466 sd_ctrl = np.std(rest_values, ddof=1) 2467 esm = (avg_valid - avg_ctrl) / np.sqrt(((sd_valid**2 + sd_ctrl**2) / 2)) 2468 2469 if np.sum(target_values) == np.sum(rest_values): 2470 p_val = 1.0 2471 else: 2472 _, p_val = stats.mannwhitneyu( 2473 target_values, rest_values, alternative="two-sided" 2474 ) 2475 2476 return { 2477 "feature": feature_name, 2478 "p_val": p_val, 2479 "pct_valid": pct_valid, 2480 "pct_ctrl": pct_rest, 2481 "avg_valid": avg_valid, 2482 "avg_ctrl": avg_ctrl, 2483 "sd_valid": sd_valid, 2484 "sd_ctrl": sd_ctrl, 2485 "esm": esm, 2486 } 2487 2488 def prepare_and_run_stat(choose, valid_group, min_exp, min_pct, n_proc): 2489 2490 def safe_min_half(series): 2491 filtered = series[(series > ((2**-1074)*2)) & (series.notna())] 2492 return filtered.min() / 2 if not filtered.empty else 0 2493 2494 tmp_dat = choose[choose["DEG"] == "target"] 2495 tmp_dat = tmp_dat.drop("DEG", axis=1) 2496 2497 counts = (tmp_dat > min_exp).sum(axis=0) 2498 2499 total_count = tmp_dat.shape[0] 2500 2501 info = pd.DataFrame( 2502 {"feature": list(tmp_dat.columns), "pct": list(counts / total_count)} 2503 ) 2504 2505 del tmp_dat 2506 2507 drop_col = info["feature"][info["pct"] <= min_pct] 2508 2509 if len(drop_col) + 1 == len(choose.columns): 2510 drop_col = info["feature"][info["pct"] == 0] 2511 2512 del info 2513 2514 choose = choose.drop(list(drop_col), axis=1) 2515 2516 results = Parallel(n_jobs=n_proc)( 2517 delayed(stat_calc)(choose, feature) 2518 for feature in tqdm(choose.columns[choose.columns != "DEG"]) 2519 ) 2520 2521 df = pd.DataFrame(results) 2522 df = df[(df["avg_valid"] > 0) | (df["avg_ctrl"] > 0)] 2523 2524 df["valid_group"] = valid_group 2525 df.sort_values(by="p_val", inplace=True) 2526 2527 num_tests = len(df) 2528 df["adj_pval"] = np.minimum( 2529 1, (df["p_val"] * num_tests) / np.arange(1, num_tests + 1) 2530 ) 2531 2532 valid_factor = safe_min_half(df["avg_valid"]) 2533 ctrl_factor = safe_min_half(df["avg_ctrl"]) 2534 2535 cv_factor = min(valid_factor, ctrl_factor) 2536 2537 if cv_factor == 0: 2538 cv_factor = max(valid_factor, ctrl_factor) 2539 2540 if not np.isfinite(cv_factor) or cv_factor == 0: 2541 cv_factor += offset 2542 2543 valid = df["avg_valid"].where( 2544 df["avg_valid"] != 0, df["avg_valid"] + cv_factor 2545 ) 2546 ctrl = df["avg_ctrl"].where( 2547 df["avg_ctrl"] != 0, df["avg_ctrl"] + cv_factor 2548 ) 2549 2550 df["FC"] = valid / ctrl 2551 2552 df["log(FC)"] = np.log2(df["FC"]) 2553 df["norm_diff"] = df["avg_valid"] - df["avg_ctrl"] 2554 2555 return df 2556 2557 choose = self.normalized_data.copy().T 2558 2559 final_results = [] 2560 2561 if isinstance(cells, list) and sets is None: 2562 print("\nAnalysis started...\nComparing selected cells to the whole set...") 2563 choose.index = ( 2564 self.input_metadata["cell_names"] + " # " + self.input_metadata["sets"] 2565 ) 2566 2567 if "#" not in cells[0]: 2568 choose.index = self.input_metadata["cell_names"] 2569 2570 print( 2571 "Not include the set info (name # set) in the 'cells' list. " 2572 "Only the names will be compared, without considering the set information." 2573 ) 2574 2575 labels = ["target" if idx in cells else "rest" for idx in choose.index] 2576 valid = list( 2577 set(choose.index[[i for i, x in enumerate(labels) if x == "target"]]) 2578 ) 2579 2580 choose["DEG"] = labels 2581 choose = choose[choose["DEG"] != "drop"] 2582 2583 result_df = prepare_and_run_stat( 2584 choose.reset_index(drop=True), 2585 valid_group=valid, 2586 min_exp=min_exp, 2587 min_pct=min_pct, 2588 n_proc=n_proc, 2589 ) 2590 return {"valid_cells": valid, "control_cells": "rest", "DEG": result_df} 2591 2592 elif cells == "All" and sets is None: 2593 print("\nAnalysis started...\nComparing each type of cell to others...") 2594 choose.index = ( 2595 self.input_metadata["cell_names"] + " # " + self.input_metadata["sets"] 2596 ) 2597 unique_labels = set(choose.index) 2598 2599 for label in tqdm(unique_labels): 2600 print(f"\nCalculating statistics for {label}") 2601 labels = ["target" if idx == label else "rest" for idx in choose.index] 2602 choose["DEG"] = labels 2603 choose = choose[choose["DEG"] != "drop"] 2604 result_df = prepare_and_run_stat( 2605 choose.copy(), 2606 valid_group=label, 2607 min_exp=min_exp, 2608 min_pct=min_pct, 2609 n_proc=n_proc, 2610 ) 2611 final_results.append(result_df) 2612 2613 return pd.concat(final_results, ignore_index=True) 2614 2615 elif cells is None and sets == "All": 2616 print("\nAnalysis started...\nComparing each set/group to others...") 2617 choose.index = self.input_metadata["sets"] 2618 unique_sets = set(choose.index) 2619 2620 for label in tqdm(unique_sets): 2621 print(f"\nCalculating statistics for {label}") 2622 labels = ["target" if idx == label else "rest" for idx in choose.index] 2623 2624 choose["DEG"] = labels 2625 choose = choose[choose["DEG"] != "drop"] 2626 result_df = prepare_and_run_stat( 2627 choose.copy(), 2628 valid_group=label, 2629 min_exp=min_exp, 2630 min_pct=min_pct, 2631 n_proc=n_proc, 2632 ) 2633 final_results.append(result_df) 2634 2635 return pd.concat(final_results, ignore_index=True) 2636 2637 elif cells is None and isinstance(sets, dict): 2638 print("\nAnalysis started...\nComparing groups...") 2639 2640 choose.index = self.input_metadata["sets"] 2641 2642 group_list = list(sets.keys()) 2643 if len(group_list) != 2: 2644 print("Only pairwise group comparison is supported.") 2645 return None 2646 2647 labels = [ 2648 ( 2649 "target" 2650 if idx in sets[group_list[0]] 2651 else "rest" if idx in sets[group_list[1]] else "drop" 2652 ) 2653 for idx in choose.index 2654 ] 2655 choose["DEG"] = labels 2656 choose = choose[choose["DEG"] != "drop"] 2657 2658 result_df = prepare_and_run_stat( 2659 choose.reset_index(drop=True), 2660 valid_group=group_list[0], 2661 min_exp=min_exp, 2662 min_pct=min_pct, 2663 n_proc=n_proc, 2664 ) 2665 return result_df 2666 2667 elif isinstance(cells, dict) and sets is None: 2668 print("\nAnalysis started...\nComparing groups...") 2669 choose.index = ( 2670 self.input_metadata["cell_names"] + " # " + self.input_metadata["sets"] 2671 ) 2672 2673 if "#" not in cells[list(cells.keys())[0]][0]: 2674 choose.index = self.input_metadata["cell_names"] 2675 2676 print( 2677 "Not include the set info (name # set) in the 'cells' dict. " 2678 "Only the names will be compared, without considering the set information." 2679 ) 2680 2681 group_list = list(cells.keys()) 2682 if len(group_list) != 2: 2683 print("Only pairwise group comparison is supported.") 2684 return None 2685 2686 labels = [ 2687 ( 2688 "target" 2689 if idx in cells[group_list[0]] 2690 else "rest" if idx in cells[group_list[1]] else "drop" 2691 ) 2692 for idx in choose.index 2693 ] 2694 2695 choose["DEG"] = labels 2696 choose = choose[choose["DEG"] != "drop"] 2697 2698 result_df = prepare_and_run_stat( 2699 choose.reset_index(drop=True), 2700 valid_group=group_list[0], 2701 min_exp=min_exp, 2702 min_pct=min_pct, 2703 n_proc=n_proc, 2704 ) 2705 2706 return result_df.reset_index(drop=True) 2707 2708 else: 2709 raise ValueError( 2710 "You must specify either 'cells' or 'sets' (or both). None were provided, which is not allowed for this analysis." 2711 ) 2712 2713 def calculate_difference_markers( 2714 self, min_exp=0, min_pct=0.25, n_proc=10, force: bool = False 2715 ): 2716 """ 2717 Compute differential markers (var_data) if not already present. 2718 2719 Parameters 2720 ---------- 2721 min_exp : float, default 0 2722 Minimum expression threshold passed to `statistic`. 2723 2724 min_pct : float, default 0.25 2725 Minimum percent expressed in target group. 2726 2727 n_proc : int, default 10 2728 Parallel jobs. 2729 2730 force : bool, default False 2731 If True, recompute even if `self.var_data` is present. 2732 2733 Update 2734 ------- 2735 Sets `self.var_data` to the result of `self.statistic(...)`. 2736 2737 Raise 2738 ------ 2739 ValueError if already computed and `force` is False. 2740 """ 2741 2742 if self.var_data is None or force: 2743 2744 self.var_data = self.statistic( 2745 cells="All", sets=None, min_exp=min_exp, min_pct=min_pct, n_proc=n_proc 2746 ) 2747 2748 else: 2749 raise ValueError( 2750 "self.calculate_difference_markers() has already been executed. " 2751 "The results are stored in self.var. " 2752 "If you want to recalculate with different statistics, please rerun the method with force=True." 2753 ) 2754 2755 def clustering_features( 2756 self, 2757 features_list: list | None, 2758 name_slot: str = "cell_names", 2759 p_val: float = 0.05, 2760 top_n: int = 25, 2761 adj_mean: bool = True, 2762 beta: float = 0.2, 2763 ): 2764 """ 2765 Prepare clustering input by selecting marker features and optionally smoothing cell values 2766 toward group means. 2767 2768 Parameters 2769 ---------- 2770 features_list : list or None 2771 If provided, use this list of features. If None, features are selected 2772 from `self.var_data` (adj_pval <= p_val, positive logFC) picking `top_n` per group. 2773 2774 name_slot : str, default 'cell_names' 2775 Metadata column used for naming. 2776 2777 p_val : float, default 0.05 2778 Adjusted p-value cutoff when selecting features automatically. 2779 2780 top_n : int, default 25 2781 Number of top features per valid group to keep if `features_list` is None. 2782 2783 adj_mean : bool, default True 2784 If True, adjust cell values toward group means using `beta`. 2785 2786 beta : float, default 0.2 2787 Adjustment strength toward group mean. 2788 2789 Update 2790 ------ 2791 Sets `self.clustering_data` and `self.clustering_metadata` to the selected subset, 2792 ready for PCA/UMAP/clustering. 2793 """ 2794 2795 if features_list is None or len(features_list) == 0: 2796 2797 if self.var_data is None: 2798 raise ValueError( 2799 "Lack of 'self.var_data'. Use self.calculate_difference_markers() method first." 2800 ) 2801 2802 df_tmp = self.var_data[self.var_data["adj_pval"] <= p_val] 2803 df_tmp = df_tmp[df_tmp["log(FC)"] > 0] 2804 df_tmp = ( 2805 df_tmp.sort_values( 2806 ["valid_group", "esm", "log(FC)"], ascending=[True, False, False] 2807 ) 2808 .groupby("valid_group") 2809 .head(top_n) 2810 ) 2811 2812 feaures_list = list(set(df_tmp["feature"])) 2813 2814 data = self.get_partial_data( 2815 names=None, features=feaures_list, name_slot=name_slot 2816 ) 2817 data_avg = average(data) 2818 2819 if adj_mean: 2820 data = adjust_cells_to_group_mean(data=data, data_avg=data_avg, beta=beta) 2821 2822 self.clustering_data = data 2823 2824 self.clustering_metadata = self.input_metadata 2825 2826 def average(self): 2827 """ 2828 Aggregate normalized data by (cell_name, set) pairs computing the mean per group. 2829 2830 The method constructs new column names as "cell_name # set", averages columns 2831 sharing identical labels, and populates `self.agg_normalized_data` and `self.agg_metadata`. 2832 2833 Update 2834 ------ 2835 Sets `self.agg_normalized_data` (features x aggregated samples) and 2836 `self.agg_metadata` (DataFrame with 'cell_names' and 'sets'). 2837 """ 2838 2839 wide_data = self.normalized_data 2840 2841 wide_metadata = self.input_metadata 2842 2843 new_names = wide_metadata["cell_names"] + " # " + wide_metadata["sets"] 2844 2845 wide_data.columns = list(new_names) 2846 2847 aggregated_df = wide_data.T.groupby(level=0).mean().T 2848 2849 sets = [re.sub(".*# ", "", x) for x in aggregated_df.columns] 2850 names = [re.sub(" #.*", "", x) for x in aggregated_df.columns] 2851 2852 aggregated_df.columns = names 2853 aggregated_metadata = pd.DataFrame({"cell_names": names, "sets": sets}) 2854 2855 self.agg_metadata = aggregated_metadata 2856 self.agg_normalized_data = aggregated_df 2857 2858 def estimating_similarity( 2859 self, method="pearson", p_val: float = 0.05, top_n: int = 25 2860 ): 2861 """ 2862 Estimate pairwise similarity and Euclidean distance between aggregated samples. 2863 2864 Parameters 2865 ---------- 2866 method : str, default 'pearson' 2867 Correlation method to use (passed to pandas.DataFrame.corr()). 2868 2869 p_val : float, default 0.05 2870 Adjusted p-value cutoff used to select marker features from `self.var_data`. 2871 2872 top_n : int, default 25 2873 Number of top features per valid group to include. 2874 2875 Update 2876 ------- 2877 Computes a combined table with per-pair correlation and euclidean distance 2878 and stores it in `self.similarity`. 2879 """ 2880 2881 if self.var_data is None: 2882 raise ValueError( 2883 "Lack of 'self.var_data'. Use self.calculate_difference_markers() method first." 2884 ) 2885 2886 if self.agg_normalized_data is None: 2887 self.average() 2888 2889 metadata = self.agg_metadata 2890 data = self.agg_normalized_data 2891 2892 df_tmp = self.var_data[self.var_data["adj_pval"] <= p_val] 2893 df_tmp = df_tmp[df_tmp["log(FC)"] > 0] 2894 df_tmp = ( 2895 df_tmp.sort_values( 2896 ["valid_group", "esm", "log(FC)"], ascending=[True, False, False] 2897 ) 2898 .groupby("valid_group") 2899 .head(top_n) 2900 ) 2901 2902 data = data.loc[list(set(df_tmp["feature"]))] 2903 2904 if len(set(metadata["sets"])) > 1: 2905 data.columns = data.columns + " # " + [x for x in metadata["sets"]] 2906 else: 2907 data = data.copy() 2908 2909 scaler = StandardScaler() 2910 2911 scaled_data = scaler.fit_transform(data) 2912 2913 scaled_df = pd.DataFrame(scaled_data, columns=data.columns) 2914 2915 cor = scaled_df.corr(method=method) 2916 cor_df = cor.stack().reset_index() 2917 cor_df.columns = ["cell1", "cell2", "correlation"] 2918 2919 distances = pdist(scaled_df.T, metric="euclidean") 2920 dist_mat = pd.DataFrame( 2921 squareform(distances), index=scaled_df.columns, columns=scaled_df.columns 2922 ) 2923 dist_df = dist_mat.stack().reset_index() 2924 dist_df.columns = ["cell1", "cell2", "euclidean_dist"] 2925 2926 full = pd.merge(cor_df, dist_df, on=["cell1", "cell2"]) 2927 2928 full = full[full["cell1"] != full["cell2"]] 2929 full = full.reset_index(drop=True) 2930 2931 self.similarity = full 2932 2933 def similarity_plot( 2934 self, 2935 split_sets=True, 2936 set_info: bool = True, 2937 cmap="seismic", 2938 width=12, 2939 height=10, 2940 ): 2941 """ 2942 Visualize pairwise similarity as a scatter plot. 2943 2944 Parameters 2945 ---------- 2946 split_sets : bool, default True 2947 If True and set information is present, split plotting area roughly into two halves to visualize cross-set pairs. 2948 2949 set_info : bool, default True 2950 If True, keep the ' # set' annotation in labels; otherwise strip it. 2951 2952 cmap : str, default 'seismic' 2953 Color map for correlation (hue). 2954 2955 width : int, default 12 2956 Figure width. 2957 2958 height : int, default 10 2959 Figure height. 2960 2961 Returns 2962 ------- 2963 matplotlib.figure.Figure 2964 2965 Raises 2966 ------ 2967 ValueError 2968 If `self.similarity` is None. 2969 2970 Notes 2971 ----- 2972 The function filters pairs by z-scored euclidean distance > 0 to focus on closer pairs. 2973 """ 2974 2975 if self.similarity is None: 2976 raise ValueError( 2977 "Similarity data is missing. Please calculate similarity using self.estimating_similarity." 2978 ) 2979 2980 similarity_data = self.similarity 2981 2982 if " # " in similarity_data["cell1"][0]: 2983 similarity_data["set1"] = [ 2984 re.sub(".*# ", "", x) for x in similarity_data["cell1"] 2985 ] 2986 similarity_data["set2"] = [ 2987 re.sub(".*# ", "", x) for x in similarity_data["cell2"] 2988 ] 2989 2990 if split_sets and " # " in similarity_data["cell1"][0]: 2991 sets = list( 2992 set(list(similarity_data["set1"]) + list(similarity_data["set2"])) 2993 ) 2994 2995 mm = math.ceil(len(sets) / 2) 2996 2997 x_s = sets[0:mm] 2998 y_s = sets[mm : len(sets)] 2999 3000 similarity_data = similarity_data[similarity_data["set1"].isin(x_s)] 3001 similarity_data = similarity_data[similarity_data["set2"].isin(y_s)] 3002 3003 similarity_data = similarity_data.sort_values(["set1", "set2"]) 3004 3005 if set_info is False and " # " in similarity_data["cell1"][0]: 3006 similarity_data["cell1"] = [ 3007 re.sub(" #.*", "", x) for x in similarity_data["cell1"] 3008 ] 3009 similarity_data["cell2"] = [ 3010 re.sub(" #.*", "", x) for x in similarity_data["cell2"] 3011 ] 3012 3013 similarity_data["-euclidean_zscore"] = -zscore( 3014 similarity_data["euclidean_dist"] 3015 ) 3016 3017 similarity_data = similarity_data[similarity_data["-euclidean_zscore"] > 0] 3018 3019 fig = plt.figure(figsize=(width, height)) 3020 sns.scatterplot( 3021 data=similarity_data, 3022 x="cell1", 3023 y="cell2", 3024 hue="correlation", 3025 size="-euclidean_zscore", 3026 sizes=(1, 100), 3027 palette=cmap, 3028 alpha=1, 3029 edgecolor="black", 3030 ) 3031 3032 plt.xticks(rotation=90) 3033 plt.yticks(rotation=0) 3034 plt.xlabel("Cell 1") 3035 plt.ylabel("Cell 2") 3036 plt.legend(bbox_to_anchor=(1.05, 1), loc="upper left") 3037 3038 plt.grid(True, alpha=0.6) 3039 3040 plt.tight_layout() 3041 3042 return fig 3043 3044 def spatial_similarity( 3045 self, 3046 set_info: bool = True, 3047 bandwidth=1, 3048 n_neighbors=5, 3049 min_dist=0.1, 3050 legend_split=2, 3051 point_size=100, 3052 spread=1.0, 3053 set_op_mix_ratio=1.0, 3054 local_connectivity=1, 3055 repulsion_strength=1.0, 3056 negative_sample_rate=5, 3057 threshold=0.1, 3058 width=12, 3059 height=10, 3060 ): 3061 """ 3062 Create a spatial UMAP-like visualization of similarity relationships between samples. 3063 3064 Parameters 3065 ---------- 3066 set_info : bool, default True 3067 If True, retain set information in labels. 3068 3069 bandwidth : float, default 1 3070 Bandwidth used by MeanShift for clustering polygons. 3071 3072 point_size : float, default 100 3073 Size of scatter points. 3074 3075 legend_split : int, default 2 3076 Number of columns in legend. 3077 3078 n_neighbors, min_dist, spread, set_op_mix_ratio, local_connectivity, repulsion_strength, negative_sample_rate : parameters passed to UMAP. 3079 3080 threshold : float, default 0.1 3081 Minimum text distance for label adjustment to avoid overlap. 3082 3083 width : int, default 12 3084 Figure width. 3085 3086 height : int, default 10 3087 Figure height. 3088 3089 Returns 3090 ------- 3091 matplotlib.figure.Figure 3092 3093 Raises 3094 ------ 3095 ValueError 3096 If `self.similarity` is None. 3097 3098 Notes 3099 ----- 3100 Builds a precomputed distance matrix combining correlation and euclidean distance, 3101 runs UMAP with metric='precomputed', then overlays cluster hulls (MeanShift + convex hull) 3102 and arrows to indicate nearest neighbors (minimal combined distance). 3103 """ 3104 3105 if self.similarity is None: 3106 raise ValueError( 3107 "Similarity data is missing. Please calculate similarity using self.estimating_similarity." 3108 ) 3109 3110 similarity_data = self.similarity 3111 3112 sim = similarity_data["correlation"] 3113 sim_scaled = (sim - sim.min()) / (sim.max() - sim.min()) 3114 eu_dist = similarity_data["euclidean_dist"] 3115 eu_dist_scaled = (eu_dist - eu_dist.min()) / (eu_dist.max() - eu_dist.min()) 3116 3117 similarity_data["combo_dist"] = (1 - sim_scaled) * eu_dist_scaled 3118 3119 # for nn target 3120 arrow_df = similarity_data.copy() 3121 arrow_df = similarity_data.loc[ 3122 similarity_data.groupby("cell1")["combo_dist"].idxmin() 3123 ] 3124 3125 cells = sorted(set(similarity_data["cell1"]) | set(similarity_data["cell2"])) 3126 combo_matrix = pd.DataFrame(0, index=cells, columns=cells, dtype=float) 3127 3128 for _, row in similarity_data.iterrows(): 3129 combo_matrix.loc[row["cell1"], row["cell2"]] = row["combo_dist"] 3130 combo_matrix.loc[row["cell2"], row["cell1"]] = row["combo_dist"] 3131 3132 umap_model = umap.UMAP( 3133 n_components=2, 3134 metric="precomputed", 3135 n_neighbors=n_neighbors, 3136 min_dist=min_dist, 3137 spread=spread, 3138 set_op_mix_ratio=set_op_mix_ratio, 3139 local_connectivity=set_op_mix_ratio, 3140 repulsion_strength=repulsion_strength, 3141 negative_sample_rate=negative_sample_rate, 3142 transform_seed=42, 3143 init="spectral", 3144 random_state=42, 3145 verbose=True, 3146 ) 3147 3148 coords = umap_model.fit_transform(combo_matrix.values) 3149 cell_names = list(combo_matrix.index) 3150 num_cells = len(cell_names) 3151 palette = sns.color_palette("tab20c", num_cells) 3152 3153 if "#" in cell_names[0]: 3154 avsets = set( 3155 [re.sub(".*# ", "", x) for x in similarity_data["cell1"]] 3156 + [re.sub(".*# ", "", x) for x in similarity_data["cell2"]] 3157 ) 3158 num_sets = len(avsets) 3159 color_indices = [i * len(palette) // num_sets for i in range(num_sets)] 3160 color_mapping_sets = { 3161 set_name: palette[i] for i, set_name in zip(color_indices, avsets) 3162 } 3163 color_mapping = { 3164 name: color_mapping_sets[re.sub(".*# ", "", name)] 3165 for i, name in enumerate(cell_names) 3166 } 3167 else: 3168 color_mapping = {name: palette[i] for i, name in enumerate(cell_names)} 3169 3170 meanshift = MeanShift(bandwidth=bandwidth) 3171 labels = meanshift.fit_predict(coords) 3172 3173 fig = plt.figure(figsize=(width, height)) 3174 ax = plt.gca() 3175 3176 unique_labels = set(labels) 3177 cluster_palette = sns.color_palette("hls", len(unique_labels)) 3178 3179 for label in unique_labels: 3180 if label == -1: 3181 continue 3182 cluster_coords = coords[labels == label] 3183 if len(cluster_coords) < 3: 3184 continue 3185 3186 hull = ConvexHull(cluster_coords) 3187 hull_points = cluster_coords[hull.vertices] 3188 3189 centroid = np.mean(hull_points, axis=0) 3190 expanded = hull_points + 0.05 * (hull_points - centroid) 3191 3192 poly = Polygon( 3193 expanded, 3194 closed=True, 3195 facecolor=cluster_palette[label], 3196 edgecolor="none", 3197 alpha=0.2, 3198 zorder=1, 3199 ) 3200 ax.add_patch(poly) 3201 3202 texts = [] 3203 for i, (x, y) in enumerate(coords): 3204 plt.scatter( 3205 x, 3206 y, 3207 s=point_size, 3208 color=color_mapping[cell_names[i]], 3209 edgecolors="black", 3210 linewidths=0.5, 3211 zorder=2, 3212 ) 3213 texts.append( 3214 ax.text( 3215 x, y, str(i), ha="center", va="center", fontsize=8, color="black" 3216 ) 3217 ) 3218 3219 def dist(p1, p2): 3220 return np.sqrt((p1[0] - p2[0]) ** 2 + (p1[1] - p2[1]) ** 2) 3221 3222 texts_to_adjust = [] 3223 for i, t1 in enumerate(texts): 3224 for j, t2 in enumerate(texts): 3225 if i >= j: 3226 continue 3227 d = dist( 3228 (t1.get_position()[0], t1.get_position()[1]), 3229 (t2.get_position()[0], t2.get_position()[1]), 3230 ) 3231 if d < threshold: 3232 if t1 not in texts_to_adjust: 3233 texts_to_adjust.append(t1) 3234 if t2 not in texts_to_adjust: 3235 texts_to_adjust.append(t2) 3236 3237 adjust_text( 3238 texts_to_adjust, 3239 expand_text=(1.0, 1.0), 3240 force_text=0.9, 3241 arrowprops=dict(arrowstyle="-", color="gray", lw=0.1), 3242 ax=ax, 3243 ) 3244 3245 for _, row in arrow_df.iterrows(): 3246 try: 3247 idx1 = cell_names.index(row["cell1"]) 3248 idx2 = cell_names.index(row["cell2"]) 3249 except ValueError: 3250 continue 3251 x1, y1 = coords[idx1] 3252 x2, y2 = coords[idx2] 3253 arrow = FancyArrowPatch( 3254 (x1, y1), 3255 (x2, y2), 3256 arrowstyle="->", 3257 color="gray", 3258 linewidth=1.5, 3259 alpha=0.5, 3260 mutation_scale=12, 3261 zorder=0, 3262 ) 3263 ax.add_patch(arrow) 3264 3265 if set_info is False and " # " in cell_names[0]: 3266 3267 legend_elements = [ 3268 Patch( 3269 facecolor=color_mapping[name], 3270 edgecolor="black", 3271 label=f"{i} – {re.sub(' #.*', '', name)}", 3272 ) 3273 for i, name in enumerate(cell_names) 3274 ] 3275 3276 else: 3277 3278 legend_elements = [ 3279 Patch( 3280 facecolor=color_mapping[name], 3281 edgecolor="black", 3282 label=f"{i} – {name}", 3283 ) 3284 for i, name in enumerate(cell_names) 3285 ] 3286 3287 plt.legend( 3288 handles=legend_elements, 3289 title="Cells", 3290 bbox_to_anchor=(1.05, 1), 3291 loc="upper left", 3292 ncol=legend_split, 3293 ) 3294 3295 plt.xlabel("UMAP 1") 3296 plt.ylabel("UMAP 2") 3297 plt.grid(False) 3298 plt.show() 3299 3300 return fig 3301 3302 # subclusters part 3303 3304 def subcluster_prepare(self, features: list, cluster: str): 3305 """ 3306 Prepare a `Clustering` object for subcluster analysis on a selected parent cluster. 3307 3308 Parameters 3309 ---------- 3310 features : list 3311 Features to include for subcluster analysis. 3312 3313 cluster : str 3314 Parent cluster name (used to select matching cells). 3315 3316 Update 3317 ------ 3318 Initializes `self.subclusters_` as a new `Clustering` instance containing the 3319 reduced data for the given cluster and stores `current_features` and `current_cluster`. 3320 """ 3321 3322 dat = self.normalized_data 3323 dat.columns = list(self.input_metadata["cell_names"]) 3324 3325 dat = reduce_data(self.normalized_data, features=features, names=[cluster]) 3326 3327 self.subclusters_ = Clustering(data=dat, metadata=None) 3328 3329 self.subclusters_.current_features = features 3330 self.subclusters_.current_cluster = cluster 3331 3332 def define_subclusters( 3333 self, 3334 umap_num: int = 2, 3335 eps: float = 0.5, 3336 min_samples: int = 10, 3337 n_neighbors: int = 5, 3338 min_dist: float = 0.1, 3339 spread: float = 1.0, 3340 set_op_mix_ratio: float = 1.0, 3341 local_connectivity: int = 1, 3342 repulsion_strength: float = 1.0, 3343 negative_sample_rate: int = 5, 3344 width=8, 3345 height=6, 3346 ): 3347 """ 3348 Compute UMAP and DBSCAN clustering within a previously prepared subcluster dataset. 3349 3350 Parameters 3351 ---------- 3352 umap_num : int, default 2 3353 Number of UMAP dimensions to compute. 3354 3355 eps : float, default 0.5 3356 DBSCAN eps parameter. 3357 3358 min_samples : int, default 10 3359 DBSCAN min_samples parameter. 3360 3361 n_neighbors, min_dist, spread, set_op_mix_ratio, local_connectivity, repulsion_strength, negative_sample_rate, width, height : 3362 Additional parameters passed to UMAP / plotting / MeanShift as appropriate. 3363 3364 Update 3365 ------- 3366 Stores cluster labels in `self.subclusters_.subclusters`. 3367 3368 Raises 3369 ------ 3370 RuntimeError 3371 If `self.subclusters_` has not been prepared. 3372 """ 3373 3374 if self.subclusters_ is None: 3375 raise RuntimeError( 3376 "Nothing to return. 'self.subcluster_prepare' was not conducted!" 3377 ) 3378 3379 self.subclusters_.perform_UMAP( 3380 factorize=False, 3381 umap_num=umap_num, 3382 pc_num=0, 3383 harmonized=False, 3384 n_neighbors=n_neighbors, 3385 min_dist=min_dist, 3386 spread=spread, 3387 set_op_mix_ratio=set_op_mix_ratio, 3388 local_connectivity=local_connectivity, 3389 repulsion_strength=repulsion_strength, 3390 negative_sample_rate=negative_sample_rate, 3391 width=width, 3392 height=height, 3393 ) 3394 3395 fig = self.subclusters_.find_clusters_UMAP( 3396 umap_n=umap_num, 3397 eps=eps, 3398 min_samples=min_samples, 3399 width=width, 3400 height=height, 3401 ) 3402 3403 clusters = self.subclusters_.return_clusters(clusters="umap") 3404 3405 self.subclusters_.subclusters = [str(x) for x in list(clusters)] 3406 3407 return fig 3408 3409 def subcluster_features_scatter( 3410 self, 3411 colors="viridis", 3412 hclust="complete", 3413 scale=False, 3414 img_width=3, 3415 img_high=5, 3416 label_size=6, 3417 size_scale=70, 3418 y_lab="Genes", 3419 legend_lab="normalized", 3420 bbox_to_anchor_scale: int = 25, 3421 bbox_to_anchor_perc: tuple = (0.91, 0.63), 3422 ): 3423 """ 3424 Create a features-scatter visualization for the subclusters (averaged and occurrence). 3425 3426 Parameters 3427 ---------- 3428 colors : str, default 'viridis' 3429 Colormap name passed to `features_scatter`. 3430 3431 hclust : str or None 3432 Hierarchical clustering linkage to order rows/columns. 3433 3434 scale: bool, default False 3435 If True, expression data will be scaled (0–1) across the rows (features). 3436 3437 img_width, img_high : float 3438 Figure size. 3439 3440 label_size : int 3441 Font size for labels. 3442 3443 size_scale : int 3444 Bubble size scaling. 3445 3446 y_lab : str 3447 X axis label. 3448 3449 legend_lab : str 3450 Colorbar label. 3451 3452 bbox_to_anchor_scale : int, default=25 3453 Vertical scale (percentage) for positioning the colorbar. 3454 3455 bbox_to_anchor_perc : tuple, default=(0.91, 0.63) 3456 Anchor position for the size legend (percent bubble legend). 3457 3458 Returns 3459 ------- 3460 matplotlib.figure.Figure 3461 3462 Raises 3463 ------ 3464 RuntimeError 3465 If subcluster preparation/definition has not been run. 3466 """ 3467 3468 if self.subclusters_ is None or self.subclusters_.subclusters is None: 3469 raise RuntimeError( 3470 "Nothing to return. 'self.subcluster_prepare' -> 'self.define_subclusters' pip was not conducted!" 3471 ) 3472 3473 dat = self.normalized_data 3474 dat.columns = list(self.input_metadata["cell_names"]) 3475 3476 dat = reduce_data( 3477 self.normalized_data, 3478 features=self.subclusters_.current_features, 3479 names=[self.subclusters_.current_cluster], 3480 ) 3481 3482 dat.columns = self.subclusters_.subclusters 3483 3484 avg = average(dat) 3485 occ = occurrence(dat) 3486 3487 scatter = features_scatter( 3488 expression_data=avg, 3489 occurence_data=occ, 3490 features=None, 3491 scale=scale, 3492 metadata_list=None, 3493 colors=colors, 3494 hclust=hclust, 3495 img_width=img_width, 3496 img_high=img_high, 3497 label_size=label_size, 3498 size_scale=size_scale, 3499 y_lab=y_lab, 3500 legend_lab=legend_lab, 3501 bbox_to_anchor_scale=bbox_to_anchor_scale, 3502 bbox_to_anchor_perc=bbox_to_anchor_perc, 3503 ) 3504 3505 return scatter 3506 3507 def subcluster_DEG_scatter( 3508 self, 3509 top_n=3, 3510 min_exp=0, 3511 min_pct=0.25, 3512 p_val=0.05, 3513 colors="viridis", 3514 hclust="complete", 3515 scale=False, 3516 img_width=3, 3517 img_high=5, 3518 label_size=6, 3519 size_scale=70, 3520 y_lab="Genes", 3521 legend_lab="normalized", 3522 bbox_to_anchor_scale: int = 25, 3523 bbox_to_anchor_perc: tuple = (0.91, 0.63), 3524 n_proc=10, 3525 ): 3526 """ 3527 Plot top differential features (DEGs) for subclusters as a features-scatter. 3528 3529 Parameters 3530 ---------- 3531 top_n : int, default 3 3532 Number of top features per subcluster to show. 3533 3534 min_exp : float, default 0 3535 Minimum expression threshold passed to `statistic`. 3536 3537 min_pct : float, default 0.25 3538 Minimum percent expressed in target group. 3539 3540 p_val: float, default 0.05 3541 Maximum p-value for visualizing features. 3542 3543 n_proc : int, default 10 3544 Parallel jobs used for DEG calculation. 3545 3546 scale: bool, default False 3547 If True, expression_data will be scaled (0–1) across the rows (features). 3548 3549 colors : str, default='viridis' 3550 Colormap for expression values. 3551 3552 hclust : str or None, default='complete' 3553 Linkage method for hierarchical clustering. If None, no clustering 3554 is performed. 3555 3556 img_width : int or float, default=8 3557 Width of the plot in inches. 3558 3559 img_high : int or float, default=5 3560 Height of the plot in inches. 3561 3562 label_size : int, default=10 3563 Font size for axis labels and ticks. 3564 3565 size_scale : int or float, default=100 3566 Scaling factor for bubble sizes. 3567 3568 y_lab : str, default='Genes' 3569 Label for the x-axis. 3570 3571 legend_lab : str, default='normalized' 3572 Label for the colorbar legend. 3573 3574 bbox_to_anchor_scale : int, default=25 3575 Vertical scale (percentage) for positioning the colorbar. 3576 3577 bbox_to_anchor_perc : tuple, default=(0.91, 0.63) 3578 Anchor position for the size legend (percent bubble legend). 3579 3580 Returns 3581 ------- 3582 matplotlib.figure.Figure 3583 3584 Raises 3585 ------ 3586 RuntimeError 3587 If subcluster preparation/definition has not been run. 3588 3589 Notes 3590 ----- 3591 Internally calls `calc_DEG` (or equivalent) to obtain statistics, filters 3592 by p-value and effect-size, selects top features per valid group and plots them. 3593 """ 3594 3595 if self.subclusters_ is None or self.subclusters_.subclusters is None: 3596 raise RuntimeError( 3597 "Nothing to return. 'self.subcluster_prepare' -> 'self.define_subclusters' pip was not conducted!" 3598 ) 3599 3600 dat = self.normalized_data 3601 dat.columns = list(self.input_metadata["cell_names"]) 3602 3603 dat = reduce_data( 3604 self.normalized_data, names=[self.subclusters_.current_cluster] 3605 ) 3606 3607 dat.columns = self.subclusters_.subclusters 3608 3609 deg_stats = calc_DEG( 3610 dat, 3611 metadata_list=None, 3612 entities="All", 3613 sets=None, 3614 min_exp=min_exp, 3615 min_pct=min_pct, 3616 n_proc=n_proc, 3617 ) 3618 3619 deg_stats = deg_stats[deg_stats["p_val"] <= p_val] 3620 deg_stats = deg_stats[deg_stats["log(FC)"] > 0] 3621 3622 deg_stats = ( 3623 deg_stats.sort_values( 3624 ["valid_group", "esm", "log(FC)"], ascending=[True, False, False] 3625 ) 3626 .groupby("valid_group") 3627 .head(top_n) 3628 ) 3629 3630 dat = reduce_data(dat, features=list(set(deg_stats["feature"]))) 3631 3632 avg = average(dat) 3633 occ = occurrence(dat) 3634 3635 scatter = features_scatter( 3636 expression_data=avg, 3637 occurence_data=occ, 3638 features=None, 3639 metadata_list=None, 3640 colors=colors, 3641 hclust=hclust, 3642 img_width=img_width, 3643 img_high=img_high, 3644 label_size=label_size, 3645 size_scale=size_scale, 3646 y_lab=y_lab, 3647 legend_lab=legend_lab, 3648 bbox_to_anchor_scale=bbox_to_anchor_scale, 3649 bbox_to_anchor_perc=bbox_to_anchor_perc, 3650 ) 3651 3652 return scatter 3653 3654 def accept_subclusters(self): 3655 """ 3656 Commit subcluster labels into the main `input_metadata` by renaming cell names. 3657 3658 The method replaces occurrences of the parent cluster name in `self.input_metadata['cell_names']` 3659 with the expanded names that include subcluster suffixes (via `add_subnames`), 3660 then clears `self.subclusters_`. 3661 3662 Update 3663 ------ 3664 Modifies `self.input_metadata['cell_names']`. 3665 3666 Resets `self.subclusters_` to None. 3667 3668 Raises 3669 ------ 3670 RuntimeError 3671 If `self.subclusters_` is not defined or subclusters were not computed. 3672 """ 3673 3674 if self.subclusters_ is None or self.subclusters_.subclusters is None: 3675 raise RuntimeError( 3676 "Nothing to return. 'self.subcluster_prepare' -> 'self.define_subclusters' pip was not conducted!" 3677 ) 3678 3679 new_meta = add_subnames( 3680 list(self.input_metadata["cell_names"]), 3681 parent_name=self.subclusters_.current_cluster, 3682 new_clusters=self.subclusters_.subclusters, 3683 ) 3684 3685 self.input_metadata["cell_names"] = new_meta 3686 3687 self.subclusters_ = None 3688 3689 def scatter_plot( 3690 self, 3691 names: list | None = None, 3692 features: list | None = None, 3693 name_slot: str = "cell_names", 3694 scale=True, 3695 colors="viridis", 3696 hclust=None, 3697 img_width=15, 3698 img_high=1, 3699 label_size=10, 3700 size_scale=200, 3701 y_lab="Genes", 3702 legend_lab="log(CPM + 1)", 3703 set_box_size: float | int = 5, 3704 set_box_high: float | int = 5, 3705 bbox_to_anchor_scale=25, 3706 bbox_to_anchor_perc=(0.90, -0.24), 3707 bbox_to_anchor_group=(1.01, 0.4), 3708 ): 3709 """ 3710 Create a bubble scatter plot of selected features across samples inside project. 3711 3712 Each point represents a feature-sample pair, where the color encodes the 3713 expression value and the size encodes occurrence or relative abundance. 3714 Optionally, hierarchical clustering can be applied to order rows and columns. 3715 3716 Parameters 3717 ---------- 3718 names : list, str, or None 3719 Names of samples to include. If None, all samples are considered. 3720 3721 features : list, str, or None 3722 Names of features to include. If None, all features are considered. 3723 3724 name_slot : str 3725 Column in metadata to use as sample names. 3726 3727 scale: bool, default False 3728 If True, expression_data will be scaled (0–1) across the rows (features). 3729 3730 colors : str, default='viridis' 3731 Colormap for expression values. 3732 3733 hclust : str or None, default='complete' 3734 Linkage method for hierarchical clustering. If None, no clustering 3735 is performed. 3736 3737 img_width : int or float, default=8 3738 Width of the plot in inches. 3739 3740 img_high : int or float, default=5 3741 Height of the plot in inches. 3742 3743 label_size : int, default=10 3744 Font size for axis labels and ticks. 3745 3746 size_scale : int or float, default=100 3747 Scaling factor for bubble sizes. 3748 3749 y_lab : str, default='Genes' 3750 Label for the x-axis. 3751 3752 legend_lab : str, default='log(CPM + 1)' 3753 Label for the colorbar legend. 3754 3755 bbox_to_anchor_scale : int, default=25 3756 Vertical scale (percentage) for positioning the colorbar. 3757 3758 bbox_to_anchor_perc : tuple, default=(0.91, 0.63) 3759 Anchor position for the size legend (percent bubble legend). 3760 3761 bbox_to_anchor_group : tuple, default=(1.01, 0.4) 3762 Anchor position for the group legend. 3763 3764 Returns 3765 ------- 3766 matplotlib.figure.Figure 3767 The generated scatter plot figure. 3768 3769 Notes 3770 ----- 3771 Colors represent expression values normalized to the colormap. 3772 """ 3773 3774 prtd, met = self.get_partial_data( 3775 names=names, features=features, name_slot=name_slot, inc_metadata=True 3776 ) 3777 3778 prtd.columns = prtd.columns + "#" + met["sets"] 3779 3780 prtd_avg = average(prtd) 3781 3782 meta_sets = [re.sub(".*#", "", x) for x in prtd_avg.columns] 3783 3784 prtd_avg.columns = [re.sub("#.*", "", x) for x in prtd_avg.columns] 3785 3786 prtd_occ = occurrence(prtd) 3787 3788 prtd_occ.columns = [re.sub("#.*", "", x) for x in prtd_occ.columns] 3789 3790 fig_scatter = features_scatter( 3791 expression_data=prtd_avg, 3792 occurence_data=prtd_occ, 3793 scale=scale, 3794 features=None, 3795 metadata_list=meta_sets, 3796 colors=colors, 3797 hclust=hclust, 3798 img_width=img_width, 3799 img_high=img_high, 3800 label_size=label_size, 3801 size_scale=size_scale, 3802 y_lab=y_lab, 3803 legend_lab=legend_lab, 3804 set_box_size=set_box_size, 3805 set_box_high=set_box_high, 3806 bbox_to_anchor_scale=bbox_to_anchor_scale, 3807 bbox_to_anchor_perc=bbox_to_anchor_perc, 3808 bbox_to_anchor_group=bbox_to_anchor_group, 3809 ) 3810 3811 return fig_scatter 3812 3813 def data_composition( 3814 self, 3815 features_count: list | None, 3816 name_slot: str = "cell_names", 3817 set_sep: bool = True, 3818 ): 3819 """ 3820 Compute composition of cell types in data set. 3821 3822 This function counts the occurrences of specific cells (e.g., cell types, subtypes) 3823 within metadata entries, calculates their relative percentages, and stores 3824 the results in `self.composition_data`. 3825 3826 Parameters 3827 ---------- 3828 features_count : list or None 3829 List of features (part or full names) to be counted. 3830 If None, all unique elements from the specified `name_slot` metadata field are used. 3831 3832 name_slot : str, default 'cell_names' 3833 Metadata field containing sample identifiers or labels. 3834 3835 set_sep : bool, default True 3836 If True and multiple sets are present in metadata, compute composition 3837 separately for each set. 3838 3839 Update 3840 ------- 3841 Stores results in `self.composition_data` as a pandas DataFrame with: 3842 - 'name': feature name 3843 - 'n': number of occurrences 3844 - 'pct': percentage of occurrences 3845 - 'set' (if applicable): dataset identifier 3846 """ 3847 3848 validated_list = list(self.input_metadata[name_slot]) 3849 sets = list(self.input_metadata["sets"]) 3850 3851 if features_count is None: 3852 features_count = list(set(self.input_metadata[name_slot])) 3853 3854 if set_sep and len(set(sets)) > 1: 3855 3856 final_res = pd.DataFrame() 3857 3858 for s in set(sets): 3859 print(s) 3860 3861 mask = [True if s == x else False for x in sets] 3862 3863 tmp_val_list = np.array(validated_list) 3864 3865 tmp_val_list = list(tmp_val_list[mask]) 3866 3867 res_dict = {"name": [], "n": [], "set": []} 3868 3869 for f in tqdm(features_count): 3870 res_dict["n"].append( 3871 sum(1 for element in tmp_val_list if f in element) 3872 ) 3873 res_dict["name"].append(f) 3874 res_dict["set"].append(s) 3875 res = pd.DataFrame(res_dict) 3876 res["pct"] = res["n"] / sum(res["n"]) * 100 3877 res["pct"] = res["pct"].round(2) 3878 3879 final_res = pd.concat([final_res, res]) 3880 3881 res = final_res.sort_values(["set", "pct"], ascending=[True, False]) 3882 3883 else: 3884 3885 res_dict = {"name": [], "n": []} 3886 3887 for f in tqdm(features_count): 3888 res_dict["n"].append( 3889 sum(1 for element in validated_list if f in element) 3890 ) 3891 res_dict["name"].append(f) 3892 3893 res = pd.DataFrame(res_dict) 3894 res["pct"] = res["n"] / sum(res["n"]) * 100 3895 res["pct"] = res["pct"].round(2) 3896 3897 res = res.sort_values("pct", ascending=False) 3898 3899 self.composition_data = res 3900 3901 def composition_pie( 3902 self, 3903 width=6, 3904 height=6, 3905 font_size=15, 3906 cmap: str = "tab20", 3907 legend_split_col: int = 1, 3908 offset_labels: float | int = 0.5, 3909 legend_bbox: tuple = (1.15, 0.95), 3910 ): 3911 """ 3912 Visualize the composition of cell lineages using pie charts. 3913 3914 Generates pie charts showing the relative proportions of features stored 3915 in `self.composition_data`. If multiple sets are present, a separate 3916 chart is drawn for each set. 3917 3918 Parameters 3919 ---------- 3920 width : int, default 6 3921 Width of the figure. 3922 3923 height : int, default 6 3924 Height of the figure (applied per set if multiple sets are plotted). 3925 3926 font_size : int, default 15 3927 Font size for labels and annotations. 3928 3929 cmap : str, default 'tab20' 3930 Colormap used for pie slices. 3931 3932 legend_split_col : int, default 1 3933 Number of columns in the legend. 3934 3935 offset_labels : float or int, default 0.5 3936 Spacing offset for label placement relative to pie slices. 3937 3938 legend_bbox : tuple, default (1.15, 0.95) 3939 Bounding box anchor position for the legend. 3940 3941 Returns 3942 ------- 3943 matplotlib.figure.Figure 3944 Pie chart visualization of composition data. 3945 """ 3946 3947 df = self.composition_data 3948 3949 if "set" in df.columns and len(set(df["set"])) > 1: 3950 3951 sets = list(set(df["set"])) 3952 fig, axes = plt.subplots(len(sets), 1, figsize=(width, height * len(sets))) 3953 3954 all_wedges = [] 3955 cmap = plt.get_cmap(cmap) 3956 3957 set_nam = len(set(df["name"])) 3958 3959 legend_labels = list(set(df["name"])) 3960 3961 colors = [cmap(i / set_nam) for i in range(set_nam)] 3962 3963 cmap_dict = dict(zip(legend_labels, colors)) 3964 3965 for idx, s in enumerate(sets): 3966 ax = axes[idx] 3967 tmp_df = df[df["set"] == s].reset_index(drop=True) 3968 3969 labels = [f"{row['pct']:.1f}%" for _, row in tmp_df.iterrows()] 3970 3971 wedges, _ = ax.pie( 3972 tmp_df["n"], 3973 startangle=90, 3974 labeldistance=1.05, 3975 colors=[cmap_dict[x] for x in tmp_df["name"]], 3976 wedgeprops={"linewidth": 0.5, "edgecolor": "black"}, 3977 ) 3978 3979 all_wedges.extend(wedges) 3980 3981 kw = dict(arrowprops=dict(arrowstyle="-"), zorder=0, va="center") 3982 n = 0 3983 for i, p in enumerate(wedges): 3984 ang = (p.theta2 - p.theta1) / 2.0 + p.theta1 3985 y = np.sin(np.deg2rad(ang)) 3986 x = np.cos(np.deg2rad(ang)) 3987 horizontalalignment = {-1: "right", 1: "left"}[int(np.sign(x))] 3988 connectionstyle = f"angle,angleA=0,angleB={ang}" 3989 kw["arrowprops"].update({"connectionstyle": connectionstyle}) 3990 if len(labels[i]) > 0: 3991 n += offset_labels 3992 ax.annotate( 3993 labels[i], 3994 xy=(x, y), 3995 xytext=(1.01 * x + (n * x / 4), 1.01 * y + (n * y / 4)), 3996 horizontalalignment=horizontalalignment, 3997 fontsize=font_size, 3998 weight="bold", 3999 **kw, 4000 ) 4001 4002 circle2 = plt.Circle((0, 0), 0.6, color="white", ec="black") 4003 ax.add_artist(circle2) 4004 4005 ax.text( 4006 0, 4007 0, 4008 f"{s}", 4009 ha="center", 4010 va="center", 4011 fontsize=font_size, 4012 weight="bold", 4013 ) 4014 4015 legend_handles = [ 4016 Patch(facecolor=cmap_dict[label], edgecolor="black", label=label) 4017 for label in legend_labels 4018 ] 4019 4020 fig.legend( 4021 handles=legend_handles, 4022 loc="center right", 4023 bbox_to_anchor=legend_bbox, 4024 ncol=legend_split_col, 4025 title="", 4026 ) 4027 4028 plt.tight_layout() 4029 plt.show() 4030 4031 else: 4032 4033 labels = [f"{row['pct']:.1f}%" for _, row in df.iterrows()] 4034 4035 legend_labels = [f"{row['name']}" for _, row in df.iterrows()] 4036 4037 cmap = plt.get_cmap(cmap) 4038 colors = [cmap(i / len(df)) for i in range(len(df))] 4039 4040 fig, ax = plt.subplots( 4041 figsize=(width, height), subplot_kw=dict(aspect="equal") 4042 ) 4043 4044 wedges, _ = ax.pie( 4045 df["n"], 4046 startangle=90, 4047 labeldistance=1.05, 4048 colors=colors, 4049 wedgeprops={"linewidth": 0.5, "edgecolor": "black"}, 4050 ) 4051 4052 kw = dict(arrowprops=dict(arrowstyle="-"), zorder=0, va="center") 4053 n = 0 4054 for i, p in enumerate(wedges): 4055 ang = (p.theta2 - p.theta1) / 2.0 + p.theta1 4056 y = np.sin(np.deg2rad(ang)) 4057 x = np.cos(np.deg2rad(ang)) 4058 horizontalalignment = {-1: "right", 1: "left"}[int(np.sign(x))] 4059 connectionstyle = "angle,angleA=0,angleB={}".format(ang) 4060 kw["arrowprops"].update({"connectionstyle": connectionstyle}) 4061 if len(labels[i]) > 0: 4062 n += offset_labels 4063 4064 ax.annotate( 4065 labels[i], 4066 xy=(x, y), 4067 xytext=(1.01 * x + (n * x / 4), y * 1.01 + (n * y / 4)), 4068 horizontalalignment=horizontalalignment, 4069 fontsize=font_size, 4070 weight="bold", 4071 **kw, 4072 ) 4073 4074 circle2 = plt.Circle((0, 0), 0.6, color="white") 4075 circle2.set_edgecolor("black") 4076 4077 p = plt.gcf() 4078 p.gca().add_artist(circle2) 4079 4080 ax.legend( 4081 wedges, 4082 legend_labels, 4083 title="", 4084 loc="center left", 4085 bbox_to_anchor=legend_bbox, 4086 ncol=legend_split_col, 4087 ) 4088 4089 plt.show() 4090 4091 return fig 4092 4093 def bar_composition( 4094 self, 4095 cmap="tab20b", 4096 width=2, 4097 height=6, 4098 font_size=15, 4099 legend_split_col: int = 1, 4100 legend_bbox: tuple = (1.3, 1), 4101 ): 4102 """ 4103 Visualize the composition of cell lineages using bar plots. 4104 4105 Produces bar plots showing the distribution of features stored in 4106 `self.composition_data`. If multiple sets are present, a separate 4107 bar is drawn for each set. Percentages are annotated alongside the bars. 4108 4109 Parameters 4110 ---------- 4111 cmap : str, default 'tab20b' 4112 Colormap used for stacked bars. 4113 4114 width : int, default 2 4115 Width of each subplot (per set). 4116 4117 height : int, default 6 4118 Height of the figure. 4119 4120 font_size : int, default 15 4121 Font size for labels and annotations. 4122 4123 legend_split_col : int, default 1 4124 Number of columns in the legend. 4125 4126 legend_bbox : tuple, default (1.3, 1) 4127 Bounding box anchor position for the legend. 4128 4129 Returns 4130 ------- 4131 matplotlib.figure.Figure 4132 Stacked bar plot visualization of composition data. 4133 """ 4134 4135 df = self.composition_data 4136 df["num"] = range(1, len(df) + 1) 4137 4138 if "set" in df.columns and len(set(df["set"])) > 1: 4139 4140 sets = list(set(df["set"])) 4141 fig, axes = plt.subplots(1, len(sets), figsize=(width * len(sets), height)) 4142 4143 cmap = plt.get_cmap(cmap) 4144 4145 set_nam = len(set(df["name"])) 4146 4147 legend_labels = list(set(df["name"])) 4148 4149 colors = [cmap(i / set_nam) for i in range(set_nam)] 4150 4151 cmap_dict = dict(zip(legend_labels, colors)) 4152 4153 for idx, s in enumerate(sets): 4154 ax = axes[idx] 4155 4156 tmp_df = df[df["set"] == s].reset_index(drop=True) 4157 4158 values = tmp_df["n"].values 4159 total = sum(values) 4160 values = [v / total * 100 for v in values] 4161 values = [round(v, 2) for v in values] 4162 4163 idx_max = np.argmax(values) 4164 correction = 100 - sum(values) 4165 values[idx_max] += correction 4166 4167 names = tmp_df["name"].values 4168 perc = tmp_df["pct"].values 4169 nums = tmp_df["num"].values 4170 4171 bottom = 0 4172 centers = [] 4173 for name, num, val, color in zip(names, nums, values, colors): 4174 ax.bar(s, val, bottom=bottom, color=cmap_dict[name], label=name) 4175 centers.append(bottom + val / 2) 4176 bottom += val 4177 4178 y_positions = np.linspace(centers[0], centers[-1], len(centers)) 4179 x_text = -0.8 4180 4181 for y_label, y_center, pct, num in zip( 4182 y_positions, centers, perc, nums 4183 ): 4184 ax.annotate( 4185 f"{pct:.1f}%", 4186 xy=(0, y_center), 4187 xycoords="data", 4188 xytext=(x_text, y_label), 4189 textcoords="data", 4190 ha="right", 4191 va="center", 4192 fontsize=font_size, 4193 arrowprops=dict( 4194 arrowstyle="->", 4195 lw=1, 4196 color="black", 4197 connectionstyle="angle3,angleA=0,angleB=90", 4198 ), 4199 ) 4200 4201 ax.set_ylim(0, 100) 4202 ax.set_xlabel(s, fontsize=font_size) 4203 ax.xaxis.label.set_rotation(30) 4204 4205 ax.set_xticks([]) 4206 ax.set_yticks([]) 4207 for spine in ax.spines.values(): 4208 spine.set_visible(False) 4209 4210 legend_handles = [ 4211 Patch(facecolor=cmap_dict[label], edgecolor="black", label=label) 4212 for label in legend_labels 4213 ] 4214 4215 fig.legend( 4216 handles=legend_handles, 4217 loc="center right", 4218 bbox_to_anchor=legend_bbox, 4219 ncol=legend_split_col, 4220 title="", 4221 ) 4222 4223 plt.tight_layout() 4224 plt.show() 4225 4226 else: 4227 4228 cmap = plt.get_cmap(cmap) 4229 4230 colors = [cmap(i / len(df)) for i in range(len(df))] 4231 4232 fig, ax = plt.subplots(figsize=(width, height)) 4233 4234 values = df["n"].values 4235 names = df["name"].values 4236 perc = df["pct"].values 4237 nums = df["num"].values 4238 4239 bottom = 0 4240 centers = [] 4241 for name, num, val, color in zip(names, nums, values, colors): 4242 ax.bar(0, val, bottom=bottom, color=color, label=f"{num}) {name}") 4243 centers.append(bottom + val / 2) 4244 bottom += val 4245 4246 y_positions = np.linspace(centers[0], centers[-1], len(centers)) 4247 x_text = -0.8 4248 4249 for y_label, y_center, pct, num in zip(y_positions, centers, perc, nums): 4250 ax.annotate( 4251 f"{num}) {pct}", 4252 xy=(0, y_center), 4253 xycoords="data", 4254 xytext=(x_text, y_label), 4255 textcoords="data", 4256 ha="right", 4257 va="center", 4258 fontsize=9, 4259 arrowprops=dict( 4260 arrowstyle="->", 4261 lw=1, 4262 color="black", 4263 connectionstyle="angle3,angleA=0,angleB=90", 4264 ), 4265 ) 4266 4267 ax.set_xticks([]) 4268 ax.set_yticks([]) 4269 for spine in ax.spines.values(): 4270 spine.set_visible(False) 4271 4272 ax.legend( 4273 title="Legend", 4274 bbox_to_anchor=legend_bbox, 4275 loc="upper left", 4276 ncol=legend_split_col, 4277 ) 4278 4279 plt.tight_layout() 4280 plt.show() 4281 4282 return fig 4283 4284 def cell_regression( 4285 self, 4286 cell_x: str, 4287 cell_y: str, 4288 set_x: str | None, 4289 set_y: str | None, 4290 threshold=10, 4291 image_width=12, 4292 image_high=7, 4293 color="black", 4294 ): 4295 """ 4296 Perform regression analysis between two selected cells and visualize the relationship. 4297 4298 This function computes a linear regression between two specified cells from 4299 aggregated normalized data, plots the regression line with scatter points, 4300 annotates regression statistics, and highlights potential outliers. 4301 4302 Parameters 4303 ---------- 4304 cell_x : str 4305 Name of the first cell (X-axis). 4306 4307 cell_y : str 4308 Name of the second cell (Y-axis). 4309 4310 set_x : str or None 4311 Dataset identifier corresponding to `cell_x`. If None, cell is selected only by name. 4312 4313 set_y : str or None 4314 Dataset identifier corresponding to `cell_y`. If None, cell is selected only by name. 4315 4316 threshold : int or float, default 10 4317 Threshold for detecting outliers. Points deviating from the mean or diagonal by more 4318 than this value are annotated. 4319 4320 image_width : int, default 12 4321 Width of the regression plot (in inches). 4322 4323 image_high : int, default 7 4324 Height of the regression plot (in inches). 4325 4326 color : str, default 'black' 4327 Color of the regression scatter points and line. 4328 4329 Returns 4330 ------- 4331 matplotlib.figure.Figure 4332 Regression plot figure with annotated regression line, R², p-value, and outliers. 4333 4334 Raises 4335 ------ 4336 ValueError 4337 If `cell_x` or `cell_y` are not found in the dataset. 4338 If multiple matches are found for a cell name and `set_x`/`set_y` are not specified. 4339 4340 Notes 4341 ----- 4342 - The function automatically calls `jseq_object.average()` if aggregated data is not available. 4343 - Outliers are annotated with their corresponding index labels. 4344 - Regression is computed using `scipy.stats.linregress`. 4345 4346 Examples 4347 -------- 4348 >>> obj.cell_regression(cell_x="Purkinje", cell_y="Granule", set_x="Exp1", set_y="Exp2") 4349 >>> obj.cell_regression(cell_x="NeuronA", cell_y="NeuronB", threshold=5, color="blue") 4350 """ 4351 4352 if self.agg_normalized_data is None: 4353 self.average() 4354 4355 metadata = self.agg_metadata 4356 data = self.agg_normalized_data 4357 4358 if set_x is not None and set_y is not None: 4359 data.columns = metadata["cell_names"] + " # " + metadata["sets"] 4360 cell_x = cell_x + " # " + set_x 4361 cell_y = cell_y + " # " + set_y 4362 4363 else: 4364 data.columns = metadata["cell_names"] 4365 4366 if not cell_x in data.columns: 4367 raise ValueError("'cell_x' value not in cell names!") 4368 4369 if not cell_y in data.columns: 4370 raise ValueError("'cell_y' value not in cell names!") 4371 4372 if list(data.columns).count(cell_x) > 1: 4373 raise ValueError( 4374 f"'{cell_x}' occurs more than once. If you want to select a specific cell, " 4375 f"please also provide the corresponding 'set_x' and 'set_y' values." 4376 ) 4377 4378 if list(data.columns).count(cell_y) > 1: 4379 raise ValueError( 4380 f"'{cell_y}' occurs more than once. If you want to select a specific cell, " 4381 f"please also provide the corresponding 'set_x' and 'set_y' values." 4382 ) 4383 4384 fig, ax = plt.subplots(figsize=(image_width, image_high)) 4385 ax = sns.regplot(x=cell_x, y=cell_y, data=data, color=color) 4386 4387 slope, intercept, r_value, p_value, _ = stats.linregress( 4388 data[cell_x], data[cell_y] 4389 ) 4390 equation = "y = {:.2f}x + {:.2f}".format(slope, intercept) 4391 4392 ax.annotate( 4393 "R-squared = {:.2f}\nP-value = {:.2f}\n{}".format( 4394 r_value**2, p_value, equation 4395 ), 4396 xy=(0.05, 0.90), 4397 xycoords="axes fraction", 4398 fontsize=12, 4399 ) 4400 4401 ax.spines["top"].set_visible(False) 4402 ax.spines["right"].set_visible(False) 4403 4404 diff = [] 4405 x_mean, y_mean = data[cell_x].mean(), data[cell_y].mean() 4406 for i, (xi, yi) in enumerate(zip(data[cell_x], data[cell_y])): 4407 diff.append(abs(xi - x_mean)) 4408 diff.append(abs(yi - y_mean)) 4409 4410 def annotate_outliers(x, y, threshold): 4411 texts = [] 4412 x_mean, y_mean = x.mean(), y.mean() 4413 for i, (xi, yi) in enumerate(zip(x, y)): 4414 if ( 4415 abs(xi - x_mean) > threshold 4416 or abs(yi - y_mean) > threshold 4417 or abs(yi - xi) > threshold 4418 ): 4419 text = ax.text(xi, yi, data.index[i]) 4420 texts.append(text) 4421 4422 return texts 4423 4424 texts = annotate_outliers(data[cell_x], data[cell_y], threshold) 4425 4426 adjust_text(texts, arrowprops=dict(arrowstyle="-", color="gray", alpha=0.5)) 4427 4428 plt.show() 4429 4430 return fig
31class Clustering: 32 """ 33 A class for performing dimensionality reduction, clustering, and visualization 34 on high-dimensional data (e.g., single-cell gene expression). 35 36 The class provides methods for: 37 - Normalizing and extracting subsets of data 38 - Principal Component Analysis (PCA) and related clustering 39 - Uniform Manifold Approximation and Projection (UMAP) and clustering 40 - Visualization of PCA and UMAP embeddings 41 - Harmonization of batch effects 42 - Accessing processed data and cluster labels 43 44 Methods 45 ------- 46 add_data_frame(data, metadata) 47 Class method to create a Clustering instance from a DataFrame and metadata. 48 49 harmonize_sets() 50 Perform batch effect harmonization on PCA data. 51 52 perform_PCA(pc_num=100, width=8, height=6) 53 Perform PCA on the dataset and visualize the first two PCs. 54 55 knee_plot_PCA(width=8, height=6) 56 Plot the cumulative variance explained by PCs to determine optimal dimensionality. 57 58 find_clusters_PCA(pc_num=0, eps=0.5, min_samples=10, width=8, height=6, harmonized=False) 59 Apply DBSCAN clustering to PCA embeddings and visualize results. 60 61 perform_UMAP(factorize=False, umap_num=100, pc_num=0, harmonized=False, ...) 62 Compute UMAP embeddings with optional parameter tuning. 63 64 knee_plot_umap(eps=0.5, min_samples=10) 65 Determine optimal UMAP dimensionality using silhouette scores. 66 67 find_clusters_UMAP(umap_n=5, eps=0.5, min_samples=10, width=8, height=6) 68 Apply DBSCAN clustering on UMAP embeddings and visualize clusters. 69 70 UMAP_vis(names_slot='cell_names', set_sep=True, point_size=0.6, ...) 71 Visualize UMAP embeddings with labels and optional cluster numbering. 72 73 UMAP_feature(feature_name, features_data=None, point_size=0.6, ...) 74 Plot a single feature over UMAP coordinates with customizable colormap. 75 76 get_umap_data() 77 Return the UMAP embeddings along with cluster labels if available. 78 79 get_pca_data() 80 Return the PCA results along with cluster labels if available. 81 82 return_clusters(clusters='umap') 83 Return the cluster labels for UMAP or PCA embeddings. 84 85 Raises 86 ------ 87 ValueError 88 For invalid parameters, mismatched dimensions, or missing metadata. 89 """ 90 91 def __init__(self, data, metadata): 92 """ 93 Initialize the clustering class with data and optional metadata. 94 95 Parameters 96 ---------- 97 data : pandas.DataFrame 98 Input data for clustering. Columns are considered as samples. 99 100 metadata : pandas.DataFrame, optional 101 Metadata for the samples. If None, a default DataFrame with column 102 names as 'cell_names' is created. 103 104 Attributes 105 ---------- 106 -clustering_data : pandas.DataFrame 107 -clustering_metadata : pandas.DataFrame 108 -subclusters : None or dict 109 -explained_var : None or numpy.ndarray 110 -cumulative_var : None or numpy.ndarray 111 -pca : None or pandas.DataFrame 112 -harmonized_pca : None or pandas.DataFrame 113 -umap : None or pandas.DataFrame 114 """ 115 116 self.clustering_data = data 117 """The input data used for clustering.""" 118 119 if metadata is None: 120 metadata = pd.DataFrame({"cell_names": list(data.columns)}) 121 122 self.clustering_metadata = metadata 123 """Metadata associated with the samples.""" 124 125 self.subclusters = None 126 """Placeholder for storing subcluster information.""" 127 128 self.explained_var = None 129 """Explained variance from PCA, initialized as None.""" 130 131 self.cumulative_var = None 132 """Cumulative explained variance from PCA, initialized as None.""" 133 134 self.pca = None 135 """PCA-transformed data, initialized as None.""" 136 137 self.harmonized_pca = None 138 """PCA data after batch effect harmonization, initialized as None.""" 139 140 self.umap = None 141 """UMAP embeddings, initialized as None.""" 142 143 @classmethod 144 def add_data_frame(cls, data: pd.DataFrame, metadata: pd.DataFrame | None): 145 """ 146 Create a Clustering instance from a DataFrame and optional metadata. 147 148 Parameters 149 ---------- 150 data : pandas.DataFrame 151 Input data with features as rows and samples/cells as columns. 152 153 metadata : pandas.DataFrame or None 154 Optional metadata for the samples. 155 Each row corresponds to a sample/cell, and column names in this DataFrame 156 should match the sample/cell names in `data`. Columns can contain additional 157 information such as cell type, experimental condition, batch, sets, etc. 158 159 Returns 160 ------- 161 Clustering 162 A new instance of the Clustering class. 163 """ 164 165 return cls(data, metadata) 166 167 def harmonize_sets(self, batch_col: str = "sets"): 168 """ 169 Perform batch effect harmonization on PCA embeddings. 170 171 Parameters 172 ---------- 173 batch_col : str, default 'sets' 174 Name of the column in `metadata` that contains batch information for the samples/cells. 175 176 Returns 177 ------- 178 None 179 Updates the `harmonized_pca` attribute with harmonized data. 180 """ 181 182 data_mat = np.array(self.pca) 183 184 metadata = self.clustering_metadata 185 186 self.harmonized_pca = pd.DataFrame( 187 harmonize.run_harmony(data_mat, metadata, vars_use=batch_col).Z_corr 188 ).T 189 190 self.harmonized_pca.columns = self.pca.columns 191 192 def perform_PCA(self, pc_num: int = 100, width=8, height=6): 193 """ 194 Perform Principal Component Analysis (PCA) on the dataset. 195 196 This method standardizes the data, applies PCA, stores results as attributes, 197 and generates a scatter plot of the first two principal components. 198 199 Parameters 200 ---------- 201 pc_num : int, default 100 202 Number of principal components to compute. 203 If 0, computes all available components. 204 205 width : int or float, default 8 206 Width of the PCA figure. 207 208 height : int or float, default 6 209 Height of the PCA figure. 210 211 Returns 212 ------- 213 matplotlib.figure.Figure 214 Scatter plot showing the first two principal components. 215 216 Updates 217 ------- 218 self.pca : pandas.DataFrame 219 DataFrame with principal component scores for each sample. 220 221 self.explained_var : numpy.ndarray 222 Percentage of variance explained by each principal component. 223 224 self.cumulative_var : numpy.ndarray 225 Cumulative explained variance. 226 """ 227 228 scaler = StandardScaler() 229 data_scaled = scaler.fit_transform(self.clustering_data.T) 230 231 if pc_num == 0 or pc_num > data_scaled.shape[0]: 232 pc_num = data_scaled.shape[0] 233 234 pca = PCA(n_components=pc_num, random_state=42) 235 236 principal_components = pca.fit_transform(data_scaled) 237 238 pca_df = pd.DataFrame( 239 data=principal_components, 240 columns=["PC" + str(x + 1) for x in range(pc_num)], 241 ) 242 243 self.explained_var = pca.explained_variance_ratio_ * 100 244 self.cumulative_var = np.cumsum(self.explained_var) 245 246 self.pca = pca_df 247 248 fig = plt.figure(figsize=(width, height)) 249 plt.scatter(pca_df["PC1"], pca_df["PC2"], alpha=0.7) 250 plt.xlabel("PC 1") 251 plt.ylabel("PC 2") 252 plt.grid(True) 253 plt.show() 254 255 return fig 256 257 def knee_plot_PCA(self, width: int = 8, height: int = 6): 258 """ 259 Plot cumulative explained variance to determine the optimal number of PCs. 260 261 Parameters 262 ---------- 263 width : int, default 8 264 Width of the figure. 265 266 height : int or, default 6 267 Height of the figure. 268 269 Returns 270 ------- 271 matplotlib.figure.Figure 272 Line plot showing cumulative variance explained by each PC. 273 """ 274 275 fig_knee = plt.figure(figsize=(width, height)) 276 plt.plot(range(1, len(self.explained_var) + 1), self.cumulative_var, marker="o") 277 plt.xlabel("PC (n components)") 278 plt.ylabel("Cumulative explained variance (%)") 279 plt.grid(True) 280 281 xticks = [1] + list(range(5, len(self.explained_var) + 1, 5)) 282 plt.xticks(xticks, rotation=60) 283 284 plt.show() 285 286 return fig_knee 287 288 def find_clusters_PCA( 289 self, 290 pc_num: int = 2, 291 eps: float = 0.5, 292 min_samples: int = 10, 293 width: int = 8, 294 height: int = 6, 295 harmonized: bool = False, 296 ): 297 """ 298 Apply DBSCAN clustering to PCA embeddings and visualize the results. 299 300 This method performs density-based clustering (DBSCAN) on the PCA-reduced 301 dataset. Cluster labels are stored in the object's metadata, and a scatter 302 plot of the first two principal components with cluster annotations is returned. 303 304 Parameters 305 ---------- 306 pc_num : int, default 2 307 Number of principal components to use for clustering. 308 If 0, uses all available components. 309 310 eps : float, default 0.5 311 Maximum distance between two points for them to be considered 312 as neighbors (DBSCAN parameter). 313 314 min_samples : int, default 10 315 Minimum number of samples required to form a cluster (DBSCAN parameter). 316 317 width : int, default 8 318 Width of the output scatter plot. 319 320 height : int, default 6 321 Height of the output scatter plot. 322 323 harmonized : bool, default False 324 If True, use harmonized PCA data (`self.harmonized_pca`). 325 If False, use standard PCA results (`self.pca`). 326 327 Returns 328 ------- 329 matplotlib.figure.Figure 330 Scatter plot of the first two principal components colored by 331 cluster assignments. 332 333 Updates 334 ------- 335 self.clustering_metadata['PCA_clusters'] : list 336 Cluster labels assigned to each cell/sample. 337 338 self.input_metadata['PCA_clusters'] : list, optional 339 Cluster labels stored in input metadata (if available). 340 """ 341 342 dbscan = DBSCAN(eps=eps, min_samples=min_samples) 343 344 if pc_num == 0 and harmonized: 345 PCA = self.harmonized_pca 346 347 elif pc_num == 0: 348 PCA = self.pca 349 350 else: 351 if harmonized: 352 353 PCA = self.harmonized_pca.iloc[:, 0:pc_num] 354 else: 355 356 PCA = self.pca.iloc[:, 0:pc_num] 357 358 dbscan_labels = dbscan.fit_predict(PCA) 359 360 pca_df = pd.DataFrame(PCA) 361 pca_df["Cluster"] = dbscan_labels 362 363 fig = plt.figure(figsize=(width, height)) 364 365 for cluster_id in sorted(pca_df["Cluster"].unique()): 366 cluster_data = pca_df[pca_df["Cluster"] == cluster_id] 367 plt.scatter( 368 cluster_data["PC1"], 369 cluster_data["PC2"], 370 label=f"Cluster {cluster_id}", 371 alpha=0.7, 372 ) 373 374 plt.xlabel("PC 1") 375 plt.ylabel("PC 2") 376 plt.legend(title="Clusters", loc="center left", bbox_to_anchor=(1.0, 0.5)) 377 378 plt.grid(True) 379 plt.show() 380 381 self.clustering_metadata["PCA_clusters"] = [str(x) for x in dbscan_labels] 382 383 try: 384 self.input_metadata["PCA_clusters"] = [str(x) for x in dbscan_labels] 385 except: 386 pass 387 388 return fig 389 390 def perform_UMAP( 391 self, 392 factorize: bool = False, 393 umap_num: int = 100, 394 pc_num: int = 0, 395 harmonized: bool = False, 396 n_neighbors: int = 5, 397 min_dist: float | int = 0.1, 398 spread: float | int = 1.0, 399 set_op_mix_ratio: float | int = 1.0, 400 local_connectivity: int = 1, 401 repulsion_strength: float | int = 1.0, 402 negative_sample_rate: int = 5, 403 width: int = 8, 404 height: int = 6, 405 ): 406 """ 407 Compute and visualize UMAP embeddings of the dataset. 408 409 This method applies Uniform Manifold Approximation and Projection (UMAP) 410 for dimensionality reduction on either raw, PCA, or harmonized PCA data. 411 Results are stored as a DataFrame (`self.umap`) and a scatter plot figure 412 (`self.UMAP_plot`). 413 414 Parameters 415 ---------- 416 factorize : bool, default False 417 If True, categorical sample labels (from column names) are factorized 418 and used as supervision in UMAP fitting. 419 420 umap_num : int, default 100 421 Number of UMAP dimensions to compute. If 0, matches the input dimension. 422 423 pc_num : int, default 0 424 Number of principal components to use as UMAP input. 425 If 0, use all available components or raw data. 426 427 harmonized : bool, default False 428 If True, use harmonized PCA embeddings (`self.harmonized_pca`). 429 If False, use standard PCA or raw scaled data. 430 431 n_neighbors : int, default 5 432 UMAP parameter controlling the size of the local neighborhood. 433 434 min_dist : float, default 0.1 435 UMAP parameter controlling minimum allowed distance between embedded points. 436 437 spread : int | float, default 1.0 438 Effective scale of embedded space (UMAP parameter). 439 440 set_op_mix_ratio : int | float, default 1.0 441 Interpolation parameter between union and intersection in fuzzy sets. 442 443 local_connectivity : int, default 1 444 Number of nearest neighbors assumed for each point. 445 446 repulsion_strength : int | float, default 1.0 447 Weighting applied to negative samples during optimization. 448 449 negative_sample_rate : int, default 5 450 Number of negative samples per positive sample in optimization. 451 452 width : int, default 8 453 Width of the output scatter plot. 454 455 height : int, default 6 456 Height of the output scatter plot. 457 458 Updates 459 ------- 460 self.umap : pandas.DataFrame 461 Table of UMAP embeddings with columns `UMAP1 ... UMAPn`. 462 463 Notes 464 ----- 465 For supervised UMAP (`factorize=True`), categorical codes from column 466 names of the dataset are used as labels. 467 """ 468 469 scaler = StandardScaler() 470 471 if pc_num == 0 and harmonized: 472 data_scaled = self.harmonized_pca 473 474 elif pc_num == 0: 475 data_scaled = scaler.fit_transform(self.clustering_data.T) 476 477 else: 478 if harmonized: 479 480 data_scaled = self.harmonized_pca.iloc[:, 0:pc_num] 481 else: 482 483 data_scaled = self.pca.iloc[:, 0:pc_num] 484 485 if umap_num == 0 or umap_num > data_scaled.shape[1]: 486 487 umap_num = data_scaled.shape[1] 488 489 reducer = umap.UMAP( 490 n_components=len(data_scaled.T), 491 random_state=42, 492 n_neighbors=n_neighbors, 493 min_dist=min_dist, 494 spread=spread, 495 set_op_mix_ratio=set_op_mix_ratio, 496 local_connectivity=local_connectivity, 497 repulsion_strength=repulsion_strength, 498 negative_sample_rate=negative_sample_rate, 499 n_jobs=1, 500 ) 501 502 else: 503 504 reducer = umap.UMAP( 505 n_components=umap_num, 506 random_state=42, 507 n_neighbors=n_neighbors, 508 min_dist=min_dist, 509 spread=spread, 510 set_op_mix_ratio=set_op_mix_ratio, 511 local_connectivity=local_connectivity, 512 repulsion_strength=repulsion_strength, 513 negative_sample_rate=negative_sample_rate, 514 n_jobs=1, 515 ) 516 517 if factorize: 518 embedding = reducer.fit_transform( 519 X=data_scaled, y=pd.Categorical(self.clustering_data.columns).codes 520 ) 521 else: 522 embedding = reducer.fit_transform(X=data_scaled) 523 524 umap_df = pd.DataFrame( 525 embedding, columns=["UMAP" + str(x + 1) for x in range(umap_num)] 526 ) 527 528 plt.figure(figsize=(width, height)) 529 plt.scatter(umap_df["UMAP1"], umap_df["UMAP2"], alpha=0.7) 530 plt.xlabel("UMAP 1") 531 plt.ylabel("UMAP 2") 532 plt.grid(True) 533 534 plt.show() 535 536 self.umap = umap_df 537 538 def knee_plot_umap(self, eps: int | float = 0.5, min_samples: int = 10): 539 """ 540 Plot silhouette scores for different UMAP dimensions to determine optimal n_components. 541 542 Parameters 543 ---------- 544 eps : float, default 0.5 545 DBSCAN eps parameter for clustering each UMAP dimension. 546 547 min_samples : int, default 10 548 Minimum number of samples to form a cluster in DBSCAN. 549 550 Returns 551 ------- 552 matplotlib.figure.Figure 553 Silhouette score plot across UMAP dimensions. 554 """ 555 556 umap_range = range(2, len(self.umap.T) + 1) 557 558 silhouette_scores = [] 559 component = [] 560 for n in umap_range: 561 562 db = DBSCAN(eps=eps, min_samples=min_samples) 563 labels = db.fit_predict(np.array(self.umap)[:, :n]) 564 565 mask = labels != -1 566 if len(set(labels[mask])) > 1: 567 score = silhouette_score(np.array(self.umap)[:, :n][mask], labels[mask]) 568 else: 569 score = -1 570 571 silhouette_scores.append(score) 572 component.append(n) 573 574 fig = plt.figure(figsize=(10, 5)) 575 plt.plot(component, silhouette_scores, marker="o") 576 plt.xlabel("UMAP (n_components)") 577 plt.ylabel("Silhouette Score") 578 plt.grid(True) 579 plt.xticks(range(int(min(component)), int(max(component)) + 1, 1)) 580 581 plt.show() 582 583 return fig 584 585 def find_clusters_UMAP( 586 self, 587 umap_n: int = 5, 588 eps: float | float = 0.5, 589 min_samples: int = 10, 590 width: int = 8, 591 height: int = 6, 592 ): 593 """ 594 Apply DBSCAN clustering on UMAP embeddings and visualize clusters. 595 596 This method performs density-based clustering (DBSCAN) on the UMAP-reduced 597 dataset. Cluster labels are stored in the object's metadata, and a scatter 598 plot of the first two UMAP components with cluster annotations is returned. 599 600 Parameters 601 ---------- 602 umap_n : int, default 5 603 Number of UMAP dimensions to use for DBSCAN clustering. 604 Must be <= number of columns in `self.umap`. 605 606 eps : float | int, default 0.5 607 Maximum neighborhood distance between two samples for them to be considered 608 as in the same cluster (DBSCAN parameter). 609 610 min_samples : int, default 10 611 Minimum number of samples in a neighborhood to form a cluster (DBSCAN parameter). 612 613 width : int, default 8 614 Figure width. 615 616 height : int, default 6 617 Figure height. 618 619 Returns 620 ------- 621 matplotlib.figure.Figure 622 Scatter plot of the first two UMAP components colored by 623 cluster assignments. 624 625 Updates 626 ------- 627 self.clustering_metadata['UMAP_clusters'] : list 628 Cluster labels assigned to each cell/sample. 629 630 self.input_metadata['UMAP_clusters'] : list, optional 631 Cluster labels stored in input metadata (if available). 632 """ 633 634 dbscan = DBSCAN(eps=eps, min_samples=min_samples) 635 dbscan_labels = dbscan.fit_predict(np.array(self.umap)[:, :umap_n]) 636 637 umap_df = self.umap 638 umap_df["Cluster"] = dbscan_labels 639 640 fig = plt.figure(figsize=(width, height)) 641 642 for cluster_id in sorted(umap_df["Cluster"].unique()): 643 cluster_data = umap_df[umap_df["Cluster"] == cluster_id] 644 plt.scatter( 645 cluster_data["UMAP1"], 646 cluster_data["UMAP2"], 647 label=f"Cluster {cluster_id}", 648 alpha=0.7, 649 ) 650 651 plt.xlabel("UMAP 1") 652 plt.ylabel("UMAP 2") 653 plt.legend(title="Clusters", loc="center left", bbox_to_anchor=(1.0, 0.5)) 654 plt.grid(True) 655 656 self.clustering_metadata["UMAP_clusters"] = [str(x) for x in dbscan_labels] 657 658 try: 659 self.input_metadata["UMAP_clusters"] = [str(x) for x in dbscan_labels] 660 except: 661 pass 662 663 return fig 664 665 def UMAP_vis( 666 self, 667 names_slot: str = "cell_names", 668 set_sep: bool = True, 669 point_size: int | float = 0.6, 670 font_size: int | float = 6, 671 legend_split_col: int = 2, 672 width: int = 8, 673 height: int = 6, 674 inc_num: bool = True, 675 ): 676 """ 677 Visualize UMAP embeddings with sample labels based on specyfic metadata slot. 678 679 Parameters 680 ---------- 681 names_slot : str, default 'cell_names' 682 Column in metadata to use as sample labels. 683 684 set_sep : bool, default True 685 If True, separate points by dataset. 686 687 point_size : float, default 0.6 688 Size of scatter points. 689 690 font_size : int, default 6 691 Font size for numbers on points. 692 693 legend_split_col : int, default 2 694 Number of columns in legend. 695 696 width : int, default 8 697 Figure width. 698 699 height : int, default 6 700 Figure height. 701 702 inc_num : bool, default True 703 If True, annotate points with numeric labels. 704 705 Returns 706 ------- 707 matplotlib.figure.Figure 708 UMAP scatter plot figure. 709 """ 710 711 umap_df = self.umap.iloc[:, 0:2].copy() 712 umap_df["names"] = list(self.clustering_metadata[names_slot]) 713 714 if set_sep: 715 716 if "sets" in list(self.clustering_metadata.columns): 717 umap_df["dataset"] = list(self.clustering_metadata["sets"]) 718 else: 719 umap_df["dataset"] = "default" 720 721 else: 722 umap_df["dataset"] = "default" 723 724 umap_df["tmp_nam"] = list(umap_df["names"] + umap_df["dataset"]) 725 726 umap_df["count"] = umap_df["tmp_nam"].map(umap_df["tmp_nam"].value_counts()) 727 728 numeric_df = ( 729 pd.DataFrame(umap_df[["count", "tmp_nam", "names"]].copy()) 730 .drop_duplicates() 731 .sort_values("count", ascending=False) 732 ) 733 numeric_df["numeric_values"] = range(0, numeric_df.shape[0]) 734 735 umap_df = umap_df.merge( 736 numeric_df[["tmp_nam", "numeric_values"]], on="tmp_nam", how="left" 737 ) 738 739 fig, ax = plt.subplots(figsize=(width, height)) 740 741 markers = ["o", "s", "^", "D", "P", "*", "X"] 742 marker_map = { 743 ds: markers[i % len(markers)] 744 for i, ds in enumerate(umap_df["dataset"].unique()) 745 } 746 747 cord_list = [] 748 749 for num, nam in zip(numeric_df["numeric_values"], numeric_df["names"]): 750 751 cluster_data = umap_df[umap_df["numeric_values"] == num] 752 753 ax.scatter( 754 cluster_data["UMAP1"], 755 cluster_data["UMAP2"], 756 label=f"{num} - {nam}", 757 marker=marker_map[cluster_data["dataset"].iloc[0]], 758 alpha=0.6, 759 s=point_size, 760 ) 761 762 coords = cluster_data[["UMAP1", "UMAP2"]].values 763 764 dists = pairwise_distances(coords) 765 766 sum_dists = dists.sum(axis=1) 767 768 center_idx = np.argmin(sum_dists) 769 center_point = coords[center_idx] 770 771 cord_list.append(center_point) 772 773 if inc_num: 774 texts = [] 775 for (x, y), num in zip(cord_list, numeric_df["numeric_values"]): 776 texts.append( 777 ax.text( 778 x, 779 y, 780 str(num), 781 ha="center", 782 va="center", 783 fontsize=font_size, 784 color="black", 785 ) 786 ) 787 788 adjust_text(texts, arrowprops=dict(arrowstyle="-", color="gray", lw=0.5)) 789 790 ax.set_xlabel("UMAP 1") 791 ax.set_ylabel("UMAP 2") 792 793 ax.legend( 794 title="Clusters", 795 loc="center left", 796 bbox_to_anchor=(1.05, 0.5), 797 ncol=legend_split_col, 798 markerscale=5, 799 ) 800 801 ax.grid(True) 802 803 plt.tight_layout() 804 805 return fig 806 807 def UMAP_feature( 808 self, 809 feature_name: str, 810 features_data: pd.DataFrame | None, 811 point_size: int | float = 0.6, 812 font_size: int | float = 6, 813 width: int = 8, 814 height: int = 6, 815 palette="light", 816 ): 817 """ 818 Visualize UMAP embedding with expression levels of a selected feature. 819 820 Each point (cell) in the UMAP plot is colored according to the expression 821 value of the chosen feature, enabling interpretation of spatial patterns 822 of gene activity or metadata distribution in low-dimensional space. 823 824 Parameters 825 ---------- 826 feature_name : str 827 Name of the feature to plot. 828 829 features_data : pandas.DataFrame or None, default None 830 If None, the function uses the DataFrame containing the clustering data. 831 To plot features not used in clustering, provide a wider DataFrame 832 containing the original feature values. 833 834 point_size : float, default 0.6 835 Size of scatter points in the plot. 836 837 font_size : int, default 6 838 Font size for axis labels and annotations. 839 840 width : int, default 8 841 Width of the matplotlib figure. 842 843 height : int, default 6 844 Height of the matplotlib figure. 845 846 palette : str, default 'light' 847 Color palette for expression visualization. Options are: 848 - 'light' 849 - 'dark' 850 - 'green' 851 - 'gray' 852 853 Returns 854 ------- 855 matplotlib.figure.Figure 856 UMAP scatter plot colored by feature values. 857 """ 858 859 umap_df = self.umap.iloc[:, 0:2].copy() 860 861 if features_data is None: 862 863 features_data = self.clustering_data 864 865 if features_data.shape[1] != umap_df.shape[0]: 866 raise ValueError( 867 "Imputed 'features_data' shape does not match the number of UMAP cells" 868 ) 869 870 blist = [ 871 True if x.upper() == feature_name.upper() else False 872 for x in features_data.index 873 ] 874 875 if not any(blist): 876 raise ValueError("Imputed feature_name is not included in the data") 877 878 umap_df.loc[:, "feature"] = ( 879 features_data.loc[blist, :] 880 .apply(lambda row: row.tolist(), axis=1) 881 .values[0] 882 ) 883 884 umap_df = umap_df.sort_values("feature", ascending=True) 885 886 import matplotlib.colors as mcolors 887 888 if palette == "light": 889 palette = px.colors.sequential.Sunsetdark 890 891 elif palette == "dark": 892 palette = px.colors.sequential.thermal 893 894 elif palette == "green": 895 palette = px.colors.sequential.Aggrnyl 896 897 elif palette == "gray": 898 palette = px.colors.sequential.gray 899 palette = palette[::-1] 900 901 else: 902 raise ValueError( 903 'Palette not found. Use: "light", "dark", "gray", or "green"' 904 ) 905 906 converted = [] 907 for c in palette: 908 rgb_255 = px.colors.unlabel_rgb(c) 909 rgb_01 = tuple(v / 255.0 for v in rgb_255) 910 converted.append(rgb_01) 911 912 my_cmap = mcolors.ListedColormap(converted, name="custom") 913 914 fig, ax = plt.subplots(figsize=(width, height)) 915 sc = ax.scatter( 916 umap_df["UMAP1"], 917 umap_df["UMAP2"], 918 c=umap_df["feature"], 919 s=point_size, 920 cmap=my_cmap, 921 alpha=1.0, 922 edgecolors="black", 923 linewidths=0.1, 924 ) 925 926 cbar = plt.colorbar(sc, ax=ax) 927 cbar.set_label(f"{feature_name}") 928 929 ax.set_xlabel("UMAP 1") 930 ax.set_ylabel("UMAP 2") 931 932 ax.grid(True) 933 934 plt.tight_layout() 935 936 return fig 937 938 def get_umap_data(self): 939 """ 940 Retrieve UMAP embedding data with optional cluster labels. 941 942 Returns the UMAP coordinates stored in `self.umap`. If clustering 943 metadata is available (specifically `UMAP_clusters`), the corresponding 944 cluster assignments are appended as an additional column. 945 946 Returns 947 ------- 948 pandas.DataFrame 949 DataFrame containing UMAP coordinates (columns: 'UMAP1', 'UMAP2', ...). 950 If available, includes an extra column 'clusters' with cluster labels. 951 952 Notes 953 ----- 954 - UMAP embeddings must be computed beforehand (e.g., using `perform_UMAP`). 955 - Cluster labels are added only if present in `self.clustering_metadata`. 956 """ 957 958 umap_data = self.umap 959 960 try: 961 umap_data["clusters"] = self.clustering_metadata["UMAP_clusters"] 962 except: 963 pass 964 965 return umap_data 966 967 def get_pca_data(self): 968 """ 969 Retrieve PCA embedding data with optional cluster labels. 970 971 Returns the principal component scores stored in `self.pca`. If clustering 972 metadata is available (specifically `PCA_clusters`), the corresponding 973 cluster assignments are appended as an additional column. 974 975 Returns 976 ------- 977 pandas.DataFrame 978 DataFrame containing PCA coordinates (columns: 'PC1', 'PC2', ...). 979 If available, includes an extra column 'clusters' with cluster labels. 980 981 Notes 982 ----- 983 - PCA must be computed beforehand (e.g., using `perform_PCA`). 984 - Cluster labels are added only if present in `self.clustering_metadata`. 985 """ 986 987 pca_data = self.pca 988 989 try: 990 pca_data["clusters"] = self.clustering_metadata["PCA_clusters"] 991 except: 992 pass 993 994 return pca_data 995 996 def return_clusters(self, clusters="umap"): 997 """ 998 Retrieve cluster labels from UMAP or PCA clustering results. 999 1000 Parameters 1001 ---------- 1002 clusters : str, default 'umap' 1003 Source of cluster labels to return. Must be one of: 1004 - 'umap': return cluster labels from UMAP embeddings. 1005 - 'pca' : return cluster labels from PCA embeddings. 1006 1007 Returns 1008 ------- 1009 list 1010 Cluster labels corresponding to the selected embedding method. 1011 1012 Raises 1013 ------ 1014 ValueError 1015 If `clusters` is not 'umap' or 'pca'. 1016 1017 Notes 1018 ----- 1019 Requires that clustering has already been performed 1020 (e.g., using `find_clusters_UMAP` or `find_clusters_PCA`). 1021 """ 1022 1023 if clusters.lower() == "umap": 1024 clusters_vector = self.clustering_metadata["UMAP_clusters"] 1025 elif clusters.lower() == "pca": 1026 clusters_vector = self.clustering_metadata["PCA_clusters"] 1027 else: 1028 raise ValueError("Parameter 'clusters' must be either 'umap' or 'pca'.") 1029 1030 return clusters_vector
A class for performing dimensionality reduction, clustering, and visualization on high-dimensional data (e.g., single-cell gene expression).
The class provides methods for:
- Normalizing and extracting subsets of data
- Principal Component Analysis (PCA) and related clustering
- Uniform Manifold Approximation and Projection (UMAP) and clustering
- Visualization of PCA and UMAP embeddings
- Harmonization of batch effects
- Accessing processed data and cluster labels
Methods
add_data_frame(data, metadata) Class method to create a Clustering instance from a DataFrame and metadata.
harmonize_sets() Perform batch effect harmonization on PCA data.
perform_PCA(pc_num=100, width=8, height=6) Perform PCA on the dataset and visualize the first two PCs.
knee_plot_PCA(width=8, height=6) Plot the cumulative variance explained by PCs to determine optimal dimensionality.
find_clusters_PCA(pc_num=0, eps=0.5, min_samples=10, width=8, height=6, harmonized=False) Apply DBSCAN clustering to PCA embeddings and visualize results.
perform_UMAP(factorize=False, umap_num=100, pc_num=0, harmonized=False, ...) Compute UMAP embeddings with optional parameter tuning.
knee_plot_umap(eps=0.5, min_samples=10) Determine optimal UMAP dimensionality using silhouette scores.
find_clusters_UMAP(umap_n=5, eps=0.5, min_samples=10, width=8, height=6) Apply DBSCAN clustering on UMAP embeddings and visualize clusters.
UMAP_vis(names_slot='cell_names', set_sep=True, point_size=0.6, ...) Visualize UMAP embeddings with labels and optional cluster numbering.
UMAP_feature(feature_name, features_data=None, point_size=0.6, ...) Plot a single feature over UMAP coordinates with customizable colormap.
get_umap_data() Return the UMAP embeddings along with cluster labels if available.
get_pca_data() Return the PCA results along with cluster labels if available.
return_clusters(clusters='umap') Return the cluster labels for UMAP or PCA embeddings.
Raises
ValueError For invalid parameters, mismatched dimensions, or missing metadata.
91 def __init__(self, data, metadata): 92 """ 93 Initialize the clustering class with data and optional metadata. 94 95 Parameters 96 ---------- 97 data : pandas.DataFrame 98 Input data for clustering. Columns are considered as samples. 99 100 metadata : pandas.DataFrame, optional 101 Metadata for the samples. If None, a default DataFrame with column 102 names as 'cell_names' is created. 103 104 Attributes 105 ---------- 106 -clustering_data : pandas.DataFrame 107 -clustering_metadata : pandas.DataFrame 108 -subclusters : None or dict 109 -explained_var : None or numpy.ndarray 110 -cumulative_var : None or numpy.ndarray 111 -pca : None or pandas.DataFrame 112 -harmonized_pca : None or pandas.DataFrame 113 -umap : None or pandas.DataFrame 114 """ 115 116 self.clustering_data = data 117 """The input data used for clustering.""" 118 119 if metadata is None: 120 metadata = pd.DataFrame({"cell_names": list(data.columns)}) 121 122 self.clustering_metadata = metadata 123 """Metadata associated with the samples.""" 124 125 self.subclusters = None 126 """Placeholder for storing subcluster information.""" 127 128 self.explained_var = None 129 """Explained variance from PCA, initialized as None.""" 130 131 self.cumulative_var = None 132 """Cumulative explained variance from PCA, initialized as None.""" 133 134 self.pca = None 135 """PCA-transformed data, initialized as None.""" 136 137 self.harmonized_pca = None 138 """PCA data after batch effect harmonization, initialized as None.""" 139 140 self.umap = None 141 """UMAP embeddings, initialized as None."""
Initialize the clustering class with data and optional metadata.
Parameters
data : pandas.DataFrame Input data for clustering. Columns are considered as samples.
metadata : pandas.DataFrame, optional Metadata for the samples. If None, a default DataFrame with column names as 'cell_names' is created.
Attributes
-clustering_data : pandas.DataFrame -clustering_metadata : pandas.DataFrame -subclusters : None or dict -explained_var : None or numpy.ndarray -cumulative_var : None or numpy.ndarray -pca : None or pandas.DataFrame -harmonized_pca : None or pandas.DataFrame -umap : None or pandas.DataFrame
143 @classmethod 144 def add_data_frame(cls, data: pd.DataFrame, metadata: pd.DataFrame | None): 145 """ 146 Create a Clustering instance from a DataFrame and optional metadata. 147 148 Parameters 149 ---------- 150 data : pandas.DataFrame 151 Input data with features as rows and samples/cells as columns. 152 153 metadata : pandas.DataFrame or None 154 Optional metadata for the samples. 155 Each row corresponds to a sample/cell, and column names in this DataFrame 156 should match the sample/cell names in `data`. Columns can contain additional 157 information such as cell type, experimental condition, batch, sets, etc. 158 159 Returns 160 ------- 161 Clustering 162 A new instance of the Clustering class. 163 """ 164 165 return cls(data, metadata)
Create a Clustering instance from a DataFrame and optional metadata.
Parameters
data : pandas.DataFrame Input data with features as rows and samples/cells as columns.
metadata : pandas.DataFrame or None
Optional metadata for the samples.
Each row corresponds to a sample/cell, and column names in this DataFrame
should match the sample/cell names in data. Columns can contain additional
information such as cell type, experimental condition, batch, sets, etc.
Returns
Clustering A new instance of the Clustering class.
167 def harmonize_sets(self, batch_col: str = "sets"): 168 """ 169 Perform batch effect harmonization on PCA embeddings. 170 171 Parameters 172 ---------- 173 batch_col : str, default 'sets' 174 Name of the column in `metadata` that contains batch information for the samples/cells. 175 176 Returns 177 ------- 178 None 179 Updates the `harmonized_pca` attribute with harmonized data. 180 """ 181 182 data_mat = np.array(self.pca) 183 184 metadata = self.clustering_metadata 185 186 self.harmonized_pca = pd.DataFrame( 187 harmonize.run_harmony(data_mat, metadata, vars_use=batch_col).Z_corr 188 ).T 189 190 self.harmonized_pca.columns = self.pca.columns
Perform batch effect harmonization on PCA embeddings.
Parameters
batch_col : str, default 'sets'
Name of the column in metadata that contains batch information for the samples/cells.
Returns
None
Updates the harmonized_pca attribute with harmonized data.
192 def perform_PCA(self, pc_num: int = 100, width=8, height=6): 193 """ 194 Perform Principal Component Analysis (PCA) on the dataset. 195 196 This method standardizes the data, applies PCA, stores results as attributes, 197 and generates a scatter plot of the first two principal components. 198 199 Parameters 200 ---------- 201 pc_num : int, default 100 202 Number of principal components to compute. 203 If 0, computes all available components. 204 205 width : int or float, default 8 206 Width of the PCA figure. 207 208 height : int or float, default 6 209 Height of the PCA figure. 210 211 Returns 212 ------- 213 matplotlib.figure.Figure 214 Scatter plot showing the first two principal components. 215 216 Updates 217 ------- 218 self.pca : pandas.DataFrame 219 DataFrame with principal component scores for each sample. 220 221 self.explained_var : numpy.ndarray 222 Percentage of variance explained by each principal component. 223 224 self.cumulative_var : numpy.ndarray 225 Cumulative explained variance. 226 """ 227 228 scaler = StandardScaler() 229 data_scaled = scaler.fit_transform(self.clustering_data.T) 230 231 if pc_num == 0 or pc_num > data_scaled.shape[0]: 232 pc_num = data_scaled.shape[0] 233 234 pca = PCA(n_components=pc_num, random_state=42) 235 236 principal_components = pca.fit_transform(data_scaled) 237 238 pca_df = pd.DataFrame( 239 data=principal_components, 240 columns=["PC" + str(x + 1) for x in range(pc_num)], 241 ) 242 243 self.explained_var = pca.explained_variance_ratio_ * 100 244 self.cumulative_var = np.cumsum(self.explained_var) 245 246 self.pca = pca_df 247 248 fig = plt.figure(figsize=(width, height)) 249 plt.scatter(pca_df["PC1"], pca_df["PC2"], alpha=0.7) 250 plt.xlabel("PC 1") 251 plt.ylabel("PC 2") 252 plt.grid(True) 253 plt.show() 254 255 return fig
Perform Principal Component Analysis (PCA) on the dataset.
This method standardizes the data, applies PCA, stores results as attributes, and generates a scatter plot of the first two principal components.
Parameters
pc_num : int, default 100 Number of principal components to compute. If 0, computes all available components.
width : int or float, default 8 Width of the PCA figure.
height : int or float, default 6 Height of the PCA figure.
Returns
matplotlib.figure.Figure Scatter plot showing the first two principal components.
Updates
self.pca : pandas.DataFrame DataFrame with principal component scores for each sample.
self.explained_var : numpy.ndarray Percentage of variance explained by each principal component.
self.cumulative_var : numpy.ndarray Cumulative explained variance.
257 def knee_plot_PCA(self, width: int = 8, height: int = 6): 258 """ 259 Plot cumulative explained variance to determine the optimal number of PCs. 260 261 Parameters 262 ---------- 263 width : int, default 8 264 Width of the figure. 265 266 height : int or, default 6 267 Height of the figure. 268 269 Returns 270 ------- 271 matplotlib.figure.Figure 272 Line plot showing cumulative variance explained by each PC. 273 """ 274 275 fig_knee = plt.figure(figsize=(width, height)) 276 plt.plot(range(1, len(self.explained_var) + 1), self.cumulative_var, marker="o") 277 plt.xlabel("PC (n components)") 278 plt.ylabel("Cumulative explained variance (%)") 279 plt.grid(True) 280 281 xticks = [1] + list(range(5, len(self.explained_var) + 1, 5)) 282 plt.xticks(xticks, rotation=60) 283 284 plt.show() 285 286 return fig_knee
Plot cumulative explained variance to determine the optimal number of PCs.
Parameters
width : int, default 8 Width of the figure.
height : int or, default 6 Height of the figure.
Returns
matplotlib.figure.Figure Line plot showing cumulative variance explained by each PC.
288 def find_clusters_PCA( 289 self, 290 pc_num: int = 2, 291 eps: float = 0.5, 292 min_samples: int = 10, 293 width: int = 8, 294 height: int = 6, 295 harmonized: bool = False, 296 ): 297 """ 298 Apply DBSCAN clustering to PCA embeddings and visualize the results. 299 300 This method performs density-based clustering (DBSCAN) on the PCA-reduced 301 dataset. Cluster labels are stored in the object's metadata, and a scatter 302 plot of the first two principal components with cluster annotations is returned. 303 304 Parameters 305 ---------- 306 pc_num : int, default 2 307 Number of principal components to use for clustering. 308 If 0, uses all available components. 309 310 eps : float, default 0.5 311 Maximum distance between two points for them to be considered 312 as neighbors (DBSCAN parameter). 313 314 min_samples : int, default 10 315 Minimum number of samples required to form a cluster (DBSCAN parameter). 316 317 width : int, default 8 318 Width of the output scatter plot. 319 320 height : int, default 6 321 Height of the output scatter plot. 322 323 harmonized : bool, default False 324 If True, use harmonized PCA data (`self.harmonized_pca`). 325 If False, use standard PCA results (`self.pca`). 326 327 Returns 328 ------- 329 matplotlib.figure.Figure 330 Scatter plot of the first two principal components colored by 331 cluster assignments. 332 333 Updates 334 ------- 335 self.clustering_metadata['PCA_clusters'] : list 336 Cluster labels assigned to each cell/sample. 337 338 self.input_metadata['PCA_clusters'] : list, optional 339 Cluster labels stored in input metadata (if available). 340 """ 341 342 dbscan = DBSCAN(eps=eps, min_samples=min_samples) 343 344 if pc_num == 0 and harmonized: 345 PCA = self.harmonized_pca 346 347 elif pc_num == 0: 348 PCA = self.pca 349 350 else: 351 if harmonized: 352 353 PCA = self.harmonized_pca.iloc[:, 0:pc_num] 354 else: 355 356 PCA = self.pca.iloc[:, 0:pc_num] 357 358 dbscan_labels = dbscan.fit_predict(PCA) 359 360 pca_df = pd.DataFrame(PCA) 361 pca_df["Cluster"] = dbscan_labels 362 363 fig = plt.figure(figsize=(width, height)) 364 365 for cluster_id in sorted(pca_df["Cluster"].unique()): 366 cluster_data = pca_df[pca_df["Cluster"] == cluster_id] 367 plt.scatter( 368 cluster_data["PC1"], 369 cluster_data["PC2"], 370 label=f"Cluster {cluster_id}", 371 alpha=0.7, 372 ) 373 374 plt.xlabel("PC 1") 375 plt.ylabel("PC 2") 376 plt.legend(title="Clusters", loc="center left", bbox_to_anchor=(1.0, 0.5)) 377 378 plt.grid(True) 379 plt.show() 380 381 self.clustering_metadata["PCA_clusters"] = [str(x) for x in dbscan_labels] 382 383 try: 384 self.input_metadata["PCA_clusters"] = [str(x) for x in dbscan_labels] 385 except: 386 pass 387 388 return fig
Apply DBSCAN clustering to PCA embeddings and visualize the results.
This method performs density-based clustering (DBSCAN) on the PCA-reduced dataset. Cluster labels are stored in the object's metadata, and a scatter plot of the first two principal components with cluster annotations is returned.
Parameters
pc_num : int, default 2 Number of principal components to use for clustering. If 0, uses all available components.
eps : float, default 0.5 Maximum distance between two points for them to be considered as neighbors (DBSCAN parameter).
min_samples : int, default 10 Minimum number of samples required to form a cluster (DBSCAN parameter).
width : int, default 8 Width of the output scatter plot.
height : int, default 6 Height of the output scatter plot.
harmonized : bool, default False
If True, use harmonized PCA data (self.harmonized_pca).
If False, use standard PCA results (self.pca).
Returns
matplotlib.figure.Figure Scatter plot of the first two principal components colored by cluster assignments.
Updates
self.clustering_metadata['PCA_clusters'] : list Cluster labels assigned to each cell/sample.
self.input_metadata['PCA_clusters'] : list, optional Cluster labels stored in input metadata (if available).
390 def perform_UMAP( 391 self, 392 factorize: bool = False, 393 umap_num: int = 100, 394 pc_num: int = 0, 395 harmonized: bool = False, 396 n_neighbors: int = 5, 397 min_dist: float | int = 0.1, 398 spread: float | int = 1.0, 399 set_op_mix_ratio: float | int = 1.0, 400 local_connectivity: int = 1, 401 repulsion_strength: float | int = 1.0, 402 negative_sample_rate: int = 5, 403 width: int = 8, 404 height: int = 6, 405 ): 406 """ 407 Compute and visualize UMAP embeddings of the dataset. 408 409 This method applies Uniform Manifold Approximation and Projection (UMAP) 410 for dimensionality reduction on either raw, PCA, or harmonized PCA data. 411 Results are stored as a DataFrame (`self.umap`) and a scatter plot figure 412 (`self.UMAP_plot`). 413 414 Parameters 415 ---------- 416 factorize : bool, default False 417 If True, categorical sample labels (from column names) are factorized 418 and used as supervision in UMAP fitting. 419 420 umap_num : int, default 100 421 Number of UMAP dimensions to compute. If 0, matches the input dimension. 422 423 pc_num : int, default 0 424 Number of principal components to use as UMAP input. 425 If 0, use all available components or raw data. 426 427 harmonized : bool, default False 428 If True, use harmonized PCA embeddings (`self.harmonized_pca`). 429 If False, use standard PCA or raw scaled data. 430 431 n_neighbors : int, default 5 432 UMAP parameter controlling the size of the local neighborhood. 433 434 min_dist : float, default 0.1 435 UMAP parameter controlling minimum allowed distance between embedded points. 436 437 spread : int | float, default 1.0 438 Effective scale of embedded space (UMAP parameter). 439 440 set_op_mix_ratio : int | float, default 1.0 441 Interpolation parameter between union and intersection in fuzzy sets. 442 443 local_connectivity : int, default 1 444 Number of nearest neighbors assumed for each point. 445 446 repulsion_strength : int | float, default 1.0 447 Weighting applied to negative samples during optimization. 448 449 negative_sample_rate : int, default 5 450 Number of negative samples per positive sample in optimization. 451 452 width : int, default 8 453 Width of the output scatter plot. 454 455 height : int, default 6 456 Height of the output scatter plot. 457 458 Updates 459 ------- 460 self.umap : pandas.DataFrame 461 Table of UMAP embeddings with columns `UMAP1 ... UMAPn`. 462 463 Notes 464 ----- 465 For supervised UMAP (`factorize=True`), categorical codes from column 466 names of the dataset are used as labels. 467 """ 468 469 scaler = StandardScaler() 470 471 if pc_num == 0 and harmonized: 472 data_scaled = self.harmonized_pca 473 474 elif pc_num == 0: 475 data_scaled = scaler.fit_transform(self.clustering_data.T) 476 477 else: 478 if harmonized: 479 480 data_scaled = self.harmonized_pca.iloc[:, 0:pc_num] 481 else: 482 483 data_scaled = self.pca.iloc[:, 0:pc_num] 484 485 if umap_num == 0 or umap_num > data_scaled.shape[1]: 486 487 umap_num = data_scaled.shape[1] 488 489 reducer = umap.UMAP( 490 n_components=len(data_scaled.T), 491 random_state=42, 492 n_neighbors=n_neighbors, 493 min_dist=min_dist, 494 spread=spread, 495 set_op_mix_ratio=set_op_mix_ratio, 496 local_connectivity=local_connectivity, 497 repulsion_strength=repulsion_strength, 498 negative_sample_rate=negative_sample_rate, 499 n_jobs=1, 500 ) 501 502 else: 503 504 reducer = umap.UMAP( 505 n_components=umap_num, 506 random_state=42, 507 n_neighbors=n_neighbors, 508 min_dist=min_dist, 509 spread=spread, 510 set_op_mix_ratio=set_op_mix_ratio, 511 local_connectivity=local_connectivity, 512 repulsion_strength=repulsion_strength, 513 negative_sample_rate=negative_sample_rate, 514 n_jobs=1, 515 ) 516 517 if factorize: 518 embedding = reducer.fit_transform( 519 X=data_scaled, y=pd.Categorical(self.clustering_data.columns).codes 520 ) 521 else: 522 embedding = reducer.fit_transform(X=data_scaled) 523 524 umap_df = pd.DataFrame( 525 embedding, columns=["UMAP" + str(x + 1) for x in range(umap_num)] 526 ) 527 528 plt.figure(figsize=(width, height)) 529 plt.scatter(umap_df["UMAP1"], umap_df["UMAP2"], alpha=0.7) 530 plt.xlabel("UMAP 1") 531 plt.ylabel("UMAP 2") 532 plt.grid(True) 533 534 plt.show() 535 536 self.umap = umap_df
Compute and visualize UMAP embeddings of the dataset.
This method applies Uniform Manifold Approximation and Projection (UMAP)
for dimensionality reduction on either raw, PCA, or harmonized PCA data.
Results are stored as a DataFrame (self.umap) and a scatter plot figure
(self.UMAP_plot).
Parameters
factorize : bool, default False If True, categorical sample labels (from column names) are factorized and used as supervision in UMAP fitting.
umap_num : int, default 100 Number of UMAP dimensions to compute. If 0, matches the input dimension.
pc_num : int, default 0 Number of principal components to use as UMAP input. If 0, use all available components or raw data.
harmonized : bool, default False
If True, use harmonized PCA embeddings (self.harmonized_pca).
If False, use standard PCA or raw scaled data.
n_neighbors : int, default 5 UMAP parameter controlling the size of the local neighborhood.
min_dist : float, default 0.1 UMAP parameter controlling minimum allowed distance between embedded points.
spread : int | float, default 1.0 Effective scale of embedded space (UMAP parameter).
set_op_mix_ratio : int | float, default 1.0 Interpolation parameter between union and intersection in fuzzy sets.
local_connectivity : int, default 1 Number of nearest neighbors assumed for each point.
repulsion_strength : int | float, default 1.0 Weighting applied to negative samples during optimization.
negative_sample_rate : int, default 5 Number of negative samples per positive sample in optimization.
width : int, default 8 Width of the output scatter plot.
height : int, default 6 Height of the output scatter plot.
Updates
self.umap : pandas.DataFrame
Table of UMAP embeddings with columns UMAP1 ... UMAPn.
Notes
For supervised UMAP (factorize=True), categorical codes from column
names of the dataset are used as labels.
538 def knee_plot_umap(self, eps: int | float = 0.5, min_samples: int = 10): 539 """ 540 Plot silhouette scores for different UMAP dimensions to determine optimal n_components. 541 542 Parameters 543 ---------- 544 eps : float, default 0.5 545 DBSCAN eps parameter for clustering each UMAP dimension. 546 547 min_samples : int, default 10 548 Minimum number of samples to form a cluster in DBSCAN. 549 550 Returns 551 ------- 552 matplotlib.figure.Figure 553 Silhouette score plot across UMAP dimensions. 554 """ 555 556 umap_range = range(2, len(self.umap.T) + 1) 557 558 silhouette_scores = [] 559 component = [] 560 for n in umap_range: 561 562 db = DBSCAN(eps=eps, min_samples=min_samples) 563 labels = db.fit_predict(np.array(self.umap)[:, :n]) 564 565 mask = labels != -1 566 if len(set(labels[mask])) > 1: 567 score = silhouette_score(np.array(self.umap)[:, :n][mask], labels[mask]) 568 else: 569 score = -1 570 571 silhouette_scores.append(score) 572 component.append(n) 573 574 fig = plt.figure(figsize=(10, 5)) 575 plt.plot(component, silhouette_scores, marker="o") 576 plt.xlabel("UMAP (n_components)") 577 plt.ylabel("Silhouette Score") 578 plt.grid(True) 579 plt.xticks(range(int(min(component)), int(max(component)) + 1, 1)) 580 581 plt.show() 582 583 return fig
Plot silhouette scores for different UMAP dimensions to determine optimal n_components.
Parameters
eps : float, default 0.5 DBSCAN eps parameter for clustering each UMAP dimension.
min_samples : int, default 10 Minimum number of samples to form a cluster in DBSCAN.
Returns
matplotlib.figure.Figure Silhouette score plot across UMAP dimensions.
585 def find_clusters_UMAP( 586 self, 587 umap_n: int = 5, 588 eps: float | float = 0.5, 589 min_samples: int = 10, 590 width: int = 8, 591 height: int = 6, 592 ): 593 """ 594 Apply DBSCAN clustering on UMAP embeddings and visualize clusters. 595 596 This method performs density-based clustering (DBSCAN) on the UMAP-reduced 597 dataset. Cluster labels are stored in the object's metadata, and a scatter 598 plot of the first two UMAP components with cluster annotations is returned. 599 600 Parameters 601 ---------- 602 umap_n : int, default 5 603 Number of UMAP dimensions to use for DBSCAN clustering. 604 Must be <= number of columns in `self.umap`. 605 606 eps : float | int, default 0.5 607 Maximum neighborhood distance between two samples for them to be considered 608 as in the same cluster (DBSCAN parameter). 609 610 min_samples : int, default 10 611 Minimum number of samples in a neighborhood to form a cluster (DBSCAN parameter). 612 613 width : int, default 8 614 Figure width. 615 616 height : int, default 6 617 Figure height. 618 619 Returns 620 ------- 621 matplotlib.figure.Figure 622 Scatter plot of the first two UMAP components colored by 623 cluster assignments. 624 625 Updates 626 ------- 627 self.clustering_metadata['UMAP_clusters'] : list 628 Cluster labels assigned to each cell/sample. 629 630 self.input_metadata['UMAP_clusters'] : list, optional 631 Cluster labels stored in input metadata (if available). 632 """ 633 634 dbscan = DBSCAN(eps=eps, min_samples=min_samples) 635 dbscan_labels = dbscan.fit_predict(np.array(self.umap)[:, :umap_n]) 636 637 umap_df = self.umap 638 umap_df["Cluster"] = dbscan_labels 639 640 fig = plt.figure(figsize=(width, height)) 641 642 for cluster_id in sorted(umap_df["Cluster"].unique()): 643 cluster_data = umap_df[umap_df["Cluster"] == cluster_id] 644 plt.scatter( 645 cluster_data["UMAP1"], 646 cluster_data["UMAP2"], 647 label=f"Cluster {cluster_id}", 648 alpha=0.7, 649 ) 650 651 plt.xlabel("UMAP 1") 652 plt.ylabel("UMAP 2") 653 plt.legend(title="Clusters", loc="center left", bbox_to_anchor=(1.0, 0.5)) 654 plt.grid(True) 655 656 self.clustering_metadata["UMAP_clusters"] = [str(x) for x in dbscan_labels] 657 658 try: 659 self.input_metadata["UMAP_clusters"] = [str(x) for x in dbscan_labels] 660 except: 661 pass 662 663 return fig
Apply DBSCAN clustering on UMAP embeddings and visualize clusters.
This method performs density-based clustering (DBSCAN) on the UMAP-reduced dataset. Cluster labels are stored in the object's metadata, and a scatter plot of the first two UMAP components with cluster annotations is returned.
Parameters
umap_n : int, default 5
Number of UMAP dimensions to use for DBSCAN clustering.
Must be <= number of columns in self.umap.
eps : float | int, default 0.5 Maximum neighborhood distance between two samples for them to be considered as in the same cluster (DBSCAN parameter).
min_samples : int, default 10 Minimum number of samples in a neighborhood to form a cluster (DBSCAN parameter).
width : int, default 8 Figure width.
height : int, default 6 Figure height.
Returns
matplotlib.figure.Figure Scatter plot of the first two UMAP components colored by cluster assignments.
Updates
self.clustering_metadata['UMAP_clusters'] : list Cluster labels assigned to each cell/sample.
self.input_metadata['UMAP_clusters'] : list, optional Cluster labels stored in input metadata (if available).
665 def UMAP_vis( 666 self, 667 names_slot: str = "cell_names", 668 set_sep: bool = True, 669 point_size: int | float = 0.6, 670 font_size: int | float = 6, 671 legend_split_col: int = 2, 672 width: int = 8, 673 height: int = 6, 674 inc_num: bool = True, 675 ): 676 """ 677 Visualize UMAP embeddings with sample labels based on specyfic metadata slot. 678 679 Parameters 680 ---------- 681 names_slot : str, default 'cell_names' 682 Column in metadata to use as sample labels. 683 684 set_sep : bool, default True 685 If True, separate points by dataset. 686 687 point_size : float, default 0.6 688 Size of scatter points. 689 690 font_size : int, default 6 691 Font size for numbers on points. 692 693 legend_split_col : int, default 2 694 Number of columns in legend. 695 696 width : int, default 8 697 Figure width. 698 699 height : int, default 6 700 Figure height. 701 702 inc_num : bool, default True 703 If True, annotate points with numeric labels. 704 705 Returns 706 ------- 707 matplotlib.figure.Figure 708 UMAP scatter plot figure. 709 """ 710 711 umap_df = self.umap.iloc[:, 0:2].copy() 712 umap_df["names"] = list(self.clustering_metadata[names_slot]) 713 714 if set_sep: 715 716 if "sets" in list(self.clustering_metadata.columns): 717 umap_df["dataset"] = list(self.clustering_metadata["sets"]) 718 else: 719 umap_df["dataset"] = "default" 720 721 else: 722 umap_df["dataset"] = "default" 723 724 umap_df["tmp_nam"] = list(umap_df["names"] + umap_df["dataset"]) 725 726 umap_df["count"] = umap_df["tmp_nam"].map(umap_df["tmp_nam"].value_counts()) 727 728 numeric_df = ( 729 pd.DataFrame(umap_df[["count", "tmp_nam", "names"]].copy()) 730 .drop_duplicates() 731 .sort_values("count", ascending=False) 732 ) 733 numeric_df["numeric_values"] = range(0, numeric_df.shape[0]) 734 735 umap_df = umap_df.merge( 736 numeric_df[["tmp_nam", "numeric_values"]], on="tmp_nam", how="left" 737 ) 738 739 fig, ax = plt.subplots(figsize=(width, height)) 740 741 markers = ["o", "s", "^", "D", "P", "*", "X"] 742 marker_map = { 743 ds: markers[i % len(markers)] 744 for i, ds in enumerate(umap_df["dataset"].unique()) 745 } 746 747 cord_list = [] 748 749 for num, nam in zip(numeric_df["numeric_values"], numeric_df["names"]): 750 751 cluster_data = umap_df[umap_df["numeric_values"] == num] 752 753 ax.scatter( 754 cluster_data["UMAP1"], 755 cluster_data["UMAP2"], 756 label=f"{num} - {nam}", 757 marker=marker_map[cluster_data["dataset"].iloc[0]], 758 alpha=0.6, 759 s=point_size, 760 ) 761 762 coords = cluster_data[["UMAP1", "UMAP2"]].values 763 764 dists = pairwise_distances(coords) 765 766 sum_dists = dists.sum(axis=1) 767 768 center_idx = np.argmin(sum_dists) 769 center_point = coords[center_idx] 770 771 cord_list.append(center_point) 772 773 if inc_num: 774 texts = [] 775 for (x, y), num in zip(cord_list, numeric_df["numeric_values"]): 776 texts.append( 777 ax.text( 778 x, 779 y, 780 str(num), 781 ha="center", 782 va="center", 783 fontsize=font_size, 784 color="black", 785 ) 786 ) 787 788 adjust_text(texts, arrowprops=dict(arrowstyle="-", color="gray", lw=0.5)) 789 790 ax.set_xlabel("UMAP 1") 791 ax.set_ylabel("UMAP 2") 792 793 ax.legend( 794 title="Clusters", 795 loc="center left", 796 bbox_to_anchor=(1.05, 0.5), 797 ncol=legend_split_col, 798 markerscale=5, 799 ) 800 801 ax.grid(True) 802 803 plt.tight_layout() 804 805 return fig
Visualize UMAP embeddings with sample labels based on specyfic metadata slot.
Parameters
names_slot : str, default 'cell_names' Column in metadata to use as sample labels.
set_sep : bool, default True If True, separate points by dataset.
point_size : float, default 0.6 Size of scatter points.
font_size : int, default 6 Font size for numbers on points.
legend_split_col : int, default 2 Number of columns in legend.
width : int, default 8 Figure width.
height : int, default 6 Figure height.
inc_num : bool, default True If True, annotate points with numeric labels.
Returns
matplotlib.figure.Figure UMAP scatter plot figure.
807 def UMAP_feature( 808 self, 809 feature_name: str, 810 features_data: pd.DataFrame | None, 811 point_size: int | float = 0.6, 812 font_size: int | float = 6, 813 width: int = 8, 814 height: int = 6, 815 palette="light", 816 ): 817 """ 818 Visualize UMAP embedding with expression levels of a selected feature. 819 820 Each point (cell) in the UMAP plot is colored according to the expression 821 value of the chosen feature, enabling interpretation of spatial patterns 822 of gene activity or metadata distribution in low-dimensional space. 823 824 Parameters 825 ---------- 826 feature_name : str 827 Name of the feature to plot. 828 829 features_data : pandas.DataFrame or None, default None 830 If None, the function uses the DataFrame containing the clustering data. 831 To plot features not used in clustering, provide a wider DataFrame 832 containing the original feature values. 833 834 point_size : float, default 0.6 835 Size of scatter points in the plot. 836 837 font_size : int, default 6 838 Font size for axis labels and annotations. 839 840 width : int, default 8 841 Width of the matplotlib figure. 842 843 height : int, default 6 844 Height of the matplotlib figure. 845 846 palette : str, default 'light' 847 Color palette for expression visualization. Options are: 848 - 'light' 849 - 'dark' 850 - 'green' 851 - 'gray' 852 853 Returns 854 ------- 855 matplotlib.figure.Figure 856 UMAP scatter plot colored by feature values. 857 """ 858 859 umap_df = self.umap.iloc[:, 0:2].copy() 860 861 if features_data is None: 862 863 features_data = self.clustering_data 864 865 if features_data.shape[1] != umap_df.shape[0]: 866 raise ValueError( 867 "Imputed 'features_data' shape does not match the number of UMAP cells" 868 ) 869 870 blist = [ 871 True if x.upper() == feature_name.upper() else False 872 for x in features_data.index 873 ] 874 875 if not any(blist): 876 raise ValueError("Imputed feature_name is not included in the data") 877 878 umap_df.loc[:, "feature"] = ( 879 features_data.loc[blist, :] 880 .apply(lambda row: row.tolist(), axis=1) 881 .values[0] 882 ) 883 884 umap_df = umap_df.sort_values("feature", ascending=True) 885 886 import matplotlib.colors as mcolors 887 888 if palette == "light": 889 palette = px.colors.sequential.Sunsetdark 890 891 elif palette == "dark": 892 palette = px.colors.sequential.thermal 893 894 elif palette == "green": 895 palette = px.colors.sequential.Aggrnyl 896 897 elif palette == "gray": 898 palette = px.colors.sequential.gray 899 palette = palette[::-1] 900 901 else: 902 raise ValueError( 903 'Palette not found. Use: "light", "dark", "gray", or "green"' 904 ) 905 906 converted = [] 907 for c in palette: 908 rgb_255 = px.colors.unlabel_rgb(c) 909 rgb_01 = tuple(v / 255.0 for v in rgb_255) 910 converted.append(rgb_01) 911 912 my_cmap = mcolors.ListedColormap(converted, name="custom") 913 914 fig, ax = plt.subplots(figsize=(width, height)) 915 sc = ax.scatter( 916 umap_df["UMAP1"], 917 umap_df["UMAP2"], 918 c=umap_df["feature"], 919 s=point_size, 920 cmap=my_cmap, 921 alpha=1.0, 922 edgecolors="black", 923 linewidths=0.1, 924 ) 925 926 cbar = plt.colorbar(sc, ax=ax) 927 cbar.set_label(f"{feature_name}") 928 929 ax.set_xlabel("UMAP 1") 930 ax.set_ylabel("UMAP 2") 931 932 ax.grid(True) 933 934 plt.tight_layout() 935 936 return fig
Visualize UMAP embedding with expression levels of a selected feature.
Each point (cell) in the UMAP plot is colored according to the expression value of the chosen feature, enabling interpretation of spatial patterns of gene activity or metadata distribution in low-dimensional space.
Parameters
feature_name : str Name of the feature to plot.
features_data : pandas.DataFrame or None, default None If None, the function uses the DataFrame containing the clustering data. To plot features not used in clustering, provide a wider DataFrame containing the original feature values.
point_size : float, default 0.6 Size of scatter points in the plot.
font_size : int, default 6 Font size for axis labels and annotations.
width : int, default 8 Width of the matplotlib figure.
height : int, default 6 Height of the matplotlib figure.
palette : str, default 'light' Color palette for expression visualization. Options are: - 'light' - 'dark' - 'green' - 'gray'
Returns
matplotlib.figure.Figure UMAP scatter plot colored by feature values.
938 def get_umap_data(self): 939 """ 940 Retrieve UMAP embedding data with optional cluster labels. 941 942 Returns the UMAP coordinates stored in `self.umap`. If clustering 943 metadata is available (specifically `UMAP_clusters`), the corresponding 944 cluster assignments are appended as an additional column. 945 946 Returns 947 ------- 948 pandas.DataFrame 949 DataFrame containing UMAP coordinates (columns: 'UMAP1', 'UMAP2', ...). 950 If available, includes an extra column 'clusters' with cluster labels. 951 952 Notes 953 ----- 954 - UMAP embeddings must be computed beforehand (e.g., using `perform_UMAP`). 955 - Cluster labels are added only if present in `self.clustering_metadata`. 956 """ 957 958 umap_data = self.umap 959 960 try: 961 umap_data["clusters"] = self.clustering_metadata["UMAP_clusters"] 962 except: 963 pass 964 965 return umap_data
Retrieve UMAP embedding data with optional cluster labels.
Returns the UMAP coordinates stored in self.umap. If clustering
metadata is available (specifically UMAP_clusters), the corresponding
cluster assignments are appended as an additional column.
Returns
pandas.DataFrame DataFrame containing UMAP coordinates (columns: 'UMAP1', 'UMAP2', ...). If available, includes an extra column 'clusters' with cluster labels.
Notes
- UMAP embeddings must be computed beforehand (e.g., using
perform_UMAP). - Cluster labels are added only if present in
self.clustering_metadata.
967 def get_pca_data(self): 968 """ 969 Retrieve PCA embedding data with optional cluster labels. 970 971 Returns the principal component scores stored in `self.pca`. If clustering 972 metadata is available (specifically `PCA_clusters`), the corresponding 973 cluster assignments are appended as an additional column. 974 975 Returns 976 ------- 977 pandas.DataFrame 978 DataFrame containing PCA coordinates (columns: 'PC1', 'PC2', ...). 979 If available, includes an extra column 'clusters' with cluster labels. 980 981 Notes 982 ----- 983 - PCA must be computed beforehand (e.g., using `perform_PCA`). 984 - Cluster labels are added only if present in `self.clustering_metadata`. 985 """ 986 987 pca_data = self.pca 988 989 try: 990 pca_data["clusters"] = self.clustering_metadata["PCA_clusters"] 991 except: 992 pass 993 994 return pca_data
Retrieve PCA embedding data with optional cluster labels.
Returns the principal component scores stored in self.pca. If clustering
metadata is available (specifically PCA_clusters), the corresponding
cluster assignments are appended as an additional column.
Returns
pandas.DataFrame DataFrame containing PCA coordinates (columns: 'PC1', 'PC2', ...). If available, includes an extra column 'clusters' with cluster labels.
Notes
- PCA must be computed beforehand (e.g., using
perform_PCA). - Cluster labels are added only if present in
self.clustering_metadata.
996 def return_clusters(self, clusters="umap"): 997 """ 998 Retrieve cluster labels from UMAP or PCA clustering results. 999 1000 Parameters 1001 ---------- 1002 clusters : str, default 'umap' 1003 Source of cluster labels to return. Must be one of: 1004 - 'umap': return cluster labels from UMAP embeddings. 1005 - 'pca' : return cluster labels from PCA embeddings. 1006 1007 Returns 1008 ------- 1009 list 1010 Cluster labels corresponding to the selected embedding method. 1011 1012 Raises 1013 ------ 1014 ValueError 1015 If `clusters` is not 'umap' or 'pca'. 1016 1017 Notes 1018 ----- 1019 Requires that clustering has already been performed 1020 (e.g., using `find_clusters_UMAP` or `find_clusters_PCA`). 1021 """ 1022 1023 if clusters.lower() == "umap": 1024 clusters_vector = self.clustering_metadata["UMAP_clusters"] 1025 elif clusters.lower() == "pca": 1026 clusters_vector = self.clustering_metadata["PCA_clusters"] 1027 else: 1028 raise ValueError("Parameter 'clusters' must be either 'umap' or 'pca'.") 1029 1030 return clusters_vector
Retrieve cluster labels from UMAP or PCA clustering results.
Parameters
clusters : str, default 'umap' Source of cluster labels to return. Must be one of: - 'umap': return cluster labels from UMAP embeddings. - 'pca' : return cluster labels from PCA embeddings.
Returns
list Cluster labels corresponding to the selected embedding method.
Raises
ValueError
If clusters is not 'umap' or 'pca'.
Notes
Requires that clustering has already been performed
(e.g., using find_clusters_UMAP or find_clusters_PCA).
1033class COMPsc(Clustering): 1034 """ 1035 A class `COMPsc` (Comparison of single-cell data) designed for the integration, 1036 analysis, and visualization of single-cell datasets. 1037 The class supports independent dataset integration, subclustering of existing clusters, 1038 marker detection, and multiple visualization strategies. 1039 1040 The COMPsc class provides methods for: 1041 1042 - Normalizing and filtering single-cell data 1043 - Loading and saving sparse 10x-style datasets 1044 - Computing differential expression and marker genes 1045 - Clustering and subclustering analysis 1046 - Visualizing similarity and spatial relationships 1047 - Aggregating data by cell and set annotations 1048 - Managing metadata and renaming labels 1049 - Plotting gene detection histograms and feature scatters 1050 1051 Methods 1052 ------- 1053 project_dir(path_to_directory, project_list) 1054 Scans a directory to create a COMPsc instance mapping project names to their paths. 1055 1056 save_project(name, path=os.getcwd()) 1057 Saves the COMPsc object to a pickle file on disk. 1058 1059 load_project(path) 1060 Loads a previously saved COMPsc object from a pickle file. 1061 1062 reduce_cols(reg, inc_set=False) 1063 Removes columns from data tables where column names contain a specified name or partial substring. 1064 1065 reduce_rows(reg, inc_set=False) 1066 Removes rows from data tables where column names contain a specified feature (gene) name. 1067 1068 get_data(set_info=False) 1069 Returns normalized data with optional set annotations in column names. 1070 1071 get_metadata() 1072 Returns the stored input metadata. 1073 1074 get_partial_data(names=None, features=None, name_slot='cell_names') 1075 Return a subset of the data by sample names and/or features. 1076 1077 gene_calculation() 1078 Calculates and stores per-cell gene detection counts as a pandas Series. 1079 1080 gene_histograme(bins=100) 1081 Plots a histogram of genes detected per cell with an overlaid normal distribution. 1082 1083 gene_threshold(min_n=None, max_n=None) 1084 Filters cells based on minimum and/or maximum gene detection thresholds. 1085 1086 load_sparse_from_projects(normalized_data=False) 1087 Loads and concatenates sparse 10x-style datasets from project paths into count or normalized data. 1088 1089 rename_names(mapping, slot='cell_names') 1090 Renames entries in a specified metadata column using a provided mapping dictionary. 1091 1092 rename_subclusters(mapping) 1093 Renames subcluster labels using a provided mapping dictionary. 1094 1095 save_sparse(path_to_save=os.getcwd(), name_slot='cell_names', data_slot='normalized') 1096 Exports data as 10x-compatible sparse files (matrix.mtx, barcodes.tsv, genes.tsv). 1097 1098 normalize_data(normalize=True, normalize_factor=100000) 1099 Normalizes raw counts to counts-per-specified factor (e.g., CPM-like). 1100 1101 statistic(cells=None, sets=None, min_exp=0.01, min_pct=0.1, n_proc=10) 1102 Computes per-feature differential expression statistics (Mann-Whitney U) comparing target vs. rest groups. 1103 1104 calculate_difference_markers(min_exp=0, min_pct=0.25, n_proc=10, force=False) 1105 Computes and caches differential markers using the statistic method. 1106 1107 clustering_features(features_list=None, name_slot='cell_names', p_val=0.05, top_n=25, adj_mean=True, beta=0.4) 1108 Prepares clustering input by selecting marker features and optionally smoothing cell values. 1109 1110 average() 1111 Aggregates normalized data by averaging across (cell_name, set) pairs. 1112 1113 estimating_similarity(method='pearson', p_val=0.05, top_n=25) 1114 Computes pairwise correlation and Euclidean distance between aggregated samples. 1115 1116 similarity_plot(split_sets=True, set_info=True, cmap='seismic', width=12, height=10) 1117 Visualizes pairwise similarity as a scatter plot with correlation as hue and scaled distance as point size. 1118 1119 spatial_similarity(set_info=True, bandwidth=1, n_neighbors=5, min_dist=0.1, legend_split=2, point_size=20, ...) 1120 Creates a UMAP-like visualization of similarity relationships with cluster hulls and nearest-neighbor arrows. 1121 1122 subcluster_prepare(features, cluster) 1123 Initializes a Clustering object for subcluster analysis on a selected parent cluster. 1124 1125 define_subclusters(umap_num=2, eps=0.5, min_samples=10, bandwidth=1, n_neighbors=5, min_dist=0.1, ...) 1126 Performs UMAP and DBSCAN clustering on prepared subcluster data and stores cluster labels. 1127 1128 subcluster_features_scatter(colors='viridis', hclust='complete', img_width=3, img_high=5, label_size=6, ...) 1129 Visualizes averaged expression and occurrence of features for subclusters as a scatter plot. 1130 1131 subcluster_DEG_scatter(top_n=3, min_exp=0, min_pct=0.25, p_val=0.05, colors='viridis', ...) 1132 Plots top differential features for subclusters as a features-scatter visualization. 1133 1134 accept_subclusters() 1135 Commits subcluster labels to main metadata by renaming cell names and clears subcluster data. 1136 1137 Raises 1138 ------ 1139 ValueError 1140 For invalid parameters, mismatched dimensions, or missing metadata. 1141 1142 """ 1143 1144 def __init__( 1145 self, 1146 objects=None, 1147 ): 1148 """ 1149 Initialize the COMPsc class for single-cell data integration and analysis. 1150 1151 Parameters 1152 ---------- 1153 objects : list or None, optional 1154 Optional list of data objects to initialize the instance with. 1155 1156 Attributes 1157 ---------- 1158 -objects : list or None 1159 -input_data : pandas.DataFrame or None 1160 -input_metadata : pandas.DataFrame or None 1161 -normalized_data : pandas.DataFrame or None 1162 -agg_metadata : pandas.DataFrame or None 1163 -agg_normalized_data : pandas.DataFrame or None 1164 -similarity : pandas.DataFrame or None 1165 -var_data : pandas.DataFrame or None 1166 -subclusters_ : instance of Clustering class or None 1167 -cells_calc : pandas.Series or None 1168 -gene_calc : pandas.Series or None 1169 -composition_data : pandas.DataFrame or None 1170 """ 1171 1172 self.objects = objects 1173 """ Stores the input data objects.""" 1174 1175 self.input_data = None 1176 """Raw input data for clustering or integration analysis.""" 1177 1178 self.input_metadata = None 1179 """Metadata associated with the input data.""" 1180 1181 self.normalized_data = None 1182 """Normalized version of the input data.""" 1183 1184 self.agg_metadata = None 1185 '''Aggregated metadata for all sets in object related to "agg_normalized_data"''' 1186 1187 self.agg_normalized_data = None 1188 """Aggregated and normalized data across multiple sets.""" 1189 1190 self.similarity = None 1191 """Similarity data between cells across all samples. and sets""" 1192 1193 self.var_data = None 1194 """DEG analysis results summarizing variance across all samples in the object.""" 1195 1196 self.subclusters_ = None 1197 """Placeholder for information about subclusters analysis; if computed.""" 1198 1199 self.cells_calc = None 1200 """Number of cells detected per sample (grouped by lineage, e.g., cluster or name), reflecting data composition.""" 1201 1202 self.gene_calc = None 1203 """Number of genes detected per sample (cell), reflecting the sequencing depth.""" 1204 1205 self.composition_data = None 1206 """Data describing composition of cells across clusters or sets.""" 1207 1208 @classmethod 1209 def project_dir(cls, path_to_directory, project_list): 1210 """ 1211 Scan a directory and build a COMPsc instance mapping provided project names 1212 to their paths. 1213 1214 Parameters 1215 ---------- 1216 path_to_directory : str 1217 Path containing project subfolders. 1218 1219 project_list : list[str] 1220 List of filenames (folder names) to include in the returned object map. 1221 1222 Returns 1223 ------- 1224 COMPsc 1225 New COMPsc instance with `objects` populated. 1226 1227 Raises 1228 ------ 1229 Exception 1230 A generic exception is caught and a message printed if scanning fails. 1231 1232 Notes 1233 ----- 1234 Function attempts to match entries in `project_list` to directory 1235 names and constructs a simplified object key from the folder name. 1236 """ 1237 try: 1238 objects = {} 1239 for filename in tqdm(os.listdir(path_to_directory)): 1240 for c in project_list: 1241 f = os.path.join(path_to_directory, filename) 1242 if c == filename and os.path.isdir(f): 1243 objects[str(c)] = f 1244 1245 return cls(objects) 1246 1247 except: 1248 print("Something went wrong. Check the function input data and try again!") 1249 1250 def save_project(self, name, path: str = os.getcwd()): 1251 """ 1252 Save the COMPsc object to disk using pickle. 1253 1254 Parameters 1255 ---------- 1256 name : str 1257 Base filename (without extension) to use when saving. 1258 1259 path : str, default os.getcwd() 1260 Directory in which to save the project file. 1261 1262 Returns 1263 ------- 1264 None 1265 1266 Side Effects 1267 ------------ 1268 - Writes a file `<path>/<name>.jpkl` containing the pickled object. 1269 - Prints a confirmation message with saved path. 1270 """ 1271 1272 full = os.path.join(path, f"{name}.jpkl") 1273 1274 with open(full, "wb") as f: 1275 pickle.dump(self, f) 1276 1277 print(f"Project saved as {full}") 1278 1279 @classmethod 1280 def load_project(cls, path): 1281 """ 1282 Load a previously saved COMPsc project from a pickle file. 1283 1284 Parameters 1285 ---------- 1286 path : str 1287 Full path to the pickled project file. 1288 1289 Returns 1290 ------- 1291 COMPsc 1292 The unpickled COMPsc object. 1293 1294 Raises 1295 ------ 1296 FileNotFoundError 1297 If the provided path does not exist. 1298 """ 1299 1300 if not os.path.exists(path): 1301 raise FileNotFoundError("File does not exist!") 1302 with open(path, "rb") as f: 1303 obj = pickle.load(f) 1304 return obj 1305 1306 def reduce_cols( 1307 self, 1308 reg: str | None = None, 1309 full: str | None = None, 1310 name_slot: str = "cell_names", 1311 inc_set: bool = False, 1312 ): 1313 """ 1314 Remove columns (cells) whose names contain a substring `reg` or 1315 full name `full` from available tables. 1316 1317 Parameters 1318 ---------- 1319 reg : str | None 1320 Substring to search for in column/cell names; matching columns will be removed. 1321 If not None, `full` must be None. 1322 1323 full : str | None 1324 Full name to search for in column/cell names; matching columns will be removed. 1325 If not None, `reg` must be None. 1326 1327 name_slot : str, default 'cell_names' 1328 Column in metadata to use as sample names. 1329 1330 inc_set : bool, default False 1331 If True, column names are interpreted as 'cell_name # set' when matching. 1332 1333 Update 1334 ------------ 1335 Mutates `self.input_data`, `self.normalized_data`, `self.input_metadata`, 1336 `self.agg_normalized_data`, and `self.agg_metadata` (if they exist), 1337 removing columns/rows that match `reg`. 1338 1339 Raises 1340 ------ 1341 Raises ValueError if nothing matches the reduction mask. 1342 """ 1343 1344 if reg is None and full is None: 1345 raise ValueError( 1346 "Both 'reg' and 'full' arguments not provided. Please provide at least one of them!" 1347 ) 1348 1349 if reg is not None and full is not None: 1350 raise ValueError( 1351 "Both 'reg' and 'full' arguments are provided. " 1352 "Please provide only one of them!\n" 1353 "'reg' is used when only part of the name must be detected.\n" 1354 "'full' is used if the full name must be detected." 1355 ) 1356 1357 if reg is not None: 1358 1359 if self.input_data is not None: 1360 1361 if inc_set: 1362 1363 self.input_data.columns = ( 1364 self.input_metadata[name_slot] 1365 + " # " 1366 + self.input_metadata["sets"] 1367 ) 1368 1369 else: 1370 1371 self.input_data.columns = self.input_metadata[name_slot] 1372 1373 mask = [reg.upper() not in x.upper() for x in self.input_data.columns] 1374 1375 if len([y for y in mask if y is False]) == 0: 1376 raise ValueError("Nothing found to reduce") 1377 1378 self.input_data = self.input_data.loc[:, mask] 1379 1380 if self.normalized_data is not None: 1381 1382 if inc_set: 1383 1384 self.normalized_data.columns = ( 1385 self.input_metadata[name_slot] 1386 + " # " 1387 + self.input_metadata["sets"] 1388 ) 1389 1390 else: 1391 1392 self.normalized_data.columns = self.input_metadata[name_slot] 1393 1394 mask = [ 1395 reg.upper() not in x.upper() for x in self.normalized_data.columns 1396 ] 1397 1398 if len([y for y in mask if y is False]) == 0: 1399 raise ValueError("Nothing found to reduce") 1400 1401 self.normalized_data = self.normalized_data.loc[:, mask] 1402 1403 if self.input_metadata is not None: 1404 1405 if inc_set: 1406 1407 self.input_metadata["drop"] = ( 1408 self.input_metadata[name_slot] 1409 + " # " 1410 + self.input_metadata["sets"] 1411 ) 1412 1413 else: 1414 1415 self.input_metadata["drop"] = self.input_metadata[name_slot] 1416 1417 mask = [ 1418 reg.upper() not in x.upper() for x in self.input_metadata["drop"] 1419 ] 1420 1421 if len([y for y in mask if y is False]) == 0: 1422 raise ValueError("Nothing found to reduce") 1423 1424 self.input_metadata = self.input_metadata.loc[mask, :].reset_index( 1425 drop=True 1426 ) 1427 1428 self.input_metadata = self.input_metadata.drop( 1429 columns=["drop"], errors="ignore" 1430 ) 1431 1432 if self.agg_normalized_data is not None: 1433 1434 if inc_set: 1435 1436 self.agg_normalized_data.columns = ( 1437 self.agg_metadata[name_slot] + " # " + self.agg_metadata["sets"] 1438 ) 1439 1440 else: 1441 1442 self.agg_normalized_data.columns = self.agg_metadata[name_slot] 1443 1444 mask = [ 1445 reg.upper() not in x.upper() 1446 for x in self.agg_normalized_data.columns 1447 ] 1448 1449 if len([y for y in mask if y is False]) == 0: 1450 raise ValueError("Nothing found to reduce") 1451 1452 self.agg_normalized_data = self.agg_normalized_data.loc[:, mask] 1453 1454 if self.agg_metadata is not None: 1455 1456 if inc_set: 1457 1458 self.agg_metadata["drop"] = ( 1459 self.agg_metadata[name_slot] + " # " + self.agg_metadata["sets"] 1460 ) 1461 1462 else: 1463 1464 self.agg_metadata["drop"] = self.agg_metadata[name_slot] 1465 1466 mask = [reg.upper() not in x.upper() for x in self.agg_metadata["drop"]] 1467 1468 if len([y for y in mask if y is False]) == 0: 1469 raise ValueError("Nothing found to reduce") 1470 1471 self.agg_metadata = self.agg_metadata.loc[mask, :].reset_index( 1472 drop=True 1473 ) 1474 1475 self.agg_metadata = self.agg_metadata.drop( 1476 columns=["drop"], errors="ignore" 1477 ) 1478 1479 elif full is not None: 1480 1481 if self.input_data is not None: 1482 1483 if inc_set: 1484 1485 self.input_data.columns = ( 1486 self.input_metadata[name_slot] 1487 + " # " 1488 + self.input_metadata["sets"] 1489 ) 1490 1491 if "#" not in full: 1492 1493 self.input_data.columns = self.input_metadata[name_slot] 1494 1495 print( 1496 "Not include the set info (name # set) in the 'full' argument, where 'inc_set' is True" 1497 "Only the names will be compared, without considering the set information." 1498 ) 1499 1500 else: 1501 1502 self.input_data.columns = self.input_metadata[name_slot] 1503 1504 mask = [full.upper() != x.upper() for x in self.input_data.columns] 1505 1506 if len([y for y in mask if y is False]) == 0: 1507 raise ValueError("Nothing found to reduce") 1508 1509 self.input_data = self.input_data.loc[:, mask] 1510 1511 if self.normalized_data is not None: 1512 1513 if inc_set: 1514 1515 self.normalized_data.columns = ( 1516 self.input_metadata[name_slot] 1517 + " # " 1518 + self.input_metadata["sets"] 1519 ) 1520 1521 if "#" not in full: 1522 1523 self.normalized_data.columns = self.input_metadata[name_slot] 1524 1525 print( 1526 "Not include the set info (name # set) in the 'full' argument, where 'inc_set' is True" 1527 "Only the names will be compared, without considering the set information." 1528 ) 1529 1530 else: 1531 1532 self.normalized_data.columns = self.input_metadata[name_slot] 1533 1534 mask = [full.upper() != x.upper() for x in self.normalized_data.columns] 1535 1536 if len([y for y in mask if y is False]) == 0: 1537 raise ValueError("Nothing found to reduce") 1538 1539 self.normalized_data = self.normalized_data.loc[:, mask] 1540 1541 if self.input_metadata is not None: 1542 1543 if inc_set: 1544 1545 self.input_metadata["drop"] = ( 1546 self.input_metadata[name_slot] 1547 + " # " 1548 + self.input_metadata["sets"] 1549 ) 1550 1551 if "#" not in full: 1552 1553 self.input_metadata["drop"] = self.input_metadata[name_slot] 1554 1555 print( 1556 "Not include the set info (name # set) in the 'full' argument, where 'inc_set' is True" 1557 "Only the names will be compared, without considering the set information." 1558 ) 1559 1560 else: 1561 1562 self.input_metadata["drop"] = self.input_metadata[name_slot] 1563 1564 mask = [full.upper() != x.upper() for x in self.input_metadata["drop"]] 1565 1566 if len([y for y in mask if y is False]) == 0: 1567 raise ValueError("Nothing found to reduce") 1568 1569 self.input_metadata = self.input_metadata.loc[mask, :].reset_index( 1570 drop=True 1571 ) 1572 1573 self.input_metadata = self.input_metadata.drop( 1574 columns=["drop"], errors="ignore" 1575 ) 1576 1577 if self.agg_normalized_data is not None: 1578 1579 if inc_set: 1580 1581 self.agg_normalized_data.columns = ( 1582 self.agg_metadata[name_slot] + " # " + self.agg_metadata["sets"] 1583 ) 1584 1585 if "#" not in full: 1586 1587 self.agg_normalized_data.columns = self.agg_metadata[name_slot] 1588 1589 print( 1590 "Not include the set info (name # set) in the 'full' argument, where 'inc_set' is True" 1591 "Only the names will be compared, without considering the set information." 1592 ) 1593 else: 1594 1595 self.agg_normalized_data.columns = self.agg_metadata[name_slot] 1596 1597 mask = [ 1598 full.upper() != x.upper() for x in self.agg_normalized_data.columns 1599 ] 1600 1601 if len([y for y in mask if y is False]) == 0: 1602 raise ValueError("Nothing found to reduce") 1603 1604 self.agg_normalized_data = self.agg_normalized_data.loc[:, mask] 1605 1606 if self.agg_metadata is not None: 1607 1608 if inc_set: 1609 1610 self.agg_metadata["drop"] = ( 1611 self.agg_metadata[name_slot] + " # " + self.agg_metadata["sets"] 1612 ) 1613 1614 if "#" not in full: 1615 1616 self.agg_metadata["drop"] = self.agg_metadata[name_slot] 1617 1618 print( 1619 "Not include the set info (name # set) in the 'full' argument, where 'inc_set' is True" 1620 "Only the names will be compared, without considering the set information." 1621 ) 1622 else: 1623 1624 self.agg_metadata["drop"] = self.agg_metadata[name_slot] 1625 1626 mask = [full.upper() != x.upper() for x in self.agg_metadata["drop"]] 1627 1628 if len([y for y in mask if y is False]) == 0: 1629 raise ValueError("Nothing found to reduce") 1630 1631 self.agg_metadata = self.agg_metadata.loc[mask, :].reset_index( 1632 drop=True 1633 ) 1634 1635 self.agg_metadata = self.agg_metadata.drop( 1636 columns=["drop"], errors="ignore" 1637 ) 1638 1639 self.gene_calculation() 1640 self.cells_calculation() 1641 1642 def reduce_rows(self, features_list: list): 1643 """ 1644 Remove rows (features) whose names are included in features_list. 1645 1646 Parameters 1647 ---------- 1648 features_list : list 1649 List of features to search for in index/gene names; matching entries will be removed. 1650 1651 Update 1652 ------------ 1653 Mutates `self.input_data`, `self.normalized_data`, `self.input_metadata`, 1654 `self.agg_normalized_data`, and `self.agg_metadata` (if they exist), 1655 removing columns/rows that match `reg`. 1656 1657 Raises 1658 ------ 1659 Prints a message listing features that are not found in the data. 1660 """ 1661 1662 if self.input_data is not None: 1663 1664 res = find_features(self.input_data, features=features_list) 1665 1666 res_list = [x.upper() for x in res["included"]] 1667 1668 mask = [x.upper() not in res_list for x in self.input_data.index] 1669 1670 if len([y for y in mask if y is False]) == 0: 1671 raise ValueError("Nothing found to reduce") 1672 1673 self.input_data = self.input_data.loc[mask, :] 1674 1675 if self.normalized_data is not None: 1676 1677 res = find_features(self.normalized_data, features=features_list) 1678 1679 res_list = [x.upper() for x in res["included"]] 1680 1681 mask = [x.upper() not in res_list for x in self.normalized_data.index] 1682 1683 if len([y for y in mask if y is False]) == 0: 1684 raise ValueError("Nothing found to reduce") 1685 1686 self.normalized_data = self.normalized_data.loc[mask, :] 1687 1688 if self.agg_normalized_data is not None: 1689 1690 res = find_features(self.agg_normalized_data, features=features_list) 1691 1692 res_list = [x.upper() for x in res["included"]] 1693 1694 mask = [x.upper() not in res_list for x in self.agg_normalized_data.index] 1695 1696 if len([y for y in mask if y is False]) == 0: 1697 raise ValueError("Nothing found to reduce") 1698 1699 self.agg_normalized_data = self.agg_normalized_data.loc[mask, :] 1700 1701 if len(res["not_included"]) > 0: 1702 print("\nFeatures not found:") 1703 for i in res["not_included"]: 1704 print(i) 1705 1706 self.gene_calculation() 1707 self.cells_calculation() 1708 1709 def get_data(self, set_info: bool = False): 1710 """ 1711 Return normalized data with optional set annotation appended to column names. 1712 1713 Parameters 1714 ---------- 1715 set_info : bool, default False 1716 If True, column names are returned as "cell_name # set"; otherwise 1717 only the `cell_name` is used. 1718 1719 Returns 1720 ------- 1721 pandas.DataFrame 1722 The `self.normalized_data` table with columns renamed according to `set_info`. 1723 1724 Raises 1725 ------ 1726 AttributeError 1727 If `self.normalized_data` or `self.input_metadata` is missing. 1728 """ 1729 1730 to_return = self.normalized_data 1731 1732 if set_info: 1733 to_return.columns = ( 1734 self.input_metadata["cell_names"] + " # " + self.input_metadata["sets"] 1735 ) 1736 else: 1737 to_return.columns = self.input_metadata["cell_names"] 1738 1739 return to_return 1740 1741 def get_partial_data( 1742 self, 1743 names: list | str | None = None, 1744 features: list | str | None = None, 1745 name_slot: str = "cell_names", 1746 inc_metadata: bool = False, 1747 ): 1748 """ 1749 Return a subset of the data filtered by sample names and/or feature names. 1750 1751 Parameters 1752 ---------- 1753 names : list, str, or None 1754 Names of samples to include. If None, all samples are considered. 1755 1756 features : list, str, or None 1757 Names of features to include. If None, all features are considered. 1758 1759 name_slot : str 1760 Column in metadata to use as sample names. 1761 1762 inc_metadata : bool 1763 If True return tuple (data, metadata) 1764 1765 Returns 1766 ------- 1767 pandas.DataFrame 1768 Subset of the normalized data based on the specified names and features. 1769 """ 1770 1771 data = self.normalized_data.copy() 1772 metadata = self.input_metadata 1773 1774 if name_slot in self.input_metadata.columns: 1775 data.columns = self.input_metadata[name_slot] 1776 else: 1777 raise ValueError("'name_slot' not occured in data!'") 1778 1779 if isinstance(features, str): 1780 features = [features] 1781 elif features is None: 1782 features = [] 1783 1784 if isinstance(names, str): 1785 names = [names] 1786 elif names is None: 1787 names = [] 1788 1789 features = [x.upper() for x in features] 1790 names = [x.upper() for x in names] 1791 1792 columns_names = [x.upper() for x in data.columns] 1793 features_names = [x.upper() for x in data.index] 1794 1795 columns_bool = [True if x in names else False for x in columns_names] 1796 features_bool = [True if x in features else False for x in features_names] 1797 1798 if True not in columns_bool and True not in features_bool: 1799 print("Missing 'names' and/or 'features'. Returning full dataset instead.") 1800 1801 if True in columns_bool: 1802 data = data.loc[:, columns_bool] 1803 metadata = metadata.loc[columns_bool, :] 1804 1805 if True in features_bool: 1806 data = data.loc[features_bool, :] 1807 1808 not_in_features = [y for y in features if y not in features_names] 1809 1810 if len(not_in_features) > 0: 1811 print('\nThe following features were not found in data:') 1812 print('\n'.join(not_in_features)) 1813 1814 not_in_names = [y for y in names if y not in columns_names] 1815 1816 if len(not_in_names) > 0: 1817 print('\nThe following names were not found in data:') 1818 print('\n'.join(not_in_names)) 1819 1820 if inc_metadata: 1821 return data, metadata 1822 else: 1823 return data 1824 1825 def get_metadata(self): 1826 """ 1827 Return the stored input metadata. 1828 1829 Returns 1830 ------- 1831 pandas.DataFrame 1832 `self.input_metadata` (may be None if not set). 1833 """ 1834 1835 to_return = self.input_metadata 1836 1837 return to_return 1838 1839 def gene_calculation(self): 1840 """ 1841 Calculate and store per-cell counts (e.g., number of detected genes). 1842 1843 The method computes a binary (presence/absence) per cell and sums across 1844 features to produce `self.gene_calc`. 1845 1846 Update 1847 ------ 1848 Sets `self.gene_calc` as a pandas.Series. 1849 1850 Side Effects 1851 ------------ 1852 Uses `self.input_data` when available, otherwise `self.normalized_data`. 1853 """ 1854 1855 if self.input_data is not None: 1856 1857 bin_col = self.input_data.columns.copy() 1858 1859 bin_col = bin_col.where(bin_col <= 0, 1) 1860 1861 sum_data = bin_col.sum(axis=0) 1862 1863 self.gene_calc = sum_data 1864 1865 elif self.normalized_data is not None: 1866 1867 bin_col = self.normalized_data.copy() 1868 1869 bin_col = bin_col.where(bin_col <= 0, 1) 1870 1871 sum_data = bin_col.sum(axis=0) 1872 1873 self.gene_calc = sum_data 1874 1875 def gene_histograme(self, bins=100): 1876 """ 1877 Plot a histogram of the number of genes detected per cell. 1878 1879 Parameters 1880 ---------- 1881 bins : int, default 100 1882 Number of histogram bins. 1883 1884 Returns 1885 ------- 1886 matplotlib.figure.Figure 1887 Figure containing the histogram of gene contents. 1888 1889 Notes 1890 ----- 1891 Requires `self.gene_calc` to be computed prior to calling. 1892 """ 1893 1894 fig, ax = plt.subplots(figsize=(8, 5)) 1895 1896 _, bin_edges, _ = ax.hist( 1897 self.gene_calc, bins=bins, edgecolor="black", alpha=0.6 1898 ) 1899 1900 mu, sigma = np.mean(self.gene_calc), np.std(self.gene_calc) 1901 1902 x = np.linspace(min(self.gene_calc), max(self.gene_calc), 1000) 1903 y = norm.pdf(x, mu, sigma) 1904 1905 y_scaled = y * len(self.gene_calc) * (bin_edges[1] - bin_edges[0]) 1906 1907 ax.plot( 1908 x, y_scaled, "r-", linewidth=2, label=f"Normal(μ={mu:.2f}, σ={sigma:.2f})" 1909 ) 1910 1911 ax.set_xlabel("Value") 1912 ax.set_ylabel("Count") 1913 ax.set_title("Histogram of genes detected per cell") 1914 1915 ax.set_xticks(np.linspace(min(self.gene_calc), max(self.gene_calc), 20)) 1916 ax.tick_params(axis="x", rotation=90) 1917 1918 ax.legend() 1919 1920 return fig 1921 1922 def gene_threshold(self, min_n: int | None, max_n: int | None): 1923 """ 1924 Filter cells by gene-detection thresholds (min and/or max). 1925 1926 Parameters 1927 ---------- 1928 min_n : int or None 1929 Minimum number of detected genes required to keep a cell. 1930 1931 max_n : int or None 1932 Maximum number of detected genes allowed to keep a cell. 1933 1934 Update 1935 ------- 1936 Filters `self.input_data`, `self.normalized_data`, `self.input_metadata` 1937 (and calls `average()` if `self.agg_normalized_data` exists). 1938 1939 Side Effects 1940 ------------ 1941 Raises ValueError if both bounds are None or if filtering removes all cells. 1942 """ 1943 1944 if min_n is not None and max_n is not None: 1945 mask = (self.gene_calc > min_n) & (self.gene_calc < max_n) 1946 elif min_n is None and max_n is not None: 1947 mask = self.gene_calc < max_n 1948 elif min_n is not None and max_n is None: 1949 mask = self.gene_calc > min_n 1950 else: 1951 raise ValueError("Lack of both min_n and max_n values") 1952 1953 if self.input_data is not None: 1954 1955 if len([y for y in mask if y is False]) == 0: 1956 raise ValueError("Nothing to reduce") 1957 1958 self.input_data = self.input_data.loc[:, mask.values] 1959 1960 if self.normalized_data is not None: 1961 1962 if len([y for y in mask if y is False]) == 0: 1963 raise ValueError("Nothing to reduce") 1964 1965 self.normalized_data = self.normalized_data.loc[:, mask.values] 1966 1967 if self.input_metadata is not None: 1968 1969 if len([y for y in mask if y is False]) == 0: 1970 raise ValueError("Nothing to reduce") 1971 1972 self.input_metadata = self.input_metadata.loc[mask.values, :].reset_index( 1973 drop=True 1974 ) 1975 1976 self.input_metadata = self.input_metadata.drop( 1977 columns=["drop"], errors="ignore" 1978 ) 1979 1980 if self.agg_normalized_data is not None: 1981 self.average() 1982 1983 self.gene_calculation() 1984 self.cells_calculation() 1985 1986 def cells_calculation(self, name_slot="cell_names"): 1987 """ 1988 Calculate number of cells per call name / cluster. 1989 1990 The method computes a binary (presence/absence) per cell name / cluster and sums across 1991 cells. 1992 1993 Parameters 1994 ---------- 1995 name_slot : str, default 'cell_names' 1996 Column in metadata to use as sample names. 1997 1998 Update 1999 ------ 2000 Sets `self.cells_calc` as a pd.DataFrame. 2001 """ 2002 2003 ls = list(self.input_metadata[name_slot]) 2004 2005 df = pd.DataFrame( 2006 { 2007 "cluster": pd.Series(ls).value_counts().index, 2008 "n": pd.Series(ls).value_counts().values, 2009 } 2010 ) 2011 2012 self.cells_calc = df 2013 2014 def cell_histograme(self, name_slot: str = "cell_names"): 2015 """ 2016 Plot a histogram of the number of cells detected per cell name (cluster). 2017 2018 Parameters 2019 ---------- 2020 name_slot : str, default 'cell_names' 2021 Column in metadata to use as sample names. 2022 2023 Returns 2024 ------- 2025 matplotlib.figure.Figure 2026 Figure containing the histogram of cell contents. 2027 2028 Notes 2029 ----- 2030 Requires `self.cells_calc` to be computed prior to calling. 2031 """ 2032 2033 if name_slot != "cell_names": 2034 self.cells_calculation(name_slot=name_slot) 2035 2036 fig, ax = plt.subplots(figsize=(8, 5)) 2037 2038 _, bin_edges, _ = ax.hist( 2039 list(self.cells_calc["n"]), 2040 bins=len(set(self.cells_calc["cluster"])), 2041 edgecolor="black", 2042 color="orange", 2043 alpha=0.6, 2044 ) 2045 2046 mu, sigma = np.mean(list(self.cells_calc["n"])), np.std( 2047 list(self.cells_calc["n"]) 2048 ) 2049 2050 x = np.linspace( 2051 min(list(self.cells_calc["n"])), max(list(self.cells_calc["n"])), 1000 2052 ) 2053 y = norm.pdf(x, mu, sigma) 2054 2055 y_scaled = y * len(list(self.cells_calc["n"])) * (bin_edges[1] - bin_edges[0]) 2056 2057 ax.plot( 2058 x, y_scaled, "r-", linewidth=2, label=f"Normal(μ={mu:.2f}, σ={sigma:.2f})" 2059 ) 2060 2061 ax.set_xlabel("Value") 2062 ax.set_ylabel("Count") 2063 ax.set_title("Histogram of cells detected per cell name / cluster") 2064 2065 ax.set_xticks( 2066 np.linspace( 2067 min(list(self.cells_calc["n"])), max(list(self.cells_calc["n"])), 20 2068 ) 2069 ) 2070 ax.tick_params(axis="x", rotation=90) 2071 2072 ax.legend() 2073 2074 return fig 2075 2076 def cluster_threshold(self, min_n: int | None, name_slot: str = "cell_names"): 2077 """ 2078 Filter cell names / clusters by cell-detection threshold. 2079 2080 Parameters 2081 ---------- 2082 min_n : int or None 2083 Minimum number of detected genes required to keep a cell. 2084 2085 name_slot : str, default 'cell_names' 2086 Column in metadata to use as sample names. 2087 2088 2089 Update 2090 ------- 2091 Filters `self.input_data`, `self.normalized_data`, `self.input_metadata` 2092 (and calls `average()` if `self.agg_normalized_data` exists). 2093 """ 2094 2095 if name_slot != "cell_names": 2096 self.cells_calculation(name_slot=name_slot) 2097 2098 if min_n is not None: 2099 names = self.cells_calc["cluster"][self.cells_calc["n"] < min_n] 2100 else: 2101 raise ValueError("Lack of min_n value") 2102 2103 if len(names) > 0: 2104 2105 if self.input_data is not None: 2106 2107 self.input_data.columns = self.input_metadata[name_slot] 2108 2109 mask = [not any(r in x for r in names) for x in self.input_data.columns] 2110 2111 if len([y for y in mask if y is False]) > 0: 2112 2113 self.input_data = self.input_data.loc[:, mask] 2114 2115 if self.normalized_data is not None: 2116 2117 self.normalized_data.columns = self.input_metadata[name_slot] 2118 2119 mask = [ 2120 not any(r in x for r in names) for x in self.normalized_data.columns 2121 ] 2122 2123 if len([y for y in mask if y is False]) > 0: 2124 2125 self.normalized_data = self.normalized_data.loc[:, mask] 2126 2127 if self.input_metadata is not None: 2128 2129 self.input_metadata["drop"] = self.input_metadata[name_slot] 2130 2131 mask = [ 2132 not any(r in x for r in names) for x in self.input_metadata["drop"] 2133 ] 2134 2135 if len([y for y in mask if y is False]) > 0: 2136 2137 self.input_metadata = self.input_metadata.loc[mask, :].reset_index( 2138 drop=True 2139 ) 2140 2141 self.input_metadata = self.input_metadata.drop( 2142 columns=["drop"], errors="ignore" 2143 ) 2144 2145 if self.agg_normalized_data is not None: 2146 2147 self.agg_normalized_data.columns = self.agg_metadata[name_slot] 2148 2149 mask = [ 2150 not any(r in x for r in names) 2151 for x in self.agg_normalized_data.columns 2152 ] 2153 2154 if len([y for y in mask if y is False]) > 0: 2155 2156 self.agg_normalized_data = self.agg_normalized_data.loc[:, mask] 2157 2158 if self.agg_metadata is not None: 2159 2160 self.agg_metadata["drop"] = self.agg_metadata[name_slot] 2161 2162 mask = [ 2163 not any(r in x for r in names) for x in self.agg_metadata["drop"] 2164 ] 2165 2166 if len([y for y in mask if y is False]) > 0: 2167 2168 self.agg_metadata = self.agg_metadata.loc[mask, :].reset_index( 2169 drop=True 2170 ) 2171 2172 self.agg_metadata = self.agg_metadata.drop( 2173 columns=["drop"], errors="ignore" 2174 ) 2175 2176 self.gene_calculation() 2177 self.cells_calculation() 2178 2179 def load_sparse_from_projects(self, normalized_data: bool = False): 2180 """ 2181 Load sparse 10x-style datasets from stored project paths, concatenate them, 2182 and populate `input_data` / `normalized_data` and `input_metadata`. 2183 2184 Parameters 2185 ---------- 2186 normalized_data : bool, default False 2187 If True, store concatenated tables in `self.normalized_data`. 2188 If False, store them in `self.input_data` and normalization 2189 is needed using normalize_data() method. 2190 2191 Side Effects 2192 ------------ 2193 - Reads each project using `load_sparse(...)` (expects matrix.mtx, genes.tsv, barcodes.tsv). 2194 - Concatenates all projects column-wise and sets `self.input_metadata`. 2195 - Replaces NaNs with zeros and updates `self.gene_calc`. 2196 """ 2197 2198 obj = self.objects 2199 2200 full_data = pd.DataFrame() 2201 full_metadata = pd.DataFrame() 2202 2203 for ke in obj.keys(): 2204 print(ke) 2205 2206 dt, met = load_sparse(path=obj[ke], name=ke) 2207 2208 full_data = pd.concat([full_data, dt], axis=1) 2209 full_metadata = pd.concat([full_metadata, met], axis=0) 2210 2211 full_data[np.isnan(full_data)] = 0 2212 2213 if normalized_data: 2214 self.normalized_data = full_data 2215 self.input_metadata = full_metadata 2216 else: 2217 2218 self.input_data = full_data 2219 self.input_metadata = full_metadata 2220 2221 self.gene_calculation() 2222 self.cells_calculation() 2223 2224 def rename_names(self, mapping: dict, slot: str = "cell_names"): 2225 """ 2226 Rename entries in `self.input_metadata[slot]` according to a provided mapping. 2227 2228 Parameters 2229 ---------- 2230 mapping : dict 2231 Dictionary with keys 'old_name' and 'new_name', each mapping to a list 2232 of equal length describing replacements. 2233 2234 slot : str, default 'cell_names' 2235 Metadata column to operate on. 2236 2237 Update 2238 ------- 2239 Updates `self.input_metadata[slot]` in-place with renamed values. 2240 2241 Raises 2242 ------ 2243 ValueError 2244 If mapping keys are incorrect, lengths differ, or some 'old_name' values 2245 are not present in the metadata column. 2246 """ 2247 2248 if set(["old_name", "new_name"]) != set(mapping.keys()): 2249 raise ValueError( 2250 "Mapping dictionary must contain keys 'old_name' and 'new_name', " 2251 "each with a list of names to change." 2252 ) 2253 2254 if len(mapping["old_name"]) != len(mapping["new_name"]): 2255 raise ValueError( 2256 "Mapping dictionary lists 'old_name' and 'new_name' " 2257 "must have the same length!" 2258 ) 2259 2260 names_vector = list(self.input_metadata[slot]) 2261 2262 if not all(elem in names_vector for elem in list(mapping["old_name"])): 2263 raise ValueError( 2264 f"Some entries from 'old_name' do not exist in the names of slot {slot}." 2265 ) 2266 2267 replace_dict = dict(zip(mapping["old_name"], mapping["new_name"])) 2268 2269 names_vector_ret = [replace_dict.get(item, item) for item in names_vector] 2270 2271 self.input_metadata[slot] = names_vector_ret 2272 2273 def rename_subclusters(self, mapping): 2274 """ 2275 Rename labels stored in `self.subclusters_.subclusters` according to mapping. 2276 2277 Parameters 2278 ---------- 2279 mapping : dict 2280 Mapping with keys 'old_name' and 'new_name' (lists of equal length). 2281 2282 Update 2283 ------- 2284 Updates `self.subclusters_.subclusters` with renamed labels. 2285 2286 Raises 2287 ------ 2288 ValueError 2289 If mapping is invalid or old names are not present. 2290 """ 2291 2292 if set(["old_name", "new_name"]) != set(mapping.keys()): 2293 raise ValueError( 2294 "Mapping dictionary must contain keys 'old_name' and 'new_name', " 2295 "each with a list of names to change." 2296 ) 2297 2298 if len(mapping["old_name"]) != len(mapping["new_name"]): 2299 raise ValueError( 2300 "Mapping dictionary lists 'old_name' and 'new_name' " 2301 "must have the same length!" 2302 ) 2303 2304 names_vector = list(self.subclusters_.subclusters) 2305 2306 if not all(elem in names_vector for elem in list(mapping["old_name"])): 2307 raise ValueError( 2308 "Some entries from 'old_name' do not exist in the subcluster names." 2309 ) 2310 2311 replace_dict = dict(zip(mapping["old_name"], mapping["new_name"])) 2312 2313 names_vector_ret = [replace_dict.get(item, item) for item in names_vector] 2314 2315 self.subclusters_.subclusters = names_vector_ret 2316 2317 def save_sparse( 2318 self, 2319 path_to_save: str = os.getcwd(), 2320 name_slot: str = "cell_names", 2321 data_slot: str = "normalized", 2322 ): 2323 """ 2324 Export data as 10x-compatible sparse files (matrix.mtx, barcodes.tsv, genes.tsv). 2325 2326 Parameters 2327 ---------- 2328 path_to_save : str, default current working directory 2329 Directory where files will be written. 2330 2331 name_slot : str, default 'cell_names' 2332 Metadata column providing cell names for barcodes.tsv. 2333 2334 data_slot : str, default 'normalized' 2335 Either 'normalized' (uses self.normalized_data) or 'count' (uses self.input_data). 2336 2337 Raises 2338 ------ 2339 ValueError 2340 If `data_slot` is not 'normalized' or 'count'. 2341 """ 2342 2343 names = self.input_metadata[name_slot] 2344 2345 if data_slot.lower() == "normalized": 2346 2347 features = list(self.normalized_data.index) 2348 mtx = sparse.csr_matrix(self.normalized_data) 2349 2350 elif data_slot.lower() == "count": 2351 2352 features = list(self.input_data.index) 2353 mtx = sparse.csr_matrix(self.input_data) 2354 2355 else: 2356 raise ValueError("'data_slot' must be included in 'normalized' or 'count'") 2357 2358 os.makedirs(path_to_save, exist_ok=True) 2359 2360 mmwrite(os.path.join(path_to_save, "matrix.mtx"), mtx) 2361 2362 pd.Series(names).to_csv( 2363 os.path.join(path_to_save, "barcodes.tsv"), 2364 index=False, 2365 header=False, 2366 sep="\t", 2367 ) 2368 2369 pd.Series(features).to_csv( 2370 os.path.join(path_to_save, "genes.tsv"), index=False, header=False, sep="\t" 2371 ) 2372 2373 def normalize_counts( 2374 self, normalize_factor: int = 100000, log_transform: bool = True 2375 ): 2376 """ 2377 Normalize raw counts to counts-per-(normalize_factor) 2378 (e.g., CPM, TPM - depending on normalize_factor). 2379 2380 Parameters 2381 ---------- 2382 normalize_factor : int, default 100000 2383 Scaling factor used after dividing by column sums. 2384 2385 log_transform : bool, default True 2386 If True, apply log2(x+1) transformation to normalized values. 2387 2388 Update 2389 ------- 2390 Sets `self.normalized_data` to normalized values (fills NaNs with 0). 2391 2392 Raises 2393 ------ 2394 ValueError 2395 If `self.input_data` is missing (cannot normalize). 2396 """ 2397 if self.input_data is None: 2398 raise ValueError("Input data is missing, cannot normalize.") 2399 2400 sum_col = self.input_data.sum() 2401 self.normalized_data = self.input_data.div(sum_col).fillna(0) * normalize_factor 2402 2403 if log_transform: 2404 # log2(x + 1) to avoid -inf for zeros 2405 self.normalized_data = np.log2(self.normalized_data + 1) 2406 2407 def statistic( 2408 self, 2409 cells=None, 2410 sets=None, 2411 min_exp: float = 0.01, 2412 min_pct: float = 0.1, 2413 n_proc: int = 10, 2414 ): 2415 """ 2416 Compute per-feature statistics (Mann–Whitney U) comparing target vs rest. 2417 2418 This is a wrapper similar to `calc_DEG` tailored to use `self.normalized_data` 2419 and `self.input_metadata`. It returns per-feature statistics including p-values, 2420 adjusted p-values, means, variances, effect-size measures and fold-changes. 2421 2422 Parameters 2423 ---------- 2424 cells : list, 'All', dict, or None 2425 Defines the target cells or groups for comparison (several modes supported). 2426 2427 sets : 'All', dict, or None 2428 Alternative grouping mode (operate on `self.input_metadata['sets']`). 2429 2430 min_exp : float, default 0.01 2431 Minimum expression threshold used when filtering features. 2432 2433 min_pct : float, default 0.1 2434 Minimum proportion of expressing cells in the target group required to test a feature. 2435 2436 n_proc : int, default 10 2437 Number of parallel jobs to use. 2438 2439 Returns 2440 ------- 2441 pandas.DataFrame or dict 2442 Results DataFrame (or dict containing valid/control cells + DataFrame), 2443 similar to `calc_DEG` interface. 2444 2445 Raises 2446 ------ 2447 ValueError 2448 If neither `cells` nor `sets` is provided, or input metadata mismatch occurs. 2449 2450 Notes 2451 ----- 2452 Multiple modes supported: single-list entities, 'All', pairwise dicts, etc. 2453 """ 2454 2455 offset = 1e-100 2456 2457 def stat_calc(choose, feature_name): 2458 target_values = choose.loc[choose["DEG"] == "target", feature_name] 2459 rest_values = choose.loc[choose["DEG"] == "rest", feature_name] 2460 2461 pct_valid = (target_values > 0).sum() / len(target_values) 2462 pct_rest = (rest_values > 0).sum() / len(rest_values) 2463 2464 avg_valid = np.mean(target_values) 2465 avg_ctrl = np.mean(rest_values) 2466 sd_valid = np.std(target_values, ddof=1) 2467 sd_ctrl = np.std(rest_values, ddof=1) 2468 esm = (avg_valid - avg_ctrl) / np.sqrt(((sd_valid**2 + sd_ctrl**2) / 2)) 2469 2470 if np.sum(target_values) == np.sum(rest_values): 2471 p_val = 1.0 2472 else: 2473 _, p_val = stats.mannwhitneyu( 2474 target_values, rest_values, alternative="two-sided" 2475 ) 2476 2477 return { 2478 "feature": feature_name, 2479 "p_val": p_val, 2480 "pct_valid": pct_valid, 2481 "pct_ctrl": pct_rest, 2482 "avg_valid": avg_valid, 2483 "avg_ctrl": avg_ctrl, 2484 "sd_valid": sd_valid, 2485 "sd_ctrl": sd_ctrl, 2486 "esm": esm, 2487 } 2488 2489 def prepare_and_run_stat(choose, valid_group, min_exp, min_pct, n_proc): 2490 2491 def safe_min_half(series): 2492 filtered = series[(series > ((2**-1074)*2)) & (series.notna())] 2493 return filtered.min() / 2 if not filtered.empty else 0 2494 2495 tmp_dat = choose[choose["DEG"] == "target"] 2496 tmp_dat = tmp_dat.drop("DEG", axis=1) 2497 2498 counts = (tmp_dat > min_exp).sum(axis=0) 2499 2500 total_count = tmp_dat.shape[0] 2501 2502 info = pd.DataFrame( 2503 {"feature": list(tmp_dat.columns), "pct": list(counts / total_count)} 2504 ) 2505 2506 del tmp_dat 2507 2508 drop_col = info["feature"][info["pct"] <= min_pct] 2509 2510 if len(drop_col) + 1 == len(choose.columns): 2511 drop_col = info["feature"][info["pct"] == 0] 2512 2513 del info 2514 2515 choose = choose.drop(list(drop_col), axis=1) 2516 2517 results = Parallel(n_jobs=n_proc)( 2518 delayed(stat_calc)(choose, feature) 2519 for feature in tqdm(choose.columns[choose.columns != "DEG"]) 2520 ) 2521 2522 df = pd.DataFrame(results) 2523 df = df[(df["avg_valid"] > 0) | (df["avg_ctrl"] > 0)] 2524 2525 df["valid_group"] = valid_group 2526 df.sort_values(by="p_val", inplace=True) 2527 2528 num_tests = len(df) 2529 df["adj_pval"] = np.minimum( 2530 1, (df["p_val"] * num_tests) / np.arange(1, num_tests + 1) 2531 ) 2532 2533 valid_factor = safe_min_half(df["avg_valid"]) 2534 ctrl_factor = safe_min_half(df["avg_ctrl"]) 2535 2536 cv_factor = min(valid_factor, ctrl_factor) 2537 2538 if cv_factor == 0: 2539 cv_factor = max(valid_factor, ctrl_factor) 2540 2541 if not np.isfinite(cv_factor) or cv_factor == 0: 2542 cv_factor += offset 2543 2544 valid = df["avg_valid"].where( 2545 df["avg_valid"] != 0, df["avg_valid"] + cv_factor 2546 ) 2547 ctrl = df["avg_ctrl"].where( 2548 df["avg_ctrl"] != 0, df["avg_ctrl"] + cv_factor 2549 ) 2550 2551 df["FC"] = valid / ctrl 2552 2553 df["log(FC)"] = np.log2(df["FC"]) 2554 df["norm_diff"] = df["avg_valid"] - df["avg_ctrl"] 2555 2556 return df 2557 2558 choose = self.normalized_data.copy().T 2559 2560 final_results = [] 2561 2562 if isinstance(cells, list) and sets is None: 2563 print("\nAnalysis started...\nComparing selected cells to the whole set...") 2564 choose.index = ( 2565 self.input_metadata["cell_names"] + " # " + self.input_metadata["sets"] 2566 ) 2567 2568 if "#" not in cells[0]: 2569 choose.index = self.input_metadata["cell_names"] 2570 2571 print( 2572 "Not include the set info (name # set) in the 'cells' list. " 2573 "Only the names will be compared, without considering the set information." 2574 ) 2575 2576 labels = ["target" if idx in cells else "rest" for idx in choose.index] 2577 valid = list( 2578 set(choose.index[[i for i, x in enumerate(labels) if x == "target"]]) 2579 ) 2580 2581 choose["DEG"] = labels 2582 choose = choose[choose["DEG"] != "drop"] 2583 2584 result_df = prepare_and_run_stat( 2585 choose.reset_index(drop=True), 2586 valid_group=valid, 2587 min_exp=min_exp, 2588 min_pct=min_pct, 2589 n_proc=n_proc, 2590 ) 2591 return {"valid_cells": valid, "control_cells": "rest", "DEG": result_df} 2592 2593 elif cells == "All" and sets is None: 2594 print("\nAnalysis started...\nComparing each type of cell to others...") 2595 choose.index = ( 2596 self.input_metadata["cell_names"] + " # " + self.input_metadata["sets"] 2597 ) 2598 unique_labels = set(choose.index) 2599 2600 for label in tqdm(unique_labels): 2601 print(f"\nCalculating statistics for {label}") 2602 labels = ["target" if idx == label else "rest" for idx in choose.index] 2603 choose["DEG"] = labels 2604 choose = choose[choose["DEG"] != "drop"] 2605 result_df = prepare_and_run_stat( 2606 choose.copy(), 2607 valid_group=label, 2608 min_exp=min_exp, 2609 min_pct=min_pct, 2610 n_proc=n_proc, 2611 ) 2612 final_results.append(result_df) 2613 2614 return pd.concat(final_results, ignore_index=True) 2615 2616 elif cells is None and sets == "All": 2617 print("\nAnalysis started...\nComparing each set/group to others...") 2618 choose.index = self.input_metadata["sets"] 2619 unique_sets = set(choose.index) 2620 2621 for label in tqdm(unique_sets): 2622 print(f"\nCalculating statistics for {label}") 2623 labels = ["target" if idx == label else "rest" for idx in choose.index] 2624 2625 choose["DEG"] = labels 2626 choose = choose[choose["DEG"] != "drop"] 2627 result_df = prepare_and_run_stat( 2628 choose.copy(), 2629 valid_group=label, 2630 min_exp=min_exp, 2631 min_pct=min_pct, 2632 n_proc=n_proc, 2633 ) 2634 final_results.append(result_df) 2635 2636 return pd.concat(final_results, ignore_index=True) 2637 2638 elif cells is None and isinstance(sets, dict): 2639 print("\nAnalysis started...\nComparing groups...") 2640 2641 choose.index = self.input_metadata["sets"] 2642 2643 group_list = list(sets.keys()) 2644 if len(group_list) != 2: 2645 print("Only pairwise group comparison is supported.") 2646 return None 2647 2648 labels = [ 2649 ( 2650 "target" 2651 if idx in sets[group_list[0]] 2652 else "rest" if idx in sets[group_list[1]] else "drop" 2653 ) 2654 for idx in choose.index 2655 ] 2656 choose["DEG"] = labels 2657 choose = choose[choose["DEG"] != "drop"] 2658 2659 result_df = prepare_and_run_stat( 2660 choose.reset_index(drop=True), 2661 valid_group=group_list[0], 2662 min_exp=min_exp, 2663 min_pct=min_pct, 2664 n_proc=n_proc, 2665 ) 2666 return result_df 2667 2668 elif isinstance(cells, dict) and sets is None: 2669 print("\nAnalysis started...\nComparing groups...") 2670 choose.index = ( 2671 self.input_metadata["cell_names"] + " # " + self.input_metadata["sets"] 2672 ) 2673 2674 if "#" not in cells[list(cells.keys())[0]][0]: 2675 choose.index = self.input_metadata["cell_names"] 2676 2677 print( 2678 "Not include the set info (name # set) in the 'cells' dict. " 2679 "Only the names will be compared, without considering the set information." 2680 ) 2681 2682 group_list = list(cells.keys()) 2683 if len(group_list) != 2: 2684 print("Only pairwise group comparison is supported.") 2685 return None 2686 2687 labels = [ 2688 ( 2689 "target" 2690 if idx in cells[group_list[0]] 2691 else "rest" if idx in cells[group_list[1]] else "drop" 2692 ) 2693 for idx in choose.index 2694 ] 2695 2696 choose["DEG"] = labels 2697 choose = choose[choose["DEG"] != "drop"] 2698 2699 result_df = prepare_and_run_stat( 2700 choose.reset_index(drop=True), 2701 valid_group=group_list[0], 2702 min_exp=min_exp, 2703 min_pct=min_pct, 2704 n_proc=n_proc, 2705 ) 2706 2707 return result_df.reset_index(drop=True) 2708 2709 else: 2710 raise ValueError( 2711 "You must specify either 'cells' or 'sets' (or both). None were provided, which is not allowed for this analysis." 2712 ) 2713 2714 def calculate_difference_markers( 2715 self, min_exp=0, min_pct=0.25, n_proc=10, force: bool = False 2716 ): 2717 """ 2718 Compute differential markers (var_data) if not already present. 2719 2720 Parameters 2721 ---------- 2722 min_exp : float, default 0 2723 Minimum expression threshold passed to `statistic`. 2724 2725 min_pct : float, default 0.25 2726 Minimum percent expressed in target group. 2727 2728 n_proc : int, default 10 2729 Parallel jobs. 2730 2731 force : bool, default False 2732 If True, recompute even if `self.var_data` is present. 2733 2734 Update 2735 ------- 2736 Sets `self.var_data` to the result of `self.statistic(...)`. 2737 2738 Raise 2739 ------ 2740 ValueError if already computed and `force` is False. 2741 """ 2742 2743 if self.var_data is None or force: 2744 2745 self.var_data = self.statistic( 2746 cells="All", sets=None, min_exp=min_exp, min_pct=min_pct, n_proc=n_proc 2747 ) 2748 2749 else: 2750 raise ValueError( 2751 "self.calculate_difference_markers() has already been executed. " 2752 "The results are stored in self.var. " 2753 "If you want to recalculate with different statistics, please rerun the method with force=True." 2754 ) 2755 2756 def clustering_features( 2757 self, 2758 features_list: list | None, 2759 name_slot: str = "cell_names", 2760 p_val: float = 0.05, 2761 top_n: int = 25, 2762 adj_mean: bool = True, 2763 beta: float = 0.2, 2764 ): 2765 """ 2766 Prepare clustering input by selecting marker features and optionally smoothing cell values 2767 toward group means. 2768 2769 Parameters 2770 ---------- 2771 features_list : list or None 2772 If provided, use this list of features. If None, features are selected 2773 from `self.var_data` (adj_pval <= p_val, positive logFC) picking `top_n` per group. 2774 2775 name_slot : str, default 'cell_names' 2776 Metadata column used for naming. 2777 2778 p_val : float, default 0.05 2779 Adjusted p-value cutoff when selecting features automatically. 2780 2781 top_n : int, default 25 2782 Number of top features per valid group to keep if `features_list` is None. 2783 2784 adj_mean : bool, default True 2785 If True, adjust cell values toward group means using `beta`. 2786 2787 beta : float, default 0.2 2788 Adjustment strength toward group mean. 2789 2790 Update 2791 ------ 2792 Sets `self.clustering_data` and `self.clustering_metadata` to the selected subset, 2793 ready for PCA/UMAP/clustering. 2794 """ 2795 2796 if features_list is None or len(features_list) == 0: 2797 2798 if self.var_data is None: 2799 raise ValueError( 2800 "Lack of 'self.var_data'. Use self.calculate_difference_markers() method first." 2801 ) 2802 2803 df_tmp = self.var_data[self.var_data["adj_pval"] <= p_val] 2804 df_tmp = df_tmp[df_tmp["log(FC)"] > 0] 2805 df_tmp = ( 2806 df_tmp.sort_values( 2807 ["valid_group", "esm", "log(FC)"], ascending=[True, False, False] 2808 ) 2809 .groupby("valid_group") 2810 .head(top_n) 2811 ) 2812 2813 feaures_list = list(set(df_tmp["feature"])) 2814 2815 data = self.get_partial_data( 2816 names=None, features=feaures_list, name_slot=name_slot 2817 ) 2818 data_avg = average(data) 2819 2820 if adj_mean: 2821 data = adjust_cells_to_group_mean(data=data, data_avg=data_avg, beta=beta) 2822 2823 self.clustering_data = data 2824 2825 self.clustering_metadata = self.input_metadata 2826 2827 def average(self): 2828 """ 2829 Aggregate normalized data by (cell_name, set) pairs computing the mean per group. 2830 2831 The method constructs new column names as "cell_name # set", averages columns 2832 sharing identical labels, and populates `self.agg_normalized_data` and `self.agg_metadata`. 2833 2834 Update 2835 ------ 2836 Sets `self.agg_normalized_data` (features x aggregated samples) and 2837 `self.agg_metadata` (DataFrame with 'cell_names' and 'sets'). 2838 """ 2839 2840 wide_data = self.normalized_data 2841 2842 wide_metadata = self.input_metadata 2843 2844 new_names = wide_metadata["cell_names"] + " # " + wide_metadata["sets"] 2845 2846 wide_data.columns = list(new_names) 2847 2848 aggregated_df = wide_data.T.groupby(level=0).mean().T 2849 2850 sets = [re.sub(".*# ", "", x) for x in aggregated_df.columns] 2851 names = [re.sub(" #.*", "", x) for x in aggregated_df.columns] 2852 2853 aggregated_df.columns = names 2854 aggregated_metadata = pd.DataFrame({"cell_names": names, "sets": sets}) 2855 2856 self.agg_metadata = aggregated_metadata 2857 self.agg_normalized_data = aggregated_df 2858 2859 def estimating_similarity( 2860 self, method="pearson", p_val: float = 0.05, top_n: int = 25 2861 ): 2862 """ 2863 Estimate pairwise similarity and Euclidean distance between aggregated samples. 2864 2865 Parameters 2866 ---------- 2867 method : str, default 'pearson' 2868 Correlation method to use (passed to pandas.DataFrame.corr()). 2869 2870 p_val : float, default 0.05 2871 Adjusted p-value cutoff used to select marker features from `self.var_data`. 2872 2873 top_n : int, default 25 2874 Number of top features per valid group to include. 2875 2876 Update 2877 ------- 2878 Computes a combined table with per-pair correlation and euclidean distance 2879 and stores it in `self.similarity`. 2880 """ 2881 2882 if self.var_data is None: 2883 raise ValueError( 2884 "Lack of 'self.var_data'. Use self.calculate_difference_markers() method first." 2885 ) 2886 2887 if self.agg_normalized_data is None: 2888 self.average() 2889 2890 metadata = self.agg_metadata 2891 data = self.agg_normalized_data 2892 2893 df_tmp = self.var_data[self.var_data["adj_pval"] <= p_val] 2894 df_tmp = df_tmp[df_tmp["log(FC)"] > 0] 2895 df_tmp = ( 2896 df_tmp.sort_values( 2897 ["valid_group", "esm", "log(FC)"], ascending=[True, False, False] 2898 ) 2899 .groupby("valid_group") 2900 .head(top_n) 2901 ) 2902 2903 data = data.loc[list(set(df_tmp["feature"]))] 2904 2905 if len(set(metadata["sets"])) > 1: 2906 data.columns = data.columns + " # " + [x for x in metadata["sets"]] 2907 else: 2908 data = data.copy() 2909 2910 scaler = StandardScaler() 2911 2912 scaled_data = scaler.fit_transform(data) 2913 2914 scaled_df = pd.DataFrame(scaled_data, columns=data.columns) 2915 2916 cor = scaled_df.corr(method=method) 2917 cor_df = cor.stack().reset_index() 2918 cor_df.columns = ["cell1", "cell2", "correlation"] 2919 2920 distances = pdist(scaled_df.T, metric="euclidean") 2921 dist_mat = pd.DataFrame( 2922 squareform(distances), index=scaled_df.columns, columns=scaled_df.columns 2923 ) 2924 dist_df = dist_mat.stack().reset_index() 2925 dist_df.columns = ["cell1", "cell2", "euclidean_dist"] 2926 2927 full = pd.merge(cor_df, dist_df, on=["cell1", "cell2"]) 2928 2929 full = full[full["cell1"] != full["cell2"]] 2930 full = full.reset_index(drop=True) 2931 2932 self.similarity = full 2933 2934 def similarity_plot( 2935 self, 2936 split_sets=True, 2937 set_info: bool = True, 2938 cmap="seismic", 2939 width=12, 2940 height=10, 2941 ): 2942 """ 2943 Visualize pairwise similarity as a scatter plot. 2944 2945 Parameters 2946 ---------- 2947 split_sets : bool, default True 2948 If True and set information is present, split plotting area roughly into two halves to visualize cross-set pairs. 2949 2950 set_info : bool, default True 2951 If True, keep the ' # set' annotation in labels; otherwise strip it. 2952 2953 cmap : str, default 'seismic' 2954 Color map for correlation (hue). 2955 2956 width : int, default 12 2957 Figure width. 2958 2959 height : int, default 10 2960 Figure height. 2961 2962 Returns 2963 ------- 2964 matplotlib.figure.Figure 2965 2966 Raises 2967 ------ 2968 ValueError 2969 If `self.similarity` is None. 2970 2971 Notes 2972 ----- 2973 The function filters pairs by z-scored euclidean distance > 0 to focus on closer pairs. 2974 """ 2975 2976 if self.similarity is None: 2977 raise ValueError( 2978 "Similarity data is missing. Please calculate similarity using self.estimating_similarity." 2979 ) 2980 2981 similarity_data = self.similarity 2982 2983 if " # " in similarity_data["cell1"][0]: 2984 similarity_data["set1"] = [ 2985 re.sub(".*# ", "", x) for x in similarity_data["cell1"] 2986 ] 2987 similarity_data["set2"] = [ 2988 re.sub(".*# ", "", x) for x in similarity_data["cell2"] 2989 ] 2990 2991 if split_sets and " # " in similarity_data["cell1"][0]: 2992 sets = list( 2993 set(list(similarity_data["set1"]) + list(similarity_data["set2"])) 2994 ) 2995 2996 mm = math.ceil(len(sets) / 2) 2997 2998 x_s = sets[0:mm] 2999 y_s = sets[mm : len(sets)] 3000 3001 similarity_data = similarity_data[similarity_data["set1"].isin(x_s)] 3002 similarity_data = similarity_data[similarity_data["set2"].isin(y_s)] 3003 3004 similarity_data = similarity_data.sort_values(["set1", "set2"]) 3005 3006 if set_info is False and " # " in similarity_data["cell1"][0]: 3007 similarity_data["cell1"] = [ 3008 re.sub(" #.*", "", x) for x in similarity_data["cell1"] 3009 ] 3010 similarity_data["cell2"] = [ 3011 re.sub(" #.*", "", x) for x in similarity_data["cell2"] 3012 ] 3013 3014 similarity_data["-euclidean_zscore"] = -zscore( 3015 similarity_data["euclidean_dist"] 3016 ) 3017 3018 similarity_data = similarity_data[similarity_data["-euclidean_zscore"] > 0] 3019 3020 fig = plt.figure(figsize=(width, height)) 3021 sns.scatterplot( 3022 data=similarity_data, 3023 x="cell1", 3024 y="cell2", 3025 hue="correlation", 3026 size="-euclidean_zscore", 3027 sizes=(1, 100), 3028 palette=cmap, 3029 alpha=1, 3030 edgecolor="black", 3031 ) 3032 3033 plt.xticks(rotation=90) 3034 plt.yticks(rotation=0) 3035 plt.xlabel("Cell 1") 3036 plt.ylabel("Cell 2") 3037 plt.legend(bbox_to_anchor=(1.05, 1), loc="upper left") 3038 3039 plt.grid(True, alpha=0.6) 3040 3041 plt.tight_layout() 3042 3043 return fig 3044 3045 def spatial_similarity( 3046 self, 3047 set_info: bool = True, 3048 bandwidth=1, 3049 n_neighbors=5, 3050 min_dist=0.1, 3051 legend_split=2, 3052 point_size=100, 3053 spread=1.0, 3054 set_op_mix_ratio=1.0, 3055 local_connectivity=1, 3056 repulsion_strength=1.0, 3057 negative_sample_rate=5, 3058 threshold=0.1, 3059 width=12, 3060 height=10, 3061 ): 3062 """ 3063 Create a spatial UMAP-like visualization of similarity relationships between samples. 3064 3065 Parameters 3066 ---------- 3067 set_info : bool, default True 3068 If True, retain set information in labels. 3069 3070 bandwidth : float, default 1 3071 Bandwidth used by MeanShift for clustering polygons. 3072 3073 point_size : float, default 100 3074 Size of scatter points. 3075 3076 legend_split : int, default 2 3077 Number of columns in legend. 3078 3079 n_neighbors, min_dist, spread, set_op_mix_ratio, local_connectivity, repulsion_strength, negative_sample_rate : parameters passed to UMAP. 3080 3081 threshold : float, default 0.1 3082 Minimum text distance for label adjustment to avoid overlap. 3083 3084 width : int, default 12 3085 Figure width. 3086 3087 height : int, default 10 3088 Figure height. 3089 3090 Returns 3091 ------- 3092 matplotlib.figure.Figure 3093 3094 Raises 3095 ------ 3096 ValueError 3097 If `self.similarity` is None. 3098 3099 Notes 3100 ----- 3101 Builds a precomputed distance matrix combining correlation and euclidean distance, 3102 runs UMAP with metric='precomputed', then overlays cluster hulls (MeanShift + convex hull) 3103 and arrows to indicate nearest neighbors (minimal combined distance). 3104 """ 3105 3106 if self.similarity is None: 3107 raise ValueError( 3108 "Similarity data is missing. Please calculate similarity using self.estimating_similarity." 3109 ) 3110 3111 similarity_data = self.similarity 3112 3113 sim = similarity_data["correlation"] 3114 sim_scaled = (sim - sim.min()) / (sim.max() - sim.min()) 3115 eu_dist = similarity_data["euclidean_dist"] 3116 eu_dist_scaled = (eu_dist - eu_dist.min()) / (eu_dist.max() - eu_dist.min()) 3117 3118 similarity_data["combo_dist"] = (1 - sim_scaled) * eu_dist_scaled 3119 3120 # for nn target 3121 arrow_df = similarity_data.copy() 3122 arrow_df = similarity_data.loc[ 3123 similarity_data.groupby("cell1")["combo_dist"].idxmin() 3124 ] 3125 3126 cells = sorted(set(similarity_data["cell1"]) | set(similarity_data["cell2"])) 3127 combo_matrix = pd.DataFrame(0, index=cells, columns=cells, dtype=float) 3128 3129 for _, row in similarity_data.iterrows(): 3130 combo_matrix.loc[row["cell1"], row["cell2"]] = row["combo_dist"] 3131 combo_matrix.loc[row["cell2"], row["cell1"]] = row["combo_dist"] 3132 3133 umap_model = umap.UMAP( 3134 n_components=2, 3135 metric="precomputed", 3136 n_neighbors=n_neighbors, 3137 min_dist=min_dist, 3138 spread=spread, 3139 set_op_mix_ratio=set_op_mix_ratio, 3140 local_connectivity=set_op_mix_ratio, 3141 repulsion_strength=repulsion_strength, 3142 negative_sample_rate=negative_sample_rate, 3143 transform_seed=42, 3144 init="spectral", 3145 random_state=42, 3146 verbose=True, 3147 ) 3148 3149 coords = umap_model.fit_transform(combo_matrix.values) 3150 cell_names = list(combo_matrix.index) 3151 num_cells = len(cell_names) 3152 palette = sns.color_palette("tab20c", num_cells) 3153 3154 if "#" in cell_names[0]: 3155 avsets = set( 3156 [re.sub(".*# ", "", x) for x in similarity_data["cell1"]] 3157 + [re.sub(".*# ", "", x) for x in similarity_data["cell2"]] 3158 ) 3159 num_sets = len(avsets) 3160 color_indices = [i * len(palette) // num_sets for i in range(num_sets)] 3161 color_mapping_sets = { 3162 set_name: palette[i] for i, set_name in zip(color_indices, avsets) 3163 } 3164 color_mapping = { 3165 name: color_mapping_sets[re.sub(".*# ", "", name)] 3166 for i, name in enumerate(cell_names) 3167 } 3168 else: 3169 color_mapping = {name: palette[i] for i, name in enumerate(cell_names)} 3170 3171 meanshift = MeanShift(bandwidth=bandwidth) 3172 labels = meanshift.fit_predict(coords) 3173 3174 fig = plt.figure(figsize=(width, height)) 3175 ax = plt.gca() 3176 3177 unique_labels = set(labels) 3178 cluster_palette = sns.color_palette("hls", len(unique_labels)) 3179 3180 for label in unique_labels: 3181 if label == -1: 3182 continue 3183 cluster_coords = coords[labels == label] 3184 if len(cluster_coords) < 3: 3185 continue 3186 3187 hull = ConvexHull(cluster_coords) 3188 hull_points = cluster_coords[hull.vertices] 3189 3190 centroid = np.mean(hull_points, axis=0) 3191 expanded = hull_points + 0.05 * (hull_points - centroid) 3192 3193 poly = Polygon( 3194 expanded, 3195 closed=True, 3196 facecolor=cluster_palette[label], 3197 edgecolor="none", 3198 alpha=0.2, 3199 zorder=1, 3200 ) 3201 ax.add_patch(poly) 3202 3203 texts = [] 3204 for i, (x, y) in enumerate(coords): 3205 plt.scatter( 3206 x, 3207 y, 3208 s=point_size, 3209 color=color_mapping[cell_names[i]], 3210 edgecolors="black", 3211 linewidths=0.5, 3212 zorder=2, 3213 ) 3214 texts.append( 3215 ax.text( 3216 x, y, str(i), ha="center", va="center", fontsize=8, color="black" 3217 ) 3218 ) 3219 3220 def dist(p1, p2): 3221 return np.sqrt((p1[0] - p2[0]) ** 2 + (p1[1] - p2[1]) ** 2) 3222 3223 texts_to_adjust = [] 3224 for i, t1 in enumerate(texts): 3225 for j, t2 in enumerate(texts): 3226 if i >= j: 3227 continue 3228 d = dist( 3229 (t1.get_position()[0], t1.get_position()[1]), 3230 (t2.get_position()[0], t2.get_position()[1]), 3231 ) 3232 if d < threshold: 3233 if t1 not in texts_to_adjust: 3234 texts_to_adjust.append(t1) 3235 if t2 not in texts_to_adjust: 3236 texts_to_adjust.append(t2) 3237 3238 adjust_text( 3239 texts_to_adjust, 3240 expand_text=(1.0, 1.0), 3241 force_text=0.9, 3242 arrowprops=dict(arrowstyle="-", color="gray", lw=0.1), 3243 ax=ax, 3244 ) 3245 3246 for _, row in arrow_df.iterrows(): 3247 try: 3248 idx1 = cell_names.index(row["cell1"]) 3249 idx2 = cell_names.index(row["cell2"]) 3250 except ValueError: 3251 continue 3252 x1, y1 = coords[idx1] 3253 x2, y2 = coords[idx2] 3254 arrow = FancyArrowPatch( 3255 (x1, y1), 3256 (x2, y2), 3257 arrowstyle="->", 3258 color="gray", 3259 linewidth=1.5, 3260 alpha=0.5, 3261 mutation_scale=12, 3262 zorder=0, 3263 ) 3264 ax.add_patch(arrow) 3265 3266 if set_info is False and " # " in cell_names[0]: 3267 3268 legend_elements = [ 3269 Patch( 3270 facecolor=color_mapping[name], 3271 edgecolor="black", 3272 label=f"{i} – {re.sub(' #.*', '', name)}", 3273 ) 3274 for i, name in enumerate(cell_names) 3275 ] 3276 3277 else: 3278 3279 legend_elements = [ 3280 Patch( 3281 facecolor=color_mapping[name], 3282 edgecolor="black", 3283 label=f"{i} – {name}", 3284 ) 3285 for i, name in enumerate(cell_names) 3286 ] 3287 3288 plt.legend( 3289 handles=legend_elements, 3290 title="Cells", 3291 bbox_to_anchor=(1.05, 1), 3292 loc="upper left", 3293 ncol=legend_split, 3294 ) 3295 3296 plt.xlabel("UMAP 1") 3297 plt.ylabel("UMAP 2") 3298 plt.grid(False) 3299 plt.show() 3300 3301 return fig 3302 3303 # subclusters part 3304 3305 def subcluster_prepare(self, features: list, cluster: str): 3306 """ 3307 Prepare a `Clustering` object for subcluster analysis on a selected parent cluster. 3308 3309 Parameters 3310 ---------- 3311 features : list 3312 Features to include for subcluster analysis. 3313 3314 cluster : str 3315 Parent cluster name (used to select matching cells). 3316 3317 Update 3318 ------ 3319 Initializes `self.subclusters_` as a new `Clustering` instance containing the 3320 reduced data for the given cluster and stores `current_features` and `current_cluster`. 3321 """ 3322 3323 dat = self.normalized_data 3324 dat.columns = list(self.input_metadata["cell_names"]) 3325 3326 dat = reduce_data(self.normalized_data, features=features, names=[cluster]) 3327 3328 self.subclusters_ = Clustering(data=dat, metadata=None) 3329 3330 self.subclusters_.current_features = features 3331 self.subclusters_.current_cluster = cluster 3332 3333 def define_subclusters( 3334 self, 3335 umap_num: int = 2, 3336 eps: float = 0.5, 3337 min_samples: int = 10, 3338 n_neighbors: int = 5, 3339 min_dist: float = 0.1, 3340 spread: float = 1.0, 3341 set_op_mix_ratio: float = 1.0, 3342 local_connectivity: int = 1, 3343 repulsion_strength: float = 1.0, 3344 negative_sample_rate: int = 5, 3345 width=8, 3346 height=6, 3347 ): 3348 """ 3349 Compute UMAP and DBSCAN clustering within a previously prepared subcluster dataset. 3350 3351 Parameters 3352 ---------- 3353 umap_num : int, default 2 3354 Number of UMAP dimensions to compute. 3355 3356 eps : float, default 0.5 3357 DBSCAN eps parameter. 3358 3359 min_samples : int, default 10 3360 DBSCAN min_samples parameter. 3361 3362 n_neighbors, min_dist, spread, set_op_mix_ratio, local_connectivity, repulsion_strength, negative_sample_rate, width, height : 3363 Additional parameters passed to UMAP / plotting / MeanShift as appropriate. 3364 3365 Update 3366 ------- 3367 Stores cluster labels in `self.subclusters_.subclusters`. 3368 3369 Raises 3370 ------ 3371 RuntimeError 3372 If `self.subclusters_` has not been prepared. 3373 """ 3374 3375 if self.subclusters_ is None: 3376 raise RuntimeError( 3377 "Nothing to return. 'self.subcluster_prepare' was not conducted!" 3378 ) 3379 3380 self.subclusters_.perform_UMAP( 3381 factorize=False, 3382 umap_num=umap_num, 3383 pc_num=0, 3384 harmonized=False, 3385 n_neighbors=n_neighbors, 3386 min_dist=min_dist, 3387 spread=spread, 3388 set_op_mix_ratio=set_op_mix_ratio, 3389 local_connectivity=local_connectivity, 3390 repulsion_strength=repulsion_strength, 3391 negative_sample_rate=negative_sample_rate, 3392 width=width, 3393 height=height, 3394 ) 3395 3396 fig = self.subclusters_.find_clusters_UMAP( 3397 umap_n=umap_num, 3398 eps=eps, 3399 min_samples=min_samples, 3400 width=width, 3401 height=height, 3402 ) 3403 3404 clusters = self.subclusters_.return_clusters(clusters="umap") 3405 3406 self.subclusters_.subclusters = [str(x) for x in list(clusters)] 3407 3408 return fig 3409 3410 def subcluster_features_scatter( 3411 self, 3412 colors="viridis", 3413 hclust="complete", 3414 scale=False, 3415 img_width=3, 3416 img_high=5, 3417 label_size=6, 3418 size_scale=70, 3419 y_lab="Genes", 3420 legend_lab="normalized", 3421 bbox_to_anchor_scale: int = 25, 3422 bbox_to_anchor_perc: tuple = (0.91, 0.63), 3423 ): 3424 """ 3425 Create a features-scatter visualization for the subclusters (averaged and occurrence). 3426 3427 Parameters 3428 ---------- 3429 colors : str, default 'viridis' 3430 Colormap name passed to `features_scatter`. 3431 3432 hclust : str or None 3433 Hierarchical clustering linkage to order rows/columns. 3434 3435 scale: bool, default False 3436 If True, expression data will be scaled (0–1) across the rows (features). 3437 3438 img_width, img_high : float 3439 Figure size. 3440 3441 label_size : int 3442 Font size for labels. 3443 3444 size_scale : int 3445 Bubble size scaling. 3446 3447 y_lab : str 3448 X axis label. 3449 3450 legend_lab : str 3451 Colorbar label. 3452 3453 bbox_to_anchor_scale : int, default=25 3454 Vertical scale (percentage) for positioning the colorbar. 3455 3456 bbox_to_anchor_perc : tuple, default=(0.91, 0.63) 3457 Anchor position for the size legend (percent bubble legend). 3458 3459 Returns 3460 ------- 3461 matplotlib.figure.Figure 3462 3463 Raises 3464 ------ 3465 RuntimeError 3466 If subcluster preparation/definition has not been run. 3467 """ 3468 3469 if self.subclusters_ is None or self.subclusters_.subclusters is None: 3470 raise RuntimeError( 3471 "Nothing to return. 'self.subcluster_prepare' -> 'self.define_subclusters' pip was not conducted!" 3472 ) 3473 3474 dat = self.normalized_data 3475 dat.columns = list(self.input_metadata["cell_names"]) 3476 3477 dat = reduce_data( 3478 self.normalized_data, 3479 features=self.subclusters_.current_features, 3480 names=[self.subclusters_.current_cluster], 3481 ) 3482 3483 dat.columns = self.subclusters_.subclusters 3484 3485 avg = average(dat) 3486 occ = occurrence(dat) 3487 3488 scatter = features_scatter( 3489 expression_data=avg, 3490 occurence_data=occ, 3491 features=None, 3492 scale=scale, 3493 metadata_list=None, 3494 colors=colors, 3495 hclust=hclust, 3496 img_width=img_width, 3497 img_high=img_high, 3498 label_size=label_size, 3499 size_scale=size_scale, 3500 y_lab=y_lab, 3501 legend_lab=legend_lab, 3502 bbox_to_anchor_scale=bbox_to_anchor_scale, 3503 bbox_to_anchor_perc=bbox_to_anchor_perc, 3504 ) 3505 3506 return scatter 3507 3508 def subcluster_DEG_scatter( 3509 self, 3510 top_n=3, 3511 min_exp=0, 3512 min_pct=0.25, 3513 p_val=0.05, 3514 colors="viridis", 3515 hclust="complete", 3516 scale=False, 3517 img_width=3, 3518 img_high=5, 3519 label_size=6, 3520 size_scale=70, 3521 y_lab="Genes", 3522 legend_lab="normalized", 3523 bbox_to_anchor_scale: int = 25, 3524 bbox_to_anchor_perc: tuple = (0.91, 0.63), 3525 n_proc=10, 3526 ): 3527 """ 3528 Plot top differential features (DEGs) for subclusters as a features-scatter. 3529 3530 Parameters 3531 ---------- 3532 top_n : int, default 3 3533 Number of top features per subcluster to show. 3534 3535 min_exp : float, default 0 3536 Minimum expression threshold passed to `statistic`. 3537 3538 min_pct : float, default 0.25 3539 Minimum percent expressed in target group. 3540 3541 p_val: float, default 0.05 3542 Maximum p-value for visualizing features. 3543 3544 n_proc : int, default 10 3545 Parallel jobs used for DEG calculation. 3546 3547 scale: bool, default False 3548 If True, expression_data will be scaled (0–1) across the rows (features). 3549 3550 colors : str, default='viridis' 3551 Colormap for expression values. 3552 3553 hclust : str or None, default='complete' 3554 Linkage method for hierarchical clustering. If None, no clustering 3555 is performed. 3556 3557 img_width : int or float, default=8 3558 Width of the plot in inches. 3559 3560 img_high : int or float, default=5 3561 Height of the plot in inches. 3562 3563 label_size : int, default=10 3564 Font size for axis labels and ticks. 3565 3566 size_scale : int or float, default=100 3567 Scaling factor for bubble sizes. 3568 3569 y_lab : str, default='Genes' 3570 Label for the x-axis. 3571 3572 legend_lab : str, default='normalized' 3573 Label for the colorbar legend. 3574 3575 bbox_to_anchor_scale : int, default=25 3576 Vertical scale (percentage) for positioning the colorbar. 3577 3578 bbox_to_anchor_perc : tuple, default=(0.91, 0.63) 3579 Anchor position for the size legend (percent bubble legend). 3580 3581 Returns 3582 ------- 3583 matplotlib.figure.Figure 3584 3585 Raises 3586 ------ 3587 RuntimeError 3588 If subcluster preparation/definition has not been run. 3589 3590 Notes 3591 ----- 3592 Internally calls `calc_DEG` (or equivalent) to obtain statistics, filters 3593 by p-value and effect-size, selects top features per valid group and plots them. 3594 """ 3595 3596 if self.subclusters_ is None or self.subclusters_.subclusters is None: 3597 raise RuntimeError( 3598 "Nothing to return. 'self.subcluster_prepare' -> 'self.define_subclusters' pip was not conducted!" 3599 ) 3600 3601 dat = self.normalized_data 3602 dat.columns = list(self.input_metadata["cell_names"]) 3603 3604 dat = reduce_data( 3605 self.normalized_data, names=[self.subclusters_.current_cluster] 3606 ) 3607 3608 dat.columns = self.subclusters_.subclusters 3609 3610 deg_stats = calc_DEG( 3611 dat, 3612 metadata_list=None, 3613 entities="All", 3614 sets=None, 3615 min_exp=min_exp, 3616 min_pct=min_pct, 3617 n_proc=n_proc, 3618 ) 3619 3620 deg_stats = deg_stats[deg_stats["p_val"] <= p_val] 3621 deg_stats = deg_stats[deg_stats["log(FC)"] > 0] 3622 3623 deg_stats = ( 3624 deg_stats.sort_values( 3625 ["valid_group", "esm", "log(FC)"], ascending=[True, False, False] 3626 ) 3627 .groupby("valid_group") 3628 .head(top_n) 3629 ) 3630 3631 dat = reduce_data(dat, features=list(set(deg_stats["feature"]))) 3632 3633 avg = average(dat) 3634 occ = occurrence(dat) 3635 3636 scatter = features_scatter( 3637 expression_data=avg, 3638 occurence_data=occ, 3639 features=None, 3640 metadata_list=None, 3641 colors=colors, 3642 hclust=hclust, 3643 img_width=img_width, 3644 img_high=img_high, 3645 label_size=label_size, 3646 size_scale=size_scale, 3647 y_lab=y_lab, 3648 legend_lab=legend_lab, 3649 bbox_to_anchor_scale=bbox_to_anchor_scale, 3650 bbox_to_anchor_perc=bbox_to_anchor_perc, 3651 ) 3652 3653 return scatter 3654 3655 def accept_subclusters(self): 3656 """ 3657 Commit subcluster labels into the main `input_metadata` by renaming cell names. 3658 3659 The method replaces occurrences of the parent cluster name in `self.input_metadata['cell_names']` 3660 with the expanded names that include subcluster suffixes (via `add_subnames`), 3661 then clears `self.subclusters_`. 3662 3663 Update 3664 ------ 3665 Modifies `self.input_metadata['cell_names']`. 3666 3667 Resets `self.subclusters_` to None. 3668 3669 Raises 3670 ------ 3671 RuntimeError 3672 If `self.subclusters_` is not defined or subclusters were not computed. 3673 """ 3674 3675 if self.subclusters_ is None or self.subclusters_.subclusters is None: 3676 raise RuntimeError( 3677 "Nothing to return. 'self.subcluster_prepare' -> 'self.define_subclusters' pip was not conducted!" 3678 ) 3679 3680 new_meta = add_subnames( 3681 list(self.input_metadata["cell_names"]), 3682 parent_name=self.subclusters_.current_cluster, 3683 new_clusters=self.subclusters_.subclusters, 3684 ) 3685 3686 self.input_metadata["cell_names"] = new_meta 3687 3688 self.subclusters_ = None 3689 3690 def scatter_plot( 3691 self, 3692 names: list | None = None, 3693 features: list | None = None, 3694 name_slot: str = "cell_names", 3695 scale=True, 3696 colors="viridis", 3697 hclust=None, 3698 img_width=15, 3699 img_high=1, 3700 label_size=10, 3701 size_scale=200, 3702 y_lab="Genes", 3703 legend_lab="log(CPM + 1)", 3704 set_box_size: float | int = 5, 3705 set_box_high: float | int = 5, 3706 bbox_to_anchor_scale=25, 3707 bbox_to_anchor_perc=(0.90, -0.24), 3708 bbox_to_anchor_group=(1.01, 0.4), 3709 ): 3710 """ 3711 Create a bubble scatter plot of selected features across samples inside project. 3712 3713 Each point represents a feature-sample pair, where the color encodes the 3714 expression value and the size encodes occurrence or relative abundance. 3715 Optionally, hierarchical clustering can be applied to order rows and columns. 3716 3717 Parameters 3718 ---------- 3719 names : list, str, or None 3720 Names of samples to include. If None, all samples are considered. 3721 3722 features : list, str, or None 3723 Names of features to include. If None, all features are considered. 3724 3725 name_slot : str 3726 Column in metadata to use as sample names. 3727 3728 scale: bool, default False 3729 If True, expression_data will be scaled (0–1) across the rows (features). 3730 3731 colors : str, default='viridis' 3732 Colormap for expression values. 3733 3734 hclust : str or None, default='complete' 3735 Linkage method for hierarchical clustering. If None, no clustering 3736 is performed. 3737 3738 img_width : int or float, default=8 3739 Width of the plot in inches. 3740 3741 img_high : int or float, default=5 3742 Height of the plot in inches. 3743 3744 label_size : int, default=10 3745 Font size for axis labels and ticks. 3746 3747 size_scale : int or float, default=100 3748 Scaling factor for bubble sizes. 3749 3750 y_lab : str, default='Genes' 3751 Label for the x-axis. 3752 3753 legend_lab : str, default='log(CPM + 1)' 3754 Label for the colorbar legend. 3755 3756 bbox_to_anchor_scale : int, default=25 3757 Vertical scale (percentage) for positioning the colorbar. 3758 3759 bbox_to_anchor_perc : tuple, default=(0.91, 0.63) 3760 Anchor position for the size legend (percent bubble legend). 3761 3762 bbox_to_anchor_group : tuple, default=(1.01, 0.4) 3763 Anchor position for the group legend. 3764 3765 Returns 3766 ------- 3767 matplotlib.figure.Figure 3768 The generated scatter plot figure. 3769 3770 Notes 3771 ----- 3772 Colors represent expression values normalized to the colormap. 3773 """ 3774 3775 prtd, met = self.get_partial_data( 3776 names=names, features=features, name_slot=name_slot, inc_metadata=True 3777 ) 3778 3779 prtd.columns = prtd.columns + "#" + met["sets"] 3780 3781 prtd_avg = average(prtd) 3782 3783 meta_sets = [re.sub(".*#", "", x) for x in prtd_avg.columns] 3784 3785 prtd_avg.columns = [re.sub("#.*", "", x) for x in prtd_avg.columns] 3786 3787 prtd_occ = occurrence(prtd) 3788 3789 prtd_occ.columns = [re.sub("#.*", "", x) for x in prtd_occ.columns] 3790 3791 fig_scatter = features_scatter( 3792 expression_data=prtd_avg, 3793 occurence_data=prtd_occ, 3794 scale=scale, 3795 features=None, 3796 metadata_list=meta_sets, 3797 colors=colors, 3798 hclust=hclust, 3799 img_width=img_width, 3800 img_high=img_high, 3801 label_size=label_size, 3802 size_scale=size_scale, 3803 y_lab=y_lab, 3804 legend_lab=legend_lab, 3805 set_box_size=set_box_size, 3806 set_box_high=set_box_high, 3807 bbox_to_anchor_scale=bbox_to_anchor_scale, 3808 bbox_to_anchor_perc=bbox_to_anchor_perc, 3809 bbox_to_anchor_group=bbox_to_anchor_group, 3810 ) 3811 3812 return fig_scatter 3813 3814 def data_composition( 3815 self, 3816 features_count: list | None, 3817 name_slot: str = "cell_names", 3818 set_sep: bool = True, 3819 ): 3820 """ 3821 Compute composition of cell types in data set. 3822 3823 This function counts the occurrences of specific cells (e.g., cell types, subtypes) 3824 within metadata entries, calculates their relative percentages, and stores 3825 the results in `self.composition_data`. 3826 3827 Parameters 3828 ---------- 3829 features_count : list or None 3830 List of features (part or full names) to be counted. 3831 If None, all unique elements from the specified `name_slot` metadata field are used. 3832 3833 name_slot : str, default 'cell_names' 3834 Metadata field containing sample identifiers or labels. 3835 3836 set_sep : bool, default True 3837 If True and multiple sets are present in metadata, compute composition 3838 separately for each set. 3839 3840 Update 3841 ------- 3842 Stores results in `self.composition_data` as a pandas DataFrame with: 3843 - 'name': feature name 3844 - 'n': number of occurrences 3845 - 'pct': percentage of occurrences 3846 - 'set' (if applicable): dataset identifier 3847 """ 3848 3849 validated_list = list(self.input_metadata[name_slot]) 3850 sets = list(self.input_metadata["sets"]) 3851 3852 if features_count is None: 3853 features_count = list(set(self.input_metadata[name_slot])) 3854 3855 if set_sep and len(set(sets)) > 1: 3856 3857 final_res = pd.DataFrame() 3858 3859 for s in set(sets): 3860 print(s) 3861 3862 mask = [True if s == x else False for x in sets] 3863 3864 tmp_val_list = np.array(validated_list) 3865 3866 tmp_val_list = list(tmp_val_list[mask]) 3867 3868 res_dict = {"name": [], "n": [], "set": []} 3869 3870 for f in tqdm(features_count): 3871 res_dict["n"].append( 3872 sum(1 for element in tmp_val_list if f in element) 3873 ) 3874 res_dict["name"].append(f) 3875 res_dict["set"].append(s) 3876 res = pd.DataFrame(res_dict) 3877 res["pct"] = res["n"] / sum(res["n"]) * 100 3878 res["pct"] = res["pct"].round(2) 3879 3880 final_res = pd.concat([final_res, res]) 3881 3882 res = final_res.sort_values(["set", "pct"], ascending=[True, False]) 3883 3884 else: 3885 3886 res_dict = {"name": [], "n": []} 3887 3888 for f in tqdm(features_count): 3889 res_dict["n"].append( 3890 sum(1 for element in validated_list if f in element) 3891 ) 3892 res_dict["name"].append(f) 3893 3894 res = pd.DataFrame(res_dict) 3895 res["pct"] = res["n"] / sum(res["n"]) * 100 3896 res["pct"] = res["pct"].round(2) 3897 3898 res = res.sort_values("pct", ascending=False) 3899 3900 self.composition_data = res 3901 3902 def composition_pie( 3903 self, 3904 width=6, 3905 height=6, 3906 font_size=15, 3907 cmap: str = "tab20", 3908 legend_split_col: int = 1, 3909 offset_labels: float | int = 0.5, 3910 legend_bbox: tuple = (1.15, 0.95), 3911 ): 3912 """ 3913 Visualize the composition of cell lineages using pie charts. 3914 3915 Generates pie charts showing the relative proportions of features stored 3916 in `self.composition_data`. If multiple sets are present, a separate 3917 chart is drawn for each set. 3918 3919 Parameters 3920 ---------- 3921 width : int, default 6 3922 Width of the figure. 3923 3924 height : int, default 6 3925 Height of the figure (applied per set if multiple sets are plotted). 3926 3927 font_size : int, default 15 3928 Font size for labels and annotations. 3929 3930 cmap : str, default 'tab20' 3931 Colormap used for pie slices. 3932 3933 legend_split_col : int, default 1 3934 Number of columns in the legend. 3935 3936 offset_labels : float or int, default 0.5 3937 Spacing offset for label placement relative to pie slices. 3938 3939 legend_bbox : tuple, default (1.15, 0.95) 3940 Bounding box anchor position for the legend. 3941 3942 Returns 3943 ------- 3944 matplotlib.figure.Figure 3945 Pie chart visualization of composition data. 3946 """ 3947 3948 df = self.composition_data 3949 3950 if "set" in df.columns and len(set(df["set"])) > 1: 3951 3952 sets = list(set(df["set"])) 3953 fig, axes = plt.subplots(len(sets), 1, figsize=(width, height * len(sets))) 3954 3955 all_wedges = [] 3956 cmap = plt.get_cmap(cmap) 3957 3958 set_nam = len(set(df["name"])) 3959 3960 legend_labels = list(set(df["name"])) 3961 3962 colors = [cmap(i / set_nam) for i in range(set_nam)] 3963 3964 cmap_dict = dict(zip(legend_labels, colors)) 3965 3966 for idx, s in enumerate(sets): 3967 ax = axes[idx] 3968 tmp_df = df[df["set"] == s].reset_index(drop=True) 3969 3970 labels = [f"{row['pct']:.1f}%" for _, row in tmp_df.iterrows()] 3971 3972 wedges, _ = ax.pie( 3973 tmp_df["n"], 3974 startangle=90, 3975 labeldistance=1.05, 3976 colors=[cmap_dict[x] for x in tmp_df["name"]], 3977 wedgeprops={"linewidth": 0.5, "edgecolor": "black"}, 3978 ) 3979 3980 all_wedges.extend(wedges) 3981 3982 kw = dict(arrowprops=dict(arrowstyle="-"), zorder=0, va="center") 3983 n = 0 3984 for i, p in enumerate(wedges): 3985 ang = (p.theta2 - p.theta1) / 2.0 + p.theta1 3986 y = np.sin(np.deg2rad(ang)) 3987 x = np.cos(np.deg2rad(ang)) 3988 horizontalalignment = {-1: "right", 1: "left"}[int(np.sign(x))] 3989 connectionstyle = f"angle,angleA=0,angleB={ang}" 3990 kw["arrowprops"].update({"connectionstyle": connectionstyle}) 3991 if len(labels[i]) > 0: 3992 n += offset_labels 3993 ax.annotate( 3994 labels[i], 3995 xy=(x, y), 3996 xytext=(1.01 * x + (n * x / 4), 1.01 * y + (n * y / 4)), 3997 horizontalalignment=horizontalalignment, 3998 fontsize=font_size, 3999 weight="bold", 4000 **kw, 4001 ) 4002 4003 circle2 = plt.Circle((0, 0), 0.6, color="white", ec="black") 4004 ax.add_artist(circle2) 4005 4006 ax.text( 4007 0, 4008 0, 4009 f"{s}", 4010 ha="center", 4011 va="center", 4012 fontsize=font_size, 4013 weight="bold", 4014 ) 4015 4016 legend_handles = [ 4017 Patch(facecolor=cmap_dict[label], edgecolor="black", label=label) 4018 for label in legend_labels 4019 ] 4020 4021 fig.legend( 4022 handles=legend_handles, 4023 loc="center right", 4024 bbox_to_anchor=legend_bbox, 4025 ncol=legend_split_col, 4026 title="", 4027 ) 4028 4029 plt.tight_layout() 4030 plt.show() 4031 4032 else: 4033 4034 labels = [f"{row['pct']:.1f}%" for _, row in df.iterrows()] 4035 4036 legend_labels = [f"{row['name']}" for _, row in df.iterrows()] 4037 4038 cmap = plt.get_cmap(cmap) 4039 colors = [cmap(i / len(df)) for i in range(len(df))] 4040 4041 fig, ax = plt.subplots( 4042 figsize=(width, height), subplot_kw=dict(aspect="equal") 4043 ) 4044 4045 wedges, _ = ax.pie( 4046 df["n"], 4047 startangle=90, 4048 labeldistance=1.05, 4049 colors=colors, 4050 wedgeprops={"linewidth": 0.5, "edgecolor": "black"}, 4051 ) 4052 4053 kw = dict(arrowprops=dict(arrowstyle="-"), zorder=0, va="center") 4054 n = 0 4055 for i, p in enumerate(wedges): 4056 ang = (p.theta2 - p.theta1) / 2.0 + p.theta1 4057 y = np.sin(np.deg2rad(ang)) 4058 x = np.cos(np.deg2rad(ang)) 4059 horizontalalignment = {-1: "right", 1: "left"}[int(np.sign(x))] 4060 connectionstyle = "angle,angleA=0,angleB={}".format(ang) 4061 kw["arrowprops"].update({"connectionstyle": connectionstyle}) 4062 if len(labels[i]) > 0: 4063 n += offset_labels 4064 4065 ax.annotate( 4066 labels[i], 4067 xy=(x, y), 4068 xytext=(1.01 * x + (n * x / 4), y * 1.01 + (n * y / 4)), 4069 horizontalalignment=horizontalalignment, 4070 fontsize=font_size, 4071 weight="bold", 4072 **kw, 4073 ) 4074 4075 circle2 = plt.Circle((0, 0), 0.6, color="white") 4076 circle2.set_edgecolor("black") 4077 4078 p = plt.gcf() 4079 p.gca().add_artist(circle2) 4080 4081 ax.legend( 4082 wedges, 4083 legend_labels, 4084 title="", 4085 loc="center left", 4086 bbox_to_anchor=legend_bbox, 4087 ncol=legend_split_col, 4088 ) 4089 4090 plt.show() 4091 4092 return fig 4093 4094 def bar_composition( 4095 self, 4096 cmap="tab20b", 4097 width=2, 4098 height=6, 4099 font_size=15, 4100 legend_split_col: int = 1, 4101 legend_bbox: tuple = (1.3, 1), 4102 ): 4103 """ 4104 Visualize the composition of cell lineages using bar plots. 4105 4106 Produces bar plots showing the distribution of features stored in 4107 `self.composition_data`. If multiple sets are present, a separate 4108 bar is drawn for each set. Percentages are annotated alongside the bars. 4109 4110 Parameters 4111 ---------- 4112 cmap : str, default 'tab20b' 4113 Colormap used for stacked bars. 4114 4115 width : int, default 2 4116 Width of each subplot (per set). 4117 4118 height : int, default 6 4119 Height of the figure. 4120 4121 font_size : int, default 15 4122 Font size for labels and annotations. 4123 4124 legend_split_col : int, default 1 4125 Number of columns in the legend. 4126 4127 legend_bbox : tuple, default (1.3, 1) 4128 Bounding box anchor position for the legend. 4129 4130 Returns 4131 ------- 4132 matplotlib.figure.Figure 4133 Stacked bar plot visualization of composition data. 4134 """ 4135 4136 df = self.composition_data 4137 df["num"] = range(1, len(df) + 1) 4138 4139 if "set" in df.columns and len(set(df["set"])) > 1: 4140 4141 sets = list(set(df["set"])) 4142 fig, axes = plt.subplots(1, len(sets), figsize=(width * len(sets), height)) 4143 4144 cmap = plt.get_cmap(cmap) 4145 4146 set_nam = len(set(df["name"])) 4147 4148 legend_labels = list(set(df["name"])) 4149 4150 colors = [cmap(i / set_nam) for i in range(set_nam)] 4151 4152 cmap_dict = dict(zip(legend_labels, colors)) 4153 4154 for idx, s in enumerate(sets): 4155 ax = axes[idx] 4156 4157 tmp_df = df[df["set"] == s].reset_index(drop=True) 4158 4159 values = tmp_df["n"].values 4160 total = sum(values) 4161 values = [v / total * 100 for v in values] 4162 values = [round(v, 2) for v in values] 4163 4164 idx_max = np.argmax(values) 4165 correction = 100 - sum(values) 4166 values[idx_max] += correction 4167 4168 names = tmp_df["name"].values 4169 perc = tmp_df["pct"].values 4170 nums = tmp_df["num"].values 4171 4172 bottom = 0 4173 centers = [] 4174 for name, num, val, color in zip(names, nums, values, colors): 4175 ax.bar(s, val, bottom=bottom, color=cmap_dict[name], label=name) 4176 centers.append(bottom + val / 2) 4177 bottom += val 4178 4179 y_positions = np.linspace(centers[0], centers[-1], len(centers)) 4180 x_text = -0.8 4181 4182 for y_label, y_center, pct, num in zip( 4183 y_positions, centers, perc, nums 4184 ): 4185 ax.annotate( 4186 f"{pct:.1f}%", 4187 xy=(0, y_center), 4188 xycoords="data", 4189 xytext=(x_text, y_label), 4190 textcoords="data", 4191 ha="right", 4192 va="center", 4193 fontsize=font_size, 4194 arrowprops=dict( 4195 arrowstyle="->", 4196 lw=1, 4197 color="black", 4198 connectionstyle="angle3,angleA=0,angleB=90", 4199 ), 4200 ) 4201 4202 ax.set_ylim(0, 100) 4203 ax.set_xlabel(s, fontsize=font_size) 4204 ax.xaxis.label.set_rotation(30) 4205 4206 ax.set_xticks([]) 4207 ax.set_yticks([]) 4208 for spine in ax.spines.values(): 4209 spine.set_visible(False) 4210 4211 legend_handles = [ 4212 Patch(facecolor=cmap_dict[label], edgecolor="black", label=label) 4213 for label in legend_labels 4214 ] 4215 4216 fig.legend( 4217 handles=legend_handles, 4218 loc="center right", 4219 bbox_to_anchor=legend_bbox, 4220 ncol=legend_split_col, 4221 title="", 4222 ) 4223 4224 plt.tight_layout() 4225 plt.show() 4226 4227 else: 4228 4229 cmap = plt.get_cmap(cmap) 4230 4231 colors = [cmap(i / len(df)) for i in range(len(df))] 4232 4233 fig, ax = plt.subplots(figsize=(width, height)) 4234 4235 values = df["n"].values 4236 names = df["name"].values 4237 perc = df["pct"].values 4238 nums = df["num"].values 4239 4240 bottom = 0 4241 centers = [] 4242 for name, num, val, color in zip(names, nums, values, colors): 4243 ax.bar(0, val, bottom=bottom, color=color, label=f"{num}) {name}") 4244 centers.append(bottom + val / 2) 4245 bottom += val 4246 4247 y_positions = np.linspace(centers[0], centers[-1], len(centers)) 4248 x_text = -0.8 4249 4250 for y_label, y_center, pct, num in zip(y_positions, centers, perc, nums): 4251 ax.annotate( 4252 f"{num}) {pct}", 4253 xy=(0, y_center), 4254 xycoords="data", 4255 xytext=(x_text, y_label), 4256 textcoords="data", 4257 ha="right", 4258 va="center", 4259 fontsize=9, 4260 arrowprops=dict( 4261 arrowstyle="->", 4262 lw=1, 4263 color="black", 4264 connectionstyle="angle3,angleA=0,angleB=90", 4265 ), 4266 ) 4267 4268 ax.set_xticks([]) 4269 ax.set_yticks([]) 4270 for spine in ax.spines.values(): 4271 spine.set_visible(False) 4272 4273 ax.legend( 4274 title="Legend", 4275 bbox_to_anchor=legend_bbox, 4276 loc="upper left", 4277 ncol=legend_split_col, 4278 ) 4279 4280 plt.tight_layout() 4281 plt.show() 4282 4283 return fig 4284 4285 def cell_regression( 4286 self, 4287 cell_x: str, 4288 cell_y: str, 4289 set_x: str | None, 4290 set_y: str | None, 4291 threshold=10, 4292 image_width=12, 4293 image_high=7, 4294 color="black", 4295 ): 4296 """ 4297 Perform regression analysis between two selected cells and visualize the relationship. 4298 4299 This function computes a linear regression between two specified cells from 4300 aggregated normalized data, plots the regression line with scatter points, 4301 annotates regression statistics, and highlights potential outliers. 4302 4303 Parameters 4304 ---------- 4305 cell_x : str 4306 Name of the first cell (X-axis). 4307 4308 cell_y : str 4309 Name of the second cell (Y-axis). 4310 4311 set_x : str or None 4312 Dataset identifier corresponding to `cell_x`. If None, cell is selected only by name. 4313 4314 set_y : str or None 4315 Dataset identifier corresponding to `cell_y`. If None, cell is selected only by name. 4316 4317 threshold : int or float, default 10 4318 Threshold for detecting outliers. Points deviating from the mean or diagonal by more 4319 than this value are annotated. 4320 4321 image_width : int, default 12 4322 Width of the regression plot (in inches). 4323 4324 image_high : int, default 7 4325 Height of the regression plot (in inches). 4326 4327 color : str, default 'black' 4328 Color of the regression scatter points and line. 4329 4330 Returns 4331 ------- 4332 matplotlib.figure.Figure 4333 Regression plot figure with annotated regression line, R², p-value, and outliers. 4334 4335 Raises 4336 ------ 4337 ValueError 4338 If `cell_x` or `cell_y` are not found in the dataset. 4339 If multiple matches are found for a cell name and `set_x`/`set_y` are not specified. 4340 4341 Notes 4342 ----- 4343 - The function automatically calls `jseq_object.average()` if aggregated data is not available. 4344 - Outliers are annotated with their corresponding index labels. 4345 - Regression is computed using `scipy.stats.linregress`. 4346 4347 Examples 4348 -------- 4349 >>> obj.cell_regression(cell_x="Purkinje", cell_y="Granule", set_x="Exp1", set_y="Exp2") 4350 >>> obj.cell_regression(cell_x="NeuronA", cell_y="NeuronB", threshold=5, color="blue") 4351 """ 4352 4353 if self.agg_normalized_data is None: 4354 self.average() 4355 4356 metadata = self.agg_metadata 4357 data = self.agg_normalized_data 4358 4359 if set_x is not None and set_y is not None: 4360 data.columns = metadata["cell_names"] + " # " + metadata["sets"] 4361 cell_x = cell_x + " # " + set_x 4362 cell_y = cell_y + " # " + set_y 4363 4364 else: 4365 data.columns = metadata["cell_names"] 4366 4367 if not cell_x in data.columns: 4368 raise ValueError("'cell_x' value not in cell names!") 4369 4370 if not cell_y in data.columns: 4371 raise ValueError("'cell_y' value not in cell names!") 4372 4373 if list(data.columns).count(cell_x) > 1: 4374 raise ValueError( 4375 f"'{cell_x}' occurs more than once. If you want to select a specific cell, " 4376 f"please also provide the corresponding 'set_x' and 'set_y' values." 4377 ) 4378 4379 if list(data.columns).count(cell_y) > 1: 4380 raise ValueError( 4381 f"'{cell_y}' occurs more than once. If you want to select a specific cell, " 4382 f"please also provide the corresponding 'set_x' and 'set_y' values." 4383 ) 4384 4385 fig, ax = plt.subplots(figsize=(image_width, image_high)) 4386 ax = sns.regplot(x=cell_x, y=cell_y, data=data, color=color) 4387 4388 slope, intercept, r_value, p_value, _ = stats.linregress( 4389 data[cell_x], data[cell_y] 4390 ) 4391 equation = "y = {:.2f}x + {:.2f}".format(slope, intercept) 4392 4393 ax.annotate( 4394 "R-squared = {:.2f}\nP-value = {:.2f}\n{}".format( 4395 r_value**2, p_value, equation 4396 ), 4397 xy=(0.05, 0.90), 4398 xycoords="axes fraction", 4399 fontsize=12, 4400 ) 4401 4402 ax.spines["top"].set_visible(False) 4403 ax.spines["right"].set_visible(False) 4404 4405 diff = [] 4406 x_mean, y_mean = data[cell_x].mean(), data[cell_y].mean() 4407 for i, (xi, yi) in enumerate(zip(data[cell_x], data[cell_y])): 4408 diff.append(abs(xi - x_mean)) 4409 diff.append(abs(yi - y_mean)) 4410 4411 def annotate_outliers(x, y, threshold): 4412 texts = [] 4413 x_mean, y_mean = x.mean(), y.mean() 4414 for i, (xi, yi) in enumerate(zip(x, y)): 4415 if ( 4416 abs(xi - x_mean) > threshold 4417 or abs(yi - y_mean) > threshold 4418 or abs(yi - xi) > threshold 4419 ): 4420 text = ax.text(xi, yi, data.index[i]) 4421 texts.append(text) 4422 4423 return texts 4424 4425 texts = annotate_outliers(data[cell_x], data[cell_y], threshold) 4426 4427 adjust_text(texts, arrowprops=dict(arrowstyle="-", color="gray", alpha=0.5)) 4428 4429 plt.show() 4430 4431 return fig
A class COMPsc (Comparison of single-cell data) designed for the integration,
analysis, and visualization of single-cell datasets.
The class supports independent dataset integration, subclustering of existing clusters,
marker detection, and multiple visualization strategies.
The COMPsc class provides methods for:
- Normalizing and filtering single-cell data
- Loading and saving sparse 10x-style datasets
- Computing differential expression and marker genes
- Clustering and subclustering analysis
- Visualizing similarity and spatial relationships
- Aggregating data by cell and set annotations
- Managing metadata and renaming labels
- Plotting gene detection histograms and feature scatters
Methods
project_dir(path_to_directory, project_list) Scans a directory to create a COMPsc instance mapping project names to their paths.
save_project(name, path=os.getcwd()) Saves the COMPsc object to a pickle file on disk.
load_project(path) Loads a previously saved COMPsc object from a pickle file.
reduce_cols(reg, inc_set=False) Removes columns from data tables where column names contain a specified name or partial substring.
reduce_rows(reg, inc_set=False) Removes rows from data tables where column names contain a specified feature (gene) name.
get_data(set_info=False) Returns normalized data with optional set annotations in column names.
get_metadata() Returns the stored input metadata.
get_partial_data(names=None, features=None, name_slot='cell_names') Return a subset of the data by sample names and/or features.
gene_calculation() Calculates and stores per-cell gene detection counts as a pandas Series.
gene_histograme(bins=100) Plots a histogram of genes detected per cell with an overlaid normal distribution.
gene_threshold(min_n=None, max_n=None) Filters cells based on minimum and/or maximum gene detection thresholds.
load_sparse_from_projects(normalized_data=False) Loads and concatenates sparse 10x-style datasets from project paths into count or normalized data.
rename_names(mapping, slot='cell_names') Renames entries in a specified metadata column using a provided mapping dictionary.
rename_subclusters(mapping) Renames subcluster labels using a provided mapping dictionary.
save_sparse(path_to_save=os.getcwd(), name_slot='cell_names', data_slot='normalized') Exports data as 10x-compatible sparse files (matrix.mtx, barcodes.tsv, genes.tsv).
normalize_data(normalize=True, normalize_factor=100000) Normalizes raw counts to counts-per-specified factor (e.g., CPM-like).
statistic(cells=None, sets=None, min_exp=0.01, min_pct=0.1, n_proc=10) Computes per-feature differential expression statistics (Mann-Whitney U) comparing target vs. rest groups.
calculate_difference_markers(min_exp=0, min_pct=0.25, n_proc=10, force=False) Computes and caches differential markers using the statistic method.
clustering_features(features_list=None, name_slot='cell_names', p_val=0.05, top_n=25, adj_mean=True, beta=0.4) Prepares clustering input by selecting marker features and optionally smoothing cell values.
average() Aggregates normalized data by averaging across (cell_name, set) pairs.
estimating_similarity(method='pearson', p_val=0.05, top_n=25) Computes pairwise correlation and Euclidean distance between aggregated samples.
similarity_plot(split_sets=True, set_info=True, cmap='seismic', width=12, height=10) Visualizes pairwise similarity as a scatter plot with correlation as hue and scaled distance as point size.
spatial_similarity(set_info=True, bandwidth=1, n_neighbors=5, min_dist=0.1, legend_split=2, point_size=20, ...) Creates a UMAP-like visualization of similarity relationships with cluster hulls and nearest-neighbor arrows.
subcluster_prepare(features, cluster) Initializes a Clustering object for subcluster analysis on a selected parent cluster.
define_subclusters(umap_num=2, eps=0.5, min_samples=10, bandwidth=1, n_neighbors=5, min_dist=0.1, ...) Performs UMAP and DBSCAN clustering on prepared subcluster data and stores cluster labels.
subcluster_features_scatter(colors='viridis', hclust='complete', img_width=3, img_high=5, label_size=6, ...) Visualizes averaged expression and occurrence of features for subclusters as a scatter plot.
subcluster_DEG_scatter(top_n=3, min_exp=0, min_pct=0.25, p_val=0.05, colors='viridis', ...) Plots top differential features for subclusters as a features-scatter visualization.
accept_subclusters() Commits subcluster labels to main metadata by renaming cell names and clears subcluster data.
Raises
ValueError For invalid parameters, mismatched dimensions, or missing metadata.
1144 def __init__( 1145 self, 1146 objects=None, 1147 ): 1148 """ 1149 Initialize the COMPsc class for single-cell data integration and analysis. 1150 1151 Parameters 1152 ---------- 1153 objects : list or None, optional 1154 Optional list of data objects to initialize the instance with. 1155 1156 Attributes 1157 ---------- 1158 -objects : list or None 1159 -input_data : pandas.DataFrame or None 1160 -input_metadata : pandas.DataFrame or None 1161 -normalized_data : pandas.DataFrame or None 1162 -agg_metadata : pandas.DataFrame or None 1163 -agg_normalized_data : pandas.DataFrame or None 1164 -similarity : pandas.DataFrame or None 1165 -var_data : pandas.DataFrame or None 1166 -subclusters_ : instance of Clustering class or None 1167 -cells_calc : pandas.Series or None 1168 -gene_calc : pandas.Series or None 1169 -composition_data : pandas.DataFrame or None 1170 """ 1171 1172 self.objects = objects 1173 """ Stores the input data objects.""" 1174 1175 self.input_data = None 1176 """Raw input data for clustering or integration analysis.""" 1177 1178 self.input_metadata = None 1179 """Metadata associated with the input data.""" 1180 1181 self.normalized_data = None 1182 """Normalized version of the input data.""" 1183 1184 self.agg_metadata = None 1185 '''Aggregated metadata for all sets in object related to "agg_normalized_data"''' 1186 1187 self.agg_normalized_data = None 1188 """Aggregated and normalized data across multiple sets.""" 1189 1190 self.similarity = None 1191 """Similarity data between cells across all samples. and sets""" 1192 1193 self.var_data = None 1194 """DEG analysis results summarizing variance across all samples in the object.""" 1195 1196 self.subclusters_ = None 1197 """Placeholder for information about subclusters analysis; if computed.""" 1198 1199 self.cells_calc = None 1200 """Number of cells detected per sample (grouped by lineage, e.g., cluster or name), reflecting data composition.""" 1201 1202 self.gene_calc = None 1203 """Number of genes detected per sample (cell), reflecting the sequencing depth.""" 1204 1205 self.composition_data = None 1206 """Data describing composition of cells across clusters or sets."""
Initialize the COMPsc class for single-cell data integration and analysis.
Parameters
objects : list or None, optional Optional list of data objects to initialize the instance with.
Attributes
-objects : list or None -input_data : pandas.DataFrame or None -input_metadata : pandas.DataFrame or None -normalized_data : pandas.DataFrame or None -agg_metadata : pandas.DataFrame or None -agg_normalized_data : pandas.DataFrame or None -similarity : pandas.DataFrame or None -var_data : pandas.DataFrame or None -subclusters_ : instance of Clustering class or None -cells_calc : pandas.Series or None -gene_calc : pandas.Series or None -composition_data : pandas.DataFrame or None
Number of cells detected per sample (grouped by lineage, e.g., cluster or name), reflecting data composition.
1208 @classmethod 1209 def project_dir(cls, path_to_directory, project_list): 1210 """ 1211 Scan a directory and build a COMPsc instance mapping provided project names 1212 to their paths. 1213 1214 Parameters 1215 ---------- 1216 path_to_directory : str 1217 Path containing project subfolders. 1218 1219 project_list : list[str] 1220 List of filenames (folder names) to include in the returned object map. 1221 1222 Returns 1223 ------- 1224 COMPsc 1225 New COMPsc instance with `objects` populated. 1226 1227 Raises 1228 ------ 1229 Exception 1230 A generic exception is caught and a message printed if scanning fails. 1231 1232 Notes 1233 ----- 1234 Function attempts to match entries in `project_list` to directory 1235 names and constructs a simplified object key from the folder name. 1236 """ 1237 try: 1238 objects = {} 1239 for filename in tqdm(os.listdir(path_to_directory)): 1240 for c in project_list: 1241 f = os.path.join(path_to_directory, filename) 1242 if c == filename and os.path.isdir(f): 1243 objects[str(c)] = f 1244 1245 return cls(objects) 1246 1247 except: 1248 print("Something went wrong. Check the function input data and try again!")
Scan a directory and build a COMPsc instance mapping provided project names to their paths.
Parameters
path_to_directory : str Path containing project subfolders.
project_list : list[str] List of filenames (folder names) to include in the returned object map.
Returns
COMPsc
New COMPsc instance with objects populated.
Raises
Exception A generic exception is caught and a message printed if scanning fails.
Notes
Function attempts to match entries in project_list to directory
names and constructs a simplified object key from the folder name.
1250 def save_project(self, name, path: str = os.getcwd()): 1251 """ 1252 Save the COMPsc object to disk using pickle. 1253 1254 Parameters 1255 ---------- 1256 name : str 1257 Base filename (without extension) to use when saving. 1258 1259 path : str, default os.getcwd() 1260 Directory in which to save the project file. 1261 1262 Returns 1263 ------- 1264 None 1265 1266 Side Effects 1267 ------------ 1268 - Writes a file `<path>/<name>.jpkl` containing the pickled object. 1269 - Prints a confirmation message with saved path. 1270 """ 1271 1272 full = os.path.join(path, f"{name}.jpkl") 1273 1274 with open(full, "wb") as f: 1275 pickle.dump(self, f) 1276 1277 print(f"Project saved as {full}")
Save the COMPsc object to disk using pickle.
Parameters
name : str Base filename (without extension) to use when saving.
path : str, default os.getcwd() Directory in which to save the project file.
Returns
None
Side Effects
- Writes a file
<path>/<name>.jpklcontaining the pickled object. - Prints a confirmation message with saved path.
1279 @classmethod 1280 def load_project(cls, path): 1281 """ 1282 Load a previously saved COMPsc project from a pickle file. 1283 1284 Parameters 1285 ---------- 1286 path : str 1287 Full path to the pickled project file. 1288 1289 Returns 1290 ------- 1291 COMPsc 1292 The unpickled COMPsc object. 1293 1294 Raises 1295 ------ 1296 FileNotFoundError 1297 If the provided path does not exist. 1298 """ 1299 1300 if not os.path.exists(path): 1301 raise FileNotFoundError("File does not exist!") 1302 with open(path, "rb") as f: 1303 obj = pickle.load(f) 1304 return obj
Load a previously saved COMPsc project from a pickle file.
Parameters
path : str Full path to the pickled project file.
Returns
COMPsc The unpickled COMPsc object.
Raises
FileNotFoundError If the provided path does not exist.
1306 def reduce_cols( 1307 self, 1308 reg: str | None = None, 1309 full: str | None = None, 1310 name_slot: str = "cell_names", 1311 inc_set: bool = False, 1312 ): 1313 """ 1314 Remove columns (cells) whose names contain a substring `reg` or 1315 full name `full` from available tables. 1316 1317 Parameters 1318 ---------- 1319 reg : str | None 1320 Substring to search for in column/cell names; matching columns will be removed. 1321 If not None, `full` must be None. 1322 1323 full : str | None 1324 Full name to search for in column/cell names; matching columns will be removed. 1325 If not None, `reg` must be None. 1326 1327 name_slot : str, default 'cell_names' 1328 Column in metadata to use as sample names. 1329 1330 inc_set : bool, default False 1331 If True, column names are interpreted as 'cell_name # set' when matching. 1332 1333 Update 1334 ------------ 1335 Mutates `self.input_data`, `self.normalized_data`, `self.input_metadata`, 1336 `self.agg_normalized_data`, and `self.agg_metadata` (if they exist), 1337 removing columns/rows that match `reg`. 1338 1339 Raises 1340 ------ 1341 Raises ValueError if nothing matches the reduction mask. 1342 """ 1343 1344 if reg is None and full is None: 1345 raise ValueError( 1346 "Both 'reg' and 'full' arguments not provided. Please provide at least one of them!" 1347 ) 1348 1349 if reg is not None and full is not None: 1350 raise ValueError( 1351 "Both 'reg' and 'full' arguments are provided. " 1352 "Please provide only one of them!\n" 1353 "'reg' is used when only part of the name must be detected.\n" 1354 "'full' is used if the full name must be detected." 1355 ) 1356 1357 if reg is not None: 1358 1359 if self.input_data is not None: 1360 1361 if inc_set: 1362 1363 self.input_data.columns = ( 1364 self.input_metadata[name_slot] 1365 + " # " 1366 + self.input_metadata["sets"] 1367 ) 1368 1369 else: 1370 1371 self.input_data.columns = self.input_metadata[name_slot] 1372 1373 mask = [reg.upper() not in x.upper() for x in self.input_data.columns] 1374 1375 if len([y for y in mask if y is False]) == 0: 1376 raise ValueError("Nothing found to reduce") 1377 1378 self.input_data = self.input_data.loc[:, mask] 1379 1380 if self.normalized_data is not None: 1381 1382 if inc_set: 1383 1384 self.normalized_data.columns = ( 1385 self.input_metadata[name_slot] 1386 + " # " 1387 + self.input_metadata["sets"] 1388 ) 1389 1390 else: 1391 1392 self.normalized_data.columns = self.input_metadata[name_slot] 1393 1394 mask = [ 1395 reg.upper() not in x.upper() for x in self.normalized_data.columns 1396 ] 1397 1398 if len([y for y in mask if y is False]) == 0: 1399 raise ValueError("Nothing found to reduce") 1400 1401 self.normalized_data = self.normalized_data.loc[:, mask] 1402 1403 if self.input_metadata is not None: 1404 1405 if inc_set: 1406 1407 self.input_metadata["drop"] = ( 1408 self.input_metadata[name_slot] 1409 + " # " 1410 + self.input_metadata["sets"] 1411 ) 1412 1413 else: 1414 1415 self.input_metadata["drop"] = self.input_metadata[name_slot] 1416 1417 mask = [ 1418 reg.upper() not in x.upper() for x in self.input_metadata["drop"] 1419 ] 1420 1421 if len([y for y in mask if y is False]) == 0: 1422 raise ValueError("Nothing found to reduce") 1423 1424 self.input_metadata = self.input_metadata.loc[mask, :].reset_index( 1425 drop=True 1426 ) 1427 1428 self.input_metadata = self.input_metadata.drop( 1429 columns=["drop"], errors="ignore" 1430 ) 1431 1432 if self.agg_normalized_data is not None: 1433 1434 if inc_set: 1435 1436 self.agg_normalized_data.columns = ( 1437 self.agg_metadata[name_slot] + " # " + self.agg_metadata["sets"] 1438 ) 1439 1440 else: 1441 1442 self.agg_normalized_data.columns = self.agg_metadata[name_slot] 1443 1444 mask = [ 1445 reg.upper() not in x.upper() 1446 for x in self.agg_normalized_data.columns 1447 ] 1448 1449 if len([y for y in mask if y is False]) == 0: 1450 raise ValueError("Nothing found to reduce") 1451 1452 self.agg_normalized_data = self.agg_normalized_data.loc[:, mask] 1453 1454 if self.agg_metadata is not None: 1455 1456 if inc_set: 1457 1458 self.agg_metadata["drop"] = ( 1459 self.agg_metadata[name_slot] + " # " + self.agg_metadata["sets"] 1460 ) 1461 1462 else: 1463 1464 self.agg_metadata["drop"] = self.agg_metadata[name_slot] 1465 1466 mask = [reg.upper() not in x.upper() for x in self.agg_metadata["drop"]] 1467 1468 if len([y for y in mask if y is False]) == 0: 1469 raise ValueError("Nothing found to reduce") 1470 1471 self.agg_metadata = self.agg_metadata.loc[mask, :].reset_index( 1472 drop=True 1473 ) 1474 1475 self.agg_metadata = self.agg_metadata.drop( 1476 columns=["drop"], errors="ignore" 1477 ) 1478 1479 elif full is not None: 1480 1481 if self.input_data is not None: 1482 1483 if inc_set: 1484 1485 self.input_data.columns = ( 1486 self.input_metadata[name_slot] 1487 + " # " 1488 + self.input_metadata["sets"] 1489 ) 1490 1491 if "#" not in full: 1492 1493 self.input_data.columns = self.input_metadata[name_slot] 1494 1495 print( 1496 "Not include the set info (name # set) in the 'full' argument, where 'inc_set' is True" 1497 "Only the names will be compared, without considering the set information." 1498 ) 1499 1500 else: 1501 1502 self.input_data.columns = self.input_metadata[name_slot] 1503 1504 mask = [full.upper() != x.upper() for x in self.input_data.columns] 1505 1506 if len([y for y in mask if y is False]) == 0: 1507 raise ValueError("Nothing found to reduce") 1508 1509 self.input_data = self.input_data.loc[:, mask] 1510 1511 if self.normalized_data is not None: 1512 1513 if inc_set: 1514 1515 self.normalized_data.columns = ( 1516 self.input_metadata[name_slot] 1517 + " # " 1518 + self.input_metadata["sets"] 1519 ) 1520 1521 if "#" not in full: 1522 1523 self.normalized_data.columns = self.input_metadata[name_slot] 1524 1525 print( 1526 "Not include the set info (name # set) in the 'full' argument, where 'inc_set' is True" 1527 "Only the names will be compared, without considering the set information." 1528 ) 1529 1530 else: 1531 1532 self.normalized_data.columns = self.input_metadata[name_slot] 1533 1534 mask = [full.upper() != x.upper() for x in self.normalized_data.columns] 1535 1536 if len([y for y in mask if y is False]) == 0: 1537 raise ValueError("Nothing found to reduce") 1538 1539 self.normalized_data = self.normalized_data.loc[:, mask] 1540 1541 if self.input_metadata is not None: 1542 1543 if inc_set: 1544 1545 self.input_metadata["drop"] = ( 1546 self.input_metadata[name_slot] 1547 + " # " 1548 + self.input_metadata["sets"] 1549 ) 1550 1551 if "#" not in full: 1552 1553 self.input_metadata["drop"] = self.input_metadata[name_slot] 1554 1555 print( 1556 "Not include the set info (name # set) in the 'full' argument, where 'inc_set' is True" 1557 "Only the names will be compared, without considering the set information." 1558 ) 1559 1560 else: 1561 1562 self.input_metadata["drop"] = self.input_metadata[name_slot] 1563 1564 mask = [full.upper() != x.upper() for x in self.input_metadata["drop"]] 1565 1566 if len([y for y in mask if y is False]) == 0: 1567 raise ValueError("Nothing found to reduce") 1568 1569 self.input_metadata = self.input_metadata.loc[mask, :].reset_index( 1570 drop=True 1571 ) 1572 1573 self.input_metadata = self.input_metadata.drop( 1574 columns=["drop"], errors="ignore" 1575 ) 1576 1577 if self.agg_normalized_data is not None: 1578 1579 if inc_set: 1580 1581 self.agg_normalized_data.columns = ( 1582 self.agg_metadata[name_slot] + " # " + self.agg_metadata["sets"] 1583 ) 1584 1585 if "#" not in full: 1586 1587 self.agg_normalized_data.columns = self.agg_metadata[name_slot] 1588 1589 print( 1590 "Not include the set info (name # set) in the 'full' argument, where 'inc_set' is True" 1591 "Only the names will be compared, without considering the set information." 1592 ) 1593 else: 1594 1595 self.agg_normalized_data.columns = self.agg_metadata[name_slot] 1596 1597 mask = [ 1598 full.upper() != x.upper() for x in self.agg_normalized_data.columns 1599 ] 1600 1601 if len([y for y in mask if y is False]) == 0: 1602 raise ValueError("Nothing found to reduce") 1603 1604 self.agg_normalized_data = self.agg_normalized_data.loc[:, mask] 1605 1606 if self.agg_metadata is not None: 1607 1608 if inc_set: 1609 1610 self.agg_metadata["drop"] = ( 1611 self.agg_metadata[name_slot] + " # " + self.agg_metadata["sets"] 1612 ) 1613 1614 if "#" not in full: 1615 1616 self.agg_metadata["drop"] = self.agg_metadata[name_slot] 1617 1618 print( 1619 "Not include the set info (name # set) in the 'full' argument, where 'inc_set' is True" 1620 "Only the names will be compared, without considering the set information." 1621 ) 1622 else: 1623 1624 self.agg_metadata["drop"] = self.agg_metadata[name_slot] 1625 1626 mask = [full.upper() != x.upper() for x in self.agg_metadata["drop"]] 1627 1628 if len([y for y in mask if y is False]) == 0: 1629 raise ValueError("Nothing found to reduce") 1630 1631 self.agg_metadata = self.agg_metadata.loc[mask, :].reset_index( 1632 drop=True 1633 ) 1634 1635 self.agg_metadata = self.agg_metadata.drop( 1636 columns=["drop"], errors="ignore" 1637 ) 1638 1639 self.gene_calculation() 1640 self.cells_calculation()
Remove columns (cells) whose names contain a substring reg or
full name full from available tables.
Parameters
reg : str | None
Substring to search for in column/cell names; matching columns will be removed.
If not None, full must be None.
full : str | None
Full name to search for in column/cell names; matching columns will be removed.
If not None, reg must be None.
name_slot : str, default 'cell_names' Column in metadata to use as sample names.
inc_set : bool, default False If True, column names are interpreted as 'cell_name # set' when matching.
Update
Mutates self.input_data, self.normalized_data, self.input_metadata,
self.agg_normalized_data, and self.agg_metadata (if they exist),
removing columns/rows that match reg.
Raises
Raises ValueError if nothing matches the reduction mask.
1642 def reduce_rows(self, features_list: list): 1643 """ 1644 Remove rows (features) whose names are included in features_list. 1645 1646 Parameters 1647 ---------- 1648 features_list : list 1649 List of features to search for in index/gene names; matching entries will be removed. 1650 1651 Update 1652 ------------ 1653 Mutates `self.input_data`, `self.normalized_data`, `self.input_metadata`, 1654 `self.agg_normalized_data`, and `self.agg_metadata` (if they exist), 1655 removing columns/rows that match `reg`. 1656 1657 Raises 1658 ------ 1659 Prints a message listing features that are not found in the data. 1660 """ 1661 1662 if self.input_data is not None: 1663 1664 res = find_features(self.input_data, features=features_list) 1665 1666 res_list = [x.upper() for x in res["included"]] 1667 1668 mask = [x.upper() not in res_list for x in self.input_data.index] 1669 1670 if len([y for y in mask if y is False]) == 0: 1671 raise ValueError("Nothing found to reduce") 1672 1673 self.input_data = self.input_data.loc[mask, :] 1674 1675 if self.normalized_data is not None: 1676 1677 res = find_features(self.normalized_data, features=features_list) 1678 1679 res_list = [x.upper() for x in res["included"]] 1680 1681 mask = [x.upper() not in res_list for x in self.normalized_data.index] 1682 1683 if len([y for y in mask if y is False]) == 0: 1684 raise ValueError("Nothing found to reduce") 1685 1686 self.normalized_data = self.normalized_data.loc[mask, :] 1687 1688 if self.agg_normalized_data is not None: 1689 1690 res = find_features(self.agg_normalized_data, features=features_list) 1691 1692 res_list = [x.upper() for x in res["included"]] 1693 1694 mask = [x.upper() not in res_list for x in self.agg_normalized_data.index] 1695 1696 if len([y for y in mask if y is False]) == 0: 1697 raise ValueError("Nothing found to reduce") 1698 1699 self.agg_normalized_data = self.agg_normalized_data.loc[mask, :] 1700 1701 if len(res["not_included"]) > 0: 1702 print("\nFeatures not found:") 1703 for i in res["not_included"]: 1704 print(i) 1705 1706 self.gene_calculation() 1707 self.cells_calculation()
Remove rows (features) whose names are included in features_list.
Parameters
features_list : list List of features to search for in index/gene names; matching entries will be removed.
Update
Mutates self.input_data, self.normalized_data, self.input_metadata,
self.agg_normalized_data, and self.agg_metadata (if they exist),
removing columns/rows that match reg.
Raises
Prints a message listing features that are not found in the data.
1709 def get_data(self, set_info: bool = False): 1710 """ 1711 Return normalized data with optional set annotation appended to column names. 1712 1713 Parameters 1714 ---------- 1715 set_info : bool, default False 1716 If True, column names are returned as "cell_name # set"; otherwise 1717 only the `cell_name` is used. 1718 1719 Returns 1720 ------- 1721 pandas.DataFrame 1722 The `self.normalized_data` table with columns renamed according to `set_info`. 1723 1724 Raises 1725 ------ 1726 AttributeError 1727 If `self.normalized_data` or `self.input_metadata` is missing. 1728 """ 1729 1730 to_return = self.normalized_data 1731 1732 if set_info: 1733 to_return.columns = ( 1734 self.input_metadata["cell_names"] + " # " + self.input_metadata["sets"] 1735 ) 1736 else: 1737 to_return.columns = self.input_metadata["cell_names"] 1738 1739 return to_return
Return normalized data with optional set annotation appended to column names.
Parameters
set_info : bool, default False
If True, column names are returned as "cell_name # set"; otherwise
only the cell_name is used.
Returns
pandas.DataFrame
The self.normalized_data table with columns renamed according to set_info.
Raises
AttributeError
If self.normalized_data or self.input_metadata is missing.
1741 def get_partial_data( 1742 self, 1743 names: list | str | None = None, 1744 features: list | str | None = None, 1745 name_slot: str = "cell_names", 1746 inc_metadata: bool = False, 1747 ): 1748 """ 1749 Return a subset of the data filtered by sample names and/or feature names. 1750 1751 Parameters 1752 ---------- 1753 names : list, str, or None 1754 Names of samples to include. If None, all samples are considered. 1755 1756 features : list, str, or None 1757 Names of features to include. If None, all features are considered. 1758 1759 name_slot : str 1760 Column in metadata to use as sample names. 1761 1762 inc_metadata : bool 1763 If True return tuple (data, metadata) 1764 1765 Returns 1766 ------- 1767 pandas.DataFrame 1768 Subset of the normalized data based on the specified names and features. 1769 """ 1770 1771 data = self.normalized_data.copy() 1772 metadata = self.input_metadata 1773 1774 if name_slot in self.input_metadata.columns: 1775 data.columns = self.input_metadata[name_slot] 1776 else: 1777 raise ValueError("'name_slot' not occured in data!'") 1778 1779 if isinstance(features, str): 1780 features = [features] 1781 elif features is None: 1782 features = [] 1783 1784 if isinstance(names, str): 1785 names = [names] 1786 elif names is None: 1787 names = [] 1788 1789 features = [x.upper() for x in features] 1790 names = [x.upper() for x in names] 1791 1792 columns_names = [x.upper() for x in data.columns] 1793 features_names = [x.upper() for x in data.index] 1794 1795 columns_bool = [True if x in names else False for x in columns_names] 1796 features_bool = [True if x in features else False for x in features_names] 1797 1798 if True not in columns_bool and True not in features_bool: 1799 print("Missing 'names' and/or 'features'. Returning full dataset instead.") 1800 1801 if True in columns_bool: 1802 data = data.loc[:, columns_bool] 1803 metadata = metadata.loc[columns_bool, :] 1804 1805 if True in features_bool: 1806 data = data.loc[features_bool, :] 1807 1808 not_in_features = [y for y in features if y not in features_names] 1809 1810 if len(not_in_features) > 0: 1811 print('\nThe following features were not found in data:') 1812 print('\n'.join(not_in_features)) 1813 1814 not_in_names = [y for y in names if y not in columns_names] 1815 1816 if len(not_in_names) > 0: 1817 print('\nThe following names were not found in data:') 1818 print('\n'.join(not_in_names)) 1819 1820 if inc_metadata: 1821 return data, metadata 1822 else: 1823 return data
Return a subset of the data filtered by sample names and/or feature names.
Parameters
names : list, str, or None Names of samples to include. If None, all samples are considered.
features : list, str, or None Names of features to include. If None, all features are considered.
name_slot : str Column in metadata to use as sample names.
inc_metadata : bool If True return tuple (data, metadata)
Returns
pandas.DataFrame Subset of the normalized data based on the specified names and features.
1825 def get_metadata(self): 1826 """ 1827 Return the stored input metadata. 1828 1829 Returns 1830 ------- 1831 pandas.DataFrame 1832 `self.input_metadata` (may be None if not set). 1833 """ 1834 1835 to_return = self.input_metadata 1836 1837 return to_return
Return the stored input metadata.
Returns
pandas.DataFrame
self.input_metadata (may be None if not set).
1839 def gene_calculation(self): 1840 """ 1841 Calculate and store per-cell counts (e.g., number of detected genes). 1842 1843 The method computes a binary (presence/absence) per cell and sums across 1844 features to produce `self.gene_calc`. 1845 1846 Update 1847 ------ 1848 Sets `self.gene_calc` as a pandas.Series. 1849 1850 Side Effects 1851 ------------ 1852 Uses `self.input_data` when available, otherwise `self.normalized_data`. 1853 """ 1854 1855 if self.input_data is not None: 1856 1857 bin_col = self.input_data.columns.copy() 1858 1859 bin_col = bin_col.where(bin_col <= 0, 1) 1860 1861 sum_data = bin_col.sum(axis=0) 1862 1863 self.gene_calc = sum_data 1864 1865 elif self.normalized_data is not None: 1866 1867 bin_col = self.normalized_data.copy() 1868 1869 bin_col = bin_col.where(bin_col <= 0, 1) 1870 1871 sum_data = bin_col.sum(axis=0) 1872 1873 self.gene_calc = sum_data
Calculate and store per-cell counts (e.g., number of detected genes).
The method computes a binary (presence/absence) per cell and sums across
features to produce self.gene_calc.
Update
Sets self.gene_calc as a pandas.Series.
Side Effects
Uses self.input_data when available, otherwise self.normalized_data.
1875 def gene_histograme(self, bins=100): 1876 """ 1877 Plot a histogram of the number of genes detected per cell. 1878 1879 Parameters 1880 ---------- 1881 bins : int, default 100 1882 Number of histogram bins. 1883 1884 Returns 1885 ------- 1886 matplotlib.figure.Figure 1887 Figure containing the histogram of gene contents. 1888 1889 Notes 1890 ----- 1891 Requires `self.gene_calc` to be computed prior to calling. 1892 """ 1893 1894 fig, ax = plt.subplots(figsize=(8, 5)) 1895 1896 _, bin_edges, _ = ax.hist( 1897 self.gene_calc, bins=bins, edgecolor="black", alpha=0.6 1898 ) 1899 1900 mu, sigma = np.mean(self.gene_calc), np.std(self.gene_calc) 1901 1902 x = np.linspace(min(self.gene_calc), max(self.gene_calc), 1000) 1903 y = norm.pdf(x, mu, sigma) 1904 1905 y_scaled = y * len(self.gene_calc) * (bin_edges[1] - bin_edges[0]) 1906 1907 ax.plot( 1908 x, y_scaled, "r-", linewidth=2, label=f"Normal(μ={mu:.2f}, σ={sigma:.2f})" 1909 ) 1910 1911 ax.set_xlabel("Value") 1912 ax.set_ylabel("Count") 1913 ax.set_title("Histogram of genes detected per cell") 1914 1915 ax.set_xticks(np.linspace(min(self.gene_calc), max(self.gene_calc), 20)) 1916 ax.tick_params(axis="x", rotation=90) 1917 1918 ax.legend() 1919 1920 return fig
Plot a histogram of the number of genes detected per cell.
Parameters
bins : int, default 100 Number of histogram bins.
Returns
matplotlib.figure.Figure Figure containing the histogram of gene contents.
Notes
Requires self.gene_calc to be computed prior to calling.
1922 def gene_threshold(self, min_n: int | None, max_n: int | None): 1923 """ 1924 Filter cells by gene-detection thresholds (min and/or max). 1925 1926 Parameters 1927 ---------- 1928 min_n : int or None 1929 Minimum number of detected genes required to keep a cell. 1930 1931 max_n : int or None 1932 Maximum number of detected genes allowed to keep a cell. 1933 1934 Update 1935 ------- 1936 Filters `self.input_data`, `self.normalized_data`, `self.input_metadata` 1937 (and calls `average()` if `self.agg_normalized_data` exists). 1938 1939 Side Effects 1940 ------------ 1941 Raises ValueError if both bounds are None or if filtering removes all cells. 1942 """ 1943 1944 if min_n is not None and max_n is not None: 1945 mask = (self.gene_calc > min_n) & (self.gene_calc < max_n) 1946 elif min_n is None and max_n is not None: 1947 mask = self.gene_calc < max_n 1948 elif min_n is not None and max_n is None: 1949 mask = self.gene_calc > min_n 1950 else: 1951 raise ValueError("Lack of both min_n and max_n values") 1952 1953 if self.input_data is not None: 1954 1955 if len([y for y in mask if y is False]) == 0: 1956 raise ValueError("Nothing to reduce") 1957 1958 self.input_data = self.input_data.loc[:, mask.values] 1959 1960 if self.normalized_data is not None: 1961 1962 if len([y for y in mask if y is False]) == 0: 1963 raise ValueError("Nothing to reduce") 1964 1965 self.normalized_data = self.normalized_data.loc[:, mask.values] 1966 1967 if self.input_metadata is not None: 1968 1969 if len([y for y in mask if y is False]) == 0: 1970 raise ValueError("Nothing to reduce") 1971 1972 self.input_metadata = self.input_metadata.loc[mask.values, :].reset_index( 1973 drop=True 1974 ) 1975 1976 self.input_metadata = self.input_metadata.drop( 1977 columns=["drop"], errors="ignore" 1978 ) 1979 1980 if self.agg_normalized_data is not None: 1981 self.average() 1982 1983 self.gene_calculation() 1984 self.cells_calculation()
Filter cells by gene-detection thresholds (min and/or max).
Parameters
min_n : int or None Minimum number of detected genes required to keep a cell.
max_n : int or None Maximum number of detected genes allowed to keep a cell.
Update
Filters self.input_data, self.normalized_data, self.input_metadata
(and calls average() if self.agg_normalized_data exists).
Side Effects
Raises ValueError if both bounds are None or if filtering removes all cells.
1986 def cells_calculation(self, name_slot="cell_names"): 1987 """ 1988 Calculate number of cells per call name / cluster. 1989 1990 The method computes a binary (presence/absence) per cell name / cluster and sums across 1991 cells. 1992 1993 Parameters 1994 ---------- 1995 name_slot : str, default 'cell_names' 1996 Column in metadata to use as sample names. 1997 1998 Update 1999 ------ 2000 Sets `self.cells_calc` as a pd.DataFrame. 2001 """ 2002 2003 ls = list(self.input_metadata[name_slot]) 2004 2005 df = pd.DataFrame( 2006 { 2007 "cluster": pd.Series(ls).value_counts().index, 2008 "n": pd.Series(ls).value_counts().values, 2009 } 2010 ) 2011 2012 self.cells_calc = df
Calculate number of cells per call name / cluster.
The method computes a binary (presence/absence) per cell name / cluster and sums across cells.
Parameters
name_slot : str, default 'cell_names' Column in metadata to use as sample names.
Update
Sets self.cells_calc as a pd.DataFrame.
2014 def cell_histograme(self, name_slot: str = "cell_names"): 2015 """ 2016 Plot a histogram of the number of cells detected per cell name (cluster). 2017 2018 Parameters 2019 ---------- 2020 name_slot : str, default 'cell_names' 2021 Column in metadata to use as sample names. 2022 2023 Returns 2024 ------- 2025 matplotlib.figure.Figure 2026 Figure containing the histogram of cell contents. 2027 2028 Notes 2029 ----- 2030 Requires `self.cells_calc` to be computed prior to calling. 2031 """ 2032 2033 if name_slot != "cell_names": 2034 self.cells_calculation(name_slot=name_slot) 2035 2036 fig, ax = plt.subplots(figsize=(8, 5)) 2037 2038 _, bin_edges, _ = ax.hist( 2039 list(self.cells_calc["n"]), 2040 bins=len(set(self.cells_calc["cluster"])), 2041 edgecolor="black", 2042 color="orange", 2043 alpha=0.6, 2044 ) 2045 2046 mu, sigma = np.mean(list(self.cells_calc["n"])), np.std( 2047 list(self.cells_calc["n"]) 2048 ) 2049 2050 x = np.linspace( 2051 min(list(self.cells_calc["n"])), max(list(self.cells_calc["n"])), 1000 2052 ) 2053 y = norm.pdf(x, mu, sigma) 2054 2055 y_scaled = y * len(list(self.cells_calc["n"])) * (bin_edges[1] - bin_edges[0]) 2056 2057 ax.plot( 2058 x, y_scaled, "r-", linewidth=2, label=f"Normal(μ={mu:.2f}, σ={sigma:.2f})" 2059 ) 2060 2061 ax.set_xlabel("Value") 2062 ax.set_ylabel("Count") 2063 ax.set_title("Histogram of cells detected per cell name / cluster") 2064 2065 ax.set_xticks( 2066 np.linspace( 2067 min(list(self.cells_calc["n"])), max(list(self.cells_calc["n"])), 20 2068 ) 2069 ) 2070 ax.tick_params(axis="x", rotation=90) 2071 2072 ax.legend() 2073 2074 return fig
Plot a histogram of the number of cells detected per cell name (cluster).
Parameters
name_slot : str, default 'cell_names' Column in metadata to use as sample names.
Returns
matplotlib.figure.Figure Figure containing the histogram of cell contents.
Notes
Requires self.cells_calc to be computed prior to calling.
2076 def cluster_threshold(self, min_n: int | None, name_slot: str = "cell_names"): 2077 """ 2078 Filter cell names / clusters by cell-detection threshold. 2079 2080 Parameters 2081 ---------- 2082 min_n : int or None 2083 Minimum number of detected genes required to keep a cell. 2084 2085 name_slot : str, default 'cell_names' 2086 Column in metadata to use as sample names. 2087 2088 2089 Update 2090 ------- 2091 Filters `self.input_data`, `self.normalized_data`, `self.input_metadata` 2092 (and calls `average()` if `self.agg_normalized_data` exists). 2093 """ 2094 2095 if name_slot != "cell_names": 2096 self.cells_calculation(name_slot=name_slot) 2097 2098 if min_n is not None: 2099 names = self.cells_calc["cluster"][self.cells_calc["n"] < min_n] 2100 else: 2101 raise ValueError("Lack of min_n value") 2102 2103 if len(names) > 0: 2104 2105 if self.input_data is not None: 2106 2107 self.input_data.columns = self.input_metadata[name_slot] 2108 2109 mask = [not any(r in x for r in names) for x in self.input_data.columns] 2110 2111 if len([y for y in mask if y is False]) > 0: 2112 2113 self.input_data = self.input_data.loc[:, mask] 2114 2115 if self.normalized_data is not None: 2116 2117 self.normalized_data.columns = self.input_metadata[name_slot] 2118 2119 mask = [ 2120 not any(r in x for r in names) for x in self.normalized_data.columns 2121 ] 2122 2123 if len([y for y in mask if y is False]) > 0: 2124 2125 self.normalized_data = self.normalized_data.loc[:, mask] 2126 2127 if self.input_metadata is not None: 2128 2129 self.input_metadata["drop"] = self.input_metadata[name_slot] 2130 2131 mask = [ 2132 not any(r in x for r in names) for x in self.input_metadata["drop"] 2133 ] 2134 2135 if len([y for y in mask if y is False]) > 0: 2136 2137 self.input_metadata = self.input_metadata.loc[mask, :].reset_index( 2138 drop=True 2139 ) 2140 2141 self.input_metadata = self.input_metadata.drop( 2142 columns=["drop"], errors="ignore" 2143 ) 2144 2145 if self.agg_normalized_data is not None: 2146 2147 self.agg_normalized_data.columns = self.agg_metadata[name_slot] 2148 2149 mask = [ 2150 not any(r in x for r in names) 2151 for x in self.agg_normalized_data.columns 2152 ] 2153 2154 if len([y for y in mask if y is False]) > 0: 2155 2156 self.agg_normalized_data = self.agg_normalized_data.loc[:, mask] 2157 2158 if self.agg_metadata is not None: 2159 2160 self.agg_metadata["drop"] = self.agg_metadata[name_slot] 2161 2162 mask = [ 2163 not any(r in x for r in names) for x in self.agg_metadata["drop"] 2164 ] 2165 2166 if len([y for y in mask if y is False]) > 0: 2167 2168 self.agg_metadata = self.agg_metadata.loc[mask, :].reset_index( 2169 drop=True 2170 ) 2171 2172 self.agg_metadata = self.agg_metadata.drop( 2173 columns=["drop"], errors="ignore" 2174 ) 2175 2176 self.gene_calculation() 2177 self.cells_calculation()
Filter cell names / clusters by cell-detection threshold.
Parameters
min_n : int or None Minimum number of detected genes required to keep a cell.
name_slot : str, default 'cell_names' Column in metadata to use as sample names.
Update
Filters self.input_data, self.normalized_data, self.input_metadata
(and calls average() if self.agg_normalized_data exists).
2179 def load_sparse_from_projects(self, normalized_data: bool = False): 2180 """ 2181 Load sparse 10x-style datasets from stored project paths, concatenate them, 2182 and populate `input_data` / `normalized_data` and `input_metadata`. 2183 2184 Parameters 2185 ---------- 2186 normalized_data : bool, default False 2187 If True, store concatenated tables in `self.normalized_data`. 2188 If False, store them in `self.input_data` and normalization 2189 is needed using normalize_data() method. 2190 2191 Side Effects 2192 ------------ 2193 - Reads each project using `load_sparse(...)` (expects matrix.mtx, genes.tsv, barcodes.tsv). 2194 - Concatenates all projects column-wise and sets `self.input_metadata`. 2195 - Replaces NaNs with zeros and updates `self.gene_calc`. 2196 """ 2197 2198 obj = self.objects 2199 2200 full_data = pd.DataFrame() 2201 full_metadata = pd.DataFrame() 2202 2203 for ke in obj.keys(): 2204 print(ke) 2205 2206 dt, met = load_sparse(path=obj[ke], name=ke) 2207 2208 full_data = pd.concat([full_data, dt], axis=1) 2209 full_metadata = pd.concat([full_metadata, met], axis=0) 2210 2211 full_data[np.isnan(full_data)] = 0 2212 2213 if normalized_data: 2214 self.normalized_data = full_data 2215 self.input_metadata = full_metadata 2216 else: 2217 2218 self.input_data = full_data 2219 self.input_metadata = full_metadata 2220 2221 self.gene_calculation() 2222 self.cells_calculation()
Load sparse 10x-style datasets from stored project paths, concatenate them,
and populate input_data / normalized_data and input_metadata.
Parameters
normalized_data : bool, default False
If True, store concatenated tables in self.normalized_data.
If False, store them in self.input_data and normalization
is needed using normalize_data() method.
Side Effects
- Reads each project using
load_sparse(...)(expects matrix.mtx, genes.tsv, barcodes.tsv). - Concatenates all projects column-wise and sets
self.input_metadata. - Replaces NaNs with zeros and updates
self.gene_calc.
2224 def rename_names(self, mapping: dict, slot: str = "cell_names"): 2225 """ 2226 Rename entries in `self.input_metadata[slot]` according to a provided mapping. 2227 2228 Parameters 2229 ---------- 2230 mapping : dict 2231 Dictionary with keys 'old_name' and 'new_name', each mapping to a list 2232 of equal length describing replacements. 2233 2234 slot : str, default 'cell_names' 2235 Metadata column to operate on. 2236 2237 Update 2238 ------- 2239 Updates `self.input_metadata[slot]` in-place with renamed values. 2240 2241 Raises 2242 ------ 2243 ValueError 2244 If mapping keys are incorrect, lengths differ, or some 'old_name' values 2245 are not present in the metadata column. 2246 """ 2247 2248 if set(["old_name", "new_name"]) != set(mapping.keys()): 2249 raise ValueError( 2250 "Mapping dictionary must contain keys 'old_name' and 'new_name', " 2251 "each with a list of names to change." 2252 ) 2253 2254 if len(mapping["old_name"]) != len(mapping["new_name"]): 2255 raise ValueError( 2256 "Mapping dictionary lists 'old_name' and 'new_name' " 2257 "must have the same length!" 2258 ) 2259 2260 names_vector = list(self.input_metadata[slot]) 2261 2262 if not all(elem in names_vector for elem in list(mapping["old_name"])): 2263 raise ValueError( 2264 f"Some entries from 'old_name' do not exist in the names of slot {slot}." 2265 ) 2266 2267 replace_dict = dict(zip(mapping["old_name"], mapping["new_name"])) 2268 2269 names_vector_ret = [replace_dict.get(item, item) for item in names_vector] 2270 2271 self.input_metadata[slot] = names_vector_ret
Rename entries in self.input_metadata[slot] according to a provided mapping.
Parameters
mapping : dict Dictionary with keys 'old_name' and 'new_name', each mapping to a list of equal length describing replacements.
slot : str, default 'cell_names' Metadata column to operate on.
Update
Updates self.input_metadata[slot] in-place with renamed values.
Raises
ValueError If mapping keys are incorrect, lengths differ, or some 'old_name' values are not present in the metadata column.
2273 def rename_subclusters(self, mapping): 2274 """ 2275 Rename labels stored in `self.subclusters_.subclusters` according to mapping. 2276 2277 Parameters 2278 ---------- 2279 mapping : dict 2280 Mapping with keys 'old_name' and 'new_name' (lists of equal length). 2281 2282 Update 2283 ------- 2284 Updates `self.subclusters_.subclusters` with renamed labels. 2285 2286 Raises 2287 ------ 2288 ValueError 2289 If mapping is invalid or old names are not present. 2290 """ 2291 2292 if set(["old_name", "new_name"]) != set(mapping.keys()): 2293 raise ValueError( 2294 "Mapping dictionary must contain keys 'old_name' and 'new_name', " 2295 "each with a list of names to change." 2296 ) 2297 2298 if len(mapping["old_name"]) != len(mapping["new_name"]): 2299 raise ValueError( 2300 "Mapping dictionary lists 'old_name' and 'new_name' " 2301 "must have the same length!" 2302 ) 2303 2304 names_vector = list(self.subclusters_.subclusters) 2305 2306 if not all(elem in names_vector for elem in list(mapping["old_name"])): 2307 raise ValueError( 2308 "Some entries from 'old_name' do not exist in the subcluster names." 2309 ) 2310 2311 replace_dict = dict(zip(mapping["old_name"], mapping["new_name"])) 2312 2313 names_vector_ret = [replace_dict.get(item, item) for item in names_vector] 2314 2315 self.subclusters_.subclusters = names_vector_ret
Rename labels stored in self.subclusters_.subclusters according to mapping.
Parameters
mapping : dict Mapping with keys 'old_name' and 'new_name' (lists of equal length).
Update
Updates self.subclusters_.subclusters with renamed labels.
Raises
ValueError If mapping is invalid or old names are not present.
2317 def save_sparse( 2318 self, 2319 path_to_save: str = os.getcwd(), 2320 name_slot: str = "cell_names", 2321 data_slot: str = "normalized", 2322 ): 2323 """ 2324 Export data as 10x-compatible sparse files (matrix.mtx, barcodes.tsv, genes.tsv). 2325 2326 Parameters 2327 ---------- 2328 path_to_save : str, default current working directory 2329 Directory where files will be written. 2330 2331 name_slot : str, default 'cell_names' 2332 Metadata column providing cell names for barcodes.tsv. 2333 2334 data_slot : str, default 'normalized' 2335 Either 'normalized' (uses self.normalized_data) or 'count' (uses self.input_data). 2336 2337 Raises 2338 ------ 2339 ValueError 2340 If `data_slot` is not 'normalized' or 'count'. 2341 """ 2342 2343 names = self.input_metadata[name_slot] 2344 2345 if data_slot.lower() == "normalized": 2346 2347 features = list(self.normalized_data.index) 2348 mtx = sparse.csr_matrix(self.normalized_data) 2349 2350 elif data_slot.lower() == "count": 2351 2352 features = list(self.input_data.index) 2353 mtx = sparse.csr_matrix(self.input_data) 2354 2355 else: 2356 raise ValueError("'data_slot' must be included in 'normalized' or 'count'") 2357 2358 os.makedirs(path_to_save, exist_ok=True) 2359 2360 mmwrite(os.path.join(path_to_save, "matrix.mtx"), mtx) 2361 2362 pd.Series(names).to_csv( 2363 os.path.join(path_to_save, "barcodes.tsv"), 2364 index=False, 2365 header=False, 2366 sep="\t", 2367 ) 2368 2369 pd.Series(features).to_csv( 2370 os.path.join(path_to_save, "genes.tsv"), index=False, header=False, sep="\t" 2371 )
Export data as 10x-compatible sparse files (matrix.mtx, barcodes.tsv, genes.tsv).
Parameters
path_to_save : str, default current working directory Directory where files will be written.
name_slot : str, default 'cell_names' Metadata column providing cell names for barcodes.tsv.
data_slot : str, default 'normalized' Either 'normalized' (uses self.normalized_data) or 'count' (uses self.input_data).
Raises
ValueError
If data_slot is not 'normalized' or 'count'.
2373 def normalize_counts( 2374 self, normalize_factor: int = 100000, log_transform: bool = True 2375 ): 2376 """ 2377 Normalize raw counts to counts-per-(normalize_factor) 2378 (e.g., CPM, TPM - depending on normalize_factor). 2379 2380 Parameters 2381 ---------- 2382 normalize_factor : int, default 100000 2383 Scaling factor used after dividing by column sums. 2384 2385 log_transform : bool, default True 2386 If True, apply log2(x+1) transformation to normalized values. 2387 2388 Update 2389 ------- 2390 Sets `self.normalized_data` to normalized values (fills NaNs with 0). 2391 2392 Raises 2393 ------ 2394 ValueError 2395 If `self.input_data` is missing (cannot normalize). 2396 """ 2397 if self.input_data is None: 2398 raise ValueError("Input data is missing, cannot normalize.") 2399 2400 sum_col = self.input_data.sum() 2401 self.normalized_data = self.input_data.div(sum_col).fillna(0) * normalize_factor 2402 2403 if log_transform: 2404 # log2(x + 1) to avoid -inf for zeros 2405 self.normalized_data = np.log2(self.normalized_data + 1)
Normalize raw counts to counts-per-(normalize_factor) (e.g., CPM, TPM - depending on normalize_factor).
Parameters
normalize_factor : int, default 100000 Scaling factor used after dividing by column sums.
log_transform : bool, default True If True, apply log2(x+1) transformation to normalized values.
Update
Sets `self.normalized_data` to normalized values (fills NaNs with 0).
Raises
ValueError
If self.input_data is missing (cannot normalize).
2407 def statistic( 2408 self, 2409 cells=None, 2410 sets=None, 2411 min_exp: float = 0.01, 2412 min_pct: float = 0.1, 2413 n_proc: int = 10, 2414 ): 2415 """ 2416 Compute per-feature statistics (Mann–Whitney U) comparing target vs rest. 2417 2418 This is a wrapper similar to `calc_DEG` tailored to use `self.normalized_data` 2419 and `self.input_metadata`. It returns per-feature statistics including p-values, 2420 adjusted p-values, means, variances, effect-size measures and fold-changes. 2421 2422 Parameters 2423 ---------- 2424 cells : list, 'All', dict, or None 2425 Defines the target cells or groups for comparison (several modes supported). 2426 2427 sets : 'All', dict, or None 2428 Alternative grouping mode (operate on `self.input_metadata['sets']`). 2429 2430 min_exp : float, default 0.01 2431 Minimum expression threshold used when filtering features. 2432 2433 min_pct : float, default 0.1 2434 Minimum proportion of expressing cells in the target group required to test a feature. 2435 2436 n_proc : int, default 10 2437 Number of parallel jobs to use. 2438 2439 Returns 2440 ------- 2441 pandas.DataFrame or dict 2442 Results DataFrame (or dict containing valid/control cells + DataFrame), 2443 similar to `calc_DEG` interface. 2444 2445 Raises 2446 ------ 2447 ValueError 2448 If neither `cells` nor `sets` is provided, or input metadata mismatch occurs. 2449 2450 Notes 2451 ----- 2452 Multiple modes supported: single-list entities, 'All', pairwise dicts, etc. 2453 """ 2454 2455 offset = 1e-100 2456 2457 def stat_calc(choose, feature_name): 2458 target_values = choose.loc[choose["DEG"] == "target", feature_name] 2459 rest_values = choose.loc[choose["DEG"] == "rest", feature_name] 2460 2461 pct_valid = (target_values > 0).sum() / len(target_values) 2462 pct_rest = (rest_values > 0).sum() / len(rest_values) 2463 2464 avg_valid = np.mean(target_values) 2465 avg_ctrl = np.mean(rest_values) 2466 sd_valid = np.std(target_values, ddof=1) 2467 sd_ctrl = np.std(rest_values, ddof=1) 2468 esm = (avg_valid - avg_ctrl) / np.sqrt(((sd_valid**2 + sd_ctrl**2) / 2)) 2469 2470 if np.sum(target_values) == np.sum(rest_values): 2471 p_val = 1.0 2472 else: 2473 _, p_val = stats.mannwhitneyu( 2474 target_values, rest_values, alternative="two-sided" 2475 ) 2476 2477 return { 2478 "feature": feature_name, 2479 "p_val": p_val, 2480 "pct_valid": pct_valid, 2481 "pct_ctrl": pct_rest, 2482 "avg_valid": avg_valid, 2483 "avg_ctrl": avg_ctrl, 2484 "sd_valid": sd_valid, 2485 "sd_ctrl": sd_ctrl, 2486 "esm": esm, 2487 } 2488 2489 def prepare_and_run_stat(choose, valid_group, min_exp, min_pct, n_proc): 2490 2491 def safe_min_half(series): 2492 filtered = series[(series > ((2**-1074)*2)) & (series.notna())] 2493 return filtered.min() / 2 if not filtered.empty else 0 2494 2495 tmp_dat = choose[choose["DEG"] == "target"] 2496 tmp_dat = tmp_dat.drop("DEG", axis=1) 2497 2498 counts = (tmp_dat > min_exp).sum(axis=0) 2499 2500 total_count = tmp_dat.shape[0] 2501 2502 info = pd.DataFrame( 2503 {"feature": list(tmp_dat.columns), "pct": list(counts / total_count)} 2504 ) 2505 2506 del tmp_dat 2507 2508 drop_col = info["feature"][info["pct"] <= min_pct] 2509 2510 if len(drop_col) + 1 == len(choose.columns): 2511 drop_col = info["feature"][info["pct"] == 0] 2512 2513 del info 2514 2515 choose = choose.drop(list(drop_col), axis=1) 2516 2517 results = Parallel(n_jobs=n_proc)( 2518 delayed(stat_calc)(choose, feature) 2519 for feature in tqdm(choose.columns[choose.columns != "DEG"]) 2520 ) 2521 2522 df = pd.DataFrame(results) 2523 df = df[(df["avg_valid"] > 0) | (df["avg_ctrl"] > 0)] 2524 2525 df["valid_group"] = valid_group 2526 df.sort_values(by="p_val", inplace=True) 2527 2528 num_tests = len(df) 2529 df["adj_pval"] = np.minimum( 2530 1, (df["p_val"] * num_tests) / np.arange(1, num_tests + 1) 2531 ) 2532 2533 valid_factor = safe_min_half(df["avg_valid"]) 2534 ctrl_factor = safe_min_half(df["avg_ctrl"]) 2535 2536 cv_factor = min(valid_factor, ctrl_factor) 2537 2538 if cv_factor == 0: 2539 cv_factor = max(valid_factor, ctrl_factor) 2540 2541 if not np.isfinite(cv_factor) or cv_factor == 0: 2542 cv_factor += offset 2543 2544 valid = df["avg_valid"].where( 2545 df["avg_valid"] != 0, df["avg_valid"] + cv_factor 2546 ) 2547 ctrl = df["avg_ctrl"].where( 2548 df["avg_ctrl"] != 0, df["avg_ctrl"] + cv_factor 2549 ) 2550 2551 df["FC"] = valid / ctrl 2552 2553 df["log(FC)"] = np.log2(df["FC"]) 2554 df["norm_diff"] = df["avg_valid"] - df["avg_ctrl"] 2555 2556 return df 2557 2558 choose = self.normalized_data.copy().T 2559 2560 final_results = [] 2561 2562 if isinstance(cells, list) and sets is None: 2563 print("\nAnalysis started...\nComparing selected cells to the whole set...") 2564 choose.index = ( 2565 self.input_metadata["cell_names"] + " # " + self.input_metadata["sets"] 2566 ) 2567 2568 if "#" not in cells[0]: 2569 choose.index = self.input_metadata["cell_names"] 2570 2571 print( 2572 "Not include the set info (name # set) in the 'cells' list. " 2573 "Only the names will be compared, without considering the set information." 2574 ) 2575 2576 labels = ["target" if idx in cells else "rest" for idx in choose.index] 2577 valid = list( 2578 set(choose.index[[i for i, x in enumerate(labels) if x == "target"]]) 2579 ) 2580 2581 choose["DEG"] = labels 2582 choose = choose[choose["DEG"] != "drop"] 2583 2584 result_df = prepare_and_run_stat( 2585 choose.reset_index(drop=True), 2586 valid_group=valid, 2587 min_exp=min_exp, 2588 min_pct=min_pct, 2589 n_proc=n_proc, 2590 ) 2591 return {"valid_cells": valid, "control_cells": "rest", "DEG": result_df} 2592 2593 elif cells == "All" and sets is None: 2594 print("\nAnalysis started...\nComparing each type of cell to others...") 2595 choose.index = ( 2596 self.input_metadata["cell_names"] + " # " + self.input_metadata["sets"] 2597 ) 2598 unique_labels = set(choose.index) 2599 2600 for label in tqdm(unique_labels): 2601 print(f"\nCalculating statistics for {label}") 2602 labels = ["target" if idx == label else "rest" for idx in choose.index] 2603 choose["DEG"] = labels 2604 choose = choose[choose["DEG"] != "drop"] 2605 result_df = prepare_and_run_stat( 2606 choose.copy(), 2607 valid_group=label, 2608 min_exp=min_exp, 2609 min_pct=min_pct, 2610 n_proc=n_proc, 2611 ) 2612 final_results.append(result_df) 2613 2614 return pd.concat(final_results, ignore_index=True) 2615 2616 elif cells is None and sets == "All": 2617 print("\nAnalysis started...\nComparing each set/group to others...") 2618 choose.index = self.input_metadata["sets"] 2619 unique_sets = set(choose.index) 2620 2621 for label in tqdm(unique_sets): 2622 print(f"\nCalculating statistics for {label}") 2623 labels = ["target" if idx == label else "rest" for idx in choose.index] 2624 2625 choose["DEG"] = labels 2626 choose = choose[choose["DEG"] != "drop"] 2627 result_df = prepare_and_run_stat( 2628 choose.copy(), 2629 valid_group=label, 2630 min_exp=min_exp, 2631 min_pct=min_pct, 2632 n_proc=n_proc, 2633 ) 2634 final_results.append(result_df) 2635 2636 return pd.concat(final_results, ignore_index=True) 2637 2638 elif cells is None and isinstance(sets, dict): 2639 print("\nAnalysis started...\nComparing groups...") 2640 2641 choose.index = self.input_metadata["sets"] 2642 2643 group_list = list(sets.keys()) 2644 if len(group_list) != 2: 2645 print("Only pairwise group comparison is supported.") 2646 return None 2647 2648 labels = [ 2649 ( 2650 "target" 2651 if idx in sets[group_list[0]] 2652 else "rest" if idx in sets[group_list[1]] else "drop" 2653 ) 2654 for idx in choose.index 2655 ] 2656 choose["DEG"] = labels 2657 choose = choose[choose["DEG"] != "drop"] 2658 2659 result_df = prepare_and_run_stat( 2660 choose.reset_index(drop=True), 2661 valid_group=group_list[0], 2662 min_exp=min_exp, 2663 min_pct=min_pct, 2664 n_proc=n_proc, 2665 ) 2666 return result_df 2667 2668 elif isinstance(cells, dict) and sets is None: 2669 print("\nAnalysis started...\nComparing groups...") 2670 choose.index = ( 2671 self.input_metadata["cell_names"] + " # " + self.input_metadata["sets"] 2672 ) 2673 2674 if "#" not in cells[list(cells.keys())[0]][0]: 2675 choose.index = self.input_metadata["cell_names"] 2676 2677 print( 2678 "Not include the set info (name # set) in the 'cells' dict. " 2679 "Only the names will be compared, without considering the set information." 2680 ) 2681 2682 group_list = list(cells.keys()) 2683 if len(group_list) != 2: 2684 print("Only pairwise group comparison is supported.") 2685 return None 2686 2687 labels = [ 2688 ( 2689 "target" 2690 if idx in cells[group_list[0]] 2691 else "rest" if idx in cells[group_list[1]] else "drop" 2692 ) 2693 for idx in choose.index 2694 ] 2695 2696 choose["DEG"] = labels 2697 choose = choose[choose["DEG"] != "drop"] 2698 2699 result_df = prepare_and_run_stat( 2700 choose.reset_index(drop=True), 2701 valid_group=group_list[0], 2702 min_exp=min_exp, 2703 min_pct=min_pct, 2704 n_proc=n_proc, 2705 ) 2706 2707 return result_df.reset_index(drop=True) 2708 2709 else: 2710 raise ValueError( 2711 "You must specify either 'cells' or 'sets' (or both). None were provided, which is not allowed for this analysis." 2712 )
Compute per-feature statistics (Mann–Whitney U) comparing target vs rest.
This is a wrapper similar to calc_DEG tailored to use self.normalized_data
and self.input_metadata. It returns per-feature statistics including p-values,
adjusted p-values, means, variances, effect-size measures and fold-changes.
Parameters
cells : list, 'All', dict, or None Defines the target cells or groups for comparison (several modes supported).
sets : 'All', dict, or None
Alternative grouping mode (operate on self.input_metadata['sets']).
min_exp : float, default 0.01 Minimum expression threshold used when filtering features.
min_pct : float, default 0.1 Minimum proportion of expressing cells in the target group required to test a feature.
n_proc : int, default 10 Number of parallel jobs to use.
Returns
pandas.DataFrame or dict
Results DataFrame (or dict containing valid/control cells + DataFrame),
similar to calc_DEG interface.
Raises
ValueError
If neither cells nor sets is provided, or input metadata mismatch occurs.
Notes
Multiple modes supported: single-list entities, 'All', pairwise dicts, etc.
2714 def calculate_difference_markers( 2715 self, min_exp=0, min_pct=0.25, n_proc=10, force: bool = False 2716 ): 2717 """ 2718 Compute differential markers (var_data) if not already present. 2719 2720 Parameters 2721 ---------- 2722 min_exp : float, default 0 2723 Minimum expression threshold passed to `statistic`. 2724 2725 min_pct : float, default 0.25 2726 Minimum percent expressed in target group. 2727 2728 n_proc : int, default 10 2729 Parallel jobs. 2730 2731 force : bool, default False 2732 If True, recompute even if `self.var_data` is present. 2733 2734 Update 2735 ------- 2736 Sets `self.var_data` to the result of `self.statistic(...)`. 2737 2738 Raise 2739 ------ 2740 ValueError if already computed and `force` is False. 2741 """ 2742 2743 if self.var_data is None or force: 2744 2745 self.var_data = self.statistic( 2746 cells="All", sets=None, min_exp=min_exp, min_pct=min_pct, n_proc=n_proc 2747 ) 2748 2749 else: 2750 raise ValueError( 2751 "self.calculate_difference_markers() has already been executed. " 2752 "The results are stored in self.var. " 2753 "If you want to recalculate with different statistics, please rerun the method with force=True." 2754 )
Compute differential markers (var_data) if not already present.
Parameters
min_exp : float, default 0
Minimum expression threshold passed to statistic.
min_pct : float, default 0.25 Minimum percent expressed in target group.
n_proc : int, default 10 Parallel jobs.
force : bool, default False
If True, recompute even if self.var_data is present.
Update
Sets self.var_data to the result of self.statistic(...).
Raise
ValueError if already computed and force is False.
2756 def clustering_features( 2757 self, 2758 features_list: list | None, 2759 name_slot: str = "cell_names", 2760 p_val: float = 0.05, 2761 top_n: int = 25, 2762 adj_mean: bool = True, 2763 beta: float = 0.2, 2764 ): 2765 """ 2766 Prepare clustering input by selecting marker features and optionally smoothing cell values 2767 toward group means. 2768 2769 Parameters 2770 ---------- 2771 features_list : list or None 2772 If provided, use this list of features. If None, features are selected 2773 from `self.var_data` (adj_pval <= p_val, positive logFC) picking `top_n` per group. 2774 2775 name_slot : str, default 'cell_names' 2776 Metadata column used for naming. 2777 2778 p_val : float, default 0.05 2779 Adjusted p-value cutoff when selecting features automatically. 2780 2781 top_n : int, default 25 2782 Number of top features per valid group to keep if `features_list` is None. 2783 2784 adj_mean : bool, default True 2785 If True, adjust cell values toward group means using `beta`. 2786 2787 beta : float, default 0.2 2788 Adjustment strength toward group mean. 2789 2790 Update 2791 ------ 2792 Sets `self.clustering_data` and `self.clustering_metadata` to the selected subset, 2793 ready for PCA/UMAP/clustering. 2794 """ 2795 2796 if features_list is None or len(features_list) == 0: 2797 2798 if self.var_data is None: 2799 raise ValueError( 2800 "Lack of 'self.var_data'. Use self.calculate_difference_markers() method first." 2801 ) 2802 2803 df_tmp = self.var_data[self.var_data["adj_pval"] <= p_val] 2804 df_tmp = df_tmp[df_tmp["log(FC)"] > 0] 2805 df_tmp = ( 2806 df_tmp.sort_values( 2807 ["valid_group", "esm", "log(FC)"], ascending=[True, False, False] 2808 ) 2809 .groupby("valid_group") 2810 .head(top_n) 2811 ) 2812 2813 feaures_list = list(set(df_tmp["feature"])) 2814 2815 data = self.get_partial_data( 2816 names=None, features=feaures_list, name_slot=name_slot 2817 ) 2818 data_avg = average(data) 2819 2820 if adj_mean: 2821 data = adjust_cells_to_group_mean(data=data, data_avg=data_avg, beta=beta) 2822 2823 self.clustering_data = data 2824 2825 self.clustering_metadata = self.input_metadata
Prepare clustering input by selecting marker features and optionally smoothing cell values toward group means.
Parameters
features_list : list or None
If provided, use this list of features. If None, features are selected
from self.var_data (adj_pval <= p_val, positive logFC) picking top_n per group.
name_slot : str, default 'cell_names' Metadata column used for naming.
p_val : float, default 0.05 Adjusted p-value cutoff when selecting features automatically.
top_n : int, default 25
Number of top features per valid group to keep if features_list is None.
adj_mean : bool, default True
If True, adjust cell values toward group means using beta.
beta : float, default 0.2 Adjustment strength toward group mean.
Update
Sets self.clustering_data and self.clustering_metadata to the selected subset,
ready for PCA/UMAP/clustering.
2827 def average(self): 2828 """ 2829 Aggregate normalized data by (cell_name, set) pairs computing the mean per group. 2830 2831 The method constructs new column names as "cell_name # set", averages columns 2832 sharing identical labels, and populates `self.agg_normalized_data` and `self.agg_metadata`. 2833 2834 Update 2835 ------ 2836 Sets `self.agg_normalized_data` (features x aggregated samples) and 2837 `self.agg_metadata` (DataFrame with 'cell_names' and 'sets'). 2838 """ 2839 2840 wide_data = self.normalized_data 2841 2842 wide_metadata = self.input_metadata 2843 2844 new_names = wide_metadata["cell_names"] + " # " + wide_metadata["sets"] 2845 2846 wide_data.columns = list(new_names) 2847 2848 aggregated_df = wide_data.T.groupby(level=0).mean().T 2849 2850 sets = [re.sub(".*# ", "", x) for x in aggregated_df.columns] 2851 names = [re.sub(" #.*", "", x) for x in aggregated_df.columns] 2852 2853 aggregated_df.columns = names 2854 aggregated_metadata = pd.DataFrame({"cell_names": names, "sets": sets}) 2855 2856 self.agg_metadata = aggregated_metadata 2857 self.agg_normalized_data = aggregated_df
Aggregate normalized data by (cell_name, set) pairs computing the mean per group.
The method constructs new column names as "cell_name # set", averages columns
sharing identical labels, and populates self.agg_normalized_data and self.agg_metadata.
Update
Sets self.agg_normalized_data (features x aggregated samples) and
self.agg_metadata (DataFrame with 'cell_names' and 'sets').
2859 def estimating_similarity( 2860 self, method="pearson", p_val: float = 0.05, top_n: int = 25 2861 ): 2862 """ 2863 Estimate pairwise similarity and Euclidean distance between aggregated samples. 2864 2865 Parameters 2866 ---------- 2867 method : str, default 'pearson' 2868 Correlation method to use (passed to pandas.DataFrame.corr()). 2869 2870 p_val : float, default 0.05 2871 Adjusted p-value cutoff used to select marker features from `self.var_data`. 2872 2873 top_n : int, default 25 2874 Number of top features per valid group to include. 2875 2876 Update 2877 ------- 2878 Computes a combined table with per-pair correlation and euclidean distance 2879 and stores it in `self.similarity`. 2880 """ 2881 2882 if self.var_data is None: 2883 raise ValueError( 2884 "Lack of 'self.var_data'. Use self.calculate_difference_markers() method first." 2885 ) 2886 2887 if self.agg_normalized_data is None: 2888 self.average() 2889 2890 metadata = self.agg_metadata 2891 data = self.agg_normalized_data 2892 2893 df_tmp = self.var_data[self.var_data["adj_pval"] <= p_val] 2894 df_tmp = df_tmp[df_tmp["log(FC)"] > 0] 2895 df_tmp = ( 2896 df_tmp.sort_values( 2897 ["valid_group", "esm", "log(FC)"], ascending=[True, False, False] 2898 ) 2899 .groupby("valid_group") 2900 .head(top_n) 2901 ) 2902 2903 data = data.loc[list(set(df_tmp["feature"]))] 2904 2905 if len(set(metadata["sets"])) > 1: 2906 data.columns = data.columns + " # " + [x for x in metadata["sets"]] 2907 else: 2908 data = data.copy() 2909 2910 scaler = StandardScaler() 2911 2912 scaled_data = scaler.fit_transform(data) 2913 2914 scaled_df = pd.DataFrame(scaled_data, columns=data.columns) 2915 2916 cor = scaled_df.corr(method=method) 2917 cor_df = cor.stack().reset_index() 2918 cor_df.columns = ["cell1", "cell2", "correlation"] 2919 2920 distances = pdist(scaled_df.T, metric="euclidean") 2921 dist_mat = pd.DataFrame( 2922 squareform(distances), index=scaled_df.columns, columns=scaled_df.columns 2923 ) 2924 dist_df = dist_mat.stack().reset_index() 2925 dist_df.columns = ["cell1", "cell2", "euclidean_dist"] 2926 2927 full = pd.merge(cor_df, dist_df, on=["cell1", "cell2"]) 2928 2929 full = full[full["cell1"] != full["cell2"]] 2930 full = full.reset_index(drop=True) 2931 2932 self.similarity = full
Estimate pairwise similarity and Euclidean distance between aggregated samples.
Parameters
method : str, default 'pearson' Correlation method to use (passed to pandas.DataFrame.corr()).
p_val : float, default 0.05
Adjusted p-value cutoff used to select marker features from self.var_data.
top_n : int, default 25 Number of top features per valid group to include.
Update
Computes a combined table with per-pair correlation and euclidean distance
and stores it in self.similarity.
2934 def similarity_plot( 2935 self, 2936 split_sets=True, 2937 set_info: bool = True, 2938 cmap="seismic", 2939 width=12, 2940 height=10, 2941 ): 2942 """ 2943 Visualize pairwise similarity as a scatter plot. 2944 2945 Parameters 2946 ---------- 2947 split_sets : bool, default True 2948 If True and set information is present, split plotting area roughly into two halves to visualize cross-set pairs. 2949 2950 set_info : bool, default True 2951 If True, keep the ' # set' annotation in labels; otherwise strip it. 2952 2953 cmap : str, default 'seismic' 2954 Color map for correlation (hue). 2955 2956 width : int, default 12 2957 Figure width. 2958 2959 height : int, default 10 2960 Figure height. 2961 2962 Returns 2963 ------- 2964 matplotlib.figure.Figure 2965 2966 Raises 2967 ------ 2968 ValueError 2969 If `self.similarity` is None. 2970 2971 Notes 2972 ----- 2973 The function filters pairs by z-scored euclidean distance > 0 to focus on closer pairs. 2974 """ 2975 2976 if self.similarity is None: 2977 raise ValueError( 2978 "Similarity data is missing. Please calculate similarity using self.estimating_similarity." 2979 ) 2980 2981 similarity_data = self.similarity 2982 2983 if " # " in similarity_data["cell1"][0]: 2984 similarity_data["set1"] = [ 2985 re.sub(".*# ", "", x) for x in similarity_data["cell1"] 2986 ] 2987 similarity_data["set2"] = [ 2988 re.sub(".*# ", "", x) for x in similarity_data["cell2"] 2989 ] 2990 2991 if split_sets and " # " in similarity_data["cell1"][0]: 2992 sets = list( 2993 set(list(similarity_data["set1"]) + list(similarity_data["set2"])) 2994 ) 2995 2996 mm = math.ceil(len(sets) / 2) 2997 2998 x_s = sets[0:mm] 2999 y_s = sets[mm : len(sets)] 3000 3001 similarity_data = similarity_data[similarity_data["set1"].isin(x_s)] 3002 similarity_data = similarity_data[similarity_data["set2"].isin(y_s)] 3003 3004 similarity_data = similarity_data.sort_values(["set1", "set2"]) 3005 3006 if set_info is False and " # " in similarity_data["cell1"][0]: 3007 similarity_data["cell1"] = [ 3008 re.sub(" #.*", "", x) for x in similarity_data["cell1"] 3009 ] 3010 similarity_data["cell2"] = [ 3011 re.sub(" #.*", "", x) for x in similarity_data["cell2"] 3012 ] 3013 3014 similarity_data["-euclidean_zscore"] = -zscore( 3015 similarity_data["euclidean_dist"] 3016 ) 3017 3018 similarity_data = similarity_data[similarity_data["-euclidean_zscore"] > 0] 3019 3020 fig = plt.figure(figsize=(width, height)) 3021 sns.scatterplot( 3022 data=similarity_data, 3023 x="cell1", 3024 y="cell2", 3025 hue="correlation", 3026 size="-euclidean_zscore", 3027 sizes=(1, 100), 3028 palette=cmap, 3029 alpha=1, 3030 edgecolor="black", 3031 ) 3032 3033 plt.xticks(rotation=90) 3034 plt.yticks(rotation=0) 3035 plt.xlabel("Cell 1") 3036 plt.ylabel("Cell 2") 3037 plt.legend(bbox_to_anchor=(1.05, 1), loc="upper left") 3038 3039 plt.grid(True, alpha=0.6) 3040 3041 plt.tight_layout() 3042 3043 return fig
Visualize pairwise similarity as a scatter plot.
Parameters
split_sets : bool, default True If True and set information is present, split plotting area roughly into two halves to visualize cross-set pairs.
set_info : bool, default True If True, keep the ' # set' annotation in labels; otherwise strip it.
cmap : str, default 'seismic' Color map for correlation (hue).
width : int, default 12 Figure width.
height : int, default 10 Figure height.
Returns
matplotlib.figure.Figure
Raises
ValueError
If self.similarity is None.
Notes
The function filters pairs by z-scored euclidean distance > 0 to focus on closer pairs.
3045 def spatial_similarity( 3046 self, 3047 set_info: bool = True, 3048 bandwidth=1, 3049 n_neighbors=5, 3050 min_dist=0.1, 3051 legend_split=2, 3052 point_size=100, 3053 spread=1.0, 3054 set_op_mix_ratio=1.0, 3055 local_connectivity=1, 3056 repulsion_strength=1.0, 3057 negative_sample_rate=5, 3058 threshold=0.1, 3059 width=12, 3060 height=10, 3061 ): 3062 """ 3063 Create a spatial UMAP-like visualization of similarity relationships between samples. 3064 3065 Parameters 3066 ---------- 3067 set_info : bool, default True 3068 If True, retain set information in labels. 3069 3070 bandwidth : float, default 1 3071 Bandwidth used by MeanShift for clustering polygons. 3072 3073 point_size : float, default 100 3074 Size of scatter points. 3075 3076 legend_split : int, default 2 3077 Number of columns in legend. 3078 3079 n_neighbors, min_dist, spread, set_op_mix_ratio, local_connectivity, repulsion_strength, negative_sample_rate : parameters passed to UMAP. 3080 3081 threshold : float, default 0.1 3082 Minimum text distance for label adjustment to avoid overlap. 3083 3084 width : int, default 12 3085 Figure width. 3086 3087 height : int, default 10 3088 Figure height. 3089 3090 Returns 3091 ------- 3092 matplotlib.figure.Figure 3093 3094 Raises 3095 ------ 3096 ValueError 3097 If `self.similarity` is None. 3098 3099 Notes 3100 ----- 3101 Builds a precomputed distance matrix combining correlation and euclidean distance, 3102 runs UMAP with metric='precomputed', then overlays cluster hulls (MeanShift + convex hull) 3103 and arrows to indicate nearest neighbors (minimal combined distance). 3104 """ 3105 3106 if self.similarity is None: 3107 raise ValueError( 3108 "Similarity data is missing. Please calculate similarity using self.estimating_similarity." 3109 ) 3110 3111 similarity_data = self.similarity 3112 3113 sim = similarity_data["correlation"] 3114 sim_scaled = (sim - sim.min()) / (sim.max() - sim.min()) 3115 eu_dist = similarity_data["euclidean_dist"] 3116 eu_dist_scaled = (eu_dist - eu_dist.min()) / (eu_dist.max() - eu_dist.min()) 3117 3118 similarity_data["combo_dist"] = (1 - sim_scaled) * eu_dist_scaled 3119 3120 # for nn target 3121 arrow_df = similarity_data.copy() 3122 arrow_df = similarity_data.loc[ 3123 similarity_data.groupby("cell1")["combo_dist"].idxmin() 3124 ] 3125 3126 cells = sorted(set(similarity_data["cell1"]) | set(similarity_data["cell2"])) 3127 combo_matrix = pd.DataFrame(0, index=cells, columns=cells, dtype=float) 3128 3129 for _, row in similarity_data.iterrows(): 3130 combo_matrix.loc[row["cell1"], row["cell2"]] = row["combo_dist"] 3131 combo_matrix.loc[row["cell2"], row["cell1"]] = row["combo_dist"] 3132 3133 umap_model = umap.UMAP( 3134 n_components=2, 3135 metric="precomputed", 3136 n_neighbors=n_neighbors, 3137 min_dist=min_dist, 3138 spread=spread, 3139 set_op_mix_ratio=set_op_mix_ratio, 3140 local_connectivity=set_op_mix_ratio, 3141 repulsion_strength=repulsion_strength, 3142 negative_sample_rate=negative_sample_rate, 3143 transform_seed=42, 3144 init="spectral", 3145 random_state=42, 3146 verbose=True, 3147 ) 3148 3149 coords = umap_model.fit_transform(combo_matrix.values) 3150 cell_names = list(combo_matrix.index) 3151 num_cells = len(cell_names) 3152 palette = sns.color_palette("tab20c", num_cells) 3153 3154 if "#" in cell_names[0]: 3155 avsets = set( 3156 [re.sub(".*# ", "", x) for x in similarity_data["cell1"]] 3157 + [re.sub(".*# ", "", x) for x in similarity_data["cell2"]] 3158 ) 3159 num_sets = len(avsets) 3160 color_indices = [i * len(palette) // num_sets for i in range(num_sets)] 3161 color_mapping_sets = { 3162 set_name: palette[i] for i, set_name in zip(color_indices, avsets) 3163 } 3164 color_mapping = { 3165 name: color_mapping_sets[re.sub(".*# ", "", name)] 3166 for i, name in enumerate(cell_names) 3167 } 3168 else: 3169 color_mapping = {name: palette[i] for i, name in enumerate(cell_names)} 3170 3171 meanshift = MeanShift(bandwidth=bandwidth) 3172 labels = meanshift.fit_predict(coords) 3173 3174 fig = plt.figure(figsize=(width, height)) 3175 ax = plt.gca() 3176 3177 unique_labels = set(labels) 3178 cluster_palette = sns.color_palette("hls", len(unique_labels)) 3179 3180 for label in unique_labels: 3181 if label == -1: 3182 continue 3183 cluster_coords = coords[labels == label] 3184 if len(cluster_coords) < 3: 3185 continue 3186 3187 hull = ConvexHull(cluster_coords) 3188 hull_points = cluster_coords[hull.vertices] 3189 3190 centroid = np.mean(hull_points, axis=0) 3191 expanded = hull_points + 0.05 * (hull_points - centroid) 3192 3193 poly = Polygon( 3194 expanded, 3195 closed=True, 3196 facecolor=cluster_palette[label], 3197 edgecolor="none", 3198 alpha=0.2, 3199 zorder=1, 3200 ) 3201 ax.add_patch(poly) 3202 3203 texts = [] 3204 for i, (x, y) in enumerate(coords): 3205 plt.scatter( 3206 x, 3207 y, 3208 s=point_size, 3209 color=color_mapping[cell_names[i]], 3210 edgecolors="black", 3211 linewidths=0.5, 3212 zorder=2, 3213 ) 3214 texts.append( 3215 ax.text( 3216 x, y, str(i), ha="center", va="center", fontsize=8, color="black" 3217 ) 3218 ) 3219 3220 def dist(p1, p2): 3221 return np.sqrt((p1[0] - p2[0]) ** 2 + (p1[1] - p2[1]) ** 2) 3222 3223 texts_to_adjust = [] 3224 for i, t1 in enumerate(texts): 3225 for j, t2 in enumerate(texts): 3226 if i >= j: 3227 continue 3228 d = dist( 3229 (t1.get_position()[0], t1.get_position()[1]), 3230 (t2.get_position()[0], t2.get_position()[1]), 3231 ) 3232 if d < threshold: 3233 if t1 not in texts_to_adjust: 3234 texts_to_adjust.append(t1) 3235 if t2 not in texts_to_adjust: 3236 texts_to_adjust.append(t2) 3237 3238 adjust_text( 3239 texts_to_adjust, 3240 expand_text=(1.0, 1.0), 3241 force_text=0.9, 3242 arrowprops=dict(arrowstyle="-", color="gray", lw=0.1), 3243 ax=ax, 3244 ) 3245 3246 for _, row in arrow_df.iterrows(): 3247 try: 3248 idx1 = cell_names.index(row["cell1"]) 3249 idx2 = cell_names.index(row["cell2"]) 3250 except ValueError: 3251 continue 3252 x1, y1 = coords[idx1] 3253 x2, y2 = coords[idx2] 3254 arrow = FancyArrowPatch( 3255 (x1, y1), 3256 (x2, y2), 3257 arrowstyle="->", 3258 color="gray", 3259 linewidth=1.5, 3260 alpha=0.5, 3261 mutation_scale=12, 3262 zorder=0, 3263 ) 3264 ax.add_patch(arrow) 3265 3266 if set_info is False and " # " in cell_names[0]: 3267 3268 legend_elements = [ 3269 Patch( 3270 facecolor=color_mapping[name], 3271 edgecolor="black", 3272 label=f"{i} – {re.sub(' #.*', '', name)}", 3273 ) 3274 for i, name in enumerate(cell_names) 3275 ] 3276 3277 else: 3278 3279 legend_elements = [ 3280 Patch( 3281 facecolor=color_mapping[name], 3282 edgecolor="black", 3283 label=f"{i} – {name}", 3284 ) 3285 for i, name in enumerate(cell_names) 3286 ] 3287 3288 plt.legend( 3289 handles=legend_elements, 3290 title="Cells", 3291 bbox_to_anchor=(1.05, 1), 3292 loc="upper left", 3293 ncol=legend_split, 3294 ) 3295 3296 plt.xlabel("UMAP 1") 3297 plt.ylabel("UMAP 2") 3298 plt.grid(False) 3299 plt.show() 3300 3301 return fig
Create a spatial UMAP-like visualization of similarity relationships between samples.
Parameters
set_info : bool, default True If True, retain set information in labels.
bandwidth : float, default 1 Bandwidth used by MeanShift for clustering polygons.
point_size : float, default 100 Size of scatter points.
legend_split : int, default 2 Number of columns in legend.
n_neighbors, min_dist, spread, set_op_mix_ratio, local_connectivity, repulsion_strength, negative_sample_rate : parameters passed to UMAP.
threshold : float, default 0.1 Minimum text distance for label adjustment to avoid overlap.
width : int, default 12 Figure width.
height : int, default 10 Figure height.
Returns
matplotlib.figure.Figure
Raises
ValueError
If self.similarity is None.
Notes
Builds a precomputed distance matrix combining correlation and euclidean distance, runs UMAP with metric='precomputed', then overlays cluster hulls (MeanShift + convex hull) and arrows to indicate nearest neighbors (minimal combined distance).
3305 def subcluster_prepare(self, features: list, cluster: str): 3306 """ 3307 Prepare a `Clustering` object for subcluster analysis on a selected parent cluster. 3308 3309 Parameters 3310 ---------- 3311 features : list 3312 Features to include for subcluster analysis. 3313 3314 cluster : str 3315 Parent cluster name (used to select matching cells). 3316 3317 Update 3318 ------ 3319 Initializes `self.subclusters_` as a new `Clustering` instance containing the 3320 reduced data for the given cluster and stores `current_features` and `current_cluster`. 3321 """ 3322 3323 dat = self.normalized_data 3324 dat.columns = list(self.input_metadata["cell_names"]) 3325 3326 dat = reduce_data(self.normalized_data, features=features, names=[cluster]) 3327 3328 self.subclusters_ = Clustering(data=dat, metadata=None) 3329 3330 self.subclusters_.current_features = features 3331 self.subclusters_.current_cluster = cluster
Prepare a Clustering object for subcluster analysis on a selected parent cluster.
Parameters
features : list Features to include for subcluster analysis.
cluster : str Parent cluster name (used to select matching cells).
Update
Initializes self.subclusters_ as a new Clustering instance containing the
reduced data for the given cluster and stores current_features and current_cluster.
3333 def define_subclusters( 3334 self, 3335 umap_num: int = 2, 3336 eps: float = 0.5, 3337 min_samples: int = 10, 3338 n_neighbors: int = 5, 3339 min_dist: float = 0.1, 3340 spread: float = 1.0, 3341 set_op_mix_ratio: float = 1.0, 3342 local_connectivity: int = 1, 3343 repulsion_strength: float = 1.0, 3344 negative_sample_rate: int = 5, 3345 width=8, 3346 height=6, 3347 ): 3348 """ 3349 Compute UMAP and DBSCAN clustering within a previously prepared subcluster dataset. 3350 3351 Parameters 3352 ---------- 3353 umap_num : int, default 2 3354 Number of UMAP dimensions to compute. 3355 3356 eps : float, default 0.5 3357 DBSCAN eps parameter. 3358 3359 min_samples : int, default 10 3360 DBSCAN min_samples parameter. 3361 3362 n_neighbors, min_dist, spread, set_op_mix_ratio, local_connectivity, repulsion_strength, negative_sample_rate, width, height : 3363 Additional parameters passed to UMAP / plotting / MeanShift as appropriate. 3364 3365 Update 3366 ------- 3367 Stores cluster labels in `self.subclusters_.subclusters`. 3368 3369 Raises 3370 ------ 3371 RuntimeError 3372 If `self.subclusters_` has not been prepared. 3373 """ 3374 3375 if self.subclusters_ is None: 3376 raise RuntimeError( 3377 "Nothing to return. 'self.subcluster_prepare' was not conducted!" 3378 ) 3379 3380 self.subclusters_.perform_UMAP( 3381 factorize=False, 3382 umap_num=umap_num, 3383 pc_num=0, 3384 harmonized=False, 3385 n_neighbors=n_neighbors, 3386 min_dist=min_dist, 3387 spread=spread, 3388 set_op_mix_ratio=set_op_mix_ratio, 3389 local_connectivity=local_connectivity, 3390 repulsion_strength=repulsion_strength, 3391 negative_sample_rate=negative_sample_rate, 3392 width=width, 3393 height=height, 3394 ) 3395 3396 fig = self.subclusters_.find_clusters_UMAP( 3397 umap_n=umap_num, 3398 eps=eps, 3399 min_samples=min_samples, 3400 width=width, 3401 height=height, 3402 ) 3403 3404 clusters = self.subclusters_.return_clusters(clusters="umap") 3405 3406 self.subclusters_.subclusters = [str(x) for x in list(clusters)] 3407 3408 return fig
Compute UMAP and DBSCAN clustering within a previously prepared subcluster dataset.
Parameters
umap_num : int, default 2 Number of UMAP dimensions to compute.
eps : float, default 0.5 DBSCAN eps parameter.
min_samples : int, default 10 DBSCAN min_samples parameter.
n_neighbors, min_dist, spread, set_op_mix_ratio, local_connectivity, repulsion_strength, negative_sample_rate, width, height : Additional parameters passed to UMAP / plotting / MeanShift as appropriate.
Update
Stores cluster labels in self.subclusters_.subclusters.
Raises
RuntimeError
If self.subclusters_ has not been prepared.
3410 def subcluster_features_scatter( 3411 self, 3412 colors="viridis", 3413 hclust="complete", 3414 scale=False, 3415 img_width=3, 3416 img_high=5, 3417 label_size=6, 3418 size_scale=70, 3419 y_lab="Genes", 3420 legend_lab="normalized", 3421 bbox_to_anchor_scale: int = 25, 3422 bbox_to_anchor_perc: tuple = (0.91, 0.63), 3423 ): 3424 """ 3425 Create a features-scatter visualization for the subclusters (averaged and occurrence). 3426 3427 Parameters 3428 ---------- 3429 colors : str, default 'viridis' 3430 Colormap name passed to `features_scatter`. 3431 3432 hclust : str or None 3433 Hierarchical clustering linkage to order rows/columns. 3434 3435 scale: bool, default False 3436 If True, expression data will be scaled (0–1) across the rows (features). 3437 3438 img_width, img_high : float 3439 Figure size. 3440 3441 label_size : int 3442 Font size for labels. 3443 3444 size_scale : int 3445 Bubble size scaling. 3446 3447 y_lab : str 3448 X axis label. 3449 3450 legend_lab : str 3451 Colorbar label. 3452 3453 bbox_to_anchor_scale : int, default=25 3454 Vertical scale (percentage) for positioning the colorbar. 3455 3456 bbox_to_anchor_perc : tuple, default=(0.91, 0.63) 3457 Anchor position for the size legend (percent bubble legend). 3458 3459 Returns 3460 ------- 3461 matplotlib.figure.Figure 3462 3463 Raises 3464 ------ 3465 RuntimeError 3466 If subcluster preparation/definition has not been run. 3467 """ 3468 3469 if self.subclusters_ is None or self.subclusters_.subclusters is None: 3470 raise RuntimeError( 3471 "Nothing to return. 'self.subcluster_prepare' -> 'self.define_subclusters' pip was not conducted!" 3472 ) 3473 3474 dat = self.normalized_data 3475 dat.columns = list(self.input_metadata["cell_names"]) 3476 3477 dat = reduce_data( 3478 self.normalized_data, 3479 features=self.subclusters_.current_features, 3480 names=[self.subclusters_.current_cluster], 3481 ) 3482 3483 dat.columns = self.subclusters_.subclusters 3484 3485 avg = average(dat) 3486 occ = occurrence(dat) 3487 3488 scatter = features_scatter( 3489 expression_data=avg, 3490 occurence_data=occ, 3491 features=None, 3492 scale=scale, 3493 metadata_list=None, 3494 colors=colors, 3495 hclust=hclust, 3496 img_width=img_width, 3497 img_high=img_high, 3498 label_size=label_size, 3499 size_scale=size_scale, 3500 y_lab=y_lab, 3501 legend_lab=legend_lab, 3502 bbox_to_anchor_scale=bbox_to_anchor_scale, 3503 bbox_to_anchor_perc=bbox_to_anchor_perc, 3504 ) 3505 3506 return scatter
Create a features-scatter visualization for the subclusters (averaged and occurrence).
Parameters
colors : str, default 'viridis'
Colormap name passed to features_scatter.
hclust : str or None Hierarchical clustering linkage to order rows/columns.
scale: bool, default False If True, expression data will be scaled (0–1) across the rows (features).
img_width, img_high : float Figure size.
label_size : int Font size for labels.
size_scale : int Bubble size scaling.
y_lab : str X axis label.
legend_lab : str Colorbar label.
bbox_to_anchor_scale : int, default=25 Vertical scale (percentage) for positioning the colorbar.
bbox_to_anchor_perc : tuple, default=(0.91, 0.63) Anchor position for the size legend (percent bubble legend).
Returns
matplotlib.figure.Figure
Raises
RuntimeError If subcluster preparation/definition has not been run.
3508 def subcluster_DEG_scatter( 3509 self, 3510 top_n=3, 3511 min_exp=0, 3512 min_pct=0.25, 3513 p_val=0.05, 3514 colors="viridis", 3515 hclust="complete", 3516 scale=False, 3517 img_width=3, 3518 img_high=5, 3519 label_size=6, 3520 size_scale=70, 3521 y_lab="Genes", 3522 legend_lab="normalized", 3523 bbox_to_anchor_scale: int = 25, 3524 bbox_to_anchor_perc: tuple = (0.91, 0.63), 3525 n_proc=10, 3526 ): 3527 """ 3528 Plot top differential features (DEGs) for subclusters as a features-scatter. 3529 3530 Parameters 3531 ---------- 3532 top_n : int, default 3 3533 Number of top features per subcluster to show. 3534 3535 min_exp : float, default 0 3536 Minimum expression threshold passed to `statistic`. 3537 3538 min_pct : float, default 0.25 3539 Minimum percent expressed in target group. 3540 3541 p_val: float, default 0.05 3542 Maximum p-value for visualizing features. 3543 3544 n_proc : int, default 10 3545 Parallel jobs used for DEG calculation. 3546 3547 scale: bool, default False 3548 If True, expression_data will be scaled (0–1) across the rows (features). 3549 3550 colors : str, default='viridis' 3551 Colormap for expression values. 3552 3553 hclust : str or None, default='complete' 3554 Linkage method for hierarchical clustering. If None, no clustering 3555 is performed. 3556 3557 img_width : int or float, default=8 3558 Width of the plot in inches. 3559 3560 img_high : int or float, default=5 3561 Height of the plot in inches. 3562 3563 label_size : int, default=10 3564 Font size for axis labels and ticks. 3565 3566 size_scale : int or float, default=100 3567 Scaling factor for bubble sizes. 3568 3569 y_lab : str, default='Genes' 3570 Label for the x-axis. 3571 3572 legend_lab : str, default='normalized' 3573 Label for the colorbar legend. 3574 3575 bbox_to_anchor_scale : int, default=25 3576 Vertical scale (percentage) for positioning the colorbar. 3577 3578 bbox_to_anchor_perc : tuple, default=(0.91, 0.63) 3579 Anchor position for the size legend (percent bubble legend). 3580 3581 Returns 3582 ------- 3583 matplotlib.figure.Figure 3584 3585 Raises 3586 ------ 3587 RuntimeError 3588 If subcluster preparation/definition has not been run. 3589 3590 Notes 3591 ----- 3592 Internally calls `calc_DEG` (or equivalent) to obtain statistics, filters 3593 by p-value and effect-size, selects top features per valid group and plots them. 3594 """ 3595 3596 if self.subclusters_ is None or self.subclusters_.subclusters is None: 3597 raise RuntimeError( 3598 "Nothing to return. 'self.subcluster_prepare' -> 'self.define_subclusters' pip was not conducted!" 3599 ) 3600 3601 dat = self.normalized_data 3602 dat.columns = list(self.input_metadata["cell_names"]) 3603 3604 dat = reduce_data( 3605 self.normalized_data, names=[self.subclusters_.current_cluster] 3606 ) 3607 3608 dat.columns = self.subclusters_.subclusters 3609 3610 deg_stats = calc_DEG( 3611 dat, 3612 metadata_list=None, 3613 entities="All", 3614 sets=None, 3615 min_exp=min_exp, 3616 min_pct=min_pct, 3617 n_proc=n_proc, 3618 ) 3619 3620 deg_stats = deg_stats[deg_stats["p_val"] <= p_val] 3621 deg_stats = deg_stats[deg_stats["log(FC)"] > 0] 3622 3623 deg_stats = ( 3624 deg_stats.sort_values( 3625 ["valid_group", "esm", "log(FC)"], ascending=[True, False, False] 3626 ) 3627 .groupby("valid_group") 3628 .head(top_n) 3629 ) 3630 3631 dat = reduce_data(dat, features=list(set(deg_stats["feature"]))) 3632 3633 avg = average(dat) 3634 occ = occurrence(dat) 3635 3636 scatter = features_scatter( 3637 expression_data=avg, 3638 occurence_data=occ, 3639 features=None, 3640 metadata_list=None, 3641 colors=colors, 3642 hclust=hclust, 3643 img_width=img_width, 3644 img_high=img_high, 3645 label_size=label_size, 3646 size_scale=size_scale, 3647 y_lab=y_lab, 3648 legend_lab=legend_lab, 3649 bbox_to_anchor_scale=bbox_to_anchor_scale, 3650 bbox_to_anchor_perc=bbox_to_anchor_perc, 3651 ) 3652 3653 return scatter
Plot top differential features (DEGs) for subclusters as a features-scatter.
Parameters
top_n : int, default 3 Number of top features per subcluster to show.
min_exp : float, default 0
Minimum expression threshold passed to statistic.
min_pct : float, default 0.25 Minimum percent expressed in target group.
p_val: float, default 0.05 Maximum p-value for visualizing features.
n_proc : int, default 10 Parallel jobs used for DEG calculation.
scale: bool, default False If True, expression_data will be scaled (0–1) across the rows (features).
colors : str, default='viridis' Colormap for expression values.
hclust : str or None, default='complete' Linkage method for hierarchical clustering. If None, no clustering is performed.
img_width : int or float, default=8 Width of the plot in inches.
img_high : int or float, default=5 Height of the plot in inches.
label_size : int, default=10 Font size for axis labels and ticks.
size_scale : int or float, default=100 Scaling factor for bubble sizes.
y_lab : str, default='Genes' Label for the x-axis.
legend_lab : str, default='normalized' Label for the colorbar legend.
bbox_to_anchor_scale : int, default=25 Vertical scale (percentage) for positioning the colorbar.
bbox_to_anchor_perc : tuple, default=(0.91, 0.63) Anchor position for the size legend (percent bubble legend).
Returns
matplotlib.figure.Figure
Raises
RuntimeError If subcluster preparation/definition has not been run.
Notes
Internally calls calc_DEG (or equivalent) to obtain statistics, filters
by p-value and effect-size, selects top features per valid group and plots them.
3655 def accept_subclusters(self): 3656 """ 3657 Commit subcluster labels into the main `input_metadata` by renaming cell names. 3658 3659 The method replaces occurrences of the parent cluster name in `self.input_metadata['cell_names']` 3660 with the expanded names that include subcluster suffixes (via `add_subnames`), 3661 then clears `self.subclusters_`. 3662 3663 Update 3664 ------ 3665 Modifies `self.input_metadata['cell_names']`. 3666 3667 Resets `self.subclusters_` to None. 3668 3669 Raises 3670 ------ 3671 RuntimeError 3672 If `self.subclusters_` is not defined or subclusters were not computed. 3673 """ 3674 3675 if self.subclusters_ is None or self.subclusters_.subclusters is None: 3676 raise RuntimeError( 3677 "Nothing to return. 'self.subcluster_prepare' -> 'self.define_subclusters' pip was not conducted!" 3678 ) 3679 3680 new_meta = add_subnames( 3681 list(self.input_metadata["cell_names"]), 3682 parent_name=self.subclusters_.current_cluster, 3683 new_clusters=self.subclusters_.subclusters, 3684 ) 3685 3686 self.input_metadata["cell_names"] = new_meta 3687 3688 self.subclusters_ = None
Commit subcluster labels into the main input_metadata by renaming cell names.
The method replaces occurrences of the parent cluster name in self.input_metadata['cell_names']
with the expanded names that include subcluster suffixes (via add_subnames),
then clears self.subclusters_.
Update
Modifies self.input_metadata['cell_names'].
Resets self.subclusters_ to None.
Raises
RuntimeError
If self.subclusters_ is not defined or subclusters were not computed.
3690 def scatter_plot( 3691 self, 3692 names: list | None = None, 3693 features: list | None = None, 3694 name_slot: str = "cell_names", 3695 scale=True, 3696 colors="viridis", 3697 hclust=None, 3698 img_width=15, 3699 img_high=1, 3700 label_size=10, 3701 size_scale=200, 3702 y_lab="Genes", 3703 legend_lab="log(CPM + 1)", 3704 set_box_size: float | int = 5, 3705 set_box_high: float | int = 5, 3706 bbox_to_anchor_scale=25, 3707 bbox_to_anchor_perc=(0.90, -0.24), 3708 bbox_to_anchor_group=(1.01, 0.4), 3709 ): 3710 """ 3711 Create a bubble scatter plot of selected features across samples inside project. 3712 3713 Each point represents a feature-sample pair, where the color encodes the 3714 expression value and the size encodes occurrence or relative abundance. 3715 Optionally, hierarchical clustering can be applied to order rows and columns. 3716 3717 Parameters 3718 ---------- 3719 names : list, str, or None 3720 Names of samples to include. If None, all samples are considered. 3721 3722 features : list, str, or None 3723 Names of features to include. If None, all features are considered. 3724 3725 name_slot : str 3726 Column in metadata to use as sample names. 3727 3728 scale: bool, default False 3729 If True, expression_data will be scaled (0–1) across the rows (features). 3730 3731 colors : str, default='viridis' 3732 Colormap for expression values. 3733 3734 hclust : str or None, default='complete' 3735 Linkage method for hierarchical clustering. If None, no clustering 3736 is performed. 3737 3738 img_width : int or float, default=8 3739 Width of the plot in inches. 3740 3741 img_high : int or float, default=5 3742 Height of the plot in inches. 3743 3744 label_size : int, default=10 3745 Font size for axis labels and ticks. 3746 3747 size_scale : int or float, default=100 3748 Scaling factor for bubble sizes. 3749 3750 y_lab : str, default='Genes' 3751 Label for the x-axis. 3752 3753 legend_lab : str, default='log(CPM + 1)' 3754 Label for the colorbar legend. 3755 3756 bbox_to_anchor_scale : int, default=25 3757 Vertical scale (percentage) for positioning the colorbar. 3758 3759 bbox_to_anchor_perc : tuple, default=(0.91, 0.63) 3760 Anchor position for the size legend (percent bubble legend). 3761 3762 bbox_to_anchor_group : tuple, default=(1.01, 0.4) 3763 Anchor position for the group legend. 3764 3765 Returns 3766 ------- 3767 matplotlib.figure.Figure 3768 The generated scatter plot figure. 3769 3770 Notes 3771 ----- 3772 Colors represent expression values normalized to the colormap. 3773 """ 3774 3775 prtd, met = self.get_partial_data( 3776 names=names, features=features, name_slot=name_slot, inc_metadata=True 3777 ) 3778 3779 prtd.columns = prtd.columns + "#" + met["sets"] 3780 3781 prtd_avg = average(prtd) 3782 3783 meta_sets = [re.sub(".*#", "", x) for x in prtd_avg.columns] 3784 3785 prtd_avg.columns = [re.sub("#.*", "", x) for x in prtd_avg.columns] 3786 3787 prtd_occ = occurrence(prtd) 3788 3789 prtd_occ.columns = [re.sub("#.*", "", x) for x in prtd_occ.columns] 3790 3791 fig_scatter = features_scatter( 3792 expression_data=prtd_avg, 3793 occurence_data=prtd_occ, 3794 scale=scale, 3795 features=None, 3796 metadata_list=meta_sets, 3797 colors=colors, 3798 hclust=hclust, 3799 img_width=img_width, 3800 img_high=img_high, 3801 label_size=label_size, 3802 size_scale=size_scale, 3803 y_lab=y_lab, 3804 legend_lab=legend_lab, 3805 set_box_size=set_box_size, 3806 set_box_high=set_box_high, 3807 bbox_to_anchor_scale=bbox_to_anchor_scale, 3808 bbox_to_anchor_perc=bbox_to_anchor_perc, 3809 bbox_to_anchor_group=bbox_to_anchor_group, 3810 ) 3811 3812 return fig_scatter
Create a bubble scatter plot of selected features across samples inside project.
Each point represents a feature-sample pair, where the color encodes the expression value and the size encodes occurrence or relative abundance. Optionally, hierarchical clustering can be applied to order rows and columns.
Parameters
names : list, str, or None Names of samples to include. If None, all samples are considered.
features : list, str, or None Names of features to include. If None, all features are considered.
name_slot : str Column in metadata to use as sample names.
scale: bool, default False If True, expression_data will be scaled (0–1) across the rows (features).
colors : str, default='viridis' Colormap for expression values.
hclust : str or None, default='complete' Linkage method for hierarchical clustering. If None, no clustering is performed.
img_width : int or float, default=8 Width of the plot in inches.
img_high : int or float, default=5 Height of the plot in inches.
label_size : int, default=10 Font size for axis labels and ticks.
size_scale : int or float, default=100 Scaling factor for bubble sizes.
y_lab : str, default='Genes' Label for the x-axis.
legend_lab : str, default='log(CPM + 1)' Label for the colorbar legend.
bbox_to_anchor_scale : int, default=25 Vertical scale (percentage) for positioning the colorbar.
bbox_to_anchor_perc : tuple, default=(0.91, 0.63) Anchor position for the size legend (percent bubble legend).
bbox_to_anchor_group : tuple, default=(1.01, 0.4) Anchor position for the group legend.
Returns
matplotlib.figure.Figure The generated scatter plot figure.
Notes
Colors represent expression values normalized to the colormap.
3814 def data_composition( 3815 self, 3816 features_count: list | None, 3817 name_slot: str = "cell_names", 3818 set_sep: bool = True, 3819 ): 3820 """ 3821 Compute composition of cell types in data set. 3822 3823 This function counts the occurrences of specific cells (e.g., cell types, subtypes) 3824 within metadata entries, calculates their relative percentages, and stores 3825 the results in `self.composition_data`. 3826 3827 Parameters 3828 ---------- 3829 features_count : list or None 3830 List of features (part or full names) to be counted. 3831 If None, all unique elements from the specified `name_slot` metadata field are used. 3832 3833 name_slot : str, default 'cell_names' 3834 Metadata field containing sample identifiers or labels. 3835 3836 set_sep : bool, default True 3837 If True and multiple sets are present in metadata, compute composition 3838 separately for each set. 3839 3840 Update 3841 ------- 3842 Stores results in `self.composition_data` as a pandas DataFrame with: 3843 - 'name': feature name 3844 - 'n': number of occurrences 3845 - 'pct': percentage of occurrences 3846 - 'set' (if applicable): dataset identifier 3847 """ 3848 3849 validated_list = list(self.input_metadata[name_slot]) 3850 sets = list(self.input_metadata["sets"]) 3851 3852 if features_count is None: 3853 features_count = list(set(self.input_metadata[name_slot])) 3854 3855 if set_sep and len(set(sets)) > 1: 3856 3857 final_res = pd.DataFrame() 3858 3859 for s in set(sets): 3860 print(s) 3861 3862 mask = [True if s == x else False for x in sets] 3863 3864 tmp_val_list = np.array(validated_list) 3865 3866 tmp_val_list = list(tmp_val_list[mask]) 3867 3868 res_dict = {"name": [], "n": [], "set": []} 3869 3870 for f in tqdm(features_count): 3871 res_dict["n"].append( 3872 sum(1 for element in tmp_val_list if f in element) 3873 ) 3874 res_dict["name"].append(f) 3875 res_dict["set"].append(s) 3876 res = pd.DataFrame(res_dict) 3877 res["pct"] = res["n"] / sum(res["n"]) * 100 3878 res["pct"] = res["pct"].round(2) 3879 3880 final_res = pd.concat([final_res, res]) 3881 3882 res = final_res.sort_values(["set", "pct"], ascending=[True, False]) 3883 3884 else: 3885 3886 res_dict = {"name": [], "n": []} 3887 3888 for f in tqdm(features_count): 3889 res_dict["n"].append( 3890 sum(1 for element in validated_list if f in element) 3891 ) 3892 res_dict["name"].append(f) 3893 3894 res = pd.DataFrame(res_dict) 3895 res["pct"] = res["n"] / sum(res["n"]) * 100 3896 res["pct"] = res["pct"].round(2) 3897 3898 res = res.sort_values("pct", ascending=False) 3899 3900 self.composition_data = res
Compute composition of cell types in data set.
This function counts the occurrences of specific cells (e.g., cell types, subtypes)
within metadata entries, calculates their relative percentages, and stores
the results in self.composition_data.
Parameters
features_count : list or None
List of features (part or full names) to be counted.
If None, all unique elements from the specified name_slot metadata field are used.
name_slot : str, default 'cell_names' Metadata field containing sample identifiers or labels.
set_sep : bool, default True If True and multiple sets are present in metadata, compute composition separately for each set.
Update
Stores results in self.composition_data as a pandas DataFrame with:
- 'name': feature name
- 'n': number of occurrences
- 'pct': percentage of occurrences
- 'set' (if applicable): dataset identifier
3902 def composition_pie( 3903 self, 3904 width=6, 3905 height=6, 3906 font_size=15, 3907 cmap: str = "tab20", 3908 legend_split_col: int = 1, 3909 offset_labels: float | int = 0.5, 3910 legend_bbox: tuple = (1.15, 0.95), 3911 ): 3912 """ 3913 Visualize the composition of cell lineages using pie charts. 3914 3915 Generates pie charts showing the relative proportions of features stored 3916 in `self.composition_data`. If multiple sets are present, a separate 3917 chart is drawn for each set. 3918 3919 Parameters 3920 ---------- 3921 width : int, default 6 3922 Width of the figure. 3923 3924 height : int, default 6 3925 Height of the figure (applied per set if multiple sets are plotted). 3926 3927 font_size : int, default 15 3928 Font size for labels and annotations. 3929 3930 cmap : str, default 'tab20' 3931 Colormap used for pie slices. 3932 3933 legend_split_col : int, default 1 3934 Number of columns in the legend. 3935 3936 offset_labels : float or int, default 0.5 3937 Spacing offset for label placement relative to pie slices. 3938 3939 legend_bbox : tuple, default (1.15, 0.95) 3940 Bounding box anchor position for the legend. 3941 3942 Returns 3943 ------- 3944 matplotlib.figure.Figure 3945 Pie chart visualization of composition data. 3946 """ 3947 3948 df = self.composition_data 3949 3950 if "set" in df.columns and len(set(df["set"])) > 1: 3951 3952 sets = list(set(df["set"])) 3953 fig, axes = plt.subplots(len(sets), 1, figsize=(width, height * len(sets))) 3954 3955 all_wedges = [] 3956 cmap = plt.get_cmap(cmap) 3957 3958 set_nam = len(set(df["name"])) 3959 3960 legend_labels = list(set(df["name"])) 3961 3962 colors = [cmap(i / set_nam) for i in range(set_nam)] 3963 3964 cmap_dict = dict(zip(legend_labels, colors)) 3965 3966 for idx, s in enumerate(sets): 3967 ax = axes[idx] 3968 tmp_df = df[df["set"] == s].reset_index(drop=True) 3969 3970 labels = [f"{row['pct']:.1f}%" for _, row in tmp_df.iterrows()] 3971 3972 wedges, _ = ax.pie( 3973 tmp_df["n"], 3974 startangle=90, 3975 labeldistance=1.05, 3976 colors=[cmap_dict[x] for x in tmp_df["name"]], 3977 wedgeprops={"linewidth": 0.5, "edgecolor": "black"}, 3978 ) 3979 3980 all_wedges.extend(wedges) 3981 3982 kw = dict(arrowprops=dict(arrowstyle="-"), zorder=0, va="center") 3983 n = 0 3984 for i, p in enumerate(wedges): 3985 ang = (p.theta2 - p.theta1) / 2.0 + p.theta1 3986 y = np.sin(np.deg2rad(ang)) 3987 x = np.cos(np.deg2rad(ang)) 3988 horizontalalignment = {-1: "right", 1: "left"}[int(np.sign(x))] 3989 connectionstyle = f"angle,angleA=0,angleB={ang}" 3990 kw["arrowprops"].update({"connectionstyle": connectionstyle}) 3991 if len(labels[i]) > 0: 3992 n += offset_labels 3993 ax.annotate( 3994 labels[i], 3995 xy=(x, y), 3996 xytext=(1.01 * x + (n * x / 4), 1.01 * y + (n * y / 4)), 3997 horizontalalignment=horizontalalignment, 3998 fontsize=font_size, 3999 weight="bold", 4000 **kw, 4001 ) 4002 4003 circle2 = plt.Circle((0, 0), 0.6, color="white", ec="black") 4004 ax.add_artist(circle2) 4005 4006 ax.text( 4007 0, 4008 0, 4009 f"{s}", 4010 ha="center", 4011 va="center", 4012 fontsize=font_size, 4013 weight="bold", 4014 ) 4015 4016 legend_handles = [ 4017 Patch(facecolor=cmap_dict[label], edgecolor="black", label=label) 4018 for label in legend_labels 4019 ] 4020 4021 fig.legend( 4022 handles=legend_handles, 4023 loc="center right", 4024 bbox_to_anchor=legend_bbox, 4025 ncol=legend_split_col, 4026 title="", 4027 ) 4028 4029 plt.tight_layout() 4030 plt.show() 4031 4032 else: 4033 4034 labels = [f"{row['pct']:.1f}%" for _, row in df.iterrows()] 4035 4036 legend_labels = [f"{row['name']}" for _, row in df.iterrows()] 4037 4038 cmap = plt.get_cmap(cmap) 4039 colors = [cmap(i / len(df)) for i in range(len(df))] 4040 4041 fig, ax = plt.subplots( 4042 figsize=(width, height), subplot_kw=dict(aspect="equal") 4043 ) 4044 4045 wedges, _ = ax.pie( 4046 df["n"], 4047 startangle=90, 4048 labeldistance=1.05, 4049 colors=colors, 4050 wedgeprops={"linewidth": 0.5, "edgecolor": "black"}, 4051 ) 4052 4053 kw = dict(arrowprops=dict(arrowstyle="-"), zorder=0, va="center") 4054 n = 0 4055 for i, p in enumerate(wedges): 4056 ang = (p.theta2 - p.theta1) / 2.0 + p.theta1 4057 y = np.sin(np.deg2rad(ang)) 4058 x = np.cos(np.deg2rad(ang)) 4059 horizontalalignment = {-1: "right", 1: "left"}[int(np.sign(x))] 4060 connectionstyle = "angle,angleA=0,angleB={}".format(ang) 4061 kw["arrowprops"].update({"connectionstyle": connectionstyle}) 4062 if len(labels[i]) > 0: 4063 n += offset_labels 4064 4065 ax.annotate( 4066 labels[i], 4067 xy=(x, y), 4068 xytext=(1.01 * x + (n * x / 4), y * 1.01 + (n * y / 4)), 4069 horizontalalignment=horizontalalignment, 4070 fontsize=font_size, 4071 weight="bold", 4072 **kw, 4073 ) 4074 4075 circle2 = plt.Circle((0, 0), 0.6, color="white") 4076 circle2.set_edgecolor("black") 4077 4078 p = plt.gcf() 4079 p.gca().add_artist(circle2) 4080 4081 ax.legend( 4082 wedges, 4083 legend_labels, 4084 title="", 4085 loc="center left", 4086 bbox_to_anchor=legend_bbox, 4087 ncol=legend_split_col, 4088 ) 4089 4090 plt.show() 4091 4092 return fig
Visualize the composition of cell lineages using pie charts.
Generates pie charts showing the relative proportions of features stored
in self.composition_data. If multiple sets are present, a separate
chart is drawn for each set.
Parameters
width : int, default 6 Width of the figure.
height : int, default 6 Height of the figure (applied per set if multiple sets are plotted).
font_size : int, default 15 Font size for labels and annotations.
cmap : str, default 'tab20' Colormap used for pie slices.
legend_split_col : int, default 1 Number of columns in the legend.
offset_labels : float or int, default 0.5 Spacing offset for label placement relative to pie slices.
legend_bbox : tuple, default (1.15, 0.95) Bounding box anchor position for the legend.
Returns
matplotlib.figure.Figure Pie chart visualization of composition data.
4094 def bar_composition( 4095 self, 4096 cmap="tab20b", 4097 width=2, 4098 height=6, 4099 font_size=15, 4100 legend_split_col: int = 1, 4101 legend_bbox: tuple = (1.3, 1), 4102 ): 4103 """ 4104 Visualize the composition of cell lineages using bar plots. 4105 4106 Produces bar plots showing the distribution of features stored in 4107 `self.composition_data`. If multiple sets are present, a separate 4108 bar is drawn for each set. Percentages are annotated alongside the bars. 4109 4110 Parameters 4111 ---------- 4112 cmap : str, default 'tab20b' 4113 Colormap used for stacked bars. 4114 4115 width : int, default 2 4116 Width of each subplot (per set). 4117 4118 height : int, default 6 4119 Height of the figure. 4120 4121 font_size : int, default 15 4122 Font size for labels and annotations. 4123 4124 legend_split_col : int, default 1 4125 Number of columns in the legend. 4126 4127 legend_bbox : tuple, default (1.3, 1) 4128 Bounding box anchor position for the legend. 4129 4130 Returns 4131 ------- 4132 matplotlib.figure.Figure 4133 Stacked bar plot visualization of composition data. 4134 """ 4135 4136 df = self.composition_data 4137 df["num"] = range(1, len(df) + 1) 4138 4139 if "set" in df.columns and len(set(df["set"])) > 1: 4140 4141 sets = list(set(df["set"])) 4142 fig, axes = plt.subplots(1, len(sets), figsize=(width * len(sets), height)) 4143 4144 cmap = plt.get_cmap(cmap) 4145 4146 set_nam = len(set(df["name"])) 4147 4148 legend_labels = list(set(df["name"])) 4149 4150 colors = [cmap(i / set_nam) for i in range(set_nam)] 4151 4152 cmap_dict = dict(zip(legend_labels, colors)) 4153 4154 for idx, s in enumerate(sets): 4155 ax = axes[idx] 4156 4157 tmp_df = df[df["set"] == s].reset_index(drop=True) 4158 4159 values = tmp_df["n"].values 4160 total = sum(values) 4161 values = [v / total * 100 for v in values] 4162 values = [round(v, 2) for v in values] 4163 4164 idx_max = np.argmax(values) 4165 correction = 100 - sum(values) 4166 values[idx_max] += correction 4167 4168 names = tmp_df["name"].values 4169 perc = tmp_df["pct"].values 4170 nums = tmp_df["num"].values 4171 4172 bottom = 0 4173 centers = [] 4174 for name, num, val, color in zip(names, nums, values, colors): 4175 ax.bar(s, val, bottom=bottom, color=cmap_dict[name], label=name) 4176 centers.append(bottom + val / 2) 4177 bottom += val 4178 4179 y_positions = np.linspace(centers[0], centers[-1], len(centers)) 4180 x_text = -0.8 4181 4182 for y_label, y_center, pct, num in zip( 4183 y_positions, centers, perc, nums 4184 ): 4185 ax.annotate( 4186 f"{pct:.1f}%", 4187 xy=(0, y_center), 4188 xycoords="data", 4189 xytext=(x_text, y_label), 4190 textcoords="data", 4191 ha="right", 4192 va="center", 4193 fontsize=font_size, 4194 arrowprops=dict( 4195 arrowstyle="->", 4196 lw=1, 4197 color="black", 4198 connectionstyle="angle3,angleA=0,angleB=90", 4199 ), 4200 ) 4201 4202 ax.set_ylim(0, 100) 4203 ax.set_xlabel(s, fontsize=font_size) 4204 ax.xaxis.label.set_rotation(30) 4205 4206 ax.set_xticks([]) 4207 ax.set_yticks([]) 4208 for spine in ax.spines.values(): 4209 spine.set_visible(False) 4210 4211 legend_handles = [ 4212 Patch(facecolor=cmap_dict[label], edgecolor="black", label=label) 4213 for label in legend_labels 4214 ] 4215 4216 fig.legend( 4217 handles=legend_handles, 4218 loc="center right", 4219 bbox_to_anchor=legend_bbox, 4220 ncol=legend_split_col, 4221 title="", 4222 ) 4223 4224 plt.tight_layout() 4225 plt.show() 4226 4227 else: 4228 4229 cmap = plt.get_cmap(cmap) 4230 4231 colors = [cmap(i / len(df)) for i in range(len(df))] 4232 4233 fig, ax = plt.subplots(figsize=(width, height)) 4234 4235 values = df["n"].values 4236 names = df["name"].values 4237 perc = df["pct"].values 4238 nums = df["num"].values 4239 4240 bottom = 0 4241 centers = [] 4242 for name, num, val, color in zip(names, nums, values, colors): 4243 ax.bar(0, val, bottom=bottom, color=color, label=f"{num}) {name}") 4244 centers.append(bottom + val / 2) 4245 bottom += val 4246 4247 y_positions = np.linspace(centers[0], centers[-1], len(centers)) 4248 x_text = -0.8 4249 4250 for y_label, y_center, pct, num in zip(y_positions, centers, perc, nums): 4251 ax.annotate( 4252 f"{num}) {pct}", 4253 xy=(0, y_center), 4254 xycoords="data", 4255 xytext=(x_text, y_label), 4256 textcoords="data", 4257 ha="right", 4258 va="center", 4259 fontsize=9, 4260 arrowprops=dict( 4261 arrowstyle="->", 4262 lw=1, 4263 color="black", 4264 connectionstyle="angle3,angleA=0,angleB=90", 4265 ), 4266 ) 4267 4268 ax.set_xticks([]) 4269 ax.set_yticks([]) 4270 for spine in ax.spines.values(): 4271 spine.set_visible(False) 4272 4273 ax.legend( 4274 title="Legend", 4275 bbox_to_anchor=legend_bbox, 4276 loc="upper left", 4277 ncol=legend_split_col, 4278 ) 4279 4280 plt.tight_layout() 4281 plt.show() 4282 4283 return fig
Visualize the composition of cell lineages using bar plots.
Produces bar plots showing the distribution of features stored in
self.composition_data. If multiple sets are present, a separate
bar is drawn for each set. Percentages are annotated alongside the bars.
Parameters
cmap : str, default 'tab20b' Colormap used for stacked bars.
width : int, default 2 Width of each subplot (per set).
height : int, default 6 Height of the figure.
font_size : int, default 15 Font size for labels and annotations.
legend_split_col : int, default 1 Number of columns in the legend.
legend_bbox : tuple, default (1.3, 1) Bounding box anchor position for the legend.
Returns
matplotlib.figure.Figure Stacked bar plot visualization of composition data.
4285 def cell_regression( 4286 self, 4287 cell_x: str, 4288 cell_y: str, 4289 set_x: str | None, 4290 set_y: str | None, 4291 threshold=10, 4292 image_width=12, 4293 image_high=7, 4294 color="black", 4295 ): 4296 """ 4297 Perform regression analysis between two selected cells and visualize the relationship. 4298 4299 This function computes a linear regression between two specified cells from 4300 aggregated normalized data, plots the regression line with scatter points, 4301 annotates regression statistics, and highlights potential outliers. 4302 4303 Parameters 4304 ---------- 4305 cell_x : str 4306 Name of the first cell (X-axis). 4307 4308 cell_y : str 4309 Name of the second cell (Y-axis). 4310 4311 set_x : str or None 4312 Dataset identifier corresponding to `cell_x`. If None, cell is selected only by name. 4313 4314 set_y : str or None 4315 Dataset identifier corresponding to `cell_y`. If None, cell is selected only by name. 4316 4317 threshold : int or float, default 10 4318 Threshold for detecting outliers. Points deviating from the mean or diagonal by more 4319 than this value are annotated. 4320 4321 image_width : int, default 12 4322 Width of the regression plot (in inches). 4323 4324 image_high : int, default 7 4325 Height of the regression plot (in inches). 4326 4327 color : str, default 'black' 4328 Color of the regression scatter points and line. 4329 4330 Returns 4331 ------- 4332 matplotlib.figure.Figure 4333 Regression plot figure with annotated regression line, R², p-value, and outliers. 4334 4335 Raises 4336 ------ 4337 ValueError 4338 If `cell_x` or `cell_y` are not found in the dataset. 4339 If multiple matches are found for a cell name and `set_x`/`set_y` are not specified. 4340 4341 Notes 4342 ----- 4343 - The function automatically calls `jseq_object.average()` if aggregated data is not available. 4344 - Outliers are annotated with their corresponding index labels. 4345 - Regression is computed using `scipy.stats.linregress`. 4346 4347 Examples 4348 -------- 4349 >>> obj.cell_regression(cell_x="Purkinje", cell_y="Granule", set_x="Exp1", set_y="Exp2") 4350 >>> obj.cell_regression(cell_x="NeuronA", cell_y="NeuronB", threshold=5, color="blue") 4351 """ 4352 4353 if self.agg_normalized_data is None: 4354 self.average() 4355 4356 metadata = self.agg_metadata 4357 data = self.agg_normalized_data 4358 4359 if set_x is not None and set_y is not None: 4360 data.columns = metadata["cell_names"] + " # " + metadata["sets"] 4361 cell_x = cell_x + " # " + set_x 4362 cell_y = cell_y + " # " + set_y 4363 4364 else: 4365 data.columns = metadata["cell_names"] 4366 4367 if not cell_x in data.columns: 4368 raise ValueError("'cell_x' value not in cell names!") 4369 4370 if not cell_y in data.columns: 4371 raise ValueError("'cell_y' value not in cell names!") 4372 4373 if list(data.columns).count(cell_x) > 1: 4374 raise ValueError( 4375 f"'{cell_x}' occurs more than once. If you want to select a specific cell, " 4376 f"please also provide the corresponding 'set_x' and 'set_y' values." 4377 ) 4378 4379 if list(data.columns).count(cell_y) > 1: 4380 raise ValueError( 4381 f"'{cell_y}' occurs more than once. If you want to select a specific cell, " 4382 f"please also provide the corresponding 'set_x' and 'set_y' values." 4383 ) 4384 4385 fig, ax = plt.subplots(figsize=(image_width, image_high)) 4386 ax = sns.regplot(x=cell_x, y=cell_y, data=data, color=color) 4387 4388 slope, intercept, r_value, p_value, _ = stats.linregress( 4389 data[cell_x], data[cell_y] 4390 ) 4391 equation = "y = {:.2f}x + {:.2f}".format(slope, intercept) 4392 4393 ax.annotate( 4394 "R-squared = {:.2f}\nP-value = {:.2f}\n{}".format( 4395 r_value**2, p_value, equation 4396 ), 4397 xy=(0.05, 0.90), 4398 xycoords="axes fraction", 4399 fontsize=12, 4400 ) 4401 4402 ax.spines["top"].set_visible(False) 4403 ax.spines["right"].set_visible(False) 4404 4405 diff = [] 4406 x_mean, y_mean = data[cell_x].mean(), data[cell_y].mean() 4407 for i, (xi, yi) in enumerate(zip(data[cell_x], data[cell_y])): 4408 diff.append(abs(xi - x_mean)) 4409 diff.append(abs(yi - y_mean)) 4410 4411 def annotate_outliers(x, y, threshold): 4412 texts = [] 4413 x_mean, y_mean = x.mean(), y.mean() 4414 for i, (xi, yi) in enumerate(zip(x, y)): 4415 if ( 4416 abs(xi - x_mean) > threshold 4417 or abs(yi - y_mean) > threshold 4418 or abs(yi - xi) > threshold 4419 ): 4420 text = ax.text(xi, yi, data.index[i]) 4421 texts.append(text) 4422 4423 return texts 4424 4425 texts = annotate_outliers(data[cell_x], data[cell_y], threshold) 4426 4427 adjust_text(texts, arrowprops=dict(arrowstyle="-", color="gray", alpha=0.5)) 4428 4429 plt.show() 4430 4431 return fig
Perform regression analysis between two selected cells and visualize the relationship.
This function computes a linear regression between two specified cells from aggregated normalized data, plots the regression line with scatter points, annotates regression statistics, and highlights potential outliers.
Parameters
cell_x : str Name of the first cell (X-axis).
cell_y : str Name of the second cell (Y-axis).
set_x : str or None
Dataset identifier corresponding to cell_x. If None, cell is selected only by name.
set_y : str or None
Dataset identifier corresponding to cell_y. If None, cell is selected only by name.
threshold : int or float, default 10 Threshold for detecting outliers. Points deviating from the mean or diagonal by more than this value are annotated.
image_width : int, default 12 Width of the regression plot (in inches).
image_high : int, default 7 Height of the regression plot (in inches).
color : str, default 'black' Color of the regression scatter points and line.
Returns
matplotlib.figure.Figure Regression plot figure with annotated regression line, R², p-value, and outliers.
Raises
ValueError
If cell_x or cell_y are not found in the dataset.
If multiple matches are found for a cell name and set_x/set_y are not specified.
Notes
- The function automatically calls
jseq_object.average()if aggregated data is not available. - Outliers are annotated with their corresponding index labels.
- Regression is computed using
scipy.stats.linregress.
Examples
>>> obj.cell_regression(cell_x="Purkinje", cell_y="Granule", set_x="Exp1", set_y="Exp2")
>>> obj.cell_regression(cell_x="NeuronA", cell_y="NeuronB", threshold=5, color="blue")