jdti.jdti
1import math 2import os 3import pickle 4import re 5 6import harmonypy as harmonize 7import matplotlib.pyplot as plt 8import numpy as np 9import pandas as pd 10import plotly.express as px 11import seaborn as sns 12import umap 13from adjustText import adjust_text 14from joblib import Parallel, delayed 15from matplotlib.patches import FancyArrowPatch, Patch, Polygon 16from scipy import sparse 17from scipy.io import mmwrite 18from scipy.spatial import ConvexHull 19from scipy.spatial.distance import pdist, squareform 20from scipy.stats import norm, stats, zscore 21from sklearn.cluster import DBSCAN, MeanShift 22from sklearn.decomposition import PCA 23from sklearn.metrics import pairwise_distances, silhouette_score 24from sklearn.preprocessing import MinMaxScaler, StandardScaler 25from tqdm import tqdm 26 27from .utils import * 28 29 30class Clustering: 31 """ 32 A class for performing dimensionality reduction, clustering, and visualization 33 on high-dimensional data (e.g., single-cell gene expression). 34 35 The class provides methods for: 36 - Normalizing and extracting subsets of data 37 - Principal Component Analysis (PCA) and related clustering 38 - Uniform Manifold Approximation and Projection (UMAP) and clustering 39 - Visualization of PCA and UMAP embeddings 40 - Harmonization of batch effects 41 - Accessing processed data and cluster labels 42 43 Methods 44 ------- 45 add_data_frame(data, metadata) 46 Class method to create a Clustering instance from a DataFrame and metadata. 47 48 harmonize_sets() 49 Perform batch effect harmonization on PCA data. 50 51 perform_PCA(pc_num=100, width=8, height=6) 52 Perform PCA on the dataset and visualize the first two PCs. 53 54 knee_plot_PCA(width=8, height=6) 55 Plot the cumulative variance explained by PCs to determine optimal dimensionality. 56 57 find_clusters_PCA(pc_num=0, eps=0.5, min_samples=10, width=8, height=6, harmonized=False) 58 Apply DBSCAN clustering to PCA embeddings and visualize results. 59 60 perform_UMAP(factorize=False, umap_num=100, pc_num=0, harmonized=False, ...) 61 Compute UMAP embeddings with optional parameter tuning. 62 63 knee_plot_umap(eps=0.5, min_samples=10) 64 Determine optimal UMAP dimensionality using silhouette scores. 65 66 find_clusters_UMAP(umap_n=5, eps=0.5, min_samples=10, width=8, height=6) 67 Apply DBSCAN clustering on UMAP embeddings and visualize clusters. 68 69 UMAP_vis(names_slot='cell_names', set_sep=True, point_size=0.6, ...) 70 Visualize UMAP embeddings with labels and optional cluster numbering. 71 72 UMAP_feature(feature_name, features_data=None, point_size=0.6, ...) 73 Plot a single feature over UMAP coordinates with customizable colormap. 74 75 get_umap_data() 76 Return the UMAP embeddings along with cluster labels if available. 77 78 get_pca_data() 79 Return the PCA results along with cluster labels if available. 80 81 return_clusters(clusters='umap') 82 Return the cluster labels for UMAP or PCA embeddings. 83 84 Raises 85 ------ 86 ValueError 87 For invalid parameters, mismatched dimensions, or missing metadata. 88 """ 89 90 def __init__(self, data, metadata): 91 """ 92 Initialize the clustering class with data and optional metadata. 93 94 Parameters 95 ---------- 96 data : pandas.DataFrame 97 Input data for clustering. Columns are considered as samples. 98 99 metadata : pandas.DataFrame, optional 100 Metadata for the samples. If None, a default DataFrame with column 101 names as 'cell_names' is created. 102 103 Attributes 104 ---------- 105 -clustering_data : pandas.DataFrame 106 -clustering_metadata : pandas.DataFrame 107 -subclusters : None or dict 108 -explained_var : None or numpy.ndarray 109 -cumulative_var : None or numpy.ndarray 110 -pca : None or pandas.DataFrame 111 -harmonized_pca : None or pandas.DataFrame 112 -umap : None or pandas.DataFrame 113 """ 114 115 self.clustering_data = data 116 """The input data used for clustering.""" 117 118 if metadata is None: 119 metadata = pd.DataFrame({"cell_names": list(data.columns)}) 120 121 self.clustering_metadata = metadata 122 """Metadata associated with the samples.""" 123 124 self.subclusters = None 125 """Placeholder for storing subcluster information.""" 126 127 self.explained_var = None 128 """Explained variance from PCA, initialized as None.""" 129 130 self.cumulative_var = None 131 """Cumulative explained variance from PCA, initialized as None.""" 132 133 self.pca = None 134 """PCA-transformed data, initialized as None.""" 135 136 self.harmonized_pca = None 137 """PCA data after batch effect harmonization, initialized as None.""" 138 139 self.umap = None 140 """UMAP embeddings, initialized as None.""" 141 142 @classmethod 143 def add_data_frame(cls, data: pd.DataFrame, metadata: pd.DataFrame | None): 144 """ 145 Create a Clustering instance from a DataFrame and optional metadata. 146 147 Parameters 148 ---------- 149 data : pandas.DataFrame 150 Input data with features as rows and samples/cells as columns. 151 152 metadata : pandas.DataFrame or None 153 Optional metadata for the samples. 154 Each row corresponds to a sample/cell, and column names in this DataFrame 155 should match the sample/cell names in `data`. Columns can contain additional 156 information such as cell type, experimental condition, batch, sets, etc. 157 158 Returns 159 ------- 160 Clustering 161 A new instance of the Clustering class. 162 """ 163 164 return cls(data, metadata) 165 166 def harmonize_sets(self, batch_col: str = "sets"): 167 """ 168 Perform batch effect harmonization on PCA embeddings. 169 170 Parameters 171 ---------- 172 batch_col : str, default 'sets' 173 Name of the column in `metadata` that contains batch information for the samples/cells. 174 175 Returns 176 ------- 177 None 178 Updates the `harmonized_pca` attribute with harmonized data. 179 """ 180 181 data_mat = np.array(self.pca) 182 183 metadata = self.clustering_metadata 184 185 self.harmonized_pca = pd.DataFrame( 186 harmonize.run_harmony(data_mat, metadata, vars_use=batch_col).Z_corr 187 ).T 188 189 self.harmonized_pca.columns = self.pca.columns 190 191 def perform_PCA(self, pc_num: int = 100, width=8, height=6): 192 """ 193 Perform Principal Component Analysis (PCA) on the dataset. 194 195 This method standardizes the data, applies PCA, stores results as attributes, 196 and generates a scatter plot of the first two principal components. 197 198 Parameters 199 ---------- 200 pc_num : int, default 100 201 Number of principal components to compute. 202 If 0, computes all available components. 203 204 width : int or float, default 8 205 Width of the PCA figure. 206 207 height : int or float, default 6 208 Height of the PCA figure. 209 210 Returns 211 ------- 212 matplotlib.figure.Figure 213 Scatter plot showing the first two principal components. 214 215 Updates 216 ------- 217 self.pca : pandas.DataFrame 218 DataFrame with principal component scores for each sample. 219 220 self.explained_var : numpy.ndarray 221 Percentage of variance explained by each principal component. 222 223 self.cumulative_var : numpy.ndarray 224 Cumulative explained variance. 225 """ 226 227 scaler = StandardScaler() 228 data_scaled = scaler.fit_transform(self.clustering_data.T) 229 230 if pc_num == 0 or pc_num > data_scaled.shape[0]: 231 pc_num = data_scaled.shape[0] 232 233 pca = PCA(n_components=pc_num, random_state=42) 234 235 principal_components = pca.fit_transform(data_scaled) 236 237 pca_df = pd.DataFrame( 238 data=principal_components, 239 columns=["PC" + str(x + 1) for x in range(pc_num)], 240 ) 241 242 self.explained_var = pca.explained_variance_ratio_ * 100 243 self.cumulative_var = np.cumsum(self.explained_var) 244 245 self.pca = pca_df 246 247 fig = plt.figure(figsize=(width, height)) 248 plt.scatter(pca_df["PC1"], pca_df["PC2"], alpha=0.7) 249 plt.xlabel("PC 1") 250 plt.ylabel("PC 2") 251 plt.grid(True) 252 plt.show() 253 254 return fig 255 256 def knee_plot_PCA(self, width: int = 8, height: int = 6): 257 """ 258 Plot cumulative explained variance to determine the optimal number of PCs. 259 260 Parameters 261 ---------- 262 width : int, default 8 263 Width of the figure. 264 265 height : int or, default 6 266 Height of the figure. 267 268 Returns 269 ------- 270 matplotlib.figure.Figure 271 Line plot showing cumulative variance explained by each PC. 272 """ 273 274 fig_knee = plt.figure(figsize=(width, height)) 275 plt.plot(range(1, len(self.explained_var) + 1), self.cumulative_var, marker="o") 276 plt.xlabel("PC (n components)") 277 plt.ylabel("Cumulative explained variance (%)") 278 plt.grid(True) 279 280 xticks = [1] + list(range(5, len(self.explained_var) + 1, 5)) 281 plt.xticks(xticks, rotation=60) 282 283 plt.show() 284 285 return fig_knee 286 287 def find_clusters_PCA( 288 self, 289 pc_num: int = 2, 290 eps: float = 0.5, 291 min_samples: int = 10, 292 width: int = 8, 293 height: int = 6, 294 harmonized: bool = False, 295 ): 296 """ 297 Apply DBSCAN clustering to PCA embeddings and visualize the results. 298 299 This method performs density-based clustering (DBSCAN) on the PCA-reduced 300 dataset. Cluster labels are stored in the object's metadata, and a scatter 301 plot of the first two principal components with cluster annotations is returned. 302 303 Parameters 304 ---------- 305 pc_num : int, default 2 306 Number of principal components to use for clustering. 307 If 0, uses all available components. 308 309 eps : float, default 0.5 310 Maximum distance between two points for them to be considered 311 as neighbors (DBSCAN parameter). 312 313 min_samples : int, default 10 314 Minimum number of samples required to form a cluster (DBSCAN parameter). 315 316 width : int, default 8 317 Width of the output scatter plot. 318 319 height : int, default 6 320 Height of the output scatter plot. 321 322 harmonized : bool, default False 323 If True, use harmonized PCA data (`self.harmonized_pca`). 324 If False, use standard PCA results (`self.pca`). 325 326 Returns 327 ------- 328 matplotlib.figure.Figure 329 Scatter plot of the first two principal components colored by 330 cluster assignments. 331 332 Updates 333 ------- 334 self.clustering_metadata['PCA_clusters'] : list 335 Cluster labels assigned to each cell/sample. 336 337 self.input_metadata['PCA_clusters'] : list, optional 338 Cluster labels stored in input metadata (if available). 339 """ 340 341 dbscan = DBSCAN(eps=eps, min_samples=min_samples) 342 343 if pc_num == 0 and harmonized: 344 PCA = self.harmonized_pca 345 346 elif pc_num == 0: 347 PCA = self.pca 348 349 else: 350 if harmonized: 351 352 PCA = self.harmonized_pca.iloc[:, 0:pc_num] 353 else: 354 355 PCA = self.pca.iloc[:, 0:pc_num] 356 357 dbscan_labels = dbscan.fit_predict(PCA) 358 359 pca_df = pd.DataFrame(PCA) 360 pca_df["Cluster"] = dbscan_labels 361 362 fig = plt.figure(figsize=(width, height)) 363 364 for cluster_id in sorted(pca_df["Cluster"].unique()): 365 cluster_data = pca_df[pca_df["Cluster"] == cluster_id] 366 plt.scatter( 367 cluster_data["PC1"], 368 cluster_data["PC2"], 369 label=f"Cluster {cluster_id}", 370 alpha=0.7, 371 ) 372 373 plt.xlabel("PC 1") 374 plt.ylabel("PC 2") 375 plt.legend(title="Clusters", loc="center left", bbox_to_anchor=(1.0, 0.5)) 376 377 plt.grid(True) 378 plt.show() 379 380 self.clustering_metadata["PCA_clusters"] = [str(x) for x in dbscan_labels] 381 382 try: 383 self.input_metadata["PCA_clusters"] = [str(x) for x in dbscan_labels] 384 except: 385 pass 386 387 return fig 388 389 def perform_UMAP( 390 self, 391 factorize: bool = False, 392 umap_num: int = 100, 393 pc_num: int = 0, 394 harmonized: bool = False, 395 n_neighbors: int = 5, 396 min_dist: float | int = 0.1, 397 spread: float | int = 1.0, 398 set_op_mix_ratio: float | int = 1.0, 399 local_connectivity: int = 1, 400 repulsion_strength: float | int = 1.0, 401 negative_sample_rate: int = 5, 402 width: int = 8, 403 height: int = 6, 404 ): 405 """ 406 Compute and visualize UMAP embeddings of the dataset. 407 408 This method applies Uniform Manifold Approximation and Projection (UMAP) 409 for dimensionality reduction on either raw, PCA, or harmonized PCA data. 410 Results are stored as a DataFrame (`self.umap`) and a scatter plot figure 411 (`self.UMAP_plot`). 412 413 Parameters 414 ---------- 415 factorize : bool, default False 416 If True, categorical sample labels (from column names) are factorized 417 and used as supervision in UMAP fitting. 418 419 umap_num : int, default 100 420 Number of UMAP dimensions to compute. If 0, matches the input dimension. 421 422 pc_num : int, default 0 423 Number of principal components to use as UMAP input. 424 If 0, use all available components or raw data. 425 426 harmonized : bool, default False 427 If True, use harmonized PCA embeddings (`self.harmonized_pca`). 428 If False, use standard PCA or raw scaled data. 429 430 n_neighbors : int, default 5 431 UMAP parameter controlling the size of the local neighborhood. 432 433 min_dist : float, default 0.1 434 UMAP parameter controlling minimum allowed distance between embedded points. 435 436 spread : int | float, default 1.0 437 Effective scale of embedded space (UMAP parameter). 438 439 set_op_mix_ratio : int | float, default 1.0 440 Interpolation parameter between union and intersection in fuzzy sets. 441 442 local_connectivity : int, default 1 443 Number of nearest neighbors assumed for each point. 444 445 repulsion_strength : int | float, default 1.0 446 Weighting applied to negative samples during optimization. 447 448 negative_sample_rate : int, default 5 449 Number of negative samples per positive sample in optimization. 450 451 width : int, default 8 452 Width of the output scatter plot. 453 454 height : int, default 6 455 Height of the output scatter plot. 456 457 Updates 458 ------- 459 self.umap : pandas.DataFrame 460 Table of UMAP embeddings with columns `UMAP1 ... UMAPn`. 461 462 Notes 463 ----- 464 For supervised UMAP (`factorize=True`), categorical codes from column 465 names of the dataset are used as labels. 466 """ 467 468 scaler = StandardScaler() 469 470 if pc_num == 0 and harmonized: 471 data_scaled = self.harmonized_pca 472 473 elif pc_num == 0: 474 data_scaled = scaler.fit_transform(self.clustering_data.T) 475 476 else: 477 if harmonized: 478 479 data_scaled = self.harmonized_pca.iloc[:, 0:pc_num] 480 else: 481 482 data_scaled = self.pca.iloc[:, 0:pc_num] 483 484 if umap_num == 0 or umap_num > data_scaled.shape[1]: 485 486 umap_num = data_scaled.shape[1] 487 488 reducer = umap.UMAP( 489 n_components=len(data_scaled.T), 490 random_state=42, 491 n_neighbors=n_neighbors, 492 min_dist=min_dist, 493 spread=spread, 494 set_op_mix_ratio=set_op_mix_ratio, 495 local_connectivity=local_connectivity, 496 repulsion_strength=repulsion_strength, 497 negative_sample_rate=negative_sample_rate, 498 n_jobs=1, 499 ) 500 501 else: 502 503 reducer = umap.UMAP( 504 n_components=umap_num, 505 random_state=42, 506 n_neighbors=n_neighbors, 507 min_dist=min_dist, 508 spread=spread, 509 set_op_mix_ratio=set_op_mix_ratio, 510 local_connectivity=local_connectivity, 511 repulsion_strength=repulsion_strength, 512 negative_sample_rate=negative_sample_rate, 513 n_jobs=1, 514 ) 515 516 if factorize: 517 embedding = reducer.fit_transform( 518 X=data_scaled, y=pd.Categorical(self.clustering_data.columns).codes 519 ) 520 else: 521 embedding = reducer.fit_transform(X=data_scaled) 522 523 umap_df = pd.DataFrame( 524 embedding, columns=["UMAP" + str(x + 1) for x in range(umap_num)] 525 ) 526 527 plt.figure(figsize=(width, height)) 528 plt.scatter(umap_df["UMAP1"], umap_df["UMAP2"], alpha=0.7) 529 plt.xlabel("UMAP 1") 530 plt.ylabel("UMAP 2") 531 plt.grid(True) 532 533 plt.show() 534 535 self.umap = umap_df 536 537 def knee_plot_umap(self, eps: int | float = 0.5, min_samples: int = 10): 538 """ 539 Plot silhouette scores for different UMAP dimensions to determine optimal n_components. 540 541 Parameters 542 ---------- 543 eps : float, default 0.5 544 DBSCAN eps parameter for clustering each UMAP dimension. 545 546 min_samples : int, default 10 547 Minimum number of samples to form a cluster in DBSCAN. 548 549 Returns 550 ------- 551 matplotlib.figure.Figure 552 Silhouette score plot across UMAP dimensions. 553 """ 554 555 umap_range = range(2, len(self.umap.T) + 1) 556 557 silhouette_scores = [] 558 component = [] 559 for n in umap_range: 560 561 db = DBSCAN(eps=eps, min_samples=min_samples) 562 labels = db.fit_predict(np.array(self.umap)[:, :n]) 563 564 mask = labels != -1 565 if len(set(labels[mask])) > 1: 566 score = silhouette_score(np.array(self.umap)[:, :n][mask], labels[mask]) 567 else: 568 score = -1 569 570 silhouette_scores.append(score) 571 component.append(n) 572 573 fig = plt.figure(figsize=(10, 5)) 574 plt.plot(component, silhouette_scores, marker="o") 575 plt.xlabel("UMAP (n_components)") 576 plt.ylabel("Silhouette Score") 577 plt.grid(True) 578 plt.xticks(range(int(min(component)), int(max(component)) + 1, 1)) 579 580 plt.show() 581 582 return fig 583 584 def find_clusters_UMAP( 585 self, 586 umap_n: int = 5, 587 eps: float | float = 0.5, 588 min_samples: int = 10, 589 width: int = 8, 590 height: int = 6, 591 ): 592 """ 593 Apply DBSCAN clustering on UMAP embeddings and visualize clusters. 594 595 This method performs density-based clustering (DBSCAN) on the UMAP-reduced 596 dataset. Cluster labels are stored in the object's metadata, and a scatter 597 plot of the first two UMAP components with cluster annotations is returned. 598 599 Parameters 600 ---------- 601 umap_n : int, default 5 602 Number of UMAP dimensions to use for DBSCAN clustering. 603 Must be <= number of columns in `self.umap`. 604 605 eps : float | int, default 0.5 606 Maximum neighborhood distance between two samples for them to be considered 607 as in the same cluster (DBSCAN parameter). 608 609 min_samples : int, default 10 610 Minimum number of samples in a neighborhood to form a cluster (DBSCAN parameter). 611 612 width : int, default 8 613 Figure width. 614 615 height : int, default 6 616 Figure height. 617 618 Returns 619 ------- 620 matplotlib.figure.Figure 621 Scatter plot of the first two UMAP components colored by 622 cluster assignments. 623 624 Updates 625 ------- 626 self.clustering_metadata['UMAP_clusters'] : list 627 Cluster labels assigned to each cell/sample. 628 629 self.input_metadata['UMAP_clusters'] : list, optional 630 Cluster labels stored in input metadata (if available). 631 """ 632 633 dbscan = DBSCAN(eps=eps, min_samples=min_samples) 634 dbscan_labels = dbscan.fit_predict(np.array(self.umap)[:, :umap_n]) 635 636 umap_df = self.umap 637 umap_df["Cluster"] = dbscan_labels 638 639 fig = plt.figure(figsize=(width, height)) 640 641 for cluster_id in sorted(umap_df["Cluster"].unique()): 642 cluster_data = umap_df[umap_df["Cluster"] == cluster_id] 643 plt.scatter( 644 cluster_data["UMAP1"], 645 cluster_data["UMAP2"], 646 label=f"Cluster {cluster_id}", 647 alpha=0.7, 648 ) 649 650 plt.xlabel("UMAP 1") 651 plt.ylabel("UMAP 2") 652 plt.legend(title="Clusters", loc="center left", bbox_to_anchor=(1.0, 0.5)) 653 plt.grid(True) 654 655 self.clustering_metadata["UMAP_clusters"] = [str(x) for x in dbscan_labels] 656 657 try: 658 self.input_metadata["UMAP_clusters"] = [str(x) for x in dbscan_labels] 659 except: 660 pass 661 662 return fig 663 664 def UMAP_vis( 665 self, 666 names_slot: str = "cell_names", 667 set_sep: bool = True, 668 point_size: int | float = 0.6, 669 font_size: int | float = 6, 670 legend_split_col: int = 2, 671 width: int = 8, 672 height: int = 6, 673 inc_num: bool = True, 674 ): 675 """ 676 Visualize UMAP embeddings with sample labels based on specyfic metadata slot. 677 678 Parameters 679 ---------- 680 names_slot : str, default 'cell_names' 681 Column in metadata to use as sample labels. 682 683 set_sep : bool, default True 684 If True, separate points by dataset. 685 686 point_size : float, default 0.6 687 Size of scatter points. 688 689 font_size : int, default 6 690 Font size for numbers on points. 691 692 legend_split_col : int, default 2 693 Number of columns in legend. 694 695 width : int, default 8 696 Figure width. 697 698 height : int, default 6 699 Figure height. 700 701 inc_num : bool, default True 702 If True, annotate points with numeric labels. 703 704 Returns 705 ------- 706 matplotlib.figure.Figure 707 UMAP scatter plot figure. 708 """ 709 710 umap_df = self.umap.iloc[:, 0:2].copy() 711 umap_df.loc[:, "names"] = self.clustering_metadata[names_slot] 712 713 if set_sep: 714 715 if "sets" in self.clustering_metadata.columns: 716 umap_df.loc[:, "dataset"] = self.clustering_metadata["sets"] 717 else: 718 umap_df["dataset"] = "default" 719 720 else: 721 umap_df["dataset"] = "default" 722 723 umap_df.loc[:, "tmp_nam"] = umap_df["names"] + umap_df["dataset"] 724 725 umap_df.loc[:, "count"] = umap_df["tmp_nam"].map( 726 umap_df["tmp_nam"].value_counts() 727 ) 728 729 numeric_df = ( 730 pd.DataFrame(umap_df[["count", "tmp_nam", "names"]].copy()) 731 .drop_duplicates() 732 .sort_values("count", ascending=False) 733 ) 734 numeric_df["numeric_values"] = range(0, numeric_df.shape[0]) 735 736 umap_df = umap_df.merge( 737 numeric_df[["tmp_nam", "numeric_values"]], on="tmp_nam", how="left" 738 ) 739 740 fig, ax = plt.subplots(figsize=(width, height)) 741 742 markers = ["o", "s", "^", "D", "P", "*", "X"] 743 marker_map = { 744 ds: markers[i % len(markers)] 745 for i, ds in enumerate(umap_df["dataset"].unique()) 746 } 747 748 cord_list = [] 749 750 for num, nam in zip(numeric_df["numeric_values"], numeric_df["names"]): 751 752 cluster_data = umap_df[umap_df["numeric_values"] == num] 753 754 ax.scatter( 755 cluster_data["UMAP1"], 756 cluster_data["UMAP2"], 757 label=f"{num} - {nam}", 758 marker=marker_map[cluster_data["dataset"].iloc[0]], 759 alpha=0.6, 760 s=point_size, 761 ) 762 763 coords = cluster_data[["UMAP1", "UMAP2"]].values 764 765 dists = pairwise_distances(coords) 766 767 sum_dists = dists.sum(axis=1) 768 769 center_idx = np.argmin(sum_dists) 770 center_point = coords[center_idx] 771 772 cord_list.append(center_point) 773 774 if inc_num: 775 texts = [] 776 for (x, y), num in zip(cord_list, numeric_df["numeric_values"]): 777 texts.append( 778 ax.text( 779 x, 780 y, 781 str(num), 782 ha="center", 783 va="center", 784 fontsize=font_size, 785 color="black", 786 ) 787 ) 788 789 adjust_text(texts, arrowprops=dict(arrowstyle="-", color="gray", lw=0.5)) 790 791 ax.set_xlabel("UMAP 1") 792 ax.set_ylabel("UMAP 2") 793 794 ax.legend( 795 title="Clusters", 796 loc="center left", 797 bbox_to_anchor=(1.05, 0.5), 798 ncol=legend_split_col, 799 markerscale=5, 800 ) 801 802 ax.grid(True) 803 804 plt.tight_layout() 805 806 return fig 807 808 def UMAP_feature( 809 self, 810 feature_name: str, 811 features_data: pd.DataFrame | None, 812 point_size: int | float = 0.6, 813 font_size: int | float = 6, 814 width: int = 8, 815 height: int = 6, 816 palette="light", 817 ): 818 """ 819 Visualize UMAP embedding with expression levels of a selected feature. 820 821 Each point (cell) in the UMAP plot is colored according to the expression 822 value of the chosen feature, enabling interpretation of spatial patterns 823 of gene activity or metadata distribution in low-dimensional space. 824 825 Parameters 826 ---------- 827 feature_name : str 828 Name of the feature to plot. 829 830 features_data : pandas.DataFrame or None, default None 831 If None, the function uses the DataFrame containing the clustering data. 832 To plot features not used in clustering, provide a wider DataFrame 833 containing the original feature values. 834 835 point_size : float, default 0.6 836 Size of scatter points in the plot. 837 838 font_size : int, default 6 839 Font size for axis labels and annotations. 840 841 width : int, default 8 842 Width of the matplotlib figure. 843 844 height : int, default 6 845 Height of the matplotlib figure. 846 847 palette : str, default 'light' 848 Color palette for expression visualization. Options are: 849 - 'light' 850 - 'dark' 851 - 'green' 852 - 'gray' 853 854 Returns 855 ------- 856 matplotlib.figure.Figure 857 UMAP scatter plot colored by feature values. 858 """ 859 860 umap_df = self.umap.iloc[:, 0:2].copy() 861 862 if features_data is None: 863 864 features_data = self.clustering_data 865 866 if features_data.shape[1] != umap_df.shape[0]: 867 raise ValueError( 868 "Imputed 'features_data' shape does not match the number of UMAP cells" 869 ) 870 871 blist = [ 872 True if x.upper() == feature_name.upper() else False 873 for x in features_data.index 874 ] 875 876 if not any(blist): 877 raise ValueError("Imputed feature_name is not included in the data") 878 879 umap_df.loc[:, "feature"] = ( 880 features_data.loc[blist, :] 881 .apply(lambda row: row.tolist(), axis=1) 882 .values[0] 883 ) 884 885 umap_df = umap_df.sort_values("feature", ascending=True) 886 887 import matplotlib.colors as mcolors 888 889 if palette == "light": 890 palette = px.colors.sequential.Sunsetdark 891 892 elif palette == "dark": 893 palette = px.colors.sequential.thermal 894 895 elif palette == "green": 896 palette = px.colors.sequential.Aggrnyl 897 898 elif palette == "gray": 899 palette = px.colors.sequential.gray 900 palette = palette[::-1] 901 902 else: 903 raise ValueError( 904 'Palette not found. Use: "light", "dark", "gray", or "green"' 905 ) 906 907 converted = [] 908 for c in palette: 909 rgb_255 = px.colors.unlabel_rgb(c) 910 rgb_01 = tuple(v / 255.0 for v in rgb_255) 911 converted.append(rgb_01) 912 913 my_cmap = mcolors.ListedColormap(converted, name="custom") 914 915 fig, ax = plt.subplots(figsize=(width, height)) 916 sc = ax.scatter( 917 umap_df["UMAP1"], 918 umap_df["UMAP2"], 919 c=umap_df["feature"], 920 s=point_size, 921 cmap=my_cmap, 922 alpha=1.0, 923 edgecolors="black", 924 linewidths=0.1, 925 ) 926 927 cbar = plt.colorbar(sc, ax=ax) 928 cbar.set_label(f"{feature_name}") 929 930 ax.set_xlabel("UMAP 1") 931 ax.set_ylabel("UMAP 2") 932 933 ax.grid(True) 934 935 plt.tight_layout() 936 937 return fig 938 939 def get_umap_data(self): 940 """ 941 Retrieve UMAP embedding data with optional cluster labels. 942 943 Returns the UMAP coordinates stored in `self.umap`. If clustering 944 metadata is available (specifically `UMAP_clusters`), the corresponding 945 cluster assignments are appended as an additional column. 946 947 Returns 948 ------- 949 pandas.DataFrame 950 DataFrame containing UMAP coordinates (columns: 'UMAP1', 'UMAP2', ...). 951 If available, includes an extra column 'clusters' with cluster labels. 952 953 Notes 954 ----- 955 - UMAP embeddings must be computed beforehand (e.g., using `perform_UMAP`). 956 - Cluster labels are added only if present in `self.clustering_metadata`. 957 """ 958 959 umap_data = self.umap 960 961 try: 962 umap_data["clusters"] = self.clustering_metadata["UMAP_clusters"] 963 except: 964 pass 965 966 return umap_data 967 968 def get_pca_data(self): 969 """ 970 Retrieve PCA embedding data with optional cluster labels. 971 972 Returns the principal component scores stored in `self.pca`. If clustering 973 metadata is available (specifically `PCA_clusters`), the corresponding 974 cluster assignments are appended as an additional column. 975 976 Returns 977 ------- 978 pandas.DataFrame 979 DataFrame containing PCA coordinates (columns: 'PC1', 'PC2', ...). 980 If available, includes an extra column 'clusters' with cluster labels. 981 982 Notes 983 ----- 984 - PCA must be computed beforehand (e.g., using `perform_PCA`). 985 - Cluster labels are added only if present in `self.clustering_metadata`. 986 """ 987 988 pca_data = self.pca 989 990 try: 991 pca_data["clusters"] = self.clustering_metadata["PCA_clusters"] 992 except: 993 pass 994 995 return pca_data 996 997 def return_clusters(self, clusters="umap"): 998 """ 999 Retrieve cluster labels from UMAP or PCA clustering results. 1000 1001 Parameters 1002 ---------- 1003 clusters : str, default 'umap' 1004 Source of cluster labels to return. Must be one of: 1005 - 'umap': return cluster labels from UMAP embeddings. 1006 - 'pca' : return cluster labels from PCA embeddings. 1007 1008 Returns 1009 ------- 1010 list 1011 Cluster labels corresponding to the selected embedding method. 1012 1013 Raises 1014 ------ 1015 ValueError 1016 If `clusters` is not 'umap' or 'pca'. 1017 1018 Notes 1019 ----- 1020 Requires that clustering has already been performed 1021 (e.g., using `find_clusters_UMAP` or `find_clusters_PCA`). 1022 """ 1023 1024 if clusters.lower() == "umap": 1025 clusters_vector = self.clustering_metadata["UMAP_clusters"] 1026 elif clusters.lower() == "pca": 1027 clusters_vector = self.clustering_metadata["PCA_clusters"] 1028 else: 1029 raise ValueError("Parameter 'clusters' must be either 'umap' or 'pca'.") 1030 1031 return clusters_vector 1032 1033 1034class COMPsc(Clustering): 1035 """ 1036 A class `COMPsc` (Comparison of single-cell data) designed for the integration, 1037 analysis, and visualization of single-cell datasets. 1038 The class supports independent dataset integration, subclustering of existing clusters, 1039 marker detection, and multiple visualization strategies. 1040 1041 The COMPsc class provides methods for: 1042 1043 - Normalizing and filtering single-cell data 1044 - Loading and saving sparse 10x-style datasets 1045 - Computing differential expression and marker genes 1046 - Clustering and subclustering analysis 1047 - Visualizing similarity and spatial relationships 1048 - Aggregating data by cell and set annotations 1049 - Managing metadata and renaming labels 1050 - Plotting gene detection histograms and feature scatters 1051 1052 Methods 1053 ------- 1054 project_dir(path_to_directory, project_list) 1055 Scans a directory to create a COMPsc instance mapping project names to their paths. 1056 1057 save_project(name, path=os.getcwd()) 1058 Saves the COMPsc object to a pickle file on disk. 1059 1060 load_project(path) 1061 Loads a previously saved COMPsc object from a pickle file. 1062 1063 reduce_cols(reg, inc_set=False) 1064 Removes columns from data tables where column names contain a specified name or partial substring. 1065 1066 reduce_rows(reg, inc_set=False) 1067 Removes rows from data tables where column names contain a specified feature (gene) name. 1068 1069 get_data(set_info=False) 1070 Returns normalized data with optional set annotations in column names. 1071 1072 get_metadata() 1073 Returns the stored input metadata. 1074 1075 get_partial_data(names=None, features=None, name_slot='cell_names') 1076 Return a subset of the data by sample names and/or features. 1077 1078 gene_calculation() 1079 Calculates and stores per-cell gene detection counts as a pandas Series. 1080 1081 gene_histograme(bins=100) 1082 Plots a histogram of genes detected per cell with an overlaid normal distribution. 1083 1084 gene_threshold(min_n=None, max_n=None) 1085 Filters cells based on minimum and/or maximum gene detection thresholds. 1086 1087 load_sparse_from_projects(normalized_data=False) 1088 Loads and concatenates sparse 10x-style datasets from project paths into count or normalized data. 1089 1090 rename_names(mapping, slot='cell_names') 1091 Renames entries in a specified metadata column using a provided mapping dictionary. 1092 1093 rename_subclusters(mapping) 1094 Renames subcluster labels using a provided mapping dictionary. 1095 1096 save_sparse(path_to_save=os.getcwd(), name_slot='cell_names', data_slot='normalized') 1097 Exports data as 10x-compatible sparse files (matrix.mtx, barcodes.tsv, genes.tsv). 1098 1099 normalize_data(normalize=True, normalize_factor=100000) 1100 Normalizes raw counts to counts-per-specified factor (e.g., CPM-like). 1101 1102 statistic(cells=None, sets=None, min_exp=0.01, min_pct=0.1, n_proc=10) 1103 Computes per-feature differential expression statistics (Mann-Whitney U) comparing target vs. rest groups. 1104 1105 calculate_difference_markers(min_exp=0, min_pct=0.25, n_proc=10, force=False) 1106 Computes and caches differential markers using the statistic method. 1107 1108 clustering_features(features_list=None, name_slot='cell_names', p_val=0.05, top_n=25, adj_mean=True, beta=0.4) 1109 Prepares clustering input by selecting marker features and optionally smoothing cell values. 1110 1111 average() 1112 Aggregates normalized data by averaging across (cell_name, set) pairs. 1113 1114 estimating_similarity(method='pearson', p_val=0.05, top_n=25) 1115 Computes pairwise correlation and Euclidean distance between aggregated samples. 1116 1117 similarity_plot(split_sets=True, set_info=True, cmap='seismic', width=12, height=10) 1118 Visualizes pairwise similarity as a scatter plot with correlation as hue and scaled distance as point size. 1119 1120 spatial_similarity(set_info=True, bandwidth=1, n_neighbors=5, min_dist=0.1, legend_split=2, point_size=20, ...) 1121 Creates a UMAP-like visualization of similarity relationships with cluster hulls and nearest-neighbor arrows. 1122 1123 subcluster_prepare(features, cluster) 1124 Initializes a Clustering object for subcluster analysis on a selected parent cluster. 1125 1126 define_subclusters(umap_num=2, eps=0.5, min_samples=10, bandwidth=1, n_neighbors=5, min_dist=0.1, ...) 1127 Performs UMAP and DBSCAN clustering on prepared subcluster data and stores cluster labels. 1128 1129 subcluster_features_scatter(colors='viridis', hclust='complete', img_width=3, img_high=5, label_size=6, ...) 1130 Visualizes averaged expression and occurrence of features for subclusters as a scatter plot. 1131 1132 subcluster_DEG_scatter(top_n=3, min_exp=0, min_pct=0.25, p_val=0.05, colors='viridis', ...) 1133 Plots top differential features for subclusters as a features-scatter visualization. 1134 1135 accept_subclusters() 1136 Commits subcluster labels to main metadata by renaming cell names and clears subcluster data. 1137 1138 Raises 1139 ------ 1140 ValueError 1141 For invalid parameters, mismatched dimensions, or missing metadata. 1142 1143 """ 1144 1145 def __init__( 1146 self, 1147 objects=None, 1148 ): 1149 """ 1150 Initialize the COMPsc class for single-cell data integration and analysis. 1151 1152 Parameters 1153 ---------- 1154 objects : list or None, optional 1155 Optional list of data objects to initialize the instance with. 1156 1157 Attributes 1158 ---------- 1159 -objects : list or None 1160 -input_data : pandas.DataFrame or None 1161 -input_metadata : pandas.DataFrame or None 1162 -normalized_data : pandas.DataFrame or None 1163 -agg_metadata : pandas.DataFrame or None 1164 -agg_normalized_data : pandas.DataFrame or None 1165 -similarity : pandas.DataFrame or None 1166 -var_data : pandas.DataFrame or None 1167 -subclusters_ : instance of Clustering class or None 1168 -cells_calc : pandas.Series or None 1169 -gene_calc : pandas.Series or None 1170 -composition_data : pandas.DataFrame or None 1171 """ 1172 1173 self.objects = objects 1174 """ Stores the input data objects.""" 1175 1176 self.input_data = None 1177 """Raw input data for clustering or integration analysis.""" 1178 1179 self.input_metadata = None 1180 """Metadata associated with the input data.""" 1181 1182 self.normalized_data = None 1183 """Normalized version of the input data.""" 1184 1185 self.agg_metadata = None 1186 '''Aggregated metadata for all sets in object related to "agg_normalized_data"''' 1187 1188 self.agg_normalized_data = None 1189 """Aggregated and normalized data across multiple sets.""" 1190 1191 self.similarity = None 1192 """Similarity data between cells across all samples. and sets""" 1193 1194 self.var_data = None 1195 """DEG analysis results summarizing variance across all samples in the object.""" 1196 1197 self.subclusters_ = None 1198 """Placeholder for information about subclusters analysis; if computed.""" 1199 1200 self.cells_calc = None 1201 """Number of cells detected per sample (grouped by lineage, e.g., cluster or name), reflecting data composition.""" 1202 1203 self.gene_calc = None 1204 """Number of genes detected per sample (cell), reflecting the sequencing depth.""" 1205 1206 self.composition_data = None 1207 """Data describing composition of cells across clusters or sets.""" 1208 1209 @classmethod 1210 def project_dir(cls, path_to_directory, project_list): 1211 """ 1212 Scan a directory and build a COMPsc instance mapping provided project names 1213 to their paths. 1214 1215 Parameters 1216 ---------- 1217 path_to_directory : str 1218 Path containing project subfolders. 1219 1220 project_list : list[str] 1221 List of filenames (folder names) to include in the returned object map. 1222 1223 Returns 1224 ------- 1225 COMPsc 1226 New COMPsc instance with `objects` populated. 1227 1228 Raises 1229 ------ 1230 Exception 1231 A generic exception is caught and a message printed if scanning fails. 1232 1233 Notes 1234 ----- 1235 Function attempts to match entries in `project_list` to directory 1236 names and constructs a simplified object key from the folder name. 1237 """ 1238 try: 1239 objects = {} 1240 for filename in tqdm(os.listdir(path_to_directory)): 1241 for c in project_list: 1242 f = os.path.join(path_to_directory, filename) 1243 if c == filename and os.path.isdir(f): 1244 objects[ 1245 str( 1246 re.sub( 1247 "_count", 1248 "", 1249 re.sub( 1250 "_fq", 1251 "", 1252 re.sub( 1253 re.sub("_.*", "", filename) + "_", 1254 "", 1255 filename, 1256 ), 1257 ), 1258 ) 1259 ) 1260 ] = f 1261 1262 return cls(objects) 1263 1264 except: 1265 print("Something went wrong. Check the function input data and try again!") 1266 1267 def save_project(self, name, path: str = os.getcwd()): 1268 """ 1269 Save the COMPsc object to disk using pickle. 1270 1271 Parameters 1272 ---------- 1273 name : str 1274 Base filename (without extension) to use when saving. 1275 1276 path : str, default os.getcwd() 1277 Directory in which to save the project file. 1278 1279 Returns 1280 ------- 1281 None 1282 1283 Side Effects 1284 ------------ 1285 - Writes a file `<path>/<name>.jpkl` containing the pickled object. 1286 - Prints a confirmation message with saved path. 1287 """ 1288 1289 full = os.path.join(path, f"{name}.jpkl") 1290 1291 with open(full, "wb") as f: 1292 pickle.dump(self, f) 1293 1294 print(f"Project saved as {full}") 1295 1296 @classmethod 1297 def load_project(cls, path): 1298 """ 1299 Load a previously saved COMPsc project from a pickle file. 1300 1301 Parameters 1302 ---------- 1303 path : str 1304 Full path to the pickled project file. 1305 1306 Returns 1307 ------- 1308 COMPsc 1309 The unpickled COMPsc object. 1310 1311 Raises 1312 ------ 1313 FileNotFoundError 1314 If the provided path does not exist. 1315 """ 1316 1317 if not os.path.exists(path): 1318 raise FileNotFoundError("File does not exist!") 1319 with open(path, "rb") as f: 1320 obj = pickle.load(f) 1321 return obj 1322 1323 def reduce_cols( 1324 self, 1325 reg: str | None = None, 1326 full: str | None = None, 1327 name_slot: str = "cell_names", 1328 inc_set: bool = False, 1329 ): 1330 """ 1331 Remove columns (cells) whose names contain a substring `reg` or 1332 full name `full` from available tables. 1333 1334 Parameters 1335 ---------- 1336 reg : str | None 1337 Substring to search for in column/cell names; matching columns will be removed. 1338 If not None, `full` must be None. 1339 1340 full : str | None 1341 Full name to search for in column/cell names; matching columns will be removed. 1342 If not None, `reg` must be None. 1343 1344 name_slot : str, default 'cell_names' 1345 Column in metadata to use as sample names. 1346 1347 inc_set : bool, default False 1348 If True, column names are interpreted as 'cell_name # set' when matching. 1349 1350 Update 1351 ------------ 1352 Mutates `self.input_data`, `self.normalized_data`, `self.input_metadata`, 1353 `self.agg_normalized_data`, and `self.agg_metadata` (if they exist), 1354 removing columns/rows that match `reg`. 1355 1356 Raises 1357 ------ 1358 Raises ValueError if nothing matches the reduction mask. 1359 """ 1360 1361 if reg is None and full is None: 1362 raise ValueError( 1363 "Both 'reg' and 'full' arguments not provided. Please provide at least one of them!" 1364 ) 1365 1366 if reg is not None and full is not None: 1367 raise ValueError( 1368 "Both 'reg' and 'full' arguments are provided. " 1369 "Please provide only one of them!\n" 1370 "'reg' is used when only part of the name must be detected.\n" 1371 "'full' is used if the full name must be detected." 1372 ) 1373 1374 if reg is not None: 1375 1376 if self.input_data is not None: 1377 1378 if inc_set: 1379 1380 self.input_data.columns = ( 1381 self.input_metadata[name_slot] 1382 + " # " 1383 + self.input_metadata["sets"] 1384 ) 1385 1386 else: 1387 1388 self.input_data.columns = self.input_metadata[name_slot] 1389 1390 mask = [reg.upper() not in x.upper() for x in self.input_data.columns] 1391 1392 if len([y for y in mask if y is False]) == 0: 1393 raise ValueError("Nothing found to reduce") 1394 1395 self.input_data = self.input_data.loc[:, mask] 1396 1397 if self.normalized_data is not None: 1398 1399 if inc_set: 1400 1401 self.normalized_data.columns = ( 1402 self.input_metadata[name_slot] 1403 + " # " 1404 + self.input_metadata["sets"] 1405 ) 1406 1407 else: 1408 1409 self.normalized_data.columns = self.input_metadata[name_slot] 1410 1411 mask = [ 1412 reg.upper() not in x.upper() for x in self.normalized_data.columns 1413 ] 1414 1415 if len([y for y in mask if y is False]) == 0: 1416 raise ValueError("Nothing found to reduce") 1417 1418 self.normalized_data = self.normalized_data.loc[:, mask] 1419 1420 if self.input_metadata is not None: 1421 1422 if inc_set: 1423 1424 self.input_metadata["drop"] = ( 1425 self.input_metadata[name_slot] 1426 + " # " 1427 + self.input_metadata["sets"] 1428 ) 1429 1430 else: 1431 1432 self.input_metadata["drop"] = self.input_metadata[name_slot] 1433 1434 mask = [ 1435 reg.upper() not in x.upper() for x in self.input_metadata["drop"] 1436 ] 1437 1438 if len([y for y in mask if y is False]) == 0: 1439 raise ValueError("Nothing found to reduce") 1440 1441 self.input_metadata = self.input_metadata.loc[mask, :].reset_index( 1442 drop=True 1443 ) 1444 1445 self.input_metadata = self.input_metadata.drop( 1446 columns=["drop"], errors="ignore" 1447 ) 1448 1449 if self.agg_normalized_data is not None: 1450 1451 if inc_set: 1452 1453 self.agg_normalized_data.columns = ( 1454 self.agg_metadata[name_slot] + " # " + self.agg_metadata["sets"] 1455 ) 1456 1457 else: 1458 1459 self.agg_normalized_data.columns = self.agg_metadata[name_slot] 1460 1461 mask = [ 1462 reg.upper() not in x.upper() 1463 for x in self.agg_normalized_data.columns 1464 ] 1465 1466 if len([y for y in mask if y is False]) == 0: 1467 raise ValueError("Nothing found to reduce") 1468 1469 self.agg_normalized_data = self.agg_normalized_data.loc[:, mask] 1470 1471 if self.agg_metadata is not None: 1472 1473 if inc_set: 1474 1475 self.agg_metadata["drop"] = ( 1476 self.agg_metadata[name_slot] + " # " + self.agg_metadata["sets"] 1477 ) 1478 1479 else: 1480 1481 self.agg_metadata["drop"] = self.agg_metadata[name_slot] 1482 1483 mask = [reg.upper() not in x.upper() for x in self.agg_metadata["drop"]] 1484 1485 if len([y for y in mask if y is False]) == 0: 1486 raise ValueError("Nothing found to reduce") 1487 1488 self.agg_metadata = self.agg_metadata.loc[mask, :].reset_index( 1489 drop=True 1490 ) 1491 1492 self.agg_metadata = self.agg_metadata.drop( 1493 columns=["drop"], errors="ignore" 1494 ) 1495 1496 elif full is not None: 1497 1498 if self.input_data is not None: 1499 1500 if inc_set: 1501 1502 self.input_data.columns = ( 1503 self.input_metadata[name_slot] 1504 + " # " 1505 + self.input_metadata["sets"] 1506 ) 1507 1508 if "#" not in full: 1509 1510 self.input_data.columns = self.input_metadata[name_slot] 1511 1512 print( 1513 "Not include the set info (name # set) in the 'full' argument, where 'inc_set' is True" 1514 "Only the names will be compared, without considering the set information." 1515 ) 1516 1517 else: 1518 1519 self.input_data.columns = self.input_metadata[name_slot] 1520 1521 mask = [full.upper() != x.upper() for x in self.input_data.columns] 1522 1523 if len([y for y in mask if y is False]) == 0: 1524 raise ValueError("Nothing found to reduce") 1525 1526 self.input_data = self.input_data.loc[:, mask] 1527 1528 if self.normalized_data is not None: 1529 1530 if inc_set: 1531 1532 self.normalized_data.columns = ( 1533 self.input_metadata[name_slot] 1534 + " # " 1535 + self.input_metadata["sets"] 1536 ) 1537 1538 if "#" not in full: 1539 1540 self.normalized_data.columns = self.input_metadata[name_slot] 1541 1542 print( 1543 "Not include the set info (name # set) in the 'full' argument, where 'inc_set' is True" 1544 "Only the names will be compared, without considering the set information." 1545 ) 1546 1547 else: 1548 1549 self.normalized_data.columns = self.input_metadata[name_slot] 1550 1551 mask = [full.upper() != x.upper() for x in self.normalized_data.columns] 1552 1553 if len([y for y in mask if y is False]) == 0: 1554 raise ValueError("Nothing found to reduce") 1555 1556 self.normalized_data = self.normalized_data.loc[:, mask] 1557 1558 if self.input_metadata is not None: 1559 1560 if inc_set: 1561 1562 self.input_metadata["drop"] = ( 1563 self.input_metadata[name_slot] 1564 + " # " 1565 + self.input_metadata["sets"] 1566 ) 1567 1568 if "#" not in full: 1569 1570 self.input_metadata["drop"] = self.input_metadata[name_slot] 1571 1572 print( 1573 "Not include the set info (name # set) in the 'full' argument, where 'inc_set' is True" 1574 "Only the names will be compared, without considering the set information." 1575 ) 1576 1577 else: 1578 1579 self.input_metadata["drop"] = self.input_metadata[name_slot] 1580 1581 mask = [full.upper() != x.upper() for x in self.input_metadata["drop"]] 1582 1583 if len([y for y in mask if y is False]) == 0: 1584 raise ValueError("Nothing found to reduce") 1585 1586 self.input_metadata = self.input_metadata.loc[mask, :].reset_index( 1587 drop=True 1588 ) 1589 1590 self.input_metadata = self.input_metadata.drop( 1591 columns=["drop"], errors="ignore" 1592 ) 1593 1594 if self.agg_normalized_data is not None: 1595 1596 if inc_set: 1597 1598 self.agg_normalized_data.columns = ( 1599 self.agg_metadata[name_slot] + " # " + self.agg_metadata["sets"] 1600 ) 1601 1602 if "#" not in full: 1603 1604 self.agg_normalized_data.columns = self.agg_metadata[name_slot] 1605 1606 print( 1607 "Not include the set info (name # set) in the 'full' argument, where 'inc_set' is True" 1608 "Only the names will be compared, without considering the set information." 1609 ) 1610 else: 1611 1612 self.agg_normalized_data.columns = self.agg_metadata[name_slot] 1613 1614 mask = [ 1615 full.upper() != x.upper() for x in self.agg_normalized_data.columns 1616 ] 1617 1618 if len([y for y in mask if y is False]) == 0: 1619 raise ValueError("Nothing found to reduce") 1620 1621 self.agg_normalized_data = self.agg_normalized_data.loc[:, mask] 1622 1623 if self.agg_metadata is not None: 1624 1625 if inc_set: 1626 1627 self.agg_metadata["drop"] = ( 1628 self.agg_metadata[name_slot] + " # " + self.agg_metadata["sets"] 1629 ) 1630 1631 if "#" not in full: 1632 1633 self.agg_metadata["drop"] = self.agg_metadata[name_slot] 1634 1635 print( 1636 "Not include the set info (name # set) in the 'full' argument, where 'inc_set' is True" 1637 "Only the names will be compared, without considering the set information." 1638 ) 1639 else: 1640 1641 self.agg_metadata["drop"] = self.agg_metadata[name_slot] 1642 1643 mask = [full.upper() != x.upper() for x in self.agg_metadata["drop"]] 1644 1645 if len([y for y in mask if y is False]) == 0: 1646 raise ValueError("Nothing found to reduce") 1647 1648 self.agg_metadata = self.agg_metadata.loc[mask, :].reset_index( 1649 drop=True 1650 ) 1651 1652 self.agg_metadata = self.agg_metadata.drop( 1653 columns=["drop"], errors="ignore" 1654 ) 1655 1656 self.gene_calculation() 1657 self.cells_calculation() 1658 1659 def reduce_rows(self, features_list: list): 1660 """ 1661 Remove rows (features) whose names are included in features_list. 1662 1663 Parameters 1664 ---------- 1665 features_list : list 1666 List of features to search for in index/gene names; matching entries will be removed. 1667 1668 Update 1669 ------------ 1670 Mutates `self.input_data`, `self.normalized_data`, `self.input_metadata`, 1671 `self.agg_normalized_data`, and `self.agg_metadata` (if they exist), 1672 removing columns/rows that match `reg`. 1673 1674 Raises 1675 ------ 1676 Prints a message listing features that are not found in the data. 1677 """ 1678 1679 if self.input_data is not None: 1680 1681 res = find_features(self.input_data, features=features_list) 1682 1683 res_list = [x.upper() for x in res["included"]] 1684 1685 mask = [x.upper() not in res_list for x in self.input_data.index] 1686 1687 if len([y for y in mask if y is False]) == 0: 1688 raise ValueError("Nothing found to reduce") 1689 1690 self.input_data = self.input_data.loc[mask, :] 1691 1692 if self.normalized_data is not None: 1693 1694 res = find_features(self.normalized_data, features=features_list) 1695 1696 res_list = [x.upper() for x in res["included"]] 1697 1698 mask = [x.upper() not in res_list for x in self.normalized_data.index] 1699 1700 if len([y for y in mask if y is False]) == 0: 1701 raise ValueError("Nothing found to reduce") 1702 1703 self.normalized_data = self.normalized_data.loc[mask, :] 1704 1705 if self.agg_normalized_data is not None: 1706 1707 res = find_features(self.agg_normalized_data, features=features_list) 1708 1709 res_list = [x.upper() for x in res["included"]] 1710 1711 mask = [x.upper() not in res_list for x in self.agg_normalized_data.index] 1712 1713 if len([y for y in mask if y is False]) == 0: 1714 raise ValueError("Nothing found to reduce") 1715 1716 self.agg_normalized_data = self.agg_normalized_data.loc[mask, :] 1717 1718 if len(res["not_included"]) > 0: 1719 print("\nFeatures not found:") 1720 for i in res["not_included"]: 1721 print(i) 1722 1723 self.gene_calculation() 1724 self.cells_calculation() 1725 1726 def get_data(self, set_info: bool = False): 1727 """ 1728 Return normalized data with optional set annotation appended to column names. 1729 1730 Parameters 1731 ---------- 1732 set_info : bool, default False 1733 If True, column names are returned as "cell_name # set"; otherwise 1734 only the `cell_name` is used. 1735 1736 Returns 1737 ------- 1738 pandas.DataFrame 1739 The `self.normalized_data` table with columns renamed according to `set_info`. 1740 1741 Raises 1742 ------ 1743 AttributeError 1744 If `self.normalized_data` or `self.input_metadata` is missing. 1745 """ 1746 1747 to_return = self.normalized_data 1748 1749 if set_info: 1750 to_return.columns = ( 1751 self.input_metadata["cell_names"] + " # " + self.input_metadata["sets"] 1752 ) 1753 else: 1754 to_return.columns = self.input_metadata["cell_names"] 1755 1756 return to_return 1757 1758 def get_partial_data( 1759 self, 1760 names: list | str | None = None, 1761 features: list | str | None = None, 1762 name_slot: str = "cell_names", 1763 inc_metadata: bool = False, 1764 ): 1765 """ 1766 Return a subset of the data filtered by sample names and/or feature names. 1767 1768 Parameters 1769 ---------- 1770 names : list, str, or None 1771 Names of samples to include. If None, all samples are considered. 1772 1773 features : list, str, or None 1774 Names of features to include. If None, all features are considered. 1775 1776 name_slot : str 1777 Column in metadata to use as sample names. 1778 1779 inc_metadata : bool 1780 If True return tuple (data, metadata) 1781 1782 Returns 1783 ------- 1784 pandas.DataFrame 1785 Subset of the normalized data based on the specified names and features. 1786 """ 1787 1788 data = self.normalized_data.copy() 1789 metadata = self.input_metadata 1790 1791 if name_slot in self.input_metadata.columns: 1792 data.columns = self.input_metadata[name_slot] 1793 else: 1794 raise ValueError("'name_slot' not occured in data!'") 1795 1796 if isinstance(features, str): 1797 features = [features] 1798 elif features is None: 1799 features = [] 1800 1801 if isinstance(names, str): 1802 names = [names] 1803 elif names is None: 1804 names = [] 1805 1806 features = [x.upper() for x in features] 1807 names = [x.upper() for x in names] 1808 1809 columns_names = [x.upper() for x in data.columns] 1810 features_names = [x.upper() for x in data.index] 1811 1812 columns_bool = [True if x in names else False for x in columns_names] 1813 features_bool = [True if x in features else False for x in features_names] 1814 1815 if True not in columns_bool and True not in features_bool: 1816 print("Missing 'names' and/or 'features'. Returning full dataset instead.") 1817 if inc_metadata: 1818 return data, metadata 1819 else: 1820 return data 1821 1822 if True in columns_bool: 1823 data = data.loc[:, columns_bool] 1824 metadata = metadata.loc[columns_bool, :] 1825 1826 if inc_metadata: 1827 return data, metadata 1828 else: 1829 return data 1830 1831 if True in features_bool: 1832 data = data.loc[features_bool, :] 1833 1834 if inc_metadata: 1835 return data, metadata 1836 else: 1837 return data 1838 1839 def get_metadata(self): 1840 """ 1841 Return the stored input metadata. 1842 1843 Returns 1844 ------- 1845 pandas.DataFrame 1846 `self.input_metadata` (may be None if not set). 1847 """ 1848 1849 to_return = self.input_metadata 1850 1851 return to_return 1852 1853 def gene_calculation(self): 1854 """ 1855 Calculate and store per-cell counts (e.g., number of detected genes). 1856 1857 The method computes a binary (presence/absence) per cell and sums across 1858 features to produce `self.gene_calc`. 1859 1860 Update 1861 ------ 1862 Sets `self.gene_calc` as a pandas.Series. 1863 1864 Side Effects 1865 ------------ 1866 Uses `self.input_data` when available, otherwise `self.normalized_data`. 1867 """ 1868 1869 if self.input_data is not None: 1870 1871 bin_col = self.input_data.columns.copy() 1872 1873 bin_col = bin_col.where(bin_col <= 0, 1) 1874 1875 sum_data = bin_col.sum(axis=0) 1876 1877 self.gene_calc = sum_data 1878 1879 elif self.normalized_data is not None: 1880 1881 bin_col = self.normalized_data.copy() 1882 1883 bin_col = bin_col.where(bin_col <= 0, 1) 1884 1885 sum_data = bin_col.sum(axis=0) 1886 1887 self.gene_calc = sum_data 1888 1889 def gene_histograme(self, bins=100): 1890 """ 1891 Plot a histogram of the number of genes detected per cell. 1892 1893 Parameters 1894 ---------- 1895 bins : int, default 100 1896 Number of histogram bins. 1897 1898 Returns 1899 ------- 1900 matplotlib.figure.Figure 1901 Figure containing the histogram of gene contents. 1902 1903 Notes 1904 ----- 1905 Requires `self.gene_calc` to be computed prior to calling. 1906 """ 1907 1908 fig, ax = plt.subplots(figsize=(8, 5)) 1909 1910 _, bin_edges, _ = ax.hist( 1911 self.gene_calc, bins=bins, edgecolor="black", alpha=0.6 1912 ) 1913 1914 mu, sigma = np.mean(self.gene_calc), np.std(self.gene_calc) 1915 1916 x = np.linspace(min(self.gene_calc), max(self.gene_calc), 1000) 1917 y = norm.pdf(x, mu, sigma) 1918 1919 y_scaled = y * len(self.gene_calc) * (bin_edges[1] - bin_edges[0]) 1920 1921 ax.plot( 1922 x, y_scaled, "r-", linewidth=2, label=f"Normal(μ={mu:.2f}, σ={sigma:.2f})" 1923 ) 1924 1925 ax.set_xlabel("Value") 1926 ax.set_ylabel("Count") 1927 ax.set_title("Histogram of genes detected per cell") 1928 1929 ax.set_xticks(np.linspace(min(self.gene_calc), max(self.gene_calc), 20)) 1930 ax.tick_params(axis="x", rotation=90) 1931 1932 ax.legend() 1933 1934 return fig 1935 1936 def gene_threshold(self, min_n: int | None, max_n: int | None): 1937 """ 1938 Filter cells by gene-detection thresholds (min and/or max). 1939 1940 Parameters 1941 ---------- 1942 min_n : int or None 1943 Minimum number of detected genes required to keep a cell. 1944 1945 max_n : int or None 1946 Maximum number of detected genes allowed to keep a cell. 1947 1948 Update 1949 ------- 1950 Filters `self.input_data`, `self.normalized_data`, `self.input_metadata` 1951 (and calls `average()` if `self.agg_normalized_data` exists). 1952 1953 Side Effects 1954 ------------ 1955 Raises ValueError if both bounds are None or if filtering removes all cells. 1956 """ 1957 1958 if min_n is not None and max_n is not None: 1959 mask = (self.gene_calc > min_n) & (self.gene_calc < max_n) 1960 elif min_n is None and max_n is not None: 1961 mask = self.gene_calc < max_n 1962 elif min_n is not None and max_n is None: 1963 mask = self.gene_calc > min_n 1964 else: 1965 raise ValueError("Lack of both min_n and max_n values") 1966 1967 if self.input_data is not None: 1968 1969 if len([y for y in mask if y is False]) == 0: 1970 raise ValueError("Nothing to reduce") 1971 1972 self.input_data = self.input_data.loc[:, mask.values] 1973 1974 if self.normalized_data is not None: 1975 1976 if len([y for y in mask if y is False]) == 0: 1977 raise ValueError("Nothing to reduce") 1978 1979 self.normalized_data = self.normalized_data.loc[:, mask.values] 1980 1981 if self.input_metadata is not None: 1982 1983 if len([y for y in mask if y is False]) == 0: 1984 raise ValueError("Nothing to reduce") 1985 1986 self.input_metadata = self.input_metadata.loc[mask.values, :].reset_index( 1987 drop=True 1988 ) 1989 1990 self.input_metadata = self.input_metadata.drop( 1991 columns=["drop"], errors="ignore" 1992 ) 1993 1994 if self.agg_normalized_data is not None: 1995 self.average() 1996 1997 self.gene_calculation() 1998 self.cells_calculation() 1999 2000 def cells_calculation(self, name_slot="cell_names"): 2001 """ 2002 Calculate number of cells per call name / cluster. 2003 2004 The method computes a binary (presence/absence) per cell name / cluster and sums across 2005 cells. 2006 2007 Parameters 2008 ---------- 2009 name_slot : str, default 'cell_names' 2010 Column in metadata to use as sample names. 2011 2012 Update 2013 ------ 2014 Sets `self.cells_calc` as a pd.DataFrame. 2015 """ 2016 2017 ls = list(self.input_metadata[name_slot]) 2018 2019 df = pd.DataFrame( 2020 { 2021 "cluster": pd.Series(ls).value_counts().index, 2022 "n": pd.Series(ls).value_counts().values, 2023 } 2024 ) 2025 2026 self.cells_calc = df 2027 2028 def cell_histograme(self, name_slot: str = "cell_names"): 2029 """ 2030 Plot a histogram of the number of cells detected per cell name (cluster). 2031 2032 Parameters 2033 ---------- 2034 name_slot : str, default 'cell_names' 2035 Column in metadata to use as sample names. 2036 2037 Returns 2038 ------- 2039 matplotlib.figure.Figure 2040 Figure containing the histogram of cell contents. 2041 2042 Notes 2043 ----- 2044 Requires `self.cells_calc` to be computed prior to calling. 2045 """ 2046 2047 if name_slot != "cell_names": 2048 self.cells_calculation(name_slot=name_slot) 2049 2050 fig, ax = plt.subplots(figsize=(8, 5)) 2051 2052 _, bin_edges, _ = ax.hist( 2053 list(self.cells_calc["n"]), 2054 bins=len(set(self.cells_calc["cluster"])), 2055 edgecolor="black", 2056 color="orange", 2057 alpha=0.6, 2058 ) 2059 2060 mu, sigma = np.mean(list(self.cells_calc["n"])), np.std( 2061 list(self.cells_calc["n"]) 2062 ) 2063 2064 x = np.linspace( 2065 min(list(self.cells_calc["n"])), max(list(self.cells_calc["n"])), 1000 2066 ) 2067 y = norm.pdf(x, mu, sigma) 2068 2069 y_scaled = y * len(list(self.cells_calc["n"])) * (bin_edges[1] - bin_edges[0]) 2070 2071 ax.plot( 2072 x, y_scaled, "r-", linewidth=2, label=f"Normal(μ={mu:.2f}, σ={sigma:.2f})" 2073 ) 2074 2075 ax.set_xlabel("Value") 2076 ax.set_ylabel("Count") 2077 ax.set_title("Histogram of cells detected per cell name / cluster") 2078 2079 ax.set_xticks( 2080 np.linspace( 2081 min(list(self.cells_calc["n"])), max(list(self.cells_calc["n"])), 20 2082 ) 2083 ) 2084 ax.tick_params(axis="x", rotation=90) 2085 2086 ax.legend() 2087 2088 return fig 2089 2090 def cluster_threshold(self, min_n: int | None, name_slot: str = "cell_names"): 2091 """ 2092 Filter cell names / clusters by cell-detection threshold. 2093 2094 Parameters 2095 ---------- 2096 min_n : int or None 2097 Minimum number of detected genes required to keep a cell. 2098 2099 name_slot : str, default 'cell_names' 2100 Column in metadata to use as sample names. 2101 2102 2103 Update 2104 ------- 2105 Filters `self.input_data`, `self.normalized_data`, `self.input_metadata` 2106 (and calls `average()` if `self.agg_normalized_data` exists). 2107 """ 2108 2109 if name_slot != "cell_names": 2110 self.cells_calculation(name_slot=name_slot) 2111 2112 if min_n is not None: 2113 names = self.cells_calc["cluster"][self.cells_calc["n"] < min_n] 2114 else: 2115 raise ValueError("Lack of min_n value") 2116 2117 if len(names) > 0: 2118 2119 if self.input_data is not None: 2120 2121 self.input_data.columns = self.input_metadata[name_slot] 2122 2123 mask = [not any(r in x for r in names) for x in self.input_data.columns] 2124 2125 if len([y for y in mask if y is False]) > 0: 2126 2127 self.input_data = self.input_data.loc[:, mask] 2128 2129 if self.normalized_data is not None: 2130 2131 self.normalized_data.columns = self.input_metadata[name_slot] 2132 2133 mask = [ 2134 not any(r in x for r in names) for x in self.normalized_data.columns 2135 ] 2136 2137 if len([y for y in mask if y is False]) > 0: 2138 2139 self.normalized_data = self.normalized_data.loc[:, mask] 2140 2141 if self.input_metadata is not None: 2142 2143 self.input_metadata["drop"] = self.input_metadata[name_slot] 2144 2145 mask = [ 2146 not any(r in x for r in names) for x in self.input_metadata["drop"] 2147 ] 2148 2149 if len([y for y in mask if y is False]) > 0: 2150 2151 self.input_metadata = self.input_metadata.loc[mask, :].reset_index( 2152 drop=True 2153 ) 2154 2155 self.input_metadata = self.input_metadata.drop( 2156 columns=["drop"], errors="ignore" 2157 ) 2158 2159 if self.agg_normalized_data is not None: 2160 2161 self.agg_normalized_data.columns = self.agg_metadata[name_slot] 2162 2163 mask = [ 2164 not any(r in x for r in names) 2165 for x in self.agg_normalized_data.columns 2166 ] 2167 2168 if len([y for y in mask if y is False]) > 0: 2169 2170 self.agg_normalized_data = self.agg_normalized_data.loc[:, mask] 2171 2172 if self.agg_metadata is not None: 2173 2174 self.agg_metadata["drop"] = self.agg_metadata[name_slot] 2175 2176 mask = [ 2177 not any(r in x for r in names) for x in self.agg_metadata["drop"] 2178 ] 2179 2180 if len([y for y in mask if y is False]) > 0: 2181 2182 self.agg_metadata = self.agg_metadata.loc[mask, :].reset_index( 2183 drop=True 2184 ) 2185 2186 self.agg_metadata = self.agg_metadata.drop( 2187 columns=["drop"], errors="ignore" 2188 ) 2189 2190 self.gene_calculation() 2191 self.cells_calculation() 2192 2193 def load_sparse_from_projects(self, normalized_data: bool = False): 2194 """ 2195 Load sparse 10x-style datasets from stored project paths, concatenate them, 2196 and populate `input_data` / `normalized_data` and `input_metadata`. 2197 2198 Parameters 2199 ---------- 2200 normalized_data : bool, default False 2201 If True, store concatenated tables in `self.normalized_data`. 2202 If False, store them in `self.input_data` and normalization 2203 is needed using normalize_data() method. 2204 2205 Side Effects 2206 ------------ 2207 - Reads each project using `load_sparse(...)` (expects matrix.mtx, genes.tsv, barcodes.tsv). 2208 - Concatenates all projects column-wise and sets `self.input_metadata`. 2209 - Replaces NaNs with zeros and updates `self.gene_calc`. 2210 """ 2211 2212 obj = self.objects 2213 2214 full_data = pd.DataFrame() 2215 full_metadata = pd.DataFrame() 2216 2217 for ke in obj.keys(): 2218 print(ke) 2219 2220 dt, met = load_sparse(path=obj[ke], name=ke) 2221 2222 full_data = pd.concat([full_data, dt], axis=1) 2223 full_metadata = pd.concat([full_metadata, met], axis=0) 2224 2225 full_data[np.isnan(full_data)] = 0 2226 2227 if normalized_data: 2228 self.normalized_data = full_data 2229 self.input_metadata = full_metadata 2230 else: 2231 2232 self.input_data = full_data 2233 self.input_metadata = full_metadata 2234 2235 self.gene_calculation() 2236 self.cells_calculation() 2237 2238 def rename_names(self, mapping: dict, slot: str = "cell_names"): 2239 """ 2240 Rename entries in `self.input_metadata[slot]` according to a provided mapping. 2241 2242 Parameters 2243 ---------- 2244 mapping : dict 2245 Dictionary with keys 'old_name' and 'new_name', each mapping to a list 2246 of equal length describing replacements. 2247 2248 slot : str, default 'cell_names' 2249 Metadata column to operate on. 2250 2251 Update 2252 ------- 2253 Updates `self.input_metadata[slot]` in-place with renamed values. 2254 2255 Raises 2256 ------ 2257 ValueError 2258 If mapping keys are incorrect, lengths differ, or some 'old_name' values 2259 are not present in the metadata column. 2260 """ 2261 2262 if set(["old_name", "new_name"]) != set(mapping.keys()): 2263 raise ValueError( 2264 "Mapping dictionary must contain keys 'old_name' and 'new_name', " 2265 "each with a list of names to change." 2266 ) 2267 2268 if len(mapping["old_name"]) != len(mapping["new_name"]): 2269 raise ValueError( 2270 "Mapping dictionary lists 'old_name' and 'new_name' " 2271 "must have the same length!" 2272 ) 2273 2274 names_vector = list(self.input_metadata[slot]) 2275 2276 if not all(elem in names_vector for elem in list(mapping["old_name"])): 2277 raise ValueError( 2278 f"Some entries from 'old_name' do not exist in the names of slot {slot}." 2279 ) 2280 2281 replace_dict = dict(zip(mapping["old_name"], mapping["new_name"])) 2282 2283 names_vector_ret = [replace_dict.get(item, item) for item in names_vector] 2284 2285 self.input_metadata[slot] = names_vector_ret 2286 2287 def rename_subclusters(self, mapping): 2288 """ 2289 Rename labels stored in `self.subclusters_.subclusters` according to mapping. 2290 2291 Parameters 2292 ---------- 2293 mapping : dict 2294 Mapping with keys 'old_name' and 'new_name' (lists of equal length). 2295 2296 Update 2297 ------- 2298 Updates `self.subclusters_.subclusters` with renamed labels. 2299 2300 Raises 2301 ------ 2302 ValueError 2303 If mapping is invalid or old names are not present. 2304 """ 2305 2306 if set(["old_name", "new_name"]) != set(mapping.keys()): 2307 raise ValueError( 2308 "Mapping dictionary must contain keys 'old_name' and 'new_name', " 2309 "each with a list of names to change." 2310 ) 2311 2312 if len(mapping["old_name"]) != len(mapping["new_name"]): 2313 raise ValueError( 2314 "Mapping dictionary lists 'old_name' and 'new_name' " 2315 "must have the same length!" 2316 ) 2317 2318 names_vector = list(self.subclusters_.subclusters) 2319 2320 if not all(elem in names_vector for elem in list(mapping["old_name"])): 2321 raise ValueError( 2322 "Some entries from 'old_name' do not exist in the subcluster names." 2323 ) 2324 2325 replace_dict = dict(zip(mapping["old_name"], mapping["new_name"])) 2326 2327 names_vector_ret = [replace_dict.get(item, item) for item in names_vector] 2328 2329 self.subclusters_.subclusters = names_vector_ret 2330 2331 def save_sparse( 2332 self, 2333 path_to_save: str = os.getcwd(), 2334 name_slot: str = "cell_names", 2335 data_slot: str = "normalized", 2336 ): 2337 """ 2338 Export data as 10x-compatible sparse files (matrix.mtx, barcodes.tsv, genes.tsv). 2339 2340 Parameters 2341 ---------- 2342 path_to_save : str, default current working directory 2343 Directory where files will be written. 2344 2345 name_slot : str, default 'cell_names' 2346 Metadata column providing cell names for barcodes.tsv. 2347 2348 data_slot : str, default 'normalized' 2349 Either 'normalized' (uses self.normalized_data) or 'count' (uses self.input_data). 2350 2351 Raises 2352 ------ 2353 ValueError 2354 If `data_slot` is not 'normalized' or 'count'. 2355 """ 2356 2357 names = self.input_metadata[name_slot] 2358 2359 if data_slot.lower() == "normalized": 2360 2361 features = list(self.normalized_data.index) 2362 mtx = sparse.csr_matrix(self.normalized_data) 2363 2364 elif data_slot.lower() == "count": 2365 2366 features = list(self.input_data.index) 2367 mtx = sparse.csr_matrix(self.input_data) 2368 2369 else: 2370 raise ValueError("'data_slot' must be included in 'normalized' or 'count'") 2371 2372 os.makedirs(path_to_save, exist_ok=True) 2373 2374 mmwrite(os.path.join(path_to_save, "matrix.mtx"), mtx) 2375 2376 pd.Series(names).to_csv( 2377 os.path.join(path_to_save, "barcodes.tsv"), 2378 index=False, 2379 header=False, 2380 sep="\t", 2381 ) 2382 2383 pd.Series(features).to_csv( 2384 os.path.join(path_to_save, "genes.tsv"), index=False, header=False, sep="\t" 2385 ) 2386 2387 def normalize_counts( 2388 self, normalize_factor: int = 100000, log_transform: bool = True 2389 ): 2390 """ 2391 Normalize raw counts to counts-per-(normalize_factor) 2392 (e.g., CPM, TPM - depending on normalize_factor). 2393 2394 Parameters 2395 ---------- 2396 normalize_factor : int, default 100000 2397 Scaling factor used after dividing by column sums. 2398 2399 log_transform : bool, default True 2400 If True, apply log2(x+1) transformation to normalized values. 2401 2402 Update 2403 ------- 2404 Sets `self.normalized_data` to normalized values (fills NaNs with 0). 2405 2406 Raises 2407 ------ 2408 ValueError 2409 If `self.input_data` is missing (cannot normalize). 2410 """ 2411 if self.input_data is None: 2412 raise ValueError("Input data is missing, cannot normalize.") 2413 2414 sum_col = self.input_data.sum() 2415 self.normalized_data = self.input_data.div(sum_col).fillna(0) * normalize_factor 2416 2417 if log_transform: 2418 # log2(x + 1) to avoid -inf for zeros 2419 self.normalized_data = np.log2(self.normalized_data + 1) 2420 2421 def statistic( 2422 self, 2423 cells=None, 2424 sets=None, 2425 min_exp: float = 0.01, 2426 min_pct: float = 0.1, 2427 n_proc: int = 10, 2428 ): 2429 """ 2430 Compute per-feature statistics (Mann–Whitney U) comparing target vs rest. 2431 2432 This is a wrapper similar to `calc_DEG` tailored to use `self.normalized_data` 2433 and `self.input_metadata`. It returns per-feature statistics including p-values, 2434 adjusted p-values, means, variances, effect-size measures and fold-changes. 2435 2436 Parameters 2437 ---------- 2438 cells : list, 'All', dict, or None 2439 Defines the target cells or groups for comparison (several modes supported). 2440 2441 sets : 'All', dict, or None 2442 Alternative grouping mode (operate on `self.input_metadata['sets']`). 2443 2444 min_exp : float, default 0.01 2445 Minimum expression threshold used when filtering features. 2446 2447 min_pct : float, default 0.1 2448 Minimum proportion of expressing cells in the target group required to test a feature. 2449 2450 n_proc : int, default 10 2451 Number of parallel jobs to use. 2452 2453 Returns 2454 ------- 2455 pandas.DataFrame or dict 2456 Results DataFrame (or dict containing valid/control cells + DataFrame), 2457 similar to `calc_DEG` interface. 2458 2459 Raises 2460 ------ 2461 ValueError 2462 If neither `cells` nor `sets` is provided, or input metadata mismatch occurs. 2463 2464 Notes 2465 ----- 2466 Multiple modes supported: single-list entities, 'All', pairwise dicts, etc. 2467 """ 2468 2469 def stat_calc(choose, feature_name): 2470 target_values = choose.loc[choose["DEG"] == "target", feature_name] 2471 rest_values = choose.loc[choose["DEG"] == "rest", feature_name] 2472 2473 pct_valid = (target_values > 0).sum() / len(target_values) 2474 pct_rest = (rest_values > 0).sum() / len(rest_values) 2475 2476 avg_valid = np.mean(target_values) 2477 avg_ctrl = np.mean(rest_values) 2478 sd_valid = np.std(target_values, ddof=1) 2479 sd_ctrl = np.std(rest_values, ddof=1) 2480 esm = (avg_valid - avg_ctrl) / np.sqrt(((sd_valid**2 + sd_ctrl**2) / 2)) 2481 2482 if np.sum(target_values) == np.sum(rest_values): 2483 p_val = 1.0 2484 else: 2485 _, p_val = stats.mannwhitneyu( 2486 target_values, rest_values, alternative="two-sided" 2487 ) 2488 2489 return { 2490 "feature": feature_name, 2491 "p_val": p_val, 2492 "pct_valid": pct_valid, 2493 "pct_ctrl": pct_rest, 2494 "avg_valid": avg_valid, 2495 "avg_ctrl": avg_ctrl, 2496 "sd_valid": sd_valid, 2497 "sd_ctrl": sd_ctrl, 2498 "esm": esm, 2499 } 2500 2501 def prepare_and_run_stat( 2502 choose, valid_group, min_exp, min_pct, n_proc, factors 2503 ): 2504 2505 tmp_dat = choose[choose["DEG"] == "target"] 2506 tmp_dat = tmp_dat.drop("DEG", axis=1) 2507 2508 counts = (tmp_dat > min_exp).sum(axis=0) 2509 2510 total_count = tmp_dat.shape[0] 2511 2512 info = pd.DataFrame( 2513 {"feature": list(tmp_dat.columns), "pct": list(counts / total_count)} 2514 ) 2515 2516 del tmp_dat 2517 2518 drop_col = info["feature"][info["pct"] <= min_pct] 2519 2520 if len(drop_col) + 1 == len(choose.columns): 2521 drop_col = info["feature"][info["pct"] == 0] 2522 2523 del info 2524 2525 choose = choose.drop(list(drop_col), axis=1) 2526 2527 results = Parallel(n_jobs=n_proc)( 2528 delayed(stat_calc)(choose, feature) 2529 for feature in tqdm(choose.columns[choose.columns != "DEG"]) 2530 ) 2531 2532 df = pd.DataFrame(results) 2533 df["valid_group"] = valid_group 2534 df.sort_values(by="p_val", inplace=True) 2535 2536 num_tests = len(df) 2537 df["adj_pval"] = np.minimum( 2538 1, (df["p_val"] * num_tests) / np.arange(1, num_tests + 1) 2539 ) 2540 2541 factors_df = factors.to_frame(name="factor") 2542 2543 df = df.merge(factors_df, left_on="feature", right_index=True, how="left") 2544 2545 df["FC"] = (df["avg_valid"] + df["factor"]) / ( 2546 df["avg_ctrl"] + df["factor"] 2547 ) 2548 df["log(FC)"] = np.log2(df["FC"]) 2549 df["norm_diff"] = df["avg_valid"] - df["avg_ctrl"] 2550 df = df.drop(columns=["factor"]) 2551 2552 return df 2553 2554 choose = self.normalized_data.copy().T 2555 factors = self.normalized_data.copy().replace(0, np.nan).min(axis=1) 2556 factors[factors == 0] = min(factors[factors != 0]) 2557 2558 final_results = [] 2559 2560 if isinstance(cells, list) and sets is None: 2561 print("\nAnalysis started...\nComparing selected cells to the whole set...") 2562 choose.index = ( 2563 self.input_metadata["cell_names"] + " # " + self.input_metadata["sets"] 2564 ) 2565 2566 if "#" not in cells[0]: 2567 choose.index = self.input_metadata["cell_names"] 2568 2569 print( 2570 "Not include the set info (name # set) in the 'cells' list. " 2571 "Only the names will be compared, without considering the set information." 2572 ) 2573 2574 labels = ["target" if idx in cells else "rest" for idx in choose.index] 2575 valid = list( 2576 set(choose.index[[i for i, x in enumerate(labels) if x == "target"]]) 2577 ) 2578 2579 choose["DEG"] = labels 2580 choose = choose[choose["DEG"] != "drop"] 2581 2582 result_df = prepare_and_run_stat( 2583 choose.reset_index(drop=True), 2584 valid_group=valid, 2585 min_exp=min_exp, 2586 min_pct=min_pct, 2587 n_proc=n_proc, 2588 factors=factors, 2589 ) 2590 return {"valid_cells": valid, "control_cells": "rest", "DEG": result_df} 2591 2592 elif cells == "All" and sets is None: 2593 print("\nAnalysis started...\nComparing each type of cell to others...") 2594 choose.index = ( 2595 self.input_metadata["cell_names"] + " # " + self.input_metadata["sets"] 2596 ) 2597 unique_labels = set(choose.index) 2598 2599 for label in tqdm(unique_labels): 2600 print(f"\nCalculating statistics for {label}") 2601 labels = ["target" if idx == label else "rest" for idx in choose.index] 2602 choose["DEG"] = labels 2603 choose = choose[choose["DEG"] != "drop"] 2604 result_df = prepare_and_run_stat( 2605 choose.copy(), 2606 valid_group=label, 2607 min_exp=min_exp, 2608 min_pct=min_pct, 2609 n_proc=n_proc, 2610 factors=factors, 2611 ) 2612 final_results.append(result_df) 2613 2614 return pd.concat(final_results, ignore_index=True) 2615 2616 elif cells is None and sets == "All": 2617 print("\nAnalysis started...\nComparing each set/group to others...") 2618 choose.index = self.input_metadata["sets"] 2619 unique_sets = set(choose.index) 2620 2621 for label in tqdm(unique_sets): 2622 print(f"\nCalculating statistics for {label}") 2623 labels = ["target" if idx == label else "rest" for idx in choose.index] 2624 2625 choose["DEG"] = labels 2626 choose = choose[choose["DEG"] != "drop"] 2627 result_df = prepare_and_run_stat( 2628 choose.copy(), 2629 valid_group=label, 2630 min_exp=min_exp, 2631 min_pct=min_pct, 2632 n_proc=n_proc, 2633 factors=factors, 2634 ) 2635 final_results.append(result_df) 2636 2637 return pd.concat(final_results, ignore_index=True) 2638 2639 elif cells is None and isinstance(sets, dict): 2640 print("\nAnalysis started...\nComparing groups...") 2641 2642 choose.index = self.input_metadata["sets"] 2643 2644 group_list = list(sets.keys()) 2645 if len(group_list) != 2: 2646 print("Only pairwise group comparison is supported.") 2647 return None 2648 2649 labels = [ 2650 ( 2651 "target" 2652 if idx in sets[group_list[0]] 2653 else "rest" if idx in sets[group_list[1]] else "drop" 2654 ) 2655 for idx in choose.index 2656 ] 2657 choose["DEG"] = labels 2658 choose = choose[choose["DEG"] != "drop"] 2659 2660 result_df = prepare_and_run_stat( 2661 choose.reset_index(drop=True), 2662 valid_group=group_list[0], 2663 min_exp=min_exp, 2664 min_pct=min_pct, 2665 n_proc=n_proc, 2666 factors=factors, 2667 ) 2668 return result_df 2669 2670 elif isinstance(cells, dict) and sets is None: 2671 print("\nAnalysis started...\nComparing groups...") 2672 choose.index = ( 2673 self.input_metadata["cell_names"] + " # " + self.input_metadata["sets"] 2674 ) 2675 2676 if "#" not in cells[list(cells.keys())[0]][0]: 2677 choose.index = self.input_metadata["cell_names"] 2678 2679 print( 2680 "Not include the set info (name # set) in the 'cells' dict. " 2681 "Only the names will be compared, without considering the set information." 2682 ) 2683 2684 group_list = list(cells.keys()) 2685 if len(group_list) != 2: 2686 print("Only pairwise group comparison is supported.") 2687 return None 2688 2689 labels = [ 2690 ( 2691 "target" 2692 if idx in cells[group_list[0]] 2693 else "rest" if idx in cells[group_list[1]] else "drop" 2694 ) 2695 for idx in choose.index 2696 ] 2697 2698 choose["DEG"] = labels 2699 choose = choose[choose["DEG"] != "drop"] 2700 2701 result_df = prepare_and_run_stat( 2702 choose.reset_index(drop=True), 2703 valid_group=group_list[0], 2704 min_exp=min_exp, 2705 min_pct=min_pct, 2706 n_proc=n_proc, 2707 factors=factors, 2708 ) 2709 2710 return result_df.reset_index(drop=True) 2711 2712 else: 2713 raise ValueError( 2714 "You must specify either 'cells' or 'sets' (or both). None were provided, which is not allowed for this analysis." 2715 ) 2716 2717 def calculate_difference_markers( 2718 self, min_exp=0, min_pct=0.25, n_proc=10, force: bool = False 2719 ): 2720 """ 2721 Compute differential markers (var_data) if not already present. 2722 2723 Parameters 2724 ---------- 2725 min_exp : float, default 0 2726 Minimum expression threshold passed to `statistic`. 2727 2728 min_pct : float, default 0.25 2729 Minimum percent expressed in target group. 2730 2731 n_proc : int, default 10 2732 Parallel jobs. 2733 2734 force : bool, default False 2735 If True, recompute even if `self.var_data` is present. 2736 2737 Update 2738 ------- 2739 Sets `self.var_data` to the result of `self.statistic(...)`. 2740 2741 Raise 2742 ------ 2743 ValueError if already computed and `force` is False. 2744 """ 2745 2746 if self.var_data is None or force: 2747 2748 self.var_data = self.statistic( 2749 cells="All", sets=None, min_exp=min_exp, min_pct=min_pct, n_proc=n_proc 2750 ) 2751 2752 else: 2753 raise ValueError( 2754 "self.calculate_difference_markers() has already been executed. " 2755 "The results are stored in self.var. " 2756 "If you want to recalculate with different statistics, please rerun the method with force=True." 2757 ) 2758 2759 def clustering_features( 2760 self, 2761 features_list: list | None, 2762 name_slot: str = "cell_names", 2763 p_val: float = 0.05, 2764 top_n: int = 25, 2765 adj_mean: bool = True, 2766 beta: float = 0.2, 2767 ): 2768 """ 2769 Prepare clustering input by selecting marker features and optionally smoothing cell values 2770 toward group means. 2771 2772 Parameters 2773 ---------- 2774 features_list : list or None 2775 If provided, use this list of features. If None, features are selected 2776 from `self.var_data` (adj_pval <= p_val, positive logFC) picking `top_n` per group. 2777 2778 name_slot : str, default 'cell_names' 2779 Metadata column used for naming. 2780 2781 p_val : float, default 0.05 2782 Adjusted p-value cutoff when selecting features automatically. 2783 2784 top_n : int, default 25 2785 Number of top features per valid group to keep if `features_list` is None. 2786 2787 adj_mean : bool, default True 2788 If True, adjust cell values toward group means using `beta`. 2789 2790 beta : float, default 0.2 2791 Adjustment strength toward group mean. 2792 2793 Update 2794 ------ 2795 Sets `self.clustering_data` and `self.clustering_metadata` to the selected subset, 2796 ready for PCA/UMAP/clustering. 2797 """ 2798 2799 if features_list is None or len(features_list) == 0: 2800 2801 if self.var_data is None: 2802 raise ValueError( 2803 "Lack of 'self.var_data'. Use self.calculate_difference_markers() method first." 2804 ) 2805 2806 df_tmp = self.var_data[self.var_data["adj_pval"] <= p_val] 2807 df_tmp = df_tmp[df_tmp["log(FC)"] > 0] 2808 df_tmp = ( 2809 df_tmp.sort_values( 2810 ["valid_group", "esm", "log(FC)"], ascending=[True, False, False] 2811 ) 2812 .groupby("valid_group") 2813 .head(top_n) 2814 ) 2815 2816 feaures_list = list(set(df_tmp["feature"])) 2817 2818 data = self.get_partial_data( 2819 names=None, features=feaures_list, name_slot=name_slot 2820 ) 2821 data_avg = average(data) 2822 2823 if adj_mean: 2824 data = adjust_cells_to_group_mean(data=data, data_avg=data_avg, beta=beta) 2825 2826 self.clustering_data = data 2827 2828 self.clustering_metadata = self.input_metadata 2829 2830 def average(self): 2831 """ 2832 Aggregate normalized data by (cell_name, set) pairs computing the mean per group. 2833 2834 The method constructs new column names as "cell_name # set", averages columns 2835 sharing identical labels, and populates `self.agg_normalized_data` and `self.agg_metadata`. 2836 2837 Update 2838 ------ 2839 Sets `self.agg_normalized_data` (features x aggregated samples) and 2840 `self.agg_metadata` (DataFrame with 'cell_names' and 'sets'). 2841 """ 2842 2843 wide_data = self.normalized_data 2844 2845 wide_metadata = self.input_metadata 2846 2847 new_names = wide_metadata["cell_names"] + " # " + wide_metadata["sets"] 2848 2849 wide_data.columns = list(new_names) 2850 2851 aggregated_df = wide_data.T.groupby(wide_data.columns, axis=0).mean().T 2852 2853 sets = [re.sub(".*# ", "", x) for x in aggregated_df.columns] 2854 names = [re.sub(" #.*", "", x) for x in aggregated_df.columns] 2855 2856 aggregated_df.columns = names 2857 aggregated_metadata = pd.DataFrame({"cell_names": names, "sets": sets}) 2858 2859 self.agg_metadata = aggregated_metadata 2860 self.agg_normalized_data = aggregated_df 2861 2862 def estimating_similarity( 2863 self, method="pearson", p_val: float = 0.05, top_n: int = 25 2864 ): 2865 """ 2866 Estimate pairwise similarity and Euclidean distance between aggregated samples. 2867 2868 Parameters 2869 ---------- 2870 method : str, default 'pearson' 2871 Correlation method to use (passed to pandas.DataFrame.corr()). 2872 2873 p_val : float, default 0.05 2874 Adjusted p-value cutoff used to select marker features from `self.var_data`. 2875 2876 top_n : int, default 25 2877 Number of top features per valid group to include. 2878 2879 Update 2880 ------- 2881 Computes a combined table with per-pair correlation and euclidean distance 2882 and stores it in `self.similarity`. 2883 """ 2884 2885 if self.var_data is None: 2886 raise ValueError( 2887 "Lack of 'self.var_data'. Use self.calculate_difference_markers() method first." 2888 ) 2889 2890 if self.agg_normalized_data is None: 2891 self.average() 2892 2893 metadata = self.agg_metadata 2894 data = self.agg_normalized_data 2895 2896 df_tmp = self.var_data[self.var_data["adj_pval"] <= p_val] 2897 df_tmp = df_tmp[df_tmp["log(FC)"] > 0] 2898 df_tmp = ( 2899 df_tmp.sort_values( 2900 ["valid_group", "esm", "log(FC)"], ascending=[True, False, False] 2901 ) 2902 .groupby("valid_group") 2903 .head(top_n) 2904 ) 2905 2906 data = data.loc[list(set(df_tmp["feature"]))] 2907 2908 if len(set(metadata["sets"])) > 1: 2909 data.columns = data.columns + " # " + [x for x in metadata["sets"]] 2910 else: 2911 data = data.copy() 2912 2913 scaler = StandardScaler() 2914 2915 scaled_data = scaler.fit_transform(data) 2916 2917 scaled_df = pd.DataFrame(scaled_data, columns=data.columns) 2918 2919 cor = scaled_df.corr(method=method) 2920 cor_df = cor.stack().reset_index() 2921 cor_df.columns = ["cell1", "cell2", "correlation"] 2922 2923 distances = pdist(scaled_df.T, metric="euclidean") 2924 dist_mat = pd.DataFrame( 2925 squareform(distances), index=scaled_df.columns, columns=scaled_df.columns 2926 ) 2927 dist_df = dist_mat.stack().reset_index() 2928 dist_df.columns = ["cell1", "cell2", "euclidean_dist"] 2929 2930 full = pd.merge(cor_df, dist_df, on=["cell1", "cell2"]) 2931 2932 full = full[full["cell1"] != full["cell2"]] 2933 full = full.reset_index(drop=True) 2934 2935 self.similarity = full 2936 2937 def similarity_plot( 2938 self, 2939 split_sets=True, 2940 set_info: bool = True, 2941 cmap="seismic", 2942 width=12, 2943 height=10, 2944 ): 2945 """ 2946 Visualize pairwise similarity as a scatter plot. 2947 2948 Parameters 2949 ---------- 2950 split_sets : bool, default True 2951 If True and set information is present, split plotting area roughly into two halves to visualize cross-set pairs. 2952 2953 set_info : bool, default True 2954 If True, keep the ' # set' annotation in labels; otherwise strip it. 2955 2956 cmap : str, default 'seismic' 2957 Color map for correlation (hue). 2958 2959 width : int, default 12 2960 Figure width. 2961 2962 height : int, default 10 2963 Figure height. 2964 2965 Returns 2966 ------- 2967 matplotlib.figure.Figure 2968 2969 Raises 2970 ------ 2971 ValueError 2972 If `self.similarity` is None. 2973 2974 Notes 2975 ----- 2976 The function filters pairs by z-scored euclidean distance > 0 to focus on closer pairs. 2977 """ 2978 2979 if self.similarity is None: 2980 raise ValueError( 2981 "Similarity data is missing. Please calculate similarity using self.estimating_similarity." 2982 ) 2983 2984 similarity_data = self.similarity 2985 2986 if " # " in similarity_data["cell1"][0]: 2987 similarity_data["set1"] = [ 2988 re.sub(".*# ", "", x) for x in similarity_data["cell1"] 2989 ] 2990 similarity_data["set2"] = [ 2991 re.sub(".*# ", "", x) for x in similarity_data["cell2"] 2992 ] 2993 2994 if split_sets and " # " in similarity_data["cell1"][0]: 2995 sets = list( 2996 set(list(similarity_data["set1"]) + list(similarity_data["set2"])) 2997 ) 2998 2999 mm = math.ceil(len(sets) / 2) 3000 3001 x_s = sets[0:mm] 3002 y_s = sets[mm : len(sets)] 3003 3004 similarity_data = similarity_data[similarity_data["set1"].isin(x_s)] 3005 similarity_data = similarity_data[similarity_data["set2"].isin(y_s)] 3006 3007 similarity_data = similarity_data.sort_values(["set1", "set2"]) 3008 3009 if set_info is False and " # " in similarity_data["cell1"][0]: 3010 similarity_data["cell1"] = [ 3011 re.sub(" #.*", "", x) for x in similarity_data["cell1"] 3012 ] 3013 similarity_data["cell2"] = [ 3014 re.sub(" #.*", "", x) for x in similarity_data["cell2"] 3015 ] 3016 3017 similarity_data["-euclidean_zscore"] = -zscore( 3018 similarity_data["euclidean_dist"] 3019 ) 3020 3021 similarity_data = similarity_data[similarity_data["-euclidean_zscore"] > 0] 3022 3023 fig = plt.figure(figsize=(width, height)) 3024 sns.scatterplot( 3025 data=similarity_data, 3026 x="cell1", 3027 y="cell2", 3028 hue="correlation", 3029 size="-euclidean_zscore", 3030 sizes=(1, 100), 3031 palette=cmap, 3032 alpha=1, 3033 edgecolor="black", 3034 ) 3035 3036 plt.xticks(rotation=90) 3037 plt.yticks(rotation=0) 3038 plt.xlabel("Cell 1") 3039 plt.ylabel("Cell 2") 3040 plt.legend(bbox_to_anchor=(1.05, 1), loc="upper left") 3041 3042 plt.grid(True, alpha=0.6) 3043 3044 plt.tight_layout() 3045 3046 return fig 3047 3048 def spatial_similarity( 3049 self, 3050 set_info: bool = True, 3051 bandwidth=1, 3052 n_neighbors=5, 3053 min_dist=0.1, 3054 legend_split=2, 3055 point_size=100, 3056 spread=1.0, 3057 set_op_mix_ratio=1.0, 3058 local_connectivity=1, 3059 repulsion_strength=1.0, 3060 negative_sample_rate=5, 3061 threshold=0.1, 3062 width=12, 3063 height=10, 3064 ): 3065 """ 3066 Create a spatial UMAP-like visualization of similarity relationships between samples. 3067 3068 Parameters 3069 ---------- 3070 set_info : bool, default True 3071 If True, retain set information in labels. 3072 3073 bandwidth : float, default 1 3074 Bandwidth used by MeanShift for clustering polygons. 3075 3076 point_size : float, default 100 3077 Size of scatter points. 3078 3079 legend_split : int, default 2 3080 Number of columns in legend. 3081 3082 n_neighbors, min_dist, spread, set_op_mix_ratio, local_connectivity, repulsion_strength, negative_sample_rate : parameters passed to UMAP. 3083 3084 threshold : float, default 0.1 3085 Minimum text distance for label adjustment to avoid overlap. 3086 3087 width : int, default 12 3088 Figure width. 3089 3090 height : int, default 10 3091 Figure height. 3092 3093 Returns 3094 ------- 3095 matplotlib.figure.Figure 3096 3097 Raises 3098 ------ 3099 ValueError 3100 If `self.similarity` is None. 3101 3102 Notes 3103 ----- 3104 Builds a precomputed distance matrix combining correlation and euclidean distance, 3105 runs UMAP with metric='precomputed', then overlays cluster hulls (MeanShift + convex hull) 3106 and arrows to indicate nearest neighbors (minimal combined distance). 3107 """ 3108 3109 if self.similarity is None: 3110 raise ValueError( 3111 "Similarity data is missing. Please calculate similarity using self.estimating_similarity." 3112 ) 3113 3114 similarity_data = self.similarity 3115 3116 sim = similarity_data["correlation"] 3117 sim_scaled = (sim - sim.min()) / (sim.max() - sim.min()) 3118 eu_dist = similarity_data["euclidean_dist"] 3119 eu_dist_scaled = (eu_dist - eu_dist.min()) / (eu_dist.max() - eu_dist.min()) 3120 3121 similarity_data["combo_dist"] = (1 - sim_scaled) * eu_dist_scaled 3122 3123 # for nn target 3124 arrow_df = similarity_data.copy() 3125 arrow_df = similarity_data.loc[ 3126 similarity_data.groupby("cell1")["combo_dist"].idxmin() 3127 ] 3128 3129 cells = sorted(set(similarity_data["cell1"]) | set(similarity_data["cell2"])) 3130 combo_matrix = pd.DataFrame(0, index=cells, columns=cells, dtype=float) 3131 3132 for _, row in similarity_data.iterrows(): 3133 combo_matrix.loc[row["cell1"], row["cell2"]] = row["combo_dist"] 3134 combo_matrix.loc[row["cell2"], row["cell1"]] = row["combo_dist"] 3135 3136 umap_model = umap.UMAP( 3137 n_components=2, 3138 metric="precomputed", 3139 n_neighbors=n_neighbors, 3140 min_dist=min_dist, 3141 spread=spread, 3142 set_op_mix_ratio=set_op_mix_ratio, 3143 local_connectivity=set_op_mix_ratio, 3144 repulsion_strength=repulsion_strength, 3145 negative_sample_rate=negative_sample_rate, 3146 transform_seed=42, 3147 init="spectral", 3148 random_state=42, 3149 verbose=True, 3150 ) 3151 3152 coords = umap_model.fit_transform(combo_matrix.values) 3153 cell_names = list(combo_matrix.index) 3154 num_cells = len(cell_names) 3155 palette = sns.color_palette("tab20c", num_cells) 3156 3157 if "#" in cell_names[0]: 3158 avsets = set( 3159 [re.sub(".*# ", "", x) for x in similarity_data["cell1"]] 3160 + [re.sub(".*# ", "", x) for x in similarity_data["cell2"]] 3161 ) 3162 num_sets = len(avsets) 3163 color_indices = [i * len(palette) // num_sets for i in range(num_sets)] 3164 color_mapping_sets = { 3165 set_name: palette[i] for i, set_name in zip(color_indices, avsets) 3166 } 3167 color_mapping = { 3168 name: color_mapping_sets[re.sub(".*# ", "", name)] 3169 for i, name in enumerate(cell_names) 3170 } 3171 else: 3172 color_mapping = {name: palette[i] for i, name in enumerate(cell_names)} 3173 3174 meanshift = MeanShift(bandwidth=bandwidth) 3175 labels = meanshift.fit_predict(coords) 3176 3177 fig = plt.figure(figsize=(width, height)) 3178 ax = plt.gca() 3179 3180 unique_labels = set(labels) 3181 cluster_palette = sns.color_palette("hls", len(unique_labels)) 3182 3183 for label in unique_labels: 3184 if label == -1: 3185 continue 3186 cluster_coords = coords[labels == label] 3187 if len(cluster_coords) < 3: 3188 continue 3189 3190 hull = ConvexHull(cluster_coords) 3191 hull_points = cluster_coords[hull.vertices] 3192 3193 centroid = np.mean(hull_points, axis=0) 3194 expanded = hull_points + 0.05 * (hull_points - centroid) 3195 3196 poly = Polygon( 3197 expanded, 3198 closed=True, 3199 facecolor=cluster_palette[label], 3200 edgecolor="none", 3201 alpha=0.2, 3202 zorder=1, 3203 ) 3204 ax.add_patch(poly) 3205 3206 texts = [] 3207 for i, (x, y) in enumerate(coords): 3208 plt.scatter( 3209 x, 3210 y, 3211 s=point_size, 3212 color=color_mapping[cell_names[i]], 3213 edgecolors="black", 3214 linewidths=0.5, 3215 zorder=2, 3216 ) 3217 texts.append( 3218 ax.text( 3219 x, y, str(i), ha="center", va="center", fontsize=8, color="black" 3220 ) 3221 ) 3222 3223 def dist(p1, p2): 3224 return np.sqrt((p1[0] - p2[0]) ** 2 + (p1[1] - p2[1]) ** 2) 3225 3226 texts_to_adjust = [] 3227 for i, t1 in enumerate(texts): 3228 for j, t2 in enumerate(texts): 3229 if i >= j: 3230 continue 3231 d = dist( 3232 (t1.get_position()[0], t1.get_position()[1]), 3233 (t2.get_position()[0], t2.get_position()[1]), 3234 ) 3235 if d < threshold: 3236 if t1 not in texts_to_adjust: 3237 texts_to_adjust.append(t1) 3238 if t2 not in texts_to_adjust: 3239 texts_to_adjust.append(t2) 3240 3241 adjust_text( 3242 texts_to_adjust, 3243 expand_text=(1.0, 1.0), 3244 force_text=0.9, 3245 arrowprops=dict(arrowstyle="-", color="gray", lw=0.1), 3246 ax=ax, 3247 ) 3248 3249 for _, row in arrow_df.iterrows(): 3250 try: 3251 idx1 = cell_names.index(row["cell1"]) 3252 idx2 = cell_names.index(row["cell2"]) 3253 except ValueError: 3254 continue 3255 x1, y1 = coords[idx1] 3256 x2, y2 = coords[idx2] 3257 arrow = FancyArrowPatch( 3258 (x1, y1), 3259 (x2, y2), 3260 arrowstyle="->", 3261 color="gray", 3262 linewidth=1.5, 3263 alpha=0.5, 3264 mutation_scale=12, 3265 zorder=0, 3266 ) 3267 ax.add_patch(arrow) 3268 3269 if set_info is False and " # " in cell_names[0]: 3270 3271 legend_elements = [ 3272 Patch( 3273 facecolor=color_mapping[name], 3274 edgecolor="black", 3275 label=f"{i} – {re.sub(' #.*', '', name)}", 3276 ) 3277 for i, name in enumerate(cell_names) 3278 ] 3279 3280 else: 3281 3282 legend_elements = [ 3283 Patch( 3284 facecolor=color_mapping[name], 3285 edgecolor="black", 3286 label=f"{i} – {name}", 3287 ) 3288 for i, name in enumerate(cell_names) 3289 ] 3290 3291 plt.legend( 3292 handles=legend_elements, 3293 title="Cells", 3294 bbox_to_anchor=(1.05, 1), 3295 loc="upper left", 3296 ncol=legend_split, 3297 ) 3298 3299 plt.xlabel("UMAP 1") 3300 plt.ylabel("UMAP 2") 3301 plt.grid(False) 3302 plt.show() 3303 3304 return fig 3305 3306 # subclusters part 3307 3308 def subcluster_prepare(self, features: list, cluster: str): 3309 """ 3310 Prepare a `Clustering` object for subcluster analysis on a selected parent cluster. 3311 3312 Parameters 3313 ---------- 3314 features : list 3315 Features to include for subcluster analysis. 3316 3317 cluster : str 3318 Parent cluster name (used to select matching cells). 3319 3320 Update 3321 ------ 3322 Initializes `self.subclusters_` as a new `Clustering` instance containing the 3323 reduced data for the given cluster and stores `current_features` and `current_cluster`. 3324 """ 3325 3326 dat = self.normalized_data 3327 dat.columns = list(self.input_metadata["cell_names"]) 3328 3329 dat = reduce_data(self.normalized_data, features=features, names=[cluster]) 3330 3331 self.subclusters_ = Clustering(data=dat, metadata=None) 3332 3333 self.subclusters_.current_features = features 3334 self.subclusters_.current_cluster = cluster 3335 3336 def define_subclusters( 3337 self, 3338 umap_num: int = 2, 3339 eps: float = 0.5, 3340 min_samples: int = 10, 3341 n_neighbors: int = 5, 3342 min_dist: float = 0.1, 3343 spread: float = 1.0, 3344 set_op_mix_ratio: float = 1.0, 3345 local_connectivity: int = 1, 3346 repulsion_strength: float = 1.0, 3347 negative_sample_rate: int = 5, 3348 width=8, 3349 height=6, 3350 ): 3351 """ 3352 Compute UMAP and DBSCAN clustering within a previously prepared subcluster dataset. 3353 3354 Parameters 3355 ---------- 3356 umap_num : int, default 2 3357 Number of UMAP dimensions to compute. 3358 3359 eps : float, default 0.5 3360 DBSCAN eps parameter. 3361 3362 min_samples : int, default 10 3363 DBSCAN min_samples parameter. 3364 3365 n_neighbors, min_dist, spread, set_op_mix_ratio, local_connectivity, repulsion_strength, negative_sample_rate, width, height : 3366 Additional parameters passed to UMAP / plotting / MeanShift as appropriate. 3367 3368 Update 3369 ------- 3370 Stores cluster labels in `self.subclusters_.subclusters`. 3371 3372 Raises 3373 ------ 3374 RuntimeError 3375 If `self.subclusters_` has not been prepared. 3376 """ 3377 3378 if self.subclusters_ is None: 3379 raise RuntimeError( 3380 "Nothing to return. 'self.subcluster_prepare' was not conducted!" 3381 ) 3382 3383 self.subclusters_.perform_UMAP( 3384 factorize=False, 3385 umap_num=umap_num, 3386 pc_num=0, 3387 harmonized=False, 3388 n_neighbors=n_neighbors, 3389 min_dist=min_dist, 3390 spread=spread, 3391 set_op_mix_ratio=set_op_mix_ratio, 3392 local_connectivity=local_connectivity, 3393 repulsion_strength=repulsion_strength, 3394 negative_sample_rate=negative_sample_rate, 3395 width=width, 3396 height=height, 3397 ) 3398 3399 fig = self.subclusters_.find_clusters_UMAP( 3400 umap_n=umap_num, 3401 eps=eps, 3402 min_samples=min_samples, 3403 width=width, 3404 height=height, 3405 ) 3406 3407 clusters = self.subclusters_.return_clusters(clusters="umap") 3408 3409 self.subclusters_.subclusters = [str(x) for x in list(clusters)] 3410 3411 return fig 3412 3413 def subcluster_features_scatter( 3414 self, 3415 colors="viridis", 3416 hclust="complete", 3417 scale=False, 3418 img_width=3, 3419 img_high=5, 3420 label_size=6, 3421 size_scale=70, 3422 y_lab="Genes", 3423 legend_lab="normalized", 3424 bbox_to_anchor_scale: int = 25, 3425 bbox_to_anchor_perc: tuple = (0.91, 0.63), 3426 ): 3427 """ 3428 Create a features-scatter visualization for the subclusters (averaged and occurrence). 3429 3430 Parameters 3431 ---------- 3432 colors : str, default 'viridis' 3433 Colormap name passed to `features_scatter`. 3434 3435 hclust : str or None 3436 Hierarchical clustering linkage to order rows/columns. 3437 3438 scale: bool, default False 3439 If True, expression data will be scaled (0–1) across the rows (features). 3440 3441 img_width, img_high : float 3442 Figure size. 3443 3444 label_size : int 3445 Font size for labels. 3446 3447 size_scale : int 3448 Bubble size scaling. 3449 3450 y_lab : str 3451 X axis label. 3452 3453 legend_lab : str 3454 Colorbar label. 3455 3456 bbox_to_anchor_scale : int, default=25 3457 Vertical scale (percentage) for positioning the colorbar. 3458 3459 bbox_to_anchor_perc : tuple, default=(0.91, 0.63) 3460 Anchor position for the size legend (percent bubble legend). 3461 3462 Returns 3463 ------- 3464 matplotlib.figure.Figure 3465 3466 Raises 3467 ------ 3468 RuntimeError 3469 If subcluster preparation/definition has not been run. 3470 """ 3471 3472 if self.subclusters_ is None or self.subclusters_.subclusters is None: 3473 raise RuntimeError( 3474 "Nothing to return. 'self.subcluster_prepare' -> 'self.define_subclusters' pip was not conducted!" 3475 ) 3476 3477 dat = self.normalized_data 3478 dat.columns = list(self.input_metadata["cell_names"]) 3479 3480 dat = reduce_data( 3481 self.normalized_data, 3482 features=self.subclusters_.current_features, 3483 names=[self.subclusters_.current_cluster], 3484 ) 3485 3486 dat.columns = self.subclusters_.subclusters 3487 3488 avg = average(dat) 3489 occ = occurrence(dat) 3490 3491 scatter = features_scatter( 3492 expression_data=avg, 3493 occurence_data=occ, 3494 features=None, 3495 scale=scale, 3496 metadata_list=None, 3497 colors=colors, 3498 hclust=hclust, 3499 img_width=img_width, 3500 img_high=img_high, 3501 label_size=label_size, 3502 size_scale=size_scale, 3503 y_lab=y_lab, 3504 legend_lab=legend_lab, 3505 bbox_to_anchor_scale=bbox_to_anchor_scale, 3506 bbox_to_anchor_perc=bbox_to_anchor_perc, 3507 ) 3508 3509 return scatter 3510 3511 def subcluster_DEG_scatter( 3512 self, 3513 top_n=3, 3514 min_exp=0, 3515 min_pct=0.25, 3516 p_val=0.05, 3517 colors="viridis", 3518 hclust="complete", 3519 scale=False, 3520 img_width=3, 3521 img_high=5, 3522 label_size=6, 3523 size_scale=70, 3524 y_lab="Genes", 3525 legend_lab="normalized", 3526 bbox_to_anchor_scale: int = 25, 3527 bbox_to_anchor_perc: tuple = (0.91, 0.63), 3528 n_proc=10, 3529 ): 3530 """ 3531 Plot top differential features (DEGs) for subclusters as a features-scatter. 3532 3533 Parameters 3534 ---------- 3535 top_n : int, default 3 3536 Number of top features per subcluster to show. 3537 3538 min_exp : float, default 0 3539 Minimum expression threshold passed to `statistic`. 3540 3541 min_pct : float, default 0.25 3542 Minimum percent expressed in target group. 3543 3544 p_val: float, default 0.05 3545 Maximum p-value for visualizing features. 3546 3547 n_proc : int, default 10 3548 Parallel jobs used for DEG calculation. 3549 3550 scale: bool, default False 3551 If True, expression_data will be scaled (0–1) across the rows (features). 3552 3553 colors : str, default='viridis' 3554 Colormap for expression values. 3555 3556 hclust : str or None, default='complete' 3557 Linkage method for hierarchical clustering. If None, no clustering 3558 is performed. 3559 3560 img_width : int or float, default=8 3561 Width of the plot in inches. 3562 3563 img_high : int or float, default=5 3564 Height of the plot in inches. 3565 3566 label_size : int, default=10 3567 Font size for axis labels and ticks. 3568 3569 size_scale : int or float, default=100 3570 Scaling factor for bubble sizes. 3571 3572 y_lab : str, default='Genes' 3573 Label for the x-axis. 3574 3575 legend_lab : str, default='normalized' 3576 Label for the colorbar legend. 3577 3578 bbox_to_anchor_scale : int, default=25 3579 Vertical scale (percentage) for positioning the colorbar. 3580 3581 bbox_to_anchor_perc : tuple, default=(0.91, 0.63) 3582 Anchor position for the size legend (percent bubble legend). 3583 3584 Returns 3585 ------- 3586 matplotlib.figure.Figure 3587 3588 Raises 3589 ------ 3590 RuntimeError 3591 If subcluster preparation/definition has not been run. 3592 3593 Notes 3594 ----- 3595 Internally calls `calc_DEG` (or equivalent) to obtain statistics, filters 3596 by p-value and effect-size, selects top features per valid group and plots them. 3597 """ 3598 3599 if self.subclusters_ is None or self.subclusters_.subclusters is None: 3600 raise RuntimeError( 3601 "Nothing to return. 'self.subcluster_prepare' -> 'self.define_subclusters' pip was not conducted!" 3602 ) 3603 3604 dat = self.normalized_data 3605 dat.columns = list(self.input_metadata["cell_names"]) 3606 3607 dat = reduce_data( 3608 self.normalized_data, names=[self.subclusters_.current_cluster] 3609 ) 3610 3611 dat.columns = self.subclusters_.subclusters 3612 3613 deg_stats = calc_DEG( 3614 dat, 3615 metadata_list=None, 3616 entities="All", 3617 sets=None, 3618 min_exp=min_exp, 3619 min_pct=min_pct, 3620 n_proc=n_proc, 3621 ) 3622 3623 deg_stats = deg_stats[deg_stats["p_val"] <= p_val] 3624 deg_stats = deg_stats[deg_stats["log(FC)"] > 0] 3625 3626 deg_stats = ( 3627 deg_stats.sort_values( 3628 ["valid_group", "esm", "log(FC)"], ascending=[True, False, False] 3629 ) 3630 .groupby("valid_group") 3631 .head(top_n) 3632 ) 3633 3634 dat = reduce_data(dat, features=list(set(deg_stats["feature"]))) 3635 3636 avg = average(dat) 3637 occ = occurrence(dat) 3638 3639 scatter = features_scatter( 3640 expression_data=avg, 3641 occurence_data=occ, 3642 features=None, 3643 metadata_list=None, 3644 colors=colors, 3645 hclust=hclust, 3646 img_width=img_width, 3647 img_high=img_high, 3648 label_size=label_size, 3649 size_scale=size_scale, 3650 y_lab=y_lab, 3651 legend_lab=legend_lab, 3652 bbox_to_anchor_scale=bbox_to_anchor_scale, 3653 bbox_to_anchor_perc=bbox_to_anchor_perc, 3654 ) 3655 3656 return scatter 3657 3658 def accept_subclusters(self): 3659 """ 3660 Commit subcluster labels into the main `input_metadata` by renaming cell names. 3661 3662 The method replaces occurrences of the parent cluster name in `self.input_metadata['cell_names']` 3663 with the expanded names that include subcluster suffixes (via `add_subnames`), 3664 then clears `self.subclusters_`. 3665 3666 Update 3667 ------ 3668 Modifies `self.input_metadata['cell_names']`. 3669 3670 Resets `self.subclusters_` to None. 3671 3672 Raises 3673 ------ 3674 RuntimeError 3675 If `self.subclusters_` is not defined or subclusters were not computed. 3676 """ 3677 3678 if self.subclusters_ is None or self.subclusters_.subclusters is None: 3679 raise RuntimeError( 3680 "Nothing to return. 'self.subcluster_prepare' -> 'self.define_subclusters' pip was not conducted!" 3681 ) 3682 3683 new_meta = add_subnames( 3684 list(self.input_metadata["cell_names"]), 3685 parent_name=self.subclusters_.current_cluster, 3686 new_clusters=self.subclusters_.subclusters, 3687 ) 3688 3689 self.input_metadata["cell_names"] = new_meta 3690 3691 self.subclusters_ = None 3692 3693 def scatter_plot( 3694 self, 3695 names: list | None = None, 3696 features: list | None = None, 3697 name_slot: str = "cell_names", 3698 scale=True, 3699 colors="viridis", 3700 hclust=None, 3701 img_width=15, 3702 img_high=1, 3703 label_size=10, 3704 size_scale=200, 3705 y_lab="Genes", 3706 legend_lab="log(CPM + 1)", 3707 set_box_size: float | int = 5, 3708 set_box_high: float | int = 5, 3709 bbox_to_anchor_scale=25, 3710 bbox_to_anchor_perc=(0.90, -0.24), 3711 bbox_to_anchor_group=(1.01, 0.4), 3712 ): 3713 """ 3714 Create a bubble scatter plot of selected features across samples inside project. 3715 3716 Each point represents a feature-sample pair, where the color encodes the 3717 expression value and the size encodes occurrence or relative abundance. 3718 Optionally, hierarchical clustering can be applied to order rows and columns. 3719 3720 Parameters 3721 ---------- 3722 names : list, str, or None 3723 Names of samples to include. If None, all samples are considered. 3724 3725 features : list, str, or None 3726 Names of features to include. If None, all features are considered. 3727 3728 name_slot : str 3729 Column in metadata to use as sample names. 3730 3731 scale: bool, default False 3732 If True, expression_data will be scaled (0–1) across the rows (features). 3733 3734 colors : str, default='viridis' 3735 Colormap for expression values. 3736 3737 hclust : str or None, default='complete' 3738 Linkage method for hierarchical clustering. If None, no clustering 3739 is performed. 3740 3741 img_width : int or float, default=8 3742 Width of the plot in inches. 3743 3744 img_high : int or float, default=5 3745 Height of the plot in inches. 3746 3747 label_size : int, default=10 3748 Font size for axis labels and ticks. 3749 3750 size_scale : int or float, default=100 3751 Scaling factor for bubble sizes. 3752 3753 y_lab : str, default='Genes' 3754 Label for the x-axis. 3755 3756 legend_lab : str, default='log(CPM + 1)' 3757 Label for the colorbar legend. 3758 3759 bbox_to_anchor_scale : int, default=25 3760 Vertical scale (percentage) for positioning the colorbar. 3761 3762 bbox_to_anchor_perc : tuple, default=(0.91, 0.63) 3763 Anchor position for the size legend (percent bubble legend). 3764 3765 bbox_to_anchor_group : tuple, default=(1.01, 0.4) 3766 Anchor position for the group legend. 3767 3768 Returns 3769 ------- 3770 matplotlib.figure.Figure 3771 The generated scatter plot figure. 3772 3773 Notes 3774 ----- 3775 Colors represent expression values normalized to the colormap. 3776 """ 3777 3778 prtd, met = self.get_partial_data( 3779 names=names, features=features, name_slot=name_slot, inc_metadata=True 3780 ) 3781 3782 if scale: 3783 3784 legend_lab = "Scaled\n" + legend_lab 3785 3786 scaler = MinMaxScaler(feature_range=(0, 1)) 3787 prtd = pd.DataFrame( 3788 scaler.fit_transform(prtd.T).T, 3789 index=prtd.index, 3790 columns=prtd.columns, 3791 ) 3792 3793 prtd.columns = prtd.columns + "#" + met["sets"] 3794 3795 prtd_avg = average(prtd) 3796 3797 meta_sets = [re.sub(".*#", "", x) for x in prtd_avg.columns] 3798 3799 prtd_avg.columns = [re.sub("#.*", "", x) for x in prtd_avg.columns] 3800 3801 prtd_occ = occurrence(prtd) 3802 3803 prtd_occ.columns = [re.sub("#.*", "", x) for x in prtd_occ.columns] 3804 3805 fig_scatter = features_scatter( 3806 expression_data=prtd_avg, 3807 occurence_data=prtd_occ, 3808 scale=scale, 3809 features=None, 3810 metadata_list=meta_sets, 3811 colors=colors, 3812 hclust=hclust, 3813 img_width=img_width, 3814 img_high=img_high, 3815 label_size=label_size, 3816 size_scale=size_scale, 3817 y_lab=y_lab, 3818 legend_lab=legend_lab, 3819 set_box_size=set_box_size, 3820 set_box_high=set_box_high, 3821 bbox_to_anchor_scale=bbox_to_anchor_scale, 3822 bbox_to_anchor_perc=bbox_to_anchor_perc, 3823 bbox_to_anchor_group=bbox_to_anchor_group, 3824 ) 3825 3826 return fig_scatter 3827 3828 def data_composition( 3829 self, 3830 features_count: list | None, 3831 name_slot: str = "cell_names", 3832 set_sep: bool = True, 3833 ): 3834 """ 3835 Compute composition of cell types in data set. 3836 3837 This function counts the occurrences of specific cells (e.g., cell types, subtypes) 3838 within metadata entries, calculates their relative percentages, and stores 3839 the results in `self.composition_data`. 3840 3841 Parameters 3842 ---------- 3843 features_count : list or None 3844 List of features (part or full names) to be counted. 3845 If None, all unique elements from the specified `name_slot` metadata field are used. 3846 3847 name_slot : str, default 'cell_names' 3848 Metadata field containing sample identifiers or labels. 3849 3850 set_sep : bool, default True 3851 If True and multiple sets are present in metadata, compute composition 3852 separately for each set. 3853 3854 Update 3855 ------- 3856 Stores results in `self.composition_data` as a pandas DataFrame with: 3857 - 'name': feature name 3858 - 'n': number of occurrences 3859 - 'pct': percentage of occurrences 3860 - 'set' (if applicable): dataset identifier 3861 """ 3862 3863 validated_list = list(self.input_metadata[name_slot]) 3864 sets = list(self.input_metadata["sets"]) 3865 3866 if features_count is None: 3867 features_count = list(set(self.input_metadata[name_slot])) 3868 3869 if set_sep and len(set(sets)) > 1: 3870 3871 final_res = pd.DataFrame() 3872 3873 for s in set(sets): 3874 print(s) 3875 3876 mask = [True if s == x else False for x in sets] 3877 3878 tmp_val_list = np.array(validated_list) 3879 3880 tmp_val_list = list(tmp_val_list[mask]) 3881 3882 res_dict = {"name": [], "n": [], "set": []} 3883 3884 for f in tqdm(features_count): 3885 res_dict["n"].append( 3886 sum(1 for element in tmp_val_list if f in element) 3887 ) 3888 res_dict["name"].append(f) 3889 res_dict["set"].append(s) 3890 res = pd.DataFrame(res_dict) 3891 res["pct"] = res["n"] / sum(res["n"]) * 100 3892 res["pct"] = res["pct"].round(2) 3893 3894 final_res = pd.concat([final_res, res]) 3895 3896 res = final_res.sort_values(["set", "pct"], ascending=[True, False]) 3897 3898 else: 3899 3900 res_dict = {"name": [], "n": []} 3901 3902 for f in tqdm(features_count): 3903 res_dict["n"].append( 3904 sum(1 for element in validated_list if f in element) 3905 ) 3906 res_dict["name"].append(f) 3907 3908 res = pd.DataFrame(res_dict) 3909 res["pct"] = res["n"] / sum(res["n"]) * 100 3910 res["pct"] = res["pct"].round(2) 3911 3912 res = res.sort_values("pct", ascending=False) 3913 3914 self.composition_data = res 3915 3916 def composition_pie( 3917 self, 3918 width=6, 3919 height=6, 3920 font_size=15, 3921 cmap: str = "tab20", 3922 legend_split_col: int = 1, 3923 offset_labels: float | int = 0.5, 3924 legend_bbox: tuple = (1.15, 0.95), 3925 ): 3926 """ 3927 Visualize the composition of cell lineages using pie charts. 3928 3929 Generates pie charts showing the relative proportions of features stored 3930 in `self.composition_data`. If multiple sets are present, a separate 3931 chart is drawn for each set. 3932 3933 Parameters 3934 ---------- 3935 width : int, default 6 3936 Width of the figure. 3937 3938 height : int, default 6 3939 Height of the figure (applied per set if multiple sets are plotted). 3940 3941 font_size : int, default 15 3942 Font size for labels and annotations. 3943 3944 cmap : str, default 'tab20' 3945 Colormap used for pie slices. 3946 3947 legend_split_col : int, default 1 3948 Number of columns in the legend. 3949 3950 offset_labels : float or int, default 0.5 3951 Spacing offset for label placement relative to pie slices. 3952 3953 legend_bbox : tuple, default (1.15, 0.95) 3954 Bounding box anchor position for the legend. 3955 3956 Returns 3957 ------- 3958 matplotlib.figure.Figure 3959 Pie chart visualization of composition data. 3960 """ 3961 3962 df = self.composition_data 3963 3964 if "set" in df.columns and len(set(df["set"])) > 1: 3965 3966 sets = list(set(df["set"])) 3967 fig, axes = plt.subplots(len(sets), 1, figsize=(width, height * len(sets))) 3968 3969 all_wedges = [] 3970 cmap = plt.get_cmap(cmap) 3971 3972 set_nam = len(set(df["name"])) 3973 3974 legend_labels = list(set(df["name"])) 3975 3976 colors = [cmap(i / set_nam) for i in range(set_nam)] 3977 3978 cmap_dict = dict(zip(legend_labels, colors)) 3979 3980 for idx, s in enumerate(sets): 3981 ax = axes[idx] 3982 tmp_df = df[df["set"] == s].reset_index(drop=True) 3983 3984 labels = [f"{row['pct']:.1f}%" for _, row in tmp_df.iterrows()] 3985 3986 wedges, _ = ax.pie( 3987 tmp_df["n"], 3988 startangle=90, 3989 labeldistance=1.05, 3990 colors=[cmap_dict[x] for x in tmp_df["name"]], 3991 wedgeprops={"linewidth": 0.5, "edgecolor": "black"}, 3992 ) 3993 3994 all_wedges.extend(wedges) 3995 3996 kw = dict(arrowprops=dict(arrowstyle="-"), zorder=0, va="center") 3997 n = 0 3998 for i, p in enumerate(wedges): 3999 ang = (p.theta2 - p.theta1) / 2.0 + p.theta1 4000 y = np.sin(np.deg2rad(ang)) 4001 x = np.cos(np.deg2rad(ang)) 4002 horizontalalignment = {-1: "right", 1: "left"}[int(np.sign(x))] 4003 connectionstyle = f"angle,angleA=0,angleB={ang}" 4004 kw["arrowprops"].update({"connectionstyle": connectionstyle}) 4005 if len(labels[i]) > 0: 4006 n += offset_labels 4007 ax.annotate( 4008 labels[i], 4009 xy=(x, y), 4010 xytext=(1.01 * x + (n * x / 4), 1.01 * y + (n * y / 4)), 4011 horizontalalignment=horizontalalignment, 4012 fontsize=font_size, 4013 weight="bold", 4014 **kw, 4015 ) 4016 4017 circle2 = plt.Circle((0, 0), 0.6, color="white", ec="black") 4018 ax.add_artist(circle2) 4019 4020 ax.text( 4021 0, 4022 0, 4023 f"{s}", 4024 ha="center", 4025 va="center", 4026 fontsize=font_size, 4027 weight="bold", 4028 ) 4029 4030 legend_handles = [ 4031 Patch(facecolor=cmap_dict[label], edgecolor="black", label=label) 4032 for label in legend_labels 4033 ] 4034 4035 fig.legend( 4036 handles=legend_handles, 4037 loc="center right", 4038 bbox_to_anchor=legend_bbox, 4039 ncol=legend_split_col, 4040 title="", 4041 ) 4042 4043 plt.tight_layout() 4044 plt.show() 4045 4046 else: 4047 4048 labels = [f"{row['pct']:.1f}%" for _, row in df.iterrows()] 4049 4050 legend_labels = [f"{row['name']}" for _, row in df.iterrows()] 4051 4052 cmap = plt.get_cmap(cmap) 4053 colors = [cmap(i / len(df)) for i in range(len(df))] 4054 4055 fig, ax = plt.subplots( 4056 figsize=(width, height), subplot_kw=dict(aspect="equal") 4057 ) 4058 4059 wedges, _ = ax.pie( 4060 df["n"], 4061 startangle=90, 4062 labeldistance=1.05, 4063 colors=colors, 4064 wedgeprops={"linewidth": 0.5, "edgecolor": "black"}, 4065 ) 4066 4067 kw = dict(arrowprops=dict(arrowstyle="-"), zorder=0, va="center") 4068 n = 0 4069 for i, p in enumerate(wedges): 4070 ang = (p.theta2 - p.theta1) / 2.0 + p.theta1 4071 y = np.sin(np.deg2rad(ang)) 4072 x = np.cos(np.deg2rad(ang)) 4073 horizontalalignment = {-1: "right", 1: "left"}[int(np.sign(x))] 4074 connectionstyle = "angle,angleA=0,angleB={}".format(ang) 4075 kw["arrowprops"].update({"connectionstyle": connectionstyle}) 4076 if len(labels[i]) > 0: 4077 n += offset_labels 4078 4079 ax.annotate( 4080 labels[i], 4081 xy=(x, y), 4082 xytext=(1.01 * x + (n * x / 4), y * 1.01 + (n * y / 4)), 4083 horizontalalignment=horizontalalignment, 4084 fontsize=font_size, 4085 weight="bold", 4086 **kw, 4087 ) 4088 4089 circle2 = plt.Circle((0, 0), 0.6, color="white") 4090 circle2.set_edgecolor("black") 4091 4092 p = plt.gcf() 4093 p.gca().add_artist(circle2) 4094 4095 ax.legend( 4096 wedges, 4097 legend_labels, 4098 title="", 4099 loc="center left", 4100 bbox_to_anchor=legend_bbox, 4101 ncol=legend_split_col, 4102 ) 4103 4104 plt.show() 4105 4106 return fig 4107 4108 def bar_composition( 4109 self, 4110 cmap="tab20b", 4111 width=2, 4112 height=6, 4113 font_size=15, 4114 legend_split_col: int = 1, 4115 legend_bbox: tuple = (1.3, 1), 4116 ): 4117 """ 4118 Visualize the composition of cell lineages using bar plots. 4119 4120 Produces bar plots showing the distribution of features stored in 4121 `self.composition_data`. If multiple sets are present, a separate 4122 bar is drawn for each set. Percentages are annotated alongside the bars. 4123 4124 Parameters 4125 ---------- 4126 cmap : str, default 'tab20b' 4127 Colormap used for stacked bars. 4128 4129 width : int, default 2 4130 Width of each subplot (per set). 4131 4132 height : int, default 6 4133 Height of the figure. 4134 4135 font_size : int, default 15 4136 Font size for labels and annotations. 4137 4138 legend_split_col : int, default 1 4139 Number of columns in the legend. 4140 4141 legend_bbox : tuple, default (1.3, 1) 4142 Bounding box anchor position for the legend. 4143 4144 Returns 4145 ------- 4146 matplotlib.figure.Figure 4147 Stacked bar plot visualization of composition data. 4148 """ 4149 4150 df = self.composition_data 4151 df["num"] = range(1, len(df) + 1) 4152 4153 if "set" in df.columns and len(set(df["set"])) > 1: 4154 4155 sets = list(set(df["set"])) 4156 fig, axes = plt.subplots(1, len(sets), figsize=(width * len(sets), height)) 4157 4158 cmap = plt.get_cmap(cmap) 4159 4160 set_nam = len(set(df["name"])) 4161 4162 legend_labels = list(set(df["name"])) 4163 4164 colors = [cmap(i / set_nam) for i in range(set_nam)] 4165 4166 cmap_dict = dict(zip(legend_labels, colors)) 4167 4168 for idx, s in enumerate(sets): 4169 ax = axes[idx] 4170 4171 tmp_df = df[df["set"] == s].reset_index(drop=True) 4172 4173 values = tmp_df["n"].values 4174 total = sum(values) 4175 values = [v / total * 100 for v in values] 4176 values = [round(v, 2) for v in values] 4177 4178 idx_max = np.argmax(values) 4179 correction = 100 - sum(values) 4180 values[idx_max] += correction 4181 4182 names = tmp_df["name"].values 4183 perc = tmp_df["pct"].values 4184 nums = tmp_df["num"].values 4185 4186 bottom = 0 4187 centers = [] 4188 for name, num, val, color in zip(names, nums, values, colors): 4189 ax.bar(s, val, bottom=bottom, color=cmap_dict[name], label=name) 4190 centers.append(bottom + val / 2) 4191 bottom += val 4192 4193 y_positions = np.linspace(centers[0], centers[-1], len(centers)) 4194 x_text = -0.8 4195 4196 for y_label, y_center, pct, num in zip( 4197 y_positions, centers, perc, nums 4198 ): 4199 ax.annotate( 4200 f"{pct:.1f}%", 4201 xy=(0, y_center), 4202 xycoords="data", 4203 xytext=(x_text, y_label), 4204 textcoords="data", 4205 ha="right", 4206 va="center", 4207 fontsize=font_size, 4208 arrowprops=dict( 4209 arrowstyle="->", 4210 lw=1, 4211 color="black", 4212 connectionstyle="angle3,angleA=0,angleB=90", 4213 ), 4214 ) 4215 4216 ax.set_ylim(0, 100) 4217 ax.set_xlabel(s, fontsize=font_size) 4218 ax.xaxis.label.set_rotation(30) 4219 4220 ax.set_xticks([]) 4221 ax.set_yticks([]) 4222 for spine in ax.spines.values(): 4223 spine.set_visible(False) 4224 4225 legend_handles = [ 4226 Patch(facecolor=cmap_dict[label], edgecolor="black", label=label) 4227 for label in legend_labels 4228 ] 4229 4230 fig.legend( 4231 handles=legend_handles, 4232 loc="center right", 4233 bbox_to_anchor=legend_bbox, 4234 ncol=legend_split_col, 4235 title="", 4236 ) 4237 4238 plt.tight_layout() 4239 plt.show() 4240 4241 else: 4242 4243 cmap = plt.get_cmap(cmap) 4244 4245 colors = [cmap(i / len(df)) for i in range(len(df))] 4246 4247 fig, ax = plt.subplots(figsize=(width, height)) 4248 4249 values = df["n"].values 4250 names = df["name"].values 4251 perc = df["pct"].values 4252 nums = df["num"].values 4253 4254 bottom = 0 4255 centers = [] 4256 for name, num, val, color in zip(names, nums, values, colors): 4257 ax.bar(0, val, bottom=bottom, color=color, label=f"{num}) {name}") 4258 centers.append(bottom + val / 2) 4259 bottom += val 4260 4261 y_positions = np.linspace(centers[0], centers[-1], len(centers)) 4262 x_text = -0.8 4263 4264 for y_label, y_center, pct, num in zip(y_positions, centers, perc, nums): 4265 ax.annotate( 4266 f"{num}) {pct}", 4267 xy=(0, y_center), 4268 xycoords="data", 4269 xytext=(x_text, y_label), 4270 textcoords="data", 4271 ha="right", 4272 va="center", 4273 fontsize=9, 4274 arrowprops=dict( 4275 arrowstyle="->", 4276 lw=1, 4277 color="black", 4278 connectionstyle="angle3,angleA=0,angleB=90", 4279 ), 4280 ) 4281 4282 ax.set_xticks([]) 4283 ax.set_yticks([]) 4284 for spine in ax.spines.values(): 4285 spine.set_visible(False) 4286 4287 ax.legend( 4288 title="Legend", 4289 bbox_to_anchor=legend_bbox, 4290 loc="upper left", 4291 ncol=legend_split_col, 4292 ) 4293 4294 plt.tight_layout() 4295 plt.show() 4296 4297 return fig 4298 4299 def cell_regression( 4300 self, 4301 cell_x: str, 4302 cell_y: str, 4303 set_x: str | None, 4304 set_y: str | None, 4305 threshold=10, 4306 image_width=12, 4307 image_high=7, 4308 color="black", 4309 ): 4310 """ 4311 Perform regression analysis between two selected cells and visualize the relationship. 4312 4313 This function computes a linear regression between two specified cells from 4314 aggregated normalized data, plots the regression line with scatter points, 4315 annotates regression statistics, and highlights potential outliers. 4316 4317 Parameters 4318 ---------- 4319 cell_x : str 4320 Name of the first cell (X-axis). 4321 4322 cell_y : str 4323 Name of the second cell (Y-axis). 4324 4325 set_x : str or None 4326 Dataset identifier corresponding to `cell_x`. If None, cell is selected only by name. 4327 4328 set_y : str or None 4329 Dataset identifier corresponding to `cell_y`. If None, cell is selected only by name. 4330 4331 threshold : int or float, default 10 4332 Threshold for detecting outliers. Points deviating from the mean or diagonal by more 4333 than this value are annotated. 4334 4335 image_width : int, default 12 4336 Width of the regression plot (in inches). 4337 4338 image_high : int, default 7 4339 Height of the regression plot (in inches). 4340 4341 color : str, default 'black' 4342 Color of the regression scatter points and line. 4343 4344 Returns 4345 ------- 4346 matplotlib.figure.Figure 4347 Regression plot figure with annotated regression line, R², p-value, and outliers. 4348 4349 Raises 4350 ------ 4351 ValueError 4352 If `cell_x` or `cell_y` are not found in the dataset. 4353 If multiple matches are found for a cell name and `set_x`/`set_y` are not specified. 4354 4355 Notes 4356 ----- 4357 - The function automatically calls `jseq_object.average()` if aggregated data is not available. 4358 - Outliers are annotated with their corresponding index labels. 4359 - Regression is computed using `scipy.stats.linregress`. 4360 4361 Examples 4362 -------- 4363 >>> obj.cell_regression(cell_x="Purkinje", cell_y="Granule", set_x="Exp1", set_y="Exp2") 4364 >>> obj.cell_regression(cell_x="NeuronA", cell_y="NeuronB", threshold=5, color="blue") 4365 """ 4366 4367 if self.agg_normalized_data is None: 4368 self.average() 4369 4370 metadata = self.agg_metadata 4371 data = self.agg_normalized_data 4372 4373 if set_x is not None and set_y is not None: 4374 data.columns = metadata["cell_names"] + " # " + metadata["sets"] 4375 cell_x = cell_x + " # " + set_x 4376 cell_y = cell_y + " # " + set_y 4377 4378 else: 4379 data.columns = metadata["cell_names"] 4380 4381 if not cell_x in data.columns: 4382 raise ValueError("'cell_x' value not in cell names!") 4383 4384 if not cell_y in data.columns: 4385 raise ValueError("'cell_y' value not in cell names!") 4386 4387 if list(data.columns).count(cell_x) > 1: 4388 raise ValueError( 4389 f"'{cell_x}' occurs more than once. If you want to select a specific cell, " 4390 f"please also provide the corresponding 'set_x' and 'set_y' values." 4391 ) 4392 4393 if list(data.columns).count(cell_y) > 1: 4394 raise ValueError( 4395 f"'{cell_y}' occurs more than once. If you want to select a specific cell, " 4396 f"please also provide the corresponding 'set_x' and 'set_y' values." 4397 ) 4398 4399 fig, ax = plt.subplots(figsize=(image_width, image_high)) 4400 ax = sns.regplot(x=cell_x, y=cell_y, data=data, color=color) 4401 4402 slope, intercept, r_value, p_value, _ = stats.linregress( 4403 data[cell_x], data[cell_y] 4404 ) 4405 equation = "y = {:.2f}x + {:.2f}".format(slope, intercept) 4406 4407 ax.annotate( 4408 "R-squared = {:.2f}\nP-value = {:.2f}\n{}".format( 4409 r_value**2, p_value, equation 4410 ), 4411 xy=(0.05, 0.90), 4412 xycoords="axes fraction", 4413 fontsize=12, 4414 ) 4415 4416 ax.spines["top"].set_visible(False) 4417 ax.spines["right"].set_visible(False) 4418 4419 diff = [] 4420 x_mean, y_mean = data[cell_x].mean(), data[cell_y].mean() 4421 for i, (xi, yi) in enumerate(zip(data[cell_x], data[cell_y])): 4422 diff.append(abs(xi - x_mean)) 4423 diff.append(abs(yi - y_mean)) 4424 4425 def annotate_outliers(x, y, threshold): 4426 texts = [] 4427 x_mean, y_mean = x.mean(), y.mean() 4428 for i, (xi, yi) in enumerate(zip(x, y)): 4429 if ( 4430 abs(xi - x_mean) > threshold 4431 or abs(yi - y_mean) > threshold 4432 or abs(yi - xi) > threshold 4433 ): 4434 text = ax.text(xi, yi, data.index[i]) 4435 texts.append(text) 4436 4437 return texts 4438 4439 texts = annotate_outliers(data[cell_x], data[cell_y], threshold) 4440 4441 adjust_text(texts, arrowprops=dict(arrowstyle="-", color="gray", alpha=0.5)) 4442 4443 plt.show() 4444 4445 return fig
31class Clustering: 32 """ 33 A class for performing dimensionality reduction, clustering, and visualization 34 on high-dimensional data (e.g., single-cell gene expression). 35 36 The class provides methods for: 37 - Normalizing and extracting subsets of data 38 - Principal Component Analysis (PCA) and related clustering 39 - Uniform Manifold Approximation and Projection (UMAP) and clustering 40 - Visualization of PCA and UMAP embeddings 41 - Harmonization of batch effects 42 - Accessing processed data and cluster labels 43 44 Methods 45 ------- 46 add_data_frame(data, metadata) 47 Class method to create a Clustering instance from a DataFrame and metadata. 48 49 harmonize_sets() 50 Perform batch effect harmonization on PCA data. 51 52 perform_PCA(pc_num=100, width=8, height=6) 53 Perform PCA on the dataset and visualize the first two PCs. 54 55 knee_plot_PCA(width=8, height=6) 56 Plot the cumulative variance explained by PCs to determine optimal dimensionality. 57 58 find_clusters_PCA(pc_num=0, eps=0.5, min_samples=10, width=8, height=6, harmonized=False) 59 Apply DBSCAN clustering to PCA embeddings and visualize results. 60 61 perform_UMAP(factorize=False, umap_num=100, pc_num=0, harmonized=False, ...) 62 Compute UMAP embeddings with optional parameter tuning. 63 64 knee_plot_umap(eps=0.5, min_samples=10) 65 Determine optimal UMAP dimensionality using silhouette scores. 66 67 find_clusters_UMAP(umap_n=5, eps=0.5, min_samples=10, width=8, height=6) 68 Apply DBSCAN clustering on UMAP embeddings and visualize clusters. 69 70 UMAP_vis(names_slot='cell_names', set_sep=True, point_size=0.6, ...) 71 Visualize UMAP embeddings with labels and optional cluster numbering. 72 73 UMAP_feature(feature_name, features_data=None, point_size=0.6, ...) 74 Plot a single feature over UMAP coordinates with customizable colormap. 75 76 get_umap_data() 77 Return the UMAP embeddings along with cluster labels if available. 78 79 get_pca_data() 80 Return the PCA results along with cluster labels if available. 81 82 return_clusters(clusters='umap') 83 Return the cluster labels for UMAP or PCA embeddings. 84 85 Raises 86 ------ 87 ValueError 88 For invalid parameters, mismatched dimensions, or missing metadata. 89 """ 90 91 def __init__(self, data, metadata): 92 """ 93 Initialize the clustering class with data and optional metadata. 94 95 Parameters 96 ---------- 97 data : pandas.DataFrame 98 Input data for clustering. Columns are considered as samples. 99 100 metadata : pandas.DataFrame, optional 101 Metadata for the samples. If None, a default DataFrame with column 102 names as 'cell_names' is created. 103 104 Attributes 105 ---------- 106 -clustering_data : pandas.DataFrame 107 -clustering_metadata : pandas.DataFrame 108 -subclusters : None or dict 109 -explained_var : None or numpy.ndarray 110 -cumulative_var : None or numpy.ndarray 111 -pca : None or pandas.DataFrame 112 -harmonized_pca : None or pandas.DataFrame 113 -umap : None or pandas.DataFrame 114 """ 115 116 self.clustering_data = data 117 """The input data used for clustering.""" 118 119 if metadata is None: 120 metadata = pd.DataFrame({"cell_names": list(data.columns)}) 121 122 self.clustering_metadata = metadata 123 """Metadata associated with the samples.""" 124 125 self.subclusters = None 126 """Placeholder for storing subcluster information.""" 127 128 self.explained_var = None 129 """Explained variance from PCA, initialized as None.""" 130 131 self.cumulative_var = None 132 """Cumulative explained variance from PCA, initialized as None.""" 133 134 self.pca = None 135 """PCA-transformed data, initialized as None.""" 136 137 self.harmonized_pca = None 138 """PCA data after batch effect harmonization, initialized as None.""" 139 140 self.umap = None 141 """UMAP embeddings, initialized as None.""" 142 143 @classmethod 144 def add_data_frame(cls, data: pd.DataFrame, metadata: pd.DataFrame | None): 145 """ 146 Create a Clustering instance from a DataFrame and optional metadata. 147 148 Parameters 149 ---------- 150 data : pandas.DataFrame 151 Input data with features as rows and samples/cells as columns. 152 153 metadata : pandas.DataFrame or None 154 Optional metadata for the samples. 155 Each row corresponds to a sample/cell, and column names in this DataFrame 156 should match the sample/cell names in `data`. Columns can contain additional 157 information such as cell type, experimental condition, batch, sets, etc. 158 159 Returns 160 ------- 161 Clustering 162 A new instance of the Clustering class. 163 """ 164 165 return cls(data, metadata) 166 167 def harmonize_sets(self, batch_col: str = "sets"): 168 """ 169 Perform batch effect harmonization on PCA embeddings. 170 171 Parameters 172 ---------- 173 batch_col : str, default 'sets' 174 Name of the column in `metadata` that contains batch information for the samples/cells. 175 176 Returns 177 ------- 178 None 179 Updates the `harmonized_pca` attribute with harmonized data. 180 """ 181 182 data_mat = np.array(self.pca) 183 184 metadata = self.clustering_metadata 185 186 self.harmonized_pca = pd.DataFrame( 187 harmonize.run_harmony(data_mat, metadata, vars_use=batch_col).Z_corr 188 ).T 189 190 self.harmonized_pca.columns = self.pca.columns 191 192 def perform_PCA(self, pc_num: int = 100, width=8, height=6): 193 """ 194 Perform Principal Component Analysis (PCA) on the dataset. 195 196 This method standardizes the data, applies PCA, stores results as attributes, 197 and generates a scatter plot of the first two principal components. 198 199 Parameters 200 ---------- 201 pc_num : int, default 100 202 Number of principal components to compute. 203 If 0, computes all available components. 204 205 width : int or float, default 8 206 Width of the PCA figure. 207 208 height : int or float, default 6 209 Height of the PCA figure. 210 211 Returns 212 ------- 213 matplotlib.figure.Figure 214 Scatter plot showing the first two principal components. 215 216 Updates 217 ------- 218 self.pca : pandas.DataFrame 219 DataFrame with principal component scores for each sample. 220 221 self.explained_var : numpy.ndarray 222 Percentage of variance explained by each principal component. 223 224 self.cumulative_var : numpy.ndarray 225 Cumulative explained variance. 226 """ 227 228 scaler = StandardScaler() 229 data_scaled = scaler.fit_transform(self.clustering_data.T) 230 231 if pc_num == 0 or pc_num > data_scaled.shape[0]: 232 pc_num = data_scaled.shape[0] 233 234 pca = PCA(n_components=pc_num, random_state=42) 235 236 principal_components = pca.fit_transform(data_scaled) 237 238 pca_df = pd.DataFrame( 239 data=principal_components, 240 columns=["PC" + str(x + 1) for x in range(pc_num)], 241 ) 242 243 self.explained_var = pca.explained_variance_ratio_ * 100 244 self.cumulative_var = np.cumsum(self.explained_var) 245 246 self.pca = pca_df 247 248 fig = plt.figure(figsize=(width, height)) 249 plt.scatter(pca_df["PC1"], pca_df["PC2"], alpha=0.7) 250 plt.xlabel("PC 1") 251 plt.ylabel("PC 2") 252 plt.grid(True) 253 plt.show() 254 255 return fig 256 257 def knee_plot_PCA(self, width: int = 8, height: int = 6): 258 """ 259 Plot cumulative explained variance to determine the optimal number of PCs. 260 261 Parameters 262 ---------- 263 width : int, default 8 264 Width of the figure. 265 266 height : int or, default 6 267 Height of the figure. 268 269 Returns 270 ------- 271 matplotlib.figure.Figure 272 Line plot showing cumulative variance explained by each PC. 273 """ 274 275 fig_knee = plt.figure(figsize=(width, height)) 276 plt.plot(range(1, len(self.explained_var) + 1), self.cumulative_var, marker="o") 277 plt.xlabel("PC (n components)") 278 plt.ylabel("Cumulative explained variance (%)") 279 plt.grid(True) 280 281 xticks = [1] + list(range(5, len(self.explained_var) + 1, 5)) 282 plt.xticks(xticks, rotation=60) 283 284 plt.show() 285 286 return fig_knee 287 288 def find_clusters_PCA( 289 self, 290 pc_num: int = 2, 291 eps: float = 0.5, 292 min_samples: int = 10, 293 width: int = 8, 294 height: int = 6, 295 harmonized: bool = False, 296 ): 297 """ 298 Apply DBSCAN clustering to PCA embeddings and visualize the results. 299 300 This method performs density-based clustering (DBSCAN) on the PCA-reduced 301 dataset. Cluster labels are stored in the object's metadata, and a scatter 302 plot of the first two principal components with cluster annotations is returned. 303 304 Parameters 305 ---------- 306 pc_num : int, default 2 307 Number of principal components to use for clustering. 308 If 0, uses all available components. 309 310 eps : float, default 0.5 311 Maximum distance between two points for them to be considered 312 as neighbors (DBSCAN parameter). 313 314 min_samples : int, default 10 315 Minimum number of samples required to form a cluster (DBSCAN parameter). 316 317 width : int, default 8 318 Width of the output scatter plot. 319 320 height : int, default 6 321 Height of the output scatter plot. 322 323 harmonized : bool, default False 324 If True, use harmonized PCA data (`self.harmonized_pca`). 325 If False, use standard PCA results (`self.pca`). 326 327 Returns 328 ------- 329 matplotlib.figure.Figure 330 Scatter plot of the first two principal components colored by 331 cluster assignments. 332 333 Updates 334 ------- 335 self.clustering_metadata['PCA_clusters'] : list 336 Cluster labels assigned to each cell/sample. 337 338 self.input_metadata['PCA_clusters'] : list, optional 339 Cluster labels stored in input metadata (if available). 340 """ 341 342 dbscan = DBSCAN(eps=eps, min_samples=min_samples) 343 344 if pc_num == 0 and harmonized: 345 PCA = self.harmonized_pca 346 347 elif pc_num == 0: 348 PCA = self.pca 349 350 else: 351 if harmonized: 352 353 PCA = self.harmonized_pca.iloc[:, 0:pc_num] 354 else: 355 356 PCA = self.pca.iloc[:, 0:pc_num] 357 358 dbscan_labels = dbscan.fit_predict(PCA) 359 360 pca_df = pd.DataFrame(PCA) 361 pca_df["Cluster"] = dbscan_labels 362 363 fig = plt.figure(figsize=(width, height)) 364 365 for cluster_id in sorted(pca_df["Cluster"].unique()): 366 cluster_data = pca_df[pca_df["Cluster"] == cluster_id] 367 plt.scatter( 368 cluster_data["PC1"], 369 cluster_data["PC2"], 370 label=f"Cluster {cluster_id}", 371 alpha=0.7, 372 ) 373 374 plt.xlabel("PC 1") 375 plt.ylabel("PC 2") 376 plt.legend(title="Clusters", loc="center left", bbox_to_anchor=(1.0, 0.5)) 377 378 plt.grid(True) 379 plt.show() 380 381 self.clustering_metadata["PCA_clusters"] = [str(x) for x in dbscan_labels] 382 383 try: 384 self.input_metadata["PCA_clusters"] = [str(x) for x in dbscan_labels] 385 except: 386 pass 387 388 return fig 389 390 def perform_UMAP( 391 self, 392 factorize: bool = False, 393 umap_num: int = 100, 394 pc_num: int = 0, 395 harmonized: bool = False, 396 n_neighbors: int = 5, 397 min_dist: float | int = 0.1, 398 spread: float | int = 1.0, 399 set_op_mix_ratio: float | int = 1.0, 400 local_connectivity: int = 1, 401 repulsion_strength: float | int = 1.0, 402 negative_sample_rate: int = 5, 403 width: int = 8, 404 height: int = 6, 405 ): 406 """ 407 Compute and visualize UMAP embeddings of the dataset. 408 409 This method applies Uniform Manifold Approximation and Projection (UMAP) 410 for dimensionality reduction on either raw, PCA, or harmonized PCA data. 411 Results are stored as a DataFrame (`self.umap`) and a scatter plot figure 412 (`self.UMAP_plot`). 413 414 Parameters 415 ---------- 416 factorize : bool, default False 417 If True, categorical sample labels (from column names) are factorized 418 and used as supervision in UMAP fitting. 419 420 umap_num : int, default 100 421 Number of UMAP dimensions to compute. If 0, matches the input dimension. 422 423 pc_num : int, default 0 424 Number of principal components to use as UMAP input. 425 If 0, use all available components or raw data. 426 427 harmonized : bool, default False 428 If True, use harmonized PCA embeddings (`self.harmonized_pca`). 429 If False, use standard PCA or raw scaled data. 430 431 n_neighbors : int, default 5 432 UMAP parameter controlling the size of the local neighborhood. 433 434 min_dist : float, default 0.1 435 UMAP parameter controlling minimum allowed distance between embedded points. 436 437 spread : int | float, default 1.0 438 Effective scale of embedded space (UMAP parameter). 439 440 set_op_mix_ratio : int | float, default 1.0 441 Interpolation parameter between union and intersection in fuzzy sets. 442 443 local_connectivity : int, default 1 444 Number of nearest neighbors assumed for each point. 445 446 repulsion_strength : int | float, default 1.0 447 Weighting applied to negative samples during optimization. 448 449 negative_sample_rate : int, default 5 450 Number of negative samples per positive sample in optimization. 451 452 width : int, default 8 453 Width of the output scatter plot. 454 455 height : int, default 6 456 Height of the output scatter plot. 457 458 Updates 459 ------- 460 self.umap : pandas.DataFrame 461 Table of UMAP embeddings with columns `UMAP1 ... UMAPn`. 462 463 Notes 464 ----- 465 For supervised UMAP (`factorize=True`), categorical codes from column 466 names of the dataset are used as labels. 467 """ 468 469 scaler = StandardScaler() 470 471 if pc_num == 0 and harmonized: 472 data_scaled = self.harmonized_pca 473 474 elif pc_num == 0: 475 data_scaled = scaler.fit_transform(self.clustering_data.T) 476 477 else: 478 if harmonized: 479 480 data_scaled = self.harmonized_pca.iloc[:, 0:pc_num] 481 else: 482 483 data_scaled = self.pca.iloc[:, 0:pc_num] 484 485 if umap_num == 0 or umap_num > data_scaled.shape[1]: 486 487 umap_num = data_scaled.shape[1] 488 489 reducer = umap.UMAP( 490 n_components=len(data_scaled.T), 491 random_state=42, 492 n_neighbors=n_neighbors, 493 min_dist=min_dist, 494 spread=spread, 495 set_op_mix_ratio=set_op_mix_ratio, 496 local_connectivity=local_connectivity, 497 repulsion_strength=repulsion_strength, 498 negative_sample_rate=negative_sample_rate, 499 n_jobs=1, 500 ) 501 502 else: 503 504 reducer = umap.UMAP( 505 n_components=umap_num, 506 random_state=42, 507 n_neighbors=n_neighbors, 508 min_dist=min_dist, 509 spread=spread, 510 set_op_mix_ratio=set_op_mix_ratio, 511 local_connectivity=local_connectivity, 512 repulsion_strength=repulsion_strength, 513 negative_sample_rate=negative_sample_rate, 514 n_jobs=1, 515 ) 516 517 if factorize: 518 embedding = reducer.fit_transform( 519 X=data_scaled, y=pd.Categorical(self.clustering_data.columns).codes 520 ) 521 else: 522 embedding = reducer.fit_transform(X=data_scaled) 523 524 umap_df = pd.DataFrame( 525 embedding, columns=["UMAP" + str(x + 1) for x in range(umap_num)] 526 ) 527 528 plt.figure(figsize=(width, height)) 529 plt.scatter(umap_df["UMAP1"], umap_df["UMAP2"], alpha=0.7) 530 plt.xlabel("UMAP 1") 531 plt.ylabel("UMAP 2") 532 plt.grid(True) 533 534 plt.show() 535 536 self.umap = umap_df 537 538 def knee_plot_umap(self, eps: int | float = 0.5, min_samples: int = 10): 539 """ 540 Plot silhouette scores for different UMAP dimensions to determine optimal n_components. 541 542 Parameters 543 ---------- 544 eps : float, default 0.5 545 DBSCAN eps parameter for clustering each UMAP dimension. 546 547 min_samples : int, default 10 548 Minimum number of samples to form a cluster in DBSCAN. 549 550 Returns 551 ------- 552 matplotlib.figure.Figure 553 Silhouette score plot across UMAP dimensions. 554 """ 555 556 umap_range = range(2, len(self.umap.T) + 1) 557 558 silhouette_scores = [] 559 component = [] 560 for n in umap_range: 561 562 db = DBSCAN(eps=eps, min_samples=min_samples) 563 labels = db.fit_predict(np.array(self.umap)[:, :n]) 564 565 mask = labels != -1 566 if len(set(labels[mask])) > 1: 567 score = silhouette_score(np.array(self.umap)[:, :n][mask], labels[mask]) 568 else: 569 score = -1 570 571 silhouette_scores.append(score) 572 component.append(n) 573 574 fig = plt.figure(figsize=(10, 5)) 575 plt.plot(component, silhouette_scores, marker="o") 576 plt.xlabel("UMAP (n_components)") 577 plt.ylabel("Silhouette Score") 578 plt.grid(True) 579 plt.xticks(range(int(min(component)), int(max(component)) + 1, 1)) 580 581 plt.show() 582 583 return fig 584 585 def find_clusters_UMAP( 586 self, 587 umap_n: int = 5, 588 eps: float | float = 0.5, 589 min_samples: int = 10, 590 width: int = 8, 591 height: int = 6, 592 ): 593 """ 594 Apply DBSCAN clustering on UMAP embeddings and visualize clusters. 595 596 This method performs density-based clustering (DBSCAN) on the UMAP-reduced 597 dataset. Cluster labels are stored in the object's metadata, and a scatter 598 plot of the first two UMAP components with cluster annotations is returned. 599 600 Parameters 601 ---------- 602 umap_n : int, default 5 603 Number of UMAP dimensions to use for DBSCAN clustering. 604 Must be <= number of columns in `self.umap`. 605 606 eps : float | int, default 0.5 607 Maximum neighborhood distance between two samples for them to be considered 608 as in the same cluster (DBSCAN parameter). 609 610 min_samples : int, default 10 611 Minimum number of samples in a neighborhood to form a cluster (DBSCAN parameter). 612 613 width : int, default 8 614 Figure width. 615 616 height : int, default 6 617 Figure height. 618 619 Returns 620 ------- 621 matplotlib.figure.Figure 622 Scatter plot of the first two UMAP components colored by 623 cluster assignments. 624 625 Updates 626 ------- 627 self.clustering_metadata['UMAP_clusters'] : list 628 Cluster labels assigned to each cell/sample. 629 630 self.input_metadata['UMAP_clusters'] : list, optional 631 Cluster labels stored in input metadata (if available). 632 """ 633 634 dbscan = DBSCAN(eps=eps, min_samples=min_samples) 635 dbscan_labels = dbscan.fit_predict(np.array(self.umap)[:, :umap_n]) 636 637 umap_df = self.umap 638 umap_df["Cluster"] = dbscan_labels 639 640 fig = plt.figure(figsize=(width, height)) 641 642 for cluster_id in sorted(umap_df["Cluster"].unique()): 643 cluster_data = umap_df[umap_df["Cluster"] == cluster_id] 644 plt.scatter( 645 cluster_data["UMAP1"], 646 cluster_data["UMAP2"], 647 label=f"Cluster {cluster_id}", 648 alpha=0.7, 649 ) 650 651 plt.xlabel("UMAP 1") 652 plt.ylabel("UMAP 2") 653 plt.legend(title="Clusters", loc="center left", bbox_to_anchor=(1.0, 0.5)) 654 plt.grid(True) 655 656 self.clustering_metadata["UMAP_clusters"] = [str(x) for x in dbscan_labels] 657 658 try: 659 self.input_metadata["UMAP_clusters"] = [str(x) for x in dbscan_labels] 660 except: 661 pass 662 663 return fig 664 665 def UMAP_vis( 666 self, 667 names_slot: str = "cell_names", 668 set_sep: bool = True, 669 point_size: int | float = 0.6, 670 font_size: int | float = 6, 671 legend_split_col: int = 2, 672 width: int = 8, 673 height: int = 6, 674 inc_num: bool = True, 675 ): 676 """ 677 Visualize UMAP embeddings with sample labels based on specyfic metadata slot. 678 679 Parameters 680 ---------- 681 names_slot : str, default 'cell_names' 682 Column in metadata to use as sample labels. 683 684 set_sep : bool, default True 685 If True, separate points by dataset. 686 687 point_size : float, default 0.6 688 Size of scatter points. 689 690 font_size : int, default 6 691 Font size for numbers on points. 692 693 legend_split_col : int, default 2 694 Number of columns in legend. 695 696 width : int, default 8 697 Figure width. 698 699 height : int, default 6 700 Figure height. 701 702 inc_num : bool, default True 703 If True, annotate points with numeric labels. 704 705 Returns 706 ------- 707 matplotlib.figure.Figure 708 UMAP scatter plot figure. 709 """ 710 711 umap_df = self.umap.iloc[:, 0:2].copy() 712 umap_df.loc[:, "names"] = self.clustering_metadata[names_slot] 713 714 if set_sep: 715 716 if "sets" in self.clustering_metadata.columns: 717 umap_df.loc[:, "dataset"] = self.clustering_metadata["sets"] 718 else: 719 umap_df["dataset"] = "default" 720 721 else: 722 umap_df["dataset"] = "default" 723 724 umap_df.loc[:, "tmp_nam"] = umap_df["names"] + umap_df["dataset"] 725 726 umap_df.loc[:, "count"] = umap_df["tmp_nam"].map( 727 umap_df["tmp_nam"].value_counts() 728 ) 729 730 numeric_df = ( 731 pd.DataFrame(umap_df[["count", "tmp_nam", "names"]].copy()) 732 .drop_duplicates() 733 .sort_values("count", ascending=False) 734 ) 735 numeric_df["numeric_values"] = range(0, numeric_df.shape[0]) 736 737 umap_df = umap_df.merge( 738 numeric_df[["tmp_nam", "numeric_values"]], on="tmp_nam", how="left" 739 ) 740 741 fig, ax = plt.subplots(figsize=(width, height)) 742 743 markers = ["o", "s", "^", "D", "P", "*", "X"] 744 marker_map = { 745 ds: markers[i % len(markers)] 746 for i, ds in enumerate(umap_df["dataset"].unique()) 747 } 748 749 cord_list = [] 750 751 for num, nam in zip(numeric_df["numeric_values"], numeric_df["names"]): 752 753 cluster_data = umap_df[umap_df["numeric_values"] == num] 754 755 ax.scatter( 756 cluster_data["UMAP1"], 757 cluster_data["UMAP2"], 758 label=f"{num} - {nam}", 759 marker=marker_map[cluster_data["dataset"].iloc[0]], 760 alpha=0.6, 761 s=point_size, 762 ) 763 764 coords = cluster_data[["UMAP1", "UMAP2"]].values 765 766 dists = pairwise_distances(coords) 767 768 sum_dists = dists.sum(axis=1) 769 770 center_idx = np.argmin(sum_dists) 771 center_point = coords[center_idx] 772 773 cord_list.append(center_point) 774 775 if inc_num: 776 texts = [] 777 for (x, y), num in zip(cord_list, numeric_df["numeric_values"]): 778 texts.append( 779 ax.text( 780 x, 781 y, 782 str(num), 783 ha="center", 784 va="center", 785 fontsize=font_size, 786 color="black", 787 ) 788 ) 789 790 adjust_text(texts, arrowprops=dict(arrowstyle="-", color="gray", lw=0.5)) 791 792 ax.set_xlabel("UMAP 1") 793 ax.set_ylabel("UMAP 2") 794 795 ax.legend( 796 title="Clusters", 797 loc="center left", 798 bbox_to_anchor=(1.05, 0.5), 799 ncol=legend_split_col, 800 markerscale=5, 801 ) 802 803 ax.grid(True) 804 805 plt.tight_layout() 806 807 return fig 808 809 def UMAP_feature( 810 self, 811 feature_name: str, 812 features_data: pd.DataFrame | None, 813 point_size: int | float = 0.6, 814 font_size: int | float = 6, 815 width: int = 8, 816 height: int = 6, 817 palette="light", 818 ): 819 """ 820 Visualize UMAP embedding with expression levels of a selected feature. 821 822 Each point (cell) in the UMAP plot is colored according to the expression 823 value of the chosen feature, enabling interpretation of spatial patterns 824 of gene activity or metadata distribution in low-dimensional space. 825 826 Parameters 827 ---------- 828 feature_name : str 829 Name of the feature to plot. 830 831 features_data : pandas.DataFrame or None, default None 832 If None, the function uses the DataFrame containing the clustering data. 833 To plot features not used in clustering, provide a wider DataFrame 834 containing the original feature values. 835 836 point_size : float, default 0.6 837 Size of scatter points in the plot. 838 839 font_size : int, default 6 840 Font size for axis labels and annotations. 841 842 width : int, default 8 843 Width of the matplotlib figure. 844 845 height : int, default 6 846 Height of the matplotlib figure. 847 848 palette : str, default 'light' 849 Color palette for expression visualization. Options are: 850 - 'light' 851 - 'dark' 852 - 'green' 853 - 'gray' 854 855 Returns 856 ------- 857 matplotlib.figure.Figure 858 UMAP scatter plot colored by feature values. 859 """ 860 861 umap_df = self.umap.iloc[:, 0:2].copy() 862 863 if features_data is None: 864 865 features_data = self.clustering_data 866 867 if features_data.shape[1] != umap_df.shape[0]: 868 raise ValueError( 869 "Imputed 'features_data' shape does not match the number of UMAP cells" 870 ) 871 872 blist = [ 873 True if x.upper() == feature_name.upper() else False 874 for x in features_data.index 875 ] 876 877 if not any(blist): 878 raise ValueError("Imputed feature_name is not included in the data") 879 880 umap_df.loc[:, "feature"] = ( 881 features_data.loc[blist, :] 882 .apply(lambda row: row.tolist(), axis=1) 883 .values[0] 884 ) 885 886 umap_df = umap_df.sort_values("feature", ascending=True) 887 888 import matplotlib.colors as mcolors 889 890 if palette == "light": 891 palette = px.colors.sequential.Sunsetdark 892 893 elif palette == "dark": 894 palette = px.colors.sequential.thermal 895 896 elif palette == "green": 897 palette = px.colors.sequential.Aggrnyl 898 899 elif palette == "gray": 900 palette = px.colors.sequential.gray 901 palette = palette[::-1] 902 903 else: 904 raise ValueError( 905 'Palette not found. Use: "light", "dark", "gray", or "green"' 906 ) 907 908 converted = [] 909 for c in palette: 910 rgb_255 = px.colors.unlabel_rgb(c) 911 rgb_01 = tuple(v / 255.0 for v in rgb_255) 912 converted.append(rgb_01) 913 914 my_cmap = mcolors.ListedColormap(converted, name="custom") 915 916 fig, ax = plt.subplots(figsize=(width, height)) 917 sc = ax.scatter( 918 umap_df["UMAP1"], 919 umap_df["UMAP2"], 920 c=umap_df["feature"], 921 s=point_size, 922 cmap=my_cmap, 923 alpha=1.0, 924 edgecolors="black", 925 linewidths=0.1, 926 ) 927 928 cbar = plt.colorbar(sc, ax=ax) 929 cbar.set_label(f"{feature_name}") 930 931 ax.set_xlabel("UMAP 1") 932 ax.set_ylabel("UMAP 2") 933 934 ax.grid(True) 935 936 plt.tight_layout() 937 938 return fig 939 940 def get_umap_data(self): 941 """ 942 Retrieve UMAP embedding data with optional cluster labels. 943 944 Returns the UMAP coordinates stored in `self.umap`. If clustering 945 metadata is available (specifically `UMAP_clusters`), the corresponding 946 cluster assignments are appended as an additional column. 947 948 Returns 949 ------- 950 pandas.DataFrame 951 DataFrame containing UMAP coordinates (columns: 'UMAP1', 'UMAP2', ...). 952 If available, includes an extra column 'clusters' with cluster labels. 953 954 Notes 955 ----- 956 - UMAP embeddings must be computed beforehand (e.g., using `perform_UMAP`). 957 - Cluster labels are added only if present in `self.clustering_metadata`. 958 """ 959 960 umap_data = self.umap 961 962 try: 963 umap_data["clusters"] = self.clustering_metadata["UMAP_clusters"] 964 except: 965 pass 966 967 return umap_data 968 969 def get_pca_data(self): 970 """ 971 Retrieve PCA embedding data with optional cluster labels. 972 973 Returns the principal component scores stored in `self.pca`. If clustering 974 metadata is available (specifically `PCA_clusters`), the corresponding 975 cluster assignments are appended as an additional column. 976 977 Returns 978 ------- 979 pandas.DataFrame 980 DataFrame containing PCA coordinates (columns: 'PC1', 'PC2', ...). 981 If available, includes an extra column 'clusters' with cluster labels. 982 983 Notes 984 ----- 985 - PCA must be computed beforehand (e.g., using `perform_PCA`). 986 - Cluster labels are added only if present in `self.clustering_metadata`. 987 """ 988 989 pca_data = self.pca 990 991 try: 992 pca_data["clusters"] = self.clustering_metadata["PCA_clusters"] 993 except: 994 pass 995 996 return pca_data 997 998 def return_clusters(self, clusters="umap"): 999 """ 1000 Retrieve cluster labels from UMAP or PCA clustering results. 1001 1002 Parameters 1003 ---------- 1004 clusters : str, default 'umap' 1005 Source of cluster labels to return. Must be one of: 1006 - 'umap': return cluster labels from UMAP embeddings. 1007 - 'pca' : return cluster labels from PCA embeddings. 1008 1009 Returns 1010 ------- 1011 list 1012 Cluster labels corresponding to the selected embedding method. 1013 1014 Raises 1015 ------ 1016 ValueError 1017 If `clusters` is not 'umap' or 'pca'. 1018 1019 Notes 1020 ----- 1021 Requires that clustering has already been performed 1022 (e.g., using `find_clusters_UMAP` or `find_clusters_PCA`). 1023 """ 1024 1025 if clusters.lower() == "umap": 1026 clusters_vector = self.clustering_metadata["UMAP_clusters"] 1027 elif clusters.lower() == "pca": 1028 clusters_vector = self.clustering_metadata["PCA_clusters"] 1029 else: 1030 raise ValueError("Parameter 'clusters' must be either 'umap' or 'pca'.") 1031 1032 return clusters_vector
A class for performing dimensionality reduction, clustering, and visualization on high-dimensional data (e.g., single-cell gene expression).
The class provides methods for:
- Normalizing and extracting subsets of data
- Principal Component Analysis (PCA) and related clustering
- Uniform Manifold Approximation and Projection (UMAP) and clustering
- Visualization of PCA and UMAP embeddings
- Harmonization of batch effects
- Accessing processed data and cluster labels
Methods
add_data_frame(data, metadata) Class method to create a Clustering instance from a DataFrame and metadata.
harmonize_sets() Perform batch effect harmonization on PCA data.
perform_PCA(pc_num=100, width=8, height=6) Perform PCA on the dataset and visualize the first two PCs.
knee_plot_PCA(width=8, height=6) Plot the cumulative variance explained by PCs to determine optimal dimensionality.
find_clusters_PCA(pc_num=0, eps=0.5, min_samples=10, width=8, height=6, harmonized=False) Apply DBSCAN clustering to PCA embeddings and visualize results.
perform_UMAP(factorize=False, umap_num=100, pc_num=0, harmonized=False, ...) Compute UMAP embeddings with optional parameter tuning.
knee_plot_umap(eps=0.5, min_samples=10) Determine optimal UMAP dimensionality using silhouette scores.
find_clusters_UMAP(umap_n=5, eps=0.5, min_samples=10, width=8, height=6) Apply DBSCAN clustering on UMAP embeddings and visualize clusters.
UMAP_vis(names_slot='cell_names', set_sep=True, point_size=0.6, ...) Visualize UMAP embeddings with labels and optional cluster numbering.
UMAP_feature(feature_name, features_data=None, point_size=0.6, ...) Plot a single feature over UMAP coordinates with customizable colormap.
get_umap_data() Return the UMAP embeddings along with cluster labels if available.
get_pca_data() Return the PCA results along with cluster labels if available.
return_clusters(clusters='umap') Return the cluster labels for UMAP or PCA embeddings.
Raises
ValueError For invalid parameters, mismatched dimensions, or missing metadata.
91 def __init__(self, data, metadata): 92 """ 93 Initialize the clustering class with data and optional metadata. 94 95 Parameters 96 ---------- 97 data : pandas.DataFrame 98 Input data for clustering. Columns are considered as samples. 99 100 metadata : pandas.DataFrame, optional 101 Metadata for the samples. If None, a default DataFrame with column 102 names as 'cell_names' is created. 103 104 Attributes 105 ---------- 106 -clustering_data : pandas.DataFrame 107 -clustering_metadata : pandas.DataFrame 108 -subclusters : None or dict 109 -explained_var : None or numpy.ndarray 110 -cumulative_var : None or numpy.ndarray 111 -pca : None or pandas.DataFrame 112 -harmonized_pca : None or pandas.DataFrame 113 -umap : None or pandas.DataFrame 114 """ 115 116 self.clustering_data = data 117 """The input data used for clustering.""" 118 119 if metadata is None: 120 metadata = pd.DataFrame({"cell_names": list(data.columns)}) 121 122 self.clustering_metadata = metadata 123 """Metadata associated with the samples.""" 124 125 self.subclusters = None 126 """Placeholder for storing subcluster information.""" 127 128 self.explained_var = None 129 """Explained variance from PCA, initialized as None.""" 130 131 self.cumulative_var = None 132 """Cumulative explained variance from PCA, initialized as None.""" 133 134 self.pca = None 135 """PCA-transformed data, initialized as None.""" 136 137 self.harmonized_pca = None 138 """PCA data after batch effect harmonization, initialized as None.""" 139 140 self.umap = None 141 """UMAP embeddings, initialized as None."""
Initialize the clustering class with data and optional metadata.
Parameters
data : pandas.DataFrame Input data for clustering. Columns are considered as samples.
metadata : pandas.DataFrame, optional Metadata for the samples. If None, a default DataFrame with column names as 'cell_names' is created.
Attributes
-clustering_data : pandas.DataFrame -clustering_metadata : pandas.DataFrame -subclusters : None or dict -explained_var : None or numpy.ndarray -cumulative_var : None or numpy.ndarray -pca : None or pandas.DataFrame -harmonized_pca : None or pandas.DataFrame -umap : None or pandas.DataFrame
143 @classmethod 144 def add_data_frame(cls, data: pd.DataFrame, metadata: pd.DataFrame | None): 145 """ 146 Create a Clustering instance from a DataFrame and optional metadata. 147 148 Parameters 149 ---------- 150 data : pandas.DataFrame 151 Input data with features as rows and samples/cells as columns. 152 153 metadata : pandas.DataFrame or None 154 Optional metadata for the samples. 155 Each row corresponds to a sample/cell, and column names in this DataFrame 156 should match the sample/cell names in `data`. Columns can contain additional 157 information such as cell type, experimental condition, batch, sets, etc. 158 159 Returns 160 ------- 161 Clustering 162 A new instance of the Clustering class. 163 """ 164 165 return cls(data, metadata)
Create a Clustering instance from a DataFrame and optional metadata.
Parameters
data : pandas.DataFrame Input data with features as rows and samples/cells as columns.
metadata : pandas.DataFrame or None
Optional metadata for the samples.
Each row corresponds to a sample/cell, and column names in this DataFrame
should match the sample/cell names in data
. Columns can contain additional
information such as cell type, experimental condition, batch, sets, etc.
Returns
Clustering A new instance of the Clustering class.
167 def harmonize_sets(self, batch_col: str = "sets"): 168 """ 169 Perform batch effect harmonization on PCA embeddings. 170 171 Parameters 172 ---------- 173 batch_col : str, default 'sets' 174 Name of the column in `metadata` that contains batch information for the samples/cells. 175 176 Returns 177 ------- 178 None 179 Updates the `harmonized_pca` attribute with harmonized data. 180 """ 181 182 data_mat = np.array(self.pca) 183 184 metadata = self.clustering_metadata 185 186 self.harmonized_pca = pd.DataFrame( 187 harmonize.run_harmony(data_mat, metadata, vars_use=batch_col).Z_corr 188 ).T 189 190 self.harmonized_pca.columns = self.pca.columns
Perform batch effect harmonization on PCA embeddings.
Parameters
batch_col : str, default 'sets'
Name of the column in metadata
that contains batch information for the samples/cells.
Returns
None
Updates the harmonized_pca
attribute with harmonized data.
192 def perform_PCA(self, pc_num: int = 100, width=8, height=6): 193 """ 194 Perform Principal Component Analysis (PCA) on the dataset. 195 196 This method standardizes the data, applies PCA, stores results as attributes, 197 and generates a scatter plot of the first two principal components. 198 199 Parameters 200 ---------- 201 pc_num : int, default 100 202 Number of principal components to compute. 203 If 0, computes all available components. 204 205 width : int or float, default 8 206 Width of the PCA figure. 207 208 height : int or float, default 6 209 Height of the PCA figure. 210 211 Returns 212 ------- 213 matplotlib.figure.Figure 214 Scatter plot showing the first two principal components. 215 216 Updates 217 ------- 218 self.pca : pandas.DataFrame 219 DataFrame with principal component scores for each sample. 220 221 self.explained_var : numpy.ndarray 222 Percentage of variance explained by each principal component. 223 224 self.cumulative_var : numpy.ndarray 225 Cumulative explained variance. 226 """ 227 228 scaler = StandardScaler() 229 data_scaled = scaler.fit_transform(self.clustering_data.T) 230 231 if pc_num == 0 or pc_num > data_scaled.shape[0]: 232 pc_num = data_scaled.shape[0] 233 234 pca = PCA(n_components=pc_num, random_state=42) 235 236 principal_components = pca.fit_transform(data_scaled) 237 238 pca_df = pd.DataFrame( 239 data=principal_components, 240 columns=["PC" + str(x + 1) for x in range(pc_num)], 241 ) 242 243 self.explained_var = pca.explained_variance_ratio_ * 100 244 self.cumulative_var = np.cumsum(self.explained_var) 245 246 self.pca = pca_df 247 248 fig = plt.figure(figsize=(width, height)) 249 plt.scatter(pca_df["PC1"], pca_df["PC2"], alpha=0.7) 250 plt.xlabel("PC 1") 251 plt.ylabel("PC 2") 252 plt.grid(True) 253 plt.show() 254 255 return fig
Perform Principal Component Analysis (PCA) on the dataset.
This method standardizes the data, applies PCA, stores results as attributes, and generates a scatter plot of the first two principal components.
Parameters
pc_num : int, default 100 Number of principal components to compute. If 0, computes all available components.
width : int or float, default 8 Width of the PCA figure.
height : int or float, default 6 Height of the PCA figure.
Returns
matplotlib.figure.Figure Scatter plot showing the first two principal components.
Updates
self.pca : pandas.DataFrame DataFrame with principal component scores for each sample.
self.explained_var : numpy.ndarray Percentage of variance explained by each principal component.
self.cumulative_var : numpy.ndarray Cumulative explained variance.
257 def knee_plot_PCA(self, width: int = 8, height: int = 6): 258 """ 259 Plot cumulative explained variance to determine the optimal number of PCs. 260 261 Parameters 262 ---------- 263 width : int, default 8 264 Width of the figure. 265 266 height : int or, default 6 267 Height of the figure. 268 269 Returns 270 ------- 271 matplotlib.figure.Figure 272 Line plot showing cumulative variance explained by each PC. 273 """ 274 275 fig_knee = plt.figure(figsize=(width, height)) 276 plt.plot(range(1, len(self.explained_var) + 1), self.cumulative_var, marker="o") 277 plt.xlabel("PC (n components)") 278 plt.ylabel("Cumulative explained variance (%)") 279 plt.grid(True) 280 281 xticks = [1] + list(range(5, len(self.explained_var) + 1, 5)) 282 plt.xticks(xticks, rotation=60) 283 284 plt.show() 285 286 return fig_knee
Plot cumulative explained variance to determine the optimal number of PCs.
Parameters
width : int, default 8 Width of the figure.
height : int or, default 6 Height of the figure.
Returns
matplotlib.figure.Figure Line plot showing cumulative variance explained by each PC.
288 def find_clusters_PCA( 289 self, 290 pc_num: int = 2, 291 eps: float = 0.5, 292 min_samples: int = 10, 293 width: int = 8, 294 height: int = 6, 295 harmonized: bool = False, 296 ): 297 """ 298 Apply DBSCAN clustering to PCA embeddings and visualize the results. 299 300 This method performs density-based clustering (DBSCAN) on the PCA-reduced 301 dataset. Cluster labels are stored in the object's metadata, and a scatter 302 plot of the first two principal components with cluster annotations is returned. 303 304 Parameters 305 ---------- 306 pc_num : int, default 2 307 Number of principal components to use for clustering. 308 If 0, uses all available components. 309 310 eps : float, default 0.5 311 Maximum distance between two points for them to be considered 312 as neighbors (DBSCAN parameter). 313 314 min_samples : int, default 10 315 Minimum number of samples required to form a cluster (DBSCAN parameter). 316 317 width : int, default 8 318 Width of the output scatter plot. 319 320 height : int, default 6 321 Height of the output scatter plot. 322 323 harmonized : bool, default False 324 If True, use harmonized PCA data (`self.harmonized_pca`). 325 If False, use standard PCA results (`self.pca`). 326 327 Returns 328 ------- 329 matplotlib.figure.Figure 330 Scatter plot of the first two principal components colored by 331 cluster assignments. 332 333 Updates 334 ------- 335 self.clustering_metadata['PCA_clusters'] : list 336 Cluster labels assigned to each cell/sample. 337 338 self.input_metadata['PCA_clusters'] : list, optional 339 Cluster labels stored in input metadata (if available). 340 """ 341 342 dbscan = DBSCAN(eps=eps, min_samples=min_samples) 343 344 if pc_num == 0 and harmonized: 345 PCA = self.harmonized_pca 346 347 elif pc_num == 0: 348 PCA = self.pca 349 350 else: 351 if harmonized: 352 353 PCA = self.harmonized_pca.iloc[:, 0:pc_num] 354 else: 355 356 PCA = self.pca.iloc[:, 0:pc_num] 357 358 dbscan_labels = dbscan.fit_predict(PCA) 359 360 pca_df = pd.DataFrame(PCA) 361 pca_df["Cluster"] = dbscan_labels 362 363 fig = plt.figure(figsize=(width, height)) 364 365 for cluster_id in sorted(pca_df["Cluster"].unique()): 366 cluster_data = pca_df[pca_df["Cluster"] == cluster_id] 367 plt.scatter( 368 cluster_data["PC1"], 369 cluster_data["PC2"], 370 label=f"Cluster {cluster_id}", 371 alpha=0.7, 372 ) 373 374 plt.xlabel("PC 1") 375 plt.ylabel("PC 2") 376 plt.legend(title="Clusters", loc="center left", bbox_to_anchor=(1.0, 0.5)) 377 378 plt.grid(True) 379 plt.show() 380 381 self.clustering_metadata["PCA_clusters"] = [str(x) for x in dbscan_labels] 382 383 try: 384 self.input_metadata["PCA_clusters"] = [str(x) for x in dbscan_labels] 385 except: 386 pass 387 388 return fig
Apply DBSCAN clustering to PCA embeddings and visualize the results.
This method performs density-based clustering (DBSCAN) on the PCA-reduced dataset. Cluster labels are stored in the object's metadata, and a scatter plot of the first two principal components with cluster annotations is returned.
Parameters
pc_num : int, default 2 Number of principal components to use for clustering. If 0, uses all available components.
eps : float, default 0.5 Maximum distance between two points for them to be considered as neighbors (DBSCAN parameter).
min_samples : int, default 10 Minimum number of samples required to form a cluster (DBSCAN parameter).
width : int, default 8 Width of the output scatter plot.
height : int, default 6 Height of the output scatter plot.
harmonized : bool, default False
If True, use harmonized PCA data (self.harmonized_pca
).
If False, use standard PCA results (self.pca
).
Returns
matplotlib.figure.Figure Scatter plot of the first two principal components colored by cluster assignments.
Updates
self.clustering_metadata['PCA_clusters'] : list Cluster labels assigned to each cell/sample.
self.input_metadata['PCA_clusters'] : list, optional Cluster labels stored in input metadata (if available).
390 def perform_UMAP( 391 self, 392 factorize: bool = False, 393 umap_num: int = 100, 394 pc_num: int = 0, 395 harmonized: bool = False, 396 n_neighbors: int = 5, 397 min_dist: float | int = 0.1, 398 spread: float | int = 1.0, 399 set_op_mix_ratio: float | int = 1.0, 400 local_connectivity: int = 1, 401 repulsion_strength: float | int = 1.0, 402 negative_sample_rate: int = 5, 403 width: int = 8, 404 height: int = 6, 405 ): 406 """ 407 Compute and visualize UMAP embeddings of the dataset. 408 409 This method applies Uniform Manifold Approximation and Projection (UMAP) 410 for dimensionality reduction on either raw, PCA, or harmonized PCA data. 411 Results are stored as a DataFrame (`self.umap`) and a scatter plot figure 412 (`self.UMAP_plot`). 413 414 Parameters 415 ---------- 416 factorize : bool, default False 417 If True, categorical sample labels (from column names) are factorized 418 and used as supervision in UMAP fitting. 419 420 umap_num : int, default 100 421 Number of UMAP dimensions to compute. If 0, matches the input dimension. 422 423 pc_num : int, default 0 424 Number of principal components to use as UMAP input. 425 If 0, use all available components or raw data. 426 427 harmonized : bool, default False 428 If True, use harmonized PCA embeddings (`self.harmonized_pca`). 429 If False, use standard PCA or raw scaled data. 430 431 n_neighbors : int, default 5 432 UMAP parameter controlling the size of the local neighborhood. 433 434 min_dist : float, default 0.1 435 UMAP parameter controlling minimum allowed distance between embedded points. 436 437 spread : int | float, default 1.0 438 Effective scale of embedded space (UMAP parameter). 439 440 set_op_mix_ratio : int | float, default 1.0 441 Interpolation parameter between union and intersection in fuzzy sets. 442 443 local_connectivity : int, default 1 444 Number of nearest neighbors assumed for each point. 445 446 repulsion_strength : int | float, default 1.0 447 Weighting applied to negative samples during optimization. 448 449 negative_sample_rate : int, default 5 450 Number of negative samples per positive sample in optimization. 451 452 width : int, default 8 453 Width of the output scatter plot. 454 455 height : int, default 6 456 Height of the output scatter plot. 457 458 Updates 459 ------- 460 self.umap : pandas.DataFrame 461 Table of UMAP embeddings with columns `UMAP1 ... UMAPn`. 462 463 Notes 464 ----- 465 For supervised UMAP (`factorize=True`), categorical codes from column 466 names of the dataset are used as labels. 467 """ 468 469 scaler = StandardScaler() 470 471 if pc_num == 0 and harmonized: 472 data_scaled = self.harmonized_pca 473 474 elif pc_num == 0: 475 data_scaled = scaler.fit_transform(self.clustering_data.T) 476 477 else: 478 if harmonized: 479 480 data_scaled = self.harmonized_pca.iloc[:, 0:pc_num] 481 else: 482 483 data_scaled = self.pca.iloc[:, 0:pc_num] 484 485 if umap_num == 0 or umap_num > data_scaled.shape[1]: 486 487 umap_num = data_scaled.shape[1] 488 489 reducer = umap.UMAP( 490 n_components=len(data_scaled.T), 491 random_state=42, 492 n_neighbors=n_neighbors, 493 min_dist=min_dist, 494 spread=spread, 495 set_op_mix_ratio=set_op_mix_ratio, 496 local_connectivity=local_connectivity, 497 repulsion_strength=repulsion_strength, 498 negative_sample_rate=negative_sample_rate, 499 n_jobs=1, 500 ) 501 502 else: 503 504 reducer = umap.UMAP( 505 n_components=umap_num, 506 random_state=42, 507 n_neighbors=n_neighbors, 508 min_dist=min_dist, 509 spread=spread, 510 set_op_mix_ratio=set_op_mix_ratio, 511 local_connectivity=local_connectivity, 512 repulsion_strength=repulsion_strength, 513 negative_sample_rate=negative_sample_rate, 514 n_jobs=1, 515 ) 516 517 if factorize: 518 embedding = reducer.fit_transform( 519 X=data_scaled, y=pd.Categorical(self.clustering_data.columns).codes 520 ) 521 else: 522 embedding = reducer.fit_transform(X=data_scaled) 523 524 umap_df = pd.DataFrame( 525 embedding, columns=["UMAP" + str(x + 1) for x in range(umap_num)] 526 ) 527 528 plt.figure(figsize=(width, height)) 529 plt.scatter(umap_df["UMAP1"], umap_df["UMAP2"], alpha=0.7) 530 plt.xlabel("UMAP 1") 531 plt.ylabel("UMAP 2") 532 plt.grid(True) 533 534 plt.show() 535 536 self.umap = umap_df
Compute and visualize UMAP embeddings of the dataset.
This method applies Uniform Manifold Approximation and Projection (UMAP)
for dimensionality reduction on either raw, PCA, or harmonized PCA data.
Results are stored as a DataFrame (self.umap
) and a scatter plot figure
(self.UMAP_plot
).
Parameters
factorize : bool, default False If True, categorical sample labels (from column names) are factorized and used as supervision in UMAP fitting.
umap_num : int, default 100 Number of UMAP dimensions to compute. If 0, matches the input dimension.
pc_num : int, default 0 Number of principal components to use as UMAP input. If 0, use all available components or raw data.
harmonized : bool, default False
If True, use harmonized PCA embeddings (self.harmonized_pca
).
If False, use standard PCA or raw scaled data.
n_neighbors : int, default 5 UMAP parameter controlling the size of the local neighborhood.
min_dist : float, default 0.1 UMAP parameter controlling minimum allowed distance between embedded points.
spread : int | float, default 1.0 Effective scale of embedded space (UMAP parameter).
set_op_mix_ratio : int | float, default 1.0 Interpolation parameter between union and intersection in fuzzy sets.
local_connectivity : int, default 1 Number of nearest neighbors assumed for each point.
repulsion_strength : int | float, default 1.0 Weighting applied to negative samples during optimization.
negative_sample_rate : int, default 5 Number of negative samples per positive sample in optimization.
width : int, default 8 Width of the output scatter plot.
height : int, default 6 Height of the output scatter plot.
Updates
self.umap : pandas.DataFrame
Table of UMAP embeddings with columns UMAP1 ... UMAPn
.
Notes
For supervised UMAP (factorize=True
), categorical codes from column
names of the dataset are used as labels.
538 def knee_plot_umap(self, eps: int | float = 0.5, min_samples: int = 10): 539 """ 540 Plot silhouette scores for different UMAP dimensions to determine optimal n_components. 541 542 Parameters 543 ---------- 544 eps : float, default 0.5 545 DBSCAN eps parameter for clustering each UMAP dimension. 546 547 min_samples : int, default 10 548 Minimum number of samples to form a cluster in DBSCAN. 549 550 Returns 551 ------- 552 matplotlib.figure.Figure 553 Silhouette score plot across UMAP dimensions. 554 """ 555 556 umap_range = range(2, len(self.umap.T) + 1) 557 558 silhouette_scores = [] 559 component = [] 560 for n in umap_range: 561 562 db = DBSCAN(eps=eps, min_samples=min_samples) 563 labels = db.fit_predict(np.array(self.umap)[:, :n]) 564 565 mask = labels != -1 566 if len(set(labels[mask])) > 1: 567 score = silhouette_score(np.array(self.umap)[:, :n][mask], labels[mask]) 568 else: 569 score = -1 570 571 silhouette_scores.append(score) 572 component.append(n) 573 574 fig = plt.figure(figsize=(10, 5)) 575 plt.plot(component, silhouette_scores, marker="o") 576 plt.xlabel("UMAP (n_components)") 577 plt.ylabel("Silhouette Score") 578 plt.grid(True) 579 plt.xticks(range(int(min(component)), int(max(component)) + 1, 1)) 580 581 plt.show() 582 583 return fig
Plot silhouette scores for different UMAP dimensions to determine optimal n_components.
Parameters
eps : float, default 0.5 DBSCAN eps parameter for clustering each UMAP dimension.
min_samples : int, default 10 Minimum number of samples to form a cluster in DBSCAN.
Returns
matplotlib.figure.Figure Silhouette score plot across UMAP dimensions.
585 def find_clusters_UMAP( 586 self, 587 umap_n: int = 5, 588 eps: float | float = 0.5, 589 min_samples: int = 10, 590 width: int = 8, 591 height: int = 6, 592 ): 593 """ 594 Apply DBSCAN clustering on UMAP embeddings and visualize clusters. 595 596 This method performs density-based clustering (DBSCAN) on the UMAP-reduced 597 dataset. Cluster labels are stored in the object's metadata, and a scatter 598 plot of the first two UMAP components with cluster annotations is returned. 599 600 Parameters 601 ---------- 602 umap_n : int, default 5 603 Number of UMAP dimensions to use for DBSCAN clustering. 604 Must be <= number of columns in `self.umap`. 605 606 eps : float | int, default 0.5 607 Maximum neighborhood distance between two samples for them to be considered 608 as in the same cluster (DBSCAN parameter). 609 610 min_samples : int, default 10 611 Minimum number of samples in a neighborhood to form a cluster (DBSCAN parameter). 612 613 width : int, default 8 614 Figure width. 615 616 height : int, default 6 617 Figure height. 618 619 Returns 620 ------- 621 matplotlib.figure.Figure 622 Scatter plot of the first two UMAP components colored by 623 cluster assignments. 624 625 Updates 626 ------- 627 self.clustering_metadata['UMAP_clusters'] : list 628 Cluster labels assigned to each cell/sample. 629 630 self.input_metadata['UMAP_clusters'] : list, optional 631 Cluster labels stored in input metadata (if available). 632 """ 633 634 dbscan = DBSCAN(eps=eps, min_samples=min_samples) 635 dbscan_labels = dbscan.fit_predict(np.array(self.umap)[:, :umap_n]) 636 637 umap_df = self.umap 638 umap_df["Cluster"] = dbscan_labels 639 640 fig = plt.figure(figsize=(width, height)) 641 642 for cluster_id in sorted(umap_df["Cluster"].unique()): 643 cluster_data = umap_df[umap_df["Cluster"] == cluster_id] 644 plt.scatter( 645 cluster_data["UMAP1"], 646 cluster_data["UMAP2"], 647 label=f"Cluster {cluster_id}", 648 alpha=0.7, 649 ) 650 651 plt.xlabel("UMAP 1") 652 plt.ylabel("UMAP 2") 653 plt.legend(title="Clusters", loc="center left", bbox_to_anchor=(1.0, 0.5)) 654 plt.grid(True) 655 656 self.clustering_metadata["UMAP_clusters"] = [str(x) for x in dbscan_labels] 657 658 try: 659 self.input_metadata["UMAP_clusters"] = [str(x) for x in dbscan_labels] 660 except: 661 pass 662 663 return fig
Apply DBSCAN clustering on UMAP embeddings and visualize clusters.
This method performs density-based clustering (DBSCAN) on the UMAP-reduced dataset. Cluster labels are stored in the object's metadata, and a scatter plot of the first two UMAP components with cluster annotations is returned.
Parameters
umap_n : int, default 5
Number of UMAP dimensions to use for DBSCAN clustering.
Must be <= number of columns in self.umap
.
eps : float | int, default 0.5 Maximum neighborhood distance between two samples for them to be considered as in the same cluster (DBSCAN parameter).
min_samples : int, default 10 Minimum number of samples in a neighborhood to form a cluster (DBSCAN parameter).
width : int, default 8 Figure width.
height : int, default 6 Figure height.
Returns
matplotlib.figure.Figure Scatter plot of the first two UMAP components colored by cluster assignments.
Updates
self.clustering_metadata['UMAP_clusters'] : list Cluster labels assigned to each cell/sample.
self.input_metadata['UMAP_clusters'] : list, optional Cluster labels stored in input metadata (if available).
665 def UMAP_vis( 666 self, 667 names_slot: str = "cell_names", 668 set_sep: bool = True, 669 point_size: int | float = 0.6, 670 font_size: int | float = 6, 671 legend_split_col: int = 2, 672 width: int = 8, 673 height: int = 6, 674 inc_num: bool = True, 675 ): 676 """ 677 Visualize UMAP embeddings with sample labels based on specyfic metadata slot. 678 679 Parameters 680 ---------- 681 names_slot : str, default 'cell_names' 682 Column in metadata to use as sample labels. 683 684 set_sep : bool, default True 685 If True, separate points by dataset. 686 687 point_size : float, default 0.6 688 Size of scatter points. 689 690 font_size : int, default 6 691 Font size for numbers on points. 692 693 legend_split_col : int, default 2 694 Number of columns in legend. 695 696 width : int, default 8 697 Figure width. 698 699 height : int, default 6 700 Figure height. 701 702 inc_num : bool, default True 703 If True, annotate points with numeric labels. 704 705 Returns 706 ------- 707 matplotlib.figure.Figure 708 UMAP scatter plot figure. 709 """ 710 711 umap_df = self.umap.iloc[:, 0:2].copy() 712 umap_df.loc[:, "names"] = self.clustering_metadata[names_slot] 713 714 if set_sep: 715 716 if "sets" in self.clustering_metadata.columns: 717 umap_df.loc[:, "dataset"] = self.clustering_metadata["sets"] 718 else: 719 umap_df["dataset"] = "default" 720 721 else: 722 umap_df["dataset"] = "default" 723 724 umap_df.loc[:, "tmp_nam"] = umap_df["names"] + umap_df["dataset"] 725 726 umap_df.loc[:, "count"] = umap_df["tmp_nam"].map( 727 umap_df["tmp_nam"].value_counts() 728 ) 729 730 numeric_df = ( 731 pd.DataFrame(umap_df[["count", "tmp_nam", "names"]].copy()) 732 .drop_duplicates() 733 .sort_values("count", ascending=False) 734 ) 735 numeric_df["numeric_values"] = range(0, numeric_df.shape[0]) 736 737 umap_df = umap_df.merge( 738 numeric_df[["tmp_nam", "numeric_values"]], on="tmp_nam", how="left" 739 ) 740 741 fig, ax = plt.subplots(figsize=(width, height)) 742 743 markers = ["o", "s", "^", "D", "P", "*", "X"] 744 marker_map = { 745 ds: markers[i % len(markers)] 746 for i, ds in enumerate(umap_df["dataset"].unique()) 747 } 748 749 cord_list = [] 750 751 for num, nam in zip(numeric_df["numeric_values"], numeric_df["names"]): 752 753 cluster_data = umap_df[umap_df["numeric_values"] == num] 754 755 ax.scatter( 756 cluster_data["UMAP1"], 757 cluster_data["UMAP2"], 758 label=f"{num} - {nam}", 759 marker=marker_map[cluster_data["dataset"].iloc[0]], 760 alpha=0.6, 761 s=point_size, 762 ) 763 764 coords = cluster_data[["UMAP1", "UMAP2"]].values 765 766 dists = pairwise_distances(coords) 767 768 sum_dists = dists.sum(axis=1) 769 770 center_idx = np.argmin(sum_dists) 771 center_point = coords[center_idx] 772 773 cord_list.append(center_point) 774 775 if inc_num: 776 texts = [] 777 for (x, y), num in zip(cord_list, numeric_df["numeric_values"]): 778 texts.append( 779 ax.text( 780 x, 781 y, 782 str(num), 783 ha="center", 784 va="center", 785 fontsize=font_size, 786 color="black", 787 ) 788 ) 789 790 adjust_text(texts, arrowprops=dict(arrowstyle="-", color="gray", lw=0.5)) 791 792 ax.set_xlabel("UMAP 1") 793 ax.set_ylabel("UMAP 2") 794 795 ax.legend( 796 title="Clusters", 797 loc="center left", 798 bbox_to_anchor=(1.05, 0.5), 799 ncol=legend_split_col, 800 markerscale=5, 801 ) 802 803 ax.grid(True) 804 805 plt.tight_layout() 806 807 return fig
Visualize UMAP embeddings with sample labels based on specyfic metadata slot.
Parameters
names_slot : str, default 'cell_names' Column in metadata to use as sample labels.
set_sep : bool, default True If True, separate points by dataset.
point_size : float, default 0.6 Size of scatter points.
font_size : int, default 6 Font size for numbers on points.
legend_split_col : int, default 2 Number of columns in legend.
width : int, default 8 Figure width.
height : int, default 6 Figure height.
inc_num : bool, default True If True, annotate points with numeric labels.
Returns
matplotlib.figure.Figure UMAP scatter plot figure.
809 def UMAP_feature( 810 self, 811 feature_name: str, 812 features_data: pd.DataFrame | None, 813 point_size: int | float = 0.6, 814 font_size: int | float = 6, 815 width: int = 8, 816 height: int = 6, 817 palette="light", 818 ): 819 """ 820 Visualize UMAP embedding with expression levels of a selected feature. 821 822 Each point (cell) in the UMAP plot is colored according to the expression 823 value of the chosen feature, enabling interpretation of spatial patterns 824 of gene activity or metadata distribution in low-dimensional space. 825 826 Parameters 827 ---------- 828 feature_name : str 829 Name of the feature to plot. 830 831 features_data : pandas.DataFrame or None, default None 832 If None, the function uses the DataFrame containing the clustering data. 833 To plot features not used in clustering, provide a wider DataFrame 834 containing the original feature values. 835 836 point_size : float, default 0.6 837 Size of scatter points in the plot. 838 839 font_size : int, default 6 840 Font size for axis labels and annotations. 841 842 width : int, default 8 843 Width of the matplotlib figure. 844 845 height : int, default 6 846 Height of the matplotlib figure. 847 848 palette : str, default 'light' 849 Color palette for expression visualization. Options are: 850 - 'light' 851 - 'dark' 852 - 'green' 853 - 'gray' 854 855 Returns 856 ------- 857 matplotlib.figure.Figure 858 UMAP scatter plot colored by feature values. 859 """ 860 861 umap_df = self.umap.iloc[:, 0:2].copy() 862 863 if features_data is None: 864 865 features_data = self.clustering_data 866 867 if features_data.shape[1] != umap_df.shape[0]: 868 raise ValueError( 869 "Imputed 'features_data' shape does not match the number of UMAP cells" 870 ) 871 872 blist = [ 873 True if x.upper() == feature_name.upper() else False 874 for x in features_data.index 875 ] 876 877 if not any(blist): 878 raise ValueError("Imputed feature_name is not included in the data") 879 880 umap_df.loc[:, "feature"] = ( 881 features_data.loc[blist, :] 882 .apply(lambda row: row.tolist(), axis=1) 883 .values[0] 884 ) 885 886 umap_df = umap_df.sort_values("feature", ascending=True) 887 888 import matplotlib.colors as mcolors 889 890 if palette == "light": 891 palette = px.colors.sequential.Sunsetdark 892 893 elif palette == "dark": 894 palette = px.colors.sequential.thermal 895 896 elif palette == "green": 897 palette = px.colors.sequential.Aggrnyl 898 899 elif palette == "gray": 900 palette = px.colors.sequential.gray 901 palette = palette[::-1] 902 903 else: 904 raise ValueError( 905 'Palette not found. Use: "light", "dark", "gray", or "green"' 906 ) 907 908 converted = [] 909 for c in palette: 910 rgb_255 = px.colors.unlabel_rgb(c) 911 rgb_01 = tuple(v / 255.0 for v in rgb_255) 912 converted.append(rgb_01) 913 914 my_cmap = mcolors.ListedColormap(converted, name="custom") 915 916 fig, ax = plt.subplots(figsize=(width, height)) 917 sc = ax.scatter( 918 umap_df["UMAP1"], 919 umap_df["UMAP2"], 920 c=umap_df["feature"], 921 s=point_size, 922 cmap=my_cmap, 923 alpha=1.0, 924 edgecolors="black", 925 linewidths=0.1, 926 ) 927 928 cbar = plt.colorbar(sc, ax=ax) 929 cbar.set_label(f"{feature_name}") 930 931 ax.set_xlabel("UMAP 1") 932 ax.set_ylabel("UMAP 2") 933 934 ax.grid(True) 935 936 plt.tight_layout() 937 938 return fig
Visualize UMAP embedding with expression levels of a selected feature.
Each point (cell) in the UMAP plot is colored according to the expression value of the chosen feature, enabling interpretation of spatial patterns of gene activity or metadata distribution in low-dimensional space.
Parameters
feature_name : str Name of the feature to plot.
features_data : pandas.DataFrame or None, default None If None, the function uses the DataFrame containing the clustering data. To plot features not used in clustering, provide a wider DataFrame containing the original feature values.
point_size : float, default 0.6 Size of scatter points in the plot.
font_size : int, default 6 Font size for axis labels and annotations.
width : int, default 8 Width of the matplotlib figure.
height : int, default 6 Height of the matplotlib figure.
palette : str, default 'light' Color palette for expression visualization. Options are: - 'light' - 'dark' - 'green' - 'gray'
Returns
matplotlib.figure.Figure UMAP scatter plot colored by feature values.
940 def get_umap_data(self): 941 """ 942 Retrieve UMAP embedding data with optional cluster labels. 943 944 Returns the UMAP coordinates stored in `self.umap`. If clustering 945 metadata is available (specifically `UMAP_clusters`), the corresponding 946 cluster assignments are appended as an additional column. 947 948 Returns 949 ------- 950 pandas.DataFrame 951 DataFrame containing UMAP coordinates (columns: 'UMAP1', 'UMAP2', ...). 952 If available, includes an extra column 'clusters' with cluster labels. 953 954 Notes 955 ----- 956 - UMAP embeddings must be computed beforehand (e.g., using `perform_UMAP`). 957 - Cluster labels are added only if present in `self.clustering_metadata`. 958 """ 959 960 umap_data = self.umap 961 962 try: 963 umap_data["clusters"] = self.clustering_metadata["UMAP_clusters"] 964 except: 965 pass 966 967 return umap_data
Retrieve UMAP embedding data with optional cluster labels.
Returns the UMAP coordinates stored in self.umap
. If clustering
metadata is available (specifically UMAP_clusters
), the corresponding
cluster assignments are appended as an additional column.
Returns
pandas.DataFrame DataFrame containing UMAP coordinates (columns: 'UMAP1', 'UMAP2', ...). If available, includes an extra column 'clusters' with cluster labels.
Notes
- UMAP embeddings must be computed beforehand (e.g., using
perform_UMAP
). - Cluster labels are added only if present in
self.clustering_metadata
.
969 def get_pca_data(self): 970 """ 971 Retrieve PCA embedding data with optional cluster labels. 972 973 Returns the principal component scores stored in `self.pca`. If clustering 974 metadata is available (specifically `PCA_clusters`), the corresponding 975 cluster assignments are appended as an additional column. 976 977 Returns 978 ------- 979 pandas.DataFrame 980 DataFrame containing PCA coordinates (columns: 'PC1', 'PC2', ...). 981 If available, includes an extra column 'clusters' with cluster labels. 982 983 Notes 984 ----- 985 - PCA must be computed beforehand (e.g., using `perform_PCA`). 986 - Cluster labels are added only if present in `self.clustering_metadata`. 987 """ 988 989 pca_data = self.pca 990 991 try: 992 pca_data["clusters"] = self.clustering_metadata["PCA_clusters"] 993 except: 994 pass 995 996 return pca_data
Retrieve PCA embedding data with optional cluster labels.
Returns the principal component scores stored in self.pca
. If clustering
metadata is available (specifically PCA_clusters
), the corresponding
cluster assignments are appended as an additional column.
Returns
pandas.DataFrame DataFrame containing PCA coordinates (columns: 'PC1', 'PC2', ...). If available, includes an extra column 'clusters' with cluster labels.
Notes
- PCA must be computed beforehand (e.g., using
perform_PCA
). - Cluster labels are added only if present in
self.clustering_metadata
.
998 def return_clusters(self, clusters="umap"): 999 """ 1000 Retrieve cluster labels from UMAP or PCA clustering results. 1001 1002 Parameters 1003 ---------- 1004 clusters : str, default 'umap' 1005 Source of cluster labels to return. Must be one of: 1006 - 'umap': return cluster labels from UMAP embeddings. 1007 - 'pca' : return cluster labels from PCA embeddings. 1008 1009 Returns 1010 ------- 1011 list 1012 Cluster labels corresponding to the selected embedding method. 1013 1014 Raises 1015 ------ 1016 ValueError 1017 If `clusters` is not 'umap' or 'pca'. 1018 1019 Notes 1020 ----- 1021 Requires that clustering has already been performed 1022 (e.g., using `find_clusters_UMAP` or `find_clusters_PCA`). 1023 """ 1024 1025 if clusters.lower() == "umap": 1026 clusters_vector = self.clustering_metadata["UMAP_clusters"] 1027 elif clusters.lower() == "pca": 1028 clusters_vector = self.clustering_metadata["PCA_clusters"] 1029 else: 1030 raise ValueError("Parameter 'clusters' must be either 'umap' or 'pca'.") 1031 1032 return clusters_vector
Retrieve cluster labels from UMAP or PCA clustering results.
Parameters
clusters : str, default 'umap' Source of cluster labels to return. Must be one of: - 'umap': return cluster labels from UMAP embeddings. - 'pca' : return cluster labels from PCA embeddings.
Returns
list Cluster labels corresponding to the selected embedding method.
Raises
ValueError
If clusters
is not 'umap' or 'pca'.
Notes
Requires that clustering has already been performed
(e.g., using find_clusters_UMAP
or find_clusters_PCA
).
1035class COMPsc(Clustering): 1036 """ 1037 A class `COMPsc` (Comparison of single-cell data) designed for the integration, 1038 analysis, and visualization of single-cell datasets. 1039 The class supports independent dataset integration, subclustering of existing clusters, 1040 marker detection, and multiple visualization strategies. 1041 1042 The COMPsc class provides methods for: 1043 1044 - Normalizing and filtering single-cell data 1045 - Loading and saving sparse 10x-style datasets 1046 - Computing differential expression and marker genes 1047 - Clustering and subclustering analysis 1048 - Visualizing similarity and spatial relationships 1049 - Aggregating data by cell and set annotations 1050 - Managing metadata and renaming labels 1051 - Plotting gene detection histograms and feature scatters 1052 1053 Methods 1054 ------- 1055 project_dir(path_to_directory, project_list) 1056 Scans a directory to create a COMPsc instance mapping project names to their paths. 1057 1058 save_project(name, path=os.getcwd()) 1059 Saves the COMPsc object to a pickle file on disk. 1060 1061 load_project(path) 1062 Loads a previously saved COMPsc object from a pickle file. 1063 1064 reduce_cols(reg, inc_set=False) 1065 Removes columns from data tables where column names contain a specified name or partial substring. 1066 1067 reduce_rows(reg, inc_set=False) 1068 Removes rows from data tables where column names contain a specified feature (gene) name. 1069 1070 get_data(set_info=False) 1071 Returns normalized data with optional set annotations in column names. 1072 1073 get_metadata() 1074 Returns the stored input metadata. 1075 1076 get_partial_data(names=None, features=None, name_slot='cell_names') 1077 Return a subset of the data by sample names and/or features. 1078 1079 gene_calculation() 1080 Calculates and stores per-cell gene detection counts as a pandas Series. 1081 1082 gene_histograme(bins=100) 1083 Plots a histogram of genes detected per cell with an overlaid normal distribution. 1084 1085 gene_threshold(min_n=None, max_n=None) 1086 Filters cells based on minimum and/or maximum gene detection thresholds. 1087 1088 load_sparse_from_projects(normalized_data=False) 1089 Loads and concatenates sparse 10x-style datasets from project paths into count or normalized data. 1090 1091 rename_names(mapping, slot='cell_names') 1092 Renames entries in a specified metadata column using a provided mapping dictionary. 1093 1094 rename_subclusters(mapping) 1095 Renames subcluster labels using a provided mapping dictionary. 1096 1097 save_sparse(path_to_save=os.getcwd(), name_slot='cell_names', data_slot='normalized') 1098 Exports data as 10x-compatible sparse files (matrix.mtx, barcodes.tsv, genes.tsv). 1099 1100 normalize_data(normalize=True, normalize_factor=100000) 1101 Normalizes raw counts to counts-per-specified factor (e.g., CPM-like). 1102 1103 statistic(cells=None, sets=None, min_exp=0.01, min_pct=0.1, n_proc=10) 1104 Computes per-feature differential expression statistics (Mann-Whitney U) comparing target vs. rest groups. 1105 1106 calculate_difference_markers(min_exp=0, min_pct=0.25, n_proc=10, force=False) 1107 Computes and caches differential markers using the statistic method. 1108 1109 clustering_features(features_list=None, name_slot='cell_names', p_val=0.05, top_n=25, adj_mean=True, beta=0.4) 1110 Prepares clustering input by selecting marker features and optionally smoothing cell values. 1111 1112 average() 1113 Aggregates normalized data by averaging across (cell_name, set) pairs. 1114 1115 estimating_similarity(method='pearson', p_val=0.05, top_n=25) 1116 Computes pairwise correlation and Euclidean distance between aggregated samples. 1117 1118 similarity_plot(split_sets=True, set_info=True, cmap='seismic', width=12, height=10) 1119 Visualizes pairwise similarity as a scatter plot with correlation as hue and scaled distance as point size. 1120 1121 spatial_similarity(set_info=True, bandwidth=1, n_neighbors=5, min_dist=0.1, legend_split=2, point_size=20, ...) 1122 Creates a UMAP-like visualization of similarity relationships with cluster hulls and nearest-neighbor arrows. 1123 1124 subcluster_prepare(features, cluster) 1125 Initializes a Clustering object for subcluster analysis on a selected parent cluster. 1126 1127 define_subclusters(umap_num=2, eps=0.5, min_samples=10, bandwidth=1, n_neighbors=5, min_dist=0.1, ...) 1128 Performs UMAP and DBSCAN clustering on prepared subcluster data and stores cluster labels. 1129 1130 subcluster_features_scatter(colors='viridis', hclust='complete', img_width=3, img_high=5, label_size=6, ...) 1131 Visualizes averaged expression and occurrence of features for subclusters as a scatter plot. 1132 1133 subcluster_DEG_scatter(top_n=3, min_exp=0, min_pct=0.25, p_val=0.05, colors='viridis', ...) 1134 Plots top differential features for subclusters as a features-scatter visualization. 1135 1136 accept_subclusters() 1137 Commits subcluster labels to main metadata by renaming cell names and clears subcluster data. 1138 1139 Raises 1140 ------ 1141 ValueError 1142 For invalid parameters, mismatched dimensions, or missing metadata. 1143 1144 """ 1145 1146 def __init__( 1147 self, 1148 objects=None, 1149 ): 1150 """ 1151 Initialize the COMPsc class for single-cell data integration and analysis. 1152 1153 Parameters 1154 ---------- 1155 objects : list or None, optional 1156 Optional list of data objects to initialize the instance with. 1157 1158 Attributes 1159 ---------- 1160 -objects : list or None 1161 -input_data : pandas.DataFrame or None 1162 -input_metadata : pandas.DataFrame or None 1163 -normalized_data : pandas.DataFrame or None 1164 -agg_metadata : pandas.DataFrame or None 1165 -agg_normalized_data : pandas.DataFrame or None 1166 -similarity : pandas.DataFrame or None 1167 -var_data : pandas.DataFrame or None 1168 -subclusters_ : instance of Clustering class or None 1169 -cells_calc : pandas.Series or None 1170 -gene_calc : pandas.Series or None 1171 -composition_data : pandas.DataFrame or None 1172 """ 1173 1174 self.objects = objects 1175 """ Stores the input data objects.""" 1176 1177 self.input_data = None 1178 """Raw input data for clustering or integration analysis.""" 1179 1180 self.input_metadata = None 1181 """Metadata associated with the input data.""" 1182 1183 self.normalized_data = None 1184 """Normalized version of the input data.""" 1185 1186 self.agg_metadata = None 1187 '''Aggregated metadata for all sets in object related to "agg_normalized_data"''' 1188 1189 self.agg_normalized_data = None 1190 """Aggregated and normalized data across multiple sets.""" 1191 1192 self.similarity = None 1193 """Similarity data between cells across all samples. and sets""" 1194 1195 self.var_data = None 1196 """DEG analysis results summarizing variance across all samples in the object.""" 1197 1198 self.subclusters_ = None 1199 """Placeholder for information about subclusters analysis; if computed.""" 1200 1201 self.cells_calc = None 1202 """Number of cells detected per sample (grouped by lineage, e.g., cluster or name), reflecting data composition.""" 1203 1204 self.gene_calc = None 1205 """Number of genes detected per sample (cell), reflecting the sequencing depth.""" 1206 1207 self.composition_data = None 1208 """Data describing composition of cells across clusters or sets.""" 1209 1210 @classmethod 1211 def project_dir(cls, path_to_directory, project_list): 1212 """ 1213 Scan a directory and build a COMPsc instance mapping provided project names 1214 to their paths. 1215 1216 Parameters 1217 ---------- 1218 path_to_directory : str 1219 Path containing project subfolders. 1220 1221 project_list : list[str] 1222 List of filenames (folder names) to include in the returned object map. 1223 1224 Returns 1225 ------- 1226 COMPsc 1227 New COMPsc instance with `objects` populated. 1228 1229 Raises 1230 ------ 1231 Exception 1232 A generic exception is caught and a message printed if scanning fails. 1233 1234 Notes 1235 ----- 1236 Function attempts to match entries in `project_list` to directory 1237 names and constructs a simplified object key from the folder name. 1238 """ 1239 try: 1240 objects = {} 1241 for filename in tqdm(os.listdir(path_to_directory)): 1242 for c in project_list: 1243 f = os.path.join(path_to_directory, filename) 1244 if c == filename and os.path.isdir(f): 1245 objects[ 1246 str( 1247 re.sub( 1248 "_count", 1249 "", 1250 re.sub( 1251 "_fq", 1252 "", 1253 re.sub( 1254 re.sub("_.*", "", filename) + "_", 1255 "", 1256 filename, 1257 ), 1258 ), 1259 ) 1260 ) 1261 ] = f 1262 1263 return cls(objects) 1264 1265 except: 1266 print("Something went wrong. Check the function input data and try again!") 1267 1268 def save_project(self, name, path: str = os.getcwd()): 1269 """ 1270 Save the COMPsc object to disk using pickle. 1271 1272 Parameters 1273 ---------- 1274 name : str 1275 Base filename (without extension) to use when saving. 1276 1277 path : str, default os.getcwd() 1278 Directory in which to save the project file. 1279 1280 Returns 1281 ------- 1282 None 1283 1284 Side Effects 1285 ------------ 1286 - Writes a file `<path>/<name>.jpkl` containing the pickled object. 1287 - Prints a confirmation message with saved path. 1288 """ 1289 1290 full = os.path.join(path, f"{name}.jpkl") 1291 1292 with open(full, "wb") as f: 1293 pickle.dump(self, f) 1294 1295 print(f"Project saved as {full}") 1296 1297 @classmethod 1298 def load_project(cls, path): 1299 """ 1300 Load a previously saved COMPsc project from a pickle file. 1301 1302 Parameters 1303 ---------- 1304 path : str 1305 Full path to the pickled project file. 1306 1307 Returns 1308 ------- 1309 COMPsc 1310 The unpickled COMPsc object. 1311 1312 Raises 1313 ------ 1314 FileNotFoundError 1315 If the provided path does not exist. 1316 """ 1317 1318 if not os.path.exists(path): 1319 raise FileNotFoundError("File does not exist!") 1320 with open(path, "rb") as f: 1321 obj = pickle.load(f) 1322 return obj 1323 1324 def reduce_cols( 1325 self, 1326 reg: str | None = None, 1327 full: str | None = None, 1328 name_slot: str = "cell_names", 1329 inc_set: bool = False, 1330 ): 1331 """ 1332 Remove columns (cells) whose names contain a substring `reg` or 1333 full name `full` from available tables. 1334 1335 Parameters 1336 ---------- 1337 reg : str | None 1338 Substring to search for in column/cell names; matching columns will be removed. 1339 If not None, `full` must be None. 1340 1341 full : str | None 1342 Full name to search for in column/cell names; matching columns will be removed. 1343 If not None, `reg` must be None. 1344 1345 name_slot : str, default 'cell_names' 1346 Column in metadata to use as sample names. 1347 1348 inc_set : bool, default False 1349 If True, column names are interpreted as 'cell_name # set' when matching. 1350 1351 Update 1352 ------------ 1353 Mutates `self.input_data`, `self.normalized_data`, `self.input_metadata`, 1354 `self.agg_normalized_data`, and `self.agg_metadata` (if they exist), 1355 removing columns/rows that match `reg`. 1356 1357 Raises 1358 ------ 1359 Raises ValueError if nothing matches the reduction mask. 1360 """ 1361 1362 if reg is None and full is None: 1363 raise ValueError( 1364 "Both 'reg' and 'full' arguments not provided. Please provide at least one of them!" 1365 ) 1366 1367 if reg is not None and full is not None: 1368 raise ValueError( 1369 "Both 'reg' and 'full' arguments are provided. " 1370 "Please provide only one of them!\n" 1371 "'reg' is used when only part of the name must be detected.\n" 1372 "'full' is used if the full name must be detected." 1373 ) 1374 1375 if reg is not None: 1376 1377 if self.input_data is not None: 1378 1379 if inc_set: 1380 1381 self.input_data.columns = ( 1382 self.input_metadata[name_slot] 1383 + " # " 1384 + self.input_metadata["sets"] 1385 ) 1386 1387 else: 1388 1389 self.input_data.columns = self.input_metadata[name_slot] 1390 1391 mask = [reg.upper() not in x.upper() for x in self.input_data.columns] 1392 1393 if len([y for y in mask if y is False]) == 0: 1394 raise ValueError("Nothing found to reduce") 1395 1396 self.input_data = self.input_data.loc[:, mask] 1397 1398 if self.normalized_data is not None: 1399 1400 if inc_set: 1401 1402 self.normalized_data.columns = ( 1403 self.input_metadata[name_slot] 1404 + " # " 1405 + self.input_metadata["sets"] 1406 ) 1407 1408 else: 1409 1410 self.normalized_data.columns = self.input_metadata[name_slot] 1411 1412 mask = [ 1413 reg.upper() not in x.upper() for x in self.normalized_data.columns 1414 ] 1415 1416 if len([y for y in mask if y is False]) == 0: 1417 raise ValueError("Nothing found to reduce") 1418 1419 self.normalized_data = self.normalized_data.loc[:, mask] 1420 1421 if self.input_metadata is not None: 1422 1423 if inc_set: 1424 1425 self.input_metadata["drop"] = ( 1426 self.input_metadata[name_slot] 1427 + " # " 1428 + self.input_metadata["sets"] 1429 ) 1430 1431 else: 1432 1433 self.input_metadata["drop"] = self.input_metadata[name_slot] 1434 1435 mask = [ 1436 reg.upper() not in x.upper() for x in self.input_metadata["drop"] 1437 ] 1438 1439 if len([y for y in mask if y is False]) == 0: 1440 raise ValueError("Nothing found to reduce") 1441 1442 self.input_metadata = self.input_metadata.loc[mask, :].reset_index( 1443 drop=True 1444 ) 1445 1446 self.input_metadata = self.input_metadata.drop( 1447 columns=["drop"], errors="ignore" 1448 ) 1449 1450 if self.agg_normalized_data is not None: 1451 1452 if inc_set: 1453 1454 self.agg_normalized_data.columns = ( 1455 self.agg_metadata[name_slot] + " # " + self.agg_metadata["sets"] 1456 ) 1457 1458 else: 1459 1460 self.agg_normalized_data.columns = self.agg_metadata[name_slot] 1461 1462 mask = [ 1463 reg.upper() not in x.upper() 1464 for x in self.agg_normalized_data.columns 1465 ] 1466 1467 if len([y for y in mask if y is False]) == 0: 1468 raise ValueError("Nothing found to reduce") 1469 1470 self.agg_normalized_data = self.agg_normalized_data.loc[:, mask] 1471 1472 if self.agg_metadata is not None: 1473 1474 if inc_set: 1475 1476 self.agg_metadata["drop"] = ( 1477 self.agg_metadata[name_slot] + " # " + self.agg_metadata["sets"] 1478 ) 1479 1480 else: 1481 1482 self.agg_metadata["drop"] = self.agg_metadata[name_slot] 1483 1484 mask = [reg.upper() not in x.upper() for x in self.agg_metadata["drop"]] 1485 1486 if len([y for y in mask if y is False]) == 0: 1487 raise ValueError("Nothing found to reduce") 1488 1489 self.agg_metadata = self.agg_metadata.loc[mask, :].reset_index( 1490 drop=True 1491 ) 1492 1493 self.agg_metadata = self.agg_metadata.drop( 1494 columns=["drop"], errors="ignore" 1495 ) 1496 1497 elif full is not None: 1498 1499 if self.input_data is not None: 1500 1501 if inc_set: 1502 1503 self.input_data.columns = ( 1504 self.input_metadata[name_slot] 1505 + " # " 1506 + self.input_metadata["sets"] 1507 ) 1508 1509 if "#" not in full: 1510 1511 self.input_data.columns = self.input_metadata[name_slot] 1512 1513 print( 1514 "Not include the set info (name # set) in the 'full' argument, where 'inc_set' is True" 1515 "Only the names will be compared, without considering the set information." 1516 ) 1517 1518 else: 1519 1520 self.input_data.columns = self.input_metadata[name_slot] 1521 1522 mask = [full.upper() != x.upper() for x in self.input_data.columns] 1523 1524 if len([y for y in mask if y is False]) == 0: 1525 raise ValueError("Nothing found to reduce") 1526 1527 self.input_data = self.input_data.loc[:, mask] 1528 1529 if self.normalized_data is not None: 1530 1531 if inc_set: 1532 1533 self.normalized_data.columns = ( 1534 self.input_metadata[name_slot] 1535 + " # " 1536 + self.input_metadata["sets"] 1537 ) 1538 1539 if "#" not in full: 1540 1541 self.normalized_data.columns = self.input_metadata[name_slot] 1542 1543 print( 1544 "Not include the set info (name # set) in the 'full' argument, where 'inc_set' is True" 1545 "Only the names will be compared, without considering the set information." 1546 ) 1547 1548 else: 1549 1550 self.normalized_data.columns = self.input_metadata[name_slot] 1551 1552 mask = [full.upper() != x.upper() for x in self.normalized_data.columns] 1553 1554 if len([y for y in mask if y is False]) == 0: 1555 raise ValueError("Nothing found to reduce") 1556 1557 self.normalized_data = self.normalized_data.loc[:, mask] 1558 1559 if self.input_metadata is not None: 1560 1561 if inc_set: 1562 1563 self.input_metadata["drop"] = ( 1564 self.input_metadata[name_slot] 1565 + " # " 1566 + self.input_metadata["sets"] 1567 ) 1568 1569 if "#" not in full: 1570 1571 self.input_metadata["drop"] = self.input_metadata[name_slot] 1572 1573 print( 1574 "Not include the set info (name # set) in the 'full' argument, where 'inc_set' is True" 1575 "Only the names will be compared, without considering the set information." 1576 ) 1577 1578 else: 1579 1580 self.input_metadata["drop"] = self.input_metadata[name_slot] 1581 1582 mask = [full.upper() != x.upper() for x in self.input_metadata["drop"]] 1583 1584 if len([y for y in mask if y is False]) == 0: 1585 raise ValueError("Nothing found to reduce") 1586 1587 self.input_metadata = self.input_metadata.loc[mask, :].reset_index( 1588 drop=True 1589 ) 1590 1591 self.input_metadata = self.input_metadata.drop( 1592 columns=["drop"], errors="ignore" 1593 ) 1594 1595 if self.agg_normalized_data is not None: 1596 1597 if inc_set: 1598 1599 self.agg_normalized_data.columns = ( 1600 self.agg_metadata[name_slot] + " # " + self.agg_metadata["sets"] 1601 ) 1602 1603 if "#" not in full: 1604 1605 self.agg_normalized_data.columns = self.agg_metadata[name_slot] 1606 1607 print( 1608 "Not include the set info (name # set) in the 'full' argument, where 'inc_set' is True" 1609 "Only the names will be compared, without considering the set information." 1610 ) 1611 else: 1612 1613 self.agg_normalized_data.columns = self.agg_metadata[name_slot] 1614 1615 mask = [ 1616 full.upper() != x.upper() for x in self.agg_normalized_data.columns 1617 ] 1618 1619 if len([y for y in mask if y is False]) == 0: 1620 raise ValueError("Nothing found to reduce") 1621 1622 self.agg_normalized_data = self.agg_normalized_data.loc[:, mask] 1623 1624 if self.agg_metadata is not None: 1625 1626 if inc_set: 1627 1628 self.agg_metadata["drop"] = ( 1629 self.agg_metadata[name_slot] + " # " + self.agg_metadata["sets"] 1630 ) 1631 1632 if "#" not in full: 1633 1634 self.agg_metadata["drop"] = self.agg_metadata[name_slot] 1635 1636 print( 1637 "Not include the set info (name # set) in the 'full' argument, where 'inc_set' is True" 1638 "Only the names will be compared, without considering the set information." 1639 ) 1640 else: 1641 1642 self.agg_metadata["drop"] = self.agg_metadata[name_slot] 1643 1644 mask = [full.upper() != x.upper() for x in self.agg_metadata["drop"]] 1645 1646 if len([y for y in mask if y is False]) == 0: 1647 raise ValueError("Nothing found to reduce") 1648 1649 self.agg_metadata = self.agg_metadata.loc[mask, :].reset_index( 1650 drop=True 1651 ) 1652 1653 self.agg_metadata = self.agg_metadata.drop( 1654 columns=["drop"], errors="ignore" 1655 ) 1656 1657 self.gene_calculation() 1658 self.cells_calculation() 1659 1660 def reduce_rows(self, features_list: list): 1661 """ 1662 Remove rows (features) whose names are included in features_list. 1663 1664 Parameters 1665 ---------- 1666 features_list : list 1667 List of features to search for in index/gene names; matching entries will be removed. 1668 1669 Update 1670 ------------ 1671 Mutates `self.input_data`, `self.normalized_data`, `self.input_metadata`, 1672 `self.agg_normalized_data`, and `self.agg_metadata` (if they exist), 1673 removing columns/rows that match `reg`. 1674 1675 Raises 1676 ------ 1677 Prints a message listing features that are not found in the data. 1678 """ 1679 1680 if self.input_data is not None: 1681 1682 res = find_features(self.input_data, features=features_list) 1683 1684 res_list = [x.upper() for x in res["included"]] 1685 1686 mask = [x.upper() not in res_list for x in self.input_data.index] 1687 1688 if len([y for y in mask if y is False]) == 0: 1689 raise ValueError("Nothing found to reduce") 1690 1691 self.input_data = self.input_data.loc[mask, :] 1692 1693 if self.normalized_data is not None: 1694 1695 res = find_features(self.normalized_data, features=features_list) 1696 1697 res_list = [x.upper() for x in res["included"]] 1698 1699 mask = [x.upper() not in res_list for x in self.normalized_data.index] 1700 1701 if len([y for y in mask if y is False]) == 0: 1702 raise ValueError("Nothing found to reduce") 1703 1704 self.normalized_data = self.normalized_data.loc[mask, :] 1705 1706 if self.agg_normalized_data is not None: 1707 1708 res = find_features(self.agg_normalized_data, features=features_list) 1709 1710 res_list = [x.upper() for x in res["included"]] 1711 1712 mask = [x.upper() not in res_list for x in self.agg_normalized_data.index] 1713 1714 if len([y for y in mask if y is False]) == 0: 1715 raise ValueError("Nothing found to reduce") 1716 1717 self.agg_normalized_data = self.agg_normalized_data.loc[mask, :] 1718 1719 if len(res["not_included"]) > 0: 1720 print("\nFeatures not found:") 1721 for i in res["not_included"]: 1722 print(i) 1723 1724 self.gene_calculation() 1725 self.cells_calculation() 1726 1727 def get_data(self, set_info: bool = False): 1728 """ 1729 Return normalized data with optional set annotation appended to column names. 1730 1731 Parameters 1732 ---------- 1733 set_info : bool, default False 1734 If True, column names are returned as "cell_name # set"; otherwise 1735 only the `cell_name` is used. 1736 1737 Returns 1738 ------- 1739 pandas.DataFrame 1740 The `self.normalized_data` table with columns renamed according to `set_info`. 1741 1742 Raises 1743 ------ 1744 AttributeError 1745 If `self.normalized_data` or `self.input_metadata` is missing. 1746 """ 1747 1748 to_return = self.normalized_data 1749 1750 if set_info: 1751 to_return.columns = ( 1752 self.input_metadata["cell_names"] + " # " + self.input_metadata["sets"] 1753 ) 1754 else: 1755 to_return.columns = self.input_metadata["cell_names"] 1756 1757 return to_return 1758 1759 def get_partial_data( 1760 self, 1761 names: list | str | None = None, 1762 features: list | str | None = None, 1763 name_slot: str = "cell_names", 1764 inc_metadata: bool = False, 1765 ): 1766 """ 1767 Return a subset of the data filtered by sample names and/or feature names. 1768 1769 Parameters 1770 ---------- 1771 names : list, str, or None 1772 Names of samples to include. If None, all samples are considered. 1773 1774 features : list, str, or None 1775 Names of features to include. If None, all features are considered. 1776 1777 name_slot : str 1778 Column in metadata to use as sample names. 1779 1780 inc_metadata : bool 1781 If True return tuple (data, metadata) 1782 1783 Returns 1784 ------- 1785 pandas.DataFrame 1786 Subset of the normalized data based on the specified names and features. 1787 """ 1788 1789 data = self.normalized_data.copy() 1790 metadata = self.input_metadata 1791 1792 if name_slot in self.input_metadata.columns: 1793 data.columns = self.input_metadata[name_slot] 1794 else: 1795 raise ValueError("'name_slot' not occured in data!'") 1796 1797 if isinstance(features, str): 1798 features = [features] 1799 elif features is None: 1800 features = [] 1801 1802 if isinstance(names, str): 1803 names = [names] 1804 elif names is None: 1805 names = [] 1806 1807 features = [x.upper() for x in features] 1808 names = [x.upper() for x in names] 1809 1810 columns_names = [x.upper() for x in data.columns] 1811 features_names = [x.upper() for x in data.index] 1812 1813 columns_bool = [True if x in names else False for x in columns_names] 1814 features_bool = [True if x in features else False for x in features_names] 1815 1816 if True not in columns_bool and True not in features_bool: 1817 print("Missing 'names' and/or 'features'. Returning full dataset instead.") 1818 if inc_metadata: 1819 return data, metadata 1820 else: 1821 return data 1822 1823 if True in columns_bool: 1824 data = data.loc[:, columns_bool] 1825 metadata = metadata.loc[columns_bool, :] 1826 1827 if inc_metadata: 1828 return data, metadata 1829 else: 1830 return data 1831 1832 if True in features_bool: 1833 data = data.loc[features_bool, :] 1834 1835 if inc_metadata: 1836 return data, metadata 1837 else: 1838 return data 1839 1840 def get_metadata(self): 1841 """ 1842 Return the stored input metadata. 1843 1844 Returns 1845 ------- 1846 pandas.DataFrame 1847 `self.input_metadata` (may be None if not set). 1848 """ 1849 1850 to_return = self.input_metadata 1851 1852 return to_return 1853 1854 def gene_calculation(self): 1855 """ 1856 Calculate and store per-cell counts (e.g., number of detected genes). 1857 1858 The method computes a binary (presence/absence) per cell and sums across 1859 features to produce `self.gene_calc`. 1860 1861 Update 1862 ------ 1863 Sets `self.gene_calc` as a pandas.Series. 1864 1865 Side Effects 1866 ------------ 1867 Uses `self.input_data` when available, otherwise `self.normalized_data`. 1868 """ 1869 1870 if self.input_data is not None: 1871 1872 bin_col = self.input_data.columns.copy() 1873 1874 bin_col = bin_col.where(bin_col <= 0, 1) 1875 1876 sum_data = bin_col.sum(axis=0) 1877 1878 self.gene_calc = sum_data 1879 1880 elif self.normalized_data is not None: 1881 1882 bin_col = self.normalized_data.copy() 1883 1884 bin_col = bin_col.where(bin_col <= 0, 1) 1885 1886 sum_data = bin_col.sum(axis=0) 1887 1888 self.gene_calc = sum_data 1889 1890 def gene_histograme(self, bins=100): 1891 """ 1892 Plot a histogram of the number of genes detected per cell. 1893 1894 Parameters 1895 ---------- 1896 bins : int, default 100 1897 Number of histogram bins. 1898 1899 Returns 1900 ------- 1901 matplotlib.figure.Figure 1902 Figure containing the histogram of gene contents. 1903 1904 Notes 1905 ----- 1906 Requires `self.gene_calc` to be computed prior to calling. 1907 """ 1908 1909 fig, ax = plt.subplots(figsize=(8, 5)) 1910 1911 _, bin_edges, _ = ax.hist( 1912 self.gene_calc, bins=bins, edgecolor="black", alpha=0.6 1913 ) 1914 1915 mu, sigma = np.mean(self.gene_calc), np.std(self.gene_calc) 1916 1917 x = np.linspace(min(self.gene_calc), max(self.gene_calc), 1000) 1918 y = norm.pdf(x, mu, sigma) 1919 1920 y_scaled = y * len(self.gene_calc) * (bin_edges[1] - bin_edges[0]) 1921 1922 ax.plot( 1923 x, y_scaled, "r-", linewidth=2, label=f"Normal(μ={mu:.2f}, σ={sigma:.2f})" 1924 ) 1925 1926 ax.set_xlabel("Value") 1927 ax.set_ylabel("Count") 1928 ax.set_title("Histogram of genes detected per cell") 1929 1930 ax.set_xticks(np.linspace(min(self.gene_calc), max(self.gene_calc), 20)) 1931 ax.tick_params(axis="x", rotation=90) 1932 1933 ax.legend() 1934 1935 return fig 1936 1937 def gene_threshold(self, min_n: int | None, max_n: int | None): 1938 """ 1939 Filter cells by gene-detection thresholds (min and/or max). 1940 1941 Parameters 1942 ---------- 1943 min_n : int or None 1944 Minimum number of detected genes required to keep a cell. 1945 1946 max_n : int or None 1947 Maximum number of detected genes allowed to keep a cell. 1948 1949 Update 1950 ------- 1951 Filters `self.input_data`, `self.normalized_data`, `self.input_metadata` 1952 (and calls `average()` if `self.agg_normalized_data` exists). 1953 1954 Side Effects 1955 ------------ 1956 Raises ValueError if both bounds are None or if filtering removes all cells. 1957 """ 1958 1959 if min_n is not None and max_n is not None: 1960 mask = (self.gene_calc > min_n) & (self.gene_calc < max_n) 1961 elif min_n is None and max_n is not None: 1962 mask = self.gene_calc < max_n 1963 elif min_n is not None and max_n is None: 1964 mask = self.gene_calc > min_n 1965 else: 1966 raise ValueError("Lack of both min_n and max_n values") 1967 1968 if self.input_data is not None: 1969 1970 if len([y for y in mask if y is False]) == 0: 1971 raise ValueError("Nothing to reduce") 1972 1973 self.input_data = self.input_data.loc[:, mask.values] 1974 1975 if self.normalized_data is not None: 1976 1977 if len([y for y in mask if y is False]) == 0: 1978 raise ValueError("Nothing to reduce") 1979 1980 self.normalized_data = self.normalized_data.loc[:, mask.values] 1981 1982 if self.input_metadata is not None: 1983 1984 if len([y for y in mask if y is False]) == 0: 1985 raise ValueError("Nothing to reduce") 1986 1987 self.input_metadata = self.input_metadata.loc[mask.values, :].reset_index( 1988 drop=True 1989 ) 1990 1991 self.input_metadata = self.input_metadata.drop( 1992 columns=["drop"], errors="ignore" 1993 ) 1994 1995 if self.agg_normalized_data is not None: 1996 self.average() 1997 1998 self.gene_calculation() 1999 self.cells_calculation() 2000 2001 def cells_calculation(self, name_slot="cell_names"): 2002 """ 2003 Calculate number of cells per call name / cluster. 2004 2005 The method computes a binary (presence/absence) per cell name / cluster and sums across 2006 cells. 2007 2008 Parameters 2009 ---------- 2010 name_slot : str, default 'cell_names' 2011 Column in metadata to use as sample names. 2012 2013 Update 2014 ------ 2015 Sets `self.cells_calc` as a pd.DataFrame. 2016 """ 2017 2018 ls = list(self.input_metadata[name_slot]) 2019 2020 df = pd.DataFrame( 2021 { 2022 "cluster": pd.Series(ls).value_counts().index, 2023 "n": pd.Series(ls).value_counts().values, 2024 } 2025 ) 2026 2027 self.cells_calc = df 2028 2029 def cell_histograme(self, name_slot: str = "cell_names"): 2030 """ 2031 Plot a histogram of the number of cells detected per cell name (cluster). 2032 2033 Parameters 2034 ---------- 2035 name_slot : str, default 'cell_names' 2036 Column in metadata to use as sample names. 2037 2038 Returns 2039 ------- 2040 matplotlib.figure.Figure 2041 Figure containing the histogram of cell contents. 2042 2043 Notes 2044 ----- 2045 Requires `self.cells_calc` to be computed prior to calling. 2046 """ 2047 2048 if name_slot != "cell_names": 2049 self.cells_calculation(name_slot=name_slot) 2050 2051 fig, ax = plt.subplots(figsize=(8, 5)) 2052 2053 _, bin_edges, _ = ax.hist( 2054 list(self.cells_calc["n"]), 2055 bins=len(set(self.cells_calc["cluster"])), 2056 edgecolor="black", 2057 color="orange", 2058 alpha=0.6, 2059 ) 2060 2061 mu, sigma = np.mean(list(self.cells_calc["n"])), np.std( 2062 list(self.cells_calc["n"]) 2063 ) 2064 2065 x = np.linspace( 2066 min(list(self.cells_calc["n"])), max(list(self.cells_calc["n"])), 1000 2067 ) 2068 y = norm.pdf(x, mu, sigma) 2069 2070 y_scaled = y * len(list(self.cells_calc["n"])) * (bin_edges[1] - bin_edges[0]) 2071 2072 ax.plot( 2073 x, y_scaled, "r-", linewidth=2, label=f"Normal(μ={mu:.2f}, σ={sigma:.2f})" 2074 ) 2075 2076 ax.set_xlabel("Value") 2077 ax.set_ylabel("Count") 2078 ax.set_title("Histogram of cells detected per cell name / cluster") 2079 2080 ax.set_xticks( 2081 np.linspace( 2082 min(list(self.cells_calc["n"])), max(list(self.cells_calc["n"])), 20 2083 ) 2084 ) 2085 ax.tick_params(axis="x", rotation=90) 2086 2087 ax.legend() 2088 2089 return fig 2090 2091 def cluster_threshold(self, min_n: int | None, name_slot: str = "cell_names"): 2092 """ 2093 Filter cell names / clusters by cell-detection threshold. 2094 2095 Parameters 2096 ---------- 2097 min_n : int or None 2098 Minimum number of detected genes required to keep a cell. 2099 2100 name_slot : str, default 'cell_names' 2101 Column in metadata to use as sample names. 2102 2103 2104 Update 2105 ------- 2106 Filters `self.input_data`, `self.normalized_data`, `self.input_metadata` 2107 (and calls `average()` if `self.agg_normalized_data` exists). 2108 """ 2109 2110 if name_slot != "cell_names": 2111 self.cells_calculation(name_slot=name_slot) 2112 2113 if min_n is not None: 2114 names = self.cells_calc["cluster"][self.cells_calc["n"] < min_n] 2115 else: 2116 raise ValueError("Lack of min_n value") 2117 2118 if len(names) > 0: 2119 2120 if self.input_data is not None: 2121 2122 self.input_data.columns = self.input_metadata[name_slot] 2123 2124 mask = [not any(r in x for r in names) for x in self.input_data.columns] 2125 2126 if len([y for y in mask if y is False]) > 0: 2127 2128 self.input_data = self.input_data.loc[:, mask] 2129 2130 if self.normalized_data is not None: 2131 2132 self.normalized_data.columns = self.input_metadata[name_slot] 2133 2134 mask = [ 2135 not any(r in x for r in names) for x in self.normalized_data.columns 2136 ] 2137 2138 if len([y for y in mask if y is False]) > 0: 2139 2140 self.normalized_data = self.normalized_data.loc[:, mask] 2141 2142 if self.input_metadata is not None: 2143 2144 self.input_metadata["drop"] = self.input_metadata[name_slot] 2145 2146 mask = [ 2147 not any(r in x for r in names) for x in self.input_metadata["drop"] 2148 ] 2149 2150 if len([y for y in mask if y is False]) > 0: 2151 2152 self.input_metadata = self.input_metadata.loc[mask, :].reset_index( 2153 drop=True 2154 ) 2155 2156 self.input_metadata = self.input_metadata.drop( 2157 columns=["drop"], errors="ignore" 2158 ) 2159 2160 if self.agg_normalized_data is not None: 2161 2162 self.agg_normalized_data.columns = self.agg_metadata[name_slot] 2163 2164 mask = [ 2165 not any(r in x for r in names) 2166 for x in self.agg_normalized_data.columns 2167 ] 2168 2169 if len([y for y in mask if y is False]) > 0: 2170 2171 self.agg_normalized_data = self.agg_normalized_data.loc[:, mask] 2172 2173 if self.agg_metadata is not None: 2174 2175 self.agg_metadata["drop"] = self.agg_metadata[name_slot] 2176 2177 mask = [ 2178 not any(r in x for r in names) for x in self.agg_metadata["drop"] 2179 ] 2180 2181 if len([y for y in mask if y is False]) > 0: 2182 2183 self.agg_metadata = self.agg_metadata.loc[mask, :].reset_index( 2184 drop=True 2185 ) 2186 2187 self.agg_metadata = self.agg_metadata.drop( 2188 columns=["drop"], errors="ignore" 2189 ) 2190 2191 self.gene_calculation() 2192 self.cells_calculation() 2193 2194 def load_sparse_from_projects(self, normalized_data: bool = False): 2195 """ 2196 Load sparse 10x-style datasets from stored project paths, concatenate them, 2197 and populate `input_data` / `normalized_data` and `input_metadata`. 2198 2199 Parameters 2200 ---------- 2201 normalized_data : bool, default False 2202 If True, store concatenated tables in `self.normalized_data`. 2203 If False, store them in `self.input_data` and normalization 2204 is needed using normalize_data() method. 2205 2206 Side Effects 2207 ------------ 2208 - Reads each project using `load_sparse(...)` (expects matrix.mtx, genes.tsv, barcodes.tsv). 2209 - Concatenates all projects column-wise and sets `self.input_metadata`. 2210 - Replaces NaNs with zeros and updates `self.gene_calc`. 2211 """ 2212 2213 obj = self.objects 2214 2215 full_data = pd.DataFrame() 2216 full_metadata = pd.DataFrame() 2217 2218 for ke in obj.keys(): 2219 print(ke) 2220 2221 dt, met = load_sparse(path=obj[ke], name=ke) 2222 2223 full_data = pd.concat([full_data, dt], axis=1) 2224 full_metadata = pd.concat([full_metadata, met], axis=0) 2225 2226 full_data[np.isnan(full_data)] = 0 2227 2228 if normalized_data: 2229 self.normalized_data = full_data 2230 self.input_metadata = full_metadata 2231 else: 2232 2233 self.input_data = full_data 2234 self.input_metadata = full_metadata 2235 2236 self.gene_calculation() 2237 self.cells_calculation() 2238 2239 def rename_names(self, mapping: dict, slot: str = "cell_names"): 2240 """ 2241 Rename entries in `self.input_metadata[slot]` according to a provided mapping. 2242 2243 Parameters 2244 ---------- 2245 mapping : dict 2246 Dictionary with keys 'old_name' and 'new_name', each mapping to a list 2247 of equal length describing replacements. 2248 2249 slot : str, default 'cell_names' 2250 Metadata column to operate on. 2251 2252 Update 2253 ------- 2254 Updates `self.input_metadata[slot]` in-place with renamed values. 2255 2256 Raises 2257 ------ 2258 ValueError 2259 If mapping keys are incorrect, lengths differ, or some 'old_name' values 2260 are not present in the metadata column. 2261 """ 2262 2263 if set(["old_name", "new_name"]) != set(mapping.keys()): 2264 raise ValueError( 2265 "Mapping dictionary must contain keys 'old_name' and 'new_name', " 2266 "each with a list of names to change." 2267 ) 2268 2269 if len(mapping["old_name"]) != len(mapping["new_name"]): 2270 raise ValueError( 2271 "Mapping dictionary lists 'old_name' and 'new_name' " 2272 "must have the same length!" 2273 ) 2274 2275 names_vector = list(self.input_metadata[slot]) 2276 2277 if not all(elem in names_vector for elem in list(mapping["old_name"])): 2278 raise ValueError( 2279 f"Some entries from 'old_name' do not exist in the names of slot {slot}." 2280 ) 2281 2282 replace_dict = dict(zip(mapping["old_name"], mapping["new_name"])) 2283 2284 names_vector_ret = [replace_dict.get(item, item) for item in names_vector] 2285 2286 self.input_metadata[slot] = names_vector_ret 2287 2288 def rename_subclusters(self, mapping): 2289 """ 2290 Rename labels stored in `self.subclusters_.subclusters` according to mapping. 2291 2292 Parameters 2293 ---------- 2294 mapping : dict 2295 Mapping with keys 'old_name' and 'new_name' (lists of equal length). 2296 2297 Update 2298 ------- 2299 Updates `self.subclusters_.subclusters` with renamed labels. 2300 2301 Raises 2302 ------ 2303 ValueError 2304 If mapping is invalid or old names are not present. 2305 """ 2306 2307 if set(["old_name", "new_name"]) != set(mapping.keys()): 2308 raise ValueError( 2309 "Mapping dictionary must contain keys 'old_name' and 'new_name', " 2310 "each with a list of names to change." 2311 ) 2312 2313 if len(mapping["old_name"]) != len(mapping["new_name"]): 2314 raise ValueError( 2315 "Mapping dictionary lists 'old_name' and 'new_name' " 2316 "must have the same length!" 2317 ) 2318 2319 names_vector = list(self.subclusters_.subclusters) 2320 2321 if not all(elem in names_vector for elem in list(mapping["old_name"])): 2322 raise ValueError( 2323 "Some entries from 'old_name' do not exist in the subcluster names." 2324 ) 2325 2326 replace_dict = dict(zip(mapping["old_name"], mapping["new_name"])) 2327 2328 names_vector_ret = [replace_dict.get(item, item) for item in names_vector] 2329 2330 self.subclusters_.subclusters = names_vector_ret 2331 2332 def save_sparse( 2333 self, 2334 path_to_save: str = os.getcwd(), 2335 name_slot: str = "cell_names", 2336 data_slot: str = "normalized", 2337 ): 2338 """ 2339 Export data as 10x-compatible sparse files (matrix.mtx, barcodes.tsv, genes.tsv). 2340 2341 Parameters 2342 ---------- 2343 path_to_save : str, default current working directory 2344 Directory where files will be written. 2345 2346 name_slot : str, default 'cell_names' 2347 Metadata column providing cell names for barcodes.tsv. 2348 2349 data_slot : str, default 'normalized' 2350 Either 'normalized' (uses self.normalized_data) or 'count' (uses self.input_data). 2351 2352 Raises 2353 ------ 2354 ValueError 2355 If `data_slot` is not 'normalized' or 'count'. 2356 """ 2357 2358 names = self.input_metadata[name_slot] 2359 2360 if data_slot.lower() == "normalized": 2361 2362 features = list(self.normalized_data.index) 2363 mtx = sparse.csr_matrix(self.normalized_data) 2364 2365 elif data_slot.lower() == "count": 2366 2367 features = list(self.input_data.index) 2368 mtx = sparse.csr_matrix(self.input_data) 2369 2370 else: 2371 raise ValueError("'data_slot' must be included in 'normalized' or 'count'") 2372 2373 os.makedirs(path_to_save, exist_ok=True) 2374 2375 mmwrite(os.path.join(path_to_save, "matrix.mtx"), mtx) 2376 2377 pd.Series(names).to_csv( 2378 os.path.join(path_to_save, "barcodes.tsv"), 2379 index=False, 2380 header=False, 2381 sep="\t", 2382 ) 2383 2384 pd.Series(features).to_csv( 2385 os.path.join(path_to_save, "genes.tsv"), index=False, header=False, sep="\t" 2386 ) 2387 2388 def normalize_counts( 2389 self, normalize_factor: int = 100000, log_transform: bool = True 2390 ): 2391 """ 2392 Normalize raw counts to counts-per-(normalize_factor) 2393 (e.g., CPM, TPM - depending on normalize_factor). 2394 2395 Parameters 2396 ---------- 2397 normalize_factor : int, default 100000 2398 Scaling factor used after dividing by column sums. 2399 2400 log_transform : bool, default True 2401 If True, apply log2(x+1) transformation to normalized values. 2402 2403 Update 2404 ------- 2405 Sets `self.normalized_data` to normalized values (fills NaNs with 0). 2406 2407 Raises 2408 ------ 2409 ValueError 2410 If `self.input_data` is missing (cannot normalize). 2411 """ 2412 if self.input_data is None: 2413 raise ValueError("Input data is missing, cannot normalize.") 2414 2415 sum_col = self.input_data.sum() 2416 self.normalized_data = self.input_data.div(sum_col).fillna(0) * normalize_factor 2417 2418 if log_transform: 2419 # log2(x + 1) to avoid -inf for zeros 2420 self.normalized_data = np.log2(self.normalized_data + 1) 2421 2422 def statistic( 2423 self, 2424 cells=None, 2425 sets=None, 2426 min_exp: float = 0.01, 2427 min_pct: float = 0.1, 2428 n_proc: int = 10, 2429 ): 2430 """ 2431 Compute per-feature statistics (Mann–Whitney U) comparing target vs rest. 2432 2433 This is a wrapper similar to `calc_DEG` tailored to use `self.normalized_data` 2434 and `self.input_metadata`. It returns per-feature statistics including p-values, 2435 adjusted p-values, means, variances, effect-size measures and fold-changes. 2436 2437 Parameters 2438 ---------- 2439 cells : list, 'All', dict, or None 2440 Defines the target cells or groups for comparison (several modes supported). 2441 2442 sets : 'All', dict, or None 2443 Alternative grouping mode (operate on `self.input_metadata['sets']`). 2444 2445 min_exp : float, default 0.01 2446 Minimum expression threshold used when filtering features. 2447 2448 min_pct : float, default 0.1 2449 Minimum proportion of expressing cells in the target group required to test a feature. 2450 2451 n_proc : int, default 10 2452 Number of parallel jobs to use. 2453 2454 Returns 2455 ------- 2456 pandas.DataFrame or dict 2457 Results DataFrame (or dict containing valid/control cells + DataFrame), 2458 similar to `calc_DEG` interface. 2459 2460 Raises 2461 ------ 2462 ValueError 2463 If neither `cells` nor `sets` is provided, or input metadata mismatch occurs. 2464 2465 Notes 2466 ----- 2467 Multiple modes supported: single-list entities, 'All', pairwise dicts, etc. 2468 """ 2469 2470 def stat_calc(choose, feature_name): 2471 target_values = choose.loc[choose["DEG"] == "target", feature_name] 2472 rest_values = choose.loc[choose["DEG"] == "rest", feature_name] 2473 2474 pct_valid = (target_values > 0).sum() / len(target_values) 2475 pct_rest = (rest_values > 0).sum() / len(rest_values) 2476 2477 avg_valid = np.mean(target_values) 2478 avg_ctrl = np.mean(rest_values) 2479 sd_valid = np.std(target_values, ddof=1) 2480 sd_ctrl = np.std(rest_values, ddof=1) 2481 esm = (avg_valid - avg_ctrl) / np.sqrt(((sd_valid**2 + sd_ctrl**2) / 2)) 2482 2483 if np.sum(target_values) == np.sum(rest_values): 2484 p_val = 1.0 2485 else: 2486 _, p_val = stats.mannwhitneyu( 2487 target_values, rest_values, alternative="two-sided" 2488 ) 2489 2490 return { 2491 "feature": feature_name, 2492 "p_val": p_val, 2493 "pct_valid": pct_valid, 2494 "pct_ctrl": pct_rest, 2495 "avg_valid": avg_valid, 2496 "avg_ctrl": avg_ctrl, 2497 "sd_valid": sd_valid, 2498 "sd_ctrl": sd_ctrl, 2499 "esm": esm, 2500 } 2501 2502 def prepare_and_run_stat( 2503 choose, valid_group, min_exp, min_pct, n_proc, factors 2504 ): 2505 2506 tmp_dat = choose[choose["DEG"] == "target"] 2507 tmp_dat = tmp_dat.drop("DEG", axis=1) 2508 2509 counts = (tmp_dat > min_exp).sum(axis=0) 2510 2511 total_count = tmp_dat.shape[0] 2512 2513 info = pd.DataFrame( 2514 {"feature": list(tmp_dat.columns), "pct": list(counts / total_count)} 2515 ) 2516 2517 del tmp_dat 2518 2519 drop_col = info["feature"][info["pct"] <= min_pct] 2520 2521 if len(drop_col) + 1 == len(choose.columns): 2522 drop_col = info["feature"][info["pct"] == 0] 2523 2524 del info 2525 2526 choose = choose.drop(list(drop_col), axis=1) 2527 2528 results = Parallel(n_jobs=n_proc)( 2529 delayed(stat_calc)(choose, feature) 2530 for feature in tqdm(choose.columns[choose.columns != "DEG"]) 2531 ) 2532 2533 df = pd.DataFrame(results) 2534 df["valid_group"] = valid_group 2535 df.sort_values(by="p_val", inplace=True) 2536 2537 num_tests = len(df) 2538 df["adj_pval"] = np.minimum( 2539 1, (df["p_val"] * num_tests) / np.arange(1, num_tests + 1) 2540 ) 2541 2542 factors_df = factors.to_frame(name="factor") 2543 2544 df = df.merge(factors_df, left_on="feature", right_index=True, how="left") 2545 2546 df["FC"] = (df["avg_valid"] + df["factor"]) / ( 2547 df["avg_ctrl"] + df["factor"] 2548 ) 2549 df["log(FC)"] = np.log2(df["FC"]) 2550 df["norm_diff"] = df["avg_valid"] - df["avg_ctrl"] 2551 df = df.drop(columns=["factor"]) 2552 2553 return df 2554 2555 choose = self.normalized_data.copy().T 2556 factors = self.normalized_data.copy().replace(0, np.nan).min(axis=1) 2557 factors[factors == 0] = min(factors[factors != 0]) 2558 2559 final_results = [] 2560 2561 if isinstance(cells, list) and sets is None: 2562 print("\nAnalysis started...\nComparing selected cells to the whole set...") 2563 choose.index = ( 2564 self.input_metadata["cell_names"] + " # " + self.input_metadata["sets"] 2565 ) 2566 2567 if "#" not in cells[0]: 2568 choose.index = self.input_metadata["cell_names"] 2569 2570 print( 2571 "Not include the set info (name # set) in the 'cells' list. " 2572 "Only the names will be compared, without considering the set information." 2573 ) 2574 2575 labels = ["target" if idx in cells else "rest" for idx in choose.index] 2576 valid = list( 2577 set(choose.index[[i for i, x in enumerate(labels) if x == "target"]]) 2578 ) 2579 2580 choose["DEG"] = labels 2581 choose = choose[choose["DEG"] != "drop"] 2582 2583 result_df = prepare_and_run_stat( 2584 choose.reset_index(drop=True), 2585 valid_group=valid, 2586 min_exp=min_exp, 2587 min_pct=min_pct, 2588 n_proc=n_proc, 2589 factors=factors, 2590 ) 2591 return {"valid_cells": valid, "control_cells": "rest", "DEG": result_df} 2592 2593 elif cells == "All" and sets is None: 2594 print("\nAnalysis started...\nComparing each type of cell to others...") 2595 choose.index = ( 2596 self.input_metadata["cell_names"] + " # " + self.input_metadata["sets"] 2597 ) 2598 unique_labels = set(choose.index) 2599 2600 for label in tqdm(unique_labels): 2601 print(f"\nCalculating statistics for {label}") 2602 labels = ["target" if idx == label else "rest" for idx in choose.index] 2603 choose["DEG"] = labels 2604 choose = choose[choose["DEG"] != "drop"] 2605 result_df = prepare_and_run_stat( 2606 choose.copy(), 2607 valid_group=label, 2608 min_exp=min_exp, 2609 min_pct=min_pct, 2610 n_proc=n_proc, 2611 factors=factors, 2612 ) 2613 final_results.append(result_df) 2614 2615 return pd.concat(final_results, ignore_index=True) 2616 2617 elif cells is None and sets == "All": 2618 print("\nAnalysis started...\nComparing each set/group to others...") 2619 choose.index = self.input_metadata["sets"] 2620 unique_sets = set(choose.index) 2621 2622 for label in tqdm(unique_sets): 2623 print(f"\nCalculating statistics for {label}") 2624 labels = ["target" if idx == label else "rest" for idx in choose.index] 2625 2626 choose["DEG"] = labels 2627 choose = choose[choose["DEG"] != "drop"] 2628 result_df = prepare_and_run_stat( 2629 choose.copy(), 2630 valid_group=label, 2631 min_exp=min_exp, 2632 min_pct=min_pct, 2633 n_proc=n_proc, 2634 factors=factors, 2635 ) 2636 final_results.append(result_df) 2637 2638 return pd.concat(final_results, ignore_index=True) 2639 2640 elif cells is None and isinstance(sets, dict): 2641 print("\nAnalysis started...\nComparing groups...") 2642 2643 choose.index = self.input_metadata["sets"] 2644 2645 group_list = list(sets.keys()) 2646 if len(group_list) != 2: 2647 print("Only pairwise group comparison is supported.") 2648 return None 2649 2650 labels = [ 2651 ( 2652 "target" 2653 if idx in sets[group_list[0]] 2654 else "rest" if idx in sets[group_list[1]] else "drop" 2655 ) 2656 for idx in choose.index 2657 ] 2658 choose["DEG"] = labels 2659 choose = choose[choose["DEG"] != "drop"] 2660 2661 result_df = prepare_and_run_stat( 2662 choose.reset_index(drop=True), 2663 valid_group=group_list[0], 2664 min_exp=min_exp, 2665 min_pct=min_pct, 2666 n_proc=n_proc, 2667 factors=factors, 2668 ) 2669 return result_df 2670 2671 elif isinstance(cells, dict) and sets is None: 2672 print("\nAnalysis started...\nComparing groups...") 2673 choose.index = ( 2674 self.input_metadata["cell_names"] + " # " + self.input_metadata["sets"] 2675 ) 2676 2677 if "#" not in cells[list(cells.keys())[0]][0]: 2678 choose.index = self.input_metadata["cell_names"] 2679 2680 print( 2681 "Not include the set info (name # set) in the 'cells' dict. " 2682 "Only the names will be compared, without considering the set information." 2683 ) 2684 2685 group_list = list(cells.keys()) 2686 if len(group_list) != 2: 2687 print("Only pairwise group comparison is supported.") 2688 return None 2689 2690 labels = [ 2691 ( 2692 "target" 2693 if idx in cells[group_list[0]] 2694 else "rest" if idx in cells[group_list[1]] else "drop" 2695 ) 2696 for idx in choose.index 2697 ] 2698 2699 choose["DEG"] = labels 2700 choose = choose[choose["DEG"] != "drop"] 2701 2702 result_df = prepare_and_run_stat( 2703 choose.reset_index(drop=True), 2704 valid_group=group_list[0], 2705 min_exp=min_exp, 2706 min_pct=min_pct, 2707 n_proc=n_proc, 2708 factors=factors, 2709 ) 2710 2711 return result_df.reset_index(drop=True) 2712 2713 else: 2714 raise ValueError( 2715 "You must specify either 'cells' or 'sets' (or both). None were provided, which is not allowed for this analysis." 2716 ) 2717 2718 def calculate_difference_markers( 2719 self, min_exp=0, min_pct=0.25, n_proc=10, force: bool = False 2720 ): 2721 """ 2722 Compute differential markers (var_data) if not already present. 2723 2724 Parameters 2725 ---------- 2726 min_exp : float, default 0 2727 Minimum expression threshold passed to `statistic`. 2728 2729 min_pct : float, default 0.25 2730 Minimum percent expressed in target group. 2731 2732 n_proc : int, default 10 2733 Parallel jobs. 2734 2735 force : bool, default False 2736 If True, recompute even if `self.var_data` is present. 2737 2738 Update 2739 ------- 2740 Sets `self.var_data` to the result of `self.statistic(...)`. 2741 2742 Raise 2743 ------ 2744 ValueError if already computed and `force` is False. 2745 """ 2746 2747 if self.var_data is None or force: 2748 2749 self.var_data = self.statistic( 2750 cells="All", sets=None, min_exp=min_exp, min_pct=min_pct, n_proc=n_proc 2751 ) 2752 2753 else: 2754 raise ValueError( 2755 "self.calculate_difference_markers() has already been executed. " 2756 "The results are stored in self.var. " 2757 "If you want to recalculate with different statistics, please rerun the method with force=True." 2758 ) 2759 2760 def clustering_features( 2761 self, 2762 features_list: list | None, 2763 name_slot: str = "cell_names", 2764 p_val: float = 0.05, 2765 top_n: int = 25, 2766 adj_mean: bool = True, 2767 beta: float = 0.2, 2768 ): 2769 """ 2770 Prepare clustering input by selecting marker features and optionally smoothing cell values 2771 toward group means. 2772 2773 Parameters 2774 ---------- 2775 features_list : list or None 2776 If provided, use this list of features. If None, features are selected 2777 from `self.var_data` (adj_pval <= p_val, positive logFC) picking `top_n` per group. 2778 2779 name_slot : str, default 'cell_names' 2780 Metadata column used for naming. 2781 2782 p_val : float, default 0.05 2783 Adjusted p-value cutoff when selecting features automatically. 2784 2785 top_n : int, default 25 2786 Number of top features per valid group to keep if `features_list` is None. 2787 2788 adj_mean : bool, default True 2789 If True, adjust cell values toward group means using `beta`. 2790 2791 beta : float, default 0.2 2792 Adjustment strength toward group mean. 2793 2794 Update 2795 ------ 2796 Sets `self.clustering_data` and `self.clustering_metadata` to the selected subset, 2797 ready for PCA/UMAP/clustering. 2798 """ 2799 2800 if features_list is None or len(features_list) == 0: 2801 2802 if self.var_data is None: 2803 raise ValueError( 2804 "Lack of 'self.var_data'. Use self.calculate_difference_markers() method first." 2805 ) 2806 2807 df_tmp = self.var_data[self.var_data["adj_pval"] <= p_val] 2808 df_tmp = df_tmp[df_tmp["log(FC)"] > 0] 2809 df_tmp = ( 2810 df_tmp.sort_values( 2811 ["valid_group", "esm", "log(FC)"], ascending=[True, False, False] 2812 ) 2813 .groupby("valid_group") 2814 .head(top_n) 2815 ) 2816 2817 feaures_list = list(set(df_tmp["feature"])) 2818 2819 data = self.get_partial_data( 2820 names=None, features=feaures_list, name_slot=name_slot 2821 ) 2822 data_avg = average(data) 2823 2824 if adj_mean: 2825 data = adjust_cells_to_group_mean(data=data, data_avg=data_avg, beta=beta) 2826 2827 self.clustering_data = data 2828 2829 self.clustering_metadata = self.input_metadata 2830 2831 def average(self): 2832 """ 2833 Aggregate normalized data by (cell_name, set) pairs computing the mean per group. 2834 2835 The method constructs new column names as "cell_name # set", averages columns 2836 sharing identical labels, and populates `self.agg_normalized_data` and `self.agg_metadata`. 2837 2838 Update 2839 ------ 2840 Sets `self.agg_normalized_data` (features x aggregated samples) and 2841 `self.agg_metadata` (DataFrame with 'cell_names' and 'sets'). 2842 """ 2843 2844 wide_data = self.normalized_data 2845 2846 wide_metadata = self.input_metadata 2847 2848 new_names = wide_metadata["cell_names"] + " # " + wide_metadata["sets"] 2849 2850 wide_data.columns = list(new_names) 2851 2852 aggregated_df = wide_data.T.groupby(wide_data.columns, axis=0).mean().T 2853 2854 sets = [re.sub(".*# ", "", x) for x in aggregated_df.columns] 2855 names = [re.sub(" #.*", "", x) for x in aggregated_df.columns] 2856 2857 aggregated_df.columns = names 2858 aggregated_metadata = pd.DataFrame({"cell_names": names, "sets": sets}) 2859 2860 self.agg_metadata = aggregated_metadata 2861 self.agg_normalized_data = aggregated_df 2862 2863 def estimating_similarity( 2864 self, method="pearson", p_val: float = 0.05, top_n: int = 25 2865 ): 2866 """ 2867 Estimate pairwise similarity and Euclidean distance between aggregated samples. 2868 2869 Parameters 2870 ---------- 2871 method : str, default 'pearson' 2872 Correlation method to use (passed to pandas.DataFrame.corr()). 2873 2874 p_val : float, default 0.05 2875 Adjusted p-value cutoff used to select marker features from `self.var_data`. 2876 2877 top_n : int, default 25 2878 Number of top features per valid group to include. 2879 2880 Update 2881 ------- 2882 Computes a combined table with per-pair correlation and euclidean distance 2883 and stores it in `self.similarity`. 2884 """ 2885 2886 if self.var_data is None: 2887 raise ValueError( 2888 "Lack of 'self.var_data'. Use self.calculate_difference_markers() method first." 2889 ) 2890 2891 if self.agg_normalized_data is None: 2892 self.average() 2893 2894 metadata = self.agg_metadata 2895 data = self.agg_normalized_data 2896 2897 df_tmp = self.var_data[self.var_data["adj_pval"] <= p_val] 2898 df_tmp = df_tmp[df_tmp["log(FC)"] > 0] 2899 df_tmp = ( 2900 df_tmp.sort_values( 2901 ["valid_group", "esm", "log(FC)"], ascending=[True, False, False] 2902 ) 2903 .groupby("valid_group") 2904 .head(top_n) 2905 ) 2906 2907 data = data.loc[list(set(df_tmp["feature"]))] 2908 2909 if len(set(metadata["sets"])) > 1: 2910 data.columns = data.columns + " # " + [x for x in metadata["sets"]] 2911 else: 2912 data = data.copy() 2913 2914 scaler = StandardScaler() 2915 2916 scaled_data = scaler.fit_transform(data) 2917 2918 scaled_df = pd.DataFrame(scaled_data, columns=data.columns) 2919 2920 cor = scaled_df.corr(method=method) 2921 cor_df = cor.stack().reset_index() 2922 cor_df.columns = ["cell1", "cell2", "correlation"] 2923 2924 distances = pdist(scaled_df.T, metric="euclidean") 2925 dist_mat = pd.DataFrame( 2926 squareform(distances), index=scaled_df.columns, columns=scaled_df.columns 2927 ) 2928 dist_df = dist_mat.stack().reset_index() 2929 dist_df.columns = ["cell1", "cell2", "euclidean_dist"] 2930 2931 full = pd.merge(cor_df, dist_df, on=["cell1", "cell2"]) 2932 2933 full = full[full["cell1"] != full["cell2"]] 2934 full = full.reset_index(drop=True) 2935 2936 self.similarity = full 2937 2938 def similarity_plot( 2939 self, 2940 split_sets=True, 2941 set_info: bool = True, 2942 cmap="seismic", 2943 width=12, 2944 height=10, 2945 ): 2946 """ 2947 Visualize pairwise similarity as a scatter plot. 2948 2949 Parameters 2950 ---------- 2951 split_sets : bool, default True 2952 If True and set information is present, split plotting area roughly into two halves to visualize cross-set pairs. 2953 2954 set_info : bool, default True 2955 If True, keep the ' # set' annotation in labels; otherwise strip it. 2956 2957 cmap : str, default 'seismic' 2958 Color map for correlation (hue). 2959 2960 width : int, default 12 2961 Figure width. 2962 2963 height : int, default 10 2964 Figure height. 2965 2966 Returns 2967 ------- 2968 matplotlib.figure.Figure 2969 2970 Raises 2971 ------ 2972 ValueError 2973 If `self.similarity` is None. 2974 2975 Notes 2976 ----- 2977 The function filters pairs by z-scored euclidean distance > 0 to focus on closer pairs. 2978 """ 2979 2980 if self.similarity is None: 2981 raise ValueError( 2982 "Similarity data is missing. Please calculate similarity using self.estimating_similarity." 2983 ) 2984 2985 similarity_data = self.similarity 2986 2987 if " # " in similarity_data["cell1"][0]: 2988 similarity_data["set1"] = [ 2989 re.sub(".*# ", "", x) for x in similarity_data["cell1"] 2990 ] 2991 similarity_data["set2"] = [ 2992 re.sub(".*# ", "", x) for x in similarity_data["cell2"] 2993 ] 2994 2995 if split_sets and " # " in similarity_data["cell1"][0]: 2996 sets = list( 2997 set(list(similarity_data["set1"]) + list(similarity_data["set2"])) 2998 ) 2999 3000 mm = math.ceil(len(sets) / 2) 3001 3002 x_s = sets[0:mm] 3003 y_s = sets[mm : len(sets)] 3004 3005 similarity_data = similarity_data[similarity_data["set1"].isin(x_s)] 3006 similarity_data = similarity_data[similarity_data["set2"].isin(y_s)] 3007 3008 similarity_data = similarity_data.sort_values(["set1", "set2"]) 3009 3010 if set_info is False and " # " in similarity_data["cell1"][0]: 3011 similarity_data["cell1"] = [ 3012 re.sub(" #.*", "", x) for x in similarity_data["cell1"] 3013 ] 3014 similarity_data["cell2"] = [ 3015 re.sub(" #.*", "", x) for x in similarity_data["cell2"] 3016 ] 3017 3018 similarity_data["-euclidean_zscore"] = -zscore( 3019 similarity_data["euclidean_dist"] 3020 ) 3021 3022 similarity_data = similarity_data[similarity_data["-euclidean_zscore"] > 0] 3023 3024 fig = plt.figure(figsize=(width, height)) 3025 sns.scatterplot( 3026 data=similarity_data, 3027 x="cell1", 3028 y="cell2", 3029 hue="correlation", 3030 size="-euclidean_zscore", 3031 sizes=(1, 100), 3032 palette=cmap, 3033 alpha=1, 3034 edgecolor="black", 3035 ) 3036 3037 plt.xticks(rotation=90) 3038 plt.yticks(rotation=0) 3039 plt.xlabel("Cell 1") 3040 plt.ylabel("Cell 2") 3041 plt.legend(bbox_to_anchor=(1.05, 1), loc="upper left") 3042 3043 plt.grid(True, alpha=0.6) 3044 3045 plt.tight_layout() 3046 3047 return fig 3048 3049 def spatial_similarity( 3050 self, 3051 set_info: bool = True, 3052 bandwidth=1, 3053 n_neighbors=5, 3054 min_dist=0.1, 3055 legend_split=2, 3056 point_size=100, 3057 spread=1.0, 3058 set_op_mix_ratio=1.0, 3059 local_connectivity=1, 3060 repulsion_strength=1.0, 3061 negative_sample_rate=5, 3062 threshold=0.1, 3063 width=12, 3064 height=10, 3065 ): 3066 """ 3067 Create a spatial UMAP-like visualization of similarity relationships between samples. 3068 3069 Parameters 3070 ---------- 3071 set_info : bool, default True 3072 If True, retain set information in labels. 3073 3074 bandwidth : float, default 1 3075 Bandwidth used by MeanShift for clustering polygons. 3076 3077 point_size : float, default 100 3078 Size of scatter points. 3079 3080 legend_split : int, default 2 3081 Number of columns in legend. 3082 3083 n_neighbors, min_dist, spread, set_op_mix_ratio, local_connectivity, repulsion_strength, negative_sample_rate : parameters passed to UMAP. 3084 3085 threshold : float, default 0.1 3086 Minimum text distance for label adjustment to avoid overlap. 3087 3088 width : int, default 12 3089 Figure width. 3090 3091 height : int, default 10 3092 Figure height. 3093 3094 Returns 3095 ------- 3096 matplotlib.figure.Figure 3097 3098 Raises 3099 ------ 3100 ValueError 3101 If `self.similarity` is None. 3102 3103 Notes 3104 ----- 3105 Builds a precomputed distance matrix combining correlation and euclidean distance, 3106 runs UMAP with metric='precomputed', then overlays cluster hulls (MeanShift + convex hull) 3107 and arrows to indicate nearest neighbors (minimal combined distance). 3108 """ 3109 3110 if self.similarity is None: 3111 raise ValueError( 3112 "Similarity data is missing. Please calculate similarity using self.estimating_similarity." 3113 ) 3114 3115 similarity_data = self.similarity 3116 3117 sim = similarity_data["correlation"] 3118 sim_scaled = (sim - sim.min()) / (sim.max() - sim.min()) 3119 eu_dist = similarity_data["euclidean_dist"] 3120 eu_dist_scaled = (eu_dist - eu_dist.min()) / (eu_dist.max() - eu_dist.min()) 3121 3122 similarity_data["combo_dist"] = (1 - sim_scaled) * eu_dist_scaled 3123 3124 # for nn target 3125 arrow_df = similarity_data.copy() 3126 arrow_df = similarity_data.loc[ 3127 similarity_data.groupby("cell1")["combo_dist"].idxmin() 3128 ] 3129 3130 cells = sorted(set(similarity_data["cell1"]) | set(similarity_data["cell2"])) 3131 combo_matrix = pd.DataFrame(0, index=cells, columns=cells, dtype=float) 3132 3133 for _, row in similarity_data.iterrows(): 3134 combo_matrix.loc[row["cell1"], row["cell2"]] = row["combo_dist"] 3135 combo_matrix.loc[row["cell2"], row["cell1"]] = row["combo_dist"] 3136 3137 umap_model = umap.UMAP( 3138 n_components=2, 3139 metric="precomputed", 3140 n_neighbors=n_neighbors, 3141 min_dist=min_dist, 3142 spread=spread, 3143 set_op_mix_ratio=set_op_mix_ratio, 3144 local_connectivity=set_op_mix_ratio, 3145 repulsion_strength=repulsion_strength, 3146 negative_sample_rate=negative_sample_rate, 3147 transform_seed=42, 3148 init="spectral", 3149 random_state=42, 3150 verbose=True, 3151 ) 3152 3153 coords = umap_model.fit_transform(combo_matrix.values) 3154 cell_names = list(combo_matrix.index) 3155 num_cells = len(cell_names) 3156 palette = sns.color_palette("tab20c", num_cells) 3157 3158 if "#" in cell_names[0]: 3159 avsets = set( 3160 [re.sub(".*# ", "", x) for x in similarity_data["cell1"]] 3161 + [re.sub(".*# ", "", x) for x in similarity_data["cell2"]] 3162 ) 3163 num_sets = len(avsets) 3164 color_indices = [i * len(palette) // num_sets for i in range(num_sets)] 3165 color_mapping_sets = { 3166 set_name: palette[i] for i, set_name in zip(color_indices, avsets) 3167 } 3168 color_mapping = { 3169 name: color_mapping_sets[re.sub(".*# ", "", name)] 3170 for i, name in enumerate(cell_names) 3171 } 3172 else: 3173 color_mapping = {name: palette[i] for i, name in enumerate(cell_names)} 3174 3175 meanshift = MeanShift(bandwidth=bandwidth) 3176 labels = meanshift.fit_predict(coords) 3177 3178 fig = plt.figure(figsize=(width, height)) 3179 ax = plt.gca() 3180 3181 unique_labels = set(labels) 3182 cluster_palette = sns.color_palette("hls", len(unique_labels)) 3183 3184 for label in unique_labels: 3185 if label == -1: 3186 continue 3187 cluster_coords = coords[labels == label] 3188 if len(cluster_coords) < 3: 3189 continue 3190 3191 hull = ConvexHull(cluster_coords) 3192 hull_points = cluster_coords[hull.vertices] 3193 3194 centroid = np.mean(hull_points, axis=0) 3195 expanded = hull_points + 0.05 * (hull_points - centroid) 3196 3197 poly = Polygon( 3198 expanded, 3199 closed=True, 3200 facecolor=cluster_palette[label], 3201 edgecolor="none", 3202 alpha=0.2, 3203 zorder=1, 3204 ) 3205 ax.add_patch(poly) 3206 3207 texts = [] 3208 for i, (x, y) in enumerate(coords): 3209 plt.scatter( 3210 x, 3211 y, 3212 s=point_size, 3213 color=color_mapping[cell_names[i]], 3214 edgecolors="black", 3215 linewidths=0.5, 3216 zorder=2, 3217 ) 3218 texts.append( 3219 ax.text( 3220 x, y, str(i), ha="center", va="center", fontsize=8, color="black" 3221 ) 3222 ) 3223 3224 def dist(p1, p2): 3225 return np.sqrt((p1[0] - p2[0]) ** 2 + (p1[1] - p2[1]) ** 2) 3226 3227 texts_to_adjust = [] 3228 for i, t1 in enumerate(texts): 3229 for j, t2 in enumerate(texts): 3230 if i >= j: 3231 continue 3232 d = dist( 3233 (t1.get_position()[0], t1.get_position()[1]), 3234 (t2.get_position()[0], t2.get_position()[1]), 3235 ) 3236 if d < threshold: 3237 if t1 not in texts_to_adjust: 3238 texts_to_adjust.append(t1) 3239 if t2 not in texts_to_adjust: 3240 texts_to_adjust.append(t2) 3241 3242 adjust_text( 3243 texts_to_adjust, 3244 expand_text=(1.0, 1.0), 3245 force_text=0.9, 3246 arrowprops=dict(arrowstyle="-", color="gray", lw=0.1), 3247 ax=ax, 3248 ) 3249 3250 for _, row in arrow_df.iterrows(): 3251 try: 3252 idx1 = cell_names.index(row["cell1"]) 3253 idx2 = cell_names.index(row["cell2"]) 3254 except ValueError: 3255 continue 3256 x1, y1 = coords[idx1] 3257 x2, y2 = coords[idx2] 3258 arrow = FancyArrowPatch( 3259 (x1, y1), 3260 (x2, y2), 3261 arrowstyle="->", 3262 color="gray", 3263 linewidth=1.5, 3264 alpha=0.5, 3265 mutation_scale=12, 3266 zorder=0, 3267 ) 3268 ax.add_patch(arrow) 3269 3270 if set_info is False and " # " in cell_names[0]: 3271 3272 legend_elements = [ 3273 Patch( 3274 facecolor=color_mapping[name], 3275 edgecolor="black", 3276 label=f"{i} – {re.sub(' #.*', '', name)}", 3277 ) 3278 for i, name in enumerate(cell_names) 3279 ] 3280 3281 else: 3282 3283 legend_elements = [ 3284 Patch( 3285 facecolor=color_mapping[name], 3286 edgecolor="black", 3287 label=f"{i} – {name}", 3288 ) 3289 for i, name in enumerate(cell_names) 3290 ] 3291 3292 plt.legend( 3293 handles=legend_elements, 3294 title="Cells", 3295 bbox_to_anchor=(1.05, 1), 3296 loc="upper left", 3297 ncol=legend_split, 3298 ) 3299 3300 plt.xlabel("UMAP 1") 3301 plt.ylabel("UMAP 2") 3302 plt.grid(False) 3303 plt.show() 3304 3305 return fig 3306 3307 # subclusters part 3308 3309 def subcluster_prepare(self, features: list, cluster: str): 3310 """ 3311 Prepare a `Clustering` object for subcluster analysis on a selected parent cluster. 3312 3313 Parameters 3314 ---------- 3315 features : list 3316 Features to include for subcluster analysis. 3317 3318 cluster : str 3319 Parent cluster name (used to select matching cells). 3320 3321 Update 3322 ------ 3323 Initializes `self.subclusters_` as a new `Clustering` instance containing the 3324 reduced data for the given cluster and stores `current_features` and `current_cluster`. 3325 """ 3326 3327 dat = self.normalized_data 3328 dat.columns = list(self.input_metadata["cell_names"]) 3329 3330 dat = reduce_data(self.normalized_data, features=features, names=[cluster]) 3331 3332 self.subclusters_ = Clustering(data=dat, metadata=None) 3333 3334 self.subclusters_.current_features = features 3335 self.subclusters_.current_cluster = cluster 3336 3337 def define_subclusters( 3338 self, 3339 umap_num: int = 2, 3340 eps: float = 0.5, 3341 min_samples: int = 10, 3342 n_neighbors: int = 5, 3343 min_dist: float = 0.1, 3344 spread: float = 1.0, 3345 set_op_mix_ratio: float = 1.0, 3346 local_connectivity: int = 1, 3347 repulsion_strength: float = 1.0, 3348 negative_sample_rate: int = 5, 3349 width=8, 3350 height=6, 3351 ): 3352 """ 3353 Compute UMAP and DBSCAN clustering within a previously prepared subcluster dataset. 3354 3355 Parameters 3356 ---------- 3357 umap_num : int, default 2 3358 Number of UMAP dimensions to compute. 3359 3360 eps : float, default 0.5 3361 DBSCAN eps parameter. 3362 3363 min_samples : int, default 10 3364 DBSCAN min_samples parameter. 3365 3366 n_neighbors, min_dist, spread, set_op_mix_ratio, local_connectivity, repulsion_strength, negative_sample_rate, width, height : 3367 Additional parameters passed to UMAP / plotting / MeanShift as appropriate. 3368 3369 Update 3370 ------- 3371 Stores cluster labels in `self.subclusters_.subclusters`. 3372 3373 Raises 3374 ------ 3375 RuntimeError 3376 If `self.subclusters_` has not been prepared. 3377 """ 3378 3379 if self.subclusters_ is None: 3380 raise RuntimeError( 3381 "Nothing to return. 'self.subcluster_prepare' was not conducted!" 3382 ) 3383 3384 self.subclusters_.perform_UMAP( 3385 factorize=False, 3386 umap_num=umap_num, 3387 pc_num=0, 3388 harmonized=False, 3389 n_neighbors=n_neighbors, 3390 min_dist=min_dist, 3391 spread=spread, 3392 set_op_mix_ratio=set_op_mix_ratio, 3393 local_connectivity=local_connectivity, 3394 repulsion_strength=repulsion_strength, 3395 negative_sample_rate=negative_sample_rate, 3396 width=width, 3397 height=height, 3398 ) 3399 3400 fig = self.subclusters_.find_clusters_UMAP( 3401 umap_n=umap_num, 3402 eps=eps, 3403 min_samples=min_samples, 3404 width=width, 3405 height=height, 3406 ) 3407 3408 clusters = self.subclusters_.return_clusters(clusters="umap") 3409 3410 self.subclusters_.subclusters = [str(x) for x in list(clusters)] 3411 3412 return fig 3413 3414 def subcluster_features_scatter( 3415 self, 3416 colors="viridis", 3417 hclust="complete", 3418 scale=False, 3419 img_width=3, 3420 img_high=5, 3421 label_size=6, 3422 size_scale=70, 3423 y_lab="Genes", 3424 legend_lab="normalized", 3425 bbox_to_anchor_scale: int = 25, 3426 bbox_to_anchor_perc: tuple = (0.91, 0.63), 3427 ): 3428 """ 3429 Create a features-scatter visualization for the subclusters (averaged and occurrence). 3430 3431 Parameters 3432 ---------- 3433 colors : str, default 'viridis' 3434 Colormap name passed to `features_scatter`. 3435 3436 hclust : str or None 3437 Hierarchical clustering linkage to order rows/columns. 3438 3439 scale: bool, default False 3440 If True, expression data will be scaled (0–1) across the rows (features). 3441 3442 img_width, img_high : float 3443 Figure size. 3444 3445 label_size : int 3446 Font size for labels. 3447 3448 size_scale : int 3449 Bubble size scaling. 3450 3451 y_lab : str 3452 X axis label. 3453 3454 legend_lab : str 3455 Colorbar label. 3456 3457 bbox_to_anchor_scale : int, default=25 3458 Vertical scale (percentage) for positioning the colorbar. 3459 3460 bbox_to_anchor_perc : tuple, default=(0.91, 0.63) 3461 Anchor position for the size legend (percent bubble legend). 3462 3463 Returns 3464 ------- 3465 matplotlib.figure.Figure 3466 3467 Raises 3468 ------ 3469 RuntimeError 3470 If subcluster preparation/definition has not been run. 3471 """ 3472 3473 if self.subclusters_ is None or self.subclusters_.subclusters is None: 3474 raise RuntimeError( 3475 "Nothing to return. 'self.subcluster_prepare' -> 'self.define_subclusters' pip was not conducted!" 3476 ) 3477 3478 dat = self.normalized_data 3479 dat.columns = list(self.input_metadata["cell_names"]) 3480 3481 dat = reduce_data( 3482 self.normalized_data, 3483 features=self.subclusters_.current_features, 3484 names=[self.subclusters_.current_cluster], 3485 ) 3486 3487 dat.columns = self.subclusters_.subclusters 3488 3489 avg = average(dat) 3490 occ = occurrence(dat) 3491 3492 scatter = features_scatter( 3493 expression_data=avg, 3494 occurence_data=occ, 3495 features=None, 3496 scale=scale, 3497 metadata_list=None, 3498 colors=colors, 3499 hclust=hclust, 3500 img_width=img_width, 3501 img_high=img_high, 3502 label_size=label_size, 3503 size_scale=size_scale, 3504 y_lab=y_lab, 3505 legend_lab=legend_lab, 3506 bbox_to_anchor_scale=bbox_to_anchor_scale, 3507 bbox_to_anchor_perc=bbox_to_anchor_perc, 3508 ) 3509 3510 return scatter 3511 3512 def subcluster_DEG_scatter( 3513 self, 3514 top_n=3, 3515 min_exp=0, 3516 min_pct=0.25, 3517 p_val=0.05, 3518 colors="viridis", 3519 hclust="complete", 3520 scale=False, 3521 img_width=3, 3522 img_high=5, 3523 label_size=6, 3524 size_scale=70, 3525 y_lab="Genes", 3526 legend_lab="normalized", 3527 bbox_to_anchor_scale: int = 25, 3528 bbox_to_anchor_perc: tuple = (0.91, 0.63), 3529 n_proc=10, 3530 ): 3531 """ 3532 Plot top differential features (DEGs) for subclusters as a features-scatter. 3533 3534 Parameters 3535 ---------- 3536 top_n : int, default 3 3537 Number of top features per subcluster to show. 3538 3539 min_exp : float, default 0 3540 Minimum expression threshold passed to `statistic`. 3541 3542 min_pct : float, default 0.25 3543 Minimum percent expressed in target group. 3544 3545 p_val: float, default 0.05 3546 Maximum p-value for visualizing features. 3547 3548 n_proc : int, default 10 3549 Parallel jobs used for DEG calculation. 3550 3551 scale: bool, default False 3552 If True, expression_data will be scaled (0–1) across the rows (features). 3553 3554 colors : str, default='viridis' 3555 Colormap for expression values. 3556 3557 hclust : str or None, default='complete' 3558 Linkage method for hierarchical clustering. If None, no clustering 3559 is performed. 3560 3561 img_width : int or float, default=8 3562 Width of the plot in inches. 3563 3564 img_high : int or float, default=5 3565 Height of the plot in inches. 3566 3567 label_size : int, default=10 3568 Font size for axis labels and ticks. 3569 3570 size_scale : int or float, default=100 3571 Scaling factor for bubble sizes. 3572 3573 y_lab : str, default='Genes' 3574 Label for the x-axis. 3575 3576 legend_lab : str, default='normalized' 3577 Label for the colorbar legend. 3578 3579 bbox_to_anchor_scale : int, default=25 3580 Vertical scale (percentage) for positioning the colorbar. 3581 3582 bbox_to_anchor_perc : tuple, default=(0.91, 0.63) 3583 Anchor position for the size legend (percent bubble legend). 3584 3585 Returns 3586 ------- 3587 matplotlib.figure.Figure 3588 3589 Raises 3590 ------ 3591 RuntimeError 3592 If subcluster preparation/definition has not been run. 3593 3594 Notes 3595 ----- 3596 Internally calls `calc_DEG` (or equivalent) to obtain statistics, filters 3597 by p-value and effect-size, selects top features per valid group and plots them. 3598 """ 3599 3600 if self.subclusters_ is None or self.subclusters_.subclusters is None: 3601 raise RuntimeError( 3602 "Nothing to return. 'self.subcluster_prepare' -> 'self.define_subclusters' pip was not conducted!" 3603 ) 3604 3605 dat = self.normalized_data 3606 dat.columns = list(self.input_metadata["cell_names"]) 3607 3608 dat = reduce_data( 3609 self.normalized_data, names=[self.subclusters_.current_cluster] 3610 ) 3611 3612 dat.columns = self.subclusters_.subclusters 3613 3614 deg_stats = calc_DEG( 3615 dat, 3616 metadata_list=None, 3617 entities="All", 3618 sets=None, 3619 min_exp=min_exp, 3620 min_pct=min_pct, 3621 n_proc=n_proc, 3622 ) 3623 3624 deg_stats = deg_stats[deg_stats["p_val"] <= p_val] 3625 deg_stats = deg_stats[deg_stats["log(FC)"] > 0] 3626 3627 deg_stats = ( 3628 deg_stats.sort_values( 3629 ["valid_group", "esm", "log(FC)"], ascending=[True, False, False] 3630 ) 3631 .groupby("valid_group") 3632 .head(top_n) 3633 ) 3634 3635 dat = reduce_data(dat, features=list(set(deg_stats["feature"]))) 3636 3637 avg = average(dat) 3638 occ = occurrence(dat) 3639 3640 scatter = features_scatter( 3641 expression_data=avg, 3642 occurence_data=occ, 3643 features=None, 3644 metadata_list=None, 3645 colors=colors, 3646 hclust=hclust, 3647 img_width=img_width, 3648 img_high=img_high, 3649 label_size=label_size, 3650 size_scale=size_scale, 3651 y_lab=y_lab, 3652 legend_lab=legend_lab, 3653 bbox_to_anchor_scale=bbox_to_anchor_scale, 3654 bbox_to_anchor_perc=bbox_to_anchor_perc, 3655 ) 3656 3657 return scatter 3658 3659 def accept_subclusters(self): 3660 """ 3661 Commit subcluster labels into the main `input_metadata` by renaming cell names. 3662 3663 The method replaces occurrences of the parent cluster name in `self.input_metadata['cell_names']` 3664 with the expanded names that include subcluster suffixes (via `add_subnames`), 3665 then clears `self.subclusters_`. 3666 3667 Update 3668 ------ 3669 Modifies `self.input_metadata['cell_names']`. 3670 3671 Resets `self.subclusters_` to None. 3672 3673 Raises 3674 ------ 3675 RuntimeError 3676 If `self.subclusters_` is not defined or subclusters were not computed. 3677 """ 3678 3679 if self.subclusters_ is None or self.subclusters_.subclusters is None: 3680 raise RuntimeError( 3681 "Nothing to return. 'self.subcluster_prepare' -> 'self.define_subclusters' pip was not conducted!" 3682 ) 3683 3684 new_meta = add_subnames( 3685 list(self.input_metadata["cell_names"]), 3686 parent_name=self.subclusters_.current_cluster, 3687 new_clusters=self.subclusters_.subclusters, 3688 ) 3689 3690 self.input_metadata["cell_names"] = new_meta 3691 3692 self.subclusters_ = None 3693 3694 def scatter_plot( 3695 self, 3696 names: list | None = None, 3697 features: list | None = None, 3698 name_slot: str = "cell_names", 3699 scale=True, 3700 colors="viridis", 3701 hclust=None, 3702 img_width=15, 3703 img_high=1, 3704 label_size=10, 3705 size_scale=200, 3706 y_lab="Genes", 3707 legend_lab="log(CPM + 1)", 3708 set_box_size: float | int = 5, 3709 set_box_high: float | int = 5, 3710 bbox_to_anchor_scale=25, 3711 bbox_to_anchor_perc=(0.90, -0.24), 3712 bbox_to_anchor_group=(1.01, 0.4), 3713 ): 3714 """ 3715 Create a bubble scatter plot of selected features across samples inside project. 3716 3717 Each point represents a feature-sample pair, where the color encodes the 3718 expression value and the size encodes occurrence or relative abundance. 3719 Optionally, hierarchical clustering can be applied to order rows and columns. 3720 3721 Parameters 3722 ---------- 3723 names : list, str, or None 3724 Names of samples to include. If None, all samples are considered. 3725 3726 features : list, str, or None 3727 Names of features to include. If None, all features are considered. 3728 3729 name_slot : str 3730 Column in metadata to use as sample names. 3731 3732 scale: bool, default False 3733 If True, expression_data will be scaled (0–1) across the rows (features). 3734 3735 colors : str, default='viridis' 3736 Colormap for expression values. 3737 3738 hclust : str or None, default='complete' 3739 Linkage method for hierarchical clustering. If None, no clustering 3740 is performed. 3741 3742 img_width : int or float, default=8 3743 Width of the plot in inches. 3744 3745 img_high : int or float, default=5 3746 Height of the plot in inches. 3747 3748 label_size : int, default=10 3749 Font size for axis labels and ticks. 3750 3751 size_scale : int or float, default=100 3752 Scaling factor for bubble sizes. 3753 3754 y_lab : str, default='Genes' 3755 Label for the x-axis. 3756 3757 legend_lab : str, default='log(CPM + 1)' 3758 Label for the colorbar legend. 3759 3760 bbox_to_anchor_scale : int, default=25 3761 Vertical scale (percentage) for positioning the colorbar. 3762 3763 bbox_to_anchor_perc : tuple, default=(0.91, 0.63) 3764 Anchor position for the size legend (percent bubble legend). 3765 3766 bbox_to_anchor_group : tuple, default=(1.01, 0.4) 3767 Anchor position for the group legend. 3768 3769 Returns 3770 ------- 3771 matplotlib.figure.Figure 3772 The generated scatter plot figure. 3773 3774 Notes 3775 ----- 3776 Colors represent expression values normalized to the colormap. 3777 """ 3778 3779 prtd, met = self.get_partial_data( 3780 names=names, features=features, name_slot=name_slot, inc_metadata=True 3781 ) 3782 3783 if scale: 3784 3785 legend_lab = "Scaled\n" + legend_lab 3786 3787 scaler = MinMaxScaler(feature_range=(0, 1)) 3788 prtd = pd.DataFrame( 3789 scaler.fit_transform(prtd.T).T, 3790 index=prtd.index, 3791 columns=prtd.columns, 3792 ) 3793 3794 prtd.columns = prtd.columns + "#" + met["sets"] 3795 3796 prtd_avg = average(prtd) 3797 3798 meta_sets = [re.sub(".*#", "", x) for x in prtd_avg.columns] 3799 3800 prtd_avg.columns = [re.sub("#.*", "", x) for x in prtd_avg.columns] 3801 3802 prtd_occ = occurrence(prtd) 3803 3804 prtd_occ.columns = [re.sub("#.*", "", x) for x in prtd_occ.columns] 3805 3806 fig_scatter = features_scatter( 3807 expression_data=prtd_avg, 3808 occurence_data=prtd_occ, 3809 scale=scale, 3810 features=None, 3811 metadata_list=meta_sets, 3812 colors=colors, 3813 hclust=hclust, 3814 img_width=img_width, 3815 img_high=img_high, 3816 label_size=label_size, 3817 size_scale=size_scale, 3818 y_lab=y_lab, 3819 legend_lab=legend_lab, 3820 set_box_size=set_box_size, 3821 set_box_high=set_box_high, 3822 bbox_to_anchor_scale=bbox_to_anchor_scale, 3823 bbox_to_anchor_perc=bbox_to_anchor_perc, 3824 bbox_to_anchor_group=bbox_to_anchor_group, 3825 ) 3826 3827 return fig_scatter 3828 3829 def data_composition( 3830 self, 3831 features_count: list | None, 3832 name_slot: str = "cell_names", 3833 set_sep: bool = True, 3834 ): 3835 """ 3836 Compute composition of cell types in data set. 3837 3838 This function counts the occurrences of specific cells (e.g., cell types, subtypes) 3839 within metadata entries, calculates their relative percentages, and stores 3840 the results in `self.composition_data`. 3841 3842 Parameters 3843 ---------- 3844 features_count : list or None 3845 List of features (part or full names) to be counted. 3846 If None, all unique elements from the specified `name_slot` metadata field are used. 3847 3848 name_slot : str, default 'cell_names' 3849 Metadata field containing sample identifiers or labels. 3850 3851 set_sep : bool, default True 3852 If True and multiple sets are present in metadata, compute composition 3853 separately for each set. 3854 3855 Update 3856 ------- 3857 Stores results in `self.composition_data` as a pandas DataFrame with: 3858 - 'name': feature name 3859 - 'n': number of occurrences 3860 - 'pct': percentage of occurrences 3861 - 'set' (if applicable): dataset identifier 3862 """ 3863 3864 validated_list = list(self.input_metadata[name_slot]) 3865 sets = list(self.input_metadata["sets"]) 3866 3867 if features_count is None: 3868 features_count = list(set(self.input_metadata[name_slot])) 3869 3870 if set_sep and len(set(sets)) > 1: 3871 3872 final_res = pd.DataFrame() 3873 3874 for s in set(sets): 3875 print(s) 3876 3877 mask = [True if s == x else False for x in sets] 3878 3879 tmp_val_list = np.array(validated_list) 3880 3881 tmp_val_list = list(tmp_val_list[mask]) 3882 3883 res_dict = {"name": [], "n": [], "set": []} 3884 3885 for f in tqdm(features_count): 3886 res_dict["n"].append( 3887 sum(1 for element in tmp_val_list if f in element) 3888 ) 3889 res_dict["name"].append(f) 3890 res_dict["set"].append(s) 3891 res = pd.DataFrame(res_dict) 3892 res["pct"] = res["n"] / sum(res["n"]) * 100 3893 res["pct"] = res["pct"].round(2) 3894 3895 final_res = pd.concat([final_res, res]) 3896 3897 res = final_res.sort_values(["set", "pct"], ascending=[True, False]) 3898 3899 else: 3900 3901 res_dict = {"name": [], "n": []} 3902 3903 for f in tqdm(features_count): 3904 res_dict["n"].append( 3905 sum(1 for element in validated_list if f in element) 3906 ) 3907 res_dict["name"].append(f) 3908 3909 res = pd.DataFrame(res_dict) 3910 res["pct"] = res["n"] / sum(res["n"]) * 100 3911 res["pct"] = res["pct"].round(2) 3912 3913 res = res.sort_values("pct", ascending=False) 3914 3915 self.composition_data = res 3916 3917 def composition_pie( 3918 self, 3919 width=6, 3920 height=6, 3921 font_size=15, 3922 cmap: str = "tab20", 3923 legend_split_col: int = 1, 3924 offset_labels: float | int = 0.5, 3925 legend_bbox: tuple = (1.15, 0.95), 3926 ): 3927 """ 3928 Visualize the composition of cell lineages using pie charts. 3929 3930 Generates pie charts showing the relative proportions of features stored 3931 in `self.composition_data`. If multiple sets are present, a separate 3932 chart is drawn for each set. 3933 3934 Parameters 3935 ---------- 3936 width : int, default 6 3937 Width of the figure. 3938 3939 height : int, default 6 3940 Height of the figure (applied per set if multiple sets are plotted). 3941 3942 font_size : int, default 15 3943 Font size for labels and annotations. 3944 3945 cmap : str, default 'tab20' 3946 Colormap used for pie slices. 3947 3948 legend_split_col : int, default 1 3949 Number of columns in the legend. 3950 3951 offset_labels : float or int, default 0.5 3952 Spacing offset for label placement relative to pie slices. 3953 3954 legend_bbox : tuple, default (1.15, 0.95) 3955 Bounding box anchor position for the legend. 3956 3957 Returns 3958 ------- 3959 matplotlib.figure.Figure 3960 Pie chart visualization of composition data. 3961 """ 3962 3963 df = self.composition_data 3964 3965 if "set" in df.columns and len(set(df["set"])) > 1: 3966 3967 sets = list(set(df["set"])) 3968 fig, axes = plt.subplots(len(sets), 1, figsize=(width, height * len(sets))) 3969 3970 all_wedges = [] 3971 cmap = plt.get_cmap(cmap) 3972 3973 set_nam = len(set(df["name"])) 3974 3975 legend_labels = list(set(df["name"])) 3976 3977 colors = [cmap(i / set_nam) for i in range(set_nam)] 3978 3979 cmap_dict = dict(zip(legend_labels, colors)) 3980 3981 for idx, s in enumerate(sets): 3982 ax = axes[idx] 3983 tmp_df = df[df["set"] == s].reset_index(drop=True) 3984 3985 labels = [f"{row['pct']:.1f}%" for _, row in tmp_df.iterrows()] 3986 3987 wedges, _ = ax.pie( 3988 tmp_df["n"], 3989 startangle=90, 3990 labeldistance=1.05, 3991 colors=[cmap_dict[x] for x in tmp_df["name"]], 3992 wedgeprops={"linewidth": 0.5, "edgecolor": "black"}, 3993 ) 3994 3995 all_wedges.extend(wedges) 3996 3997 kw = dict(arrowprops=dict(arrowstyle="-"), zorder=0, va="center") 3998 n = 0 3999 for i, p in enumerate(wedges): 4000 ang = (p.theta2 - p.theta1) / 2.0 + p.theta1 4001 y = np.sin(np.deg2rad(ang)) 4002 x = np.cos(np.deg2rad(ang)) 4003 horizontalalignment = {-1: "right", 1: "left"}[int(np.sign(x))] 4004 connectionstyle = f"angle,angleA=0,angleB={ang}" 4005 kw["arrowprops"].update({"connectionstyle": connectionstyle}) 4006 if len(labels[i]) > 0: 4007 n += offset_labels 4008 ax.annotate( 4009 labels[i], 4010 xy=(x, y), 4011 xytext=(1.01 * x + (n * x / 4), 1.01 * y + (n * y / 4)), 4012 horizontalalignment=horizontalalignment, 4013 fontsize=font_size, 4014 weight="bold", 4015 **kw, 4016 ) 4017 4018 circle2 = plt.Circle((0, 0), 0.6, color="white", ec="black") 4019 ax.add_artist(circle2) 4020 4021 ax.text( 4022 0, 4023 0, 4024 f"{s}", 4025 ha="center", 4026 va="center", 4027 fontsize=font_size, 4028 weight="bold", 4029 ) 4030 4031 legend_handles = [ 4032 Patch(facecolor=cmap_dict[label], edgecolor="black", label=label) 4033 for label in legend_labels 4034 ] 4035 4036 fig.legend( 4037 handles=legend_handles, 4038 loc="center right", 4039 bbox_to_anchor=legend_bbox, 4040 ncol=legend_split_col, 4041 title="", 4042 ) 4043 4044 plt.tight_layout() 4045 plt.show() 4046 4047 else: 4048 4049 labels = [f"{row['pct']:.1f}%" for _, row in df.iterrows()] 4050 4051 legend_labels = [f"{row['name']}" for _, row in df.iterrows()] 4052 4053 cmap = plt.get_cmap(cmap) 4054 colors = [cmap(i / len(df)) for i in range(len(df))] 4055 4056 fig, ax = plt.subplots( 4057 figsize=(width, height), subplot_kw=dict(aspect="equal") 4058 ) 4059 4060 wedges, _ = ax.pie( 4061 df["n"], 4062 startangle=90, 4063 labeldistance=1.05, 4064 colors=colors, 4065 wedgeprops={"linewidth": 0.5, "edgecolor": "black"}, 4066 ) 4067 4068 kw = dict(arrowprops=dict(arrowstyle="-"), zorder=0, va="center") 4069 n = 0 4070 for i, p in enumerate(wedges): 4071 ang = (p.theta2 - p.theta1) / 2.0 + p.theta1 4072 y = np.sin(np.deg2rad(ang)) 4073 x = np.cos(np.deg2rad(ang)) 4074 horizontalalignment = {-1: "right", 1: "left"}[int(np.sign(x))] 4075 connectionstyle = "angle,angleA=0,angleB={}".format(ang) 4076 kw["arrowprops"].update({"connectionstyle": connectionstyle}) 4077 if len(labels[i]) > 0: 4078 n += offset_labels 4079 4080 ax.annotate( 4081 labels[i], 4082 xy=(x, y), 4083 xytext=(1.01 * x + (n * x / 4), y * 1.01 + (n * y / 4)), 4084 horizontalalignment=horizontalalignment, 4085 fontsize=font_size, 4086 weight="bold", 4087 **kw, 4088 ) 4089 4090 circle2 = plt.Circle((0, 0), 0.6, color="white") 4091 circle2.set_edgecolor("black") 4092 4093 p = plt.gcf() 4094 p.gca().add_artist(circle2) 4095 4096 ax.legend( 4097 wedges, 4098 legend_labels, 4099 title="", 4100 loc="center left", 4101 bbox_to_anchor=legend_bbox, 4102 ncol=legend_split_col, 4103 ) 4104 4105 plt.show() 4106 4107 return fig 4108 4109 def bar_composition( 4110 self, 4111 cmap="tab20b", 4112 width=2, 4113 height=6, 4114 font_size=15, 4115 legend_split_col: int = 1, 4116 legend_bbox: tuple = (1.3, 1), 4117 ): 4118 """ 4119 Visualize the composition of cell lineages using bar plots. 4120 4121 Produces bar plots showing the distribution of features stored in 4122 `self.composition_data`. If multiple sets are present, a separate 4123 bar is drawn for each set. Percentages are annotated alongside the bars. 4124 4125 Parameters 4126 ---------- 4127 cmap : str, default 'tab20b' 4128 Colormap used for stacked bars. 4129 4130 width : int, default 2 4131 Width of each subplot (per set). 4132 4133 height : int, default 6 4134 Height of the figure. 4135 4136 font_size : int, default 15 4137 Font size for labels and annotations. 4138 4139 legend_split_col : int, default 1 4140 Number of columns in the legend. 4141 4142 legend_bbox : tuple, default (1.3, 1) 4143 Bounding box anchor position for the legend. 4144 4145 Returns 4146 ------- 4147 matplotlib.figure.Figure 4148 Stacked bar plot visualization of composition data. 4149 """ 4150 4151 df = self.composition_data 4152 df["num"] = range(1, len(df) + 1) 4153 4154 if "set" in df.columns and len(set(df["set"])) > 1: 4155 4156 sets = list(set(df["set"])) 4157 fig, axes = plt.subplots(1, len(sets), figsize=(width * len(sets), height)) 4158 4159 cmap = plt.get_cmap(cmap) 4160 4161 set_nam = len(set(df["name"])) 4162 4163 legend_labels = list(set(df["name"])) 4164 4165 colors = [cmap(i / set_nam) for i in range(set_nam)] 4166 4167 cmap_dict = dict(zip(legend_labels, colors)) 4168 4169 for idx, s in enumerate(sets): 4170 ax = axes[idx] 4171 4172 tmp_df = df[df["set"] == s].reset_index(drop=True) 4173 4174 values = tmp_df["n"].values 4175 total = sum(values) 4176 values = [v / total * 100 for v in values] 4177 values = [round(v, 2) for v in values] 4178 4179 idx_max = np.argmax(values) 4180 correction = 100 - sum(values) 4181 values[idx_max] += correction 4182 4183 names = tmp_df["name"].values 4184 perc = tmp_df["pct"].values 4185 nums = tmp_df["num"].values 4186 4187 bottom = 0 4188 centers = [] 4189 for name, num, val, color in zip(names, nums, values, colors): 4190 ax.bar(s, val, bottom=bottom, color=cmap_dict[name], label=name) 4191 centers.append(bottom + val / 2) 4192 bottom += val 4193 4194 y_positions = np.linspace(centers[0], centers[-1], len(centers)) 4195 x_text = -0.8 4196 4197 for y_label, y_center, pct, num in zip( 4198 y_positions, centers, perc, nums 4199 ): 4200 ax.annotate( 4201 f"{pct:.1f}%", 4202 xy=(0, y_center), 4203 xycoords="data", 4204 xytext=(x_text, y_label), 4205 textcoords="data", 4206 ha="right", 4207 va="center", 4208 fontsize=font_size, 4209 arrowprops=dict( 4210 arrowstyle="->", 4211 lw=1, 4212 color="black", 4213 connectionstyle="angle3,angleA=0,angleB=90", 4214 ), 4215 ) 4216 4217 ax.set_ylim(0, 100) 4218 ax.set_xlabel(s, fontsize=font_size) 4219 ax.xaxis.label.set_rotation(30) 4220 4221 ax.set_xticks([]) 4222 ax.set_yticks([]) 4223 for spine in ax.spines.values(): 4224 spine.set_visible(False) 4225 4226 legend_handles = [ 4227 Patch(facecolor=cmap_dict[label], edgecolor="black", label=label) 4228 for label in legend_labels 4229 ] 4230 4231 fig.legend( 4232 handles=legend_handles, 4233 loc="center right", 4234 bbox_to_anchor=legend_bbox, 4235 ncol=legend_split_col, 4236 title="", 4237 ) 4238 4239 plt.tight_layout() 4240 plt.show() 4241 4242 else: 4243 4244 cmap = plt.get_cmap(cmap) 4245 4246 colors = [cmap(i / len(df)) for i in range(len(df))] 4247 4248 fig, ax = plt.subplots(figsize=(width, height)) 4249 4250 values = df["n"].values 4251 names = df["name"].values 4252 perc = df["pct"].values 4253 nums = df["num"].values 4254 4255 bottom = 0 4256 centers = [] 4257 for name, num, val, color in zip(names, nums, values, colors): 4258 ax.bar(0, val, bottom=bottom, color=color, label=f"{num}) {name}") 4259 centers.append(bottom + val / 2) 4260 bottom += val 4261 4262 y_positions = np.linspace(centers[0], centers[-1], len(centers)) 4263 x_text = -0.8 4264 4265 for y_label, y_center, pct, num in zip(y_positions, centers, perc, nums): 4266 ax.annotate( 4267 f"{num}) {pct}", 4268 xy=(0, y_center), 4269 xycoords="data", 4270 xytext=(x_text, y_label), 4271 textcoords="data", 4272 ha="right", 4273 va="center", 4274 fontsize=9, 4275 arrowprops=dict( 4276 arrowstyle="->", 4277 lw=1, 4278 color="black", 4279 connectionstyle="angle3,angleA=0,angleB=90", 4280 ), 4281 ) 4282 4283 ax.set_xticks([]) 4284 ax.set_yticks([]) 4285 for spine in ax.spines.values(): 4286 spine.set_visible(False) 4287 4288 ax.legend( 4289 title="Legend", 4290 bbox_to_anchor=legend_bbox, 4291 loc="upper left", 4292 ncol=legend_split_col, 4293 ) 4294 4295 plt.tight_layout() 4296 plt.show() 4297 4298 return fig 4299 4300 def cell_regression( 4301 self, 4302 cell_x: str, 4303 cell_y: str, 4304 set_x: str | None, 4305 set_y: str | None, 4306 threshold=10, 4307 image_width=12, 4308 image_high=7, 4309 color="black", 4310 ): 4311 """ 4312 Perform regression analysis between two selected cells and visualize the relationship. 4313 4314 This function computes a linear regression between two specified cells from 4315 aggregated normalized data, plots the regression line with scatter points, 4316 annotates regression statistics, and highlights potential outliers. 4317 4318 Parameters 4319 ---------- 4320 cell_x : str 4321 Name of the first cell (X-axis). 4322 4323 cell_y : str 4324 Name of the second cell (Y-axis). 4325 4326 set_x : str or None 4327 Dataset identifier corresponding to `cell_x`. If None, cell is selected only by name. 4328 4329 set_y : str or None 4330 Dataset identifier corresponding to `cell_y`. If None, cell is selected only by name. 4331 4332 threshold : int or float, default 10 4333 Threshold for detecting outliers. Points deviating from the mean or diagonal by more 4334 than this value are annotated. 4335 4336 image_width : int, default 12 4337 Width of the regression plot (in inches). 4338 4339 image_high : int, default 7 4340 Height of the regression plot (in inches). 4341 4342 color : str, default 'black' 4343 Color of the regression scatter points and line. 4344 4345 Returns 4346 ------- 4347 matplotlib.figure.Figure 4348 Regression plot figure with annotated regression line, R², p-value, and outliers. 4349 4350 Raises 4351 ------ 4352 ValueError 4353 If `cell_x` or `cell_y` are not found in the dataset. 4354 If multiple matches are found for a cell name and `set_x`/`set_y` are not specified. 4355 4356 Notes 4357 ----- 4358 - The function automatically calls `jseq_object.average()` if aggregated data is not available. 4359 - Outliers are annotated with their corresponding index labels. 4360 - Regression is computed using `scipy.stats.linregress`. 4361 4362 Examples 4363 -------- 4364 >>> obj.cell_regression(cell_x="Purkinje", cell_y="Granule", set_x="Exp1", set_y="Exp2") 4365 >>> obj.cell_regression(cell_x="NeuronA", cell_y="NeuronB", threshold=5, color="blue") 4366 """ 4367 4368 if self.agg_normalized_data is None: 4369 self.average() 4370 4371 metadata = self.agg_metadata 4372 data = self.agg_normalized_data 4373 4374 if set_x is not None and set_y is not None: 4375 data.columns = metadata["cell_names"] + " # " + metadata["sets"] 4376 cell_x = cell_x + " # " + set_x 4377 cell_y = cell_y + " # " + set_y 4378 4379 else: 4380 data.columns = metadata["cell_names"] 4381 4382 if not cell_x in data.columns: 4383 raise ValueError("'cell_x' value not in cell names!") 4384 4385 if not cell_y in data.columns: 4386 raise ValueError("'cell_y' value not in cell names!") 4387 4388 if list(data.columns).count(cell_x) > 1: 4389 raise ValueError( 4390 f"'{cell_x}' occurs more than once. If you want to select a specific cell, " 4391 f"please also provide the corresponding 'set_x' and 'set_y' values." 4392 ) 4393 4394 if list(data.columns).count(cell_y) > 1: 4395 raise ValueError( 4396 f"'{cell_y}' occurs more than once. If you want to select a specific cell, " 4397 f"please also provide the corresponding 'set_x' and 'set_y' values." 4398 ) 4399 4400 fig, ax = plt.subplots(figsize=(image_width, image_high)) 4401 ax = sns.regplot(x=cell_x, y=cell_y, data=data, color=color) 4402 4403 slope, intercept, r_value, p_value, _ = stats.linregress( 4404 data[cell_x], data[cell_y] 4405 ) 4406 equation = "y = {:.2f}x + {:.2f}".format(slope, intercept) 4407 4408 ax.annotate( 4409 "R-squared = {:.2f}\nP-value = {:.2f}\n{}".format( 4410 r_value**2, p_value, equation 4411 ), 4412 xy=(0.05, 0.90), 4413 xycoords="axes fraction", 4414 fontsize=12, 4415 ) 4416 4417 ax.spines["top"].set_visible(False) 4418 ax.spines["right"].set_visible(False) 4419 4420 diff = [] 4421 x_mean, y_mean = data[cell_x].mean(), data[cell_y].mean() 4422 for i, (xi, yi) in enumerate(zip(data[cell_x], data[cell_y])): 4423 diff.append(abs(xi - x_mean)) 4424 diff.append(abs(yi - y_mean)) 4425 4426 def annotate_outliers(x, y, threshold): 4427 texts = [] 4428 x_mean, y_mean = x.mean(), y.mean() 4429 for i, (xi, yi) in enumerate(zip(x, y)): 4430 if ( 4431 abs(xi - x_mean) > threshold 4432 or abs(yi - y_mean) > threshold 4433 or abs(yi - xi) > threshold 4434 ): 4435 text = ax.text(xi, yi, data.index[i]) 4436 texts.append(text) 4437 4438 return texts 4439 4440 texts = annotate_outliers(data[cell_x], data[cell_y], threshold) 4441 4442 adjust_text(texts, arrowprops=dict(arrowstyle="-", color="gray", alpha=0.5)) 4443 4444 plt.show() 4445 4446 return fig
A class COMPsc
(Comparison of single-cell data) designed for the integration,
analysis, and visualization of single-cell datasets.
The class supports independent dataset integration, subclustering of existing clusters,
marker detection, and multiple visualization strategies.
The COMPsc class provides methods for:
- Normalizing and filtering single-cell data
- Loading and saving sparse 10x-style datasets
- Computing differential expression and marker genes
- Clustering and subclustering analysis
- Visualizing similarity and spatial relationships
- Aggregating data by cell and set annotations
- Managing metadata and renaming labels
- Plotting gene detection histograms and feature scatters
Methods
project_dir(path_to_directory, project_list) Scans a directory to create a COMPsc instance mapping project names to their paths.
save_project(name, path=os.getcwd()) Saves the COMPsc object to a pickle file on disk.
load_project(path) Loads a previously saved COMPsc object from a pickle file.
reduce_cols(reg, inc_set=False) Removes columns from data tables where column names contain a specified name or partial substring.
reduce_rows(reg, inc_set=False) Removes rows from data tables where column names contain a specified feature (gene) name.
get_data(set_info=False) Returns normalized data with optional set annotations in column names.
get_metadata() Returns the stored input metadata.
get_partial_data(names=None, features=None, name_slot='cell_names') Return a subset of the data by sample names and/or features.
gene_calculation() Calculates and stores per-cell gene detection counts as a pandas Series.
gene_histograme(bins=100) Plots a histogram of genes detected per cell with an overlaid normal distribution.
gene_threshold(min_n=None, max_n=None) Filters cells based on minimum and/or maximum gene detection thresholds.
load_sparse_from_projects(normalized_data=False) Loads and concatenates sparse 10x-style datasets from project paths into count or normalized data.
rename_names(mapping, slot='cell_names') Renames entries in a specified metadata column using a provided mapping dictionary.
rename_subclusters(mapping) Renames subcluster labels using a provided mapping dictionary.
save_sparse(path_to_save=os.getcwd(), name_slot='cell_names', data_slot='normalized') Exports data as 10x-compatible sparse files (matrix.mtx, barcodes.tsv, genes.tsv).
normalize_data(normalize=True, normalize_factor=100000) Normalizes raw counts to counts-per-specified factor (e.g., CPM-like).
statistic(cells=None, sets=None, min_exp=0.01, min_pct=0.1, n_proc=10) Computes per-feature differential expression statistics (Mann-Whitney U) comparing target vs. rest groups.
calculate_difference_markers(min_exp=0, min_pct=0.25, n_proc=10, force=False) Computes and caches differential markers using the statistic method.
clustering_features(features_list=None, name_slot='cell_names', p_val=0.05, top_n=25, adj_mean=True, beta=0.4) Prepares clustering input by selecting marker features and optionally smoothing cell values.
average() Aggregates normalized data by averaging across (cell_name, set) pairs.
estimating_similarity(method='pearson', p_val=0.05, top_n=25) Computes pairwise correlation and Euclidean distance between aggregated samples.
similarity_plot(split_sets=True, set_info=True, cmap='seismic', width=12, height=10) Visualizes pairwise similarity as a scatter plot with correlation as hue and scaled distance as point size.
spatial_similarity(set_info=True, bandwidth=1, n_neighbors=5, min_dist=0.1, legend_split=2, point_size=20, ...) Creates a UMAP-like visualization of similarity relationships with cluster hulls and nearest-neighbor arrows.
subcluster_prepare(features, cluster) Initializes a Clustering object for subcluster analysis on a selected parent cluster.
define_subclusters(umap_num=2, eps=0.5, min_samples=10, bandwidth=1, n_neighbors=5, min_dist=0.1, ...) Performs UMAP and DBSCAN clustering on prepared subcluster data and stores cluster labels.
subcluster_features_scatter(colors='viridis', hclust='complete', img_width=3, img_high=5, label_size=6, ...) Visualizes averaged expression and occurrence of features for subclusters as a scatter plot.
subcluster_DEG_scatter(top_n=3, min_exp=0, min_pct=0.25, p_val=0.05, colors='viridis', ...) Plots top differential features for subclusters as a features-scatter visualization.
accept_subclusters() Commits subcluster labels to main metadata by renaming cell names and clears subcluster data.
Raises
ValueError For invalid parameters, mismatched dimensions, or missing metadata.
1146 def __init__( 1147 self, 1148 objects=None, 1149 ): 1150 """ 1151 Initialize the COMPsc class for single-cell data integration and analysis. 1152 1153 Parameters 1154 ---------- 1155 objects : list or None, optional 1156 Optional list of data objects to initialize the instance with. 1157 1158 Attributes 1159 ---------- 1160 -objects : list or None 1161 -input_data : pandas.DataFrame or None 1162 -input_metadata : pandas.DataFrame or None 1163 -normalized_data : pandas.DataFrame or None 1164 -agg_metadata : pandas.DataFrame or None 1165 -agg_normalized_data : pandas.DataFrame or None 1166 -similarity : pandas.DataFrame or None 1167 -var_data : pandas.DataFrame or None 1168 -subclusters_ : instance of Clustering class or None 1169 -cells_calc : pandas.Series or None 1170 -gene_calc : pandas.Series or None 1171 -composition_data : pandas.DataFrame or None 1172 """ 1173 1174 self.objects = objects 1175 """ Stores the input data objects.""" 1176 1177 self.input_data = None 1178 """Raw input data for clustering or integration analysis.""" 1179 1180 self.input_metadata = None 1181 """Metadata associated with the input data.""" 1182 1183 self.normalized_data = None 1184 """Normalized version of the input data.""" 1185 1186 self.agg_metadata = None 1187 '''Aggregated metadata for all sets in object related to "agg_normalized_data"''' 1188 1189 self.agg_normalized_data = None 1190 """Aggregated and normalized data across multiple sets.""" 1191 1192 self.similarity = None 1193 """Similarity data between cells across all samples. and sets""" 1194 1195 self.var_data = None 1196 """DEG analysis results summarizing variance across all samples in the object.""" 1197 1198 self.subclusters_ = None 1199 """Placeholder for information about subclusters analysis; if computed.""" 1200 1201 self.cells_calc = None 1202 """Number of cells detected per sample (grouped by lineage, e.g., cluster or name), reflecting data composition.""" 1203 1204 self.gene_calc = None 1205 """Number of genes detected per sample (cell), reflecting the sequencing depth.""" 1206 1207 self.composition_data = None 1208 """Data describing composition of cells across clusters or sets."""
Initialize the COMPsc class for single-cell data integration and analysis.
Parameters
objects : list or None, optional Optional list of data objects to initialize the instance with.
Attributes
-objects : list or None -input_data : pandas.DataFrame or None -input_metadata : pandas.DataFrame or None -normalized_data : pandas.DataFrame or None -agg_metadata : pandas.DataFrame or None -agg_normalized_data : pandas.DataFrame or None -similarity : pandas.DataFrame or None -var_data : pandas.DataFrame or None -subclusters_ : instance of Clustering class or None -cells_calc : pandas.Series or None -gene_calc : pandas.Series or None -composition_data : pandas.DataFrame or None
Number of cells detected per sample (grouped by lineage, e.g., cluster or name), reflecting data composition.
1210 @classmethod 1211 def project_dir(cls, path_to_directory, project_list): 1212 """ 1213 Scan a directory and build a COMPsc instance mapping provided project names 1214 to their paths. 1215 1216 Parameters 1217 ---------- 1218 path_to_directory : str 1219 Path containing project subfolders. 1220 1221 project_list : list[str] 1222 List of filenames (folder names) to include in the returned object map. 1223 1224 Returns 1225 ------- 1226 COMPsc 1227 New COMPsc instance with `objects` populated. 1228 1229 Raises 1230 ------ 1231 Exception 1232 A generic exception is caught and a message printed if scanning fails. 1233 1234 Notes 1235 ----- 1236 Function attempts to match entries in `project_list` to directory 1237 names and constructs a simplified object key from the folder name. 1238 """ 1239 try: 1240 objects = {} 1241 for filename in tqdm(os.listdir(path_to_directory)): 1242 for c in project_list: 1243 f = os.path.join(path_to_directory, filename) 1244 if c == filename and os.path.isdir(f): 1245 objects[ 1246 str( 1247 re.sub( 1248 "_count", 1249 "", 1250 re.sub( 1251 "_fq", 1252 "", 1253 re.sub( 1254 re.sub("_.*", "", filename) + "_", 1255 "", 1256 filename, 1257 ), 1258 ), 1259 ) 1260 ) 1261 ] = f 1262 1263 return cls(objects) 1264 1265 except: 1266 print("Something went wrong. Check the function input data and try again!")
Scan a directory and build a COMPsc instance mapping provided project names to their paths.
Parameters
path_to_directory : str Path containing project subfolders.
project_list : list[str] List of filenames (folder names) to include in the returned object map.
Returns
COMPsc
New COMPsc instance with objects
populated.
Raises
Exception A generic exception is caught and a message printed if scanning fails.
Notes
Function attempts to match entries in project_list
to directory
names and constructs a simplified object key from the folder name.
1268 def save_project(self, name, path: str = os.getcwd()): 1269 """ 1270 Save the COMPsc object to disk using pickle. 1271 1272 Parameters 1273 ---------- 1274 name : str 1275 Base filename (without extension) to use when saving. 1276 1277 path : str, default os.getcwd() 1278 Directory in which to save the project file. 1279 1280 Returns 1281 ------- 1282 None 1283 1284 Side Effects 1285 ------------ 1286 - Writes a file `<path>/<name>.jpkl` containing the pickled object. 1287 - Prints a confirmation message with saved path. 1288 """ 1289 1290 full = os.path.join(path, f"{name}.jpkl") 1291 1292 with open(full, "wb") as f: 1293 pickle.dump(self, f) 1294 1295 print(f"Project saved as {full}")
Save the COMPsc object to disk using pickle.
Parameters
name : str Base filename (without extension) to use when saving.
path : str, default os.getcwd() Directory in which to save the project file.
Returns
None
Side Effects
- Writes a file
<path>/<name>.jpkl
containing the pickled object. - Prints a confirmation message with saved path.
1297 @classmethod 1298 def load_project(cls, path): 1299 """ 1300 Load a previously saved COMPsc project from a pickle file. 1301 1302 Parameters 1303 ---------- 1304 path : str 1305 Full path to the pickled project file. 1306 1307 Returns 1308 ------- 1309 COMPsc 1310 The unpickled COMPsc object. 1311 1312 Raises 1313 ------ 1314 FileNotFoundError 1315 If the provided path does not exist. 1316 """ 1317 1318 if not os.path.exists(path): 1319 raise FileNotFoundError("File does not exist!") 1320 with open(path, "rb") as f: 1321 obj = pickle.load(f) 1322 return obj
Load a previously saved COMPsc project from a pickle file.
Parameters
path : str Full path to the pickled project file.
Returns
COMPsc The unpickled COMPsc object.
Raises
FileNotFoundError If the provided path does not exist.
1324 def reduce_cols( 1325 self, 1326 reg: str | None = None, 1327 full: str | None = None, 1328 name_slot: str = "cell_names", 1329 inc_set: bool = False, 1330 ): 1331 """ 1332 Remove columns (cells) whose names contain a substring `reg` or 1333 full name `full` from available tables. 1334 1335 Parameters 1336 ---------- 1337 reg : str | None 1338 Substring to search for in column/cell names; matching columns will be removed. 1339 If not None, `full` must be None. 1340 1341 full : str | None 1342 Full name to search for in column/cell names; matching columns will be removed. 1343 If not None, `reg` must be None. 1344 1345 name_slot : str, default 'cell_names' 1346 Column in metadata to use as sample names. 1347 1348 inc_set : bool, default False 1349 If True, column names are interpreted as 'cell_name # set' when matching. 1350 1351 Update 1352 ------------ 1353 Mutates `self.input_data`, `self.normalized_data`, `self.input_metadata`, 1354 `self.agg_normalized_data`, and `self.agg_metadata` (if they exist), 1355 removing columns/rows that match `reg`. 1356 1357 Raises 1358 ------ 1359 Raises ValueError if nothing matches the reduction mask. 1360 """ 1361 1362 if reg is None and full is None: 1363 raise ValueError( 1364 "Both 'reg' and 'full' arguments not provided. Please provide at least one of them!" 1365 ) 1366 1367 if reg is not None and full is not None: 1368 raise ValueError( 1369 "Both 'reg' and 'full' arguments are provided. " 1370 "Please provide only one of them!\n" 1371 "'reg' is used when only part of the name must be detected.\n" 1372 "'full' is used if the full name must be detected." 1373 ) 1374 1375 if reg is not None: 1376 1377 if self.input_data is not None: 1378 1379 if inc_set: 1380 1381 self.input_data.columns = ( 1382 self.input_metadata[name_slot] 1383 + " # " 1384 + self.input_metadata["sets"] 1385 ) 1386 1387 else: 1388 1389 self.input_data.columns = self.input_metadata[name_slot] 1390 1391 mask = [reg.upper() not in x.upper() for x in self.input_data.columns] 1392 1393 if len([y for y in mask if y is False]) == 0: 1394 raise ValueError("Nothing found to reduce") 1395 1396 self.input_data = self.input_data.loc[:, mask] 1397 1398 if self.normalized_data is not None: 1399 1400 if inc_set: 1401 1402 self.normalized_data.columns = ( 1403 self.input_metadata[name_slot] 1404 + " # " 1405 + self.input_metadata["sets"] 1406 ) 1407 1408 else: 1409 1410 self.normalized_data.columns = self.input_metadata[name_slot] 1411 1412 mask = [ 1413 reg.upper() not in x.upper() for x in self.normalized_data.columns 1414 ] 1415 1416 if len([y for y in mask if y is False]) == 0: 1417 raise ValueError("Nothing found to reduce") 1418 1419 self.normalized_data = self.normalized_data.loc[:, mask] 1420 1421 if self.input_metadata is not None: 1422 1423 if inc_set: 1424 1425 self.input_metadata["drop"] = ( 1426 self.input_metadata[name_slot] 1427 + " # " 1428 + self.input_metadata["sets"] 1429 ) 1430 1431 else: 1432 1433 self.input_metadata["drop"] = self.input_metadata[name_slot] 1434 1435 mask = [ 1436 reg.upper() not in x.upper() for x in self.input_metadata["drop"] 1437 ] 1438 1439 if len([y for y in mask if y is False]) == 0: 1440 raise ValueError("Nothing found to reduce") 1441 1442 self.input_metadata = self.input_metadata.loc[mask, :].reset_index( 1443 drop=True 1444 ) 1445 1446 self.input_metadata = self.input_metadata.drop( 1447 columns=["drop"], errors="ignore" 1448 ) 1449 1450 if self.agg_normalized_data is not None: 1451 1452 if inc_set: 1453 1454 self.agg_normalized_data.columns = ( 1455 self.agg_metadata[name_slot] + " # " + self.agg_metadata["sets"] 1456 ) 1457 1458 else: 1459 1460 self.agg_normalized_data.columns = self.agg_metadata[name_slot] 1461 1462 mask = [ 1463 reg.upper() not in x.upper() 1464 for x in self.agg_normalized_data.columns 1465 ] 1466 1467 if len([y for y in mask if y is False]) == 0: 1468 raise ValueError("Nothing found to reduce") 1469 1470 self.agg_normalized_data = self.agg_normalized_data.loc[:, mask] 1471 1472 if self.agg_metadata is not None: 1473 1474 if inc_set: 1475 1476 self.agg_metadata["drop"] = ( 1477 self.agg_metadata[name_slot] + " # " + self.agg_metadata["sets"] 1478 ) 1479 1480 else: 1481 1482 self.agg_metadata["drop"] = self.agg_metadata[name_slot] 1483 1484 mask = [reg.upper() not in x.upper() for x in self.agg_metadata["drop"]] 1485 1486 if len([y for y in mask if y is False]) == 0: 1487 raise ValueError("Nothing found to reduce") 1488 1489 self.agg_metadata = self.agg_metadata.loc[mask, :].reset_index( 1490 drop=True 1491 ) 1492 1493 self.agg_metadata = self.agg_metadata.drop( 1494 columns=["drop"], errors="ignore" 1495 ) 1496 1497 elif full is not None: 1498 1499 if self.input_data is not None: 1500 1501 if inc_set: 1502 1503 self.input_data.columns = ( 1504 self.input_metadata[name_slot] 1505 + " # " 1506 + self.input_metadata["sets"] 1507 ) 1508 1509 if "#" not in full: 1510 1511 self.input_data.columns = self.input_metadata[name_slot] 1512 1513 print( 1514 "Not include the set info (name # set) in the 'full' argument, where 'inc_set' is True" 1515 "Only the names will be compared, without considering the set information." 1516 ) 1517 1518 else: 1519 1520 self.input_data.columns = self.input_metadata[name_slot] 1521 1522 mask = [full.upper() != x.upper() for x in self.input_data.columns] 1523 1524 if len([y for y in mask if y is False]) == 0: 1525 raise ValueError("Nothing found to reduce") 1526 1527 self.input_data = self.input_data.loc[:, mask] 1528 1529 if self.normalized_data is not None: 1530 1531 if inc_set: 1532 1533 self.normalized_data.columns = ( 1534 self.input_metadata[name_slot] 1535 + " # " 1536 + self.input_metadata["sets"] 1537 ) 1538 1539 if "#" not in full: 1540 1541 self.normalized_data.columns = self.input_metadata[name_slot] 1542 1543 print( 1544 "Not include the set info (name # set) in the 'full' argument, where 'inc_set' is True" 1545 "Only the names will be compared, without considering the set information." 1546 ) 1547 1548 else: 1549 1550 self.normalized_data.columns = self.input_metadata[name_slot] 1551 1552 mask = [full.upper() != x.upper() for x in self.normalized_data.columns] 1553 1554 if len([y for y in mask if y is False]) == 0: 1555 raise ValueError("Nothing found to reduce") 1556 1557 self.normalized_data = self.normalized_data.loc[:, mask] 1558 1559 if self.input_metadata is not None: 1560 1561 if inc_set: 1562 1563 self.input_metadata["drop"] = ( 1564 self.input_metadata[name_slot] 1565 + " # " 1566 + self.input_metadata["sets"] 1567 ) 1568 1569 if "#" not in full: 1570 1571 self.input_metadata["drop"] = self.input_metadata[name_slot] 1572 1573 print( 1574 "Not include the set info (name # set) in the 'full' argument, where 'inc_set' is True" 1575 "Only the names will be compared, without considering the set information." 1576 ) 1577 1578 else: 1579 1580 self.input_metadata["drop"] = self.input_metadata[name_slot] 1581 1582 mask = [full.upper() != x.upper() for x in self.input_metadata["drop"]] 1583 1584 if len([y for y in mask if y is False]) == 0: 1585 raise ValueError("Nothing found to reduce") 1586 1587 self.input_metadata = self.input_metadata.loc[mask, :].reset_index( 1588 drop=True 1589 ) 1590 1591 self.input_metadata = self.input_metadata.drop( 1592 columns=["drop"], errors="ignore" 1593 ) 1594 1595 if self.agg_normalized_data is not None: 1596 1597 if inc_set: 1598 1599 self.agg_normalized_data.columns = ( 1600 self.agg_metadata[name_slot] + " # " + self.agg_metadata["sets"] 1601 ) 1602 1603 if "#" not in full: 1604 1605 self.agg_normalized_data.columns = self.agg_metadata[name_slot] 1606 1607 print( 1608 "Not include the set info (name # set) in the 'full' argument, where 'inc_set' is True" 1609 "Only the names will be compared, without considering the set information." 1610 ) 1611 else: 1612 1613 self.agg_normalized_data.columns = self.agg_metadata[name_slot] 1614 1615 mask = [ 1616 full.upper() != x.upper() for x in self.agg_normalized_data.columns 1617 ] 1618 1619 if len([y for y in mask if y is False]) == 0: 1620 raise ValueError("Nothing found to reduce") 1621 1622 self.agg_normalized_data = self.agg_normalized_data.loc[:, mask] 1623 1624 if self.agg_metadata is not None: 1625 1626 if inc_set: 1627 1628 self.agg_metadata["drop"] = ( 1629 self.agg_metadata[name_slot] + " # " + self.agg_metadata["sets"] 1630 ) 1631 1632 if "#" not in full: 1633 1634 self.agg_metadata["drop"] = self.agg_metadata[name_slot] 1635 1636 print( 1637 "Not include the set info (name # set) in the 'full' argument, where 'inc_set' is True" 1638 "Only the names will be compared, without considering the set information." 1639 ) 1640 else: 1641 1642 self.agg_metadata["drop"] = self.agg_metadata[name_slot] 1643 1644 mask = [full.upper() != x.upper() for x in self.agg_metadata["drop"]] 1645 1646 if len([y for y in mask if y is False]) == 0: 1647 raise ValueError("Nothing found to reduce") 1648 1649 self.agg_metadata = self.agg_metadata.loc[mask, :].reset_index( 1650 drop=True 1651 ) 1652 1653 self.agg_metadata = self.agg_metadata.drop( 1654 columns=["drop"], errors="ignore" 1655 ) 1656 1657 self.gene_calculation() 1658 self.cells_calculation()
Remove columns (cells) whose names contain a substring reg
or
full name full
from available tables.
Parameters
reg : str | None
Substring to search for in column/cell names; matching columns will be removed.
If not None, full
must be None.
full : str | None
Full name to search for in column/cell names; matching columns will be removed.
If not None, reg
must be None.
name_slot : str, default 'cell_names' Column in metadata to use as sample names.
inc_set : bool, default False If True, column names are interpreted as 'cell_name # set' when matching.
Update
Mutates self.input_data
, self.normalized_data
, self.input_metadata
,
self.agg_normalized_data
, and self.agg_metadata
(if they exist),
removing columns/rows that match reg
.
Raises
Raises ValueError if nothing matches the reduction mask.
1660 def reduce_rows(self, features_list: list): 1661 """ 1662 Remove rows (features) whose names are included in features_list. 1663 1664 Parameters 1665 ---------- 1666 features_list : list 1667 List of features to search for in index/gene names; matching entries will be removed. 1668 1669 Update 1670 ------------ 1671 Mutates `self.input_data`, `self.normalized_data`, `self.input_metadata`, 1672 `self.agg_normalized_data`, and `self.agg_metadata` (if they exist), 1673 removing columns/rows that match `reg`. 1674 1675 Raises 1676 ------ 1677 Prints a message listing features that are not found in the data. 1678 """ 1679 1680 if self.input_data is not None: 1681 1682 res = find_features(self.input_data, features=features_list) 1683 1684 res_list = [x.upper() for x in res["included"]] 1685 1686 mask = [x.upper() not in res_list for x in self.input_data.index] 1687 1688 if len([y for y in mask if y is False]) == 0: 1689 raise ValueError("Nothing found to reduce") 1690 1691 self.input_data = self.input_data.loc[mask, :] 1692 1693 if self.normalized_data is not None: 1694 1695 res = find_features(self.normalized_data, features=features_list) 1696 1697 res_list = [x.upper() for x in res["included"]] 1698 1699 mask = [x.upper() not in res_list for x in self.normalized_data.index] 1700 1701 if len([y for y in mask if y is False]) == 0: 1702 raise ValueError("Nothing found to reduce") 1703 1704 self.normalized_data = self.normalized_data.loc[mask, :] 1705 1706 if self.agg_normalized_data is not None: 1707 1708 res = find_features(self.agg_normalized_data, features=features_list) 1709 1710 res_list = [x.upper() for x in res["included"]] 1711 1712 mask = [x.upper() not in res_list for x in self.agg_normalized_data.index] 1713 1714 if len([y for y in mask if y is False]) == 0: 1715 raise ValueError("Nothing found to reduce") 1716 1717 self.agg_normalized_data = self.agg_normalized_data.loc[mask, :] 1718 1719 if len(res["not_included"]) > 0: 1720 print("\nFeatures not found:") 1721 for i in res["not_included"]: 1722 print(i) 1723 1724 self.gene_calculation() 1725 self.cells_calculation()
Remove rows (features) whose names are included in features_list.
Parameters
features_list : list List of features to search for in index/gene names; matching entries will be removed.
Update
Mutates self.input_data
, self.normalized_data
, self.input_metadata
,
self.agg_normalized_data
, and self.agg_metadata
(if they exist),
removing columns/rows that match reg
.
Raises
Prints a message listing features that are not found in the data.
1727 def get_data(self, set_info: bool = False): 1728 """ 1729 Return normalized data with optional set annotation appended to column names. 1730 1731 Parameters 1732 ---------- 1733 set_info : bool, default False 1734 If True, column names are returned as "cell_name # set"; otherwise 1735 only the `cell_name` is used. 1736 1737 Returns 1738 ------- 1739 pandas.DataFrame 1740 The `self.normalized_data` table with columns renamed according to `set_info`. 1741 1742 Raises 1743 ------ 1744 AttributeError 1745 If `self.normalized_data` or `self.input_metadata` is missing. 1746 """ 1747 1748 to_return = self.normalized_data 1749 1750 if set_info: 1751 to_return.columns = ( 1752 self.input_metadata["cell_names"] + " # " + self.input_metadata["sets"] 1753 ) 1754 else: 1755 to_return.columns = self.input_metadata["cell_names"] 1756 1757 return to_return
Return normalized data with optional set annotation appended to column names.
Parameters
set_info : bool, default False
If True, column names are returned as "cell_name # set"; otherwise
only the cell_name
is used.
Returns
pandas.DataFrame
The self.normalized_data
table with columns renamed according to set_info
.
Raises
AttributeError
If self.normalized_data
or self.input_metadata
is missing.
1759 def get_partial_data( 1760 self, 1761 names: list | str | None = None, 1762 features: list | str | None = None, 1763 name_slot: str = "cell_names", 1764 inc_metadata: bool = False, 1765 ): 1766 """ 1767 Return a subset of the data filtered by sample names and/or feature names. 1768 1769 Parameters 1770 ---------- 1771 names : list, str, or None 1772 Names of samples to include. If None, all samples are considered. 1773 1774 features : list, str, or None 1775 Names of features to include. If None, all features are considered. 1776 1777 name_slot : str 1778 Column in metadata to use as sample names. 1779 1780 inc_metadata : bool 1781 If True return tuple (data, metadata) 1782 1783 Returns 1784 ------- 1785 pandas.DataFrame 1786 Subset of the normalized data based on the specified names and features. 1787 """ 1788 1789 data = self.normalized_data.copy() 1790 metadata = self.input_metadata 1791 1792 if name_slot in self.input_metadata.columns: 1793 data.columns = self.input_metadata[name_slot] 1794 else: 1795 raise ValueError("'name_slot' not occured in data!'") 1796 1797 if isinstance(features, str): 1798 features = [features] 1799 elif features is None: 1800 features = [] 1801 1802 if isinstance(names, str): 1803 names = [names] 1804 elif names is None: 1805 names = [] 1806 1807 features = [x.upper() for x in features] 1808 names = [x.upper() for x in names] 1809 1810 columns_names = [x.upper() for x in data.columns] 1811 features_names = [x.upper() for x in data.index] 1812 1813 columns_bool = [True if x in names else False for x in columns_names] 1814 features_bool = [True if x in features else False for x in features_names] 1815 1816 if True not in columns_bool and True not in features_bool: 1817 print("Missing 'names' and/or 'features'. Returning full dataset instead.") 1818 if inc_metadata: 1819 return data, metadata 1820 else: 1821 return data 1822 1823 if True in columns_bool: 1824 data = data.loc[:, columns_bool] 1825 metadata = metadata.loc[columns_bool, :] 1826 1827 if inc_metadata: 1828 return data, metadata 1829 else: 1830 return data 1831 1832 if True in features_bool: 1833 data = data.loc[features_bool, :] 1834 1835 if inc_metadata: 1836 return data, metadata 1837 else: 1838 return data
Return a subset of the data filtered by sample names and/or feature names.
Parameters
names : list, str, or None Names of samples to include. If None, all samples are considered.
features : list, str, or None Names of features to include. If None, all features are considered.
name_slot : str Column in metadata to use as sample names.
inc_metadata : bool If True return tuple (data, metadata)
Returns
pandas.DataFrame Subset of the normalized data based on the specified names and features.
1840 def get_metadata(self): 1841 """ 1842 Return the stored input metadata. 1843 1844 Returns 1845 ------- 1846 pandas.DataFrame 1847 `self.input_metadata` (may be None if not set). 1848 """ 1849 1850 to_return = self.input_metadata 1851 1852 return to_return
Return the stored input metadata.
Returns
pandas.DataFrame
self.input_metadata
(may be None if not set).
1854 def gene_calculation(self): 1855 """ 1856 Calculate and store per-cell counts (e.g., number of detected genes). 1857 1858 The method computes a binary (presence/absence) per cell and sums across 1859 features to produce `self.gene_calc`. 1860 1861 Update 1862 ------ 1863 Sets `self.gene_calc` as a pandas.Series. 1864 1865 Side Effects 1866 ------------ 1867 Uses `self.input_data` when available, otherwise `self.normalized_data`. 1868 """ 1869 1870 if self.input_data is not None: 1871 1872 bin_col = self.input_data.columns.copy() 1873 1874 bin_col = bin_col.where(bin_col <= 0, 1) 1875 1876 sum_data = bin_col.sum(axis=0) 1877 1878 self.gene_calc = sum_data 1879 1880 elif self.normalized_data is not None: 1881 1882 bin_col = self.normalized_data.copy() 1883 1884 bin_col = bin_col.where(bin_col <= 0, 1) 1885 1886 sum_data = bin_col.sum(axis=0) 1887 1888 self.gene_calc = sum_data
Calculate and store per-cell counts (e.g., number of detected genes).
The method computes a binary (presence/absence) per cell and sums across
features to produce self.gene_calc
.
Update
Sets self.gene_calc
as a pandas.Series.
Side Effects
Uses self.input_data
when available, otherwise self.normalized_data
.
1890 def gene_histograme(self, bins=100): 1891 """ 1892 Plot a histogram of the number of genes detected per cell. 1893 1894 Parameters 1895 ---------- 1896 bins : int, default 100 1897 Number of histogram bins. 1898 1899 Returns 1900 ------- 1901 matplotlib.figure.Figure 1902 Figure containing the histogram of gene contents. 1903 1904 Notes 1905 ----- 1906 Requires `self.gene_calc` to be computed prior to calling. 1907 """ 1908 1909 fig, ax = plt.subplots(figsize=(8, 5)) 1910 1911 _, bin_edges, _ = ax.hist( 1912 self.gene_calc, bins=bins, edgecolor="black", alpha=0.6 1913 ) 1914 1915 mu, sigma = np.mean(self.gene_calc), np.std(self.gene_calc) 1916 1917 x = np.linspace(min(self.gene_calc), max(self.gene_calc), 1000) 1918 y = norm.pdf(x, mu, sigma) 1919 1920 y_scaled = y * len(self.gene_calc) * (bin_edges[1] - bin_edges[0]) 1921 1922 ax.plot( 1923 x, y_scaled, "r-", linewidth=2, label=f"Normal(μ={mu:.2f}, σ={sigma:.2f})" 1924 ) 1925 1926 ax.set_xlabel("Value") 1927 ax.set_ylabel("Count") 1928 ax.set_title("Histogram of genes detected per cell") 1929 1930 ax.set_xticks(np.linspace(min(self.gene_calc), max(self.gene_calc), 20)) 1931 ax.tick_params(axis="x", rotation=90) 1932 1933 ax.legend() 1934 1935 return fig
Plot a histogram of the number of genes detected per cell.
Parameters
bins : int, default 100 Number of histogram bins.
Returns
matplotlib.figure.Figure Figure containing the histogram of gene contents.
Notes
Requires self.gene_calc
to be computed prior to calling.
1937 def gene_threshold(self, min_n: int | None, max_n: int | None): 1938 """ 1939 Filter cells by gene-detection thresholds (min and/or max). 1940 1941 Parameters 1942 ---------- 1943 min_n : int or None 1944 Minimum number of detected genes required to keep a cell. 1945 1946 max_n : int or None 1947 Maximum number of detected genes allowed to keep a cell. 1948 1949 Update 1950 ------- 1951 Filters `self.input_data`, `self.normalized_data`, `self.input_metadata` 1952 (and calls `average()` if `self.agg_normalized_data` exists). 1953 1954 Side Effects 1955 ------------ 1956 Raises ValueError if both bounds are None or if filtering removes all cells. 1957 """ 1958 1959 if min_n is not None and max_n is not None: 1960 mask = (self.gene_calc > min_n) & (self.gene_calc < max_n) 1961 elif min_n is None and max_n is not None: 1962 mask = self.gene_calc < max_n 1963 elif min_n is not None and max_n is None: 1964 mask = self.gene_calc > min_n 1965 else: 1966 raise ValueError("Lack of both min_n and max_n values") 1967 1968 if self.input_data is not None: 1969 1970 if len([y for y in mask if y is False]) == 0: 1971 raise ValueError("Nothing to reduce") 1972 1973 self.input_data = self.input_data.loc[:, mask.values] 1974 1975 if self.normalized_data is not None: 1976 1977 if len([y for y in mask if y is False]) == 0: 1978 raise ValueError("Nothing to reduce") 1979 1980 self.normalized_data = self.normalized_data.loc[:, mask.values] 1981 1982 if self.input_metadata is not None: 1983 1984 if len([y for y in mask if y is False]) == 0: 1985 raise ValueError("Nothing to reduce") 1986 1987 self.input_metadata = self.input_metadata.loc[mask.values, :].reset_index( 1988 drop=True 1989 ) 1990 1991 self.input_metadata = self.input_metadata.drop( 1992 columns=["drop"], errors="ignore" 1993 ) 1994 1995 if self.agg_normalized_data is not None: 1996 self.average() 1997 1998 self.gene_calculation() 1999 self.cells_calculation()
Filter cells by gene-detection thresholds (min and/or max).
Parameters
min_n : int or None Minimum number of detected genes required to keep a cell.
max_n : int or None Maximum number of detected genes allowed to keep a cell.
Update
Filters self.input_data
, self.normalized_data
, self.input_metadata
(and calls average()
if self.agg_normalized_data
exists).
Side Effects
Raises ValueError if both bounds are None or if filtering removes all cells.
2001 def cells_calculation(self, name_slot="cell_names"): 2002 """ 2003 Calculate number of cells per call name / cluster. 2004 2005 The method computes a binary (presence/absence) per cell name / cluster and sums across 2006 cells. 2007 2008 Parameters 2009 ---------- 2010 name_slot : str, default 'cell_names' 2011 Column in metadata to use as sample names. 2012 2013 Update 2014 ------ 2015 Sets `self.cells_calc` as a pd.DataFrame. 2016 """ 2017 2018 ls = list(self.input_metadata[name_slot]) 2019 2020 df = pd.DataFrame( 2021 { 2022 "cluster": pd.Series(ls).value_counts().index, 2023 "n": pd.Series(ls).value_counts().values, 2024 } 2025 ) 2026 2027 self.cells_calc = df
Calculate number of cells per call name / cluster.
The method computes a binary (presence/absence) per cell name / cluster and sums across cells.
Parameters
name_slot : str, default 'cell_names' Column in metadata to use as sample names.
Update
Sets self.cells_calc
as a pd.DataFrame.
2029 def cell_histograme(self, name_slot: str = "cell_names"): 2030 """ 2031 Plot a histogram of the number of cells detected per cell name (cluster). 2032 2033 Parameters 2034 ---------- 2035 name_slot : str, default 'cell_names' 2036 Column in metadata to use as sample names. 2037 2038 Returns 2039 ------- 2040 matplotlib.figure.Figure 2041 Figure containing the histogram of cell contents. 2042 2043 Notes 2044 ----- 2045 Requires `self.cells_calc` to be computed prior to calling. 2046 """ 2047 2048 if name_slot != "cell_names": 2049 self.cells_calculation(name_slot=name_slot) 2050 2051 fig, ax = plt.subplots(figsize=(8, 5)) 2052 2053 _, bin_edges, _ = ax.hist( 2054 list(self.cells_calc["n"]), 2055 bins=len(set(self.cells_calc["cluster"])), 2056 edgecolor="black", 2057 color="orange", 2058 alpha=0.6, 2059 ) 2060 2061 mu, sigma = np.mean(list(self.cells_calc["n"])), np.std( 2062 list(self.cells_calc["n"]) 2063 ) 2064 2065 x = np.linspace( 2066 min(list(self.cells_calc["n"])), max(list(self.cells_calc["n"])), 1000 2067 ) 2068 y = norm.pdf(x, mu, sigma) 2069 2070 y_scaled = y * len(list(self.cells_calc["n"])) * (bin_edges[1] - bin_edges[0]) 2071 2072 ax.plot( 2073 x, y_scaled, "r-", linewidth=2, label=f"Normal(μ={mu:.2f}, σ={sigma:.2f})" 2074 ) 2075 2076 ax.set_xlabel("Value") 2077 ax.set_ylabel("Count") 2078 ax.set_title("Histogram of cells detected per cell name / cluster") 2079 2080 ax.set_xticks( 2081 np.linspace( 2082 min(list(self.cells_calc["n"])), max(list(self.cells_calc["n"])), 20 2083 ) 2084 ) 2085 ax.tick_params(axis="x", rotation=90) 2086 2087 ax.legend() 2088 2089 return fig
Plot a histogram of the number of cells detected per cell name (cluster).
Parameters
name_slot : str, default 'cell_names' Column in metadata to use as sample names.
Returns
matplotlib.figure.Figure Figure containing the histogram of cell contents.
Notes
Requires self.cells_calc
to be computed prior to calling.
2091 def cluster_threshold(self, min_n: int | None, name_slot: str = "cell_names"): 2092 """ 2093 Filter cell names / clusters by cell-detection threshold. 2094 2095 Parameters 2096 ---------- 2097 min_n : int or None 2098 Minimum number of detected genes required to keep a cell. 2099 2100 name_slot : str, default 'cell_names' 2101 Column in metadata to use as sample names. 2102 2103 2104 Update 2105 ------- 2106 Filters `self.input_data`, `self.normalized_data`, `self.input_metadata` 2107 (and calls `average()` if `self.agg_normalized_data` exists). 2108 """ 2109 2110 if name_slot != "cell_names": 2111 self.cells_calculation(name_slot=name_slot) 2112 2113 if min_n is not None: 2114 names = self.cells_calc["cluster"][self.cells_calc["n"] < min_n] 2115 else: 2116 raise ValueError("Lack of min_n value") 2117 2118 if len(names) > 0: 2119 2120 if self.input_data is not None: 2121 2122 self.input_data.columns = self.input_metadata[name_slot] 2123 2124 mask = [not any(r in x for r in names) for x in self.input_data.columns] 2125 2126 if len([y for y in mask if y is False]) > 0: 2127 2128 self.input_data = self.input_data.loc[:, mask] 2129 2130 if self.normalized_data is not None: 2131 2132 self.normalized_data.columns = self.input_metadata[name_slot] 2133 2134 mask = [ 2135 not any(r in x for r in names) for x in self.normalized_data.columns 2136 ] 2137 2138 if len([y for y in mask if y is False]) > 0: 2139 2140 self.normalized_data = self.normalized_data.loc[:, mask] 2141 2142 if self.input_metadata is not None: 2143 2144 self.input_metadata["drop"] = self.input_metadata[name_slot] 2145 2146 mask = [ 2147 not any(r in x for r in names) for x in self.input_metadata["drop"] 2148 ] 2149 2150 if len([y for y in mask if y is False]) > 0: 2151 2152 self.input_metadata = self.input_metadata.loc[mask, :].reset_index( 2153 drop=True 2154 ) 2155 2156 self.input_metadata = self.input_metadata.drop( 2157 columns=["drop"], errors="ignore" 2158 ) 2159 2160 if self.agg_normalized_data is not None: 2161 2162 self.agg_normalized_data.columns = self.agg_metadata[name_slot] 2163 2164 mask = [ 2165 not any(r in x for r in names) 2166 for x in self.agg_normalized_data.columns 2167 ] 2168 2169 if len([y for y in mask if y is False]) > 0: 2170 2171 self.agg_normalized_data = self.agg_normalized_data.loc[:, mask] 2172 2173 if self.agg_metadata is not None: 2174 2175 self.agg_metadata["drop"] = self.agg_metadata[name_slot] 2176 2177 mask = [ 2178 not any(r in x for r in names) for x in self.agg_metadata["drop"] 2179 ] 2180 2181 if len([y for y in mask if y is False]) > 0: 2182 2183 self.agg_metadata = self.agg_metadata.loc[mask, :].reset_index( 2184 drop=True 2185 ) 2186 2187 self.agg_metadata = self.agg_metadata.drop( 2188 columns=["drop"], errors="ignore" 2189 ) 2190 2191 self.gene_calculation() 2192 self.cells_calculation()
Filter cell names / clusters by cell-detection threshold.
Parameters
min_n : int or None Minimum number of detected genes required to keep a cell.
name_slot : str, default 'cell_names' Column in metadata to use as sample names.
Update
Filters self.input_data
, self.normalized_data
, self.input_metadata
(and calls average()
if self.agg_normalized_data
exists).
2194 def load_sparse_from_projects(self, normalized_data: bool = False): 2195 """ 2196 Load sparse 10x-style datasets from stored project paths, concatenate them, 2197 and populate `input_data` / `normalized_data` and `input_metadata`. 2198 2199 Parameters 2200 ---------- 2201 normalized_data : bool, default False 2202 If True, store concatenated tables in `self.normalized_data`. 2203 If False, store them in `self.input_data` and normalization 2204 is needed using normalize_data() method. 2205 2206 Side Effects 2207 ------------ 2208 - Reads each project using `load_sparse(...)` (expects matrix.mtx, genes.tsv, barcodes.tsv). 2209 - Concatenates all projects column-wise and sets `self.input_metadata`. 2210 - Replaces NaNs with zeros and updates `self.gene_calc`. 2211 """ 2212 2213 obj = self.objects 2214 2215 full_data = pd.DataFrame() 2216 full_metadata = pd.DataFrame() 2217 2218 for ke in obj.keys(): 2219 print(ke) 2220 2221 dt, met = load_sparse(path=obj[ke], name=ke) 2222 2223 full_data = pd.concat([full_data, dt], axis=1) 2224 full_metadata = pd.concat([full_metadata, met], axis=0) 2225 2226 full_data[np.isnan(full_data)] = 0 2227 2228 if normalized_data: 2229 self.normalized_data = full_data 2230 self.input_metadata = full_metadata 2231 else: 2232 2233 self.input_data = full_data 2234 self.input_metadata = full_metadata 2235 2236 self.gene_calculation() 2237 self.cells_calculation()
Load sparse 10x-style datasets from stored project paths, concatenate them,
and populate input_data
/ normalized_data
and input_metadata
.
Parameters
normalized_data : bool, default False
If True, store concatenated tables in self.normalized_data
.
If False, store them in self.input_data
and normalization
is needed using normalize_data() method.
Side Effects
- Reads each project using
load_sparse(...)
(expects matrix.mtx, genes.tsv, barcodes.tsv). - Concatenates all projects column-wise and sets
self.input_metadata
. - Replaces NaNs with zeros and updates
self.gene_calc
.
2239 def rename_names(self, mapping: dict, slot: str = "cell_names"): 2240 """ 2241 Rename entries in `self.input_metadata[slot]` according to a provided mapping. 2242 2243 Parameters 2244 ---------- 2245 mapping : dict 2246 Dictionary with keys 'old_name' and 'new_name', each mapping to a list 2247 of equal length describing replacements. 2248 2249 slot : str, default 'cell_names' 2250 Metadata column to operate on. 2251 2252 Update 2253 ------- 2254 Updates `self.input_metadata[slot]` in-place with renamed values. 2255 2256 Raises 2257 ------ 2258 ValueError 2259 If mapping keys are incorrect, lengths differ, or some 'old_name' values 2260 are not present in the metadata column. 2261 """ 2262 2263 if set(["old_name", "new_name"]) != set(mapping.keys()): 2264 raise ValueError( 2265 "Mapping dictionary must contain keys 'old_name' and 'new_name', " 2266 "each with a list of names to change." 2267 ) 2268 2269 if len(mapping["old_name"]) != len(mapping["new_name"]): 2270 raise ValueError( 2271 "Mapping dictionary lists 'old_name' and 'new_name' " 2272 "must have the same length!" 2273 ) 2274 2275 names_vector = list(self.input_metadata[slot]) 2276 2277 if not all(elem in names_vector for elem in list(mapping["old_name"])): 2278 raise ValueError( 2279 f"Some entries from 'old_name' do not exist in the names of slot {slot}." 2280 ) 2281 2282 replace_dict = dict(zip(mapping["old_name"], mapping["new_name"])) 2283 2284 names_vector_ret = [replace_dict.get(item, item) for item in names_vector] 2285 2286 self.input_metadata[slot] = names_vector_ret
Rename entries in self.input_metadata[slot]
according to a provided mapping.
Parameters
mapping : dict Dictionary with keys 'old_name' and 'new_name', each mapping to a list of equal length describing replacements.
slot : str, default 'cell_names' Metadata column to operate on.
Update
Updates self.input_metadata[slot]
in-place with renamed values.
Raises
ValueError If mapping keys are incorrect, lengths differ, or some 'old_name' values are not present in the metadata column.
2288 def rename_subclusters(self, mapping): 2289 """ 2290 Rename labels stored in `self.subclusters_.subclusters` according to mapping. 2291 2292 Parameters 2293 ---------- 2294 mapping : dict 2295 Mapping with keys 'old_name' and 'new_name' (lists of equal length). 2296 2297 Update 2298 ------- 2299 Updates `self.subclusters_.subclusters` with renamed labels. 2300 2301 Raises 2302 ------ 2303 ValueError 2304 If mapping is invalid or old names are not present. 2305 """ 2306 2307 if set(["old_name", "new_name"]) != set(mapping.keys()): 2308 raise ValueError( 2309 "Mapping dictionary must contain keys 'old_name' and 'new_name', " 2310 "each with a list of names to change." 2311 ) 2312 2313 if len(mapping["old_name"]) != len(mapping["new_name"]): 2314 raise ValueError( 2315 "Mapping dictionary lists 'old_name' and 'new_name' " 2316 "must have the same length!" 2317 ) 2318 2319 names_vector = list(self.subclusters_.subclusters) 2320 2321 if not all(elem in names_vector for elem in list(mapping["old_name"])): 2322 raise ValueError( 2323 "Some entries from 'old_name' do not exist in the subcluster names." 2324 ) 2325 2326 replace_dict = dict(zip(mapping["old_name"], mapping["new_name"])) 2327 2328 names_vector_ret = [replace_dict.get(item, item) for item in names_vector] 2329 2330 self.subclusters_.subclusters = names_vector_ret
Rename labels stored in self.subclusters_.subclusters
according to mapping.
Parameters
mapping : dict Mapping with keys 'old_name' and 'new_name' (lists of equal length).
Update
Updates self.subclusters_.subclusters
with renamed labels.
Raises
ValueError If mapping is invalid or old names are not present.
2332 def save_sparse( 2333 self, 2334 path_to_save: str = os.getcwd(), 2335 name_slot: str = "cell_names", 2336 data_slot: str = "normalized", 2337 ): 2338 """ 2339 Export data as 10x-compatible sparse files (matrix.mtx, barcodes.tsv, genes.tsv). 2340 2341 Parameters 2342 ---------- 2343 path_to_save : str, default current working directory 2344 Directory where files will be written. 2345 2346 name_slot : str, default 'cell_names' 2347 Metadata column providing cell names for barcodes.tsv. 2348 2349 data_slot : str, default 'normalized' 2350 Either 'normalized' (uses self.normalized_data) or 'count' (uses self.input_data). 2351 2352 Raises 2353 ------ 2354 ValueError 2355 If `data_slot` is not 'normalized' or 'count'. 2356 """ 2357 2358 names = self.input_metadata[name_slot] 2359 2360 if data_slot.lower() == "normalized": 2361 2362 features = list(self.normalized_data.index) 2363 mtx = sparse.csr_matrix(self.normalized_data) 2364 2365 elif data_slot.lower() == "count": 2366 2367 features = list(self.input_data.index) 2368 mtx = sparse.csr_matrix(self.input_data) 2369 2370 else: 2371 raise ValueError("'data_slot' must be included in 'normalized' or 'count'") 2372 2373 os.makedirs(path_to_save, exist_ok=True) 2374 2375 mmwrite(os.path.join(path_to_save, "matrix.mtx"), mtx) 2376 2377 pd.Series(names).to_csv( 2378 os.path.join(path_to_save, "barcodes.tsv"), 2379 index=False, 2380 header=False, 2381 sep="\t", 2382 ) 2383 2384 pd.Series(features).to_csv( 2385 os.path.join(path_to_save, "genes.tsv"), index=False, header=False, sep="\t" 2386 )
Export data as 10x-compatible sparse files (matrix.mtx, barcodes.tsv, genes.tsv).
Parameters
path_to_save : str, default current working directory Directory where files will be written.
name_slot : str, default 'cell_names' Metadata column providing cell names for barcodes.tsv.
data_slot : str, default 'normalized' Either 'normalized' (uses self.normalized_data) or 'count' (uses self.input_data).
Raises
ValueError
If data_slot
is not 'normalized' or 'count'.
2388 def normalize_counts( 2389 self, normalize_factor: int = 100000, log_transform: bool = True 2390 ): 2391 """ 2392 Normalize raw counts to counts-per-(normalize_factor) 2393 (e.g., CPM, TPM - depending on normalize_factor). 2394 2395 Parameters 2396 ---------- 2397 normalize_factor : int, default 100000 2398 Scaling factor used after dividing by column sums. 2399 2400 log_transform : bool, default True 2401 If True, apply log2(x+1) transformation to normalized values. 2402 2403 Update 2404 ------- 2405 Sets `self.normalized_data` to normalized values (fills NaNs with 0). 2406 2407 Raises 2408 ------ 2409 ValueError 2410 If `self.input_data` is missing (cannot normalize). 2411 """ 2412 if self.input_data is None: 2413 raise ValueError("Input data is missing, cannot normalize.") 2414 2415 sum_col = self.input_data.sum() 2416 self.normalized_data = self.input_data.div(sum_col).fillna(0) * normalize_factor 2417 2418 if log_transform: 2419 # log2(x + 1) to avoid -inf for zeros 2420 self.normalized_data = np.log2(self.normalized_data + 1)
Normalize raw counts to counts-per-(normalize_factor) (e.g., CPM, TPM - depending on normalize_factor).
Parameters
normalize_factor : int, default 100000 Scaling factor used after dividing by column sums.
log_transform : bool, default True If True, apply log2(x+1) transformation to normalized values.
Update
Sets `self.normalized_data` to normalized values (fills NaNs with 0).
Raises
ValueError
If self.input_data
is missing (cannot normalize).
2422 def statistic( 2423 self, 2424 cells=None, 2425 sets=None, 2426 min_exp: float = 0.01, 2427 min_pct: float = 0.1, 2428 n_proc: int = 10, 2429 ): 2430 """ 2431 Compute per-feature statistics (Mann–Whitney U) comparing target vs rest. 2432 2433 This is a wrapper similar to `calc_DEG` tailored to use `self.normalized_data` 2434 and `self.input_metadata`. It returns per-feature statistics including p-values, 2435 adjusted p-values, means, variances, effect-size measures and fold-changes. 2436 2437 Parameters 2438 ---------- 2439 cells : list, 'All', dict, or None 2440 Defines the target cells or groups for comparison (several modes supported). 2441 2442 sets : 'All', dict, or None 2443 Alternative grouping mode (operate on `self.input_metadata['sets']`). 2444 2445 min_exp : float, default 0.01 2446 Minimum expression threshold used when filtering features. 2447 2448 min_pct : float, default 0.1 2449 Minimum proportion of expressing cells in the target group required to test a feature. 2450 2451 n_proc : int, default 10 2452 Number of parallel jobs to use. 2453 2454 Returns 2455 ------- 2456 pandas.DataFrame or dict 2457 Results DataFrame (or dict containing valid/control cells + DataFrame), 2458 similar to `calc_DEG` interface. 2459 2460 Raises 2461 ------ 2462 ValueError 2463 If neither `cells` nor `sets` is provided, or input metadata mismatch occurs. 2464 2465 Notes 2466 ----- 2467 Multiple modes supported: single-list entities, 'All', pairwise dicts, etc. 2468 """ 2469 2470 def stat_calc(choose, feature_name): 2471 target_values = choose.loc[choose["DEG"] == "target", feature_name] 2472 rest_values = choose.loc[choose["DEG"] == "rest", feature_name] 2473 2474 pct_valid = (target_values > 0).sum() / len(target_values) 2475 pct_rest = (rest_values > 0).sum() / len(rest_values) 2476 2477 avg_valid = np.mean(target_values) 2478 avg_ctrl = np.mean(rest_values) 2479 sd_valid = np.std(target_values, ddof=1) 2480 sd_ctrl = np.std(rest_values, ddof=1) 2481 esm = (avg_valid - avg_ctrl) / np.sqrt(((sd_valid**2 + sd_ctrl**2) / 2)) 2482 2483 if np.sum(target_values) == np.sum(rest_values): 2484 p_val = 1.0 2485 else: 2486 _, p_val = stats.mannwhitneyu( 2487 target_values, rest_values, alternative="two-sided" 2488 ) 2489 2490 return { 2491 "feature": feature_name, 2492 "p_val": p_val, 2493 "pct_valid": pct_valid, 2494 "pct_ctrl": pct_rest, 2495 "avg_valid": avg_valid, 2496 "avg_ctrl": avg_ctrl, 2497 "sd_valid": sd_valid, 2498 "sd_ctrl": sd_ctrl, 2499 "esm": esm, 2500 } 2501 2502 def prepare_and_run_stat( 2503 choose, valid_group, min_exp, min_pct, n_proc, factors 2504 ): 2505 2506 tmp_dat = choose[choose["DEG"] == "target"] 2507 tmp_dat = tmp_dat.drop("DEG", axis=1) 2508 2509 counts = (tmp_dat > min_exp).sum(axis=0) 2510 2511 total_count = tmp_dat.shape[0] 2512 2513 info = pd.DataFrame( 2514 {"feature": list(tmp_dat.columns), "pct": list(counts / total_count)} 2515 ) 2516 2517 del tmp_dat 2518 2519 drop_col = info["feature"][info["pct"] <= min_pct] 2520 2521 if len(drop_col) + 1 == len(choose.columns): 2522 drop_col = info["feature"][info["pct"] == 0] 2523 2524 del info 2525 2526 choose = choose.drop(list(drop_col), axis=1) 2527 2528 results = Parallel(n_jobs=n_proc)( 2529 delayed(stat_calc)(choose, feature) 2530 for feature in tqdm(choose.columns[choose.columns != "DEG"]) 2531 ) 2532 2533 df = pd.DataFrame(results) 2534 df["valid_group"] = valid_group 2535 df.sort_values(by="p_val", inplace=True) 2536 2537 num_tests = len(df) 2538 df["adj_pval"] = np.minimum( 2539 1, (df["p_val"] * num_tests) / np.arange(1, num_tests + 1) 2540 ) 2541 2542 factors_df = factors.to_frame(name="factor") 2543 2544 df = df.merge(factors_df, left_on="feature", right_index=True, how="left") 2545 2546 df["FC"] = (df["avg_valid"] + df["factor"]) / ( 2547 df["avg_ctrl"] + df["factor"] 2548 ) 2549 df["log(FC)"] = np.log2(df["FC"]) 2550 df["norm_diff"] = df["avg_valid"] - df["avg_ctrl"] 2551 df = df.drop(columns=["factor"]) 2552 2553 return df 2554 2555 choose = self.normalized_data.copy().T 2556 factors = self.normalized_data.copy().replace(0, np.nan).min(axis=1) 2557 factors[factors == 0] = min(factors[factors != 0]) 2558 2559 final_results = [] 2560 2561 if isinstance(cells, list) and sets is None: 2562 print("\nAnalysis started...\nComparing selected cells to the whole set...") 2563 choose.index = ( 2564 self.input_metadata["cell_names"] + " # " + self.input_metadata["sets"] 2565 ) 2566 2567 if "#" not in cells[0]: 2568 choose.index = self.input_metadata["cell_names"] 2569 2570 print( 2571 "Not include the set info (name # set) in the 'cells' list. " 2572 "Only the names will be compared, without considering the set information." 2573 ) 2574 2575 labels = ["target" if idx in cells else "rest" for idx in choose.index] 2576 valid = list( 2577 set(choose.index[[i for i, x in enumerate(labels) if x == "target"]]) 2578 ) 2579 2580 choose["DEG"] = labels 2581 choose = choose[choose["DEG"] != "drop"] 2582 2583 result_df = prepare_and_run_stat( 2584 choose.reset_index(drop=True), 2585 valid_group=valid, 2586 min_exp=min_exp, 2587 min_pct=min_pct, 2588 n_proc=n_proc, 2589 factors=factors, 2590 ) 2591 return {"valid_cells": valid, "control_cells": "rest", "DEG": result_df} 2592 2593 elif cells == "All" and sets is None: 2594 print("\nAnalysis started...\nComparing each type of cell to others...") 2595 choose.index = ( 2596 self.input_metadata["cell_names"] + " # " + self.input_metadata["sets"] 2597 ) 2598 unique_labels = set(choose.index) 2599 2600 for label in tqdm(unique_labels): 2601 print(f"\nCalculating statistics for {label}") 2602 labels = ["target" if idx == label else "rest" for idx in choose.index] 2603 choose["DEG"] = labels 2604 choose = choose[choose["DEG"] != "drop"] 2605 result_df = prepare_and_run_stat( 2606 choose.copy(), 2607 valid_group=label, 2608 min_exp=min_exp, 2609 min_pct=min_pct, 2610 n_proc=n_proc, 2611 factors=factors, 2612 ) 2613 final_results.append(result_df) 2614 2615 return pd.concat(final_results, ignore_index=True) 2616 2617 elif cells is None and sets == "All": 2618 print("\nAnalysis started...\nComparing each set/group to others...") 2619 choose.index = self.input_metadata["sets"] 2620 unique_sets = set(choose.index) 2621 2622 for label in tqdm(unique_sets): 2623 print(f"\nCalculating statistics for {label}") 2624 labels = ["target" if idx == label else "rest" for idx in choose.index] 2625 2626 choose["DEG"] = labels 2627 choose = choose[choose["DEG"] != "drop"] 2628 result_df = prepare_and_run_stat( 2629 choose.copy(), 2630 valid_group=label, 2631 min_exp=min_exp, 2632 min_pct=min_pct, 2633 n_proc=n_proc, 2634 factors=factors, 2635 ) 2636 final_results.append(result_df) 2637 2638 return pd.concat(final_results, ignore_index=True) 2639 2640 elif cells is None and isinstance(sets, dict): 2641 print("\nAnalysis started...\nComparing groups...") 2642 2643 choose.index = self.input_metadata["sets"] 2644 2645 group_list = list(sets.keys()) 2646 if len(group_list) != 2: 2647 print("Only pairwise group comparison is supported.") 2648 return None 2649 2650 labels = [ 2651 ( 2652 "target" 2653 if idx in sets[group_list[0]] 2654 else "rest" if idx in sets[group_list[1]] else "drop" 2655 ) 2656 for idx in choose.index 2657 ] 2658 choose["DEG"] = labels 2659 choose = choose[choose["DEG"] != "drop"] 2660 2661 result_df = prepare_and_run_stat( 2662 choose.reset_index(drop=True), 2663 valid_group=group_list[0], 2664 min_exp=min_exp, 2665 min_pct=min_pct, 2666 n_proc=n_proc, 2667 factors=factors, 2668 ) 2669 return result_df 2670 2671 elif isinstance(cells, dict) and sets is None: 2672 print("\nAnalysis started...\nComparing groups...") 2673 choose.index = ( 2674 self.input_metadata["cell_names"] + " # " + self.input_metadata["sets"] 2675 ) 2676 2677 if "#" not in cells[list(cells.keys())[0]][0]: 2678 choose.index = self.input_metadata["cell_names"] 2679 2680 print( 2681 "Not include the set info (name # set) in the 'cells' dict. " 2682 "Only the names will be compared, without considering the set information." 2683 ) 2684 2685 group_list = list(cells.keys()) 2686 if len(group_list) != 2: 2687 print("Only pairwise group comparison is supported.") 2688 return None 2689 2690 labels = [ 2691 ( 2692 "target" 2693 if idx in cells[group_list[0]] 2694 else "rest" if idx in cells[group_list[1]] else "drop" 2695 ) 2696 for idx in choose.index 2697 ] 2698 2699 choose["DEG"] = labels 2700 choose = choose[choose["DEG"] != "drop"] 2701 2702 result_df = prepare_and_run_stat( 2703 choose.reset_index(drop=True), 2704 valid_group=group_list[0], 2705 min_exp=min_exp, 2706 min_pct=min_pct, 2707 n_proc=n_proc, 2708 factors=factors, 2709 ) 2710 2711 return result_df.reset_index(drop=True) 2712 2713 else: 2714 raise ValueError( 2715 "You must specify either 'cells' or 'sets' (or both). None were provided, which is not allowed for this analysis." 2716 )
Compute per-feature statistics (Mann–Whitney U) comparing target vs rest.
This is a wrapper similar to calc_DEG
tailored to use self.normalized_data
and self.input_metadata
. It returns per-feature statistics including p-values,
adjusted p-values, means, variances, effect-size measures and fold-changes.
Parameters
cells : list, 'All', dict, or None Defines the target cells or groups for comparison (several modes supported).
sets : 'All', dict, or None
Alternative grouping mode (operate on self.input_metadata['sets']
).
min_exp : float, default 0.01 Minimum expression threshold used when filtering features.
min_pct : float, default 0.1 Minimum proportion of expressing cells in the target group required to test a feature.
n_proc : int, default 10 Number of parallel jobs to use.
Returns
pandas.DataFrame or dict
Results DataFrame (or dict containing valid/control cells + DataFrame),
similar to calc_DEG
interface.
Raises
ValueError
If neither cells
nor sets
is provided, or input metadata mismatch occurs.
Notes
Multiple modes supported: single-list entities, 'All', pairwise dicts, etc.
2718 def calculate_difference_markers( 2719 self, min_exp=0, min_pct=0.25, n_proc=10, force: bool = False 2720 ): 2721 """ 2722 Compute differential markers (var_data) if not already present. 2723 2724 Parameters 2725 ---------- 2726 min_exp : float, default 0 2727 Minimum expression threshold passed to `statistic`. 2728 2729 min_pct : float, default 0.25 2730 Minimum percent expressed in target group. 2731 2732 n_proc : int, default 10 2733 Parallel jobs. 2734 2735 force : bool, default False 2736 If True, recompute even if `self.var_data` is present. 2737 2738 Update 2739 ------- 2740 Sets `self.var_data` to the result of `self.statistic(...)`. 2741 2742 Raise 2743 ------ 2744 ValueError if already computed and `force` is False. 2745 """ 2746 2747 if self.var_data is None or force: 2748 2749 self.var_data = self.statistic( 2750 cells="All", sets=None, min_exp=min_exp, min_pct=min_pct, n_proc=n_proc 2751 ) 2752 2753 else: 2754 raise ValueError( 2755 "self.calculate_difference_markers() has already been executed. " 2756 "The results are stored in self.var. " 2757 "If you want to recalculate with different statistics, please rerun the method with force=True." 2758 )
Compute differential markers (var_data) if not already present.
Parameters
min_exp : float, default 0
Minimum expression threshold passed to statistic
.
min_pct : float, default 0.25 Minimum percent expressed in target group.
n_proc : int, default 10 Parallel jobs.
force : bool, default False
If True, recompute even if self.var_data
is present.
Update
Sets self.var_data
to the result of self.statistic(...)
.
Raise
ValueError if already computed and force
is False.
2760 def clustering_features( 2761 self, 2762 features_list: list | None, 2763 name_slot: str = "cell_names", 2764 p_val: float = 0.05, 2765 top_n: int = 25, 2766 adj_mean: bool = True, 2767 beta: float = 0.2, 2768 ): 2769 """ 2770 Prepare clustering input by selecting marker features and optionally smoothing cell values 2771 toward group means. 2772 2773 Parameters 2774 ---------- 2775 features_list : list or None 2776 If provided, use this list of features. If None, features are selected 2777 from `self.var_data` (adj_pval <= p_val, positive logFC) picking `top_n` per group. 2778 2779 name_slot : str, default 'cell_names' 2780 Metadata column used for naming. 2781 2782 p_val : float, default 0.05 2783 Adjusted p-value cutoff when selecting features automatically. 2784 2785 top_n : int, default 25 2786 Number of top features per valid group to keep if `features_list` is None. 2787 2788 adj_mean : bool, default True 2789 If True, adjust cell values toward group means using `beta`. 2790 2791 beta : float, default 0.2 2792 Adjustment strength toward group mean. 2793 2794 Update 2795 ------ 2796 Sets `self.clustering_data` and `self.clustering_metadata` to the selected subset, 2797 ready for PCA/UMAP/clustering. 2798 """ 2799 2800 if features_list is None or len(features_list) == 0: 2801 2802 if self.var_data is None: 2803 raise ValueError( 2804 "Lack of 'self.var_data'. Use self.calculate_difference_markers() method first." 2805 ) 2806 2807 df_tmp = self.var_data[self.var_data["adj_pval"] <= p_val] 2808 df_tmp = df_tmp[df_tmp["log(FC)"] > 0] 2809 df_tmp = ( 2810 df_tmp.sort_values( 2811 ["valid_group", "esm", "log(FC)"], ascending=[True, False, False] 2812 ) 2813 .groupby("valid_group") 2814 .head(top_n) 2815 ) 2816 2817 feaures_list = list(set(df_tmp["feature"])) 2818 2819 data = self.get_partial_data( 2820 names=None, features=feaures_list, name_slot=name_slot 2821 ) 2822 data_avg = average(data) 2823 2824 if adj_mean: 2825 data = adjust_cells_to_group_mean(data=data, data_avg=data_avg, beta=beta) 2826 2827 self.clustering_data = data 2828 2829 self.clustering_metadata = self.input_metadata
Prepare clustering input by selecting marker features and optionally smoothing cell values toward group means.
Parameters
features_list : list or None
If provided, use this list of features. If None, features are selected
from self.var_data
(adj_pval <= p_val, positive logFC) picking top_n
per group.
name_slot : str, default 'cell_names' Metadata column used for naming.
p_val : float, default 0.05 Adjusted p-value cutoff when selecting features automatically.
top_n : int, default 25
Number of top features per valid group to keep if features_list
is None.
adj_mean : bool, default True
If True, adjust cell values toward group means using beta
.
beta : float, default 0.2 Adjustment strength toward group mean.
Update
Sets self.clustering_data
and self.clustering_metadata
to the selected subset,
ready for PCA/UMAP/clustering.
2831 def average(self): 2832 """ 2833 Aggregate normalized data by (cell_name, set) pairs computing the mean per group. 2834 2835 The method constructs new column names as "cell_name # set", averages columns 2836 sharing identical labels, and populates `self.agg_normalized_data` and `self.agg_metadata`. 2837 2838 Update 2839 ------ 2840 Sets `self.agg_normalized_data` (features x aggregated samples) and 2841 `self.agg_metadata` (DataFrame with 'cell_names' and 'sets'). 2842 """ 2843 2844 wide_data = self.normalized_data 2845 2846 wide_metadata = self.input_metadata 2847 2848 new_names = wide_metadata["cell_names"] + " # " + wide_metadata["sets"] 2849 2850 wide_data.columns = list(new_names) 2851 2852 aggregated_df = wide_data.T.groupby(wide_data.columns, axis=0).mean().T 2853 2854 sets = [re.sub(".*# ", "", x) for x in aggregated_df.columns] 2855 names = [re.sub(" #.*", "", x) for x in aggregated_df.columns] 2856 2857 aggregated_df.columns = names 2858 aggregated_metadata = pd.DataFrame({"cell_names": names, "sets": sets}) 2859 2860 self.agg_metadata = aggregated_metadata 2861 self.agg_normalized_data = aggregated_df
Aggregate normalized data by (cell_name, set) pairs computing the mean per group.
The method constructs new column names as "cell_name # set", averages columns
sharing identical labels, and populates self.agg_normalized_data
and self.agg_metadata
.
Update
Sets self.agg_normalized_data
(features x aggregated samples) and
self.agg_metadata
(DataFrame with 'cell_names' and 'sets').
2863 def estimating_similarity( 2864 self, method="pearson", p_val: float = 0.05, top_n: int = 25 2865 ): 2866 """ 2867 Estimate pairwise similarity and Euclidean distance between aggregated samples. 2868 2869 Parameters 2870 ---------- 2871 method : str, default 'pearson' 2872 Correlation method to use (passed to pandas.DataFrame.corr()). 2873 2874 p_val : float, default 0.05 2875 Adjusted p-value cutoff used to select marker features from `self.var_data`. 2876 2877 top_n : int, default 25 2878 Number of top features per valid group to include. 2879 2880 Update 2881 ------- 2882 Computes a combined table with per-pair correlation and euclidean distance 2883 and stores it in `self.similarity`. 2884 """ 2885 2886 if self.var_data is None: 2887 raise ValueError( 2888 "Lack of 'self.var_data'. Use self.calculate_difference_markers() method first." 2889 ) 2890 2891 if self.agg_normalized_data is None: 2892 self.average() 2893 2894 metadata = self.agg_metadata 2895 data = self.agg_normalized_data 2896 2897 df_tmp = self.var_data[self.var_data["adj_pval"] <= p_val] 2898 df_tmp = df_tmp[df_tmp["log(FC)"] > 0] 2899 df_tmp = ( 2900 df_tmp.sort_values( 2901 ["valid_group", "esm", "log(FC)"], ascending=[True, False, False] 2902 ) 2903 .groupby("valid_group") 2904 .head(top_n) 2905 ) 2906 2907 data = data.loc[list(set(df_tmp["feature"]))] 2908 2909 if len(set(metadata["sets"])) > 1: 2910 data.columns = data.columns + " # " + [x for x in metadata["sets"]] 2911 else: 2912 data = data.copy() 2913 2914 scaler = StandardScaler() 2915 2916 scaled_data = scaler.fit_transform(data) 2917 2918 scaled_df = pd.DataFrame(scaled_data, columns=data.columns) 2919 2920 cor = scaled_df.corr(method=method) 2921 cor_df = cor.stack().reset_index() 2922 cor_df.columns = ["cell1", "cell2", "correlation"] 2923 2924 distances = pdist(scaled_df.T, metric="euclidean") 2925 dist_mat = pd.DataFrame( 2926 squareform(distances), index=scaled_df.columns, columns=scaled_df.columns 2927 ) 2928 dist_df = dist_mat.stack().reset_index() 2929 dist_df.columns = ["cell1", "cell2", "euclidean_dist"] 2930 2931 full = pd.merge(cor_df, dist_df, on=["cell1", "cell2"]) 2932 2933 full = full[full["cell1"] != full["cell2"]] 2934 full = full.reset_index(drop=True) 2935 2936 self.similarity = full
Estimate pairwise similarity and Euclidean distance between aggregated samples.
Parameters
method : str, default 'pearson' Correlation method to use (passed to pandas.DataFrame.corr()).
p_val : float, default 0.05
Adjusted p-value cutoff used to select marker features from self.var_data
.
top_n : int, default 25 Number of top features per valid group to include.
Update
Computes a combined table with per-pair correlation and euclidean distance
and stores it in self.similarity
.
2938 def similarity_plot( 2939 self, 2940 split_sets=True, 2941 set_info: bool = True, 2942 cmap="seismic", 2943 width=12, 2944 height=10, 2945 ): 2946 """ 2947 Visualize pairwise similarity as a scatter plot. 2948 2949 Parameters 2950 ---------- 2951 split_sets : bool, default True 2952 If True and set information is present, split plotting area roughly into two halves to visualize cross-set pairs. 2953 2954 set_info : bool, default True 2955 If True, keep the ' # set' annotation in labels; otherwise strip it. 2956 2957 cmap : str, default 'seismic' 2958 Color map for correlation (hue). 2959 2960 width : int, default 12 2961 Figure width. 2962 2963 height : int, default 10 2964 Figure height. 2965 2966 Returns 2967 ------- 2968 matplotlib.figure.Figure 2969 2970 Raises 2971 ------ 2972 ValueError 2973 If `self.similarity` is None. 2974 2975 Notes 2976 ----- 2977 The function filters pairs by z-scored euclidean distance > 0 to focus on closer pairs. 2978 """ 2979 2980 if self.similarity is None: 2981 raise ValueError( 2982 "Similarity data is missing. Please calculate similarity using self.estimating_similarity." 2983 ) 2984 2985 similarity_data = self.similarity 2986 2987 if " # " in similarity_data["cell1"][0]: 2988 similarity_data["set1"] = [ 2989 re.sub(".*# ", "", x) for x in similarity_data["cell1"] 2990 ] 2991 similarity_data["set2"] = [ 2992 re.sub(".*# ", "", x) for x in similarity_data["cell2"] 2993 ] 2994 2995 if split_sets and " # " in similarity_data["cell1"][0]: 2996 sets = list( 2997 set(list(similarity_data["set1"]) + list(similarity_data["set2"])) 2998 ) 2999 3000 mm = math.ceil(len(sets) / 2) 3001 3002 x_s = sets[0:mm] 3003 y_s = sets[mm : len(sets)] 3004 3005 similarity_data = similarity_data[similarity_data["set1"].isin(x_s)] 3006 similarity_data = similarity_data[similarity_data["set2"].isin(y_s)] 3007 3008 similarity_data = similarity_data.sort_values(["set1", "set2"]) 3009 3010 if set_info is False and " # " in similarity_data["cell1"][0]: 3011 similarity_data["cell1"] = [ 3012 re.sub(" #.*", "", x) for x in similarity_data["cell1"] 3013 ] 3014 similarity_data["cell2"] = [ 3015 re.sub(" #.*", "", x) for x in similarity_data["cell2"] 3016 ] 3017 3018 similarity_data["-euclidean_zscore"] = -zscore( 3019 similarity_data["euclidean_dist"] 3020 ) 3021 3022 similarity_data = similarity_data[similarity_data["-euclidean_zscore"] > 0] 3023 3024 fig = plt.figure(figsize=(width, height)) 3025 sns.scatterplot( 3026 data=similarity_data, 3027 x="cell1", 3028 y="cell2", 3029 hue="correlation", 3030 size="-euclidean_zscore", 3031 sizes=(1, 100), 3032 palette=cmap, 3033 alpha=1, 3034 edgecolor="black", 3035 ) 3036 3037 plt.xticks(rotation=90) 3038 plt.yticks(rotation=0) 3039 plt.xlabel("Cell 1") 3040 plt.ylabel("Cell 2") 3041 plt.legend(bbox_to_anchor=(1.05, 1), loc="upper left") 3042 3043 plt.grid(True, alpha=0.6) 3044 3045 plt.tight_layout() 3046 3047 return fig
Visualize pairwise similarity as a scatter plot.
Parameters
split_sets : bool, default True If True and set information is present, split plotting area roughly into two halves to visualize cross-set pairs.
set_info : bool, default True If True, keep the ' # set' annotation in labels; otherwise strip it.
cmap : str, default 'seismic' Color map for correlation (hue).
width : int, default 12 Figure width.
height : int, default 10 Figure height.
Returns
matplotlib.figure.Figure
Raises
ValueError
If self.similarity
is None.
Notes
The function filters pairs by z-scored euclidean distance > 0 to focus on closer pairs.
3049 def spatial_similarity( 3050 self, 3051 set_info: bool = True, 3052 bandwidth=1, 3053 n_neighbors=5, 3054 min_dist=0.1, 3055 legend_split=2, 3056 point_size=100, 3057 spread=1.0, 3058 set_op_mix_ratio=1.0, 3059 local_connectivity=1, 3060 repulsion_strength=1.0, 3061 negative_sample_rate=5, 3062 threshold=0.1, 3063 width=12, 3064 height=10, 3065 ): 3066 """ 3067 Create a spatial UMAP-like visualization of similarity relationships between samples. 3068 3069 Parameters 3070 ---------- 3071 set_info : bool, default True 3072 If True, retain set information in labels. 3073 3074 bandwidth : float, default 1 3075 Bandwidth used by MeanShift for clustering polygons. 3076 3077 point_size : float, default 100 3078 Size of scatter points. 3079 3080 legend_split : int, default 2 3081 Number of columns in legend. 3082 3083 n_neighbors, min_dist, spread, set_op_mix_ratio, local_connectivity, repulsion_strength, negative_sample_rate : parameters passed to UMAP. 3084 3085 threshold : float, default 0.1 3086 Minimum text distance for label adjustment to avoid overlap. 3087 3088 width : int, default 12 3089 Figure width. 3090 3091 height : int, default 10 3092 Figure height. 3093 3094 Returns 3095 ------- 3096 matplotlib.figure.Figure 3097 3098 Raises 3099 ------ 3100 ValueError 3101 If `self.similarity` is None. 3102 3103 Notes 3104 ----- 3105 Builds a precomputed distance matrix combining correlation and euclidean distance, 3106 runs UMAP with metric='precomputed', then overlays cluster hulls (MeanShift + convex hull) 3107 and arrows to indicate nearest neighbors (minimal combined distance). 3108 """ 3109 3110 if self.similarity is None: 3111 raise ValueError( 3112 "Similarity data is missing. Please calculate similarity using self.estimating_similarity." 3113 ) 3114 3115 similarity_data = self.similarity 3116 3117 sim = similarity_data["correlation"] 3118 sim_scaled = (sim - sim.min()) / (sim.max() - sim.min()) 3119 eu_dist = similarity_data["euclidean_dist"] 3120 eu_dist_scaled = (eu_dist - eu_dist.min()) / (eu_dist.max() - eu_dist.min()) 3121 3122 similarity_data["combo_dist"] = (1 - sim_scaled) * eu_dist_scaled 3123 3124 # for nn target 3125 arrow_df = similarity_data.copy() 3126 arrow_df = similarity_data.loc[ 3127 similarity_data.groupby("cell1")["combo_dist"].idxmin() 3128 ] 3129 3130 cells = sorted(set(similarity_data["cell1"]) | set(similarity_data["cell2"])) 3131 combo_matrix = pd.DataFrame(0, index=cells, columns=cells, dtype=float) 3132 3133 for _, row in similarity_data.iterrows(): 3134 combo_matrix.loc[row["cell1"], row["cell2"]] = row["combo_dist"] 3135 combo_matrix.loc[row["cell2"], row["cell1"]] = row["combo_dist"] 3136 3137 umap_model = umap.UMAP( 3138 n_components=2, 3139 metric="precomputed", 3140 n_neighbors=n_neighbors, 3141 min_dist=min_dist, 3142 spread=spread, 3143 set_op_mix_ratio=set_op_mix_ratio, 3144 local_connectivity=set_op_mix_ratio, 3145 repulsion_strength=repulsion_strength, 3146 negative_sample_rate=negative_sample_rate, 3147 transform_seed=42, 3148 init="spectral", 3149 random_state=42, 3150 verbose=True, 3151 ) 3152 3153 coords = umap_model.fit_transform(combo_matrix.values) 3154 cell_names = list(combo_matrix.index) 3155 num_cells = len(cell_names) 3156 palette = sns.color_palette("tab20c", num_cells) 3157 3158 if "#" in cell_names[0]: 3159 avsets = set( 3160 [re.sub(".*# ", "", x) for x in similarity_data["cell1"]] 3161 + [re.sub(".*# ", "", x) for x in similarity_data["cell2"]] 3162 ) 3163 num_sets = len(avsets) 3164 color_indices = [i * len(palette) // num_sets for i in range(num_sets)] 3165 color_mapping_sets = { 3166 set_name: palette[i] for i, set_name in zip(color_indices, avsets) 3167 } 3168 color_mapping = { 3169 name: color_mapping_sets[re.sub(".*# ", "", name)] 3170 for i, name in enumerate(cell_names) 3171 } 3172 else: 3173 color_mapping = {name: palette[i] for i, name in enumerate(cell_names)} 3174 3175 meanshift = MeanShift(bandwidth=bandwidth) 3176 labels = meanshift.fit_predict(coords) 3177 3178 fig = plt.figure(figsize=(width, height)) 3179 ax = plt.gca() 3180 3181 unique_labels = set(labels) 3182 cluster_palette = sns.color_palette("hls", len(unique_labels)) 3183 3184 for label in unique_labels: 3185 if label == -1: 3186 continue 3187 cluster_coords = coords[labels == label] 3188 if len(cluster_coords) < 3: 3189 continue 3190 3191 hull = ConvexHull(cluster_coords) 3192 hull_points = cluster_coords[hull.vertices] 3193 3194 centroid = np.mean(hull_points, axis=0) 3195 expanded = hull_points + 0.05 * (hull_points - centroid) 3196 3197 poly = Polygon( 3198 expanded, 3199 closed=True, 3200 facecolor=cluster_palette[label], 3201 edgecolor="none", 3202 alpha=0.2, 3203 zorder=1, 3204 ) 3205 ax.add_patch(poly) 3206 3207 texts = [] 3208 for i, (x, y) in enumerate(coords): 3209 plt.scatter( 3210 x, 3211 y, 3212 s=point_size, 3213 color=color_mapping[cell_names[i]], 3214 edgecolors="black", 3215 linewidths=0.5, 3216 zorder=2, 3217 ) 3218 texts.append( 3219 ax.text( 3220 x, y, str(i), ha="center", va="center", fontsize=8, color="black" 3221 ) 3222 ) 3223 3224 def dist(p1, p2): 3225 return np.sqrt((p1[0] - p2[0]) ** 2 + (p1[1] - p2[1]) ** 2) 3226 3227 texts_to_adjust = [] 3228 for i, t1 in enumerate(texts): 3229 for j, t2 in enumerate(texts): 3230 if i >= j: 3231 continue 3232 d = dist( 3233 (t1.get_position()[0], t1.get_position()[1]), 3234 (t2.get_position()[0], t2.get_position()[1]), 3235 ) 3236 if d < threshold: 3237 if t1 not in texts_to_adjust: 3238 texts_to_adjust.append(t1) 3239 if t2 not in texts_to_adjust: 3240 texts_to_adjust.append(t2) 3241 3242 adjust_text( 3243 texts_to_adjust, 3244 expand_text=(1.0, 1.0), 3245 force_text=0.9, 3246 arrowprops=dict(arrowstyle="-", color="gray", lw=0.1), 3247 ax=ax, 3248 ) 3249 3250 for _, row in arrow_df.iterrows(): 3251 try: 3252 idx1 = cell_names.index(row["cell1"]) 3253 idx2 = cell_names.index(row["cell2"]) 3254 except ValueError: 3255 continue 3256 x1, y1 = coords[idx1] 3257 x2, y2 = coords[idx2] 3258 arrow = FancyArrowPatch( 3259 (x1, y1), 3260 (x2, y2), 3261 arrowstyle="->", 3262 color="gray", 3263 linewidth=1.5, 3264 alpha=0.5, 3265 mutation_scale=12, 3266 zorder=0, 3267 ) 3268 ax.add_patch(arrow) 3269 3270 if set_info is False and " # " in cell_names[0]: 3271 3272 legend_elements = [ 3273 Patch( 3274 facecolor=color_mapping[name], 3275 edgecolor="black", 3276 label=f"{i} – {re.sub(' #.*', '', name)}", 3277 ) 3278 for i, name in enumerate(cell_names) 3279 ] 3280 3281 else: 3282 3283 legend_elements = [ 3284 Patch( 3285 facecolor=color_mapping[name], 3286 edgecolor="black", 3287 label=f"{i} – {name}", 3288 ) 3289 for i, name in enumerate(cell_names) 3290 ] 3291 3292 plt.legend( 3293 handles=legend_elements, 3294 title="Cells", 3295 bbox_to_anchor=(1.05, 1), 3296 loc="upper left", 3297 ncol=legend_split, 3298 ) 3299 3300 plt.xlabel("UMAP 1") 3301 plt.ylabel("UMAP 2") 3302 plt.grid(False) 3303 plt.show() 3304 3305 return fig
Create a spatial UMAP-like visualization of similarity relationships between samples.
Parameters
set_info : bool, default True If True, retain set information in labels.
bandwidth : float, default 1 Bandwidth used by MeanShift for clustering polygons.
point_size : float, default 100 Size of scatter points.
legend_split : int, default 2 Number of columns in legend.
n_neighbors, min_dist, spread, set_op_mix_ratio, local_connectivity, repulsion_strength, negative_sample_rate : parameters passed to UMAP.
threshold : float, default 0.1 Minimum text distance for label adjustment to avoid overlap.
width : int, default 12 Figure width.
height : int, default 10 Figure height.
Returns
matplotlib.figure.Figure
Raises
ValueError
If self.similarity
is None.
Notes
Builds a precomputed distance matrix combining correlation and euclidean distance, runs UMAP with metric='precomputed', then overlays cluster hulls (MeanShift + convex hull) and arrows to indicate nearest neighbors (minimal combined distance).
3309 def subcluster_prepare(self, features: list, cluster: str): 3310 """ 3311 Prepare a `Clustering` object for subcluster analysis on a selected parent cluster. 3312 3313 Parameters 3314 ---------- 3315 features : list 3316 Features to include for subcluster analysis. 3317 3318 cluster : str 3319 Parent cluster name (used to select matching cells). 3320 3321 Update 3322 ------ 3323 Initializes `self.subclusters_` as a new `Clustering` instance containing the 3324 reduced data for the given cluster and stores `current_features` and `current_cluster`. 3325 """ 3326 3327 dat = self.normalized_data 3328 dat.columns = list(self.input_metadata["cell_names"]) 3329 3330 dat = reduce_data(self.normalized_data, features=features, names=[cluster]) 3331 3332 self.subclusters_ = Clustering(data=dat, metadata=None) 3333 3334 self.subclusters_.current_features = features 3335 self.subclusters_.current_cluster = cluster
Prepare a Clustering
object for subcluster analysis on a selected parent cluster.
Parameters
features : list Features to include for subcluster analysis.
cluster : str Parent cluster name (used to select matching cells).
Update
Initializes self.subclusters_
as a new Clustering
instance containing the
reduced data for the given cluster and stores current_features
and current_cluster
.
3337 def define_subclusters( 3338 self, 3339 umap_num: int = 2, 3340 eps: float = 0.5, 3341 min_samples: int = 10, 3342 n_neighbors: int = 5, 3343 min_dist: float = 0.1, 3344 spread: float = 1.0, 3345 set_op_mix_ratio: float = 1.0, 3346 local_connectivity: int = 1, 3347 repulsion_strength: float = 1.0, 3348 negative_sample_rate: int = 5, 3349 width=8, 3350 height=6, 3351 ): 3352 """ 3353 Compute UMAP and DBSCAN clustering within a previously prepared subcluster dataset. 3354 3355 Parameters 3356 ---------- 3357 umap_num : int, default 2 3358 Number of UMAP dimensions to compute. 3359 3360 eps : float, default 0.5 3361 DBSCAN eps parameter. 3362 3363 min_samples : int, default 10 3364 DBSCAN min_samples parameter. 3365 3366 n_neighbors, min_dist, spread, set_op_mix_ratio, local_connectivity, repulsion_strength, negative_sample_rate, width, height : 3367 Additional parameters passed to UMAP / plotting / MeanShift as appropriate. 3368 3369 Update 3370 ------- 3371 Stores cluster labels in `self.subclusters_.subclusters`. 3372 3373 Raises 3374 ------ 3375 RuntimeError 3376 If `self.subclusters_` has not been prepared. 3377 """ 3378 3379 if self.subclusters_ is None: 3380 raise RuntimeError( 3381 "Nothing to return. 'self.subcluster_prepare' was not conducted!" 3382 ) 3383 3384 self.subclusters_.perform_UMAP( 3385 factorize=False, 3386 umap_num=umap_num, 3387 pc_num=0, 3388 harmonized=False, 3389 n_neighbors=n_neighbors, 3390 min_dist=min_dist, 3391 spread=spread, 3392 set_op_mix_ratio=set_op_mix_ratio, 3393 local_connectivity=local_connectivity, 3394 repulsion_strength=repulsion_strength, 3395 negative_sample_rate=negative_sample_rate, 3396 width=width, 3397 height=height, 3398 ) 3399 3400 fig = self.subclusters_.find_clusters_UMAP( 3401 umap_n=umap_num, 3402 eps=eps, 3403 min_samples=min_samples, 3404 width=width, 3405 height=height, 3406 ) 3407 3408 clusters = self.subclusters_.return_clusters(clusters="umap") 3409 3410 self.subclusters_.subclusters = [str(x) for x in list(clusters)] 3411 3412 return fig
Compute UMAP and DBSCAN clustering within a previously prepared subcluster dataset.
Parameters
umap_num : int, default 2 Number of UMAP dimensions to compute.
eps : float, default 0.5 DBSCAN eps parameter.
min_samples : int, default 10 DBSCAN min_samples parameter.
n_neighbors, min_dist, spread, set_op_mix_ratio, local_connectivity, repulsion_strength, negative_sample_rate, width, height : Additional parameters passed to UMAP / plotting / MeanShift as appropriate.
Update
Stores cluster labels in self.subclusters_.subclusters
.
Raises
RuntimeError
If self.subclusters_
has not been prepared.
3414 def subcluster_features_scatter( 3415 self, 3416 colors="viridis", 3417 hclust="complete", 3418 scale=False, 3419 img_width=3, 3420 img_high=5, 3421 label_size=6, 3422 size_scale=70, 3423 y_lab="Genes", 3424 legend_lab="normalized", 3425 bbox_to_anchor_scale: int = 25, 3426 bbox_to_anchor_perc: tuple = (0.91, 0.63), 3427 ): 3428 """ 3429 Create a features-scatter visualization for the subclusters (averaged and occurrence). 3430 3431 Parameters 3432 ---------- 3433 colors : str, default 'viridis' 3434 Colormap name passed to `features_scatter`. 3435 3436 hclust : str or None 3437 Hierarchical clustering linkage to order rows/columns. 3438 3439 scale: bool, default False 3440 If True, expression data will be scaled (0–1) across the rows (features). 3441 3442 img_width, img_high : float 3443 Figure size. 3444 3445 label_size : int 3446 Font size for labels. 3447 3448 size_scale : int 3449 Bubble size scaling. 3450 3451 y_lab : str 3452 X axis label. 3453 3454 legend_lab : str 3455 Colorbar label. 3456 3457 bbox_to_anchor_scale : int, default=25 3458 Vertical scale (percentage) for positioning the colorbar. 3459 3460 bbox_to_anchor_perc : tuple, default=(0.91, 0.63) 3461 Anchor position for the size legend (percent bubble legend). 3462 3463 Returns 3464 ------- 3465 matplotlib.figure.Figure 3466 3467 Raises 3468 ------ 3469 RuntimeError 3470 If subcluster preparation/definition has not been run. 3471 """ 3472 3473 if self.subclusters_ is None or self.subclusters_.subclusters is None: 3474 raise RuntimeError( 3475 "Nothing to return. 'self.subcluster_prepare' -> 'self.define_subclusters' pip was not conducted!" 3476 ) 3477 3478 dat = self.normalized_data 3479 dat.columns = list(self.input_metadata["cell_names"]) 3480 3481 dat = reduce_data( 3482 self.normalized_data, 3483 features=self.subclusters_.current_features, 3484 names=[self.subclusters_.current_cluster], 3485 ) 3486 3487 dat.columns = self.subclusters_.subclusters 3488 3489 avg = average(dat) 3490 occ = occurrence(dat) 3491 3492 scatter = features_scatter( 3493 expression_data=avg, 3494 occurence_data=occ, 3495 features=None, 3496 scale=scale, 3497 metadata_list=None, 3498 colors=colors, 3499 hclust=hclust, 3500 img_width=img_width, 3501 img_high=img_high, 3502 label_size=label_size, 3503 size_scale=size_scale, 3504 y_lab=y_lab, 3505 legend_lab=legend_lab, 3506 bbox_to_anchor_scale=bbox_to_anchor_scale, 3507 bbox_to_anchor_perc=bbox_to_anchor_perc, 3508 ) 3509 3510 return scatter
Create a features-scatter visualization for the subclusters (averaged and occurrence).
Parameters
colors : str, default 'viridis'
Colormap name passed to features_scatter
.
hclust : str or None Hierarchical clustering linkage to order rows/columns.
scale: bool, default False If True, expression data will be scaled (0–1) across the rows (features).
img_width, img_high : float Figure size.
label_size : int Font size for labels.
size_scale : int Bubble size scaling.
y_lab : str X axis label.
legend_lab : str Colorbar label.
bbox_to_anchor_scale : int, default=25 Vertical scale (percentage) for positioning the colorbar.
bbox_to_anchor_perc : tuple, default=(0.91, 0.63) Anchor position for the size legend (percent bubble legend).
Returns
matplotlib.figure.Figure
Raises
RuntimeError If subcluster preparation/definition has not been run.
3512 def subcluster_DEG_scatter( 3513 self, 3514 top_n=3, 3515 min_exp=0, 3516 min_pct=0.25, 3517 p_val=0.05, 3518 colors="viridis", 3519 hclust="complete", 3520 scale=False, 3521 img_width=3, 3522 img_high=5, 3523 label_size=6, 3524 size_scale=70, 3525 y_lab="Genes", 3526 legend_lab="normalized", 3527 bbox_to_anchor_scale: int = 25, 3528 bbox_to_anchor_perc: tuple = (0.91, 0.63), 3529 n_proc=10, 3530 ): 3531 """ 3532 Plot top differential features (DEGs) for subclusters as a features-scatter. 3533 3534 Parameters 3535 ---------- 3536 top_n : int, default 3 3537 Number of top features per subcluster to show. 3538 3539 min_exp : float, default 0 3540 Minimum expression threshold passed to `statistic`. 3541 3542 min_pct : float, default 0.25 3543 Minimum percent expressed in target group. 3544 3545 p_val: float, default 0.05 3546 Maximum p-value for visualizing features. 3547 3548 n_proc : int, default 10 3549 Parallel jobs used for DEG calculation. 3550 3551 scale: bool, default False 3552 If True, expression_data will be scaled (0–1) across the rows (features). 3553 3554 colors : str, default='viridis' 3555 Colormap for expression values. 3556 3557 hclust : str or None, default='complete' 3558 Linkage method for hierarchical clustering. If None, no clustering 3559 is performed. 3560 3561 img_width : int or float, default=8 3562 Width of the plot in inches. 3563 3564 img_high : int or float, default=5 3565 Height of the plot in inches. 3566 3567 label_size : int, default=10 3568 Font size for axis labels and ticks. 3569 3570 size_scale : int or float, default=100 3571 Scaling factor for bubble sizes. 3572 3573 y_lab : str, default='Genes' 3574 Label for the x-axis. 3575 3576 legend_lab : str, default='normalized' 3577 Label for the colorbar legend. 3578 3579 bbox_to_anchor_scale : int, default=25 3580 Vertical scale (percentage) for positioning the colorbar. 3581 3582 bbox_to_anchor_perc : tuple, default=(0.91, 0.63) 3583 Anchor position for the size legend (percent bubble legend). 3584 3585 Returns 3586 ------- 3587 matplotlib.figure.Figure 3588 3589 Raises 3590 ------ 3591 RuntimeError 3592 If subcluster preparation/definition has not been run. 3593 3594 Notes 3595 ----- 3596 Internally calls `calc_DEG` (or equivalent) to obtain statistics, filters 3597 by p-value and effect-size, selects top features per valid group and plots them. 3598 """ 3599 3600 if self.subclusters_ is None or self.subclusters_.subclusters is None: 3601 raise RuntimeError( 3602 "Nothing to return. 'self.subcluster_prepare' -> 'self.define_subclusters' pip was not conducted!" 3603 ) 3604 3605 dat = self.normalized_data 3606 dat.columns = list(self.input_metadata["cell_names"]) 3607 3608 dat = reduce_data( 3609 self.normalized_data, names=[self.subclusters_.current_cluster] 3610 ) 3611 3612 dat.columns = self.subclusters_.subclusters 3613 3614 deg_stats = calc_DEG( 3615 dat, 3616 metadata_list=None, 3617 entities="All", 3618 sets=None, 3619 min_exp=min_exp, 3620 min_pct=min_pct, 3621 n_proc=n_proc, 3622 ) 3623 3624 deg_stats = deg_stats[deg_stats["p_val"] <= p_val] 3625 deg_stats = deg_stats[deg_stats["log(FC)"] > 0] 3626 3627 deg_stats = ( 3628 deg_stats.sort_values( 3629 ["valid_group", "esm", "log(FC)"], ascending=[True, False, False] 3630 ) 3631 .groupby("valid_group") 3632 .head(top_n) 3633 ) 3634 3635 dat = reduce_data(dat, features=list(set(deg_stats["feature"]))) 3636 3637 avg = average(dat) 3638 occ = occurrence(dat) 3639 3640 scatter = features_scatter( 3641 expression_data=avg, 3642 occurence_data=occ, 3643 features=None, 3644 metadata_list=None, 3645 colors=colors, 3646 hclust=hclust, 3647 img_width=img_width, 3648 img_high=img_high, 3649 label_size=label_size, 3650 size_scale=size_scale, 3651 y_lab=y_lab, 3652 legend_lab=legend_lab, 3653 bbox_to_anchor_scale=bbox_to_anchor_scale, 3654 bbox_to_anchor_perc=bbox_to_anchor_perc, 3655 ) 3656 3657 return scatter
Plot top differential features (DEGs) for subclusters as a features-scatter.
Parameters
top_n : int, default 3 Number of top features per subcluster to show.
min_exp : float, default 0
Minimum expression threshold passed to statistic
.
min_pct : float, default 0.25 Minimum percent expressed in target group.
p_val: float, default 0.05 Maximum p-value for visualizing features.
n_proc : int, default 10 Parallel jobs used for DEG calculation.
scale: bool, default False If True, expression_data will be scaled (0–1) across the rows (features).
colors : str, default='viridis' Colormap for expression values.
hclust : str or None, default='complete' Linkage method for hierarchical clustering. If None, no clustering is performed.
img_width : int or float, default=8 Width of the plot in inches.
img_high : int or float, default=5 Height of the plot in inches.
label_size : int, default=10 Font size for axis labels and ticks.
size_scale : int or float, default=100 Scaling factor for bubble sizes.
y_lab : str, default='Genes' Label for the x-axis.
legend_lab : str, default='normalized' Label for the colorbar legend.
bbox_to_anchor_scale : int, default=25 Vertical scale (percentage) for positioning the colorbar.
bbox_to_anchor_perc : tuple, default=(0.91, 0.63) Anchor position for the size legend (percent bubble legend).
Returns
matplotlib.figure.Figure
Raises
RuntimeError If subcluster preparation/definition has not been run.
Notes
Internally calls calc_DEG
(or equivalent) to obtain statistics, filters
by p-value and effect-size, selects top features per valid group and plots them.
3659 def accept_subclusters(self): 3660 """ 3661 Commit subcluster labels into the main `input_metadata` by renaming cell names. 3662 3663 The method replaces occurrences of the parent cluster name in `self.input_metadata['cell_names']` 3664 with the expanded names that include subcluster suffixes (via `add_subnames`), 3665 then clears `self.subclusters_`. 3666 3667 Update 3668 ------ 3669 Modifies `self.input_metadata['cell_names']`. 3670 3671 Resets `self.subclusters_` to None. 3672 3673 Raises 3674 ------ 3675 RuntimeError 3676 If `self.subclusters_` is not defined or subclusters were not computed. 3677 """ 3678 3679 if self.subclusters_ is None or self.subclusters_.subclusters is None: 3680 raise RuntimeError( 3681 "Nothing to return. 'self.subcluster_prepare' -> 'self.define_subclusters' pip was not conducted!" 3682 ) 3683 3684 new_meta = add_subnames( 3685 list(self.input_metadata["cell_names"]), 3686 parent_name=self.subclusters_.current_cluster, 3687 new_clusters=self.subclusters_.subclusters, 3688 ) 3689 3690 self.input_metadata["cell_names"] = new_meta 3691 3692 self.subclusters_ = None
Commit subcluster labels into the main input_metadata
by renaming cell names.
The method replaces occurrences of the parent cluster name in self.input_metadata['cell_names']
with the expanded names that include subcluster suffixes (via add_subnames
),
then clears self.subclusters_
.
Update
Modifies self.input_metadata['cell_names']
.
Resets self.subclusters_
to None.
Raises
RuntimeError
If self.subclusters_
is not defined or subclusters were not computed.
3694 def scatter_plot( 3695 self, 3696 names: list | None = None, 3697 features: list | None = None, 3698 name_slot: str = "cell_names", 3699 scale=True, 3700 colors="viridis", 3701 hclust=None, 3702 img_width=15, 3703 img_high=1, 3704 label_size=10, 3705 size_scale=200, 3706 y_lab="Genes", 3707 legend_lab="log(CPM + 1)", 3708 set_box_size: float | int = 5, 3709 set_box_high: float | int = 5, 3710 bbox_to_anchor_scale=25, 3711 bbox_to_anchor_perc=(0.90, -0.24), 3712 bbox_to_anchor_group=(1.01, 0.4), 3713 ): 3714 """ 3715 Create a bubble scatter plot of selected features across samples inside project. 3716 3717 Each point represents a feature-sample pair, where the color encodes the 3718 expression value and the size encodes occurrence or relative abundance. 3719 Optionally, hierarchical clustering can be applied to order rows and columns. 3720 3721 Parameters 3722 ---------- 3723 names : list, str, or None 3724 Names of samples to include. If None, all samples are considered. 3725 3726 features : list, str, or None 3727 Names of features to include. If None, all features are considered. 3728 3729 name_slot : str 3730 Column in metadata to use as sample names. 3731 3732 scale: bool, default False 3733 If True, expression_data will be scaled (0–1) across the rows (features). 3734 3735 colors : str, default='viridis' 3736 Colormap for expression values. 3737 3738 hclust : str or None, default='complete' 3739 Linkage method for hierarchical clustering. If None, no clustering 3740 is performed. 3741 3742 img_width : int or float, default=8 3743 Width of the plot in inches. 3744 3745 img_high : int or float, default=5 3746 Height of the plot in inches. 3747 3748 label_size : int, default=10 3749 Font size for axis labels and ticks. 3750 3751 size_scale : int or float, default=100 3752 Scaling factor for bubble sizes. 3753 3754 y_lab : str, default='Genes' 3755 Label for the x-axis. 3756 3757 legend_lab : str, default='log(CPM + 1)' 3758 Label for the colorbar legend. 3759 3760 bbox_to_anchor_scale : int, default=25 3761 Vertical scale (percentage) for positioning the colorbar. 3762 3763 bbox_to_anchor_perc : tuple, default=(0.91, 0.63) 3764 Anchor position for the size legend (percent bubble legend). 3765 3766 bbox_to_anchor_group : tuple, default=(1.01, 0.4) 3767 Anchor position for the group legend. 3768 3769 Returns 3770 ------- 3771 matplotlib.figure.Figure 3772 The generated scatter plot figure. 3773 3774 Notes 3775 ----- 3776 Colors represent expression values normalized to the colormap. 3777 """ 3778 3779 prtd, met = self.get_partial_data( 3780 names=names, features=features, name_slot=name_slot, inc_metadata=True 3781 ) 3782 3783 if scale: 3784 3785 legend_lab = "Scaled\n" + legend_lab 3786 3787 scaler = MinMaxScaler(feature_range=(0, 1)) 3788 prtd = pd.DataFrame( 3789 scaler.fit_transform(prtd.T).T, 3790 index=prtd.index, 3791 columns=prtd.columns, 3792 ) 3793 3794 prtd.columns = prtd.columns + "#" + met["sets"] 3795 3796 prtd_avg = average(prtd) 3797 3798 meta_sets = [re.sub(".*#", "", x) for x in prtd_avg.columns] 3799 3800 prtd_avg.columns = [re.sub("#.*", "", x) for x in prtd_avg.columns] 3801 3802 prtd_occ = occurrence(prtd) 3803 3804 prtd_occ.columns = [re.sub("#.*", "", x) for x in prtd_occ.columns] 3805 3806 fig_scatter = features_scatter( 3807 expression_data=prtd_avg, 3808 occurence_data=prtd_occ, 3809 scale=scale, 3810 features=None, 3811 metadata_list=meta_sets, 3812 colors=colors, 3813 hclust=hclust, 3814 img_width=img_width, 3815 img_high=img_high, 3816 label_size=label_size, 3817 size_scale=size_scale, 3818 y_lab=y_lab, 3819 legend_lab=legend_lab, 3820 set_box_size=set_box_size, 3821 set_box_high=set_box_high, 3822 bbox_to_anchor_scale=bbox_to_anchor_scale, 3823 bbox_to_anchor_perc=bbox_to_anchor_perc, 3824 bbox_to_anchor_group=bbox_to_anchor_group, 3825 ) 3826 3827 return fig_scatter
Create a bubble scatter plot of selected features across samples inside project.
Each point represents a feature-sample pair, where the color encodes the expression value and the size encodes occurrence or relative abundance. Optionally, hierarchical clustering can be applied to order rows and columns.
Parameters
names : list, str, or None Names of samples to include. If None, all samples are considered.
features : list, str, or None Names of features to include. If None, all features are considered.
name_slot : str Column in metadata to use as sample names.
scale: bool, default False If True, expression_data will be scaled (0–1) across the rows (features).
colors : str, default='viridis' Colormap for expression values.
hclust : str or None, default='complete' Linkage method for hierarchical clustering. If None, no clustering is performed.
img_width : int or float, default=8 Width of the plot in inches.
img_high : int or float, default=5 Height of the plot in inches.
label_size : int, default=10 Font size for axis labels and ticks.
size_scale : int or float, default=100 Scaling factor for bubble sizes.
y_lab : str, default='Genes' Label for the x-axis.
legend_lab : str, default='log(CPM + 1)' Label for the colorbar legend.
bbox_to_anchor_scale : int, default=25 Vertical scale (percentage) for positioning the colorbar.
bbox_to_anchor_perc : tuple, default=(0.91, 0.63) Anchor position for the size legend (percent bubble legend).
bbox_to_anchor_group : tuple, default=(1.01, 0.4) Anchor position for the group legend.
Returns
matplotlib.figure.Figure The generated scatter plot figure.
Notes
Colors represent expression values normalized to the colormap.
3829 def data_composition( 3830 self, 3831 features_count: list | None, 3832 name_slot: str = "cell_names", 3833 set_sep: bool = True, 3834 ): 3835 """ 3836 Compute composition of cell types in data set. 3837 3838 This function counts the occurrences of specific cells (e.g., cell types, subtypes) 3839 within metadata entries, calculates their relative percentages, and stores 3840 the results in `self.composition_data`. 3841 3842 Parameters 3843 ---------- 3844 features_count : list or None 3845 List of features (part or full names) to be counted. 3846 If None, all unique elements from the specified `name_slot` metadata field are used. 3847 3848 name_slot : str, default 'cell_names' 3849 Metadata field containing sample identifiers or labels. 3850 3851 set_sep : bool, default True 3852 If True and multiple sets are present in metadata, compute composition 3853 separately for each set. 3854 3855 Update 3856 ------- 3857 Stores results in `self.composition_data` as a pandas DataFrame with: 3858 - 'name': feature name 3859 - 'n': number of occurrences 3860 - 'pct': percentage of occurrences 3861 - 'set' (if applicable): dataset identifier 3862 """ 3863 3864 validated_list = list(self.input_metadata[name_slot]) 3865 sets = list(self.input_metadata["sets"]) 3866 3867 if features_count is None: 3868 features_count = list(set(self.input_metadata[name_slot])) 3869 3870 if set_sep and len(set(sets)) > 1: 3871 3872 final_res = pd.DataFrame() 3873 3874 for s in set(sets): 3875 print(s) 3876 3877 mask = [True if s == x else False for x in sets] 3878 3879 tmp_val_list = np.array(validated_list) 3880 3881 tmp_val_list = list(tmp_val_list[mask]) 3882 3883 res_dict = {"name": [], "n": [], "set": []} 3884 3885 for f in tqdm(features_count): 3886 res_dict["n"].append( 3887 sum(1 for element in tmp_val_list if f in element) 3888 ) 3889 res_dict["name"].append(f) 3890 res_dict["set"].append(s) 3891 res = pd.DataFrame(res_dict) 3892 res["pct"] = res["n"] / sum(res["n"]) * 100 3893 res["pct"] = res["pct"].round(2) 3894 3895 final_res = pd.concat([final_res, res]) 3896 3897 res = final_res.sort_values(["set", "pct"], ascending=[True, False]) 3898 3899 else: 3900 3901 res_dict = {"name": [], "n": []} 3902 3903 for f in tqdm(features_count): 3904 res_dict["n"].append( 3905 sum(1 for element in validated_list if f in element) 3906 ) 3907 res_dict["name"].append(f) 3908 3909 res = pd.DataFrame(res_dict) 3910 res["pct"] = res["n"] / sum(res["n"]) * 100 3911 res["pct"] = res["pct"].round(2) 3912 3913 res = res.sort_values("pct", ascending=False) 3914 3915 self.composition_data = res
Compute composition of cell types in data set.
This function counts the occurrences of specific cells (e.g., cell types, subtypes)
within metadata entries, calculates their relative percentages, and stores
the results in self.composition_data
.
Parameters
features_count : list or None
List of features (part or full names) to be counted.
If None, all unique elements from the specified name_slot
metadata field are used.
name_slot : str, default 'cell_names' Metadata field containing sample identifiers or labels.
set_sep : bool, default True If True and multiple sets are present in metadata, compute composition separately for each set.
Update
Stores results in self.composition_data
as a pandas DataFrame with:
- 'name': feature name
- 'n': number of occurrences
- 'pct': percentage of occurrences
- 'set' (if applicable): dataset identifier
3917 def composition_pie( 3918 self, 3919 width=6, 3920 height=6, 3921 font_size=15, 3922 cmap: str = "tab20", 3923 legend_split_col: int = 1, 3924 offset_labels: float | int = 0.5, 3925 legend_bbox: tuple = (1.15, 0.95), 3926 ): 3927 """ 3928 Visualize the composition of cell lineages using pie charts. 3929 3930 Generates pie charts showing the relative proportions of features stored 3931 in `self.composition_data`. If multiple sets are present, a separate 3932 chart is drawn for each set. 3933 3934 Parameters 3935 ---------- 3936 width : int, default 6 3937 Width of the figure. 3938 3939 height : int, default 6 3940 Height of the figure (applied per set if multiple sets are plotted). 3941 3942 font_size : int, default 15 3943 Font size for labels and annotations. 3944 3945 cmap : str, default 'tab20' 3946 Colormap used for pie slices. 3947 3948 legend_split_col : int, default 1 3949 Number of columns in the legend. 3950 3951 offset_labels : float or int, default 0.5 3952 Spacing offset for label placement relative to pie slices. 3953 3954 legend_bbox : tuple, default (1.15, 0.95) 3955 Bounding box anchor position for the legend. 3956 3957 Returns 3958 ------- 3959 matplotlib.figure.Figure 3960 Pie chart visualization of composition data. 3961 """ 3962 3963 df = self.composition_data 3964 3965 if "set" in df.columns and len(set(df["set"])) > 1: 3966 3967 sets = list(set(df["set"])) 3968 fig, axes = plt.subplots(len(sets), 1, figsize=(width, height * len(sets))) 3969 3970 all_wedges = [] 3971 cmap = plt.get_cmap(cmap) 3972 3973 set_nam = len(set(df["name"])) 3974 3975 legend_labels = list(set(df["name"])) 3976 3977 colors = [cmap(i / set_nam) for i in range(set_nam)] 3978 3979 cmap_dict = dict(zip(legend_labels, colors)) 3980 3981 for idx, s in enumerate(sets): 3982 ax = axes[idx] 3983 tmp_df = df[df["set"] == s].reset_index(drop=True) 3984 3985 labels = [f"{row['pct']:.1f}%" for _, row in tmp_df.iterrows()] 3986 3987 wedges, _ = ax.pie( 3988 tmp_df["n"], 3989 startangle=90, 3990 labeldistance=1.05, 3991 colors=[cmap_dict[x] for x in tmp_df["name"]], 3992 wedgeprops={"linewidth": 0.5, "edgecolor": "black"}, 3993 ) 3994 3995 all_wedges.extend(wedges) 3996 3997 kw = dict(arrowprops=dict(arrowstyle="-"), zorder=0, va="center") 3998 n = 0 3999 for i, p in enumerate(wedges): 4000 ang = (p.theta2 - p.theta1) / 2.0 + p.theta1 4001 y = np.sin(np.deg2rad(ang)) 4002 x = np.cos(np.deg2rad(ang)) 4003 horizontalalignment = {-1: "right", 1: "left"}[int(np.sign(x))] 4004 connectionstyle = f"angle,angleA=0,angleB={ang}" 4005 kw["arrowprops"].update({"connectionstyle": connectionstyle}) 4006 if len(labels[i]) > 0: 4007 n += offset_labels 4008 ax.annotate( 4009 labels[i], 4010 xy=(x, y), 4011 xytext=(1.01 * x + (n * x / 4), 1.01 * y + (n * y / 4)), 4012 horizontalalignment=horizontalalignment, 4013 fontsize=font_size, 4014 weight="bold", 4015 **kw, 4016 ) 4017 4018 circle2 = plt.Circle((0, 0), 0.6, color="white", ec="black") 4019 ax.add_artist(circle2) 4020 4021 ax.text( 4022 0, 4023 0, 4024 f"{s}", 4025 ha="center", 4026 va="center", 4027 fontsize=font_size, 4028 weight="bold", 4029 ) 4030 4031 legend_handles = [ 4032 Patch(facecolor=cmap_dict[label], edgecolor="black", label=label) 4033 for label in legend_labels 4034 ] 4035 4036 fig.legend( 4037 handles=legend_handles, 4038 loc="center right", 4039 bbox_to_anchor=legend_bbox, 4040 ncol=legend_split_col, 4041 title="", 4042 ) 4043 4044 plt.tight_layout() 4045 plt.show() 4046 4047 else: 4048 4049 labels = [f"{row['pct']:.1f}%" for _, row in df.iterrows()] 4050 4051 legend_labels = [f"{row['name']}" for _, row in df.iterrows()] 4052 4053 cmap = plt.get_cmap(cmap) 4054 colors = [cmap(i / len(df)) for i in range(len(df))] 4055 4056 fig, ax = plt.subplots( 4057 figsize=(width, height), subplot_kw=dict(aspect="equal") 4058 ) 4059 4060 wedges, _ = ax.pie( 4061 df["n"], 4062 startangle=90, 4063 labeldistance=1.05, 4064 colors=colors, 4065 wedgeprops={"linewidth": 0.5, "edgecolor": "black"}, 4066 ) 4067 4068 kw = dict(arrowprops=dict(arrowstyle="-"), zorder=0, va="center") 4069 n = 0 4070 for i, p in enumerate(wedges): 4071 ang = (p.theta2 - p.theta1) / 2.0 + p.theta1 4072 y = np.sin(np.deg2rad(ang)) 4073 x = np.cos(np.deg2rad(ang)) 4074 horizontalalignment = {-1: "right", 1: "left"}[int(np.sign(x))] 4075 connectionstyle = "angle,angleA=0,angleB={}".format(ang) 4076 kw["arrowprops"].update({"connectionstyle": connectionstyle}) 4077 if len(labels[i]) > 0: 4078 n += offset_labels 4079 4080 ax.annotate( 4081 labels[i], 4082 xy=(x, y), 4083 xytext=(1.01 * x + (n * x / 4), y * 1.01 + (n * y / 4)), 4084 horizontalalignment=horizontalalignment, 4085 fontsize=font_size, 4086 weight="bold", 4087 **kw, 4088 ) 4089 4090 circle2 = plt.Circle((0, 0), 0.6, color="white") 4091 circle2.set_edgecolor("black") 4092 4093 p = plt.gcf() 4094 p.gca().add_artist(circle2) 4095 4096 ax.legend( 4097 wedges, 4098 legend_labels, 4099 title="", 4100 loc="center left", 4101 bbox_to_anchor=legend_bbox, 4102 ncol=legend_split_col, 4103 ) 4104 4105 plt.show() 4106 4107 return fig
Visualize the composition of cell lineages using pie charts.
Generates pie charts showing the relative proportions of features stored
in self.composition_data
. If multiple sets are present, a separate
chart is drawn for each set.
Parameters
width : int, default 6 Width of the figure.
height : int, default 6 Height of the figure (applied per set if multiple sets are plotted).
font_size : int, default 15 Font size for labels and annotations.
cmap : str, default 'tab20' Colormap used for pie slices.
legend_split_col : int, default 1 Number of columns in the legend.
offset_labels : float or int, default 0.5 Spacing offset for label placement relative to pie slices.
legend_bbox : tuple, default (1.15, 0.95) Bounding box anchor position for the legend.
Returns
matplotlib.figure.Figure Pie chart visualization of composition data.
4109 def bar_composition( 4110 self, 4111 cmap="tab20b", 4112 width=2, 4113 height=6, 4114 font_size=15, 4115 legend_split_col: int = 1, 4116 legend_bbox: tuple = (1.3, 1), 4117 ): 4118 """ 4119 Visualize the composition of cell lineages using bar plots. 4120 4121 Produces bar plots showing the distribution of features stored in 4122 `self.composition_data`. If multiple sets are present, a separate 4123 bar is drawn for each set. Percentages are annotated alongside the bars. 4124 4125 Parameters 4126 ---------- 4127 cmap : str, default 'tab20b' 4128 Colormap used for stacked bars. 4129 4130 width : int, default 2 4131 Width of each subplot (per set). 4132 4133 height : int, default 6 4134 Height of the figure. 4135 4136 font_size : int, default 15 4137 Font size for labels and annotations. 4138 4139 legend_split_col : int, default 1 4140 Number of columns in the legend. 4141 4142 legend_bbox : tuple, default (1.3, 1) 4143 Bounding box anchor position for the legend. 4144 4145 Returns 4146 ------- 4147 matplotlib.figure.Figure 4148 Stacked bar plot visualization of composition data. 4149 """ 4150 4151 df = self.composition_data 4152 df["num"] = range(1, len(df) + 1) 4153 4154 if "set" in df.columns and len(set(df["set"])) > 1: 4155 4156 sets = list(set(df["set"])) 4157 fig, axes = plt.subplots(1, len(sets), figsize=(width * len(sets), height)) 4158 4159 cmap = plt.get_cmap(cmap) 4160 4161 set_nam = len(set(df["name"])) 4162 4163 legend_labels = list(set(df["name"])) 4164 4165 colors = [cmap(i / set_nam) for i in range(set_nam)] 4166 4167 cmap_dict = dict(zip(legend_labels, colors)) 4168 4169 for idx, s in enumerate(sets): 4170 ax = axes[idx] 4171 4172 tmp_df = df[df["set"] == s].reset_index(drop=True) 4173 4174 values = tmp_df["n"].values 4175 total = sum(values) 4176 values = [v / total * 100 for v in values] 4177 values = [round(v, 2) for v in values] 4178 4179 idx_max = np.argmax(values) 4180 correction = 100 - sum(values) 4181 values[idx_max] += correction 4182 4183 names = tmp_df["name"].values 4184 perc = tmp_df["pct"].values 4185 nums = tmp_df["num"].values 4186 4187 bottom = 0 4188 centers = [] 4189 for name, num, val, color in zip(names, nums, values, colors): 4190 ax.bar(s, val, bottom=bottom, color=cmap_dict[name], label=name) 4191 centers.append(bottom + val / 2) 4192 bottom += val 4193 4194 y_positions = np.linspace(centers[0], centers[-1], len(centers)) 4195 x_text = -0.8 4196 4197 for y_label, y_center, pct, num in zip( 4198 y_positions, centers, perc, nums 4199 ): 4200 ax.annotate( 4201 f"{pct:.1f}%", 4202 xy=(0, y_center), 4203 xycoords="data", 4204 xytext=(x_text, y_label), 4205 textcoords="data", 4206 ha="right", 4207 va="center", 4208 fontsize=font_size, 4209 arrowprops=dict( 4210 arrowstyle="->", 4211 lw=1, 4212 color="black", 4213 connectionstyle="angle3,angleA=0,angleB=90", 4214 ), 4215 ) 4216 4217 ax.set_ylim(0, 100) 4218 ax.set_xlabel(s, fontsize=font_size) 4219 ax.xaxis.label.set_rotation(30) 4220 4221 ax.set_xticks([]) 4222 ax.set_yticks([]) 4223 for spine in ax.spines.values(): 4224 spine.set_visible(False) 4225 4226 legend_handles = [ 4227 Patch(facecolor=cmap_dict[label], edgecolor="black", label=label) 4228 for label in legend_labels 4229 ] 4230 4231 fig.legend( 4232 handles=legend_handles, 4233 loc="center right", 4234 bbox_to_anchor=legend_bbox, 4235 ncol=legend_split_col, 4236 title="", 4237 ) 4238 4239 plt.tight_layout() 4240 plt.show() 4241 4242 else: 4243 4244 cmap = plt.get_cmap(cmap) 4245 4246 colors = [cmap(i / len(df)) for i in range(len(df))] 4247 4248 fig, ax = plt.subplots(figsize=(width, height)) 4249 4250 values = df["n"].values 4251 names = df["name"].values 4252 perc = df["pct"].values 4253 nums = df["num"].values 4254 4255 bottom = 0 4256 centers = [] 4257 for name, num, val, color in zip(names, nums, values, colors): 4258 ax.bar(0, val, bottom=bottom, color=color, label=f"{num}) {name}") 4259 centers.append(bottom + val / 2) 4260 bottom += val 4261 4262 y_positions = np.linspace(centers[0], centers[-1], len(centers)) 4263 x_text = -0.8 4264 4265 for y_label, y_center, pct, num in zip(y_positions, centers, perc, nums): 4266 ax.annotate( 4267 f"{num}) {pct}", 4268 xy=(0, y_center), 4269 xycoords="data", 4270 xytext=(x_text, y_label), 4271 textcoords="data", 4272 ha="right", 4273 va="center", 4274 fontsize=9, 4275 arrowprops=dict( 4276 arrowstyle="->", 4277 lw=1, 4278 color="black", 4279 connectionstyle="angle3,angleA=0,angleB=90", 4280 ), 4281 ) 4282 4283 ax.set_xticks([]) 4284 ax.set_yticks([]) 4285 for spine in ax.spines.values(): 4286 spine.set_visible(False) 4287 4288 ax.legend( 4289 title="Legend", 4290 bbox_to_anchor=legend_bbox, 4291 loc="upper left", 4292 ncol=legend_split_col, 4293 ) 4294 4295 plt.tight_layout() 4296 plt.show() 4297 4298 return fig
Visualize the composition of cell lineages using bar plots.
Produces bar plots showing the distribution of features stored in
self.composition_data
. If multiple sets are present, a separate
bar is drawn for each set. Percentages are annotated alongside the bars.
Parameters
cmap : str, default 'tab20b' Colormap used for stacked bars.
width : int, default 2 Width of each subplot (per set).
height : int, default 6 Height of the figure.
font_size : int, default 15 Font size for labels and annotations.
legend_split_col : int, default 1 Number of columns in the legend.
legend_bbox : tuple, default (1.3, 1) Bounding box anchor position for the legend.
Returns
matplotlib.figure.Figure Stacked bar plot visualization of composition data.
4300 def cell_regression( 4301 self, 4302 cell_x: str, 4303 cell_y: str, 4304 set_x: str | None, 4305 set_y: str | None, 4306 threshold=10, 4307 image_width=12, 4308 image_high=7, 4309 color="black", 4310 ): 4311 """ 4312 Perform regression analysis between two selected cells and visualize the relationship. 4313 4314 This function computes a linear regression between two specified cells from 4315 aggregated normalized data, plots the regression line with scatter points, 4316 annotates regression statistics, and highlights potential outliers. 4317 4318 Parameters 4319 ---------- 4320 cell_x : str 4321 Name of the first cell (X-axis). 4322 4323 cell_y : str 4324 Name of the second cell (Y-axis). 4325 4326 set_x : str or None 4327 Dataset identifier corresponding to `cell_x`. If None, cell is selected only by name. 4328 4329 set_y : str or None 4330 Dataset identifier corresponding to `cell_y`. If None, cell is selected only by name. 4331 4332 threshold : int or float, default 10 4333 Threshold for detecting outliers. Points deviating from the mean or diagonal by more 4334 than this value are annotated. 4335 4336 image_width : int, default 12 4337 Width of the regression plot (in inches). 4338 4339 image_high : int, default 7 4340 Height of the regression plot (in inches). 4341 4342 color : str, default 'black' 4343 Color of the regression scatter points and line. 4344 4345 Returns 4346 ------- 4347 matplotlib.figure.Figure 4348 Regression plot figure with annotated regression line, R², p-value, and outliers. 4349 4350 Raises 4351 ------ 4352 ValueError 4353 If `cell_x` or `cell_y` are not found in the dataset. 4354 If multiple matches are found for a cell name and `set_x`/`set_y` are not specified. 4355 4356 Notes 4357 ----- 4358 - The function automatically calls `jseq_object.average()` if aggregated data is not available. 4359 - Outliers are annotated with their corresponding index labels. 4360 - Regression is computed using `scipy.stats.linregress`. 4361 4362 Examples 4363 -------- 4364 >>> obj.cell_regression(cell_x="Purkinje", cell_y="Granule", set_x="Exp1", set_y="Exp2") 4365 >>> obj.cell_regression(cell_x="NeuronA", cell_y="NeuronB", threshold=5, color="blue") 4366 """ 4367 4368 if self.agg_normalized_data is None: 4369 self.average() 4370 4371 metadata = self.agg_metadata 4372 data = self.agg_normalized_data 4373 4374 if set_x is not None and set_y is not None: 4375 data.columns = metadata["cell_names"] + " # " + metadata["sets"] 4376 cell_x = cell_x + " # " + set_x 4377 cell_y = cell_y + " # " + set_y 4378 4379 else: 4380 data.columns = metadata["cell_names"] 4381 4382 if not cell_x in data.columns: 4383 raise ValueError("'cell_x' value not in cell names!") 4384 4385 if not cell_y in data.columns: 4386 raise ValueError("'cell_y' value not in cell names!") 4387 4388 if list(data.columns).count(cell_x) > 1: 4389 raise ValueError( 4390 f"'{cell_x}' occurs more than once. If you want to select a specific cell, " 4391 f"please also provide the corresponding 'set_x' and 'set_y' values." 4392 ) 4393 4394 if list(data.columns).count(cell_y) > 1: 4395 raise ValueError( 4396 f"'{cell_y}' occurs more than once. If you want to select a specific cell, " 4397 f"please also provide the corresponding 'set_x' and 'set_y' values." 4398 ) 4399 4400 fig, ax = plt.subplots(figsize=(image_width, image_high)) 4401 ax = sns.regplot(x=cell_x, y=cell_y, data=data, color=color) 4402 4403 slope, intercept, r_value, p_value, _ = stats.linregress( 4404 data[cell_x], data[cell_y] 4405 ) 4406 equation = "y = {:.2f}x + {:.2f}".format(slope, intercept) 4407 4408 ax.annotate( 4409 "R-squared = {:.2f}\nP-value = {:.2f}\n{}".format( 4410 r_value**2, p_value, equation 4411 ), 4412 xy=(0.05, 0.90), 4413 xycoords="axes fraction", 4414 fontsize=12, 4415 ) 4416 4417 ax.spines["top"].set_visible(False) 4418 ax.spines["right"].set_visible(False) 4419 4420 diff = [] 4421 x_mean, y_mean = data[cell_x].mean(), data[cell_y].mean() 4422 for i, (xi, yi) in enumerate(zip(data[cell_x], data[cell_y])): 4423 diff.append(abs(xi - x_mean)) 4424 diff.append(abs(yi - y_mean)) 4425 4426 def annotate_outliers(x, y, threshold): 4427 texts = [] 4428 x_mean, y_mean = x.mean(), y.mean() 4429 for i, (xi, yi) in enumerate(zip(x, y)): 4430 if ( 4431 abs(xi - x_mean) > threshold 4432 or abs(yi - y_mean) > threshold 4433 or abs(yi - xi) > threshold 4434 ): 4435 text = ax.text(xi, yi, data.index[i]) 4436 texts.append(text) 4437 4438 return texts 4439 4440 texts = annotate_outliers(data[cell_x], data[cell_y], threshold) 4441 4442 adjust_text(texts, arrowprops=dict(arrowstyle="-", color="gray", alpha=0.5)) 4443 4444 plt.show() 4445 4446 return fig
Perform regression analysis between two selected cells and visualize the relationship.
This function computes a linear regression between two specified cells from aggregated normalized data, plots the regression line with scatter points, annotates regression statistics, and highlights potential outliers.
Parameters
cell_x : str Name of the first cell (X-axis).
cell_y : str Name of the second cell (Y-axis).
set_x : str or None
Dataset identifier corresponding to cell_x
. If None, cell is selected only by name.
set_y : str or None
Dataset identifier corresponding to cell_y
. If None, cell is selected only by name.
threshold : int or float, default 10 Threshold for detecting outliers. Points deviating from the mean or diagonal by more than this value are annotated.
image_width : int, default 12 Width of the regression plot (in inches).
image_high : int, default 7 Height of the regression plot (in inches).
color : str, default 'black' Color of the regression scatter points and line.
Returns
matplotlib.figure.Figure Regression plot figure with annotated regression line, R², p-value, and outliers.
Raises
ValueError
If cell_x
or cell_y
are not found in the dataset.
If multiple matches are found for a cell name and set_x
/set_y
are not specified.
Notes
- The function automatically calls
jseq_object.average()
if aggregated data is not available. - Outliers are annotated with their corresponding index labels.
- Regression is computed using
scipy.stats.linregress
.
Examples
>>> obj.cell_regression(cell_x="Purkinje", cell_y="Granule", set_x="Exp1", set_y="Exp2")
>>> obj.cell_regression(cell_x="NeuronA", cell_y="NeuronB", threshold=5, color="blue")