cfi_toolkit.CellGraph

  1import os
  2import sys
  3from collections import Counter
  4
  5import matplotlib.pyplot as plt
  6import networkx as nx
  7import numpy as np
  8import pandas as pd
  9from adjustText import adjust_text
 10from matplotlib.lines import Line2D
 11
 12_old_stdout = sys.stdout
 13sys.stdout = open(os.devnull, "w")
 14
 15from gedspy import enrichment_heatmap
 16
 17sys.stdout.close()
 18sys.stdout = _old_stdout
 19
 20
 21def gene_interaction_network(idata: pd.DataFrame, min_con: int = 2):
 22    """
 23    Creates a gene or protein interaction network graph.
 24
 25    The network is built from gene/protein interaction data. Nodes represent genes or proteins,
 26    edges represent interactions, and edge colors indicate the type of interaction.
 27
 28    Parameters
 29    ----------
 30    idata : pd.DataFrame
 31      A DataFrame containing the interaction data with columns:
 32        - "A" (str): first gene/protein in the interaction
 33        - "B" (str): second gene/protein in the interaction
 34        - "connection_type" (str): interaction type, e.g., "gene -> protein"
 35
 36    min_con : int, optional
 37        Minimum number of connections (node degree) required
 38        for a gene/protein to be included in the network. Default is 2.
 39
 40    Returns
 41    -------
 42    nx.Graph: A NetworkX graph representing the interaction network.
 43
 44        - Nodes have attributes:
 45            - "size": node size based on connection count (log-scaled)
 46            - "color": node color (default is 'khaki')
 47
 48        - Edges have attributes:
 49            - "color": edge color based on interaction type
 50
 51    Example
 52    -------
 53    >>> G = gene_interaction_network(interactions_df, min_con=3)
 54
 55    >>> nx.draw(G, with_labels=True, node_size=[G.nodes[n]['size'] for n in G.nodes()])
 56    """
 57
 58    inter = idata
 59    inter = inter[["A", "B", "connection_type"]]
 60
 61    dict_meta = pd.DataFrame(
 62        {
 63            "interactions": [
 64                ["gene -> gene"],
 65                ["protein -> protein"],
 66                ["gene -> protein"],
 67                ["protein -> gene"],
 68                ["gene -> gene", "protein -> protein"],
 69                ["gene -> gene", "gene -> protein"],
 70                ["gene -> gene", "protein -> gene"],
 71                ["protein -> protein", "gene -> protein"],
 72                ["protein -> protein", "protein -> gene"],
 73                ["gene -> protein", "protein -> gene"],
 74                ["gene -> gene", "protein -> protein", "gene -> protein"],
 75                ["gene -> gene", "protein -> protein", "protein -> gene"],
 76                ["gene -> gene", "gene -> protein", "protein -> gene"],
 77                ["protein -> protein", "gene -> protein", "protein -> gene"],
 78                [
 79                    "gene -> gene",
 80                    "protein -> protein",
 81                    "gene -> protein",
 82                    "protein -> gene",
 83                ],
 84            ],
 85            "color": [
 86                "#f67089",
 87                "#f47832",
 88                "#ca9213",
 89                "#ad9d31",
 90                "#8eb041",
 91                "#4fb14f",
 92                "#33b07a",
 93                "#35ae99",
 94                "#36acae",
 95                "#38a9c5",
 96                "#3aa3ec",
 97                "#957cf4",
 98                "#cd79f4",
 99                "#f35fb5",
100                "#f669b7",
101            ],
102        }
103    )
104
105    genes_list = list(inter["A"]) + list(inter["B"])
106
107    genes_list = Counter(genes_list)
108
109    genes_list = pd.DataFrame(genes_list.items(), columns=["features", "n"])
110
111    genes_list = genes_list.sort_values("n", ascending=False)
112
113    genes_list = genes_list[genes_list["n"] >= min_con]
114
115    inter = inter[inter["A"].isin(list(genes_list["features"]))]
116    inter = inter[inter["B"].isin(list(genes_list["features"]))]
117
118    inter = inter.groupby(["A", "B"]).agg({"connection_type": list}).reset_index()
119
120    inter["color"] = "black"
121
122    for inx in inter.index:
123        for inx2 in dict_meta.index:
124            if set(inter["connection_type"][inx]) == set(
125                dict_meta["interactions"][inx2]
126            ):
127                inter["color"][inx] = dict_meta["color"][inx2]
128                break
129
130    G = nx.Graph()
131
132    for _, row in genes_list.iterrows():
133        node = row["features"]
134        color = "khaki"
135        weight = np.log2(row["n"] * 500)
136        G.add_node(node, size=weight, color=color)
137
138    for _, row in inter.iterrows():
139        source = row["A"]
140        target = row["B"]
141        color = row["color"]
142        G.add_edge(source, target, color=color)
143
144    return G
145
146
147def encrichment_cell_heatmap(
148    data: pd.DataFrame,
149    fig_size=(35, 25),
150    sets=None,
151    top_n=2,
152    test="FISH",
153    adj="BH",
154    parent_inc=False,
155    font_size=16,
156    clustering: str | None = "ward",
157    scale: bool = True,
158):
159    """
160    Creates a functional enrichment heatmap for cell types.
161
162    This function visualizes the most significant functional terms (GO, KEGG, REACTOME, specificity)
163    across different cell types.
164
165    Parameters
166    ----------
167    data : pd.DataFrame
168        Input data containing columns dependent on the source (GO-TERM, KEGG, REACTOME, specificity).
169
170    fig_size : tuple, optional
171        Figure size (width, height). Default is (35, 25).
172
173    sets : list, optional
174        List of specific cell sets to include. Default is None (all sets).
175
176    top_n : int, optional
177        Number of top terms to include per cell. Default is 2.
178
179    test : str, optional
180        Name of the statistical test column. Default is 'FISH'.
181
182    adj : str, optional
183        P-value adjustment method. Default is 'BH'.
184
185    parent_inc : bool, optional
186        Whether to include parent terms in labels. Default is False.
187
188    font_size : int, optional
189        Font size for the heatmap. Default is 16.
190
191    clustering : str | None, optional
192        Clustering method for rows/columns ('ward', 'single', None). Default is 'ward'.
193
194    scale : bool, optional
195        Whether to scale values before plotting. Default is True.
196
197    Returns
198    -------
199    matplotlib.figure.Figure
200        A heatmap figure of functional enrichment per cell type.
201
202    Raises
203    ------
204    ValueError: If the 'source' column in the data is not one of ['GO-TERM', 'KEGG', 'REACTOME', 'specificity'].
205
206    Example
207    -------
208    >>> fig = encrichment_cell_heatmap(data_df, top_n=3, parent_inc=True)
209
210    >>> fig.savefig('cell_heatmap.svg', bbox_inches='tight')
211    """
212
213    if not any(
214        x in data["source"].iloc[0]
215        for x in ("GO-TERM", "KEGG", "REACTOME", "specificity")
216    ):
217        raise ValueError(
218            "Invalid value for 'source' in data. Expected: 'GO-TERM', 'KEGG', 'REACTOME' or 'specificity'."
219        )
220
221    set_col = "cell"
222    if data["source"].iloc[0] == "GO-TERM":
223        term_col = "child_name"
224        parent_col = "parent_name"
225
226    elif data["source"].iloc[0] == "KEGG":
227        term_col = "3rd"
228        parent_col = "2nd"
229
230    elif data["source"].iloc[0] == "REACTOME":
231        term_col = "pathway"
232        parent_col = "top_level_pathway"
233
234    elif data["source"].iloc[0] == "specificity":
235        term_col = "specificity"
236        parent_col = "None"
237
238    title = f"Cells - {data['source'].iloc[0]}"
239
240    if isinstance(sets, list):
241        data[data["cell"].isin(sets)]
242
243    stat_col = [
244        x
245        for x in data.columns
246        if test in x and adj in x and parent_col.upper() not in x.upper()
247    ][0]
248
249    if parent_inc and data["source"].iloc[0] != "specificity":
250        data[term_col] = data.apply(
251            lambda row: f"{row[parent_col]} -> {row[term_col]}", axis=1
252        )
253
254    data = data.loc[data.groupby([set_col, term_col])[stat_col].idxmin()].reset_index(
255        drop=True
256    )
257
258    data = (
259        data.sort_values(stat_col, ascending=True).groupby(set_col).head(top_n)
260    ).reset_index(drop=True)
261
262    if sets is None and len(list(set(data["cell"]))) < 2:
263        if clustering is not None:
264            clustering = None
265            print(
266                "Clustering could not be conducted, because only one group is available in this analysis data."
267            )
268
269    figure = enrichment_heatmap(
270        data=data,
271        stat_col=stat_col,
272        term_col=term_col,
273        set_col=set_col,
274        sets=sets,
275        title=title,
276        fig_size=fig_size,
277        font_size=font_size,
278        scale=scale,
279        clustering=clustering,
280    )
281
282    return figure
283
284
285def draw_cell_conections(
286    data: pd.DataFrame, top_n: int = 5, weight_percentile_threshold: int | float = 0.75
287):
288    """
289    Creates a cell-cell interaction network graph based on co-occurrence frequency.
290
291    The function generates a NetworkX graph where nodes represent cell types,
292    and edges represent the frequency of interactions between cells.
293
294    Parameters
295    ----------
296    data : pd.DataFrame)
297        A DataFrame containing columns:
298            - "cell1" (str): source cell type
299            - "cell2" (str): target cell type
300
301    top_n : int, optional)
302        Maximal n neighboured interactions to source cell. Default is 5.
303
304    weight_percentile_threshold : float, optional
305        Percentile used to compute the minimum interaction weight threshold.
306        Interactions with weights below this percentile are filtered out.
307        If no interaction for a given source cell meets this threshold,
308        the top-1 interaction (highest weight) is retained. Default is 0.75.
309
310    Returns
311    -------
312    nx.Graph
313        A NetworkX graph with attributes:
314
315            - Nodes:
316                - "size": node size (default 10)
317                - "color": node color (default "#FFA07A")
318
319            - Edges:
320                - "weight": edge weight (log-transformed from frequency)
321                - "color": edge color (default '#DCDCDC')
322                - "alpha": edge transparency (default 0.05)
323
324    Example
325    -------
326    >>> G = draw_cell_conections(cell_interactions_df, top_n=10)
327
328    >>> nx.draw(G, with_labels=True, node_size=[G.nodes[n]['size'] for n in G.nodes()])
329    """
330
331    cell_cell_df = (
332        data.groupby(["cell1", "cell2"])
333        .size()
334        .reset_index(name="weight")
335        .sort_values("weight", ascending=False)
336    )
337
338    min_weight = cell_cell_df["weight"].quantile(weight_percentile_threshold)
339
340    df_top = (
341        cell_cell_df.sort_values("weight", ascending=False)
342        .groupby("cell1", group_keys=False)
343        .apply(
344            lambda x: (
345                x[x["weight"] >= min_weight].head(top_n)
346                if (x["weight"] >= min_weight).any()
347                else x.head(1)
348            )
349        )
350    )
351
352    cell_cell_df["weight"] = np.log1p(cell_cell_df["weight"])
353    cell_list = list(set(list(cell_cell_df["cell1"]) + list(cell_cell_df["cell2"])))
354
355    G = nx.Graph()
356
357    for c in cell_list:
358        node = c
359        color = "#FFA07A"
360        weight = 10
361        G.add_node(node, size=weight, color=color)
362
363    for _, row in df_top.iterrows():
364        source = row["cell1"]
365        target = row["cell2"]
366        color = "#DCDCDC"
367        weight = row["weight"]
368
369        G.add_edge(source, target, weight=weight, color=color, alpha=0.05)
370
371    nx.spring_layout(G, weight="weight", k=0.1, iterations=500)
372
373    return G
374
375
376def volcano_plot_conections(
377    deg_data: pd.DataFrame,
378    p_adj: bool = True,
379    top: int = 25,
380    top_rank: str = "p_value",
381    p_val: float | int = 0.05,
382    lfc: float | int = 0.25,
383    rescale_adj: bool = True,
384    image_width: int = 12,
385    image_high: int = 12,
386):
387    """
388    Generate a volcano plot from differential expression results.
389
390    A volcano plot visualizes the relationship between statistical significance
391    (p-values or standarized p-value) and log(fold change) for each gene, highlighting
392    genes that pass significance thresholds.
393
394    Parameters
395    ----------
396    deg_data : pandas.DataFrame
397        DataFrame containing differential expression results from calc_DEG() function.
398
399    p_adj : bool, default=True
400        If True, use adjusted p-values. If False, use raw p-values.
401
402    top : int, default=25
403        Number of top significant genes to highlight on the plot.
404
405    top_rank : str, default='p_value'
406        Statistic used primarily to determine the top significant genes to highlight on the plot. ['p_value' or 'FC']
407
408    p_val : float | int, default=0.05
409        Significance threshold for p-values (or adjusted p-values).
410
411    lfc : float | int, default=0.25
412        Threshold for absolute log fold change.
413
414    rescale_adj : bool, default=True
415        If True, rescale p-values to avoid long breaks caused by outlier values.
416
417    image_width : int, default=12
418        Width of the generated plot in inches.
419
420    image_high : int, default=12
421        Height of the generated plot in inches.
422
423    Returns
424    -------
425    matplotlib.figure.Figure
426        The generated volcano plot figure.
427    """
428
429    if top_rank.upper() not in ["FC", "P_VALUE"]:
430        raise ValueError("top_rank must be either 'FC' or 'p_value'")
431
432    if p_adj:
433        pv = "adj_pval"
434    else:
435        pv = "p_val"
436
437    deg_df = deg_data.copy()
438
439    shift = 0.25
440
441    p_val_scale = "-log(p_val)"
442
443    min_minus = min(deg_df[pv][(deg_df[pv] != 0) & (deg_df["log(FC)"] < 0)])
444    min_plus = min(deg_df[pv][(deg_df[pv] != 0) & (deg_df["log(FC)"] > 0)])
445
446    zero_p_plus = deg_df[(deg_df[pv] == 0) & (deg_df["log(FC)"] > 0)]
447    zero_p_plus = zero_p_plus.sort_values(by="log(FC)", ascending=False).reset_index(
448        drop=True
449    )
450    zero_p_plus[pv] = [
451        (shift * x) * min_plus for x in range(1, len(zero_p_plus.index) + 1)
452    ]
453
454    zero_p_minus = deg_df[(deg_df[pv] == 0) & (deg_df["log(FC)"] < 0)]
455    zero_p_minus = zero_p_minus.sort_values(by="log(FC)", ascending=True).reset_index(
456        drop=True
457    )
458    zero_p_minus[pv] = [
459        (shift * x) * min_minus for x in range(1, len(zero_p_minus.index) + 1)
460    ]
461
462    tmp_p = deg_df[
463        ((deg_df[pv] != 0) & (deg_df["log(FC)"] < 0))
464        | ((deg_df[pv] != 0) & (deg_df["log(FC)"] > 0))
465    ]
466
467    del deg_df
468
469    deg_df = pd.concat([zero_p_plus, tmp_p, zero_p_minus], ignore_index=True)
470
471    deg_df[p_val_scale] = -np.log10(deg_df[pv])
472
473    deg_df["top100"] = None
474
475    if rescale_adj:
476
477        deg_df = deg_df.sort_values(by=p_val_scale, ascending=False)
478
479        deg_df = deg_df.reset_index(drop=True)
480
481        eps = 1e-300
482        doubled = []
483        ratio = []
484        for n, i in enumerate(deg_df.index):
485            for j in range(1, 6):
486                if (
487                    n + j < len(deg_df.index)
488                    and (deg_df[p_val_scale][n] + eps)
489                    / (deg_df[p_val_scale][n + j] + eps)
490                    >= 2
491                ):
492                    doubled.append(n)
493                    ratio.append(
494                        (deg_df[p_val_scale][n + j] + eps)
495                        / (deg_df[p_val_scale][n] + eps)
496                    )
497
498        df = pd.DataFrame({"doubled": doubled, "ratio": ratio})
499        df = df[df["doubled"] < 100]
500
501        df["ratio"] = (1 - df["ratio"]) / 5
502        df = df.reset_index(drop=True)
503
504        df = df.sort_values("doubled")
505
506        if len(df["doubled"]) == 1 and 0 in df["doubled"]:
507            df = df
508        else:
509            doubled2 = []
510
511            for l in df["doubled"]:
512                if l + 1 != len(doubled) and l + 1 - l == 1:
513                    doubled2.append(l)
514                    doubled2.append(l + 1)
515                else:
516                    break
517
518            doubled2 = sorted(set(doubled2), reverse=True)
519
520        if len(doubled2) > 1:
521            df = df[df["doubled"].isin(doubled2)]
522            df = df.sort_values("doubled", ascending=False)
523            df = df.reset_index(drop=True)
524            for c in df.index:
525                deg_df.loc[df["doubled"][c], p_val_scale] = deg_df.loc[
526                    df["doubled"][c] + 1, p_val_scale
527                ] * (1 + df["ratio"][c])
528
529    deg_df.loc[(deg_df["log(FC)"] <= 0) & (deg_df[pv] <= p_val), "top100"] = "bisque"
530    deg_df.loc[(deg_df["log(FC)"] > 0) & (deg_df[pv] <= p_val), "top100"] = "skyblue"
531    deg_df.loc[deg_df[pv] > p_val, "top100"] = "lightgray"
532
533    if lfc > 0:
534        deg_df.loc[
535            (deg_df["log(FC)"] <= lfc) & (deg_df["log(FC)"] >= -lfc), "top100"
536        ] = "lightgray"
537
538    down_int = len(
539        deg_df["top100"][(deg_df["log(FC)"] <= lfc * -1) & (deg_df[pv] <= p_val)]
540    )
541    up_int = len(deg_df["top100"][(deg_df["log(FC)"] > lfc) & (deg_df[pv] <= p_val)])
542
543    deg_df_up = deg_df[deg_df["log(FC)"] > 0]
544
545    if top_rank.upper() == "P_VALUE":
546        deg_df_up = deg_df_up.sort_values([pv, "log(FC)"], ascending=[True, False])
547    elif top_rank.upper() == "FC":
548        deg_df_up = deg_df_up.sort_values(["log(FC)", pv], ascending=[False, True])
549
550    deg_df_up = deg_df_up.reset_index(drop=True)
551
552    n = -1
553    l = 0
554    while True:
555        n += 1
556        if deg_df_up["log(FC)"][n] > lfc and deg_df_up[pv][n] <= p_val:
557            deg_df_up.loc[n, "top100"] = "dodgerblue"
558            l += 1
559        if l == top or deg_df_up[pv][n] > p_val:
560            break
561
562    deg_df_down = deg_df[deg_df["log(FC)"] <= 0]
563
564    if top_rank.upper() == "P_VALUE":
565        deg_df_down = deg_df_down.sort_values([pv, "log(FC)"], ascending=[True, True])
566    elif top_rank.upper() == "FC":
567        deg_df_down = deg_df_down.sort_values(["log(FC)", pv], ascending=[True, True])
568
569    deg_df_down = deg_df_down.reset_index(drop=True)
570
571    n = -1
572    l = 0
573    while True:
574        n += 1
575        if deg_df_down["log(FC)"][n] < lfc * -1 and deg_df_down[pv][n] <= p_val:
576            deg_df_down.loc[n, "top100"] = "tomato"
577
578            l += 1
579        if l == top or deg_df_down[pv][n] > p_val:
580            break
581
582    deg_df = pd.concat([deg_df_up, deg_df_down])
583
584    que = ["lightgray", "bisque", "skyblue", "tomato", "dodgerblue"]
585
586    deg_df = deg_df.sort_values(
587        by="top100", key=lambda x: x.map({v: i for i, v in enumerate(que)})
588    )
589
590    deg_df = deg_df.reset_index(drop=True)
591
592    fig, ax = plt.subplots(figsize=(image_width, image_high))
593
594    plt.scatter(
595        x=deg_df["log(FC)"], y=deg_df[p_val_scale], color=deg_df["top100"], zorder=2
596    )
597
598    tl = deg_df[p_val_scale][deg_df[pv] >= p_val]
599
600    if len(tl) > 0:
601
602        line_p = np.max(tl)
603
604    else:
605        line_p = np.min(deg_df[p_val_scale])
606
607    plt.plot(
608        [max(deg_df["log(FC)"]) * -1.1, max(deg_df["log(FC)"]) * 1.1],
609        [line_p, line_p],
610        linestyle="--",
611        linewidth=3,
612        color="lightgray",
613        zorder=1,
614    )
615
616    if lfc > 0:
617        plt.plot(
618            [lfc * -1, lfc * -1],
619            [-3, max(deg_df[p_val_scale]) * 1.1],
620            linestyle="--",
621            linewidth=3,
622            color="lightgray",
623            zorder=1,
624        )
625        plt.plot(
626            [lfc, lfc],
627            [-3, max(deg_df[p_val_scale]) * 1.1],
628            linestyle="--",
629            linewidth=3,
630            color="lightgray",
631            zorder=1,
632        )
633
634    plt.xlabel("log(FC)")
635    plt.ylabel(p_val_scale)
636    plt.title("Volcano plot")
637
638    plt.ylim(min(deg_df[p_val_scale]) - 5, max(deg_df[p_val_scale]) * 1.3)
639
640    texts = [
641        ax.text(deg_df["log(FC)"][i], deg_df[p_val_scale][i], deg_df["feature"][i])
642        for i in deg_df.index
643        if deg_df["top100"][i] in ["dodgerblue", "tomato"]
644    ]
645
646    adjust_text(texts, arrowprops=dict(arrowstyle="-", color="gray", alpha=0.5))
647
648    legend_elements = [
649        Line2D(
650            [0],
651            [0],
652            marker="o",
653            color="w",
654            label="top-upregulated",
655            markerfacecolor="dodgerblue",
656            markersize=10,
657        ),
658        Line2D(
659            [0],
660            [0],
661            marker="o",
662            color="w",
663            label="top-downregulated",
664            markerfacecolor="tomato",
665            markersize=10,
666        ),
667        Line2D(
668            [0],
669            [0],
670            marker="o",
671            color="w",
672            label="upregulated",
673            markerfacecolor="skyblue",
674            markersize=10,
675        ),
676        Line2D(
677            [0],
678            [0],
679            marker="o",
680            color="w",
681            label="downregulated",
682            markerfacecolor="bisque",
683            markersize=10,
684        ),
685        Line2D(
686            [0],
687            [0],
688            marker="o",
689            color="w",
690            label="non-significant",
691            markerfacecolor="lightgray",
692            markersize=10,
693        ),
694    ]
695
696    ax.legend(handles=legend_elements, loc="upper right")
697    ax.grid(visible=False)
698
699    ax.annotate(
700        f"\nmin {pv} = " + str(p_val),
701        xy=(0.025, 0.975),
702        xycoords="axes fraction",
703        fontsize=12,
704    )
705
706    if lfc > 0:
707        ax.annotate(
708            "\nmin log(FC) = " + str(lfc),
709            xy=(0.025, 0.95),
710            xycoords="axes fraction",
711            fontsize=12,
712        )
713
714    ax.annotate(
715        "\nDownregulated: " + str(down_int),
716        xy=(0.025, 0.925),
717        xycoords="axes fraction",
718        fontsize=12,
719        color="black",
720    )
721
722    ax.annotate(
723        "\nUpregulated: " + str(up_int),
724        xy=(0.025, 0.9),
725        xycoords="axes fraction",
726        fontsize=12,
727        color="black",
728    )
729
730    plt.show()
731
732    return fig
def gene_interaction_network(idata: pandas.core.frame.DataFrame, min_con: int = 2):
 22def gene_interaction_network(idata: pd.DataFrame, min_con: int = 2):
 23    """
 24    Creates a gene or protein interaction network graph.
 25
 26    The network is built from gene/protein interaction data. Nodes represent genes or proteins,
 27    edges represent interactions, and edge colors indicate the type of interaction.
 28
 29    Parameters
 30    ----------
 31    idata : pd.DataFrame
 32      A DataFrame containing the interaction data with columns:
 33        - "A" (str): first gene/protein in the interaction
 34        - "B" (str): second gene/protein in the interaction
 35        - "connection_type" (str): interaction type, e.g., "gene -> protein"
 36
 37    min_con : int, optional
 38        Minimum number of connections (node degree) required
 39        for a gene/protein to be included in the network. Default is 2.
 40
 41    Returns
 42    -------
 43    nx.Graph: A NetworkX graph representing the interaction network.
 44
 45        - Nodes have attributes:
 46            - "size": node size based on connection count (log-scaled)
 47            - "color": node color (default is 'khaki')
 48
 49        - Edges have attributes:
 50            - "color": edge color based on interaction type
 51
 52    Example
 53    -------
 54    >>> G = gene_interaction_network(interactions_df, min_con=3)
 55
 56    >>> nx.draw(G, with_labels=True, node_size=[G.nodes[n]['size'] for n in G.nodes()])
 57    """
 58
 59    inter = idata
 60    inter = inter[["A", "B", "connection_type"]]
 61
 62    dict_meta = pd.DataFrame(
 63        {
 64            "interactions": [
 65                ["gene -> gene"],
 66                ["protein -> protein"],
 67                ["gene -> protein"],
 68                ["protein -> gene"],
 69                ["gene -> gene", "protein -> protein"],
 70                ["gene -> gene", "gene -> protein"],
 71                ["gene -> gene", "protein -> gene"],
 72                ["protein -> protein", "gene -> protein"],
 73                ["protein -> protein", "protein -> gene"],
 74                ["gene -> protein", "protein -> gene"],
 75                ["gene -> gene", "protein -> protein", "gene -> protein"],
 76                ["gene -> gene", "protein -> protein", "protein -> gene"],
 77                ["gene -> gene", "gene -> protein", "protein -> gene"],
 78                ["protein -> protein", "gene -> protein", "protein -> gene"],
 79                [
 80                    "gene -> gene",
 81                    "protein -> protein",
 82                    "gene -> protein",
 83                    "protein -> gene",
 84                ],
 85            ],
 86            "color": [
 87                "#f67089",
 88                "#f47832",
 89                "#ca9213",
 90                "#ad9d31",
 91                "#8eb041",
 92                "#4fb14f",
 93                "#33b07a",
 94                "#35ae99",
 95                "#36acae",
 96                "#38a9c5",
 97                "#3aa3ec",
 98                "#957cf4",
 99                "#cd79f4",
100                "#f35fb5",
101                "#f669b7",
102            ],
103        }
104    )
105
106    genes_list = list(inter["A"]) + list(inter["B"])
107
108    genes_list = Counter(genes_list)
109
110    genes_list = pd.DataFrame(genes_list.items(), columns=["features", "n"])
111
112    genes_list = genes_list.sort_values("n", ascending=False)
113
114    genes_list = genes_list[genes_list["n"] >= min_con]
115
116    inter = inter[inter["A"].isin(list(genes_list["features"]))]
117    inter = inter[inter["B"].isin(list(genes_list["features"]))]
118
119    inter = inter.groupby(["A", "B"]).agg({"connection_type": list}).reset_index()
120
121    inter["color"] = "black"
122
123    for inx in inter.index:
124        for inx2 in dict_meta.index:
125            if set(inter["connection_type"][inx]) == set(
126                dict_meta["interactions"][inx2]
127            ):
128                inter["color"][inx] = dict_meta["color"][inx2]
129                break
130
131    G = nx.Graph()
132
133    for _, row in genes_list.iterrows():
134        node = row["features"]
135        color = "khaki"
136        weight = np.log2(row["n"] * 500)
137        G.add_node(node, size=weight, color=color)
138
139    for _, row in inter.iterrows():
140        source = row["A"]
141        target = row["B"]
142        color = row["color"]
143        G.add_edge(source, target, color=color)
144
145    return G

Creates a gene or protein interaction network graph.

The network is built from gene/protein interaction data. Nodes represent genes or proteins, edges represent interactions, and edge colors indicate the type of interaction.

Parameters

idata : pd.DataFrame A DataFrame containing the interaction data with columns: - "A" (str): first gene/protein in the interaction - "B" (str): second gene/protein in the interaction - "connection_type" (str): interaction type, e.g., "gene -> protein"

min_con : int, optional Minimum number of connections (node degree) required for a gene/protein to be included in the network. Default is 2.

Returns

nx.Graph: A NetworkX graph representing the interaction network.

- Nodes have attributes:
    - "size": node size based on connection count (log-scaled)
    - "color": node color (default is 'khaki')

- Edges have attributes:
    - "color": edge color based on interaction type

Example

>>> G = gene_interaction_network(interactions_df, min_con=3)
>>> nx.draw(G, with_labels=True, node_size=[G.nodes[n]['size'] for n in G.nodes()])
def encrichment_cell_heatmap( data: pandas.core.frame.DataFrame, fig_size=(35, 25), sets=None, top_n=2, test='FISH', adj='BH', parent_inc=False, font_size=16, clustering: str | None = 'ward', scale: bool = True):
148def encrichment_cell_heatmap(
149    data: pd.DataFrame,
150    fig_size=(35, 25),
151    sets=None,
152    top_n=2,
153    test="FISH",
154    adj="BH",
155    parent_inc=False,
156    font_size=16,
157    clustering: str | None = "ward",
158    scale: bool = True,
159):
160    """
161    Creates a functional enrichment heatmap for cell types.
162
163    This function visualizes the most significant functional terms (GO, KEGG, REACTOME, specificity)
164    across different cell types.
165
166    Parameters
167    ----------
168    data : pd.DataFrame
169        Input data containing columns dependent on the source (GO-TERM, KEGG, REACTOME, specificity).
170
171    fig_size : tuple, optional
172        Figure size (width, height). Default is (35, 25).
173
174    sets : list, optional
175        List of specific cell sets to include. Default is None (all sets).
176
177    top_n : int, optional
178        Number of top terms to include per cell. Default is 2.
179
180    test : str, optional
181        Name of the statistical test column. Default is 'FISH'.
182
183    adj : str, optional
184        P-value adjustment method. Default is 'BH'.
185
186    parent_inc : bool, optional
187        Whether to include parent terms in labels. Default is False.
188
189    font_size : int, optional
190        Font size for the heatmap. Default is 16.
191
192    clustering : str | None, optional
193        Clustering method for rows/columns ('ward', 'single', None). Default is 'ward'.
194
195    scale : bool, optional
196        Whether to scale values before plotting. Default is True.
197
198    Returns
199    -------
200    matplotlib.figure.Figure
201        A heatmap figure of functional enrichment per cell type.
202
203    Raises
204    ------
205    ValueError: If the 'source' column in the data is not one of ['GO-TERM', 'KEGG', 'REACTOME', 'specificity'].
206
207    Example
208    -------
209    >>> fig = encrichment_cell_heatmap(data_df, top_n=3, parent_inc=True)
210
211    >>> fig.savefig('cell_heatmap.svg', bbox_inches='tight')
212    """
213
214    if not any(
215        x in data["source"].iloc[0]
216        for x in ("GO-TERM", "KEGG", "REACTOME", "specificity")
217    ):
218        raise ValueError(
219            "Invalid value for 'source' in data. Expected: 'GO-TERM', 'KEGG', 'REACTOME' or 'specificity'."
220        )
221
222    set_col = "cell"
223    if data["source"].iloc[0] == "GO-TERM":
224        term_col = "child_name"
225        parent_col = "parent_name"
226
227    elif data["source"].iloc[0] == "KEGG":
228        term_col = "3rd"
229        parent_col = "2nd"
230
231    elif data["source"].iloc[0] == "REACTOME":
232        term_col = "pathway"
233        parent_col = "top_level_pathway"
234
235    elif data["source"].iloc[0] == "specificity":
236        term_col = "specificity"
237        parent_col = "None"
238
239    title = f"Cells - {data['source'].iloc[0]}"
240
241    if isinstance(sets, list):
242        data[data["cell"].isin(sets)]
243
244    stat_col = [
245        x
246        for x in data.columns
247        if test in x and adj in x and parent_col.upper() not in x.upper()
248    ][0]
249
250    if parent_inc and data["source"].iloc[0] != "specificity":
251        data[term_col] = data.apply(
252            lambda row: f"{row[parent_col]} -> {row[term_col]}", axis=1
253        )
254
255    data = data.loc[data.groupby([set_col, term_col])[stat_col].idxmin()].reset_index(
256        drop=True
257    )
258
259    data = (
260        data.sort_values(stat_col, ascending=True).groupby(set_col).head(top_n)
261    ).reset_index(drop=True)
262
263    if sets is None and len(list(set(data["cell"]))) < 2:
264        if clustering is not None:
265            clustering = None
266            print(
267                "Clustering could not be conducted, because only one group is available in this analysis data."
268            )
269
270    figure = enrichment_heatmap(
271        data=data,
272        stat_col=stat_col,
273        term_col=term_col,
274        set_col=set_col,
275        sets=sets,
276        title=title,
277        fig_size=fig_size,
278        font_size=font_size,
279        scale=scale,
280        clustering=clustering,
281    )
282
283    return figure

Creates a functional enrichment heatmap for cell types.

This function visualizes the most significant functional terms (GO, KEGG, REACTOME, specificity) across different cell types.

Parameters

data : pd.DataFrame Input data containing columns dependent on the source (GO-TERM, KEGG, REACTOME, specificity).

fig_size : tuple, optional Figure size (width, height). Default is (35, 25).

sets : list, optional List of specific cell sets to include. Default is None (all sets).

top_n : int, optional Number of top terms to include per cell. Default is 2.

test : str, optional Name of the statistical test column. Default is 'FISH'.

adj : str, optional P-value adjustment method. Default is 'BH'.

parent_inc : bool, optional Whether to include parent terms in labels. Default is False.

font_size : int, optional Font size for the heatmap. Default is 16.

clustering : str | None, optional Clustering method for rows/columns ('ward', 'single', None). Default is 'ward'.

scale : bool, optional Whether to scale values before plotting. Default is True.

Returns

matplotlib.figure.Figure A heatmap figure of functional enrichment per cell type.

Raises

ValueError: If the 'source' column in the data is not one of ['GO-TERM', 'KEGG', 'REACTOME', 'specificity'].

Example

>>> fig = encrichment_cell_heatmap(data_df, top_n=3, parent_inc=True)
>>> fig.savefig('cell_heatmap.svg', bbox_inches='tight')
def draw_cell_conections( data: pandas.core.frame.DataFrame, top_n: int = 5, weight_percentile_threshold: int | float = 0.75):
286def draw_cell_conections(
287    data: pd.DataFrame, top_n: int = 5, weight_percentile_threshold: int | float = 0.75
288):
289    """
290    Creates a cell-cell interaction network graph based on co-occurrence frequency.
291
292    The function generates a NetworkX graph where nodes represent cell types,
293    and edges represent the frequency of interactions between cells.
294
295    Parameters
296    ----------
297    data : pd.DataFrame)
298        A DataFrame containing columns:
299            - "cell1" (str): source cell type
300            - "cell2" (str): target cell type
301
302    top_n : int, optional)
303        Maximal n neighboured interactions to source cell. Default is 5.
304
305    weight_percentile_threshold : float, optional
306        Percentile used to compute the minimum interaction weight threshold.
307        Interactions with weights below this percentile are filtered out.
308        If no interaction for a given source cell meets this threshold,
309        the top-1 interaction (highest weight) is retained. Default is 0.75.
310
311    Returns
312    -------
313    nx.Graph
314        A NetworkX graph with attributes:
315
316            - Nodes:
317                - "size": node size (default 10)
318                - "color": node color (default "#FFA07A")
319
320            - Edges:
321                - "weight": edge weight (log-transformed from frequency)
322                - "color": edge color (default '#DCDCDC')
323                - "alpha": edge transparency (default 0.05)
324
325    Example
326    -------
327    >>> G = draw_cell_conections(cell_interactions_df, top_n=10)
328
329    >>> nx.draw(G, with_labels=True, node_size=[G.nodes[n]['size'] for n in G.nodes()])
330    """
331
332    cell_cell_df = (
333        data.groupby(["cell1", "cell2"])
334        .size()
335        .reset_index(name="weight")
336        .sort_values("weight", ascending=False)
337    )
338
339    min_weight = cell_cell_df["weight"].quantile(weight_percentile_threshold)
340
341    df_top = (
342        cell_cell_df.sort_values("weight", ascending=False)
343        .groupby("cell1", group_keys=False)
344        .apply(
345            lambda x: (
346                x[x["weight"] >= min_weight].head(top_n)
347                if (x["weight"] >= min_weight).any()
348                else x.head(1)
349            )
350        )
351    )
352
353    cell_cell_df["weight"] = np.log1p(cell_cell_df["weight"])
354    cell_list = list(set(list(cell_cell_df["cell1"]) + list(cell_cell_df["cell2"])))
355
356    G = nx.Graph()
357
358    for c in cell_list:
359        node = c
360        color = "#FFA07A"
361        weight = 10
362        G.add_node(node, size=weight, color=color)
363
364    for _, row in df_top.iterrows():
365        source = row["cell1"]
366        target = row["cell2"]
367        color = "#DCDCDC"
368        weight = row["weight"]
369
370        G.add_edge(source, target, weight=weight, color=color, alpha=0.05)
371
372    nx.spring_layout(G, weight="weight", k=0.1, iterations=500)
373
374    return G

Creates a cell-cell interaction network graph based on co-occurrence frequency.

The function generates a NetworkX graph where nodes represent cell types, and edges represent the frequency of interactions between cells.

Parameters

data : pd.DataFrame) A DataFrame containing columns: - "cell1" (str): source cell type - "cell2" (str): target cell type

top_n : int, optional) Maximal n neighboured interactions to source cell. Default is 5.

weight_percentile_threshold : float, optional Percentile used to compute the minimum interaction weight threshold. Interactions with weights below this percentile are filtered out. If no interaction for a given source cell meets this threshold, the top-1 interaction (highest weight) is retained. Default is 0.75.

Returns

nx.Graph A NetworkX graph with attributes:

    - Nodes:
        - "size": node size (default 10)
        - "color": node color (default "#FFA07A")

    - Edges:
        - "weight": edge weight (log-transformed from frequency)
        - "color": edge color (default '#DCDCDC')
        - "alpha": edge transparency (default 0.05)

Example

>>> G = draw_cell_conections(cell_interactions_df, top_n=10)
>>> nx.draw(G, with_labels=True, node_size=[G.nodes[n]['size'] for n in G.nodes()])
def volcano_plot_conections( deg_data: pandas.core.frame.DataFrame, p_adj: bool = True, top: int = 25, top_rank: str = 'p_value', p_val: float | int = 0.05, lfc: float | int = 0.25, rescale_adj: bool = True, image_width: int = 12, image_high: int = 12):
377def volcano_plot_conections(
378    deg_data: pd.DataFrame,
379    p_adj: bool = True,
380    top: int = 25,
381    top_rank: str = "p_value",
382    p_val: float | int = 0.05,
383    lfc: float | int = 0.25,
384    rescale_adj: bool = True,
385    image_width: int = 12,
386    image_high: int = 12,
387):
388    """
389    Generate a volcano plot from differential expression results.
390
391    A volcano plot visualizes the relationship between statistical significance
392    (p-values or standarized p-value) and log(fold change) for each gene, highlighting
393    genes that pass significance thresholds.
394
395    Parameters
396    ----------
397    deg_data : pandas.DataFrame
398        DataFrame containing differential expression results from calc_DEG() function.
399
400    p_adj : bool, default=True
401        If True, use adjusted p-values. If False, use raw p-values.
402
403    top : int, default=25
404        Number of top significant genes to highlight on the plot.
405
406    top_rank : str, default='p_value'
407        Statistic used primarily to determine the top significant genes to highlight on the plot. ['p_value' or 'FC']
408
409    p_val : float | int, default=0.05
410        Significance threshold for p-values (or adjusted p-values).
411
412    lfc : float | int, default=0.25
413        Threshold for absolute log fold change.
414
415    rescale_adj : bool, default=True
416        If True, rescale p-values to avoid long breaks caused by outlier values.
417
418    image_width : int, default=12
419        Width of the generated plot in inches.
420
421    image_high : int, default=12
422        Height of the generated plot in inches.
423
424    Returns
425    -------
426    matplotlib.figure.Figure
427        The generated volcano plot figure.
428    """
429
430    if top_rank.upper() not in ["FC", "P_VALUE"]:
431        raise ValueError("top_rank must be either 'FC' or 'p_value'")
432
433    if p_adj:
434        pv = "adj_pval"
435    else:
436        pv = "p_val"
437
438    deg_df = deg_data.copy()
439
440    shift = 0.25
441
442    p_val_scale = "-log(p_val)"
443
444    min_minus = min(deg_df[pv][(deg_df[pv] != 0) & (deg_df["log(FC)"] < 0)])
445    min_plus = min(deg_df[pv][(deg_df[pv] != 0) & (deg_df["log(FC)"] > 0)])
446
447    zero_p_plus = deg_df[(deg_df[pv] == 0) & (deg_df["log(FC)"] > 0)]
448    zero_p_plus = zero_p_plus.sort_values(by="log(FC)", ascending=False).reset_index(
449        drop=True
450    )
451    zero_p_plus[pv] = [
452        (shift * x) * min_plus for x in range(1, len(zero_p_plus.index) + 1)
453    ]
454
455    zero_p_minus = deg_df[(deg_df[pv] == 0) & (deg_df["log(FC)"] < 0)]
456    zero_p_minus = zero_p_minus.sort_values(by="log(FC)", ascending=True).reset_index(
457        drop=True
458    )
459    zero_p_minus[pv] = [
460        (shift * x) * min_minus for x in range(1, len(zero_p_minus.index) + 1)
461    ]
462
463    tmp_p = deg_df[
464        ((deg_df[pv] != 0) & (deg_df["log(FC)"] < 0))
465        | ((deg_df[pv] != 0) & (deg_df["log(FC)"] > 0))
466    ]
467
468    del deg_df
469
470    deg_df = pd.concat([zero_p_plus, tmp_p, zero_p_minus], ignore_index=True)
471
472    deg_df[p_val_scale] = -np.log10(deg_df[pv])
473
474    deg_df["top100"] = None
475
476    if rescale_adj:
477
478        deg_df = deg_df.sort_values(by=p_val_scale, ascending=False)
479
480        deg_df = deg_df.reset_index(drop=True)
481
482        eps = 1e-300
483        doubled = []
484        ratio = []
485        for n, i in enumerate(deg_df.index):
486            for j in range(1, 6):
487                if (
488                    n + j < len(deg_df.index)
489                    and (deg_df[p_val_scale][n] + eps)
490                    / (deg_df[p_val_scale][n + j] + eps)
491                    >= 2
492                ):
493                    doubled.append(n)
494                    ratio.append(
495                        (deg_df[p_val_scale][n + j] + eps)
496                        / (deg_df[p_val_scale][n] + eps)
497                    )
498
499        df = pd.DataFrame({"doubled": doubled, "ratio": ratio})
500        df = df[df["doubled"] < 100]
501
502        df["ratio"] = (1 - df["ratio"]) / 5
503        df = df.reset_index(drop=True)
504
505        df = df.sort_values("doubled")
506
507        if len(df["doubled"]) == 1 and 0 in df["doubled"]:
508            df = df
509        else:
510            doubled2 = []
511
512            for l in df["doubled"]:
513                if l + 1 != len(doubled) and l + 1 - l == 1:
514                    doubled2.append(l)
515                    doubled2.append(l + 1)
516                else:
517                    break
518
519            doubled2 = sorted(set(doubled2), reverse=True)
520
521        if len(doubled2) > 1:
522            df = df[df["doubled"].isin(doubled2)]
523            df = df.sort_values("doubled", ascending=False)
524            df = df.reset_index(drop=True)
525            for c in df.index:
526                deg_df.loc[df["doubled"][c], p_val_scale] = deg_df.loc[
527                    df["doubled"][c] + 1, p_val_scale
528                ] * (1 + df["ratio"][c])
529
530    deg_df.loc[(deg_df["log(FC)"] <= 0) & (deg_df[pv] <= p_val), "top100"] = "bisque"
531    deg_df.loc[(deg_df["log(FC)"] > 0) & (deg_df[pv] <= p_val), "top100"] = "skyblue"
532    deg_df.loc[deg_df[pv] > p_val, "top100"] = "lightgray"
533
534    if lfc > 0:
535        deg_df.loc[
536            (deg_df["log(FC)"] <= lfc) & (deg_df["log(FC)"] >= -lfc), "top100"
537        ] = "lightgray"
538
539    down_int = len(
540        deg_df["top100"][(deg_df["log(FC)"] <= lfc * -1) & (deg_df[pv] <= p_val)]
541    )
542    up_int = len(deg_df["top100"][(deg_df["log(FC)"] > lfc) & (deg_df[pv] <= p_val)])
543
544    deg_df_up = deg_df[deg_df["log(FC)"] > 0]
545
546    if top_rank.upper() == "P_VALUE":
547        deg_df_up = deg_df_up.sort_values([pv, "log(FC)"], ascending=[True, False])
548    elif top_rank.upper() == "FC":
549        deg_df_up = deg_df_up.sort_values(["log(FC)", pv], ascending=[False, True])
550
551    deg_df_up = deg_df_up.reset_index(drop=True)
552
553    n = -1
554    l = 0
555    while True:
556        n += 1
557        if deg_df_up["log(FC)"][n] > lfc and deg_df_up[pv][n] <= p_val:
558            deg_df_up.loc[n, "top100"] = "dodgerblue"
559            l += 1
560        if l == top or deg_df_up[pv][n] > p_val:
561            break
562
563    deg_df_down = deg_df[deg_df["log(FC)"] <= 0]
564
565    if top_rank.upper() == "P_VALUE":
566        deg_df_down = deg_df_down.sort_values([pv, "log(FC)"], ascending=[True, True])
567    elif top_rank.upper() == "FC":
568        deg_df_down = deg_df_down.sort_values(["log(FC)", pv], ascending=[True, True])
569
570    deg_df_down = deg_df_down.reset_index(drop=True)
571
572    n = -1
573    l = 0
574    while True:
575        n += 1
576        if deg_df_down["log(FC)"][n] < lfc * -1 and deg_df_down[pv][n] <= p_val:
577            deg_df_down.loc[n, "top100"] = "tomato"
578
579            l += 1
580        if l == top or deg_df_down[pv][n] > p_val:
581            break
582
583    deg_df = pd.concat([deg_df_up, deg_df_down])
584
585    que = ["lightgray", "bisque", "skyblue", "tomato", "dodgerblue"]
586
587    deg_df = deg_df.sort_values(
588        by="top100", key=lambda x: x.map({v: i for i, v in enumerate(que)})
589    )
590
591    deg_df = deg_df.reset_index(drop=True)
592
593    fig, ax = plt.subplots(figsize=(image_width, image_high))
594
595    plt.scatter(
596        x=deg_df["log(FC)"], y=deg_df[p_val_scale], color=deg_df["top100"], zorder=2
597    )
598
599    tl = deg_df[p_val_scale][deg_df[pv] >= p_val]
600
601    if len(tl) > 0:
602
603        line_p = np.max(tl)
604
605    else:
606        line_p = np.min(deg_df[p_val_scale])
607
608    plt.plot(
609        [max(deg_df["log(FC)"]) * -1.1, max(deg_df["log(FC)"]) * 1.1],
610        [line_p, line_p],
611        linestyle="--",
612        linewidth=3,
613        color="lightgray",
614        zorder=1,
615    )
616
617    if lfc > 0:
618        plt.plot(
619            [lfc * -1, lfc * -1],
620            [-3, max(deg_df[p_val_scale]) * 1.1],
621            linestyle="--",
622            linewidth=3,
623            color="lightgray",
624            zorder=1,
625        )
626        plt.plot(
627            [lfc, lfc],
628            [-3, max(deg_df[p_val_scale]) * 1.1],
629            linestyle="--",
630            linewidth=3,
631            color="lightgray",
632            zorder=1,
633        )
634
635    plt.xlabel("log(FC)")
636    plt.ylabel(p_val_scale)
637    plt.title("Volcano plot")
638
639    plt.ylim(min(deg_df[p_val_scale]) - 5, max(deg_df[p_val_scale]) * 1.3)
640
641    texts = [
642        ax.text(deg_df["log(FC)"][i], deg_df[p_val_scale][i], deg_df["feature"][i])
643        for i in deg_df.index
644        if deg_df["top100"][i] in ["dodgerblue", "tomato"]
645    ]
646
647    adjust_text(texts, arrowprops=dict(arrowstyle="-", color="gray", alpha=0.5))
648
649    legend_elements = [
650        Line2D(
651            [0],
652            [0],
653            marker="o",
654            color="w",
655            label="top-upregulated",
656            markerfacecolor="dodgerblue",
657            markersize=10,
658        ),
659        Line2D(
660            [0],
661            [0],
662            marker="o",
663            color="w",
664            label="top-downregulated",
665            markerfacecolor="tomato",
666            markersize=10,
667        ),
668        Line2D(
669            [0],
670            [0],
671            marker="o",
672            color="w",
673            label="upregulated",
674            markerfacecolor="skyblue",
675            markersize=10,
676        ),
677        Line2D(
678            [0],
679            [0],
680            marker="o",
681            color="w",
682            label="downregulated",
683            markerfacecolor="bisque",
684            markersize=10,
685        ),
686        Line2D(
687            [0],
688            [0],
689            marker="o",
690            color="w",
691            label="non-significant",
692            markerfacecolor="lightgray",
693            markersize=10,
694        ),
695    ]
696
697    ax.legend(handles=legend_elements, loc="upper right")
698    ax.grid(visible=False)
699
700    ax.annotate(
701        f"\nmin {pv} = " + str(p_val),
702        xy=(0.025, 0.975),
703        xycoords="axes fraction",
704        fontsize=12,
705    )
706
707    if lfc > 0:
708        ax.annotate(
709            "\nmin log(FC) = " + str(lfc),
710            xy=(0.025, 0.95),
711            xycoords="axes fraction",
712            fontsize=12,
713        )
714
715    ax.annotate(
716        "\nDownregulated: " + str(down_int),
717        xy=(0.025, 0.925),
718        xycoords="axes fraction",
719        fontsize=12,
720        color="black",
721    )
722
723    ax.annotate(
724        "\nUpregulated: " + str(up_int),
725        xy=(0.025, 0.9),
726        xycoords="axes fraction",
727        fontsize=12,
728        color="black",
729    )
730
731    plt.show()
732
733    return fig

Generate a volcano plot from differential expression results.

A volcano plot visualizes the relationship between statistical significance (p-values or standarized p-value) and log(fold change) for each gene, highlighting genes that pass significance thresholds.

Parameters

deg_data : pandas.DataFrame DataFrame containing differential expression results from calc_DEG() function.

p_adj : bool, default=True If True, use adjusted p-values. If False, use raw p-values.

top : int, default=25 Number of top significant genes to highlight on the plot.

top_rank : str, default='p_value' Statistic used primarily to determine the top significant genes to highlight on the plot. ['p_value' or 'FC']

p_val : float | int, default=0.05 Significance threshold for p-values (or adjusted p-values).

lfc : float | int, default=0.25 Threshold for absolute log fold change.

rescale_adj : bool, default=True If True, rescale p-values to avoid long breaks caused by outlier values.

image_width : int, default=12 Width of the generated plot in inches.

image_high : int, default=12 Height of the generated plot in inches.

Returns

matplotlib.figure.Figure The generated volcano plot figure.