Supervised & self-supervised model performance#

import numpy as np
from tifffile import imread
import sys
import numpy as np

import pyclesperanto_prototype as cle
from stardist.matching import matching_dataset
sys.path.append("../..")

from utils import *
from plots import *
print("Used GPU: ", cle.get_device())
show_params()
#################
SAVE_PLOTS_AS_PNG = False
SAVE_PLOTS_AS_SVG = True
Used GPU:  <NVIDIA GeForce RTX 4070 Ti on Platform: NVIDIA CUDA (1 refs)>
Plot parameters (set in plots.py) : 
- COLORMAP : ███████
- DPI : 200
- Data path : C:\Users\Cyril\Desktop\Code\CELLSEG_BENCHMARK
- Font size : 20
- Title font size : 25.0
- Label font size : 20.0
%load_ext autoreload
%autoreload 2

Data#

Supervised models#

data_path = DATA_PATH / "RESULTS/SUPERVISED_PERF_FIG/INFERENCE"
fold1_path = data_path / "fold1"
fold2_path = data_path / "fold2"
fold3_path = data_path / "fold3"

def get_fold_data(fold_path):
    return [imread(f) for f in fold_path.glob("*.tif") if "label" not in f.name], [imread(f) for f in fold_path.glob("*.tif") if "label" in f.name]

images = []
GT_labels = []
for fold in [fold1_path, fold2_path, fold3_path]:
    im, GT_lab = get_fold_data(fold)
    images.append(im[0])
    GT_labels.append(GT_lab[0])
fold_paths = [fold1_path, fold2_path, fold3_path]

def get_predictions(path):
    return [imread(f) for f in path.glob("*.tif")]
#################
swin_paths = [f / "Swin" for f in fold_paths]
segresnet_paths = [f / "SegResNet" for f in fold_paths]
#################
cellpose_paths = [f / "cellpose/tuned" for f in fold_paths]
stardist_paths = [f / "stardist/tuned" for f in fold_paths]
#################
cellpose_default_paths = [f / "cellpose/default" for f in fold_paths]
stardist_default_paths = [f / "stardist/default" for f in fold_paths]

swin_predictions = [get_predictions(p)[0] for p in swin_paths]
segresnet_predictions = [get_predictions(p)[0] for p in segresnet_paths]
cellpose_predictions = [get_predictions(p)[0] for p in cellpose_paths]
stardist_predictions = [get_predictions(p)[0] for p in stardist_paths]
cellpose_default_predictions = [get_predictions(p)[0] for p in cellpose_default_paths]
stardist_default_predictions = [get_predictions(p)[0] for p in stardist_default_paths]

Threshold-based segmentation#

Check if a simple threshold-based segmentation yields competitive results.

import pyclesperanto_prototype as cle 
threshold_only = [np.array(cle.threshold_otsu(p)) for p in images]
threshold_only = [np.array(cle.voronoi_otsu_labeling(p, spot_sigma=0.65, outline_sigma=0.65)) for p in threshold_only]

Unsupervised model#

wnet_path = DATA_PATH / "RESULTS/WNET_RETRAIN/inference/WNet"
wnet_instance = [imread(str(f)) for f in wnet_path.glob("*.tif")]
wnet_instance.reverse()

Computations#

For the choice of thresholds, see notebooks in the thresholds_opti/ folder.

swin_thresholded = [np.where(swin > 0.4, swin, 0) for swin in swin_predictions]
segresnet_thresholded = [np.where(segresnet > 0.3, segresnet, 0) for segresnet in segresnet_predictions]
swin_instance = []
segresnet_instance = []

for i in range(len(fold_paths)):
    swin_instance.append(np.array(
        cle.voronoi_otsu_labeling(swin_thresholded[i], spot_sigma=0.65, outline_sigma=0.65))
                         )
    segresnet_instance.append(np.array(
        cle.voronoi_otsu_labeling(segresnet_thresholded[i], spot_sigma=0.65, outline_sigma=0.65))
                              )
# from tifffile import imwrite
# for i in range(len(fold_paths)):
#     save_folder = data_path / f"fold{i+1}" / "processed"
#     save_folder.mkdir(exist_ok=True, parents=False)
#     imwrite(save_folder / "swin_instance.tif", swin_instance[i])
#     imwrite(save_folder / "segresnet_instance.tif", segresnet_instance[i])

Plots#

taus = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]

predictions = [
   stardist_predictions,
   stardist_default_predictions,
   cellpose_predictions,
   cellpose_default_predictions,
   segresnet_instance,
   swin_instance,
   wnet_instance,
   threshold_only,
   ]
names = [
    "StarDist",
    "StarDist - default",
    "Cellpose",
    "Cellpose - default",
    "SegResNet",
    "SwinUNetR",
    "WNet3D",
    "Otsu Threshold"
    ]

model_stats = []
model_ref_name = []
fold_ref = []

for j, fold in enumerate(fold_paths):
    print("Fold ", j+1)
    for i, p in enumerate(predictions):
        print(f"Validating on {names[i]}")
        # print(p[j].shape, GT_labels[j].shape)
        pred = p[j]
        if len(pred.shape) > 3:
            pred = pred.squeeze().astype(np.uint32)
        stats = [matching_dataset(
            GT_labels[j], pred, thresh=t, show_progress=False
            ) for t in taus]
        model_stats.append(stats)
        dice = dice_coeff(
            np.where(GT_labels[j] > 0, 1, 0),
            np.where(pred > 0, 1, 0)
        )
        print(f"Dice Coefficient: {dice}")
        for t in taus:
            model_ref_name.append(names[i])
            fold_ref.append(j+1)
        # uncomment for ALL plots : 
        # plot_performance(taus, stats, name=names[i])
        print("*"*20)
Fold  1
Validating on StarDist
Dice Coefficient: 0.6993586695284486
********************
Validating on StarDist - default
Dice Coefficient: 0.39347722427914283
********************
Validating on Cellpose
Dice Coefficient: 0.7702901873444256
********************
Validating on Cellpose - default
Dice Coefficient: 0.5247862438123209
********************
Validating on SegResNet
Dice Coefficient: 0.7726461999842159
********************
Validating on SwinUNetR
Dice Coefficient: 0.8408113189509576
********************
Validating on WNet3D
Dice Coefficient: 0.5896556737502726
********************
Validating on Otsu Threshold
Dice Coefficient: 0.4892361595314461
********************
Fold  2
Validating on StarDist
Dice Coefficient: 0.7944590859420196
********************
Validating on StarDist - default
Dice Coefficient: 0.7251454473235356
********************
Validating on Cellpose
Dice Coefficient: 0.8287320218217887
********************
Validating on Cellpose - default
Dice Coefficient: 0.5633677060172622
********************
Validating on SegResNet
Dice Coefficient: 0.8013458950201884
********************
Validating on SwinUNetR
Dice Coefficient: 0.7887976385316811
********************
Validating on WNet3D
Dice Coefficient: 0.8087545264446371
********************
Validating on Otsu Threshold
Dice Coefficient: 0.665591153771321
********************
Fold  3
Validating on StarDist
Dice Coefficient: 0.7584470232902522
********************
Validating on StarDist - default
Dice Coefficient: 0.5047402543506129
********************
Validating on Cellpose
Dice Coefficient: 0.7932470010023605
********************
Validating on Cellpose - default
Dice Coefficient: 0.5596045059982444
********************
Validating on SegResNet
Dice Coefficient: 0.8472361720498317
********************
Validating on SwinUNetR
Dice Coefficient: 0.8554751436182162
********************
Validating on WNet3D
Dice Coefficient: 0.8113536341409202
********************
Validating on Otsu Threshold
Dice Coefficient: 0.7887478958953774
********************
dfs = [dataset_matching_stats_to_df(s) for s in model_stats]
df = pd.concat(dfs)
df["Model"] = model_ref_name
df["Fold"] = fold_ref
df["thresh"] = df.index
df.set_index(["Model", "Fold", "thresh"], inplace=True)
df
criterion fp tp fn precision recall accuracy f1 n_true n_pred mean_true_score mean_matched_score panoptic_quality by_image
Model Fold thresh
StarDist 1 0.1 iou 1395 2973 412 0.680632 0.878287 0.621967 0.766929 3385 4368 0.537485 0.611970 0.469338 False
0.2 iou 1405 2963 422 0.678342 0.875332 0.618580 0.764349 3385 4368 0.537034 0.613520 0.468944 False
0.3 iou 1482 2886 499 0.660714 0.852585 0.592973 0.744486 3385 4368 0.531009 0.622822 0.463682 False
0.4 iou 1727 2641 744 0.604625 0.780207 0.516628 0.681285 3385 4368 0.506039 0.648597 0.441879 False
0.5 iou 2127 2241 1144 0.513049 0.662038 0.406567 0.578099 3385 4368 0.454747 0.686888 0.397089 False
... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ...
Otsu Threshold 3 0.5 iou 1357 4209 2700 0.756198 0.609205 0.509194 0.674790 6909 5566 0.471521 0.773993 0.522282 False
0.6 iou 2077 3489 3420 0.626842 0.504993 0.388271 0.559359 6909 5566 0.416410 0.824585 0.461239 False
0.7 iou 2758 2808 4101 0.504492 0.406426 0.290473 0.450180 6909 5566 0.352801 0.868057 0.390782 False
0.8 iou 3521 2045 4864 0.367409 0.295991 0.196069 0.327856 6909 5566 0.270493 0.913856 0.299613 False
0.9 iou 4453 1113 5796 0.199964 0.161094 0.097958 0.178437 6909 5566 0.156777 0.973200 0.173655 False

216 rows × 14 columns

df_supervised = df.copy()
df_supervised = df.loc[["StarDist", "StarDist - default", "Cellpose", "Cellpose - default", "SegResNet", "SwinUNetR"]]
temp_cmap = COLORMAP.copy()
temp_cmap.insert(2, COLORMAP[1])
temp_cmap.insert(1, COLORMAP[0])
temp_cmap[2], _ = get_shades(temp_cmap[3])
temp_cmap[0], _ = get_shades(temp_cmap[1])
EXPANDED_COLORMAP = temp_cmap

Precision#

plot_stat_comparison_fold(df_supervised, stat="precision", colormap=EXPANDED_COLORMAP)
if SAVE_PLOTS_AS_PNG:
    plt.savefig("precision_comparison_fold.png", bbox_inches = "tight")
if SAVE_PLOTS_AS_SVG:
    plt.savefig("precision_comparison_fold.svg", bbox_inches = "tight")
../_images/4ff20c1d5403b1e2a080ea49b6a826839510f6a843c8bd4c64f44de9d3094b9f.png

Recall#

plot_stat_comparison_fold(df_supervised, stat="recall", colormap=EXPANDED_COLORMAP)
if SAVE_PLOTS_AS_PNG:
    plt.savefig("recall_comparison_fold.png", bbox_inches = "tight")
if SAVE_PLOTS_AS_SVG:
    plt.savefig("recall_comparison_fold.svg", bbox_inches = "tight")
../_images/38f6f30605408b13af4aea6fde84757f3b006ba31ed61e8efe56976f743331b1.png

F1 Score#

plot_stat_comparison_fold(df_supervised, stat="f1", colormap=EXPANDED_COLORMAP)
if SAVE_PLOTS_AS_PNG:
    plt.savefig("f1_comparison_fold.png", bbox_inches = "tight")
if SAVE_PLOTS_AS_SVG:
    plt.savefig("f1_comparison_fold.svg", bbox_inches = "tight")
../_images/58b71eb20e47bb04146bc770688cc4946f2ec39d79e9ecebb65a5f053405a6fb.png

Panoptic Quality#

plot_stat_comparison_fold(df_supervised, stat="panoptic_quality", colormap=EXPANDED_COLORMAP)
if SAVE_PLOTS_AS_PNG:
    plt.savefig("panoptic_quality_comparison_fold.png", bbox_inches = "tight")
if SAVE_PLOTS_AS_SVG:
    plt.savefig("panoptic_quality_comparison_fold.svg", bbox_inches = "tight")
../_images/24d837b4db9673d4c3ce8b1094d85ab518f9ec6faf5fc099d1e6e2313d488378.png

WNet3D - No artifacts#

# evaluate stats for WNet3D on an extra fold, where visual is restricted to [50:, :, :]
stats = [matching_dataset(
    GT_labels[0][:50,:,:], wnet_instance[0][:50,:,:], thresh=t, show_progress=False
    ) for t in taus]
partial_df = dataset_matching_stats_to_df(stats)
partial_df["Model"] = "WNet3D - No artifacts"
partial_df["Fold"] = 1
partial_df["thresh"] = partial_df.index
partial_df.set_index(["Model", "Fold", "thresh"], inplace=True)
df = pd.concat([df, partial_df])
df
criterion fp tp fn precision recall accuracy f1 n_true n_pred mean_true_score mean_matched_score panoptic_quality by_image
Model Fold thresh
StarDist 1 0.1 iou 1395 2973 412 0.680632 0.878287 0.621967 0.766929 3385 4368 0.537485 0.611970 0.469338 False
0.2 iou 1405 2963 422 0.678342 0.875332 0.618580 0.764349 3385 4368 0.537034 0.613520 0.468944 False
0.3 iou 1482 2886 499 0.660714 0.852585 0.592973 0.744486 3385 4368 0.531009 0.622822 0.463682 False
0.4 iou 1727 2641 744 0.604625 0.780207 0.516628 0.681285 3385 4368 0.506039 0.648597 0.441879 False
0.5 iou 2127 2241 1144 0.513049 0.662038 0.406567 0.578099 3385 4368 0.454747 0.686888 0.397089 False
... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ...
WNet3D - No artifacts 1 0.5 iou 376 2190 854 0.853468 0.719448 0.640351 0.780749 3044 2566 0.560445 0.778992 0.608197 False
0.6 iou 660 1906 1138 0.742790 0.626150 0.514579 0.679501 3044 2566 0.510604 0.815466 0.554110 False
0.7 iou 1012 1554 1490 0.605612 0.510512 0.383136 0.554011 3044 2566 0.435460 0.852986 0.472564 False
0.8 iou 1463 1103 1941 0.429852 0.362352 0.244730 0.393226 3044 2566 0.324968 0.896830 0.352657 False
0.9 iou 2071 495 2549 0.192907 0.162615 0.096774 0.176471 3044 2566 0.156949 0.965155 0.170321 False

225 rows × 14 columns

plot_stat_comparison_fold(
    df.loc[["StarDist", "StarDist - default", "Cellpose", "Cellpose - default", "SegResNet", "SwinUNetR", "WNet3D", "WNet3D - No artifacts"]],
    stat="f1", colormap=EXPANDED_COLORMAP)
../_images/cfae84f80de0ac63d390e37c76ad2a655bcd1d7805030cc71cb9dc8a99c8b146.png
df_no_learning = df.loc[["Otsu Threshold", "StarDist", "Cellpose", "SegResNet", "SwinUNetR", "WNet3D - No artifacts", "WNet3D"]]
VO_CMAP = COLORMAP.copy()
VO_CMAP = ["black"] + VO_CMAP
plot_stat_comparison_fold(df_no_learning, stat="f1", colormap=VO_CMAP)
if SAVE_PLOTS_AS_PNG:
    plt.savefig("f1_comparison_fold_no_learning.png", bbox_inches = "tight")
if SAVE_PLOTS_AS_SVG:
    plt.savefig("f1_comparison_fold_no_learning.svg", bbox_inches = "tight")
../_images/041512758c73b0f030cf02be51d4b5750d5bd856fd43dcc41b80151256726080.png

Statistical tests (outdated & unused)#

Boxplot of model performance across all thresholds#

sns.boxplot(data=df, x="Model", y="f1", palette=EXPANDED_COLORMAP)
plt.xticks(rotation=45)
plt.show()
27598755.py (1): 

Passing `palette` without assigning `hue` is deprecated and will be removed in v0.14.0. Assign the `x` variable to `hue` and set `legend=False` for the same effect.

27598755.py (1): The palette list has more values (10) than needed (9), which may not be intended.
../_images/a0b08d36326487d545e339147f844de8f9612ab41f7c8c2030428260e1b4cbfa.png
from scipy.stats import f_oneway, kruskal

models_f1_means = df.groupby("Model").f1.mean()

f1_starDist = df.loc["StarDist"].f1
f1_cellpose = df.loc["Cellpose"].f1
f1_segresnet = df.loc["SegResNet"].f1
f1_swin = df.loc["SwinUNetR"].f1
f1_wnet = df.loc["WNet3D"].f1
f1_wnet_partial = df.loc["WNet3D - No artifacts"].f1

anova_test = f_oneway(
    # f1_starDist, 
    f1_cellpose, 
    f1_segresnet, 
    f1_swin, 
    f1_wnet, 
    f1_wnet_partial
    )
print("Kruskal test: ", anova_test)
Kruskal test:  F_onewayResult(statistic=0.4660165757160163, pvalue=0.7605465068261601)
# Significant difference between the models?
kruskal_test = kruskal(
    f1_starDist,
    f1_cellpose, 
    f1_segresnet, 
    f1_swin, 
    f1_wnet, 
    f1_wnet_partial
    )
print("Comparisons of F1 scores between all models:")
print("- StarDist\n- Cellpose\n- SegResNet\n- SwinUNetR\n- WNet3D\n- WNet3D - No artifacts")
print("Kruskal-Wallis test: ", kruskal_test)
Comparisons of F1 scores between all models:
- StarDist
- Cellpose
- SegResNet
- SwinUNetR
- WNet3D
- WNet3D - No artifacts
Kruskal-Wallis test:  KruskalResult(statistic=5.021370796083488, pvalue=0.41327764021418123)
# Best models - significant difference between them?
kruskal_test = kruskal(
    # f1_starDist, 
    f1_cellpose, 
    f1_segresnet, 
    f1_swin, 
    f1_wnet, 
    f1_wnet_partial
    )
print("Comparisons of F1 scores between the best models:")
print("- Cellpose\n- SegResNet\n- SwinUNetR\n- WNet3D\n- WNet3D - No artifacts")
print("Kruskal-Wallis test: ", kruskal_test)
Comparisons of F1 scores between the best models:
- Cellpose
- SegResNet
- SwinUNetR
- WNet3D
- WNet3D - No artifacts
Kruskal-Wallis test:  KruskalResult(statistic=3.710657202182631, pvalue=0.4465777918818482)
# Best supervised models - significant difference between them?
kruskal_test = kruskal(
    f1_cellpose, 
    f1_segresnet, 
    f1_swin
    )
print("Comparisons of F1 scores between the best supervised models:")
print("- Cellpose\n- SegResNet\n- SwinUNetR")
print("Kruskal-Wallis test: ", kruskal_test)
Comparisons of F1 scores between the best supervised models:
- Cellpose
- SegResNet
- SwinUNetR
Kruskal-Wallis test:  KruskalResult(statistic=0.6194921208470987, pvalue=0.7336332310839804)