jdti.utils
1import os 2import re 3 4import matplotlib as mpl 5import matplotlib.patches as mpatches 6import matplotlib.pyplot as plt 7import numpy as np 8import pandas as pd 9import scipy.stats as stats 10from adjustText import adjust_text 11from joblib import Parallel, delayed 12from matplotlib.lines import Line2D 13from mpl_toolkits.axes_grid1.inset_locator import inset_axes 14from scipy.cluster.hierarchy import dendrogram, linkage 15from scipy.io import mmread 16from tqdm import tqdm 17 18 19def load_sparse(path: str, name: str): 20 """ 21 Load a sparse matrix dataset along with associated gene 22 and cell metadata, and return it as a dense DataFrame. 23 24 This function expects the input directory to contain three files in standard 25 10x Genomics format: 26 - "matrix.mtx": the gene expression matrix in Matrix Market format 27 - "genes.tsv": tab-separated file containing gene identifiers 28 - "barcodes.tsv": tab-separated file containing cell barcodes / names 29 30 Parameters 31 ---------- 32 path : str 33 Path to the directory containing the matrix and annotation files. 34 35 name : str 36 Label or dataset identifier to be assigned to all cells in the metadata. 37 38 Returns 39 ------- 40 data : pandas.DataFrame 41 Dense expression matrix where rows correspond to genes and columns 42 correspond to cells. 43 metadata : pandas.DataFrame 44 Metadata DataFrame with two columns: 45 - "cell_names": the names of the cells (barcodes) 46 - "sets": the dataset label assigned to each cell (from `name`) 47 48 Notes 49 ----- 50 The function converts the sparse matrix into a dense DataFrame. This may 51 require a large amount of memory for datasets with many cells and genes. 52 """ 53 54 data = mmread(os.path.join(path, "matrix.mtx")) 55 data = pd.DataFrame(data.todense()) 56 genes = pd.read_csv(os.path.join(path, "genes.tsv"), header=None, sep="\t") 57 names = pd.read_csv(os.path.join(path, "barcodes.tsv"), header=None, sep="\t") 58 data.columns = [str(x) for x in names[0]] 59 data.index = list(genes[0]) 60 61 names = list(data.columns) 62 sets = [name] * len(names) 63 64 metadata = pd.DataFrame({"cell_names": names, "sets": sets}) 65 66 return data, metadata 67 68 69def volcano_plot( 70 deg_data: pd.DataFrame, 71 p_adj: bool = True, 72 top: int = 25, 73 top_rank: str = "p_value", 74 p_val: float | int = 0.05, 75 lfc: float | int = 0.25, 76 rescale_adj: bool = True, 77 image_width: int = 12, 78 image_high: int = 12, 79): 80 """ 81 Generate a volcano plot from differential expression results. 82 83 A volcano plot visualizes the relationship between statistical significance 84 (p-values or standarized p-value) and log(fold change) for each gene, highlighting 85 genes that pass significance thresholds. 86 87 Parameters 88 ---------- 89 deg_data : pandas.DataFrame 90 DataFrame containing differential expression results from calc_DEG() function. 91 92 p_adj : bool, default=True 93 If True, use adjusted p-values. If False, use raw p-values. 94 95 top : int, default=25 96 Number of top significant genes to highlight on the plot. 97 98 top_rank : str, default='p_value' 99 Statistic used primarily to determine the top significant genes to highlight on the plot. ['p_value' or 'FC'] 100 101 p_val : float | int, default=0.05 102 Significance threshold for p-values (or adjusted p-values). 103 104 lfc : float | int, default=0.25 105 Threshold for absolute log fold change. 106 107 rescale_adj : bool, default=True 108 If True, rescale p-values to avoid long breaks caused by outlier values. 109 110 image_width : int, default=12 111 Width of the generated plot in inches. 112 113 image_high : int, default=12 114 Height of the generated plot in inches. 115 116 Returns 117 ------- 118 matplotlib.figure.Figure 119 The generated volcano plot figure. 120 121 """ 122 123 if top_rank.upper() not in ["FC", "P_VALUE"]: 124 raise ValueError("top_rank must be either 'FC' or 'p_value'") 125 126 if p_adj: 127 pv = "adj_pval" 128 else: 129 pv = "p_val" 130 131 deg_df = deg_data.copy() 132 133 shift = 0.25 134 135 p_val_scale = "-log(p_val)" 136 137 min_minus = min(deg_df[pv][(deg_df[pv] != 0) & (deg_df["log(FC)"] < 0)]) 138 min_plus = min(deg_df[pv][(deg_df[pv] != 0) & (deg_df["log(FC)"] > 0)]) 139 140 zero_p_plus = deg_df[(deg_df[pv] == 0) & (deg_df["log(FC)"] > 0)] 141 zero_p_plus = zero_p_plus.sort_values(by="log(FC)", ascending=False).reset_index( 142 drop=True 143 ) 144 zero_p_plus[pv] = [ 145 (shift * x) * min_plus for x in range(1, len(zero_p_plus.index) + 1) 146 ] 147 148 zero_p_minus = deg_df[(deg_df[pv] == 0) & (deg_df["log(FC)"] < 0)] 149 zero_p_minus = zero_p_minus.sort_values(by="log(FC)", ascending=True).reset_index( 150 drop=True 151 ) 152 zero_p_minus[pv] = [ 153 (shift * x) * min_minus for x in range(1, len(zero_p_minus.index) + 1) 154 ] 155 156 tmp_p = deg_df[ 157 ((deg_df[pv] != 0) & (deg_df["log(FC)"] < 0)) 158 | ((deg_df[pv] != 0) & (deg_df["log(FC)"] > 0)) 159 ] 160 161 del deg_df 162 163 deg_df = pd.concat([zero_p_plus, tmp_p, zero_p_minus], ignore_index=True) 164 165 deg_df[pv] = deg_df[pv].replace(0, 2**-1074) 166 167 deg_df[p_val_scale] = -np.log10(deg_df[pv]) 168 169 deg_df["top100"] = None 170 171 if rescale_adj: 172 173 deg_df = deg_df.sort_values(by=p_val_scale, ascending=False) 174 175 deg_df = deg_df.reset_index(drop=True) 176 177 eps = 1e-300 178 doubled = [] 179 ratio = [] 180 for n, i in enumerate(deg_df.index): 181 for j in range(1, 6): 182 if ( 183 n + j < len(deg_df.index) 184 and (deg_df[p_val_scale][n] + eps) 185 / (deg_df[p_val_scale][n + j] + eps) 186 >= 2 187 ): 188 doubled.append(n) 189 ratio.append( 190 (deg_df[p_val_scale][n + j] + eps) 191 / (deg_df[p_val_scale][n] + eps) 192 ) 193 194 df = pd.DataFrame({"doubled": doubled, "ratio": ratio}) 195 df = df[df["doubled"] < 100] 196 197 df["ratio"] = (1 - df["ratio"]) / 5 198 df = df.reset_index(drop=True) 199 200 df = df.sort_values("doubled") 201 202 if len(df["doubled"]) == 1 and 0 in df["doubled"]: 203 df = df 204 else: 205 doubled2 = [] 206 207 for l in df["doubled"]: 208 if l + 1 != len(doubled) and l + 1 - l == 1: 209 doubled2.append(l) 210 doubled2.append(l + 1) 211 else: 212 break 213 214 doubled2 = sorted(set(doubled2), reverse=True) 215 216 if len(doubled2) > 1: 217 df = df[df["doubled"].isin(doubled2)] 218 df = df.sort_values("doubled", ascending=False) 219 df = df.reset_index(drop=True) 220 for c in df.index: 221 deg_df.loc[df["doubled"][c], p_val_scale] = deg_df.loc[ 222 df["doubled"][c] + 1, p_val_scale 223 ] * (1 + df["ratio"][c]) 224 225 deg_df.loc[(deg_df["log(FC)"] <= 0) & (deg_df[pv] <= p_val), "top100"] = "red" 226 deg_df.loc[(deg_df["log(FC)"] > 0) & (deg_df[pv] <= p_val), "top100"] = "blue" 227 deg_df.loc[deg_df[pv] > p_val, "top100"] = "lightgray" 228 229 if lfc > 0: 230 deg_df.loc[ 231 (deg_df["log(FC)"] <= lfc) & (deg_df["log(FC)"] >= -lfc), "top100" 232 ] = "lightgray" 233 234 down_int = len( 235 deg_df["top100"][(deg_df["log(FC)"] <= lfc * -1) & (deg_df[pv] <= p_val)] 236 ) 237 up_int = len(deg_df["top100"][(deg_df["log(FC)"] > lfc) & (deg_df[pv] <= p_val)]) 238 239 deg_df_up = deg_df[deg_df["log(FC)"] > 0] 240 241 if top_rank.upper() == "P_VALUE": 242 deg_df_up = deg_df_up.sort_values([pv, "log(FC)"], ascending=[True, False]) 243 elif top_rank.upper() == "FC": 244 deg_df_up = deg_df_up.sort_values(["log(FC)", pv], ascending=[False, True]) 245 246 deg_df_up = deg_df_up.reset_index(drop=True) 247 248 n = -1 249 l = 0 250 while True: 251 n += 1 252 if deg_df_up["log(FC)"][n] > lfc and deg_df_up[pv][n] <= p_val: 253 deg_df_up.loc[n, "top100"] = "green" 254 l += 1 255 if l == top or deg_df_up[pv][n] > p_val: 256 break 257 258 deg_df_down = deg_df[deg_df["log(FC)"] <= 0] 259 260 if top_rank.upper() == "P_VALUE": 261 deg_df_down = deg_df_down.sort_values([pv, "log(FC)"], ascending=[True, True]) 262 elif top_rank.upper() == "FC": 263 deg_df_down = deg_df_down.sort_values(["log(FC)", pv], ascending=[True, True]) 264 265 deg_df_down = deg_df_down.reset_index(drop=True) 266 267 n = -1 268 l = 0 269 while True: 270 n += 1 271 if deg_df_down["log(FC)"][n] < lfc * -1 and deg_df_down[pv][n] <= p_val: 272 deg_df_down.loc[n, "top100"] = "yellow" 273 274 l += 1 275 if l == top or deg_df_down[pv][n] > p_val: 276 break 277 278 deg_df = pd.concat([deg_df_up, deg_df_down]) 279 280 que = ["lightgray", "red", "blue", "yellow", "green"] 281 282 deg_df = deg_df.sort_values( 283 by="top100", key=lambda x: x.map({v: i for i, v in enumerate(que)}) 284 ) 285 286 deg_df = deg_df.reset_index(drop=True) 287 288 fig, ax = plt.subplots(figsize=(image_width, image_high)) 289 290 plt.scatter( 291 x=deg_df["log(FC)"], y=deg_df[p_val_scale], color=deg_df["top100"], zorder=2 292 ) 293 294 tl = deg_df[p_val_scale][deg_df[pv] >= p_val] 295 296 if len(tl) > 0: 297 298 line_p = np.max(tl) 299 300 else: 301 line_p = np.min(deg_df[p_val_scale]) 302 303 plt.plot( 304 [max(deg_df["log(FC)"]) * -1.1, max(deg_df["log(FC)"]) * 1.1], 305 [line_p, line_p], 306 linestyle="--", 307 linewidth=3, 308 color="lightgray", 309 zorder=1, 310 ) 311 312 if lfc > 0: 313 plt.plot( 314 [lfc * -1, lfc * -1], 315 [-3, max(deg_df[p_val_scale]) * 1.1], 316 linestyle="--", 317 linewidth=3, 318 color="lightgray", 319 zorder=1, 320 ) 321 plt.plot( 322 [lfc, lfc], 323 [-3, max(deg_df[p_val_scale]) * 1.1], 324 linestyle="--", 325 linewidth=3, 326 color="lightgray", 327 zorder=1, 328 ) 329 330 plt.xlabel("log(FC)") 331 plt.ylabel(p_val_scale) 332 plt.title("Volcano plot") 333 334 plt.ylim(min(deg_df[p_val_scale]) - 5, max(deg_df[p_val_scale]) * 1.25) 335 336 texts = [ 337 ax.text(deg_df["log(FC)"][i], deg_df[p_val_scale][i], deg_df["feature"][i]) 338 for i in deg_df.index 339 if deg_df["top100"][i] in ["green", "yellow"] 340 ] 341 342 adjust_text(texts, arrowprops=dict(arrowstyle="-", color="gray", alpha=0.5)) 343 344 legend_elements = [ 345 Line2D( 346 [0], 347 [0], 348 marker="o", 349 color="w", 350 label="top-upregulated", 351 markerfacecolor="green", 352 markersize=10, 353 ), 354 Line2D( 355 [0], 356 [0], 357 marker="o", 358 color="w", 359 label="top-downregulated", 360 markerfacecolor="yellow", 361 markersize=10, 362 ), 363 Line2D( 364 [0], 365 [0], 366 marker="o", 367 color="w", 368 label="upregulated", 369 markerfacecolor="blue", 370 markersize=10, 371 ), 372 Line2D( 373 [0], 374 [0], 375 marker="o", 376 color="w", 377 label="downregulated", 378 markerfacecolor="red", 379 markersize=10, 380 ), 381 Line2D( 382 [0], 383 [0], 384 marker="o", 385 color="w", 386 label="non-significant", 387 markerfacecolor="lightgray", 388 markersize=10, 389 ), 390 ] 391 392 ax.legend(handles=legend_elements, loc="upper right") 393 ax.grid(visible=False) 394 395 ax.annotate( 396 f"\nmin {pv} = " + str(p_val), 397 xy=(0.025, 0.975), 398 xycoords="axes fraction", 399 fontsize=12, 400 ) 401 402 if lfc > 0: 403 ax.annotate( 404 "\nmin log(FC) = " + str(lfc), 405 xy=(0.025, 0.95), 406 xycoords="axes fraction", 407 fontsize=12, 408 ) 409 410 ax.annotate( 411 "\nDownregulated: " + str(down_int), 412 xy=(0.025, 0.925), 413 xycoords="axes fraction", 414 fontsize=12, 415 color="red", 416 ) 417 418 ax.annotate( 419 "\nUpregulated: " + str(up_int), 420 xy=(0.025, 0.9), 421 xycoords="axes fraction", 422 fontsize=12, 423 color="blue", 424 ) 425 426 plt.show() 427 428 return fig 429 430 431def find_features(data: pd.DataFrame, features: list): 432 """ 433 Identify features (rows) from a DataFrame that match a given list of features, 434 ignoring case sensitivity. 435 436 Parameters 437 ---------- 438 data : pandas.DataFrame 439 DataFrame with features in the index (rows). 440 441 features : list 442 List of feature names to search for. 443 444 Returns 445 ------- 446 dict 447 Dictionary with keys: 448 - "included": list of features found in the DataFrame index. 449 - "not_included": list of requested features not found in the DataFrame index. 450 - "potential": list of features in the DataFrame that may be similar. 451 """ 452 453 features_upper = [str(x).upper() for x in features] 454 455 index_set = set(data.index) 456 457 features_in = [x for x in index_set if x.upper() in features_upper] 458 features_in_upper = [x.upper() for x in features_in] 459 features_out_upper = [x for x in features_upper if x not in features_in_upper] 460 features_out = [x for x in features if x.upper() in features_out_upper] 461 similar_features = [ 462 idx for idx in index_set if any(x in idx.upper() for x in features_out_upper) 463 ] 464 465 return { 466 "included": features_in, 467 "not_included": features_out, 468 "potential": similar_features, 469 } 470 471 472def find_names(data: pd.DataFrame, names: list): 473 """ 474 Identify names (columns) from a DataFrame that match a given list of names, 475 ignoring case sensitivity. 476 477 Parameters 478 ---------- 479 data : pandas.DataFrame 480 DataFrame with names in the columns. 481 482 names : list 483 List of names to search for. 484 485 Returns 486 ------- 487 dict 488 Dictionary with keys: 489 - "included": list of names found in the DataFrame columns. 490 - "not_included": list of requested names not found in the DataFrame columns. 491 - "potential": list of names in the DataFrame that may be similar. 492 """ 493 494 names_upper = [str(x).upper() for x in names] 495 496 columns = set(data.columns) 497 498 names_in = [x for x in columns if x.upper() in names_upper] 499 names_in_upper = [x.upper() for x in names_in] 500 names_out_upper = [x for x in names_upper if x not in names_in_upper] 501 names_out = [x for x in names if x.upper() in names_out_upper] 502 similar_names = [ 503 idx for idx in columns if any(x in idx.upper() for x in names_out_upper) 504 ] 505 506 return {"included": names_in, "not_included": names_out, "potential": similar_names} 507 508 509def reduce_data(data: pd.DataFrame, features: list = [], names: list = []): 510 """ 511 Subset a DataFrame based on selected features (rows) and/or names (columns). 512 513 Parameters 514 ---------- 515 data : pandas.DataFrame 516 Input DataFrame with features as rows and names as columns. 517 518 features : list 519 List of features to include (rows). Default is an empty list. 520 If empty, all rows are returned. 521 522 names : list 523 List of names to include (columns). Default is an empty list. 524 If empty, all columns are returned. 525 526 Returns 527 ------- 528 pandas.DataFrame 529 Subset of the input DataFrame containing only the selected rows 530 and/or columns. 531 532 Raises 533 ------ 534 ValueError 535 If both `features` and `names` are empty. 536 """ 537 538 if len(features) > 0 and len(names) > 0: 539 fet = find_features(data=data, features=features) 540 541 nam = find_names(data=data, names=names) 542 543 data_to_return = data.loc[fet["included"], nam["included"]] 544 545 elif len(features) > 0 and len(names) == 0: 546 fet = find_features(data=data, features=features) 547 548 data_to_return = data.loc[fet["included"], :] 549 550 elif len(features) == 0 and len(names) > 0: 551 552 nam = find_names(data=data, names=names) 553 554 data_to_return = data.loc[:, nam["included"]] 555 556 else: 557 558 raise ValueError("features and names have zero length!") 559 560 return data_to_return 561 562 563def make_unique_list(lst): 564 """ 565 Generate a list where duplicate items are renamed to ensure uniqueness. 566 567 Each duplicate is appended with a suffix ".n", where n indicates the 568 occurrence count (starting from 1). 569 570 Parameters 571 ---------- 572 lst : list 573 Input list of items (strings or other hashable types). 574 575 Returns 576 ------- 577 list 578 List with unique values. 579 580 Examples 581 -------- 582 >>> make_unique_list(["A", "B", "A", "A"]) 583 ['A', 'B', 'A.1', 'A.2'] 584 """ 585 seen = {} 586 result = [] 587 for item in lst: 588 if item not in seen: 589 seen[item] = 0 590 result.append(item) 591 else: 592 seen[item] += 1 593 result.append(f"{item}.{seen[item]}") 594 return result 595 596 597def get_color_palette(variable_list, palette_name="tab10"): 598 n = len(variable_list) 599 cmap = plt.get_cmap(palette_name) 600 colors = [cmap(i % cmap.N) for i in range(n)] 601 return dict(zip(variable_list, colors)) 602 603 604def features_scatter( 605 expression_data: pd.DataFrame, 606 occurence_data: pd.DataFrame | None = None, 607 scale: bool = False, 608 features: list | None = None, 609 metadata_list: list | None = None, 610 colors: str = "viridis", 611 hclust: str | None = "complete", 612 img_width: int = 8, 613 img_high: int = 5, 614 label_size: int = 10, 615 size_scale: int = 100, 616 y_lab: str = "Genes", 617 legend_lab: str = "log(CPM + 1)", 618 set_box_size: float | int = 5, 619 set_box_high: float | int = 5, 620 bbox_to_anchor_scale: int = 25, 621 bbox_to_anchor_perc: tuple = (0.91, 0.63), 622 bbox_to_anchor_group: tuple = (1.01, 0.4), 623): 624 """ 625 Create a bubble scatter plot of selected features across samples. 626 627 Each point represents a feature-sample pair, where the color encodes the 628 expression value and the size encodes occurrence or relative abundance. 629 Optionally, hierarchical clustering can be applied to order rows and columns. 630 631 Parameters 632 ---------- 633 expression_data : pandas.DataFrame 634 Expression values (mean) with features as rows and samples as columns derived from average() function. 635 636 occurence_data : pandas.DataFrame or None 637 DataFrame with occurrence/frequency values (same shape as `expression_data`) derived from occurrence() function. 638 If None, bubble sizes are based on expression values. 639 640 scale: bool, default False 641 If True, expression_data (features) will be scaled (0–1) across the colums (sample). 642 643 features : list or None 644 List of features (rows) to display. If None, all features are used. 645 646 metadata_list : list or None, optional 647 Metadata grouping for samples (same length as number of columns). 648 Used to add group colors and separators in the plot. 649 650 colors : str, default='viridis' 651 Colormap for expression values. 652 653 hclust : str or None, default='complete' 654 Linkage method for hierarchical clustering. If None, no clustering 655 is performed. 656 657 img_width : int or float, default=8 658 Width of the plot in inches. 659 660 img_high : int or float, default=5 661 Height of the plot in inches. 662 663 label_size : int, default=10 664 Font size for axis labels and ticks. 665 666 size_scale : int or float, default=100 667 Scaling factor for bubble sizes. 668 669 y_lab : str, default='Genes' 670 Label for the x-axis. 671 672 legend_lab : str, default='log(CPM + 1)' 673 Label for the colorbar legend. 674 675 bbox_to_anchor_scale : int, default=25 676 Vertical scale (percentage) for positioning the colorbar. 677 678 bbox_to_anchor_perc : tuple, default=(0.91, 0.63) 679 Anchor position for the size legend (percent bubble legend). 680 681 bbox_to_anchor_group : tuple, default=(1.01, 0.4) 682 Anchor position for the group legend. 683 684 Returns 685 ------- 686 matplotlib.figure.Figure 687 The generated scatter plot figure. 688 689 Raises 690 ------ 691 ValueError 692 If `metadata_list` is provided but its length does not match 693 the number of columns in `expression_data`. 694 695 Notes 696 ----- 697 - Colors represent expression values normalized to the colormap. 698 - Bubble sizes represent occurrence values (or expression values if 699 `occurence_data` is None). 700 - If `metadata_list` is given, groups are indicated with colors and 701 dashed vertical separators. 702 """ 703 704 scatter_df = expression_data.copy() 705 706 if scale: 707 708 legend_lab = "Scaled\n" + legend_lab 709 710 column_max = scatter_df.max() 711 scatter_df = scatter_df.div(column_max).replace([np.inf, -np.inf], np.nan).fillna(0) 712 scatter_df = pd.DataFrame(scatter_df, index=scatter_df.index, columns=scatter_df.columns) 713 714 715 metadata = {} 716 717 metadata["primary_names"] = [str(x) for x in scatter_df.columns] 718 719 if metadata_list is not None: 720 metadata["sets"] = metadata_list 721 722 if len(metadata["primary_names"]) != len(metadata["sets"]): 723 724 raise ValueError( 725 "Metadata list and DataFrame columns must have the same length." 726 ) 727 728 else: 729 730 metadata["sets"] = [""] * len(metadata["primary_names"]) 731 732 metadata = pd.DataFrame(metadata) 733 if features is not None: 734 scatter_df = scatter_df.loc[ 735 find_features(data=scatter_df, features=features)["included"], 736 ] 737 scatter_df.columns = metadata["primary_names"] + "#" + metadata["sets"] 738 739 if occurence_data is not None: 740 if features is not None: 741 occurence_data = occurence_data.loc[ 742 find_features(data=occurence_data, features=features)["included"], 743 ] 744 occurence_data.columns = metadata["primary_names"] + "#" + metadata["sets"] 745 746 # check duplicated names 747 748 tmp_columns = scatter_df.columns 749 750 new_cols = make_unique_list(list(tmp_columns)) 751 752 scatter_df.columns = new_cols 753 754 if hclust is not None and len(expression_data.index) != 1: 755 756 Z = linkage(scatter_df, method=hclust) 757 758 # Get the order of features based on the dendrogram 759 order_of_features = dendrogram(Z, no_plot=True)["leaves"] 760 761 indexes_sort = list(scatter_df.index) 762 sorted_list_rows = [] 763 for n in order_of_features: 764 sorted_list_rows.append(indexes_sort[n]) 765 766 scatter_df = scatter_df.transpose() 767 768 Z = linkage(scatter_df, method=hclust) 769 770 # Get the order of features based on the dendrogram 771 order_of_features = dendrogram(Z, no_plot=True)["leaves"] 772 773 indexes_sort = list(scatter_df.index) 774 sorted_list_columns = [] 775 for n in order_of_features: 776 sorted_list_columns.append(indexes_sort[n]) 777 778 scatter_df = scatter_df.transpose() 779 780 scatter_df = scatter_df.loc[sorted_list_rows, sorted_list_columns] 781 782 if occurence_data is not None: 783 occurence_data = occurence_data.loc[sorted_list_rows, sorted_list_columns] 784 785 metadata["sets"] = [re.sub(".*#", "", x) for x in scatter_df.columns] 786 787 scatter_df.columns = [re.sub("#.*", "", x) for x in scatter_df.columns] 788 789 if occurence_data is not None: 790 occurence_data.columns = [re.sub("#.*", "", x) for x in occurence_data.columns] 791 792 fig, ax = plt.subplots(figsize=(img_width, img_high)) 793 794 norm = plt.Normalize(0, np.max(scatter_df)) 795 796 cmap = plt.get_cmap(colors) 797 798 # Bubble scatter 799 for i, _ in enumerate(scatter_df.index): 800 for j, _ in enumerate(scatter_df.columns): 801 if occurence_data is not None: 802 value_e = scatter_df.iloc[i, j] 803 value_o = occurence_data.iloc[i, j] 804 ax.scatter( 805 j, 806 i, 807 s=value_o * size_scale, 808 c=[cmap(norm(value_e))], 809 edgecolors="k", 810 linewidths=0.3, 811 ) 812 else: 813 value = scatter_df.iloc[i, j] 814 ax.scatter( 815 j, 816 i, 817 s=value * size_scale, 818 c=[cmap(norm(value))], 819 edgecolors="k", 820 linewidths=0.3, 821 ) 822 823 ax.set_yticks(range(len(scatter_df.index))) 824 ax.set_yticklabels(scatter_df.index, fontsize=label_size * 0.8) 825 ax.set_ylabel(y_lab, fontsize=label_size) 826 ax.set_xticks(range(len(scatter_df.columns))) 827 ax.set_xticklabels(scatter_df.columns, fontsize=label_size * 0.8, rotation=90) 828 829 ax_pos = ax.get_position() 830 831 width_fig = 0.01 832 height_fig = ax_pos.height * (bbox_to_anchor_scale / 100) 833 left_fig = ax_pos.x1 + 0.01 834 bottom_fig = ax_pos.y1 - height_fig 835 836 cax = fig.add_axes([left_fig, bottom_fig, width_fig, height_fig]) 837 cb = mpl.colorbar.ColorbarBase(cax, cmap=cmap, norm=norm) 838 cb.set_label(legend_lab, fontsize=label_size * 0.65) 839 cb.ax.tick_params(labelsize=label_size * 0.7) 840 841 if metadata_list is not None: 842 843 metadata_list = list(metadata["sets"]) 844 group_colors = get_color_palette(list(set(metadata_list)), palette_name="tab10") 845 846 for i, group in enumerate(metadata_list): 847 ax.add_patch( 848 plt.Rectangle( 849 (i - 0.5, len(scatter_df.index) - 0.1 * set_box_high), 850 1, 851 0.1 * set_box_size, 852 color=group_colors[group], 853 transform=ax.transData, 854 clip_on=False, 855 ) 856 ) 857 858 for i in range(1, len(metadata_list)): 859 if metadata_list[i] != metadata_list[i - 1]: 860 ax.axvline(i - 0.5, color="black", linestyle="--", lw=1) 861 862 group_patches = [ 863 mpatches.Patch(color=color, label=label) 864 for label, color in group_colors.items() 865 ] 866 fig.legend( 867 handles=group_patches, 868 title="Group", 869 fontsize=label_size * 0.7, 870 title_fontsize=label_size * 0.7, 871 loc="center left", 872 bbox_to_anchor=bbox_to_anchor_group, 873 frameon=False, 874 ) 875 876 # second legend (size) 877 if occurence_data is not None: 878 size_values = [0.25, 0.5, 1] 879 legend2_handles = [ 880 plt.Line2D( 881 [], 882 [], 883 marker="o", 884 linestyle="", 885 markersize=np.sqrt(v * size_scale * 0.5), 886 color="gray", 887 alpha=0.6, 888 label=f"{v * 100:.1f}", 889 ) 890 for v in size_values 891 ] 892 893 fig.legend( 894 handles=legend2_handles, 895 title="Percent [%]", 896 fontsize=label_size * 0.7, 897 title_fontsize=label_size * 0.7, 898 loc="center left", 899 bbox_to_anchor=bbox_to_anchor_perc, 900 frameon=False, 901 ) 902 903 _, ymax = ax.get_ylim() 904 905 ax.set_xlim(-0.5, len(scatter_df.columns) - 0.5) 906 ax.set_ylim(-0.5, ymax + 0.5) 907 908 return fig 909 910 911def calc_DEG( 912 data, 913 metadata_list: list | None = None, 914 entities: str | list | dict | None = None, 915 sets: str | list | dict | None = None, 916 min_exp: int | float = 0, 917 min_pct: int | float = 0.1, 918 n_proc: int = 10, 919): 920 """ 921 Perform differential gene expression (DEG) analysis on gene expression data. 922 923 The function compares groups of cells or samples (defined by `entities` or 924 `sets`) using the Mann–Whitney U test. It computes p-values, adjusted 925 p-values, fold changes, standardized effect sizes, and other statistics. 926 927 Parameters 928 ---------- 929 data : pandas.DataFrame 930 Expression matrix with features (e.g., genes) as rows and samples/cells 931 as columns. 932 933 metadata_list : list or None, optional 934 Metadata grouping corresponding to the columns in `data`. Required for 935 comparisons based on sets. Default is None. 936 937 entities : list, str, dict, or None, optional 938 Defines the comparison strategy: 939 - list of sample names → compare selected cells to the rest. 940 - 'All' → compare each sample/cell to all others. 941 - dict → user-defined groups for pairwise comparison. 942 - None → must be combined with `sets`. 943 944 sets : str, dict, or None, optional 945 Defines group-based comparisons: 946 - 'All' → compare each set/group to all others. 947 - dict with two groups → perform pairwise set comparison. 948 - None → must be combined with `entities`. 949 950 min_exp : float | int, default=0 951 Minimum expression threshold for filtering features. 952 953 min_pct : float | int, default=0.1 954 Minimum proportion of samples within the target group that must express 955 a feature for it to be tested. 956 957 n_proc : int, default=10 958 Number of parallel processes to use for statistical testing. 959 960 Returns 961 ------- 962 pandas.DataFrame or dict 963 Results of the differential expression analysis: 964 - If `entities` is a list → dict with keys: 'valid_cells', 965 'control_cells', and 'DEG' (results DataFrame). 966 - If `entities == 'All'` or `sets == 'All'` → DataFrame with results 967 for all groups. 968 - If pairwise comparison (dict for `entities` or `sets`) → DataFrame 969 with results for the specified groups. 970 971 The results DataFrame contains: 972 - 'feature': feature name 973 - 'p_val': raw p-value 974 - 'adj_pval': adjusted p-value (multiple testing correction) 975 - 'pct_valid': fraction of target group expressing the feature 976 - 'pct_ctrl': fraction of control group expressing the feature 977 - 'avg_valid': mean expression in target group 978 - 'avg_ctrl': mean expression in control group 979 - 'sd_valid': standard deviation in target group 980 - 'sd_ctrl': standard deviation in control group 981 - 'esm': effect size metric 982 - 'FC': fold change 983 - 'log(FC)': log2-transformed fold change 984 - 'norm_diff': difference in mean expression 985 986 Raises 987 ------ 988 ValueError 989 - If `metadata_list` is provided but its length does not match 990 the number of columns in `data`. 991 - If neither `entities` nor `sets` is provided. 992 993 Notes 994 ----- 995 - Mann–Whitney U test is used for group comparisons. 996 - Multiple testing correction is applied using a simple 997 Benjamini–Hochberg-like method. 998 - Features expressed below `min_exp` or in fewer than `min_pct` of target 999 samples are filtered out. 1000 - Parallelization is handled by `joblib.Parallel`. 1001 1002 Examples 1003 -------- 1004 Compare a selected list of cells against all others: 1005 1006 >>> result = calc_DEG(data, entities=["cell1", "cell2", "cell3"]) 1007 1008 Compare each group to others (based on metadata): 1009 1010 >>> result = calc_DEG(data, metadata_list=group_labels, sets="All") 1011 1012 Perform pairwise comparison between two predefined sets: 1013 1014 >>> sets = {"GroupA": ["A1", "A2"], "GroupB": ["B1", "B2"]} 1015 >>> result = calc_DEG(data, sets=sets) 1016 """ 1017 offset = 1e-100 1018 1019 metadata = {} 1020 1021 metadata["primary_names"] = [str(x) for x in data.columns] 1022 1023 if metadata_list is not None: 1024 metadata["sets"] = metadata_list 1025 1026 if len(metadata["primary_names"]) != len(metadata["sets"]): 1027 1028 raise ValueError( 1029 "Metadata list and DataFrame columns must have the same length." 1030 ) 1031 1032 else: 1033 1034 metadata["sets"] = [""] * len(metadata["primary_names"]) 1035 1036 metadata = pd.DataFrame(metadata) 1037 1038 def stat_calc(choose, feature_name): 1039 target_values = choose.loc[choose["DEG"] == "target", feature_name] 1040 rest_values = choose.loc[choose["DEG"] == "rest", feature_name] 1041 1042 pct_valid = (target_values > 0).sum() / len(target_values) 1043 pct_rest = (rest_values > 0).sum() / len(rest_values) 1044 1045 avg_valid = np.mean(target_values) 1046 avg_ctrl = np.mean(rest_values) 1047 sd_valid = np.std(target_values, ddof=1) 1048 sd_ctrl = np.std(rest_values, ddof=1) 1049 esm = (avg_valid - avg_ctrl) / np.sqrt(((sd_valid**2 + sd_ctrl**2) / 2)) 1050 1051 if np.sum(target_values) == np.sum(rest_values): 1052 p_val = 1.0 1053 else: 1054 _, p_val = stats.mannwhitneyu( 1055 target_values, rest_values, alternative="two-sided" 1056 ) 1057 1058 return { 1059 "feature": feature_name, 1060 "p_val": p_val, 1061 "pct_valid": pct_valid, 1062 "pct_ctrl": pct_rest, 1063 "avg_valid": avg_valid, 1064 "avg_ctrl": avg_ctrl, 1065 "sd_valid": sd_valid, 1066 "sd_ctrl": sd_ctrl, 1067 "esm": esm, 1068 } 1069 1070 def prepare_and_run_stat(choose, valid_group, min_exp, min_pct, n_proc): 1071 1072 def safe_min_half(series): 1073 filtered = series[(series > ((2**-1074)*2)) & (series.notna())] 1074 return filtered.min() / 2 if not filtered.empty else 0 1075 1076 tmp_dat = choose[choose["DEG"] == "target"] 1077 tmp_dat = tmp_dat.drop("DEG", axis=1) 1078 1079 counts = (tmp_dat > min_exp).sum(axis=0) 1080 1081 total_count = tmp_dat.shape[0] 1082 1083 info = pd.DataFrame( 1084 {"feature": list(tmp_dat.columns), "pct": list(counts / total_count)} 1085 ) 1086 1087 del tmp_dat 1088 1089 drop_col = info["feature"][info["pct"] <= min_pct] 1090 1091 if len(drop_col) + 1 == len(choose.columns): 1092 drop_col = info["feature"][info["pct"] == 0] 1093 1094 del info 1095 1096 choose = choose.drop(list(drop_col), axis=1) 1097 1098 results = Parallel(n_jobs=n_proc)( 1099 delayed(stat_calc)(choose, feature) 1100 for feature in tqdm(choose.columns[choose.columns != "DEG"]) 1101 ) 1102 1103 if len(results) > 0: 1104 df = pd.DataFrame(results) 1105 1106 df = df[(df["avg_valid"] > 0) | (df["avg_ctrl"] > 0)] 1107 1108 df["valid_group"] = valid_group 1109 df.sort_values(by="p_val", inplace=True) 1110 1111 num_tests = len(df) 1112 df["adj_pval"] = np.minimum( 1113 1, (df["p_val"] * num_tests) / np.arange(1, num_tests + 1) 1114 ) 1115 1116 valid_factor = safe_min_half(df["avg_valid"]) 1117 ctrl_factor = safe_min_half(df["avg_ctrl"]) 1118 1119 cv_factor = min(valid_factor, ctrl_factor) 1120 1121 if cv_factor == 0: 1122 cv_factor = max(valid_factor, ctrl_factor) 1123 1124 if not np.isfinite(cv_factor) or cv_factor == 0: 1125 cv_factor += offset 1126 1127 valid = df["avg_valid"].where( 1128 df["avg_valid"] != 0, df["avg_valid"] + cv_factor 1129 ) 1130 ctrl = df["avg_ctrl"].where( 1131 df["avg_ctrl"] != 0, df["avg_ctrl"] + cv_factor 1132 ) 1133 1134 df["FC"] = valid / ctrl 1135 1136 df["log(FC)"] = np.log2(df["FC"]) 1137 df["norm_diff"] = df["avg_valid"] - df["avg_ctrl"] 1138 1139 else: 1140 columns = [ 1141 "feature", 1142 "valid_group", 1143 "p_val", 1144 "adj_pval", 1145 "avg_valid", 1146 "avg_ctrl", 1147 "FC", 1148 "log(FC)", 1149 "norm_diff", 1150 ] 1151 df = pd.DataFrame(columns=columns) 1152 return df 1153 1154 choose = data.T 1155 1156 final_results = [] 1157 1158 if isinstance(entities, list) and sets is None: 1159 print("\nAnalysis started...\nComparing selected cells to the whole set...") 1160 1161 if metadata_list is None: 1162 choose.index = metadata["primary_names"] 1163 else: 1164 choose.index = metadata["primary_names"] + " # " + metadata["sets"] 1165 1166 if "#" not in entities[0]: 1167 choose.index = metadata["primary_names"] 1168 print( 1169 "You provided 'metadata_list', but did not include the set info (name # set) " 1170 "in the 'entities' list. " 1171 "Only the names will be compared, without considering the set information." 1172 ) 1173 1174 labels = ["target" if idx in entities else "rest" for idx in choose.index] 1175 valid = list( 1176 set(choose.index[[i for i, x in enumerate(labels) if x == "target"]]) 1177 ) 1178 1179 choose["DEG"] = labels 1180 choose = choose[choose["DEG"] != "drop"] 1181 1182 result_df = prepare_and_run_stat( 1183 choose.reset_index(drop=True), 1184 valid_group=valid, 1185 min_exp=min_exp, 1186 min_pct=min_pct, 1187 n_proc=n_proc, 1188 ) 1189 1190 return {"valid": valid, "control": "rest", "DEG": result_df} 1191 1192 elif entities == "All" and sets is None: 1193 print("\nAnalysis started...\nComparing each type of cell to others...") 1194 1195 if metadata_list is None: 1196 choose.index = metadata["primary_names"] 1197 else: 1198 choose.index = metadata["primary_names"] + " # " + metadata["sets"] 1199 1200 unique_labels = set(choose.index) 1201 1202 for label in tqdm(unique_labels): 1203 print(f"\nCalculating statistics for {label}") 1204 labels = ["target" if idx == label else "rest" for idx in choose.index] 1205 choose["DEG"] = labels 1206 choose = choose[choose["DEG"] != "drop"] 1207 result_df = prepare_and_run_stat( 1208 choose.copy(), 1209 valid_group=label, 1210 min_exp=min_exp, 1211 min_pct=min_pct, 1212 n_proc=n_proc, 1213 ) 1214 final_results.append(result_df) 1215 1216 final_results = pd.concat(final_results, ignore_index=True) 1217 1218 if metadata_list is None: 1219 final_results["valid_group"] = [ 1220 re.sub(" # ", "", x) for x in final_results["valid_group"] 1221 ] 1222 1223 return final_results 1224 1225 elif entities is None and sets == "All": 1226 print("\nAnalysis started...\nComparing each set/group to others...") 1227 choose.index = metadata["sets"] 1228 unique_sets = set(choose.index) 1229 1230 for label in tqdm(unique_sets): 1231 print(f"\nCalculating statistics for {label}") 1232 labels = ["target" if idx == label else "rest" for idx in choose.index] 1233 1234 choose["DEG"] = labels 1235 choose = choose[choose["DEG"] != "drop"] 1236 result_df = prepare_and_run_stat( 1237 choose.copy(), 1238 valid_group=label, 1239 min_exp=min_exp, 1240 min_pct=min_pct, 1241 n_proc=n_proc, 1242 ) 1243 final_results.append(result_df) 1244 1245 return pd.concat(final_results, ignore_index=True) 1246 1247 elif entities is None and isinstance(sets, dict): 1248 print("\nAnalysis started...\nComparing groups...") 1249 choose.index = metadata["sets"] 1250 1251 group_list = list(sets.keys()) 1252 if len(group_list) != 2: 1253 print("Only pairwise group comparison is supported.") 1254 return None 1255 1256 labels = [ 1257 ( 1258 "target" 1259 if idx in sets[group_list[0]] 1260 else "rest" if idx in sets[group_list[1]] else "drop" 1261 ) 1262 for idx in choose.index 1263 ] 1264 choose["DEG"] = labels 1265 choose = choose[choose["DEG"] != "drop"] 1266 1267 result_df = prepare_and_run_stat( 1268 choose.reset_index(drop=True), 1269 valid_group=group_list[0], 1270 min_exp=min_exp, 1271 min_pct=min_pct, 1272 n_proc=n_proc, 1273 ) 1274 return result_df 1275 1276 elif isinstance(entities, dict) and sets is None: 1277 print("\nAnalysis started...\nComparing groups...") 1278 1279 if metadata_list is None: 1280 choose.index = metadata["primary_names"] 1281 else: 1282 choose.index = metadata["primary_names"] + " # " + metadata["sets"] 1283 if "#" not in entities[list(entities.keys())[0]][0]: 1284 choose.index = metadata["primary_names"] 1285 print( 1286 "You provided 'metadata_list', but did not include the set info (name # set) " 1287 "in the 'entities' dict. " 1288 "Only the names will be compared, without considering the set information." 1289 ) 1290 1291 group_list = list(entities.keys()) 1292 if len(group_list) != 2: 1293 print("Only pairwise group comparison is supported.") 1294 return None 1295 1296 labels = [ 1297 ( 1298 "target" 1299 if idx in entities[group_list[0]] 1300 else "rest" if idx in entities[group_list[1]] else "drop" 1301 ) 1302 for idx in choose.index 1303 ] 1304 1305 choose["DEG"] = labels 1306 choose = choose[choose["DEG"] != "drop"] 1307 1308 result_df = prepare_and_run_stat( 1309 choose.reset_index(drop=True), 1310 valid_group=group_list[0], 1311 min_exp=min_exp, 1312 min_pct=min_pct, 1313 n_proc=n_proc, 1314 ) 1315 1316 return result_df.reset_index(drop=True) 1317 1318 else: 1319 raise ValueError( 1320 "You must specify either 'entities' or 'sets'. None were provided, which is not allowed for this analysis." 1321 ) 1322 1323 1324def average(data): 1325 """ 1326 Compute the column-wise average of a DataFrame, aggregating by column names. 1327 1328 If multiple columns share the same name, their values are averaged. 1329 1330 Parameters 1331 ---------- 1332 data : pandas.DataFrame 1333 Input DataFrame with numeric values. Columns with identical names 1334 will be aggregated by their mean. 1335 1336 Returns 1337 ------- 1338 pandas.DataFrame 1339 DataFrame with the same rows as the input but with unique columns, 1340 where duplicate columns have been replaced by their mean values. 1341 """ 1342 1343 wide_data = data 1344 1345 aggregated_df = wide_data.T.groupby(level=0).mean().T 1346 1347 return aggregated_df 1348 1349 1350def occurrence(data): 1351 """ 1352 Calculate the occurrence frequency of features in a DataFrame. 1353 1354 Converts the input DataFrame to binary (presence/absence) and computes 1355 the proportion of non-zero entries for each feature, aggregating by 1356 column names if duplicates exist. 1357 1358 Parameters 1359 ---------- 1360 data : pandas.DataFrame 1361 Input DataFrame with numeric values. Each column represents a feature. 1362 1363 Returns 1364 ------- 1365 pandas.DataFrame 1366 DataFrame with the same rows as the input, where each value represents 1367 the proportion of samples in which the feature is present (non-zero). 1368 Columns with identical names are aggregated. 1369 """ 1370 1371 binary_data = (data > 0).astype(int) 1372 1373 counts = binary_data.columns.value_counts() 1374 1375 binary_data = binary_data.T.groupby(level=0).sum().T 1376 binary_data = binary_data.astype(float) 1377 1378 for i in counts.index: 1379 binary_data.loc[:, i] = (binary_data.loc[:, i] / counts[i]).astype(float) 1380 1381 return binary_data 1382 1383 1384def add_subnames(names_list: list, parent_name: str, new_clusters: list): 1385 """ 1386 Append sub-cluster names to a parent name within a list of names. 1387 1388 This function replaces occurrences of `parent_name` in `names_list` with 1389 a concatenation of the parent name and corresponding sub-cluster name 1390 from `new_clusters` (formatted as "parent.subcluster"). Non-matching names 1391 are left unchanged. 1392 1393 Parameters 1394 ---------- 1395 names_list : list 1396 Original list of names (e.g., column names or cluster labels). 1397 1398 parent_name : str 1399 Name of the parent cluster to which sub-cluster names will be added. 1400 Must exist in `names_list`. 1401 1402 new_clusters : list 1403 List of sub-cluster names. Its length must match the number of times 1404 `parent_name` occurs in `names_list`. 1405 1406 Returns 1407 ------- 1408 list 1409 Updated list of names with sub-cluster names appended to the parent name. 1410 1411 Raises 1412 ------ 1413 ValueError 1414 - If `parent_name` is not found in `names_list`. 1415 - If `new_clusters` length does not match the number of occurrences of 1416 `parent_name`. 1417 1418 Examples 1419 -------- 1420 >>> add_subnames(['A', 'B', 'A'], 'A', ['1', '2']) 1421 ['A.1', 'B', 'A.2'] 1422 """ 1423 1424 if str(parent_name) not in [str(x) for x in names_list]: 1425 raise ValueError( 1426 "Parent name is missing from the original dataset`s column names!" 1427 ) 1428 1429 if len(new_clusters) != len([x for x in names_list if str(x) == str(parent_name)]): 1430 raise ValueError( 1431 "New cluster names list has a different length than the number of clusters in the original dataset!" 1432 ) 1433 1434 new_names = [] 1435 ixn = 0 1436 for _, i in enumerate(names_list): 1437 if str(i) == str(parent_name): 1438 1439 new_names.append(f"{parent_name}.{new_clusters[ixn]}") 1440 ixn += 1 1441 1442 else: 1443 new_names.append(i) 1444 1445 return new_names 1446 1447 1448def development_clust( 1449 data: pd.DataFrame, method: str = "ward", img_width: int = 5, img_high: int = 5 1450): 1451 """ 1452 Perform hierarchical clustering on the columns of a DataFrame and plot a dendrogram. 1453 1454 Uses Ward's method to cluster the transposed data (columns) and generates 1455 a dendrogram showing the relationships between features or samples. 1456 1457 Parameters 1458 ---------- 1459 data : pandas.DataFrame 1460 Input DataFrame with features as rows and samples/columns to be clustered. 1461 1462 method : str 1463 Method for hierarchical clustering. Options include: 1464 - 'ward' : minimizes the variance of clusters being merged. 1465 - 'single' : uses the minimum of the distances between all observations of the two sets. 1466 - 'complete' : uses the maximum of the distances between all observations of the two sets. 1467 - 'average' : uses the average of the distances between all observations of the two sets. 1468 1469 img_width : int or float, default=5 1470 Width of the resulting figure in inches. 1471 1472 img_high : int or float, default=5 1473 Height of the resulting figure in inches. 1474 1475 Returns 1476 ------- 1477 matplotlib.figure.Figure 1478 The dendrogram figure. 1479 """ 1480 1481 z = linkage(data.T, method=method) 1482 1483 figure, ax = plt.subplots(figsize=(img_width, img_high)) 1484 1485 dendrogram(z, labels=data.columns, orientation="left", ax=ax) 1486 1487 return figure 1488 1489 1490def adjust_cells_to_group_mean(data, data_avg, beta=0.2): 1491 """ 1492 Adjust each cell's values towards the mean of its group (centroid). 1493 1494 This function moves each cell's values in `data` slightly towards the 1495 corresponding group mean in `data_avg`, controlled by the parameter `beta`. 1496 1497 Parameters 1498 ---------- 1499 data : pandas.DataFrame 1500 Original data with features as rows and cells/samples as columns. 1501 1502 data_avg : pandas.DataFrame 1503 DataFrame of group averages (centroids) with features as rows and 1504 group names as columns. 1505 1506 beta : float, default=0.2 1507 Weight for adjustment towards the group mean. 0 = no adjustment, 1508 1 = fully replaced by the group mean. 1509 1510 Returns 1511 ------- 1512 pandas.DataFrame 1513 Adjusted data with the same shape as the input `data`. 1514 """ 1515 1516 df_adjusted = data.copy() 1517 1518 for group_name in data_avg.columns: 1519 col_idx = [ 1520 i 1521 for i, c in enumerate(df_adjusted.columns) 1522 if str(c).startswith(group_name) 1523 ] 1524 if not col_idx: 1525 continue 1526 1527 centroid = data_avg.loc[df_adjusted.index, group_name].to_numpy()[:, None] 1528 1529 df_adjusted.iloc[:, col_idx] = (1 - beta) * df_adjusted.iloc[ 1530 :, col_idx 1531 ].to_numpy() + beta * centroid 1532 1533 return df_adjusted
20def load_sparse(path: str, name: str): 21 """ 22 Load a sparse matrix dataset along with associated gene 23 and cell metadata, and return it as a dense DataFrame. 24 25 This function expects the input directory to contain three files in standard 26 10x Genomics format: 27 - "matrix.mtx": the gene expression matrix in Matrix Market format 28 - "genes.tsv": tab-separated file containing gene identifiers 29 - "barcodes.tsv": tab-separated file containing cell barcodes / names 30 31 Parameters 32 ---------- 33 path : str 34 Path to the directory containing the matrix and annotation files. 35 36 name : str 37 Label or dataset identifier to be assigned to all cells in the metadata. 38 39 Returns 40 ------- 41 data : pandas.DataFrame 42 Dense expression matrix where rows correspond to genes and columns 43 correspond to cells. 44 metadata : pandas.DataFrame 45 Metadata DataFrame with two columns: 46 - "cell_names": the names of the cells (barcodes) 47 - "sets": the dataset label assigned to each cell (from `name`) 48 49 Notes 50 ----- 51 The function converts the sparse matrix into a dense DataFrame. This may 52 require a large amount of memory for datasets with many cells and genes. 53 """ 54 55 data = mmread(os.path.join(path, "matrix.mtx")) 56 data = pd.DataFrame(data.todense()) 57 genes = pd.read_csv(os.path.join(path, "genes.tsv"), header=None, sep="\t") 58 names = pd.read_csv(os.path.join(path, "barcodes.tsv"), header=None, sep="\t") 59 data.columns = [str(x) for x in names[0]] 60 data.index = list(genes[0]) 61 62 names = list(data.columns) 63 sets = [name] * len(names) 64 65 metadata = pd.DataFrame({"cell_names": names, "sets": sets}) 66 67 return data, metadata
Load a sparse matrix dataset along with associated gene and cell metadata, and return it as a dense DataFrame.
This function expects the input directory to contain three files in standard 10x Genomics format:
- "matrix.mtx": the gene expression matrix in Matrix Market format
- "genes.tsv": tab-separated file containing gene identifiers
- "barcodes.tsv": tab-separated file containing cell barcodes / names
Parameters
path : str Path to the directory containing the matrix and annotation files.
name : str Label or dataset identifier to be assigned to all cells in the metadata.
Returns
data : pandas.DataFrame
Dense expression matrix where rows correspond to genes and columns
correspond to cells.
metadata : pandas.DataFrame
Metadata DataFrame with two columns:
- "cell_names": the names of the cells (barcodes)
- "sets": the dataset label assigned to each cell (from name)
Notes
The function converts the sparse matrix into a dense DataFrame. This may require a large amount of memory for datasets with many cells and genes.
70def volcano_plot( 71 deg_data: pd.DataFrame, 72 p_adj: bool = True, 73 top: int = 25, 74 top_rank: str = "p_value", 75 p_val: float | int = 0.05, 76 lfc: float | int = 0.25, 77 rescale_adj: bool = True, 78 image_width: int = 12, 79 image_high: int = 12, 80): 81 """ 82 Generate a volcano plot from differential expression results. 83 84 A volcano plot visualizes the relationship between statistical significance 85 (p-values or standarized p-value) and log(fold change) for each gene, highlighting 86 genes that pass significance thresholds. 87 88 Parameters 89 ---------- 90 deg_data : pandas.DataFrame 91 DataFrame containing differential expression results from calc_DEG() function. 92 93 p_adj : bool, default=True 94 If True, use adjusted p-values. If False, use raw p-values. 95 96 top : int, default=25 97 Number of top significant genes to highlight on the plot. 98 99 top_rank : str, default='p_value' 100 Statistic used primarily to determine the top significant genes to highlight on the plot. ['p_value' or 'FC'] 101 102 p_val : float | int, default=0.05 103 Significance threshold for p-values (or adjusted p-values). 104 105 lfc : float | int, default=0.25 106 Threshold for absolute log fold change. 107 108 rescale_adj : bool, default=True 109 If True, rescale p-values to avoid long breaks caused by outlier values. 110 111 image_width : int, default=12 112 Width of the generated plot in inches. 113 114 image_high : int, default=12 115 Height of the generated plot in inches. 116 117 Returns 118 ------- 119 matplotlib.figure.Figure 120 The generated volcano plot figure. 121 122 """ 123 124 if top_rank.upper() not in ["FC", "P_VALUE"]: 125 raise ValueError("top_rank must be either 'FC' or 'p_value'") 126 127 if p_adj: 128 pv = "adj_pval" 129 else: 130 pv = "p_val" 131 132 deg_df = deg_data.copy() 133 134 shift = 0.25 135 136 p_val_scale = "-log(p_val)" 137 138 min_minus = min(deg_df[pv][(deg_df[pv] != 0) & (deg_df["log(FC)"] < 0)]) 139 min_plus = min(deg_df[pv][(deg_df[pv] != 0) & (deg_df["log(FC)"] > 0)]) 140 141 zero_p_plus = deg_df[(deg_df[pv] == 0) & (deg_df["log(FC)"] > 0)] 142 zero_p_plus = zero_p_plus.sort_values(by="log(FC)", ascending=False).reset_index( 143 drop=True 144 ) 145 zero_p_plus[pv] = [ 146 (shift * x) * min_plus for x in range(1, len(zero_p_plus.index) + 1) 147 ] 148 149 zero_p_minus = deg_df[(deg_df[pv] == 0) & (deg_df["log(FC)"] < 0)] 150 zero_p_minus = zero_p_minus.sort_values(by="log(FC)", ascending=True).reset_index( 151 drop=True 152 ) 153 zero_p_minus[pv] = [ 154 (shift * x) * min_minus for x in range(1, len(zero_p_minus.index) + 1) 155 ] 156 157 tmp_p = deg_df[ 158 ((deg_df[pv] != 0) & (deg_df["log(FC)"] < 0)) 159 | ((deg_df[pv] != 0) & (deg_df["log(FC)"] > 0)) 160 ] 161 162 del deg_df 163 164 deg_df = pd.concat([zero_p_plus, tmp_p, zero_p_minus], ignore_index=True) 165 166 deg_df[pv] = deg_df[pv].replace(0, 2**-1074) 167 168 deg_df[p_val_scale] = -np.log10(deg_df[pv]) 169 170 deg_df["top100"] = None 171 172 if rescale_adj: 173 174 deg_df = deg_df.sort_values(by=p_val_scale, ascending=False) 175 176 deg_df = deg_df.reset_index(drop=True) 177 178 eps = 1e-300 179 doubled = [] 180 ratio = [] 181 for n, i in enumerate(deg_df.index): 182 for j in range(1, 6): 183 if ( 184 n + j < len(deg_df.index) 185 and (deg_df[p_val_scale][n] + eps) 186 / (deg_df[p_val_scale][n + j] + eps) 187 >= 2 188 ): 189 doubled.append(n) 190 ratio.append( 191 (deg_df[p_val_scale][n + j] + eps) 192 / (deg_df[p_val_scale][n] + eps) 193 ) 194 195 df = pd.DataFrame({"doubled": doubled, "ratio": ratio}) 196 df = df[df["doubled"] < 100] 197 198 df["ratio"] = (1 - df["ratio"]) / 5 199 df = df.reset_index(drop=True) 200 201 df = df.sort_values("doubled") 202 203 if len(df["doubled"]) == 1 and 0 in df["doubled"]: 204 df = df 205 else: 206 doubled2 = [] 207 208 for l in df["doubled"]: 209 if l + 1 != len(doubled) and l + 1 - l == 1: 210 doubled2.append(l) 211 doubled2.append(l + 1) 212 else: 213 break 214 215 doubled2 = sorted(set(doubled2), reverse=True) 216 217 if len(doubled2) > 1: 218 df = df[df["doubled"].isin(doubled2)] 219 df = df.sort_values("doubled", ascending=False) 220 df = df.reset_index(drop=True) 221 for c in df.index: 222 deg_df.loc[df["doubled"][c], p_val_scale] = deg_df.loc[ 223 df["doubled"][c] + 1, p_val_scale 224 ] * (1 + df["ratio"][c]) 225 226 deg_df.loc[(deg_df["log(FC)"] <= 0) & (deg_df[pv] <= p_val), "top100"] = "red" 227 deg_df.loc[(deg_df["log(FC)"] > 0) & (deg_df[pv] <= p_val), "top100"] = "blue" 228 deg_df.loc[deg_df[pv] > p_val, "top100"] = "lightgray" 229 230 if lfc > 0: 231 deg_df.loc[ 232 (deg_df["log(FC)"] <= lfc) & (deg_df["log(FC)"] >= -lfc), "top100" 233 ] = "lightgray" 234 235 down_int = len( 236 deg_df["top100"][(deg_df["log(FC)"] <= lfc * -1) & (deg_df[pv] <= p_val)] 237 ) 238 up_int = len(deg_df["top100"][(deg_df["log(FC)"] > lfc) & (deg_df[pv] <= p_val)]) 239 240 deg_df_up = deg_df[deg_df["log(FC)"] > 0] 241 242 if top_rank.upper() == "P_VALUE": 243 deg_df_up = deg_df_up.sort_values([pv, "log(FC)"], ascending=[True, False]) 244 elif top_rank.upper() == "FC": 245 deg_df_up = deg_df_up.sort_values(["log(FC)", pv], ascending=[False, True]) 246 247 deg_df_up = deg_df_up.reset_index(drop=True) 248 249 n = -1 250 l = 0 251 while True: 252 n += 1 253 if deg_df_up["log(FC)"][n] > lfc and deg_df_up[pv][n] <= p_val: 254 deg_df_up.loc[n, "top100"] = "green" 255 l += 1 256 if l == top or deg_df_up[pv][n] > p_val: 257 break 258 259 deg_df_down = deg_df[deg_df["log(FC)"] <= 0] 260 261 if top_rank.upper() == "P_VALUE": 262 deg_df_down = deg_df_down.sort_values([pv, "log(FC)"], ascending=[True, True]) 263 elif top_rank.upper() == "FC": 264 deg_df_down = deg_df_down.sort_values(["log(FC)", pv], ascending=[True, True]) 265 266 deg_df_down = deg_df_down.reset_index(drop=True) 267 268 n = -1 269 l = 0 270 while True: 271 n += 1 272 if deg_df_down["log(FC)"][n] < lfc * -1 and deg_df_down[pv][n] <= p_val: 273 deg_df_down.loc[n, "top100"] = "yellow" 274 275 l += 1 276 if l == top or deg_df_down[pv][n] > p_val: 277 break 278 279 deg_df = pd.concat([deg_df_up, deg_df_down]) 280 281 que = ["lightgray", "red", "blue", "yellow", "green"] 282 283 deg_df = deg_df.sort_values( 284 by="top100", key=lambda x: x.map({v: i for i, v in enumerate(que)}) 285 ) 286 287 deg_df = deg_df.reset_index(drop=True) 288 289 fig, ax = plt.subplots(figsize=(image_width, image_high)) 290 291 plt.scatter( 292 x=deg_df["log(FC)"], y=deg_df[p_val_scale], color=deg_df["top100"], zorder=2 293 ) 294 295 tl = deg_df[p_val_scale][deg_df[pv] >= p_val] 296 297 if len(tl) > 0: 298 299 line_p = np.max(tl) 300 301 else: 302 line_p = np.min(deg_df[p_val_scale]) 303 304 plt.plot( 305 [max(deg_df["log(FC)"]) * -1.1, max(deg_df["log(FC)"]) * 1.1], 306 [line_p, line_p], 307 linestyle="--", 308 linewidth=3, 309 color="lightgray", 310 zorder=1, 311 ) 312 313 if lfc > 0: 314 plt.plot( 315 [lfc * -1, lfc * -1], 316 [-3, max(deg_df[p_val_scale]) * 1.1], 317 linestyle="--", 318 linewidth=3, 319 color="lightgray", 320 zorder=1, 321 ) 322 plt.plot( 323 [lfc, lfc], 324 [-3, max(deg_df[p_val_scale]) * 1.1], 325 linestyle="--", 326 linewidth=3, 327 color="lightgray", 328 zorder=1, 329 ) 330 331 plt.xlabel("log(FC)") 332 plt.ylabel(p_val_scale) 333 plt.title("Volcano plot") 334 335 plt.ylim(min(deg_df[p_val_scale]) - 5, max(deg_df[p_val_scale]) * 1.25) 336 337 texts = [ 338 ax.text(deg_df["log(FC)"][i], deg_df[p_val_scale][i], deg_df["feature"][i]) 339 for i in deg_df.index 340 if deg_df["top100"][i] in ["green", "yellow"] 341 ] 342 343 adjust_text(texts, arrowprops=dict(arrowstyle="-", color="gray", alpha=0.5)) 344 345 legend_elements = [ 346 Line2D( 347 [0], 348 [0], 349 marker="o", 350 color="w", 351 label="top-upregulated", 352 markerfacecolor="green", 353 markersize=10, 354 ), 355 Line2D( 356 [0], 357 [0], 358 marker="o", 359 color="w", 360 label="top-downregulated", 361 markerfacecolor="yellow", 362 markersize=10, 363 ), 364 Line2D( 365 [0], 366 [0], 367 marker="o", 368 color="w", 369 label="upregulated", 370 markerfacecolor="blue", 371 markersize=10, 372 ), 373 Line2D( 374 [0], 375 [0], 376 marker="o", 377 color="w", 378 label="downregulated", 379 markerfacecolor="red", 380 markersize=10, 381 ), 382 Line2D( 383 [0], 384 [0], 385 marker="o", 386 color="w", 387 label="non-significant", 388 markerfacecolor="lightgray", 389 markersize=10, 390 ), 391 ] 392 393 ax.legend(handles=legend_elements, loc="upper right") 394 ax.grid(visible=False) 395 396 ax.annotate( 397 f"\nmin {pv} = " + str(p_val), 398 xy=(0.025, 0.975), 399 xycoords="axes fraction", 400 fontsize=12, 401 ) 402 403 if lfc > 0: 404 ax.annotate( 405 "\nmin log(FC) = " + str(lfc), 406 xy=(0.025, 0.95), 407 xycoords="axes fraction", 408 fontsize=12, 409 ) 410 411 ax.annotate( 412 "\nDownregulated: " + str(down_int), 413 xy=(0.025, 0.925), 414 xycoords="axes fraction", 415 fontsize=12, 416 color="red", 417 ) 418 419 ax.annotate( 420 "\nUpregulated: " + str(up_int), 421 xy=(0.025, 0.9), 422 xycoords="axes fraction", 423 fontsize=12, 424 color="blue", 425 ) 426 427 plt.show() 428 429 return fig
Generate a volcano plot from differential expression results.
A volcano plot visualizes the relationship between statistical significance (p-values or standarized p-value) and log(fold change) for each gene, highlighting genes that pass significance thresholds.
Parameters
deg_data : pandas.DataFrame DataFrame containing differential expression results from calc_DEG() function.
p_adj : bool, default=True If True, use adjusted p-values. If False, use raw p-values.
top : int, default=25 Number of top significant genes to highlight on the plot.
top_rank : str, default='p_value' Statistic used primarily to determine the top significant genes to highlight on the plot. ['p_value' or 'FC']
p_val : float | int, default=0.05 Significance threshold for p-values (or adjusted p-values).
lfc : float | int, default=0.25 Threshold for absolute log fold change.
rescale_adj : bool, default=True If True, rescale p-values to avoid long breaks caused by outlier values.
image_width : int, default=12 Width of the generated plot in inches.
image_high : int, default=12 Height of the generated plot in inches.
Returns
matplotlib.figure.Figure The generated volcano plot figure.
432def find_features(data: pd.DataFrame, features: list): 433 """ 434 Identify features (rows) from a DataFrame that match a given list of features, 435 ignoring case sensitivity. 436 437 Parameters 438 ---------- 439 data : pandas.DataFrame 440 DataFrame with features in the index (rows). 441 442 features : list 443 List of feature names to search for. 444 445 Returns 446 ------- 447 dict 448 Dictionary with keys: 449 - "included": list of features found in the DataFrame index. 450 - "not_included": list of requested features not found in the DataFrame index. 451 - "potential": list of features in the DataFrame that may be similar. 452 """ 453 454 features_upper = [str(x).upper() for x in features] 455 456 index_set = set(data.index) 457 458 features_in = [x for x in index_set if x.upper() in features_upper] 459 features_in_upper = [x.upper() for x in features_in] 460 features_out_upper = [x for x in features_upper if x not in features_in_upper] 461 features_out = [x for x in features if x.upper() in features_out_upper] 462 similar_features = [ 463 idx for idx in index_set if any(x in idx.upper() for x in features_out_upper) 464 ] 465 466 return { 467 "included": features_in, 468 "not_included": features_out, 469 "potential": similar_features, 470 }
Identify features (rows) from a DataFrame that match a given list of features, ignoring case sensitivity.
Parameters
data : pandas.DataFrame DataFrame with features in the index (rows).
features : list List of feature names to search for.
Returns
dict Dictionary with keys: - "included": list of features found in the DataFrame index. - "not_included": list of requested features not found in the DataFrame index. - "potential": list of features in the DataFrame that may be similar.
473def find_names(data: pd.DataFrame, names: list): 474 """ 475 Identify names (columns) from a DataFrame that match a given list of names, 476 ignoring case sensitivity. 477 478 Parameters 479 ---------- 480 data : pandas.DataFrame 481 DataFrame with names in the columns. 482 483 names : list 484 List of names to search for. 485 486 Returns 487 ------- 488 dict 489 Dictionary with keys: 490 - "included": list of names found in the DataFrame columns. 491 - "not_included": list of requested names not found in the DataFrame columns. 492 - "potential": list of names in the DataFrame that may be similar. 493 """ 494 495 names_upper = [str(x).upper() for x in names] 496 497 columns = set(data.columns) 498 499 names_in = [x for x in columns if x.upper() in names_upper] 500 names_in_upper = [x.upper() for x in names_in] 501 names_out_upper = [x for x in names_upper if x not in names_in_upper] 502 names_out = [x for x in names if x.upper() in names_out_upper] 503 similar_names = [ 504 idx for idx in columns if any(x in idx.upper() for x in names_out_upper) 505 ] 506 507 return {"included": names_in, "not_included": names_out, "potential": similar_names}
Identify names (columns) from a DataFrame that match a given list of names, ignoring case sensitivity.
Parameters
data : pandas.DataFrame DataFrame with names in the columns.
names : list List of names to search for.
Returns
dict Dictionary with keys: - "included": list of names found in the DataFrame columns. - "not_included": list of requested names not found in the DataFrame columns. - "potential": list of names in the DataFrame that may be similar.
510def reduce_data(data: pd.DataFrame, features: list = [], names: list = []): 511 """ 512 Subset a DataFrame based on selected features (rows) and/or names (columns). 513 514 Parameters 515 ---------- 516 data : pandas.DataFrame 517 Input DataFrame with features as rows and names as columns. 518 519 features : list 520 List of features to include (rows). Default is an empty list. 521 If empty, all rows are returned. 522 523 names : list 524 List of names to include (columns). Default is an empty list. 525 If empty, all columns are returned. 526 527 Returns 528 ------- 529 pandas.DataFrame 530 Subset of the input DataFrame containing only the selected rows 531 and/or columns. 532 533 Raises 534 ------ 535 ValueError 536 If both `features` and `names` are empty. 537 """ 538 539 if len(features) > 0 and len(names) > 0: 540 fet = find_features(data=data, features=features) 541 542 nam = find_names(data=data, names=names) 543 544 data_to_return = data.loc[fet["included"], nam["included"]] 545 546 elif len(features) > 0 and len(names) == 0: 547 fet = find_features(data=data, features=features) 548 549 data_to_return = data.loc[fet["included"], :] 550 551 elif len(features) == 0 and len(names) > 0: 552 553 nam = find_names(data=data, names=names) 554 555 data_to_return = data.loc[:, nam["included"]] 556 557 else: 558 559 raise ValueError("features and names have zero length!") 560 561 return data_to_return
Subset a DataFrame based on selected features (rows) and/or names (columns).
Parameters
data : pandas.DataFrame Input DataFrame with features as rows and names as columns.
features : list List of features to include (rows). Default is an empty list. If empty, all rows are returned.
names : list List of names to include (columns). Default is an empty list. If empty, all columns are returned.
Returns
pandas.DataFrame Subset of the input DataFrame containing only the selected rows and/or columns.
Raises
ValueError
If both features and names are empty.
564def make_unique_list(lst): 565 """ 566 Generate a list where duplicate items are renamed to ensure uniqueness. 567 568 Each duplicate is appended with a suffix ".n", where n indicates the 569 occurrence count (starting from 1). 570 571 Parameters 572 ---------- 573 lst : list 574 Input list of items (strings or other hashable types). 575 576 Returns 577 ------- 578 list 579 List with unique values. 580 581 Examples 582 -------- 583 >>> make_unique_list(["A", "B", "A", "A"]) 584 ['A', 'B', 'A.1', 'A.2'] 585 """ 586 seen = {} 587 result = [] 588 for item in lst: 589 if item not in seen: 590 seen[item] = 0 591 result.append(item) 592 else: 593 seen[item] += 1 594 result.append(f"{item}.{seen[item]}") 595 return result
Generate a list where duplicate items are renamed to ensure uniqueness.
Each duplicate is appended with a suffix ".n", where n indicates the occurrence count (starting from 1).
Parameters
lst : list Input list of items (strings or other hashable types).
Returns
list List with unique values.
Examples
>>> make_unique_list(["A", "B", "A", "A"])
['A', 'B', 'A.1', 'A.2']
605def features_scatter( 606 expression_data: pd.DataFrame, 607 occurence_data: pd.DataFrame | None = None, 608 scale: bool = False, 609 features: list | None = None, 610 metadata_list: list | None = None, 611 colors: str = "viridis", 612 hclust: str | None = "complete", 613 img_width: int = 8, 614 img_high: int = 5, 615 label_size: int = 10, 616 size_scale: int = 100, 617 y_lab: str = "Genes", 618 legend_lab: str = "log(CPM + 1)", 619 set_box_size: float | int = 5, 620 set_box_high: float | int = 5, 621 bbox_to_anchor_scale: int = 25, 622 bbox_to_anchor_perc: tuple = (0.91, 0.63), 623 bbox_to_anchor_group: tuple = (1.01, 0.4), 624): 625 """ 626 Create a bubble scatter plot of selected features across samples. 627 628 Each point represents a feature-sample pair, where the color encodes the 629 expression value and the size encodes occurrence or relative abundance. 630 Optionally, hierarchical clustering can be applied to order rows and columns. 631 632 Parameters 633 ---------- 634 expression_data : pandas.DataFrame 635 Expression values (mean) with features as rows and samples as columns derived from average() function. 636 637 occurence_data : pandas.DataFrame or None 638 DataFrame with occurrence/frequency values (same shape as `expression_data`) derived from occurrence() function. 639 If None, bubble sizes are based on expression values. 640 641 scale: bool, default False 642 If True, expression_data (features) will be scaled (0–1) across the colums (sample). 643 644 features : list or None 645 List of features (rows) to display. If None, all features are used. 646 647 metadata_list : list or None, optional 648 Metadata grouping for samples (same length as number of columns). 649 Used to add group colors and separators in the plot. 650 651 colors : str, default='viridis' 652 Colormap for expression values. 653 654 hclust : str or None, default='complete' 655 Linkage method for hierarchical clustering. If None, no clustering 656 is performed. 657 658 img_width : int or float, default=8 659 Width of the plot in inches. 660 661 img_high : int or float, default=5 662 Height of the plot in inches. 663 664 label_size : int, default=10 665 Font size for axis labels and ticks. 666 667 size_scale : int or float, default=100 668 Scaling factor for bubble sizes. 669 670 y_lab : str, default='Genes' 671 Label for the x-axis. 672 673 legend_lab : str, default='log(CPM + 1)' 674 Label for the colorbar legend. 675 676 bbox_to_anchor_scale : int, default=25 677 Vertical scale (percentage) for positioning the colorbar. 678 679 bbox_to_anchor_perc : tuple, default=(0.91, 0.63) 680 Anchor position for the size legend (percent bubble legend). 681 682 bbox_to_anchor_group : tuple, default=(1.01, 0.4) 683 Anchor position for the group legend. 684 685 Returns 686 ------- 687 matplotlib.figure.Figure 688 The generated scatter plot figure. 689 690 Raises 691 ------ 692 ValueError 693 If `metadata_list` is provided but its length does not match 694 the number of columns in `expression_data`. 695 696 Notes 697 ----- 698 - Colors represent expression values normalized to the colormap. 699 - Bubble sizes represent occurrence values (or expression values if 700 `occurence_data` is None). 701 - If `metadata_list` is given, groups are indicated with colors and 702 dashed vertical separators. 703 """ 704 705 scatter_df = expression_data.copy() 706 707 if scale: 708 709 legend_lab = "Scaled\n" + legend_lab 710 711 column_max = scatter_df.max() 712 scatter_df = scatter_df.div(column_max).replace([np.inf, -np.inf], np.nan).fillna(0) 713 scatter_df = pd.DataFrame(scatter_df, index=scatter_df.index, columns=scatter_df.columns) 714 715 716 metadata = {} 717 718 metadata["primary_names"] = [str(x) for x in scatter_df.columns] 719 720 if metadata_list is not None: 721 metadata["sets"] = metadata_list 722 723 if len(metadata["primary_names"]) != len(metadata["sets"]): 724 725 raise ValueError( 726 "Metadata list and DataFrame columns must have the same length." 727 ) 728 729 else: 730 731 metadata["sets"] = [""] * len(metadata["primary_names"]) 732 733 metadata = pd.DataFrame(metadata) 734 if features is not None: 735 scatter_df = scatter_df.loc[ 736 find_features(data=scatter_df, features=features)["included"], 737 ] 738 scatter_df.columns = metadata["primary_names"] + "#" + metadata["sets"] 739 740 if occurence_data is not None: 741 if features is not None: 742 occurence_data = occurence_data.loc[ 743 find_features(data=occurence_data, features=features)["included"], 744 ] 745 occurence_data.columns = metadata["primary_names"] + "#" + metadata["sets"] 746 747 # check duplicated names 748 749 tmp_columns = scatter_df.columns 750 751 new_cols = make_unique_list(list(tmp_columns)) 752 753 scatter_df.columns = new_cols 754 755 if hclust is not None and len(expression_data.index) != 1: 756 757 Z = linkage(scatter_df, method=hclust) 758 759 # Get the order of features based on the dendrogram 760 order_of_features = dendrogram(Z, no_plot=True)["leaves"] 761 762 indexes_sort = list(scatter_df.index) 763 sorted_list_rows = [] 764 for n in order_of_features: 765 sorted_list_rows.append(indexes_sort[n]) 766 767 scatter_df = scatter_df.transpose() 768 769 Z = linkage(scatter_df, method=hclust) 770 771 # Get the order of features based on the dendrogram 772 order_of_features = dendrogram(Z, no_plot=True)["leaves"] 773 774 indexes_sort = list(scatter_df.index) 775 sorted_list_columns = [] 776 for n in order_of_features: 777 sorted_list_columns.append(indexes_sort[n]) 778 779 scatter_df = scatter_df.transpose() 780 781 scatter_df = scatter_df.loc[sorted_list_rows, sorted_list_columns] 782 783 if occurence_data is not None: 784 occurence_data = occurence_data.loc[sorted_list_rows, sorted_list_columns] 785 786 metadata["sets"] = [re.sub(".*#", "", x) for x in scatter_df.columns] 787 788 scatter_df.columns = [re.sub("#.*", "", x) for x in scatter_df.columns] 789 790 if occurence_data is not None: 791 occurence_data.columns = [re.sub("#.*", "", x) for x in occurence_data.columns] 792 793 fig, ax = plt.subplots(figsize=(img_width, img_high)) 794 795 norm = plt.Normalize(0, np.max(scatter_df)) 796 797 cmap = plt.get_cmap(colors) 798 799 # Bubble scatter 800 for i, _ in enumerate(scatter_df.index): 801 for j, _ in enumerate(scatter_df.columns): 802 if occurence_data is not None: 803 value_e = scatter_df.iloc[i, j] 804 value_o = occurence_data.iloc[i, j] 805 ax.scatter( 806 j, 807 i, 808 s=value_o * size_scale, 809 c=[cmap(norm(value_e))], 810 edgecolors="k", 811 linewidths=0.3, 812 ) 813 else: 814 value = scatter_df.iloc[i, j] 815 ax.scatter( 816 j, 817 i, 818 s=value * size_scale, 819 c=[cmap(norm(value))], 820 edgecolors="k", 821 linewidths=0.3, 822 ) 823 824 ax.set_yticks(range(len(scatter_df.index))) 825 ax.set_yticklabels(scatter_df.index, fontsize=label_size * 0.8) 826 ax.set_ylabel(y_lab, fontsize=label_size) 827 ax.set_xticks(range(len(scatter_df.columns))) 828 ax.set_xticklabels(scatter_df.columns, fontsize=label_size * 0.8, rotation=90) 829 830 ax_pos = ax.get_position() 831 832 width_fig = 0.01 833 height_fig = ax_pos.height * (bbox_to_anchor_scale / 100) 834 left_fig = ax_pos.x1 + 0.01 835 bottom_fig = ax_pos.y1 - height_fig 836 837 cax = fig.add_axes([left_fig, bottom_fig, width_fig, height_fig]) 838 cb = mpl.colorbar.ColorbarBase(cax, cmap=cmap, norm=norm) 839 cb.set_label(legend_lab, fontsize=label_size * 0.65) 840 cb.ax.tick_params(labelsize=label_size * 0.7) 841 842 if metadata_list is not None: 843 844 metadata_list = list(metadata["sets"]) 845 group_colors = get_color_palette(list(set(metadata_list)), palette_name="tab10") 846 847 for i, group in enumerate(metadata_list): 848 ax.add_patch( 849 plt.Rectangle( 850 (i - 0.5, len(scatter_df.index) - 0.1 * set_box_high), 851 1, 852 0.1 * set_box_size, 853 color=group_colors[group], 854 transform=ax.transData, 855 clip_on=False, 856 ) 857 ) 858 859 for i in range(1, len(metadata_list)): 860 if metadata_list[i] != metadata_list[i - 1]: 861 ax.axvline(i - 0.5, color="black", linestyle="--", lw=1) 862 863 group_patches = [ 864 mpatches.Patch(color=color, label=label) 865 for label, color in group_colors.items() 866 ] 867 fig.legend( 868 handles=group_patches, 869 title="Group", 870 fontsize=label_size * 0.7, 871 title_fontsize=label_size * 0.7, 872 loc="center left", 873 bbox_to_anchor=bbox_to_anchor_group, 874 frameon=False, 875 ) 876 877 # second legend (size) 878 if occurence_data is not None: 879 size_values = [0.25, 0.5, 1] 880 legend2_handles = [ 881 plt.Line2D( 882 [], 883 [], 884 marker="o", 885 linestyle="", 886 markersize=np.sqrt(v * size_scale * 0.5), 887 color="gray", 888 alpha=0.6, 889 label=f"{v * 100:.1f}", 890 ) 891 for v in size_values 892 ] 893 894 fig.legend( 895 handles=legend2_handles, 896 title="Percent [%]", 897 fontsize=label_size * 0.7, 898 title_fontsize=label_size * 0.7, 899 loc="center left", 900 bbox_to_anchor=bbox_to_anchor_perc, 901 frameon=False, 902 ) 903 904 _, ymax = ax.get_ylim() 905 906 ax.set_xlim(-0.5, len(scatter_df.columns) - 0.5) 907 ax.set_ylim(-0.5, ymax + 0.5) 908 909 return fig
Create a bubble scatter plot of selected features across samples.
Each point represents a feature-sample pair, where the color encodes the expression value and the size encodes occurrence or relative abundance. Optionally, hierarchical clustering can be applied to order rows and columns.
Parameters
expression_data : pandas.DataFrame Expression values (mean) with features as rows and samples as columns derived from average() function.
occurence_data : pandas.DataFrame or None
DataFrame with occurrence/frequency values (same shape as expression_data) derived from occurrence() function.
If None, bubble sizes are based on expression values.
scale: bool, default False If True, expression_data (features) will be scaled (0–1) across the colums (sample).
features : list or None List of features (rows) to display. If None, all features are used.
metadata_list : list or None, optional Metadata grouping for samples (same length as number of columns). Used to add group colors and separators in the plot.
colors : str, default='viridis' Colormap for expression values.
hclust : str or None, default='complete' Linkage method for hierarchical clustering. If None, no clustering is performed.
img_width : int or float, default=8 Width of the plot in inches.
img_high : int or float, default=5 Height of the plot in inches.
label_size : int, default=10 Font size for axis labels and ticks.
size_scale : int or float, default=100 Scaling factor for bubble sizes.
y_lab : str, default='Genes' Label for the x-axis.
legend_lab : str, default='log(CPM + 1)' Label for the colorbar legend.
bbox_to_anchor_scale : int, default=25 Vertical scale (percentage) for positioning the colorbar.
bbox_to_anchor_perc : tuple, default=(0.91, 0.63) Anchor position for the size legend (percent bubble legend).
bbox_to_anchor_group : tuple, default=(1.01, 0.4) Anchor position for the group legend.
Returns
matplotlib.figure.Figure The generated scatter plot figure.
Raises
ValueError
If metadata_list is provided but its length does not match
the number of columns in expression_data.
Notes
- Colors represent expression values normalized to the colormap.
- Bubble sizes represent occurrence values (or expression values if
occurence_datais None). - If
metadata_listis given, groups are indicated with colors and dashed vertical separators.
912def calc_DEG( 913 data, 914 metadata_list: list | None = None, 915 entities: str | list | dict | None = None, 916 sets: str | list | dict | None = None, 917 min_exp: int | float = 0, 918 min_pct: int | float = 0.1, 919 n_proc: int = 10, 920): 921 """ 922 Perform differential gene expression (DEG) analysis on gene expression data. 923 924 The function compares groups of cells or samples (defined by `entities` or 925 `sets`) using the Mann–Whitney U test. It computes p-values, adjusted 926 p-values, fold changes, standardized effect sizes, and other statistics. 927 928 Parameters 929 ---------- 930 data : pandas.DataFrame 931 Expression matrix with features (e.g., genes) as rows and samples/cells 932 as columns. 933 934 metadata_list : list or None, optional 935 Metadata grouping corresponding to the columns in `data`. Required for 936 comparisons based on sets. Default is None. 937 938 entities : list, str, dict, or None, optional 939 Defines the comparison strategy: 940 - list of sample names → compare selected cells to the rest. 941 - 'All' → compare each sample/cell to all others. 942 - dict → user-defined groups for pairwise comparison. 943 - None → must be combined with `sets`. 944 945 sets : str, dict, or None, optional 946 Defines group-based comparisons: 947 - 'All' → compare each set/group to all others. 948 - dict with two groups → perform pairwise set comparison. 949 - None → must be combined with `entities`. 950 951 min_exp : float | int, default=0 952 Minimum expression threshold for filtering features. 953 954 min_pct : float | int, default=0.1 955 Minimum proportion of samples within the target group that must express 956 a feature for it to be tested. 957 958 n_proc : int, default=10 959 Number of parallel processes to use for statistical testing. 960 961 Returns 962 ------- 963 pandas.DataFrame or dict 964 Results of the differential expression analysis: 965 - If `entities` is a list → dict with keys: 'valid_cells', 966 'control_cells', and 'DEG' (results DataFrame). 967 - If `entities == 'All'` or `sets == 'All'` → DataFrame with results 968 for all groups. 969 - If pairwise comparison (dict for `entities` or `sets`) → DataFrame 970 with results for the specified groups. 971 972 The results DataFrame contains: 973 - 'feature': feature name 974 - 'p_val': raw p-value 975 - 'adj_pval': adjusted p-value (multiple testing correction) 976 - 'pct_valid': fraction of target group expressing the feature 977 - 'pct_ctrl': fraction of control group expressing the feature 978 - 'avg_valid': mean expression in target group 979 - 'avg_ctrl': mean expression in control group 980 - 'sd_valid': standard deviation in target group 981 - 'sd_ctrl': standard deviation in control group 982 - 'esm': effect size metric 983 - 'FC': fold change 984 - 'log(FC)': log2-transformed fold change 985 - 'norm_diff': difference in mean expression 986 987 Raises 988 ------ 989 ValueError 990 - If `metadata_list` is provided but its length does not match 991 the number of columns in `data`. 992 - If neither `entities` nor `sets` is provided. 993 994 Notes 995 ----- 996 - Mann–Whitney U test is used for group comparisons. 997 - Multiple testing correction is applied using a simple 998 Benjamini–Hochberg-like method. 999 - Features expressed below `min_exp` or in fewer than `min_pct` of target 1000 samples are filtered out. 1001 - Parallelization is handled by `joblib.Parallel`. 1002 1003 Examples 1004 -------- 1005 Compare a selected list of cells against all others: 1006 1007 >>> result = calc_DEG(data, entities=["cell1", "cell2", "cell3"]) 1008 1009 Compare each group to others (based on metadata): 1010 1011 >>> result = calc_DEG(data, metadata_list=group_labels, sets="All") 1012 1013 Perform pairwise comparison between two predefined sets: 1014 1015 >>> sets = {"GroupA": ["A1", "A2"], "GroupB": ["B1", "B2"]} 1016 >>> result = calc_DEG(data, sets=sets) 1017 """ 1018 offset = 1e-100 1019 1020 metadata = {} 1021 1022 metadata["primary_names"] = [str(x) for x in data.columns] 1023 1024 if metadata_list is not None: 1025 metadata["sets"] = metadata_list 1026 1027 if len(metadata["primary_names"]) != len(metadata["sets"]): 1028 1029 raise ValueError( 1030 "Metadata list and DataFrame columns must have the same length." 1031 ) 1032 1033 else: 1034 1035 metadata["sets"] = [""] * len(metadata["primary_names"]) 1036 1037 metadata = pd.DataFrame(metadata) 1038 1039 def stat_calc(choose, feature_name): 1040 target_values = choose.loc[choose["DEG"] == "target", feature_name] 1041 rest_values = choose.loc[choose["DEG"] == "rest", feature_name] 1042 1043 pct_valid = (target_values > 0).sum() / len(target_values) 1044 pct_rest = (rest_values > 0).sum() / len(rest_values) 1045 1046 avg_valid = np.mean(target_values) 1047 avg_ctrl = np.mean(rest_values) 1048 sd_valid = np.std(target_values, ddof=1) 1049 sd_ctrl = np.std(rest_values, ddof=1) 1050 esm = (avg_valid - avg_ctrl) / np.sqrt(((sd_valid**2 + sd_ctrl**2) / 2)) 1051 1052 if np.sum(target_values) == np.sum(rest_values): 1053 p_val = 1.0 1054 else: 1055 _, p_val = stats.mannwhitneyu( 1056 target_values, rest_values, alternative="two-sided" 1057 ) 1058 1059 return { 1060 "feature": feature_name, 1061 "p_val": p_val, 1062 "pct_valid": pct_valid, 1063 "pct_ctrl": pct_rest, 1064 "avg_valid": avg_valid, 1065 "avg_ctrl": avg_ctrl, 1066 "sd_valid": sd_valid, 1067 "sd_ctrl": sd_ctrl, 1068 "esm": esm, 1069 } 1070 1071 def prepare_and_run_stat(choose, valid_group, min_exp, min_pct, n_proc): 1072 1073 def safe_min_half(series): 1074 filtered = series[(series > ((2**-1074)*2)) & (series.notna())] 1075 return filtered.min() / 2 if not filtered.empty else 0 1076 1077 tmp_dat = choose[choose["DEG"] == "target"] 1078 tmp_dat = tmp_dat.drop("DEG", axis=1) 1079 1080 counts = (tmp_dat > min_exp).sum(axis=0) 1081 1082 total_count = tmp_dat.shape[0] 1083 1084 info = pd.DataFrame( 1085 {"feature": list(tmp_dat.columns), "pct": list(counts / total_count)} 1086 ) 1087 1088 del tmp_dat 1089 1090 drop_col = info["feature"][info["pct"] <= min_pct] 1091 1092 if len(drop_col) + 1 == len(choose.columns): 1093 drop_col = info["feature"][info["pct"] == 0] 1094 1095 del info 1096 1097 choose = choose.drop(list(drop_col), axis=1) 1098 1099 results = Parallel(n_jobs=n_proc)( 1100 delayed(stat_calc)(choose, feature) 1101 for feature in tqdm(choose.columns[choose.columns != "DEG"]) 1102 ) 1103 1104 if len(results) > 0: 1105 df = pd.DataFrame(results) 1106 1107 df = df[(df["avg_valid"] > 0) | (df["avg_ctrl"] > 0)] 1108 1109 df["valid_group"] = valid_group 1110 df.sort_values(by="p_val", inplace=True) 1111 1112 num_tests = len(df) 1113 df["adj_pval"] = np.minimum( 1114 1, (df["p_val"] * num_tests) / np.arange(1, num_tests + 1) 1115 ) 1116 1117 valid_factor = safe_min_half(df["avg_valid"]) 1118 ctrl_factor = safe_min_half(df["avg_ctrl"]) 1119 1120 cv_factor = min(valid_factor, ctrl_factor) 1121 1122 if cv_factor == 0: 1123 cv_factor = max(valid_factor, ctrl_factor) 1124 1125 if not np.isfinite(cv_factor) or cv_factor == 0: 1126 cv_factor += offset 1127 1128 valid = df["avg_valid"].where( 1129 df["avg_valid"] != 0, df["avg_valid"] + cv_factor 1130 ) 1131 ctrl = df["avg_ctrl"].where( 1132 df["avg_ctrl"] != 0, df["avg_ctrl"] + cv_factor 1133 ) 1134 1135 df["FC"] = valid / ctrl 1136 1137 df["log(FC)"] = np.log2(df["FC"]) 1138 df["norm_diff"] = df["avg_valid"] - df["avg_ctrl"] 1139 1140 else: 1141 columns = [ 1142 "feature", 1143 "valid_group", 1144 "p_val", 1145 "adj_pval", 1146 "avg_valid", 1147 "avg_ctrl", 1148 "FC", 1149 "log(FC)", 1150 "norm_diff", 1151 ] 1152 df = pd.DataFrame(columns=columns) 1153 return df 1154 1155 choose = data.T 1156 1157 final_results = [] 1158 1159 if isinstance(entities, list) and sets is None: 1160 print("\nAnalysis started...\nComparing selected cells to the whole set...") 1161 1162 if metadata_list is None: 1163 choose.index = metadata["primary_names"] 1164 else: 1165 choose.index = metadata["primary_names"] + " # " + metadata["sets"] 1166 1167 if "#" not in entities[0]: 1168 choose.index = metadata["primary_names"] 1169 print( 1170 "You provided 'metadata_list', but did not include the set info (name # set) " 1171 "in the 'entities' list. " 1172 "Only the names will be compared, without considering the set information." 1173 ) 1174 1175 labels = ["target" if idx in entities else "rest" for idx in choose.index] 1176 valid = list( 1177 set(choose.index[[i for i, x in enumerate(labels) if x == "target"]]) 1178 ) 1179 1180 choose["DEG"] = labels 1181 choose = choose[choose["DEG"] != "drop"] 1182 1183 result_df = prepare_and_run_stat( 1184 choose.reset_index(drop=True), 1185 valid_group=valid, 1186 min_exp=min_exp, 1187 min_pct=min_pct, 1188 n_proc=n_proc, 1189 ) 1190 1191 return {"valid": valid, "control": "rest", "DEG": result_df} 1192 1193 elif entities == "All" and sets is None: 1194 print("\nAnalysis started...\nComparing each type of cell to others...") 1195 1196 if metadata_list is None: 1197 choose.index = metadata["primary_names"] 1198 else: 1199 choose.index = metadata["primary_names"] + " # " + metadata["sets"] 1200 1201 unique_labels = set(choose.index) 1202 1203 for label in tqdm(unique_labels): 1204 print(f"\nCalculating statistics for {label}") 1205 labels = ["target" if idx == label else "rest" for idx in choose.index] 1206 choose["DEG"] = labels 1207 choose = choose[choose["DEG"] != "drop"] 1208 result_df = prepare_and_run_stat( 1209 choose.copy(), 1210 valid_group=label, 1211 min_exp=min_exp, 1212 min_pct=min_pct, 1213 n_proc=n_proc, 1214 ) 1215 final_results.append(result_df) 1216 1217 final_results = pd.concat(final_results, ignore_index=True) 1218 1219 if metadata_list is None: 1220 final_results["valid_group"] = [ 1221 re.sub(" # ", "", x) for x in final_results["valid_group"] 1222 ] 1223 1224 return final_results 1225 1226 elif entities is None and sets == "All": 1227 print("\nAnalysis started...\nComparing each set/group to others...") 1228 choose.index = metadata["sets"] 1229 unique_sets = set(choose.index) 1230 1231 for label in tqdm(unique_sets): 1232 print(f"\nCalculating statistics for {label}") 1233 labels = ["target" if idx == label else "rest" for idx in choose.index] 1234 1235 choose["DEG"] = labels 1236 choose = choose[choose["DEG"] != "drop"] 1237 result_df = prepare_and_run_stat( 1238 choose.copy(), 1239 valid_group=label, 1240 min_exp=min_exp, 1241 min_pct=min_pct, 1242 n_proc=n_proc, 1243 ) 1244 final_results.append(result_df) 1245 1246 return pd.concat(final_results, ignore_index=True) 1247 1248 elif entities is None and isinstance(sets, dict): 1249 print("\nAnalysis started...\nComparing groups...") 1250 choose.index = metadata["sets"] 1251 1252 group_list = list(sets.keys()) 1253 if len(group_list) != 2: 1254 print("Only pairwise group comparison is supported.") 1255 return None 1256 1257 labels = [ 1258 ( 1259 "target" 1260 if idx in sets[group_list[0]] 1261 else "rest" if idx in sets[group_list[1]] else "drop" 1262 ) 1263 for idx in choose.index 1264 ] 1265 choose["DEG"] = labels 1266 choose = choose[choose["DEG"] != "drop"] 1267 1268 result_df = prepare_and_run_stat( 1269 choose.reset_index(drop=True), 1270 valid_group=group_list[0], 1271 min_exp=min_exp, 1272 min_pct=min_pct, 1273 n_proc=n_proc, 1274 ) 1275 return result_df 1276 1277 elif isinstance(entities, dict) and sets is None: 1278 print("\nAnalysis started...\nComparing groups...") 1279 1280 if metadata_list is None: 1281 choose.index = metadata["primary_names"] 1282 else: 1283 choose.index = metadata["primary_names"] + " # " + metadata["sets"] 1284 if "#" not in entities[list(entities.keys())[0]][0]: 1285 choose.index = metadata["primary_names"] 1286 print( 1287 "You provided 'metadata_list', but did not include the set info (name # set) " 1288 "in the 'entities' dict. " 1289 "Only the names will be compared, without considering the set information." 1290 ) 1291 1292 group_list = list(entities.keys()) 1293 if len(group_list) != 2: 1294 print("Only pairwise group comparison is supported.") 1295 return None 1296 1297 labels = [ 1298 ( 1299 "target" 1300 if idx in entities[group_list[0]] 1301 else "rest" if idx in entities[group_list[1]] else "drop" 1302 ) 1303 for idx in choose.index 1304 ] 1305 1306 choose["DEG"] = labels 1307 choose = choose[choose["DEG"] != "drop"] 1308 1309 result_df = prepare_and_run_stat( 1310 choose.reset_index(drop=True), 1311 valid_group=group_list[0], 1312 min_exp=min_exp, 1313 min_pct=min_pct, 1314 n_proc=n_proc, 1315 ) 1316 1317 return result_df.reset_index(drop=True) 1318 1319 else: 1320 raise ValueError( 1321 "You must specify either 'entities' or 'sets'. None were provided, which is not allowed for this analysis." 1322 )
Perform differential gene expression (DEG) analysis on gene expression data.
The function compares groups of cells or samples (defined by entities or
sets) using the Mann–Whitney U test. It computes p-values, adjusted
p-values, fold changes, standardized effect sizes, and other statistics.
Parameters
data : pandas.DataFrame Expression matrix with features (e.g., genes) as rows and samples/cells as columns.
metadata_list : list or None, optional
Metadata grouping corresponding to the columns in data. Required for
comparisons based on sets. Default is None.
entities : list, str, dict, or None, optional
Defines the comparison strategy:
- list of sample names → compare selected cells to the rest.
- 'All' → compare each sample/cell to all others.
- dict → user-defined groups for pairwise comparison.
- None → must be combined with sets.
sets : str, dict, or None, optional
Defines group-based comparisons:
- 'All' → compare each set/group to all others.
- dict with two groups → perform pairwise set comparison.
- None → must be combined with entities.
min_exp : float | int, default=0 Minimum expression threshold for filtering features.
min_pct : float | int, default=0.1 Minimum proportion of samples within the target group that must express a feature for it to be tested.
n_proc : int, default=10 Number of parallel processes to use for statistical testing.
Returns
pandas.DataFrame or dict
Results of the differential expression analysis:
- If entities is a list → dict with keys: 'valid_cells',
'control_cells', and 'DEG' (results DataFrame).
- If entities == 'All' or sets == 'All' → DataFrame with results
for all groups.
- If pairwise comparison (dict for entities or sets) → DataFrame
with results for the specified groups.
The results DataFrame contains:
- 'feature': feature name
- 'p_val': raw p-value
- 'adj_pval': adjusted p-value (multiple testing correction)
- 'pct_valid': fraction of target group expressing the feature
- 'pct_ctrl': fraction of control group expressing the feature
- 'avg_valid': mean expression in target group
- 'avg_ctrl': mean expression in control group
- 'sd_valid': standard deviation in target group
- 'sd_ctrl': standard deviation in control group
- 'esm': effect size metric
- 'FC': fold change
- 'log(FC)': log2-transformed fold change
- 'norm_diff': difference in mean expression
Raises
ValueError
- If metadata_list is provided but its length does not match
the number of columns in data.
- If neither entities nor sets is provided.
Notes
- Mann–Whitney U test is used for group comparisons.
- Multiple testing correction is applied using a simple Benjamini–Hochberg-like method.
- Features expressed below
min_expor in fewer thanmin_pctof target samples are filtered out. - Parallelization is handled by
joblib.Parallel.
Examples
Compare a selected list of cells against all others:
>>> result = calc_DEG(data, entities=["cell1", "cell2", "cell3"])
Compare each group to others (based on metadata):
>>> result = calc_DEG(data, metadata_list=group_labels, sets="All")
Perform pairwise comparison between two predefined sets:
>>> sets = {"GroupA": ["A1", "A2"], "GroupB": ["B1", "B2"]}
>>> result = calc_DEG(data, sets=sets)
1325def average(data): 1326 """ 1327 Compute the column-wise average of a DataFrame, aggregating by column names. 1328 1329 If multiple columns share the same name, their values are averaged. 1330 1331 Parameters 1332 ---------- 1333 data : pandas.DataFrame 1334 Input DataFrame with numeric values. Columns with identical names 1335 will be aggregated by their mean. 1336 1337 Returns 1338 ------- 1339 pandas.DataFrame 1340 DataFrame with the same rows as the input but with unique columns, 1341 where duplicate columns have been replaced by their mean values. 1342 """ 1343 1344 wide_data = data 1345 1346 aggregated_df = wide_data.T.groupby(level=0).mean().T 1347 1348 return aggregated_df
Compute the column-wise average of a DataFrame, aggregating by column names.
If multiple columns share the same name, their values are averaged.
Parameters
data : pandas.DataFrame Input DataFrame with numeric values. Columns with identical names will be aggregated by their mean.
Returns
pandas.DataFrame DataFrame with the same rows as the input but with unique columns, where duplicate columns have been replaced by their mean values.
1351def occurrence(data): 1352 """ 1353 Calculate the occurrence frequency of features in a DataFrame. 1354 1355 Converts the input DataFrame to binary (presence/absence) and computes 1356 the proportion of non-zero entries for each feature, aggregating by 1357 column names if duplicates exist. 1358 1359 Parameters 1360 ---------- 1361 data : pandas.DataFrame 1362 Input DataFrame with numeric values. Each column represents a feature. 1363 1364 Returns 1365 ------- 1366 pandas.DataFrame 1367 DataFrame with the same rows as the input, where each value represents 1368 the proportion of samples in which the feature is present (non-zero). 1369 Columns with identical names are aggregated. 1370 """ 1371 1372 binary_data = (data > 0).astype(int) 1373 1374 counts = binary_data.columns.value_counts() 1375 1376 binary_data = binary_data.T.groupby(level=0).sum().T 1377 binary_data = binary_data.astype(float) 1378 1379 for i in counts.index: 1380 binary_data.loc[:, i] = (binary_data.loc[:, i] / counts[i]).astype(float) 1381 1382 return binary_data
Calculate the occurrence frequency of features in a DataFrame.
Converts the input DataFrame to binary (presence/absence) and computes the proportion of non-zero entries for each feature, aggregating by column names if duplicates exist.
Parameters
data : pandas.DataFrame Input DataFrame with numeric values. Each column represents a feature.
Returns
pandas.DataFrame DataFrame with the same rows as the input, where each value represents the proportion of samples in which the feature is present (non-zero). Columns with identical names are aggregated.
1385def add_subnames(names_list: list, parent_name: str, new_clusters: list): 1386 """ 1387 Append sub-cluster names to a parent name within a list of names. 1388 1389 This function replaces occurrences of `parent_name` in `names_list` with 1390 a concatenation of the parent name and corresponding sub-cluster name 1391 from `new_clusters` (formatted as "parent.subcluster"). Non-matching names 1392 are left unchanged. 1393 1394 Parameters 1395 ---------- 1396 names_list : list 1397 Original list of names (e.g., column names or cluster labels). 1398 1399 parent_name : str 1400 Name of the parent cluster to which sub-cluster names will be added. 1401 Must exist in `names_list`. 1402 1403 new_clusters : list 1404 List of sub-cluster names. Its length must match the number of times 1405 `parent_name` occurs in `names_list`. 1406 1407 Returns 1408 ------- 1409 list 1410 Updated list of names with sub-cluster names appended to the parent name. 1411 1412 Raises 1413 ------ 1414 ValueError 1415 - If `parent_name` is not found in `names_list`. 1416 - If `new_clusters` length does not match the number of occurrences of 1417 `parent_name`. 1418 1419 Examples 1420 -------- 1421 >>> add_subnames(['A', 'B', 'A'], 'A', ['1', '2']) 1422 ['A.1', 'B', 'A.2'] 1423 """ 1424 1425 if str(parent_name) not in [str(x) for x in names_list]: 1426 raise ValueError( 1427 "Parent name is missing from the original dataset`s column names!" 1428 ) 1429 1430 if len(new_clusters) != len([x for x in names_list if str(x) == str(parent_name)]): 1431 raise ValueError( 1432 "New cluster names list has a different length than the number of clusters in the original dataset!" 1433 ) 1434 1435 new_names = [] 1436 ixn = 0 1437 for _, i in enumerate(names_list): 1438 if str(i) == str(parent_name): 1439 1440 new_names.append(f"{parent_name}.{new_clusters[ixn]}") 1441 ixn += 1 1442 1443 else: 1444 new_names.append(i) 1445 1446 return new_names
Append sub-cluster names to a parent name within a list of names.
This function replaces occurrences of parent_name in names_list with
a concatenation of the parent name and corresponding sub-cluster name
from new_clusters (formatted as "parent.subcluster"). Non-matching names
are left unchanged.
Parameters
names_list : list Original list of names (e.g., column names or cluster labels).
parent_name : str
Name of the parent cluster to which sub-cluster names will be added.
Must exist in names_list.
new_clusters : list
List of sub-cluster names. Its length must match the number of times
parent_name occurs in names_list.
Returns
list Updated list of names with sub-cluster names appended to the parent name.
Raises
ValueError
- If parent_name is not found in names_list.
- If new_clusters length does not match the number of occurrences of
parent_name.
Examples
>>> add_subnames(['A', 'B', 'A'], 'A', ['1', '2'])
['A.1', 'B', 'A.2']
1449def development_clust( 1450 data: pd.DataFrame, method: str = "ward", img_width: int = 5, img_high: int = 5 1451): 1452 """ 1453 Perform hierarchical clustering on the columns of a DataFrame and plot a dendrogram. 1454 1455 Uses Ward's method to cluster the transposed data (columns) and generates 1456 a dendrogram showing the relationships between features or samples. 1457 1458 Parameters 1459 ---------- 1460 data : pandas.DataFrame 1461 Input DataFrame with features as rows and samples/columns to be clustered. 1462 1463 method : str 1464 Method for hierarchical clustering. Options include: 1465 - 'ward' : minimizes the variance of clusters being merged. 1466 - 'single' : uses the minimum of the distances between all observations of the two sets. 1467 - 'complete' : uses the maximum of the distances between all observations of the two sets. 1468 - 'average' : uses the average of the distances between all observations of the two sets. 1469 1470 img_width : int or float, default=5 1471 Width of the resulting figure in inches. 1472 1473 img_high : int or float, default=5 1474 Height of the resulting figure in inches. 1475 1476 Returns 1477 ------- 1478 matplotlib.figure.Figure 1479 The dendrogram figure. 1480 """ 1481 1482 z = linkage(data.T, method=method) 1483 1484 figure, ax = plt.subplots(figsize=(img_width, img_high)) 1485 1486 dendrogram(z, labels=data.columns, orientation="left", ax=ax) 1487 1488 return figure
Perform hierarchical clustering on the columns of a DataFrame and plot a dendrogram.
Uses Ward's method to cluster the transposed data (columns) and generates a dendrogram showing the relationships between features or samples.
Parameters
data : pandas.DataFrame Input DataFrame with features as rows and samples/columns to be clustered.
method : str Method for hierarchical clustering. Options include:
- 'ward' : minimizes the variance of clusters being merged.
- 'single' : uses the minimum of the distances between all observations of the two sets.
- 'complete' : uses the maximum of the distances between all observations of the two sets.
- 'average' : uses the average of the distances between all observations of the two sets.
img_width : int or float, default=5 Width of the resulting figure in inches.
img_high : int or float, default=5 Height of the resulting figure in inches.
Returns
matplotlib.figure.Figure The dendrogram figure.
1491def adjust_cells_to_group_mean(data, data_avg, beta=0.2): 1492 """ 1493 Adjust each cell's values towards the mean of its group (centroid). 1494 1495 This function moves each cell's values in `data` slightly towards the 1496 corresponding group mean in `data_avg`, controlled by the parameter `beta`. 1497 1498 Parameters 1499 ---------- 1500 data : pandas.DataFrame 1501 Original data with features as rows and cells/samples as columns. 1502 1503 data_avg : pandas.DataFrame 1504 DataFrame of group averages (centroids) with features as rows and 1505 group names as columns. 1506 1507 beta : float, default=0.2 1508 Weight for adjustment towards the group mean. 0 = no adjustment, 1509 1 = fully replaced by the group mean. 1510 1511 Returns 1512 ------- 1513 pandas.DataFrame 1514 Adjusted data with the same shape as the input `data`. 1515 """ 1516 1517 df_adjusted = data.copy() 1518 1519 for group_name in data_avg.columns: 1520 col_idx = [ 1521 i 1522 for i, c in enumerate(df_adjusted.columns) 1523 if str(c).startswith(group_name) 1524 ] 1525 if not col_idx: 1526 continue 1527 1528 centroid = data_avg.loc[df_adjusted.index, group_name].to_numpy()[:, None] 1529 1530 df_adjusted.iloc[:, col_idx] = (1 - beta) * df_adjusted.iloc[ 1531 :, col_idx 1532 ].to_numpy() + beta * centroid 1533 1534 return df_adjusted
Adjust each cell's values towards the mean of its group (centroid).
This function moves each cell's values in data slightly towards the
corresponding group mean in data_avg, controlled by the parameter beta.
Parameters
data : pandas.DataFrame Original data with features as rows and cells/samples as columns.
data_avg : pandas.DataFrame DataFrame of group averages (centroids) with features as rows and group names as columns.
beta : float, default=0.2 Weight for adjustment towards the group mean. 0 = no adjustment, 1 = fully replaced by the group mean.
Returns
pandas.DataFrame
Adjusted data with the same shape as the input data.