jdti.utils
1import os 2import re 3 4import matplotlib as mpl 5import matplotlib.patches as mpatches 6import matplotlib.pyplot as plt 7import numpy as np 8import pandas as pd 9import scipy.stats as stats 10from adjustText import adjust_text 11from joblib import Parallel, delayed 12from matplotlib.lines import Line2D 13from mpl_toolkits.axes_grid1.inset_locator import inset_axes 14from scipy.cluster.hierarchy import dendrogram, linkage 15from scipy.io import mmread 16from sklearn.preprocessing import MinMaxScaler 17from tqdm import tqdm 18 19 20def load_sparse(path: str, name: str): 21 """ 22 Load a sparse matrix dataset along with associated gene 23 and cell metadata, and return it as a dense DataFrame. 24 25 This function expects the input directory to contain three files in standard 26 10x Genomics format: 27 - "matrix.mtx": the gene expression matrix in Matrix Market format 28 - "genes.tsv": tab-separated file containing gene identifiers 29 - "barcodes.tsv": tab-separated file containing cell barcodes / names 30 31 Parameters 32 ---------- 33 path : str 34 Path to the directory containing the matrix and annotation files. 35 36 name : str 37 Label or dataset identifier to be assigned to all cells in the metadata. 38 39 Returns 40 ------- 41 data : pandas.DataFrame 42 Dense expression matrix where rows correspond to genes and columns 43 correspond to cells. 44 metadata : pandas.DataFrame 45 Metadata DataFrame with two columns: 46 - "cell_names": the names of the cells (barcodes) 47 - "sets": the dataset label assigned to each cell (from `name`) 48 49 Notes 50 ----- 51 The function converts the sparse matrix into a dense DataFrame. This may 52 require a large amount of memory for datasets with many cells and genes. 53 """ 54 55 data = mmread(os.path.join(path, "matrix.mtx")) 56 data = pd.DataFrame(data.todense()) 57 genes = pd.read_csv(os.path.join(path, "genes.tsv"), header=None, sep="\t") 58 names = pd.read_csv(os.path.join(path, "barcodes.tsv"), header=None, sep="\t") 59 data.columns = [str(x) for x in names[0]] 60 data.index = list(genes[0]) 61 62 names = list(data.columns) 63 sets = [name] * len(names) 64 65 metadata = pd.DataFrame({"cell_names": names, "sets": sets}) 66 67 return data, metadata 68 69 70def volcano_plot( 71 deg_data: pd.DataFrame, 72 p_adj: bool = True, 73 top: int = 25, 74 p_val: float | int = 0.05, 75 lfc: float | int = 0.25, 76 standard_scale: bool = False, 77 rescale_adj: bool = True, 78 image_width: int = 12, 79 image_high: int = 12, 80): 81 """ 82 Generate a volcano plot from differential expression results. 83 84 A volcano plot visualizes the relationship between statistical significance 85 (p-values or standarized p-value) and log(fold change) for each gene, highlighting 86 genes that pass significance thresholds. 87 88 Parameters 89 ---------- 90 deg_data : pandas.DataFrame 91 DataFrame containing differential expression results from calc_DEG() function. 92 93 p_adj : bool, default=True 94 If True, use adjusted p-values. If False, use raw p-values. 95 96 top : int, default=25 97 Number of top significant genes to highlight on the plot. 98 99 p_val : float | int, default=0.05 100 Significance threshold for p-values (or adjusted p-values). 101 102 lfc : float | int, default=0.25 103 Threshold for absolute log fold change. 104 105 standard_scale : bool, default=False 106 If True, use standardized p-values for visualization instead of -log10(p-values). 107 108 rescale_adj : bool, default=True 109 If True, rescale p-values to avoid long breaks caused by outlier values. 110 111 image_width : int, default=12 112 Width of the generated plot in inches. 113 114 image_high : int, default=12 115 Height of the generated plot in inches. 116 117 Returns 118 ------- 119 matplotlib.figure.Figure 120 The generated volcano plot figure. 121 122 """ 123 124 if p_adj: 125 pv = "adj_pval" 126 else: 127 pv = "p_val" 128 129 deg_df = deg_data.copy() 130 131 if standard_scale: 132 133 p_val_scale = "scale(p-val)" 134 135 scaler = MinMaxScaler(feature_range=(0, 1000)) 136 137 tmp_p = deg_df[ 138 ((deg_df[pv] != 0) & (deg_df["log(FC)"] < 0)) 139 | ((deg_df[pv] != 0) & (deg_df["log(FC)"] > 0)) 140 ] 141 142 tmp_p[p_val_scale] = (-np.log10(deg_df[pv])).astype(float) 143 144 tmp_p[p_val_scale] = scaler.fit_transform( 145 tmp_p[p_val_scale].values.reshape(-1, 1) 146 ).astype(float) 147 148 median_shift = abs(tmp_p[p_val_scale].diff().max()) * 25 149 cur_max = tmp_p[p_val_scale].max() 150 151 zero_p_plus = deg_df[(deg_df[pv] == 0) & (deg_df["log(FC)"] > 0)] 152 zero_p_plus = zero_p_plus.sort_values(by="log(FC)", ascending=True).reset_index( 153 drop=True 154 ) 155 156 zero_p_plus[p_val_scale] = [ 157 (cur_max + (x * median_shift)) for x in range(1, len(zero_p_plus.index) + 1) 158 ] 159 zero_p_plus[p_val_scale] = zero_p_plus[p_val_scale].astype(float) 160 161 zero_p_minus = deg_df[(deg_df[pv] == 0) & (deg_df["log(FC)"] < 0)] 162 zero_p_minus = zero_p_minus.sort_values( 163 by="log(FC)", ascending=False 164 ).reset_index(drop=True) 165 166 zero_p_minus[p_val_scale] = [ 167 (cur_max + (x * median_shift)) 168 for x in range(1, len(zero_p_minus.index) + 1) 169 ] 170 zero_p_minus[p_val_scale] = zero_p_minus[p_val_scale].astype(float) 171 172 deg_df = pd.concat([zero_p_plus, tmp_p, zero_p_minus], ignore_index=True) 173 174 else: 175 176 shift = 0.25 177 178 p_val_scale = "-log(p_val)" 179 180 min_minus = min(deg_df[pv][(deg_df[pv] != 0) & (deg_df["log(FC)"] < 0)]) 181 min_plus = min(deg_df[pv][(deg_df[pv] != 0) & (deg_df["log(FC)"] > 0)]) 182 183 zero_p_plus = deg_df[(deg_df[pv] == 0) & (deg_df["log(FC)"] > 0)] 184 zero_p_plus = zero_p_plus.sort_values( 185 by="log(FC)", ascending=False 186 ).reset_index(drop=True) 187 zero_p_plus[pv] = [ 188 (shift * x) * min_plus for x in range(1, len(zero_p_plus.index) + 1) 189 ] 190 191 zero_p_minus = deg_df[(deg_df[pv] == 0) & (deg_df["log(FC)"] < 0)] 192 zero_p_minus = zero_p_minus.sort_values( 193 by="log(FC)", ascending=True 194 ).reset_index(drop=True) 195 zero_p_minus[pv] = [ 196 (shift * x) * min_minus for x in range(1, len(zero_p_minus.index) + 1) 197 ] 198 199 tmp_p = deg_df[ 200 ((deg_df[pv] != 0) & (deg_df["log(FC)"] < 0)) 201 | ((deg_df[pv] != 0) & (deg_df["log(FC)"] > 0)) 202 ] 203 204 del deg_df 205 206 deg_df = pd.concat([zero_p_plus, tmp_p, zero_p_minus], ignore_index=True) 207 208 deg_df[p_val_scale] = -np.log10(deg_df[pv]) 209 210 deg_df["top100"] = None 211 212 if rescale_adj: 213 214 deg_df = deg_df.sort_values(by=p_val_scale, ascending=False) 215 216 deg_df = deg_df.reset_index(drop=True) 217 218 doubled = [] 219 ratio = [] 220 for n, i in enumerate(deg_df.index): 221 for j in range(1, 6): 222 if ( 223 n + j < len(deg_df.index) 224 and deg_df[p_val_scale][n] / deg_df[p_val_scale][n + j] >= 2 225 ): 226 doubled.append(n) 227 ratio.append(deg_df[p_val_scale][n + j] / deg_df[p_val_scale][n]) 228 229 df = pd.DataFrame({"doubled": doubled, "ratio": ratio}) 230 df = df[df["doubled"] < 100] 231 232 df["ratio"] = (1 - df["ratio"]) / 5 233 df = df.reset_index(drop=True) 234 235 df = df.sort_values("doubled") 236 237 if len(df["doubled"]) == 1 and 0 in df["doubled"]: 238 df = df 239 else: 240 doubled2 = [] 241 242 for l in df["doubled"]: 243 if l + 1 != len(doubled) and l + 1 - l == 1: 244 doubled2.append(l) 245 doubled2.append(l + 1) 246 else: 247 break 248 249 doubled2 = sorted(set(doubled2), reverse=True) 250 251 if len(doubled2) > 1: 252 df = df[df["doubled"].isin(doubled2)] 253 df = df.sort_values("doubled", ascending=False) 254 df = df.reset_index(drop=True) 255 for c in df.index: 256 deg_df.loc[df["doubled"][c], p_val_scale] = deg_df.loc[ 257 df["doubled"][c] + 1, p_val_scale 258 ] * (1 + df["ratio"][c]) 259 260 deg_df.loc[(deg_df["log(FC)"] <= 0) & (deg_df["p_val"] <= p_val), "top100"] = "red" 261 deg_df.loc[(deg_df["log(FC)"] > 0) & (deg_df["p_val"] <= p_val), "top100"] = "blue" 262 deg_df.loc[deg_df["p_val"] > p_val, "top100"] = "lightgray" 263 264 if lfc > 0: 265 deg_df.loc[ 266 (deg_df["log(FC)"] <= lfc) & (deg_df["log(FC)"] >= -lfc), "top100" 267 ] = "lightgray" 268 269 down_int = len( 270 deg_df["top100"][(deg_df["log(FC)"] <= lfc * -1) & (deg_df["p_val"] <= p_val)] 271 ) 272 up_int = len( 273 deg_df["top100"][(deg_df["log(FC)"] > lfc) & (deg_df["p_val"] <= p_val)] 274 ) 275 276 deg_df_up = deg_df[deg_df["log(FC)"] > 0] 277 deg_df_up = deg_df_up.sort_values(["p_val", "log(FC)"], ascending=[True, False]) 278 deg_df_up = deg_df_up.reset_index(drop=True) 279 280 n = -1 281 l = 0 282 while True: 283 n += 1 284 if deg_df_up["log(FC)"][n] > lfc: 285 deg_df_up.loc[n, "top100"] = "green" 286 l += 1 287 if l == top or deg_df_up["p_val"][n] > p_val: 288 break 289 290 deg_df_down = deg_df[deg_df["log(FC)"] <= 0] 291 deg_df_down = deg_df_down.sort_values(["p_val", "log(FC)"], ascending=[True, True]) 292 deg_df_down = deg_df_down.reset_index(drop=True) 293 294 n = -1 295 l = 0 296 while True: 297 n += 1 298 if deg_df_down["log(FC)"][n] < lfc * -1: 299 deg_df_down.loc[n, "top100"] = "yellow" 300 301 l += 1 302 if l == top or deg_df_down["p_val"][n] > p_val: 303 break 304 305 deg_df = pd.concat([deg_df_up, deg_df_down]) 306 307 que = ["lightgray", "red", "blue", "yellow", "green"] 308 309 deg_df = deg_df.sort_values( 310 by="top100", key=lambda x: x.map({v: i for i, v in enumerate(que)}) 311 ) 312 313 deg_df = deg_df.reset_index(drop=True) 314 315 fig, ax = plt.subplots(figsize=(image_width, image_high)) 316 317 plt.scatter( 318 x=deg_df["log(FC)"], y=deg_df[p_val_scale], color=deg_df["top100"], zorder=2 319 ) 320 321 plt.plot( 322 [max(deg_df["log(FC)"]) * -1.1, max(deg_df["log(FC)"]) * 1.1], 323 [-np.log10(p_val), -np.log10(p_val)], 324 linestyle="--", 325 linewidth=3, 326 color="lightgray", 327 zorder=1, 328 ) 329 330 if lfc > 0: 331 plt.plot( 332 [lfc * -1, lfc * -1], 333 [-3, max(deg_df[p_val_scale]) * 1.1], 334 linestyle="--", 335 linewidth=3, 336 color="lightgray", 337 zorder=1, 338 ) 339 plt.plot( 340 [lfc, lfc], 341 [-3, max(deg_df[p_val_scale]) * 1.1], 342 linestyle="--", 343 linewidth=3, 344 color="lightgray", 345 zorder=1, 346 ) 347 348 plt.xlabel("log(FC)") 349 plt.ylabel(p_val_scale) 350 plt.title("Volcano plot") 351 352 plt.ylim(min(deg_df[p_val_scale]) - 5, max(deg_df[p_val_scale]) * 1.25) 353 354 texts = [ 355 ax.text(deg_df["log(FC)"][i], deg_df[p_val_scale][i], deg_df["feature"][i]) 356 for i in deg_df.index 357 if deg_df["top100"][i] in ["green", "yellow"] 358 ] 359 360 adjust_text(texts, arrowprops=dict(arrowstyle="-", color="gray", alpha=0.5)) 361 362 legend_elements = [ 363 Line2D( 364 [0], 365 [0], 366 marker="o", 367 color="w", 368 label="top-upregulated", 369 markerfacecolor="green", 370 markersize=10, 371 ), 372 Line2D( 373 [0], 374 [0], 375 marker="o", 376 color="w", 377 label="top-downregulated", 378 markerfacecolor="yellow", 379 markersize=10, 380 ), 381 Line2D( 382 [0], 383 [0], 384 marker="o", 385 color="w", 386 label="upregulated", 387 markerfacecolor="blue", 388 markersize=10, 389 ), 390 Line2D( 391 [0], 392 [0], 393 marker="o", 394 color="w", 395 label="downregulated", 396 markerfacecolor="red", 397 markersize=10, 398 ), 399 Line2D( 400 [0], 401 [0], 402 marker="o", 403 color="w", 404 label="non-significant", 405 markerfacecolor="lightgray", 406 markersize=10, 407 ), 408 ] 409 410 ax.legend(handles=legend_elements, loc="upper right") 411 ax.grid(visible=False) 412 413 ax.annotate( 414 "\nmin p-value = " + str(p_val), 415 xy=(0.025, 0.975), 416 xycoords="axes fraction", 417 fontsize=12, 418 ) 419 420 if lfc > 0: 421 ax.annotate( 422 "\nmin log(FC) = " + str(lfc), 423 xy=(0.025, 0.95), 424 xycoords="axes fraction", 425 fontsize=12, 426 ) 427 428 ax.annotate( 429 "\nDownregulated: " + str(down_int), 430 xy=(0.025, 0.925), 431 xycoords="axes fraction", 432 fontsize=12, 433 color="red", 434 ) 435 436 ax.annotate( 437 "\nUpregulated: " + str(up_int), 438 xy=(0.025, 0.9), 439 xycoords="axes fraction", 440 fontsize=12, 441 color="blue", 442 ) 443 444 plt.show() 445 446 return fig 447 448 449def find_features(data: pd.DataFrame, features: list): 450 """ 451 Identify features (rows) from a DataFrame that match a given list of features, 452 ignoring case sensitivity. 453 454 Parameters 455 ---------- 456 data : pandas.DataFrame 457 DataFrame with features in the index (rows). 458 459 features : list 460 List of feature names to search for. 461 462 Returns 463 ------- 464 dict 465 Dictionary with keys: 466 - "included": list of features found in the DataFrame index. 467 - "not_included": list of requested features not found in the DataFrame index. 468 - "potential": list of features in the DataFrame that may be similar. 469 """ 470 471 features_upper = [str(x).upper() for x in features] 472 473 index_set = set(data.index) 474 475 features_in = [x for x in index_set if x.upper() in features_upper] 476 features_in_upper = [x.upper() for x in features_in] 477 features_out_upper = [x for x in features_upper if x not in features_in_upper] 478 features_out = [x for x in features if x.upper() in features_out_upper] 479 similar_features = [ 480 idx for idx in index_set if any(x in idx.upper() for x in features_out_upper) 481 ] 482 483 return { 484 "included": features_in, 485 "not_included": features_out, 486 "potential": similar_features, 487 } 488 489 490def find_names(data: pd.DataFrame, names: list): 491 """ 492 Identify names (columns) from a DataFrame that match a given list of names, 493 ignoring case sensitivity. 494 495 Parameters 496 ---------- 497 data : pandas.DataFrame 498 DataFrame with names in the columns. 499 500 names : list 501 List of names to search for. 502 503 Returns 504 ------- 505 dict 506 Dictionary with keys: 507 - "included": list of names found in the DataFrame columns. 508 - "not_included": list of requested names not found in the DataFrame columns. 509 - "potential": list of names in the DataFrame that may be similar. 510 """ 511 512 names_upper = [str(x).upper() for x in names] 513 514 columns = set(data.columns) 515 516 names_in = [x for x in columns if x.upper() in names_upper] 517 names_in_upper = [x.upper() for x in names_in] 518 names_out_upper = [x for x in names_upper if x not in names_in_upper] 519 names_out = [x for x in names if x.upper() in names_out_upper] 520 similar_names = [ 521 idx for idx in columns if any(x in idx.upper() for x in names_out_upper) 522 ] 523 524 return {"included": names_in, "not_included": names_out, "potential": similar_names} 525 526 527def reduce_data(data: pd.DataFrame, features: list = [], names: list = []): 528 """ 529 Subset a DataFrame based on selected features (rows) and/or names (columns). 530 531 Parameters 532 ---------- 533 data : pandas.DataFrame 534 Input DataFrame with features as rows and names as columns. 535 536 features : list 537 List of features to include (rows). Default is an empty list. 538 If empty, all rows are returned. 539 540 names : list 541 List of names to include (columns). Default is an empty list. 542 If empty, all columns are returned. 543 544 Returns 545 ------- 546 pandas.DataFrame 547 Subset of the input DataFrame containing only the selected rows 548 and/or columns. 549 550 Raises 551 ------ 552 ValueError 553 If both `features` and `names` are empty. 554 """ 555 556 if len(features) > 0 and len(names) > 0: 557 fet = find_features(data=data, features=features) 558 559 nam = find_names(data=data, names=names) 560 561 data_to_return = data.loc[fet["included"], nam["included"]] 562 563 elif len(features) > 0 and len(names) == 0: 564 fet = find_features(data=data, features=features) 565 566 data_to_return = data.loc[fet["included"], :] 567 568 elif len(features) == 0 and len(names) > 0: 569 570 nam = find_names(data=data, names=names) 571 572 data_to_return = data.loc[:, nam["included"]] 573 574 else: 575 576 raise ValueError("features and names have zero length!") 577 578 return data_to_return 579 580 581def make_unique_list(lst): 582 """ 583 Generate a list where duplicate items are renamed to ensure uniqueness. 584 585 Each duplicate is appended with a suffix ".n", where n indicates the 586 occurrence count (starting from 1). 587 588 Parameters 589 ---------- 590 lst : list 591 Input list of items (strings or other hashable types). 592 593 Returns 594 ------- 595 list 596 List with unique values. 597 598 Examples 599 -------- 600 >>> make_unique_list(["A", "B", "A", "A"]) 601 ['A', 'B', 'A.1', 'A.2'] 602 """ 603 seen = {} 604 result = [] 605 for item in lst: 606 if item not in seen: 607 seen[item] = 0 608 result.append(item) 609 else: 610 seen[item] += 1 611 result.append(f"{item}.{seen[item]}") 612 return result 613 614 615def get_color_palette(variable_list, palette_name="tab10"): 616 n = len(variable_list) 617 cmap = plt.get_cmap(palette_name) 618 colors = [cmap(i % cmap.N) for i in range(n)] 619 return dict(zip(variable_list, colors)) 620 621 622def features_scatter( 623 expression_data: pd.DataFrame, 624 occurence_data: pd.DataFrame | None = None, 625 scale: bool = False, 626 features: list | None = None, 627 metadata_list: list | None = None, 628 colors: str = "viridis", 629 hclust: str | None = "complete", 630 img_width: int = 8, 631 img_high: int = 5, 632 label_size: int = 10, 633 size_scale: int = 100, 634 y_lab: str = "Genes", 635 legend_lab: str = "log(CPM + 1)", 636 set_box_size: float | int = 5, 637 set_box_high: float | int = 5, 638 bbox_to_anchor_scale: int = 25, 639 bbox_to_anchor_perc: tuple = (0.91, 0.63), 640 bbox_to_anchor_group: tuple = (1.01, 0.4), 641): 642 """ 643 Create a bubble scatter plot of selected features across samples. 644 645 Each point represents a feature-sample pair, where the color encodes the 646 expression value and the size encodes occurrence or relative abundance. 647 Optionally, hierarchical clustering can be applied to order rows and columns. 648 649 Parameters 650 ---------- 651 expression_data : pandas.DataFrame 652 Expression values (mean) with features as rows and samples as columns derived from average() function. 653 654 occurence_data : pandas.DataFrame or None 655 DataFrame with occurrence/frequency values (same shape as `expression_data`) derived from occurrence() function. 656 If None, bubble sizes are based on expression values. 657 658 scale: bool, default False 659 If True, expression_data will be scaled (0–1) across the rows (features). 660 661 features : list or None 662 List of features (rows) to display. If None, all features are used. 663 664 metadata_list : list or None, optional 665 Metadata grouping for samples (same length as number of columns). 666 Used to add group colors and separators in the plot. 667 668 colors : str, default='viridis' 669 Colormap for expression values. 670 671 hclust : str or None, default='complete' 672 Linkage method for hierarchical clustering. If None, no clustering 673 is performed. 674 675 img_width : int or float, default=8 676 Width of the plot in inches. 677 678 img_high : int or float, default=5 679 Height of the plot in inches. 680 681 label_size : int, default=10 682 Font size for axis labels and ticks. 683 684 size_scale : int or float, default=100 685 Scaling factor for bubble sizes. 686 687 y_lab : str, default='Genes' 688 Label for the x-axis. 689 690 legend_lab : str, default='log(CPM + 1)' 691 Label for the colorbar legend. 692 693 bbox_to_anchor_scale : int, default=25 694 Vertical scale (percentage) for positioning the colorbar. 695 696 bbox_to_anchor_perc : tuple, default=(0.91, 0.63) 697 Anchor position for the size legend (percent bubble legend). 698 699 bbox_to_anchor_group : tuple, default=(1.01, 0.4) 700 Anchor position for the group legend. 701 702 Returns 703 ------- 704 matplotlib.figure.Figure 705 The generated scatter plot figure. 706 707 Raises 708 ------ 709 ValueError 710 If `metadata_list` is provided but its length does not match 711 the number of columns in `expression_data`. 712 713 Notes 714 ----- 715 - Colors represent expression values normalized to the colormap. 716 - Bubble sizes represent occurrence values (or expression values if 717 `occurence_data` is None). 718 - If `metadata_list` is given, groups are indicated with colors and 719 dashed vertical separators. 720 """ 721 722 scatter_df = expression_data.copy() 723 724 if scale: 725 726 legend_lab = "Scaled\n" + legend_lab 727 728 scaler = MinMaxScaler(feature_range=(0, 1)) 729 scatter_df = pd.DataFrame( 730 scaler.fit_transform(scatter_df.T).T, 731 index=scatter_df.index, 732 columns=scatter_df.columns, 733 ) 734 735 metadata = {} 736 737 metadata["primary_names"] = [str(x) for x in scatter_df.columns] 738 739 if metadata_list is not None: 740 metadata["sets"] = metadata_list 741 742 if len(metadata["primary_names"]) != len(metadata["sets"]): 743 744 raise ValueError( 745 "Metadata list and DataFrame columns must have the same length." 746 ) 747 748 else: 749 750 metadata["sets"] = [""] * len(metadata["primary_names"]) 751 752 metadata = pd.DataFrame(metadata) 753 if features is not None: 754 scatter_df = scatter_df.loc[ 755 find_features(data=scatter_df, features=features)["included"], 756 ] 757 scatter_df.columns = metadata["primary_names"] + "#" + metadata["sets"] 758 759 if occurence_data is not None: 760 if features is not None: 761 occurence_data = occurence_data.loc[ 762 find_features(data=occurence_data, features=features)["included"], 763 ] 764 occurence_data.columns = metadata["primary_names"] + "#" + metadata["sets"] 765 766 # check duplicated names 767 768 tmp_columns = scatter_df.columns 769 770 new_cols = make_unique_list(list(tmp_columns)) 771 772 scatter_df.columns = new_cols 773 774 if hclust is not None and len(expression_data.index) != 1: 775 776 Z = linkage(scatter_df, method=hclust) 777 778 # Get the order of features based on the dendrogram 779 order_of_features = dendrogram(Z, no_plot=True)["leaves"] 780 781 indexes_sort = list(scatter_df.index) 782 sorted_list_rows = [] 783 for n in order_of_features: 784 sorted_list_rows.append(indexes_sort[n]) 785 786 scatter_df = scatter_df.transpose() 787 788 Z = linkage(scatter_df, method=hclust) 789 790 # Get the order of features based on the dendrogram 791 order_of_features = dendrogram(Z, no_plot=True)["leaves"] 792 793 indexes_sort = list(scatter_df.index) 794 sorted_list_columns = [] 795 for n in order_of_features: 796 sorted_list_columns.append(indexes_sort[n]) 797 798 scatter_df = scatter_df.transpose() 799 800 scatter_df = scatter_df.loc[sorted_list_rows, sorted_list_columns] 801 802 if occurence_data is not None: 803 occurence_data = occurence_data.loc[sorted_list_rows, sorted_list_columns] 804 805 metadata["sets"] = [re.sub(".*#", "", x) for x in scatter_df.columns] 806 807 scatter_df.columns = [re.sub("#.*", "", x) for x in scatter_df.columns] 808 809 if occurence_data is not None: 810 occurence_data.columns = [re.sub("#.*", "", x) for x in occurence_data.columns] 811 812 fig, ax = plt.subplots(figsize=(img_width, img_high)) 813 814 norm = plt.Normalize(0, np.max(scatter_df)) 815 816 cmap = plt.get_cmap(colors) 817 818 # Bubble scatter 819 for i, _ in enumerate(scatter_df.index): 820 for j, _ in enumerate(scatter_df.columns): 821 if occurence_data is not None: 822 value_e = scatter_df.iloc[i, j] 823 value_o = occurence_data.iloc[i, j] 824 ax.scatter( 825 j, 826 i, 827 s=value_o * size_scale, 828 c=[cmap(norm(value_e))], 829 edgecolors="k", 830 linewidths=0.3, 831 ) 832 else: 833 value = scatter_df.iloc[i, j] 834 ax.scatter( 835 j, 836 i, 837 s=value * size_scale, 838 c=[cmap(norm(value))], 839 edgecolors="k", 840 linewidths=0.3, 841 ) 842 843 ax.set_yticks(range(len(scatter_df.index))) 844 ax.set_yticklabels(scatter_df.index, fontsize=label_size * 0.8) 845 ax.set_ylabel(y_lab, fontsize=label_size) 846 ax.set_xticks(range(len(scatter_df.columns))) 847 ax.set_xticklabels(scatter_df.columns, fontsize=label_size * 0.8, rotation=90) 848 849 # color bar 850 cax = inset_axes( 851 ax, 852 width="1%", 853 height=f"{bbox_to_anchor_scale}%", 854 bbox_to_anchor=(1, 0, 1, 1), 855 bbox_transform=ax.transAxes, 856 loc="upper left", 857 ) 858 cb = mpl.colorbar.ColorbarBase(cax, cmap=cmap, norm=norm) 859 cb.set_label(legend_lab, fontsize=label_size * 0.65) 860 cb.ax.tick_params(labelsize=label_size * 0.7) 861 862 if metadata_list is not None: 863 864 metadata_list = list(metadata["sets"]) 865 group_colors = get_color_palette(list(set(metadata_list)), palette_name="tab10") 866 867 for i, group in enumerate(metadata_list): 868 ax.add_patch( 869 plt.Rectangle( 870 (i - 0.5, len(scatter_df.index) - 0.1 * set_box_high), 871 1, 872 0.1 * set_box_size, 873 color=group_colors[group], 874 transform=ax.transData, 875 clip_on=False, 876 ) 877 ) 878 879 for i in range(1, len(metadata_list)): 880 if metadata_list[i] != metadata_list[i - 1]: 881 ax.axvline(i - 0.5, color="black", linestyle="--", lw=1) 882 883 group_patches = [ 884 mpatches.Patch(color=color, label=label) 885 for label, color in group_colors.items() 886 ] 887 fig.legend( 888 handles=group_patches, 889 title="Group", 890 fontsize=label_size * 0.7, 891 title_fontsize=label_size * 0.7, 892 loc="center left", 893 bbox_to_anchor=bbox_to_anchor_group, 894 frameon=False, 895 ) 896 897 # second legend (size) 898 if occurence_data is not None: 899 size_values = [0.25, 0.5, 1] 900 legend2_handles = [ 901 plt.Line2D( 902 [], 903 [], 904 marker="o", 905 linestyle="", 906 markersize=np.sqrt(v * size_scale * 0.5), 907 color="gray", 908 alpha=0.6, 909 label=f"{v * 100:.1f}", 910 ) 911 for v in size_values 912 ] 913 914 fig.legend( 915 handles=legend2_handles, 916 title="Percent [%]", 917 fontsize=label_size * 0.7, 918 title_fontsize=label_size * 0.7, 919 loc="center left", 920 bbox_to_anchor=bbox_to_anchor_perc, 921 frameon=False, 922 ) 923 924 _, ymax = ax.get_ylim() 925 926 ax.set_xlim(-0.5, len(scatter_df.columns) - 0.5) 927 ax.set_ylim(-0.5, ymax + 0.5) 928 929 return fig 930 931 932def calc_DEG( 933 data, 934 metadata_list: list | None = None, 935 entities: str | list | dict | None = None, 936 sets: str | list | dict | None = None, 937 min_exp: int | float = 0, 938 min_pct: int | float = 0.1, 939 n_proc: int = 10, 940): 941 """ 942 Perform differential gene expression (DEG) analysis on gene expression data. 943 944 The function compares groups of cells or samples (defined by `entities` or 945 `sets`) using the Mann–Whitney U test. It computes p-values, adjusted 946 p-values, fold changes, standardized effect sizes, and other statistics. 947 948 Parameters 949 ---------- 950 data : pandas.DataFrame 951 Expression matrix with features (e.g., genes) as rows and samples/cells 952 as columns. 953 954 metadata_list : list or None, optional 955 Metadata grouping corresponding to the columns in `data`. Required for 956 comparisons based on sets. Default is None. 957 958 entities : list, str, dict, or None, optional 959 Defines the comparison strategy: 960 - list of sample names → compare selected cells to the rest. 961 - 'All' → compare each sample/cell to all others. 962 - dict → user-defined groups for pairwise comparison. 963 - None → must be combined with `sets`. 964 965 sets : str, dict, or None, optional 966 Defines group-based comparisons: 967 - 'All' → compare each set/group to all others. 968 - dict with two groups → perform pairwise set comparison. 969 - None → must be combined with `entities`. 970 971 min_exp : float | int, default=0 972 Minimum expression threshold for filtering features. 973 974 min_pct : float | int, default=0.1 975 Minimum proportion of samples within the target group that must express 976 a feature for it to be tested. 977 978 n_proc : int, default=10 979 Number of parallel processes to use for statistical testing. 980 981 Returns 982 ------- 983 pandas.DataFrame or dict 984 Results of the differential expression analysis: 985 - If `entities` is a list → dict with keys: 'valid_cells', 986 'control_cells', and 'DEG' (results DataFrame). 987 - If `entities == 'All'` or `sets == 'All'` → DataFrame with results 988 for all groups. 989 - If pairwise comparison (dict for `entities` or `sets`) → DataFrame 990 with results for the specified groups. 991 992 The results DataFrame contains: 993 - 'feature': feature name 994 - 'p_val': raw p-value 995 - 'adj_pval': adjusted p-value (multiple testing correction) 996 - 'pct_valid': fraction of target group expressing the feature 997 - 'pct_ctrl': fraction of control group expressing the feature 998 - 'avg_valid': mean expression in target group 999 - 'avg_ctrl': mean expression in control group 1000 - 'sd_valid': standard deviation in target group 1001 - 'sd_ctrl': standard deviation in control group 1002 - 'esm': effect size metric 1003 - 'FC': fold change 1004 - 'log(FC)': log2-transformed fold change 1005 - 'norm_diff': difference in mean expression 1006 1007 Raises 1008 ------ 1009 ValueError 1010 - If `metadata_list` is provided but its length does not match 1011 the number of columns in `data`. 1012 - If neither `entities` nor `sets` is provided. 1013 1014 Notes 1015 ----- 1016 - Mann–Whitney U test is used for group comparisons. 1017 - Multiple testing correction is applied using a simple 1018 Benjamini–Hochberg-like method. 1019 - Features expressed below `min_exp` or in fewer than `min_pct` of target 1020 samples are filtered out. 1021 - Parallelization is handled by `joblib.Parallel`. 1022 1023 Examples 1024 -------- 1025 Compare a selected list of cells against all others: 1026 1027 >>> result = calc_DEG(data, entities=["cell1", "cell2", "cell3"]) 1028 1029 Compare each group to others (based on metadata): 1030 1031 >>> result = calc_DEG(data, metadata_list=group_labels, sets="All") 1032 1033 Perform pairwise comparison between two predefined sets: 1034 1035 >>> sets = {"GroupA": ["A1", "A2"], "GroupB": ["B1", "B2"]} 1036 >>> result = calc_DEG(data, sets=sets) 1037 """ 1038 1039 metadata = {} 1040 1041 metadata["primary_names"] = [str(x) for x in data.columns] 1042 1043 if metadata_list is not None: 1044 metadata["sets"] = metadata_list 1045 1046 if len(metadata["primary_names"]) != len(metadata["sets"]): 1047 1048 raise ValueError( 1049 "Metadata list and DataFrame columns must have the same length." 1050 ) 1051 1052 else: 1053 1054 metadata["sets"] = [""] * len(metadata["primary_names"]) 1055 1056 metadata = pd.DataFrame(metadata) 1057 1058 def stat_calc(choose, feature_name): 1059 target_values = choose.loc[choose["DEG"] == "target", feature_name] 1060 rest_values = choose.loc[choose["DEG"] == "rest", feature_name] 1061 1062 pct_valid = (target_values > 0).sum() / len(target_values) 1063 pct_rest = (rest_values > 0).sum() / len(rest_values) 1064 1065 avg_valid = np.mean(target_values) 1066 avg_ctrl = np.mean(rest_values) 1067 sd_valid = np.std(target_values, ddof=1) 1068 sd_ctrl = np.std(rest_values, ddof=1) 1069 esm = (avg_valid - avg_ctrl) / np.sqrt(((sd_valid**2 + sd_ctrl**2) / 2)) 1070 1071 if np.sum(target_values) == np.sum(rest_values): 1072 p_val = 1.0 1073 else: 1074 _, p_val = stats.mannwhitneyu( 1075 target_values, rest_values, alternative="two-sided" 1076 ) 1077 1078 return { 1079 "feature": feature_name, 1080 "p_val": p_val, 1081 "pct_valid": pct_valid, 1082 "pct_ctrl": pct_rest, 1083 "avg_valid": avg_valid, 1084 "avg_ctrl": avg_ctrl, 1085 "sd_valid": sd_valid, 1086 "sd_ctrl": sd_ctrl, 1087 "esm": esm, 1088 } 1089 1090 def prepare_and_run_stat(choose, valid_group, min_exp, min_pct, n_proc, factors): 1091 1092 tmp_dat = choose[choose["DEG"] == "target"] 1093 tmp_dat = tmp_dat.drop("DEG", axis=1) 1094 1095 counts = (tmp_dat > min_exp).sum(axis=0) 1096 1097 total_count = tmp_dat.shape[0] 1098 1099 info = pd.DataFrame( 1100 {"feature": list(tmp_dat.columns), "pct": list(counts / total_count)} 1101 ) 1102 1103 del tmp_dat 1104 1105 drop_col = info["feature"][info["pct"] <= min_pct] 1106 1107 if len(drop_col) + 1 == len(choose.columns): 1108 drop_col = info["feature"][info["pct"] == 0] 1109 1110 del info 1111 1112 choose = choose.drop(list(drop_col), axis=1) 1113 1114 results = Parallel(n_jobs=n_proc)( 1115 delayed(stat_calc)(choose, feature) 1116 for feature in tqdm(choose.columns[choose.columns != "DEG"]) 1117 ) 1118 1119 if len(results) > 0: 1120 df = pd.DataFrame(results) 1121 df["valid_group"] = valid_group 1122 df.sort_values(by="p_val", inplace=True) 1123 1124 num_tests = len(df) 1125 df["adj_pval"] = np.minimum( 1126 1, (df["p_val"] * num_tests) / np.arange(1, num_tests + 1) 1127 ) 1128 1129 factors_df = factors.to_frame(name="factor") 1130 1131 df = df.merge(factors_df, left_on="feature", right_index=True, how="left") 1132 1133 df[["avg_valid", "avg_ctrl", "factor"]] = df[ 1134 ["avg_valid", "avg_ctrl", "factor"] 1135 ].astype(float) 1136 1137 df["FC"] = (df["avg_valid"] + df["factor"]) / ( 1138 df["avg_ctrl"] + df["factor"] 1139 ) 1140 df["log(FC)"] = np.log2(df["FC"]) 1141 df["norm_diff"] = df["avg_valid"] - df["avg_ctrl"] 1142 df = df.drop(columns=["factor"]) 1143 else: 1144 columns = [ 1145 "feature", 1146 "valid_group", 1147 "p_val", 1148 "adj_pval", 1149 "avg_valid", 1150 "avg_ctrl", 1151 "FC", 1152 "log(FC)", 1153 "norm_diff", 1154 ] 1155 df = pd.DataFrame(columns=columns) 1156 return df 1157 1158 choose = data.T 1159 1160 factors = data.copy().replace(0, np.nan).min(axis=1) 1161 1162 nonzero_factors = factors[factors != 0] 1163 if len(nonzero_factors) > 0: 1164 factors[factors == 0] = nonzero_factors.min() 1165 else: 1166 factors[factors == 0] = 0.01 1167 1168 final_results = [] 1169 1170 if isinstance(entities, list) and sets is None: 1171 print("\nAnalysis started...\nComparing selected cells to the whole set...") 1172 1173 if metadata_list is None: 1174 choose.index = metadata["primary_names"] 1175 else: 1176 choose.index = metadata["primary_names"] + " # " + metadata["sets"] 1177 1178 if "#" not in entities[0]: 1179 choose.index = metadata["primary_names"] 1180 print( 1181 "You provided 'metadata_list', but did not include the set info (name # set) " 1182 "in the 'entities' list. " 1183 "Only the names will be compared, without considering the set information." 1184 ) 1185 1186 labels = ["target" if idx in entities else "rest" for idx in choose.index] 1187 valid = list( 1188 set(choose.index[[i for i, x in enumerate(labels) if x == "target"]]) 1189 ) 1190 1191 choose["DEG"] = labels 1192 choose = choose[choose["DEG"] != "drop"] 1193 1194 result_df = prepare_and_run_stat( 1195 choose.reset_index(drop=True), 1196 valid_group=valid, 1197 min_exp=min_exp, 1198 min_pct=min_pct, 1199 n_proc=n_proc, 1200 factors=factors, 1201 ) 1202 1203 return {"valid": valid, "control": "rest", "DEG": result_df} 1204 1205 elif entities == "All" and sets is None: 1206 print("\nAnalysis started...\nComparing each type of cell to others...") 1207 1208 if metadata_list is None: 1209 choose.index = metadata["primary_names"] 1210 else: 1211 choose.index = metadata["primary_names"] + " # " + metadata["sets"] 1212 1213 unique_labels = set(choose.index) 1214 1215 for label in tqdm(unique_labels): 1216 print(f"\nCalculating statistics for {label}") 1217 labels = ["target" if idx == label else "rest" for idx in choose.index] 1218 choose["DEG"] = labels 1219 choose = choose[choose["DEG"] != "drop"] 1220 result_df = prepare_and_run_stat( 1221 choose.copy(), 1222 valid_group=label, 1223 min_exp=min_exp, 1224 min_pct=min_pct, 1225 n_proc=n_proc, 1226 factors=factors, 1227 ) 1228 final_results.append(result_df) 1229 1230 final_results = pd.concat(final_results, ignore_index=True) 1231 1232 if metadata_list is None: 1233 final_results["valid_group"] = [ 1234 re.sub(" # ", "", x) for x in final_results["valid_group"] 1235 ] 1236 1237 return final_results 1238 1239 elif entities is None and sets == "All": 1240 print("\nAnalysis started...\nComparing each set/group to others...") 1241 choose.index = metadata["sets"] 1242 unique_sets = set(choose.index) 1243 1244 for label in tqdm(unique_sets): 1245 print(f"\nCalculating statistics for {label}") 1246 labels = ["target" if idx == label else "rest" for idx in choose.index] 1247 1248 choose["DEG"] = labels 1249 choose = choose[choose["DEG"] != "drop"] 1250 result_df = prepare_and_run_stat( 1251 choose.copy(), 1252 valid_group=label, 1253 min_exp=min_exp, 1254 min_pct=min_pct, 1255 n_proc=n_proc, 1256 factors=factors, 1257 ) 1258 final_results.append(result_df) 1259 1260 return pd.concat(final_results, ignore_index=True) 1261 1262 elif entities is None and isinstance(sets, dict): 1263 print("\nAnalysis started...\nComparing groups...") 1264 choose.index = metadata["sets"] 1265 1266 group_list = list(sets.keys()) 1267 if len(group_list) != 2: 1268 print("Only pairwise group comparison is supported.") 1269 return None 1270 1271 labels = [ 1272 ( 1273 "target" 1274 if idx in sets[group_list[0]] 1275 else "rest" if idx in sets[group_list[1]] else "drop" 1276 ) 1277 for idx in choose.index 1278 ] 1279 choose["DEG"] = labels 1280 choose = choose[choose["DEG"] != "drop"] 1281 1282 result_df = prepare_and_run_stat( 1283 choose.reset_index(drop=True), 1284 valid_group=group_list[0], 1285 min_exp=min_exp, 1286 min_pct=min_pct, 1287 n_proc=n_proc, 1288 factors=factors, 1289 ) 1290 return result_df 1291 1292 elif isinstance(entities, dict) and sets is None: 1293 print("\nAnalysis started...\nComparing groups...") 1294 1295 if metadata_list is None: 1296 choose.index = metadata["primary_names"] 1297 else: 1298 choose.index = metadata["primary_names"] + " # " + metadata["sets"] 1299 if "#" not in entities[list(entities.keys())[0]][0]: 1300 choose.index = metadata["primary_names"] 1301 print( 1302 "You provided 'metadata_list', but did not include the set info (name # set) " 1303 "in the 'entities' dict. " 1304 "Only the names will be compared, without considering the set information." 1305 ) 1306 1307 group_list = list(entities.keys()) 1308 if len(group_list) != 2: 1309 print("Only pairwise group comparison is supported.") 1310 return None 1311 1312 labels = [ 1313 ( 1314 "target" 1315 if idx in entities[group_list[0]] 1316 else "rest" if idx in entities[group_list[1]] else "drop" 1317 ) 1318 for idx in choose.index 1319 ] 1320 1321 choose["DEG"] = labels 1322 choose = choose[choose["DEG"] != "drop"] 1323 1324 result_df = prepare_and_run_stat( 1325 choose.reset_index(drop=True), 1326 valid_group=group_list[0], 1327 min_exp=min_exp, 1328 min_pct=min_pct, 1329 n_proc=n_proc, 1330 factors=factors, 1331 ) 1332 1333 return result_df.reset_index(drop=True) 1334 1335 else: 1336 raise ValueError( 1337 "You must specify either 'entities' or 'sets'. None were provided, which is not allowed for this analysis." 1338 ) 1339 1340 1341def average(data): 1342 """ 1343 Compute the column-wise average of a DataFrame, aggregating by column names. 1344 1345 If multiple columns share the same name, their values are averaged. 1346 1347 Parameters 1348 ---------- 1349 data : pandas.DataFrame 1350 Input DataFrame with numeric values. Columns with identical names 1351 will be aggregated by their mean. 1352 1353 Returns 1354 ------- 1355 pandas.DataFrame 1356 DataFrame with the same rows as the input but with unique columns, 1357 where duplicate columns have been replaced by their mean values. 1358 """ 1359 1360 wide_data = data 1361 1362 aggregated_df = wide_data.T.groupby(level=0).mean().T 1363 1364 return aggregated_df 1365 1366 1367def occurrence(data): 1368 """ 1369 Calculate the occurrence frequency of features in a DataFrame. 1370 1371 Converts the input DataFrame to binary (presence/absence) and computes 1372 the proportion of non-zero entries for each feature, aggregating by 1373 column names if duplicates exist. 1374 1375 Parameters 1376 ---------- 1377 data : pandas.DataFrame 1378 Input DataFrame with numeric values. Each column represents a feature. 1379 1380 Returns 1381 ------- 1382 pandas.DataFrame 1383 DataFrame with the same rows as the input, where each value represents 1384 the proportion of samples in which the feature is present (non-zero). 1385 Columns with identical names are aggregated. 1386 """ 1387 1388 binary_data = (data > 0).astype(int) 1389 1390 counts = binary_data.columns.value_counts() 1391 1392 binary_data = binary_data.T.groupby(level=0).sum().T 1393 binary_data = binary_data.astype(float) 1394 1395 for i in counts.index: 1396 binary_data.loc[:, i] = (binary_data.loc[:, i] / counts[i]).astype(float) 1397 1398 return binary_data 1399 1400 1401def add_subnames(names_list: list, parent_name: str, new_clusters: list): 1402 """ 1403 Append sub-cluster names to a parent name within a list of names. 1404 1405 This function replaces occurrences of `parent_name` in `names_list` with 1406 a concatenation of the parent name and corresponding sub-cluster name 1407 from `new_clusters` (formatted as "parent.subcluster"). Non-matching names 1408 are left unchanged. 1409 1410 Parameters 1411 ---------- 1412 names_list : list 1413 Original list of names (e.g., column names or cluster labels). 1414 1415 parent_name : str 1416 Name of the parent cluster to which sub-cluster names will be added. 1417 Must exist in `names_list`. 1418 1419 new_clusters : list 1420 List of sub-cluster names. Its length must match the number of times 1421 `parent_name` occurs in `names_list`. 1422 1423 Returns 1424 ------- 1425 list 1426 Updated list of names with sub-cluster names appended to the parent name. 1427 1428 Raises 1429 ------ 1430 ValueError 1431 - If `parent_name` is not found in `names_list`. 1432 - If `new_clusters` length does not match the number of occurrences of 1433 `parent_name`. 1434 1435 Examples 1436 -------- 1437 >>> add_subnames(['A', 'B', 'A'], 'A', ['1', '2']) 1438 ['A.1', 'B', 'A.2'] 1439 """ 1440 1441 if str(parent_name) not in [str(x) for x in names_list]: 1442 raise ValueError( 1443 "Parent name is missing from the original dataset`s column names!" 1444 ) 1445 1446 if len(new_clusters) != len([x for x in names_list if str(x) == str(parent_name)]): 1447 raise ValueError( 1448 "New cluster names list has a different length than the number of clusters in the original dataset!" 1449 ) 1450 1451 new_names = [] 1452 ixn = 0 1453 for _, i in enumerate(names_list): 1454 if str(i) == str(parent_name): 1455 1456 new_names.append(f"{parent_name}.{new_clusters[ixn]}") 1457 ixn += 1 1458 1459 else: 1460 new_names.append(i) 1461 1462 return new_names 1463 1464 1465def development_clust( 1466 data: pd.DataFrame, method: str = "ward", img_width: int = 5, img_high: int = 5 1467): 1468 """ 1469 Perform hierarchical clustering on the columns of a DataFrame and plot a dendrogram. 1470 1471 Uses Ward's method to cluster the transposed data (columns) and generates 1472 a dendrogram showing the relationships between features or samples. 1473 1474 Parameters 1475 ---------- 1476 data : pandas.DataFrame 1477 Input DataFrame with features as rows and samples/columns to be clustered. 1478 1479 method : str 1480 Method for hierarchical clustering. Options include: 1481 - 'ward' : minimizes the variance of clusters being merged. 1482 - 'single' : uses the minimum of the distances between all observations of the two sets. 1483 - 'complete' : uses the maximum of the distances between all observations of the two sets. 1484 - 'average' : uses the average of the distances between all observations of the two sets. 1485 1486 img_width : int or float, default=5 1487 Width of the resulting figure in inches. 1488 1489 img_high : int or float, default=5 1490 Height of the resulting figure in inches. 1491 1492 Returns 1493 ------- 1494 matplotlib.figure.Figure 1495 The dendrogram figure. 1496 """ 1497 1498 z = linkage(data.T, method=method) 1499 1500 figure, ax = plt.subplots(figsize=(img_width, img_high)) 1501 1502 dendrogram(z, labels=data.columns, orientation="left", ax=ax) 1503 1504 return figure 1505 1506 1507def adjust_cells_to_group_mean(data, data_avg, beta=0.2): 1508 """ 1509 Adjust each cell's values towards the mean of its group (centroid). 1510 1511 This function moves each cell's values in `data` slightly towards the 1512 corresponding group mean in `data_avg`, controlled by the parameter `beta`. 1513 1514 Parameters 1515 ---------- 1516 data : pandas.DataFrame 1517 Original data with features as rows and cells/samples as columns. 1518 1519 data_avg : pandas.DataFrame 1520 DataFrame of group averages (centroids) with features as rows and 1521 group names as columns. 1522 1523 beta : float, default=0.2 1524 Weight for adjustment towards the group mean. 0 = no adjustment, 1525 1 = fully replaced by the group mean. 1526 1527 Returns 1528 ------- 1529 pandas.DataFrame 1530 Adjusted data with the same shape as the input `data`. 1531 """ 1532 1533 df_adjusted = data.copy() 1534 1535 for group_name in data_avg.columns: 1536 col_idx = [ 1537 i 1538 for i, c in enumerate(df_adjusted.columns) 1539 if str(c).startswith(group_name) 1540 ] 1541 if not col_idx: 1542 continue 1543 1544 centroid = data_avg.loc[df_adjusted.index, group_name].to_numpy()[:, None] 1545 1546 df_adjusted.iloc[:, col_idx] = (1 - beta) * df_adjusted.iloc[ 1547 :, col_idx 1548 ].to_numpy() + beta * centroid 1549 1550 return df_adjusted
21def load_sparse(path: str, name: str): 22 """ 23 Load a sparse matrix dataset along with associated gene 24 and cell metadata, and return it as a dense DataFrame. 25 26 This function expects the input directory to contain three files in standard 27 10x Genomics format: 28 - "matrix.mtx": the gene expression matrix in Matrix Market format 29 - "genes.tsv": tab-separated file containing gene identifiers 30 - "barcodes.tsv": tab-separated file containing cell barcodes / names 31 32 Parameters 33 ---------- 34 path : str 35 Path to the directory containing the matrix and annotation files. 36 37 name : str 38 Label or dataset identifier to be assigned to all cells in the metadata. 39 40 Returns 41 ------- 42 data : pandas.DataFrame 43 Dense expression matrix where rows correspond to genes and columns 44 correspond to cells. 45 metadata : pandas.DataFrame 46 Metadata DataFrame with two columns: 47 - "cell_names": the names of the cells (barcodes) 48 - "sets": the dataset label assigned to each cell (from `name`) 49 50 Notes 51 ----- 52 The function converts the sparse matrix into a dense DataFrame. This may 53 require a large amount of memory for datasets with many cells and genes. 54 """ 55 56 data = mmread(os.path.join(path, "matrix.mtx")) 57 data = pd.DataFrame(data.todense()) 58 genes = pd.read_csv(os.path.join(path, "genes.tsv"), header=None, sep="\t") 59 names = pd.read_csv(os.path.join(path, "barcodes.tsv"), header=None, sep="\t") 60 data.columns = [str(x) for x in names[0]] 61 data.index = list(genes[0]) 62 63 names = list(data.columns) 64 sets = [name] * len(names) 65 66 metadata = pd.DataFrame({"cell_names": names, "sets": sets}) 67 68 return data, metadata
Load a sparse matrix dataset along with associated gene and cell metadata, and return it as a dense DataFrame.
This function expects the input directory to contain three files in standard 10x Genomics format:
- "matrix.mtx": the gene expression matrix in Matrix Market format
- "genes.tsv": tab-separated file containing gene identifiers
- "barcodes.tsv": tab-separated file containing cell barcodes / names
Parameters
path : str Path to the directory containing the matrix and annotation files.
name : str Label or dataset identifier to be assigned to all cells in the metadata.
Returns
data : pandas.DataFrame
Dense expression matrix where rows correspond to genes and columns
correspond to cells.
metadata : pandas.DataFrame
Metadata DataFrame with two columns:
- "cell_names": the names of the cells (barcodes)
- "sets": the dataset label assigned to each cell (from name
)
Notes
The function converts the sparse matrix into a dense DataFrame. This may require a large amount of memory for datasets with many cells and genes.
71def volcano_plot( 72 deg_data: pd.DataFrame, 73 p_adj: bool = True, 74 top: int = 25, 75 p_val: float | int = 0.05, 76 lfc: float | int = 0.25, 77 standard_scale: bool = False, 78 rescale_adj: bool = True, 79 image_width: int = 12, 80 image_high: int = 12, 81): 82 """ 83 Generate a volcano plot from differential expression results. 84 85 A volcano plot visualizes the relationship between statistical significance 86 (p-values or standarized p-value) and log(fold change) for each gene, highlighting 87 genes that pass significance thresholds. 88 89 Parameters 90 ---------- 91 deg_data : pandas.DataFrame 92 DataFrame containing differential expression results from calc_DEG() function. 93 94 p_adj : bool, default=True 95 If True, use adjusted p-values. If False, use raw p-values. 96 97 top : int, default=25 98 Number of top significant genes to highlight on the plot. 99 100 p_val : float | int, default=0.05 101 Significance threshold for p-values (or adjusted p-values). 102 103 lfc : float | int, default=0.25 104 Threshold for absolute log fold change. 105 106 standard_scale : bool, default=False 107 If True, use standardized p-values for visualization instead of -log10(p-values). 108 109 rescale_adj : bool, default=True 110 If True, rescale p-values to avoid long breaks caused by outlier values. 111 112 image_width : int, default=12 113 Width of the generated plot in inches. 114 115 image_high : int, default=12 116 Height of the generated plot in inches. 117 118 Returns 119 ------- 120 matplotlib.figure.Figure 121 The generated volcano plot figure. 122 123 """ 124 125 if p_adj: 126 pv = "adj_pval" 127 else: 128 pv = "p_val" 129 130 deg_df = deg_data.copy() 131 132 if standard_scale: 133 134 p_val_scale = "scale(p-val)" 135 136 scaler = MinMaxScaler(feature_range=(0, 1000)) 137 138 tmp_p = deg_df[ 139 ((deg_df[pv] != 0) & (deg_df["log(FC)"] < 0)) 140 | ((deg_df[pv] != 0) & (deg_df["log(FC)"] > 0)) 141 ] 142 143 tmp_p[p_val_scale] = (-np.log10(deg_df[pv])).astype(float) 144 145 tmp_p[p_val_scale] = scaler.fit_transform( 146 tmp_p[p_val_scale].values.reshape(-1, 1) 147 ).astype(float) 148 149 median_shift = abs(tmp_p[p_val_scale].diff().max()) * 25 150 cur_max = tmp_p[p_val_scale].max() 151 152 zero_p_plus = deg_df[(deg_df[pv] == 0) & (deg_df["log(FC)"] > 0)] 153 zero_p_plus = zero_p_plus.sort_values(by="log(FC)", ascending=True).reset_index( 154 drop=True 155 ) 156 157 zero_p_plus[p_val_scale] = [ 158 (cur_max + (x * median_shift)) for x in range(1, len(zero_p_plus.index) + 1) 159 ] 160 zero_p_plus[p_val_scale] = zero_p_plus[p_val_scale].astype(float) 161 162 zero_p_minus = deg_df[(deg_df[pv] == 0) & (deg_df["log(FC)"] < 0)] 163 zero_p_minus = zero_p_minus.sort_values( 164 by="log(FC)", ascending=False 165 ).reset_index(drop=True) 166 167 zero_p_minus[p_val_scale] = [ 168 (cur_max + (x * median_shift)) 169 for x in range(1, len(zero_p_minus.index) + 1) 170 ] 171 zero_p_minus[p_val_scale] = zero_p_minus[p_val_scale].astype(float) 172 173 deg_df = pd.concat([zero_p_plus, tmp_p, zero_p_minus], ignore_index=True) 174 175 else: 176 177 shift = 0.25 178 179 p_val_scale = "-log(p_val)" 180 181 min_minus = min(deg_df[pv][(deg_df[pv] != 0) & (deg_df["log(FC)"] < 0)]) 182 min_plus = min(deg_df[pv][(deg_df[pv] != 0) & (deg_df["log(FC)"] > 0)]) 183 184 zero_p_plus = deg_df[(deg_df[pv] == 0) & (deg_df["log(FC)"] > 0)] 185 zero_p_plus = zero_p_plus.sort_values( 186 by="log(FC)", ascending=False 187 ).reset_index(drop=True) 188 zero_p_plus[pv] = [ 189 (shift * x) * min_plus for x in range(1, len(zero_p_plus.index) + 1) 190 ] 191 192 zero_p_minus = deg_df[(deg_df[pv] == 0) & (deg_df["log(FC)"] < 0)] 193 zero_p_minus = zero_p_minus.sort_values( 194 by="log(FC)", ascending=True 195 ).reset_index(drop=True) 196 zero_p_minus[pv] = [ 197 (shift * x) * min_minus for x in range(1, len(zero_p_minus.index) + 1) 198 ] 199 200 tmp_p = deg_df[ 201 ((deg_df[pv] != 0) & (deg_df["log(FC)"] < 0)) 202 | ((deg_df[pv] != 0) & (deg_df["log(FC)"] > 0)) 203 ] 204 205 del deg_df 206 207 deg_df = pd.concat([zero_p_plus, tmp_p, zero_p_minus], ignore_index=True) 208 209 deg_df[p_val_scale] = -np.log10(deg_df[pv]) 210 211 deg_df["top100"] = None 212 213 if rescale_adj: 214 215 deg_df = deg_df.sort_values(by=p_val_scale, ascending=False) 216 217 deg_df = deg_df.reset_index(drop=True) 218 219 doubled = [] 220 ratio = [] 221 for n, i in enumerate(deg_df.index): 222 for j in range(1, 6): 223 if ( 224 n + j < len(deg_df.index) 225 and deg_df[p_val_scale][n] / deg_df[p_val_scale][n + j] >= 2 226 ): 227 doubled.append(n) 228 ratio.append(deg_df[p_val_scale][n + j] / deg_df[p_val_scale][n]) 229 230 df = pd.DataFrame({"doubled": doubled, "ratio": ratio}) 231 df = df[df["doubled"] < 100] 232 233 df["ratio"] = (1 - df["ratio"]) / 5 234 df = df.reset_index(drop=True) 235 236 df = df.sort_values("doubled") 237 238 if len(df["doubled"]) == 1 and 0 in df["doubled"]: 239 df = df 240 else: 241 doubled2 = [] 242 243 for l in df["doubled"]: 244 if l + 1 != len(doubled) and l + 1 - l == 1: 245 doubled2.append(l) 246 doubled2.append(l + 1) 247 else: 248 break 249 250 doubled2 = sorted(set(doubled2), reverse=True) 251 252 if len(doubled2) > 1: 253 df = df[df["doubled"].isin(doubled2)] 254 df = df.sort_values("doubled", ascending=False) 255 df = df.reset_index(drop=True) 256 for c in df.index: 257 deg_df.loc[df["doubled"][c], p_val_scale] = deg_df.loc[ 258 df["doubled"][c] + 1, p_val_scale 259 ] * (1 + df["ratio"][c]) 260 261 deg_df.loc[(deg_df["log(FC)"] <= 0) & (deg_df["p_val"] <= p_val), "top100"] = "red" 262 deg_df.loc[(deg_df["log(FC)"] > 0) & (deg_df["p_val"] <= p_val), "top100"] = "blue" 263 deg_df.loc[deg_df["p_val"] > p_val, "top100"] = "lightgray" 264 265 if lfc > 0: 266 deg_df.loc[ 267 (deg_df["log(FC)"] <= lfc) & (deg_df["log(FC)"] >= -lfc), "top100" 268 ] = "lightgray" 269 270 down_int = len( 271 deg_df["top100"][(deg_df["log(FC)"] <= lfc * -1) & (deg_df["p_val"] <= p_val)] 272 ) 273 up_int = len( 274 deg_df["top100"][(deg_df["log(FC)"] > lfc) & (deg_df["p_val"] <= p_val)] 275 ) 276 277 deg_df_up = deg_df[deg_df["log(FC)"] > 0] 278 deg_df_up = deg_df_up.sort_values(["p_val", "log(FC)"], ascending=[True, False]) 279 deg_df_up = deg_df_up.reset_index(drop=True) 280 281 n = -1 282 l = 0 283 while True: 284 n += 1 285 if deg_df_up["log(FC)"][n] > lfc: 286 deg_df_up.loc[n, "top100"] = "green" 287 l += 1 288 if l == top or deg_df_up["p_val"][n] > p_val: 289 break 290 291 deg_df_down = deg_df[deg_df["log(FC)"] <= 0] 292 deg_df_down = deg_df_down.sort_values(["p_val", "log(FC)"], ascending=[True, True]) 293 deg_df_down = deg_df_down.reset_index(drop=True) 294 295 n = -1 296 l = 0 297 while True: 298 n += 1 299 if deg_df_down["log(FC)"][n] < lfc * -1: 300 deg_df_down.loc[n, "top100"] = "yellow" 301 302 l += 1 303 if l == top or deg_df_down["p_val"][n] > p_val: 304 break 305 306 deg_df = pd.concat([deg_df_up, deg_df_down]) 307 308 que = ["lightgray", "red", "blue", "yellow", "green"] 309 310 deg_df = deg_df.sort_values( 311 by="top100", key=lambda x: x.map({v: i for i, v in enumerate(que)}) 312 ) 313 314 deg_df = deg_df.reset_index(drop=True) 315 316 fig, ax = plt.subplots(figsize=(image_width, image_high)) 317 318 plt.scatter( 319 x=deg_df["log(FC)"], y=deg_df[p_val_scale], color=deg_df["top100"], zorder=2 320 ) 321 322 plt.plot( 323 [max(deg_df["log(FC)"]) * -1.1, max(deg_df["log(FC)"]) * 1.1], 324 [-np.log10(p_val), -np.log10(p_val)], 325 linestyle="--", 326 linewidth=3, 327 color="lightgray", 328 zorder=1, 329 ) 330 331 if lfc > 0: 332 plt.plot( 333 [lfc * -1, lfc * -1], 334 [-3, max(deg_df[p_val_scale]) * 1.1], 335 linestyle="--", 336 linewidth=3, 337 color="lightgray", 338 zorder=1, 339 ) 340 plt.plot( 341 [lfc, lfc], 342 [-3, max(deg_df[p_val_scale]) * 1.1], 343 linestyle="--", 344 linewidth=3, 345 color="lightgray", 346 zorder=1, 347 ) 348 349 plt.xlabel("log(FC)") 350 plt.ylabel(p_val_scale) 351 plt.title("Volcano plot") 352 353 plt.ylim(min(deg_df[p_val_scale]) - 5, max(deg_df[p_val_scale]) * 1.25) 354 355 texts = [ 356 ax.text(deg_df["log(FC)"][i], deg_df[p_val_scale][i], deg_df["feature"][i]) 357 for i in deg_df.index 358 if deg_df["top100"][i] in ["green", "yellow"] 359 ] 360 361 adjust_text(texts, arrowprops=dict(arrowstyle="-", color="gray", alpha=0.5)) 362 363 legend_elements = [ 364 Line2D( 365 [0], 366 [0], 367 marker="o", 368 color="w", 369 label="top-upregulated", 370 markerfacecolor="green", 371 markersize=10, 372 ), 373 Line2D( 374 [0], 375 [0], 376 marker="o", 377 color="w", 378 label="top-downregulated", 379 markerfacecolor="yellow", 380 markersize=10, 381 ), 382 Line2D( 383 [0], 384 [0], 385 marker="o", 386 color="w", 387 label="upregulated", 388 markerfacecolor="blue", 389 markersize=10, 390 ), 391 Line2D( 392 [0], 393 [0], 394 marker="o", 395 color="w", 396 label="downregulated", 397 markerfacecolor="red", 398 markersize=10, 399 ), 400 Line2D( 401 [0], 402 [0], 403 marker="o", 404 color="w", 405 label="non-significant", 406 markerfacecolor="lightgray", 407 markersize=10, 408 ), 409 ] 410 411 ax.legend(handles=legend_elements, loc="upper right") 412 ax.grid(visible=False) 413 414 ax.annotate( 415 "\nmin p-value = " + str(p_val), 416 xy=(0.025, 0.975), 417 xycoords="axes fraction", 418 fontsize=12, 419 ) 420 421 if lfc > 0: 422 ax.annotate( 423 "\nmin log(FC) = " + str(lfc), 424 xy=(0.025, 0.95), 425 xycoords="axes fraction", 426 fontsize=12, 427 ) 428 429 ax.annotate( 430 "\nDownregulated: " + str(down_int), 431 xy=(0.025, 0.925), 432 xycoords="axes fraction", 433 fontsize=12, 434 color="red", 435 ) 436 437 ax.annotate( 438 "\nUpregulated: " + str(up_int), 439 xy=(0.025, 0.9), 440 xycoords="axes fraction", 441 fontsize=12, 442 color="blue", 443 ) 444 445 plt.show() 446 447 return fig
Generate a volcano plot from differential expression results.
A volcano plot visualizes the relationship between statistical significance (p-values or standarized p-value) and log(fold change) for each gene, highlighting genes that pass significance thresholds.
Parameters
deg_data : pandas.DataFrame DataFrame containing differential expression results from calc_DEG() function.
p_adj : bool, default=True If True, use adjusted p-values. If False, use raw p-values.
top : int, default=25 Number of top significant genes to highlight on the plot.
p_val : float | int, default=0.05 Significance threshold for p-values (or adjusted p-values).
lfc : float | int, default=0.25 Threshold for absolute log fold change.
standard_scale : bool, default=False If True, use standardized p-values for visualization instead of -log10(p-values).
rescale_adj : bool, default=True If True, rescale p-values to avoid long breaks caused by outlier values.
image_width : int, default=12 Width of the generated plot in inches.
image_high : int, default=12 Height of the generated plot in inches.
Returns
matplotlib.figure.Figure The generated volcano plot figure.
450def find_features(data: pd.DataFrame, features: list): 451 """ 452 Identify features (rows) from a DataFrame that match a given list of features, 453 ignoring case sensitivity. 454 455 Parameters 456 ---------- 457 data : pandas.DataFrame 458 DataFrame with features in the index (rows). 459 460 features : list 461 List of feature names to search for. 462 463 Returns 464 ------- 465 dict 466 Dictionary with keys: 467 - "included": list of features found in the DataFrame index. 468 - "not_included": list of requested features not found in the DataFrame index. 469 - "potential": list of features in the DataFrame that may be similar. 470 """ 471 472 features_upper = [str(x).upper() for x in features] 473 474 index_set = set(data.index) 475 476 features_in = [x for x in index_set if x.upper() in features_upper] 477 features_in_upper = [x.upper() for x in features_in] 478 features_out_upper = [x for x in features_upper if x not in features_in_upper] 479 features_out = [x for x in features if x.upper() in features_out_upper] 480 similar_features = [ 481 idx for idx in index_set if any(x in idx.upper() for x in features_out_upper) 482 ] 483 484 return { 485 "included": features_in, 486 "not_included": features_out, 487 "potential": similar_features, 488 }
Identify features (rows) from a DataFrame that match a given list of features, ignoring case sensitivity.
Parameters
data : pandas.DataFrame DataFrame with features in the index (rows).
features : list List of feature names to search for.
Returns
dict Dictionary with keys: - "included": list of features found in the DataFrame index. - "not_included": list of requested features not found in the DataFrame index. - "potential": list of features in the DataFrame that may be similar.
491def find_names(data: pd.DataFrame, names: list): 492 """ 493 Identify names (columns) from a DataFrame that match a given list of names, 494 ignoring case sensitivity. 495 496 Parameters 497 ---------- 498 data : pandas.DataFrame 499 DataFrame with names in the columns. 500 501 names : list 502 List of names to search for. 503 504 Returns 505 ------- 506 dict 507 Dictionary with keys: 508 - "included": list of names found in the DataFrame columns. 509 - "not_included": list of requested names not found in the DataFrame columns. 510 - "potential": list of names in the DataFrame that may be similar. 511 """ 512 513 names_upper = [str(x).upper() for x in names] 514 515 columns = set(data.columns) 516 517 names_in = [x for x in columns if x.upper() in names_upper] 518 names_in_upper = [x.upper() for x in names_in] 519 names_out_upper = [x for x in names_upper if x not in names_in_upper] 520 names_out = [x for x in names if x.upper() in names_out_upper] 521 similar_names = [ 522 idx for idx in columns if any(x in idx.upper() for x in names_out_upper) 523 ] 524 525 return {"included": names_in, "not_included": names_out, "potential": similar_names}
Identify names (columns) from a DataFrame that match a given list of names, ignoring case sensitivity.
Parameters
data : pandas.DataFrame DataFrame with names in the columns.
names : list List of names to search for.
Returns
dict Dictionary with keys: - "included": list of names found in the DataFrame columns. - "not_included": list of requested names not found in the DataFrame columns. - "potential": list of names in the DataFrame that may be similar.
528def reduce_data(data: pd.DataFrame, features: list = [], names: list = []): 529 """ 530 Subset a DataFrame based on selected features (rows) and/or names (columns). 531 532 Parameters 533 ---------- 534 data : pandas.DataFrame 535 Input DataFrame with features as rows and names as columns. 536 537 features : list 538 List of features to include (rows). Default is an empty list. 539 If empty, all rows are returned. 540 541 names : list 542 List of names to include (columns). Default is an empty list. 543 If empty, all columns are returned. 544 545 Returns 546 ------- 547 pandas.DataFrame 548 Subset of the input DataFrame containing only the selected rows 549 and/or columns. 550 551 Raises 552 ------ 553 ValueError 554 If both `features` and `names` are empty. 555 """ 556 557 if len(features) > 0 and len(names) > 0: 558 fet = find_features(data=data, features=features) 559 560 nam = find_names(data=data, names=names) 561 562 data_to_return = data.loc[fet["included"], nam["included"]] 563 564 elif len(features) > 0 and len(names) == 0: 565 fet = find_features(data=data, features=features) 566 567 data_to_return = data.loc[fet["included"], :] 568 569 elif len(features) == 0 and len(names) > 0: 570 571 nam = find_names(data=data, names=names) 572 573 data_to_return = data.loc[:, nam["included"]] 574 575 else: 576 577 raise ValueError("features and names have zero length!") 578 579 return data_to_return
Subset a DataFrame based on selected features (rows) and/or names (columns).
Parameters
data : pandas.DataFrame Input DataFrame with features as rows and names as columns.
features : list List of features to include (rows). Default is an empty list. If empty, all rows are returned.
names : list List of names to include (columns). Default is an empty list. If empty, all columns are returned.
Returns
pandas.DataFrame Subset of the input DataFrame containing only the selected rows and/or columns.
Raises
ValueError
If both features
and names
are empty.
582def make_unique_list(lst): 583 """ 584 Generate a list where duplicate items are renamed to ensure uniqueness. 585 586 Each duplicate is appended with a suffix ".n", where n indicates the 587 occurrence count (starting from 1). 588 589 Parameters 590 ---------- 591 lst : list 592 Input list of items (strings or other hashable types). 593 594 Returns 595 ------- 596 list 597 List with unique values. 598 599 Examples 600 -------- 601 >>> make_unique_list(["A", "B", "A", "A"]) 602 ['A', 'B', 'A.1', 'A.2'] 603 """ 604 seen = {} 605 result = [] 606 for item in lst: 607 if item not in seen: 608 seen[item] = 0 609 result.append(item) 610 else: 611 seen[item] += 1 612 result.append(f"{item}.{seen[item]}") 613 return result
Generate a list where duplicate items are renamed to ensure uniqueness.
Each duplicate is appended with a suffix ".n", where n indicates the occurrence count (starting from 1).
Parameters
lst : list Input list of items (strings or other hashable types).
Returns
list List with unique values.
Examples
>>> make_unique_list(["A", "B", "A", "A"])
['A', 'B', 'A.1', 'A.2']
623def features_scatter( 624 expression_data: pd.DataFrame, 625 occurence_data: pd.DataFrame | None = None, 626 scale: bool = False, 627 features: list | None = None, 628 metadata_list: list | None = None, 629 colors: str = "viridis", 630 hclust: str | None = "complete", 631 img_width: int = 8, 632 img_high: int = 5, 633 label_size: int = 10, 634 size_scale: int = 100, 635 y_lab: str = "Genes", 636 legend_lab: str = "log(CPM + 1)", 637 set_box_size: float | int = 5, 638 set_box_high: float | int = 5, 639 bbox_to_anchor_scale: int = 25, 640 bbox_to_anchor_perc: tuple = (0.91, 0.63), 641 bbox_to_anchor_group: tuple = (1.01, 0.4), 642): 643 """ 644 Create a bubble scatter plot of selected features across samples. 645 646 Each point represents a feature-sample pair, where the color encodes the 647 expression value and the size encodes occurrence or relative abundance. 648 Optionally, hierarchical clustering can be applied to order rows and columns. 649 650 Parameters 651 ---------- 652 expression_data : pandas.DataFrame 653 Expression values (mean) with features as rows and samples as columns derived from average() function. 654 655 occurence_data : pandas.DataFrame or None 656 DataFrame with occurrence/frequency values (same shape as `expression_data`) derived from occurrence() function. 657 If None, bubble sizes are based on expression values. 658 659 scale: bool, default False 660 If True, expression_data will be scaled (0–1) across the rows (features). 661 662 features : list or None 663 List of features (rows) to display. If None, all features are used. 664 665 metadata_list : list or None, optional 666 Metadata grouping for samples (same length as number of columns). 667 Used to add group colors and separators in the plot. 668 669 colors : str, default='viridis' 670 Colormap for expression values. 671 672 hclust : str or None, default='complete' 673 Linkage method for hierarchical clustering. If None, no clustering 674 is performed. 675 676 img_width : int or float, default=8 677 Width of the plot in inches. 678 679 img_high : int or float, default=5 680 Height of the plot in inches. 681 682 label_size : int, default=10 683 Font size for axis labels and ticks. 684 685 size_scale : int or float, default=100 686 Scaling factor for bubble sizes. 687 688 y_lab : str, default='Genes' 689 Label for the x-axis. 690 691 legend_lab : str, default='log(CPM + 1)' 692 Label for the colorbar legend. 693 694 bbox_to_anchor_scale : int, default=25 695 Vertical scale (percentage) for positioning the colorbar. 696 697 bbox_to_anchor_perc : tuple, default=(0.91, 0.63) 698 Anchor position for the size legend (percent bubble legend). 699 700 bbox_to_anchor_group : tuple, default=(1.01, 0.4) 701 Anchor position for the group legend. 702 703 Returns 704 ------- 705 matplotlib.figure.Figure 706 The generated scatter plot figure. 707 708 Raises 709 ------ 710 ValueError 711 If `metadata_list` is provided but its length does not match 712 the number of columns in `expression_data`. 713 714 Notes 715 ----- 716 - Colors represent expression values normalized to the colormap. 717 - Bubble sizes represent occurrence values (or expression values if 718 `occurence_data` is None). 719 - If `metadata_list` is given, groups are indicated with colors and 720 dashed vertical separators. 721 """ 722 723 scatter_df = expression_data.copy() 724 725 if scale: 726 727 legend_lab = "Scaled\n" + legend_lab 728 729 scaler = MinMaxScaler(feature_range=(0, 1)) 730 scatter_df = pd.DataFrame( 731 scaler.fit_transform(scatter_df.T).T, 732 index=scatter_df.index, 733 columns=scatter_df.columns, 734 ) 735 736 metadata = {} 737 738 metadata["primary_names"] = [str(x) for x in scatter_df.columns] 739 740 if metadata_list is not None: 741 metadata["sets"] = metadata_list 742 743 if len(metadata["primary_names"]) != len(metadata["sets"]): 744 745 raise ValueError( 746 "Metadata list and DataFrame columns must have the same length." 747 ) 748 749 else: 750 751 metadata["sets"] = [""] * len(metadata["primary_names"]) 752 753 metadata = pd.DataFrame(metadata) 754 if features is not None: 755 scatter_df = scatter_df.loc[ 756 find_features(data=scatter_df, features=features)["included"], 757 ] 758 scatter_df.columns = metadata["primary_names"] + "#" + metadata["sets"] 759 760 if occurence_data is not None: 761 if features is not None: 762 occurence_data = occurence_data.loc[ 763 find_features(data=occurence_data, features=features)["included"], 764 ] 765 occurence_data.columns = metadata["primary_names"] + "#" + metadata["sets"] 766 767 # check duplicated names 768 769 tmp_columns = scatter_df.columns 770 771 new_cols = make_unique_list(list(tmp_columns)) 772 773 scatter_df.columns = new_cols 774 775 if hclust is not None and len(expression_data.index) != 1: 776 777 Z = linkage(scatter_df, method=hclust) 778 779 # Get the order of features based on the dendrogram 780 order_of_features = dendrogram(Z, no_plot=True)["leaves"] 781 782 indexes_sort = list(scatter_df.index) 783 sorted_list_rows = [] 784 for n in order_of_features: 785 sorted_list_rows.append(indexes_sort[n]) 786 787 scatter_df = scatter_df.transpose() 788 789 Z = linkage(scatter_df, method=hclust) 790 791 # Get the order of features based on the dendrogram 792 order_of_features = dendrogram(Z, no_plot=True)["leaves"] 793 794 indexes_sort = list(scatter_df.index) 795 sorted_list_columns = [] 796 for n in order_of_features: 797 sorted_list_columns.append(indexes_sort[n]) 798 799 scatter_df = scatter_df.transpose() 800 801 scatter_df = scatter_df.loc[sorted_list_rows, sorted_list_columns] 802 803 if occurence_data is not None: 804 occurence_data = occurence_data.loc[sorted_list_rows, sorted_list_columns] 805 806 metadata["sets"] = [re.sub(".*#", "", x) for x in scatter_df.columns] 807 808 scatter_df.columns = [re.sub("#.*", "", x) for x in scatter_df.columns] 809 810 if occurence_data is not None: 811 occurence_data.columns = [re.sub("#.*", "", x) for x in occurence_data.columns] 812 813 fig, ax = plt.subplots(figsize=(img_width, img_high)) 814 815 norm = plt.Normalize(0, np.max(scatter_df)) 816 817 cmap = plt.get_cmap(colors) 818 819 # Bubble scatter 820 for i, _ in enumerate(scatter_df.index): 821 for j, _ in enumerate(scatter_df.columns): 822 if occurence_data is not None: 823 value_e = scatter_df.iloc[i, j] 824 value_o = occurence_data.iloc[i, j] 825 ax.scatter( 826 j, 827 i, 828 s=value_o * size_scale, 829 c=[cmap(norm(value_e))], 830 edgecolors="k", 831 linewidths=0.3, 832 ) 833 else: 834 value = scatter_df.iloc[i, j] 835 ax.scatter( 836 j, 837 i, 838 s=value * size_scale, 839 c=[cmap(norm(value))], 840 edgecolors="k", 841 linewidths=0.3, 842 ) 843 844 ax.set_yticks(range(len(scatter_df.index))) 845 ax.set_yticklabels(scatter_df.index, fontsize=label_size * 0.8) 846 ax.set_ylabel(y_lab, fontsize=label_size) 847 ax.set_xticks(range(len(scatter_df.columns))) 848 ax.set_xticklabels(scatter_df.columns, fontsize=label_size * 0.8, rotation=90) 849 850 # color bar 851 cax = inset_axes( 852 ax, 853 width="1%", 854 height=f"{bbox_to_anchor_scale}%", 855 bbox_to_anchor=(1, 0, 1, 1), 856 bbox_transform=ax.transAxes, 857 loc="upper left", 858 ) 859 cb = mpl.colorbar.ColorbarBase(cax, cmap=cmap, norm=norm) 860 cb.set_label(legend_lab, fontsize=label_size * 0.65) 861 cb.ax.tick_params(labelsize=label_size * 0.7) 862 863 if metadata_list is not None: 864 865 metadata_list = list(metadata["sets"]) 866 group_colors = get_color_palette(list(set(metadata_list)), palette_name="tab10") 867 868 for i, group in enumerate(metadata_list): 869 ax.add_patch( 870 plt.Rectangle( 871 (i - 0.5, len(scatter_df.index) - 0.1 * set_box_high), 872 1, 873 0.1 * set_box_size, 874 color=group_colors[group], 875 transform=ax.transData, 876 clip_on=False, 877 ) 878 ) 879 880 for i in range(1, len(metadata_list)): 881 if metadata_list[i] != metadata_list[i - 1]: 882 ax.axvline(i - 0.5, color="black", linestyle="--", lw=1) 883 884 group_patches = [ 885 mpatches.Patch(color=color, label=label) 886 for label, color in group_colors.items() 887 ] 888 fig.legend( 889 handles=group_patches, 890 title="Group", 891 fontsize=label_size * 0.7, 892 title_fontsize=label_size * 0.7, 893 loc="center left", 894 bbox_to_anchor=bbox_to_anchor_group, 895 frameon=False, 896 ) 897 898 # second legend (size) 899 if occurence_data is not None: 900 size_values = [0.25, 0.5, 1] 901 legend2_handles = [ 902 plt.Line2D( 903 [], 904 [], 905 marker="o", 906 linestyle="", 907 markersize=np.sqrt(v * size_scale * 0.5), 908 color="gray", 909 alpha=0.6, 910 label=f"{v * 100:.1f}", 911 ) 912 for v in size_values 913 ] 914 915 fig.legend( 916 handles=legend2_handles, 917 title="Percent [%]", 918 fontsize=label_size * 0.7, 919 title_fontsize=label_size * 0.7, 920 loc="center left", 921 bbox_to_anchor=bbox_to_anchor_perc, 922 frameon=False, 923 ) 924 925 _, ymax = ax.get_ylim() 926 927 ax.set_xlim(-0.5, len(scatter_df.columns) - 0.5) 928 ax.set_ylim(-0.5, ymax + 0.5) 929 930 return fig
Create a bubble scatter plot of selected features across samples.
Each point represents a feature-sample pair, where the color encodes the expression value and the size encodes occurrence or relative abundance. Optionally, hierarchical clustering can be applied to order rows and columns.
Parameters
expression_data : pandas.DataFrame Expression values (mean) with features as rows and samples as columns derived from average() function.
occurence_data : pandas.DataFrame or None
DataFrame with occurrence/frequency values (same shape as expression_data
) derived from occurrence() function.
If None, bubble sizes are based on expression values.
scale: bool, default False If True, expression_data will be scaled (0–1) across the rows (features).
features : list or None List of features (rows) to display. If None, all features are used.
metadata_list : list or None, optional Metadata grouping for samples (same length as number of columns). Used to add group colors and separators in the plot.
colors : str, default='viridis' Colormap for expression values.
hclust : str or None, default='complete' Linkage method for hierarchical clustering. If None, no clustering is performed.
img_width : int or float, default=8 Width of the plot in inches.
img_high : int or float, default=5 Height of the plot in inches.
label_size : int, default=10 Font size for axis labels and ticks.
size_scale : int or float, default=100 Scaling factor for bubble sizes.
y_lab : str, default='Genes' Label for the x-axis.
legend_lab : str, default='log(CPM + 1)' Label for the colorbar legend.
bbox_to_anchor_scale : int, default=25 Vertical scale (percentage) for positioning the colorbar.
bbox_to_anchor_perc : tuple, default=(0.91, 0.63) Anchor position for the size legend (percent bubble legend).
bbox_to_anchor_group : tuple, default=(1.01, 0.4) Anchor position for the group legend.
Returns
matplotlib.figure.Figure The generated scatter plot figure.
Raises
ValueError
If metadata_list
is provided but its length does not match
the number of columns in expression_data
.
Notes
- Colors represent expression values normalized to the colormap.
- Bubble sizes represent occurrence values (or expression values if
occurence_data
is None). - If
metadata_list
is given, groups are indicated with colors and dashed vertical separators.
933def calc_DEG( 934 data, 935 metadata_list: list | None = None, 936 entities: str | list | dict | None = None, 937 sets: str | list | dict | None = None, 938 min_exp: int | float = 0, 939 min_pct: int | float = 0.1, 940 n_proc: int = 10, 941): 942 """ 943 Perform differential gene expression (DEG) analysis on gene expression data. 944 945 The function compares groups of cells or samples (defined by `entities` or 946 `sets`) using the Mann–Whitney U test. It computes p-values, adjusted 947 p-values, fold changes, standardized effect sizes, and other statistics. 948 949 Parameters 950 ---------- 951 data : pandas.DataFrame 952 Expression matrix with features (e.g., genes) as rows and samples/cells 953 as columns. 954 955 metadata_list : list or None, optional 956 Metadata grouping corresponding to the columns in `data`. Required for 957 comparisons based on sets. Default is None. 958 959 entities : list, str, dict, or None, optional 960 Defines the comparison strategy: 961 - list of sample names → compare selected cells to the rest. 962 - 'All' → compare each sample/cell to all others. 963 - dict → user-defined groups for pairwise comparison. 964 - None → must be combined with `sets`. 965 966 sets : str, dict, or None, optional 967 Defines group-based comparisons: 968 - 'All' → compare each set/group to all others. 969 - dict with two groups → perform pairwise set comparison. 970 - None → must be combined with `entities`. 971 972 min_exp : float | int, default=0 973 Minimum expression threshold for filtering features. 974 975 min_pct : float | int, default=0.1 976 Minimum proportion of samples within the target group that must express 977 a feature for it to be tested. 978 979 n_proc : int, default=10 980 Number of parallel processes to use for statistical testing. 981 982 Returns 983 ------- 984 pandas.DataFrame or dict 985 Results of the differential expression analysis: 986 - If `entities` is a list → dict with keys: 'valid_cells', 987 'control_cells', and 'DEG' (results DataFrame). 988 - If `entities == 'All'` or `sets == 'All'` → DataFrame with results 989 for all groups. 990 - If pairwise comparison (dict for `entities` or `sets`) → DataFrame 991 with results for the specified groups. 992 993 The results DataFrame contains: 994 - 'feature': feature name 995 - 'p_val': raw p-value 996 - 'adj_pval': adjusted p-value (multiple testing correction) 997 - 'pct_valid': fraction of target group expressing the feature 998 - 'pct_ctrl': fraction of control group expressing the feature 999 - 'avg_valid': mean expression in target group 1000 - 'avg_ctrl': mean expression in control group 1001 - 'sd_valid': standard deviation in target group 1002 - 'sd_ctrl': standard deviation in control group 1003 - 'esm': effect size metric 1004 - 'FC': fold change 1005 - 'log(FC)': log2-transformed fold change 1006 - 'norm_diff': difference in mean expression 1007 1008 Raises 1009 ------ 1010 ValueError 1011 - If `metadata_list` is provided but its length does not match 1012 the number of columns in `data`. 1013 - If neither `entities` nor `sets` is provided. 1014 1015 Notes 1016 ----- 1017 - Mann–Whitney U test is used for group comparisons. 1018 - Multiple testing correction is applied using a simple 1019 Benjamini–Hochberg-like method. 1020 - Features expressed below `min_exp` or in fewer than `min_pct` of target 1021 samples are filtered out. 1022 - Parallelization is handled by `joblib.Parallel`. 1023 1024 Examples 1025 -------- 1026 Compare a selected list of cells against all others: 1027 1028 >>> result = calc_DEG(data, entities=["cell1", "cell2", "cell3"]) 1029 1030 Compare each group to others (based on metadata): 1031 1032 >>> result = calc_DEG(data, metadata_list=group_labels, sets="All") 1033 1034 Perform pairwise comparison between two predefined sets: 1035 1036 >>> sets = {"GroupA": ["A1", "A2"], "GroupB": ["B1", "B2"]} 1037 >>> result = calc_DEG(data, sets=sets) 1038 """ 1039 1040 metadata = {} 1041 1042 metadata["primary_names"] = [str(x) for x in data.columns] 1043 1044 if metadata_list is not None: 1045 metadata["sets"] = metadata_list 1046 1047 if len(metadata["primary_names"]) != len(metadata["sets"]): 1048 1049 raise ValueError( 1050 "Metadata list and DataFrame columns must have the same length." 1051 ) 1052 1053 else: 1054 1055 metadata["sets"] = [""] * len(metadata["primary_names"]) 1056 1057 metadata = pd.DataFrame(metadata) 1058 1059 def stat_calc(choose, feature_name): 1060 target_values = choose.loc[choose["DEG"] == "target", feature_name] 1061 rest_values = choose.loc[choose["DEG"] == "rest", feature_name] 1062 1063 pct_valid = (target_values > 0).sum() / len(target_values) 1064 pct_rest = (rest_values > 0).sum() / len(rest_values) 1065 1066 avg_valid = np.mean(target_values) 1067 avg_ctrl = np.mean(rest_values) 1068 sd_valid = np.std(target_values, ddof=1) 1069 sd_ctrl = np.std(rest_values, ddof=1) 1070 esm = (avg_valid - avg_ctrl) / np.sqrt(((sd_valid**2 + sd_ctrl**2) / 2)) 1071 1072 if np.sum(target_values) == np.sum(rest_values): 1073 p_val = 1.0 1074 else: 1075 _, p_val = stats.mannwhitneyu( 1076 target_values, rest_values, alternative="two-sided" 1077 ) 1078 1079 return { 1080 "feature": feature_name, 1081 "p_val": p_val, 1082 "pct_valid": pct_valid, 1083 "pct_ctrl": pct_rest, 1084 "avg_valid": avg_valid, 1085 "avg_ctrl": avg_ctrl, 1086 "sd_valid": sd_valid, 1087 "sd_ctrl": sd_ctrl, 1088 "esm": esm, 1089 } 1090 1091 def prepare_and_run_stat(choose, valid_group, min_exp, min_pct, n_proc, factors): 1092 1093 tmp_dat = choose[choose["DEG"] == "target"] 1094 tmp_dat = tmp_dat.drop("DEG", axis=1) 1095 1096 counts = (tmp_dat > min_exp).sum(axis=0) 1097 1098 total_count = tmp_dat.shape[0] 1099 1100 info = pd.DataFrame( 1101 {"feature": list(tmp_dat.columns), "pct": list(counts / total_count)} 1102 ) 1103 1104 del tmp_dat 1105 1106 drop_col = info["feature"][info["pct"] <= min_pct] 1107 1108 if len(drop_col) + 1 == len(choose.columns): 1109 drop_col = info["feature"][info["pct"] == 0] 1110 1111 del info 1112 1113 choose = choose.drop(list(drop_col), axis=1) 1114 1115 results = Parallel(n_jobs=n_proc)( 1116 delayed(stat_calc)(choose, feature) 1117 for feature in tqdm(choose.columns[choose.columns != "DEG"]) 1118 ) 1119 1120 if len(results) > 0: 1121 df = pd.DataFrame(results) 1122 df["valid_group"] = valid_group 1123 df.sort_values(by="p_val", inplace=True) 1124 1125 num_tests = len(df) 1126 df["adj_pval"] = np.minimum( 1127 1, (df["p_val"] * num_tests) / np.arange(1, num_tests + 1) 1128 ) 1129 1130 factors_df = factors.to_frame(name="factor") 1131 1132 df = df.merge(factors_df, left_on="feature", right_index=True, how="left") 1133 1134 df[["avg_valid", "avg_ctrl", "factor"]] = df[ 1135 ["avg_valid", "avg_ctrl", "factor"] 1136 ].astype(float) 1137 1138 df["FC"] = (df["avg_valid"] + df["factor"]) / ( 1139 df["avg_ctrl"] + df["factor"] 1140 ) 1141 df["log(FC)"] = np.log2(df["FC"]) 1142 df["norm_diff"] = df["avg_valid"] - df["avg_ctrl"] 1143 df = df.drop(columns=["factor"]) 1144 else: 1145 columns = [ 1146 "feature", 1147 "valid_group", 1148 "p_val", 1149 "adj_pval", 1150 "avg_valid", 1151 "avg_ctrl", 1152 "FC", 1153 "log(FC)", 1154 "norm_diff", 1155 ] 1156 df = pd.DataFrame(columns=columns) 1157 return df 1158 1159 choose = data.T 1160 1161 factors = data.copy().replace(0, np.nan).min(axis=1) 1162 1163 nonzero_factors = factors[factors != 0] 1164 if len(nonzero_factors) > 0: 1165 factors[factors == 0] = nonzero_factors.min() 1166 else: 1167 factors[factors == 0] = 0.01 1168 1169 final_results = [] 1170 1171 if isinstance(entities, list) and sets is None: 1172 print("\nAnalysis started...\nComparing selected cells to the whole set...") 1173 1174 if metadata_list is None: 1175 choose.index = metadata["primary_names"] 1176 else: 1177 choose.index = metadata["primary_names"] + " # " + metadata["sets"] 1178 1179 if "#" not in entities[0]: 1180 choose.index = metadata["primary_names"] 1181 print( 1182 "You provided 'metadata_list', but did not include the set info (name # set) " 1183 "in the 'entities' list. " 1184 "Only the names will be compared, without considering the set information." 1185 ) 1186 1187 labels = ["target" if idx in entities else "rest" for idx in choose.index] 1188 valid = list( 1189 set(choose.index[[i for i, x in enumerate(labels) if x == "target"]]) 1190 ) 1191 1192 choose["DEG"] = labels 1193 choose = choose[choose["DEG"] != "drop"] 1194 1195 result_df = prepare_and_run_stat( 1196 choose.reset_index(drop=True), 1197 valid_group=valid, 1198 min_exp=min_exp, 1199 min_pct=min_pct, 1200 n_proc=n_proc, 1201 factors=factors, 1202 ) 1203 1204 return {"valid": valid, "control": "rest", "DEG": result_df} 1205 1206 elif entities == "All" and sets is None: 1207 print("\nAnalysis started...\nComparing each type of cell to others...") 1208 1209 if metadata_list is None: 1210 choose.index = metadata["primary_names"] 1211 else: 1212 choose.index = metadata["primary_names"] + " # " + metadata["sets"] 1213 1214 unique_labels = set(choose.index) 1215 1216 for label in tqdm(unique_labels): 1217 print(f"\nCalculating statistics for {label}") 1218 labels = ["target" if idx == label else "rest" for idx in choose.index] 1219 choose["DEG"] = labels 1220 choose = choose[choose["DEG"] != "drop"] 1221 result_df = prepare_and_run_stat( 1222 choose.copy(), 1223 valid_group=label, 1224 min_exp=min_exp, 1225 min_pct=min_pct, 1226 n_proc=n_proc, 1227 factors=factors, 1228 ) 1229 final_results.append(result_df) 1230 1231 final_results = pd.concat(final_results, ignore_index=True) 1232 1233 if metadata_list is None: 1234 final_results["valid_group"] = [ 1235 re.sub(" # ", "", x) for x in final_results["valid_group"] 1236 ] 1237 1238 return final_results 1239 1240 elif entities is None and sets == "All": 1241 print("\nAnalysis started...\nComparing each set/group to others...") 1242 choose.index = metadata["sets"] 1243 unique_sets = set(choose.index) 1244 1245 for label in tqdm(unique_sets): 1246 print(f"\nCalculating statistics for {label}") 1247 labels = ["target" if idx == label else "rest" for idx in choose.index] 1248 1249 choose["DEG"] = labels 1250 choose = choose[choose["DEG"] != "drop"] 1251 result_df = prepare_and_run_stat( 1252 choose.copy(), 1253 valid_group=label, 1254 min_exp=min_exp, 1255 min_pct=min_pct, 1256 n_proc=n_proc, 1257 factors=factors, 1258 ) 1259 final_results.append(result_df) 1260 1261 return pd.concat(final_results, ignore_index=True) 1262 1263 elif entities is None and isinstance(sets, dict): 1264 print("\nAnalysis started...\nComparing groups...") 1265 choose.index = metadata["sets"] 1266 1267 group_list = list(sets.keys()) 1268 if len(group_list) != 2: 1269 print("Only pairwise group comparison is supported.") 1270 return None 1271 1272 labels = [ 1273 ( 1274 "target" 1275 if idx in sets[group_list[0]] 1276 else "rest" if idx in sets[group_list[1]] else "drop" 1277 ) 1278 for idx in choose.index 1279 ] 1280 choose["DEG"] = labels 1281 choose = choose[choose["DEG"] != "drop"] 1282 1283 result_df = prepare_and_run_stat( 1284 choose.reset_index(drop=True), 1285 valid_group=group_list[0], 1286 min_exp=min_exp, 1287 min_pct=min_pct, 1288 n_proc=n_proc, 1289 factors=factors, 1290 ) 1291 return result_df 1292 1293 elif isinstance(entities, dict) and sets is None: 1294 print("\nAnalysis started...\nComparing groups...") 1295 1296 if metadata_list is None: 1297 choose.index = metadata["primary_names"] 1298 else: 1299 choose.index = metadata["primary_names"] + " # " + metadata["sets"] 1300 if "#" not in entities[list(entities.keys())[0]][0]: 1301 choose.index = metadata["primary_names"] 1302 print( 1303 "You provided 'metadata_list', but did not include the set info (name # set) " 1304 "in the 'entities' dict. " 1305 "Only the names will be compared, without considering the set information." 1306 ) 1307 1308 group_list = list(entities.keys()) 1309 if len(group_list) != 2: 1310 print("Only pairwise group comparison is supported.") 1311 return None 1312 1313 labels = [ 1314 ( 1315 "target" 1316 if idx in entities[group_list[0]] 1317 else "rest" if idx in entities[group_list[1]] else "drop" 1318 ) 1319 for idx in choose.index 1320 ] 1321 1322 choose["DEG"] = labels 1323 choose = choose[choose["DEG"] != "drop"] 1324 1325 result_df = prepare_and_run_stat( 1326 choose.reset_index(drop=True), 1327 valid_group=group_list[0], 1328 min_exp=min_exp, 1329 min_pct=min_pct, 1330 n_proc=n_proc, 1331 factors=factors, 1332 ) 1333 1334 return result_df.reset_index(drop=True) 1335 1336 else: 1337 raise ValueError( 1338 "You must specify either 'entities' or 'sets'. None were provided, which is not allowed for this analysis." 1339 )
Perform differential gene expression (DEG) analysis on gene expression data.
The function compares groups of cells or samples (defined by entities
or
sets
) using the Mann–Whitney U test. It computes p-values, adjusted
p-values, fold changes, standardized effect sizes, and other statistics.
Parameters
data : pandas.DataFrame Expression matrix with features (e.g., genes) as rows and samples/cells as columns.
metadata_list : list or None, optional
Metadata grouping corresponding to the columns in data
. Required for
comparisons based on sets. Default is None.
entities : list, str, dict, or None, optional
Defines the comparison strategy:
- list of sample names → compare selected cells to the rest.
- 'All' → compare each sample/cell to all others.
- dict → user-defined groups for pairwise comparison.
- None → must be combined with sets
.
sets : str, dict, or None, optional
Defines group-based comparisons:
- 'All' → compare each set/group to all others.
- dict with two groups → perform pairwise set comparison.
- None → must be combined with entities
.
min_exp : float | int, default=0 Minimum expression threshold for filtering features.
min_pct : float | int, default=0.1 Minimum proportion of samples within the target group that must express a feature for it to be tested.
n_proc : int, default=10 Number of parallel processes to use for statistical testing.
Returns
pandas.DataFrame or dict
Results of the differential expression analysis:
- If entities
is a list → dict with keys: 'valid_cells',
'control_cells', and 'DEG' (results DataFrame).
- If entities == 'All'
or sets == 'All'
→ DataFrame with results
for all groups.
- If pairwise comparison (dict for entities
or sets
) → DataFrame
with results for the specified groups.
The results DataFrame contains:
- 'feature': feature name
- 'p_val': raw p-value
- 'adj_pval': adjusted p-value (multiple testing correction)
- 'pct_valid': fraction of target group expressing the feature
- 'pct_ctrl': fraction of control group expressing the feature
- 'avg_valid': mean expression in target group
- 'avg_ctrl': mean expression in control group
- 'sd_valid': standard deviation in target group
- 'sd_ctrl': standard deviation in control group
- 'esm': effect size metric
- 'FC': fold change
- 'log(FC)': log2-transformed fold change
- 'norm_diff': difference in mean expression
Raises
ValueError
- If metadata_list
is provided but its length does not match
the number of columns in data
.
- If neither entities
nor sets
is provided.
Notes
- Mann–Whitney U test is used for group comparisons.
- Multiple testing correction is applied using a simple Benjamini–Hochberg-like method.
- Features expressed below
min_exp
or in fewer thanmin_pct
of target samples are filtered out. - Parallelization is handled by
joblib.Parallel
.
Examples
Compare a selected list of cells against all others:
>>> result = calc_DEG(data, entities=["cell1", "cell2", "cell3"])
Compare each group to others (based on metadata):
>>> result = calc_DEG(data, metadata_list=group_labels, sets="All")
Perform pairwise comparison between two predefined sets:
>>> sets = {"GroupA": ["A1", "A2"], "GroupB": ["B1", "B2"]}
>>> result = calc_DEG(data, sets=sets)
1342def average(data): 1343 """ 1344 Compute the column-wise average of a DataFrame, aggregating by column names. 1345 1346 If multiple columns share the same name, their values are averaged. 1347 1348 Parameters 1349 ---------- 1350 data : pandas.DataFrame 1351 Input DataFrame with numeric values. Columns with identical names 1352 will be aggregated by their mean. 1353 1354 Returns 1355 ------- 1356 pandas.DataFrame 1357 DataFrame with the same rows as the input but with unique columns, 1358 where duplicate columns have been replaced by their mean values. 1359 """ 1360 1361 wide_data = data 1362 1363 aggregated_df = wide_data.T.groupby(level=0).mean().T 1364 1365 return aggregated_df
Compute the column-wise average of a DataFrame, aggregating by column names.
If multiple columns share the same name, their values are averaged.
Parameters
data : pandas.DataFrame Input DataFrame with numeric values. Columns with identical names will be aggregated by their mean.
Returns
pandas.DataFrame DataFrame with the same rows as the input but with unique columns, where duplicate columns have been replaced by their mean values.
1368def occurrence(data): 1369 """ 1370 Calculate the occurrence frequency of features in a DataFrame. 1371 1372 Converts the input DataFrame to binary (presence/absence) and computes 1373 the proportion of non-zero entries for each feature, aggregating by 1374 column names if duplicates exist. 1375 1376 Parameters 1377 ---------- 1378 data : pandas.DataFrame 1379 Input DataFrame with numeric values. Each column represents a feature. 1380 1381 Returns 1382 ------- 1383 pandas.DataFrame 1384 DataFrame with the same rows as the input, where each value represents 1385 the proportion of samples in which the feature is present (non-zero). 1386 Columns with identical names are aggregated. 1387 """ 1388 1389 binary_data = (data > 0).astype(int) 1390 1391 counts = binary_data.columns.value_counts() 1392 1393 binary_data = binary_data.T.groupby(level=0).sum().T 1394 binary_data = binary_data.astype(float) 1395 1396 for i in counts.index: 1397 binary_data.loc[:, i] = (binary_data.loc[:, i] / counts[i]).astype(float) 1398 1399 return binary_data
Calculate the occurrence frequency of features in a DataFrame.
Converts the input DataFrame to binary (presence/absence) and computes the proportion of non-zero entries for each feature, aggregating by column names if duplicates exist.
Parameters
data : pandas.DataFrame Input DataFrame with numeric values. Each column represents a feature.
Returns
pandas.DataFrame DataFrame with the same rows as the input, where each value represents the proportion of samples in which the feature is present (non-zero). Columns with identical names are aggregated.
1402def add_subnames(names_list: list, parent_name: str, new_clusters: list): 1403 """ 1404 Append sub-cluster names to a parent name within a list of names. 1405 1406 This function replaces occurrences of `parent_name` in `names_list` with 1407 a concatenation of the parent name and corresponding sub-cluster name 1408 from `new_clusters` (formatted as "parent.subcluster"). Non-matching names 1409 are left unchanged. 1410 1411 Parameters 1412 ---------- 1413 names_list : list 1414 Original list of names (e.g., column names or cluster labels). 1415 1416 parent_name : str 1417 Name of the parent cluster to which sub-cluster names will be added. 1418 Must exist in `names_list`. 1419 1420 new_clusters : list 1421 List of sub-cluster names. Its length must match the number of times 1422 `parent_name` occurs in `names_list`. 1423 1424 Returns 1425 ------- 1426 list 1427 Updated list of names with sub-cluster names appended to the parent name. 1428 1429 Raises 1430 ------ 1431 ValueError 1432 - If `parent_name` is not found in `names_list`. 1433 - If `new_clusters` length does not match the number of occurrences of 1434 `parent_name`. 1435 1436 Examples 1437 -------- 1438 >>> add_subnames(['A', 'B', 'A'], 'A', ['1', '2']) 1439 ['A.1', 'B', 'A.2'] 1440 """ 1441 1442 if str(parent_name) not in [str(x) for x in names_list]: 1443 raise ValueError( 1444 "Parent name is missing from the original dataset`s column names!" 1445 ) 1446 1447 if len(new_clusters) != len([x for x in names_list if str(x) == str(parent_name)]): 1448 raise ValueError( 1449 "New cluster names list has a different length than the number of clusters in the original dataset!" 1450 ) 1451 1452 new_names = [] 1453 ixn = 0 1454 for _, i in enumerate(names_list): 1455 if str(i) == str(parent_name): 1456 1457 new_names.append(f"{parent_name}.{new_clusters[ixn]}") 1458 ixn += 1 1459 1460 else: 1461 new_names.append(i) 1462 1463 return new_names
Append sub-cluster names to a parent name within a list of names.
This function replaces occurrences of parent_name
in names_list
with
a concatenation of the parent name and corresponding sub-cluster name
from new_clusters
(formatted as "parent.subcluster"). Non-matching names
are left unchanged.
Parameters
names_list : list Original list of names (e.g., column names or cluster labels).
parent_name : str
Name of the parent cluster to which sub-cluster names will be added.
Must exist in names_list
.
new_clusters : list
List of sub-cluster names. Its length must match the number of times
parent_name
occurs in names_list
.
Returns
list Updated list of names with sub-cluster names appended to the parent name.
Raises
ValueError
- If parent_name
is not found in names_list
.
- If new_clusters
length does not match the number of occurrences of
parent_name
.
Examples
>>> add_subnames(['A', 'B', 'A'], 'A', ['1', '2'])
['A.1', 'B', 'A.2']
1466def development_clust( 1467 data: pd.DataFrame, method: str = "ward", img_width: int = 5, img_high: int = 5 1468): 1469 """ 1470 Perform hierarchical clustering on the columns of a DataFrame and plot a dendrogram. 1471 1472 Uses Ward's method to cluster the transposed data (columns) and generates 1473 a dendrogram showing the relationships between features or samples. 1474 1475 Parameters 1476 ---------- 1477 data : pandas.DataFrame 1478 Input DataFrame with features as rows and samples/columns to be clustered. 1479 1480 method : str 1481 Method for hierarchical clustering. Options include: 1482 - 'ward' : minimizes the variance of clusters being merged. 1483 - 'single' : uses the minimum of the distances between all observations of the two sets. 1484 - 'complete' : uses the maximum of the distances between all observations of the two sets. 1485 - 'average' : uses the average of the distances between all observations of the two sets. 1486 1487 img_width : int or float, default=5 1488 Width of the resulting figure in inches. 1489 1490 img_high : int or float, default=5 1491 Height of the resulting figure in inches. 1492 1493 Returns 1494 ------- 1495 matplotlib.figure.Figure 1496 The dendrogram figure. 1497 """ 1498 1499 z = linkage(data.T, method=method) 1500 1501 figure, ax = plt.subplots(figsize=(img_width, img_high)) 1502 1503 dendrogram(z, labels=data.columns, orientation="left", ax=ax) 1504 1505 return figure
Perform hierarchical clustering on the columns of a DataFrame and plot a dendrogram.
Uses Ward's method to cluster the transposed data (columns) and generates a dendrogram showing the relationships between features or samples.
Parameters
data : pandas.DataFrame Input DataFrame with features as rows and samples/columns to be clustered.
method : str Method for hierarchical clustering. Options include:
- 'ward' : minimizes the variance of clusters being merged.
- 'single' : uses the minimum of the distances between all observations of the two sets.
- 'complete' : uses the maximum of the distances between all observations of the two sets.
- 'average' : uses the average of the distances between all observations of the two sets.
img_width : int or float, default=5 Width of the resulting figure in inches.
img_high : int or float, default=5 Height of the resulting figure in inches.
Returns
matplotlib.figure.Figure The dendrogram figure.
1508def adjust_cells_to_group_mean(data, data_avg, beta=0.2): 1509 """ 1510 Adjust each cell's values towards the mean of its group (centroid). 1511 1512 This function moves each cell's values in `data` slightly towards the 1513 corresponding group mean in `data_avg`, controlled by the parameter `beta`. 1514 1515 Parameters 1516 ---------- 1517 data : pandas.DataFrame 1518 Original data with features as rows and cells/samples as columns. 1519 1520 data_avg : pandas.DataFrame 1521 DataFrame of group averages (centroids) with features as rows and 1522 group names as columns. 1523 1524 beta : float, default=0.2 1525 Weight for adjustment towards the group mean. 0 = no adjustment, 1526 1 = fully replaced by the group mean. 1527 1528 Returns 1529 ------- 1530 pandas.DataFrame 1531 Adjusted data with the same shape as the input `data`. 1532 """ 1533 1534 df_adjusted = data.copy() 1535 1536 for group_name in data_avg.columns: 1537 col_idx = [ 1538 i 1539 for i, c in enumerate(df_adjusted.columns) 1540 if str(c).startswith(group_name) 1541 ] 1542 if not col_idx: 1543 continue 1544 1545 centroid = data_avg.loc[df_adjusted.index, group_name].to_numpy()[:, None] 1546 1547 df_adjusted.iloc[:, col_idx] = (1 - beta) * df_adjusted.iloc[ 1548 :, col_idx 1549 ].to_numpy() + beta * centroid 1550 1551 return df_adjusted
Adjust each cell's values towards the mean of its group (centroid).
This function moves each cell's values in data
slightly towards the
corresponding group mean in data_avg
, controlled by the parameter beta
.
Parameters
data : pandas.DataFrame Original data with features as rows and cells/samples as columns.
data_avg : pandas.DataFrame DataFrame of group averages (centroids) with features as rows and group names as columns.
beta : float, default=0.2 Weight for adjustment towards the group mean. 0 = no adjustment, 1 = fully replaced by the group mean.
Returns
pandas.DataFrame
Adjusted data with the same shape as the input data
.