Supervised & self-supervised model performance#

import numpy as np
from tifffile import imread
import sys
import numpy as np

import pyclesperanto_prototype as cle
from stardist.matching import matching_dataset
sys.path.append("../..")

from utils import *
from plots import *
print("Used GPU: ", cle.get_device())
show_params()
#################
SAVE_PLOTS_AS_PNG = False
SAVE_PLOTS_AS_SVG = True
Used GPU:  <NVIDIA GeForce RTX 4070 Ti on Platform: NVIDIA CUDA (1 refs)>
Plot parameters (set in plots.py) : 
- COLORMAP : ███████
- DPI : 200
- Data path : C:\Users\Cyril\Desktop\Code\CELLSEG_BENCHMARK
- Font size : 20
- Title font size : 25.0
- Label font size : 20.0
%load_ext autoreload
%autoreload 2

Data#

Supervised models#

data_path = DATA_PATH / "RESULTS/SUPERVISED_PERF_FIG/INFERENCE"
fold1_path = data_path / "fold1"
fold2_path = data_path / "fold2"
fold3_path = data_path / "fold3"

def get_fold_data(fold_path):
    return [imread(f) for f in fold_path.glob("*.tif") if "label" not in f.name], [imread(f) for f in fold_path.glob("*.tif") if "label" in f.name]

images = []
GT_labels = []
for fold in [fold1_path, fold2_path, fold3_path]:
    im, GT_lab = get_fold_data(fold)
    images.append(im[0])
    GT_labels.append(GT_lab[0])
fold_paths = [fold1_path, fold2_path, fold3_path]

def get_predictions(path):
    return [imread(f) for f in path.glob("*.tif")]
#################
swin_paths = [f / "Swin" for f in fold_paths]
segresnet_paths = [f / "SegResNet" for f in fold_paths]
#################
cellpose_paths = [f / "cellpose/tuned" for f in fold_paths]
stardist_paths = [f / "stardist/tuned" for f in fold_paths]
#################
cellpose_default_paths = [f / "cellpose/default" for f in fold_paths]
stardist_default_paths = [f / "stardist/default" for f in fold_paths]

swin_predictions = [get_predictions(p)[0] for p in swin_paths]
segresnet_predictions = [get_predictions(p)[0] for p in segresnet_paths]
cellpose_predictions = [get_predictions(p)[0] for p in cellpose_paths]
stardist_predictions = [get_predictions(p)[0] for p in stardist_paths]
cellpose_default_predictions = [get_predictions(p)[0] for p in cellpose_default_paths]
stardist_default_predictions = [get_predictions(p)[0] for p in stardist_default_paths]

Threshold-based segmentation#

Check if a simple threshold-based segmentation yields competitive results.

import pyclesperanto_prototype as cle 
threshold_only = [np.array(cle.threshold_otsu(p)) for p in images]
threshold_only = [np.array(cle.voronoi_otsu_labeling(p, spot_sigma=0.65, outline_sigma=0.65)) for p in threshold_only]

Unsupervised model#

wnet_path = DATA_PATH / "RESULTS/WNET_RETRAIN/inference/WNet"
wnet_instance = [imread(str(f)) for f in wnet_path.glob("*.tif")]
wnet_instance.reverse()

Computations#

For the choice of thresholds, see notebooks in the thresholds_opti/ folder.

swin_thresholded = [np.where(swin > 0.4, swin, 0) for swin in swin_predictions]
segresnet_thresholded = [np.where(segresnet > 0.3, segresnet, 0) for segresnet in segresnet_predictions]
swin_instance = []
segresnet_instance = []

for i in range(len(fold_paths)):
    swin_instance.append(np.array(
        cle.voronoi_otsu_labeling(swin_thresholded[i], spot_sigma=0.65, outline_sigma=0.65))
                         )
    segresnet_instance.append(np.array(
        cle.voronoi_otsu_labeling(segresnet_thresholded[i], spot_sigma=0.65, outline_sigma=0.65))
                              )
# from tifffile import imwrite
# for i in range(len(fold_paths)):
#     save_folder = data_path / f"fold{i+1}" / "processed"
#     save_folder.mkdir(exist_ok=True, parents=False)
#     imwrite(save_folder / "swin_instance.tif", swin_instance[i])
#     imwrite(save_folder / "segresnet_instance.tif", segresnet_instance[i])

Plots#

taus = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]

predictions = [
   stardist_predictions,
   stardist_default_predictions,
   cellpose_predictions,
   cellpose_default_predictions,
   segresnet_instance,
   swin_instance,
   wnet_instance,
   threshold_only,
   ]
names = [
    "StarDist",
    "StarDist - default",
    "Cellpose",
    "Cellpose - default",
    "SegResNet",
    "SwinUNetR",
    "WNet3D",
    "Otsu Threshold"
    ]

model_stats = []
model_ref_name = []
fold_ref = []

for j, fold in enumerate(fold_paths):
    print("Fold ", j+1)
    for i, p in enumerate(predictions):
        print(f"Validating on {names[i]}")
        # print(p[j].shape, GT_labels[j].shape)
        pred = p[j]
        if len(pred.shape) > 3:
            pred = pred.squeeze().astype(np.uint32)
        stats = [matching_dataset(
            GT_labels[j], pred, thresh=t, show_progress=False
            ) for t in taus]
        model_stats.append(stats)
        dice = dice_coeff(
            np.where(GT_labels[j] > 0, 1, 0),
            np.where(pred > 0, 1, 0)
        )
        print(f"Dice Coefficient: {dice}")
        for t in taus:
            model_ref_name.append(names[i])
            fold_ref.append(j+1)
        # uncomment for ALL plots : 
        # plot_performance(taus, stats, name=names[i])
        print("*"*20)
Fold  1
Validating on StarDist
Dice Coefficient: 0.6993586695284486
********************
Validating on StarDist - default
Dice Coefficient: 0.39347722427914283
********************
Validating on Cellpose
Dice Coefficient: 0.7702901873444256
********************
Validating on Cellpose - default
Dice Coefficient: 0.5247862438123209
********************
Validating on SegResNet
Dice Coefficient: 0.7726461999842159
********************
Validating on SwinUNetR
Dice Coefficient: 0.8408113189509576
********************
Validating on WNet3D
Dice Coefficient: 0.5896556737502726
********************
Validating on Otsu Threshold
Dice Coefficient: 0.4892361595314461
********************
Fold  2
Validating on StarDist
Dice Coefficient: 0.7944590859420196
********************
Validating on StarDist - default
Dice Coefficient: 0.7251454473235356
********************
Validating on Cellpose
Dice Coefficient: 0.8287320218217887
********************
Validating on Cellpose - default
Dice Coefficient: 0.5633677060172622
********************
Validating on SegResNet
Dice Coefficient: 0.8013458950201884
********************
Validating on SwinUNetR
Dice Coefficient: 0.7887976385316811
********************
Validating on WNet3D
Dice Coefficient: 0.8087545264446371
********************
Validating on Otsu Threshold
Dice Coefficient: 0.665591153771321
********************
Fold  3
Validating on StarDist
Dice Coefficient: 0.7584470232902522
********************
Validating on StarDist - default
Dice Coefficient: 0.5047402543506129
********************
Validating on Cellpose
Dice Coefficient: 0.7932470010023605
********************
Validating on Cellpose - default
Dice Coefficient: 0.5596045059982444
********************
Validating on SegResNet
Dice Coefficient: 0.8472361720498317
********************
Validating on SwinUNetR
Dice Coefficient: 0.8554751436182162
********************
Validating on WNet3D
Dice Coefficient: 0.8113536341409202
********************
Validating on Otsu Threshold
Dice Coefficient: 0.7887478958953774
********************
dfs = [dataset_matching_stats_to_df(s) for s in model_stats]
df = pd.concat(dfs)
df["Model"] = model_ref_name
df["Fold"] = fold_ref
df["thresh"] = df.index
df.set_index(["Model", "Fold", "thresh"], inplace=True)
df
criterion fp tp fn precision recall accuracy f1 n_true n_pred mean_true_score mean_matched_score panoptic_quality by_image
Model Fold thresh
StarDist 1 0.1 iou 1395 2973 412 0.680632 0.878287 0.621967 0.766929 3385 4368 0.537485 0.611970 0.469338 False
0.2 iou 1405 2963 422 0.678342 0.875332 0.618580 0.764349 3385 4368 0.537034 0.613520 0.468944 False
0.3 iou 1482 2886 499 0.660714 0.852585 0.592973 0.744486 3385 4368 0.531009 0.622822 0.463682 False
0.4 iou 1727 2641 744 0.604625 0.780207 0.516628 0.681285 3385 4368 0.506039 0.648597 0.441879 False
0.5 iou 2127 2241 1144 0.513049 0.662038 0.406567 0.578099 3385 4368 0.454747 0.686888 0.397089 False
... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ...
Otsu Threshold 3 0.5 iou 1357 4209 2700 0.756198 0.609205 0.509194 0.674790 6909 5566 0.471521 0.773993 0.522282 False
0.6 iou 2077 3489 3420 0.626842 0.504993 0.388271 0.559359 6909 5566 0.416410 0.824585 0.461239 False
0.7 iou 2758 2808 4101 0.504492 0.406426 0.290473 0.450180 6909 5566 0.352801 0.868057 0.390782 False
0.8 iou 3521 2045 4864 0.367409 0.295991 0.196069 0.327856 6909 5566 0.270493 0.913856 0.299613 False
0.9 iou 4453 1113 5796 0.199964 0.161094 0.097958 0.178437 6909 5566 0.156777 0.973200 0.173655 False

216 rows × 14 columns

df_supervised = df.copy()
df_supervised = df.loc[["StarDist", "StarDist - default", "Cellpose", "Cellpose - default", "SegResNet", "SwinUNetR"]]
temp_cmap = COLORMAP.copy()
temp_cmap.insert(2, COLORMAP[1])
temp_cmap.insert(1, COLORMAP[0])
temp_cmap[2], _ = get_shades(temp_cmap[3])
temp_cmap[0], _ = get_shades(temp_cmap[1])
EXPANDED_COLORMAP = temp_cmap

Precision#

plot_stat_comparison_fold(df_supervised, stat="precision", colormap=EXPANDED_COLORMAP)
if SAVE_PLOTS_AS_PNG:
    plt.savefig("precision_comparison_fold.png", bbox_inches = "tight")
if SAVE_PLOTS_AS_SVG:
    plt.savefig("precision_comparison_fold.svg", bbox_inches = "tight")
../_images/2e5a9b61d578499704eecd0a92d71a43f6da551f93a975abb99349db4002d474.png

Recall#

plot_stat_comparison_fold(df_supervised, stat="recall", colormap=EXPANDED_COLORMAP)
if SAVE_PLOTS_AS_PNG:
    plt.savefig("recall_comparison_fold.png", bbox_inches = "tight")
if SAVE_PLOTS_AS_SVG:
    plt.savefig("recall_comparison_fold.svg", bbox_inches = "tight")
../_images/84fcbed71c2611165a3cf92b6147b556ff3e817ae99e7acc51a5e30ee48ac0e7.png

F1 Score#

plot_stat_comparison_fold(df_supervised, stat="f1", colormap=EXPANDED_COLORMAP)
if SAVE_PLOTS_AS_PNG:
    plt.savefig("f1_comparison_fold.png", bbox_inches = "tight")
if SAVE_PLOTS_AS_SVG:
    plt.savefig("f1_comparison_fold.svg", bbox_inches = "tight")
../_images/26ffc8c6e7d96bf5639bb515b49d3c3b5b961dbb81f6c6555e5ac83a23e859d6.png

Panoptic Quality#

plot_stat_comparison_fold(df_supervised, stat="panoptic_quality", colormap=EXPANDED_COLORMAP)
if SAVE_PLOTS_AS_PNG:
    plt.savefig("panoptic_quality_comparison_fold.png", bbox_inches = "tight")
if SAVE_PLOTS_AS_SVG:
    plt.savefig("panoptic_quality_comparison_fold.svg", bbox_inches = "tight")
../_images/22d347c8220b418f3544b482637def3291d595691d1a76d6e08d6529ea17049c.png

WNet3D - No artifacts#

# evaluate stats for WNet3D on an extra fold, where visual is restricted to [50:, :, :]
stats = [matching_dataset(
    GT_labels[0][:50,:,:], wnet_instance[0][:50,:,:], thresh=t, show_progress=False
    ) for t in taus]
partial_df = dataset_matching_stats_to_df(stats)
partial_df["Model"] = "WNet3D - No artifacts"
partial_df["Fold"] = 1
# copy rows from WNet entries for fold 2/3
rows_fold2 = df.loc[("WNet3D", 2)].copy()
rows_fold3 = df.loc[("WNet3D", 3)].copy()
rows_fold2["Model"] = "WNet3D - No artifacts"
rows_fold2["Fold"] = 2
rows_fold3["Fold"] = 3
rows_fold3["Model"] = "WNet3D - No artifacts"
partial_df = pd.concat([partial_df, rows_fold2, rows_fold3])
partial_df.set_index(["Model", "Fold", partial_df.index], inplace=True)
df = pd.concat([df, partial_df])
1321758596.py (9): indexing past lexsort depth may impact performance.
1321758596.py (10): indexing past lexsort depth may impact performance.
df
criterion fp tp fn precision recall accuracy f1 n_true n_pred mean_true_score mean_matched_score panoptic_quality by_image
Model Fold thresh
StarDist 1 0.1 iou 1395 2973 412 0.680632 0.878287 0.621967 0.766929 3385 4368 0.537485 0.611970 0.469338 False
0.2 iou 1405 2963 422 0.678342 0.875332 0.618580 0.764349 3385 4368 0.537034 0.613520 0.468944 False
0.3 iou 1482 2886 499 0.660714 0.852585 0.592973 0.744486 3385 4368 0.531009 0.622822 0.463682 False
0.4 iou 1727 2641 744 0.604625 0.780207 0.516628 0.681285 3385 4368 0.506039 0.648597 0.441879 False
0.5 iou 2127 2241 1144 0.513049 0.662038 0.406567 0.578099 3385 4368 0.454747 0.686888 0.397089 False
... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ...
WNet3D - No artifacts 3 0.5 iou 974 4656 2253 0.826998 0.673904 0.590638 0.742643 6909 5630 0.518749 0.769767 0.571662 False
0.6 iou 1640 3990 2919 0.708703 0.577508 0.466721 0.636414 6909 5630 0.467688 0.809838 0.515393 False
0.7 iou 2400 3230 3679 0.573712 0.467506 0.346976 0.515193 6909 5630 0.396466 0.848045 0.436907 False
0.8 iou 3397 2233 4676 0.396625 0.323202 0.216670 0.356169 6909 5630 0.288912 0.893907 0.318382 False
0.9 iou 4704 926 5983 0.164476 0.134028 0.079738 0.147699 6909 5630 0.129802 0.968467 0.143042 False

243 rows × 14 columns

Otsu Threshold - No artifacts#

# evaluate stats for WNet3D on an extra fold, where visual is restricted to [50:, :, :]
stats = [matching_dataset(
    GT_labels[0][:50,:,:], threshold_only[0][:50,:,:], thresh=t, show_progress=False
    ) for t in taus]
partial_df = dataset_matching_stats_to_df(stats)
partial_df["Model"] = "Otsu Threshold - No artifacts"
partial_df["Fold"] = 1

rows_fold2 = df.loc[("Otsu Threshold", 2)].copy()
rows_fold3 = df.loc[("Otsu Threshold", 3)].copy()
rows_fold2["Model"] = "Otsu Threshold - No artifacts"
rows_fold2["Fold"] = 2
rows_fold3["Fold"] = 3
rows_fold3["Model"] = "Otsu Threshold - No artifacts"
partial_df = pd.concat([partial_df, rows_fold2, rows_fold3])
partial_df.set_index(["Model", "Fold", partial_df.index], inplace=True)
df = pd.concat([df, partial_df])
2306530423.py (9): indexing past lexsort depth may impact performance.
2306530423.py (10): indexing past lexsort depth may impact performance.
plot_stat_comparison_fold(
    df.loc[["StarDist", "StarDist - default", "Cellpose", "Cellpose - default", "SegResNet", "SwinUNetR", "WNet3D", "WNet3D - No artifacts"]],
    stat="f1", colormap=EXPANDED_COLORMAP)
../_images/dfff36218114fe140dbe1c08a4b17410e764c6abdf8c1ded374d20a68e9e3d4c.png
df_no_learning = df.loc[["Otsu Threshold", "Otsu Threshold - No artifacts", "WNet3D - No artifacts", "WNet3D"]]
VO_CMAP = COLORMAP.copy()
VO_CMAP = ["black", "gray"] + VO_CMAP[4:]
plot_stat_comparison_fold(df_no_learning, stat="f1", colormap=VO_CMAP)
if SAVE_PLOTS_AS_PNG:
    plt.savefig("f1_comparison_fold_no_learning.png", bbox_inches = "tight")
if SAVE_PLOTS_AS_SVG:
    plt.savefig("f1_comparison_fold_no_learning.svg", bbox_inches = "tight")
../_images/7bd539f59b940ab52e64a036708a9fe6fc275f598c3195aba5c92baa2ae7f287.png

Statistical tests (outdated & unused)#

Boxplot of model performance across all thresholds#

sns.boxplot(data=df, x="Model", y="f1", palette=EXPANDED_COLORMAP)
plt.xticks(rotation=45)
plt.show()
27598755.py (1): 

Passing `palette` without assigning `hue` is deprecated and will be removed in v0.14.0. Assign the `x` variable to `hue` and set `legend=False` for the same effect.
../_images/760c676246a8d7f873497ac2e7ccfe7baefee3ba130badfa1421f0c703957836.png
from scipy.stats import f_oneway, kruskal

models_f1_means = df.groupby("Model").f1.mean()

f1_starDist = df.loc["StarDist"].f1
f1_cellpose = df.loc["Cellpose"].f1
f1_segresnet = df.loc["SegResNet"].f1
f1_swin = df.loc["SwinUNetR"].f1
f1_wnet = df.loc["WNet3D"].f1
f1_wnet_partial = df.loc["WNet3D - No artifacts"].f1

anova_test = f_oneway(
    # f1_starDist, 
    f1_cellpose, 
    f1_segresnet, 
    f1_swin, 
    f1_wnet, 
    f1_wnet_partial
    )
print("Kruskal test: ", anova_test)
Kruskal test:  F_onewayResult(statistic=0.4431995081645192, pvalue=0.777181391641984)
# Significant difference between the models?
kruskal_test = kruskal(
    f1_starDist,
    f1_cellpose, 
    f1_segresnet, 
    f1_swin, 
    f1_wnet, 
    f1_wnet_partial
    )
print("Comparisons of F1 scores between all models:")
print("- StarDist\n- Cellpose\n- SegResNet\n- SwinUNetR\n- WNet3D\n- WNet3D - No artifacts")
print("Kruskal-Wallis test: ", kruskal_test)
Comparisons of F1 scores between all models:
- StarDist
- Cellpose
- SegResNet
- SwinUNetR
- WNet3D
- WNet3D - No artifacts
Kruskal-Wallis test:  KruskalResult(statistic=4.824842539418781, pvalue=0.4376286539608427)
# Best models - significant difference between them?
kruskal_test = kruskal(
    # f1_starDist, 
    f1_cellpose, 
    f1_segresnet, 
    f1_swin, 
    f1_wnet, 
    f1_wnet_partial
    )
print("Comparisons of F1 scores between the best models:")
print("- Cellpose\n- SegResNet\n- SwinUNetR\n- WNet3D\n- WNet3D - No artifacts")
print("Kruskal-Wallis test: ", kruskal_test)
Comparisons of F1 scores between the best models:
- Cellpose
- SegResNet
- SwinUNetR
- WNet3D
- WNet3D - No artifacts
Kruskal-Wallis test:  KruskalResult(statistic=3.5792905060017897, pvalue=0.46592490089849237)
# Best supervised models - significant difference between them?
kruskal_test = kruskal(
    f1_cellpose, 
    f1_segresnet, 
    f1_swin
    )
print("Comparisons of F1 scores between the best supervised models:")
print("- Cellpose\n- SegResNet\n- SwinUNetR")
print("Kruskal-Wallis test: ", kruskal_test)
Comparisons of F1 scores between the best supervised models:
- Cellpose
- SegResNet
- SwinUNetR
Kruskal-Wallis test:  KruskalResult(statistic=0.6194921208470987, pvalue=0.7336332310839804)