Figure 1.f/g : Training efficiency#
Goals :
Show that 3D models are more efficient in terms of training data needed to reach a certain accuracy, compared to 2D models.
(Also shows slight advantage of 3D models in terms of accuracy, and performance of self-supervised model.)
%load_ext autoreload
%autoreload 2
from pathlib import Path
from tifffile import imread
import sys
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import numpy as np
from scipy.stats import kruskal
import pyclesperanto_prototype as cle
from skimage.morphology import remove_small_objects
sys.path.append("../..")
from utils import *
from plots import *
show_params()
Plot parameters (set in plots.py) :
- COLORMAP : ████████
- DPI : 200
- Data path : C:\Users\Cyril\Desktop\Code\CELLSEG_BENCHMARK
- Font size : 20
- Title font size : 25.0
- Label font size : 20.0
# expanded colormap has darker and lighter shades for each original color (see get_shades in utils.py)
# See intensity parameter in get_shades to adjust the intensity of the shades
# The Cellpose color needs to be there twice, we insert it between color 1 and 2
# Same for StarDist, default color is COLORMAP[0]
temp_cmap = COLORMAP.copy()
temp_cmap.insert(2, COLORMAP[1])
temp_cmap.insert(1, COLORMAP[0])
temp_cmap[2], _ = get_shades(temp_cmap[3])
temp_cmap[0], _ = get_shades(temp_cmap[1])
EXPANDED_COLORMAP = []
for color in temp_cmap[:7]:
colors = get_n_shades(color, 4)
EXPANDED_COLORMAP.extend(colors)
EXPANDED_COLORMAP.extend(COLORMAP[7:])
SAVE_PLOTS_AS_PNG = False
SAVE_PLOTS_AS_SVG = True
Data loading#
image_folder = DATA_PATH / "RESULTS/SPLITS/Analysis/dataset_splits"
c5 = imread(image_folder / "c5.tif")
c3 = imread(image_folder / "c3.tif")
visual = imread(image_folder / "visual.tif")
gt_dict = {
"visual": visual,
"c3": c3,
"c5": c5,
}
Order data by model and split#
def find_images(path, split):
return list(path.glob(f"*{split}*.tif"))
Supervised models#
Show code cell source
visual_preds = {
"Cellpose": {
"1090": find_images(image_folder / "c1_5/cp", "1090")[0],
"2080": find_images(image_folder / "c1_5/cp", "2080")[0],
"6040": find_images(image_folder / "c1_5/cp", "6040")[0],
"8020": find_images(image_folder / "c1_5/cp", "8020")[0],
},
"Cellpose - default": {
"1090": find_images(image_folder / "c1_5/cp/default", "1090")[0],
"2080": find_images(image_folder / "c1_5/cp/default", "2080")[0],
"6040": find_images(image_folder / "c1_5/cp/default", "6040")[0],
"8020": find_images(image_folder / "c1_5/cp/default", "8020")[0],
},
"StarDist - default": {
"1090": find_images(image_folder / "c1_5/sd", "1090")[0],
"2080": find_images(image_folder / "c1_5/sd", "2080")[0],
"6040": find_images(image_folder / "c1_5/sd", "6040")[0],
"8020": find_images(image_folder / "c1_5/sd", "8020")[0],
},
"StarDist": {
"1090": find_images(image_folder / "c1_5/sd/tuned", "10")[0],
"2080": find_images(image_folder / "c1_5/sd/tuned", "20")[0],
"6040": find_images(image_folder / "c1_5/sd/tuned", "60")[0],
"8020": find_images(image_folder / "c1_5/sd/tuned", "80")[0],
},
"SegResNet": {
"1090": find_images(image_folder / "c1_5/segres", "1090")[0],
"2080": find_images(image_folder / "c1_5/segres", "2080")[0],
"6040": find_images(image_folder / "c1_5/segres", "6040")[0],
"8020": find_images(image_folder / "c1_5/segres", "8020")[0],
},
"SwinUNetR": {
"1090": find_images(image_folder / "c1_5/swin", "1090")[0],
"2080": find_images(image_folder / "c1_5/swin", "2080")[0],
"6040": find_images(image_folder / "c1_5/swin", "6040")[0],
"8020": find_images(image_folder / "c1_5/swin", "8020")[0],
},
}
c3_preds = {
"Cellpose": {
"1090": find_images(image_folder / "c1245_v/cp", "1090")[0],
"2080": find_images(image_folder / "c1245_v/cp", "2080")[0],
"6040": find_images(image_folder / "c1245_v/cp", "6040")[0],
"8020" : find_images(image_folder / "c1245_v/cp", "8020")[0],
},
"Cellpose - default": {
"1090": find_images(image_folder / "c1245_v/cp/default", "1090")[0],
"2080": find_images(image_folder / "c1245_v/cp/default", "2080")[0],
"6040": find_images(image_folder / "c1245_v/cp/default", "6040")[0],
"8020": find_images(image_folder / "c1245_v/cp/default", "8020")[0],
},
"StarDist - default": {
"1090": find_images(image_folder / "c1245_v/sd", "1090")[0],
"2080": find_images(image_folder / "c1245_v/sd", "2080")[0],
"6040": find_images(image_folder / "c1245_v/sd", "6040")[0],
"8020": find_images(image_folder / "c1245_v/sd", "8020")[0],
},
"StarDist": {
"1090": find_images(image_folder / "c1245_v/sd/tuned", "10")[0],
"2080": find_images(image_folder / "c1245_v/sd/tuned", "20")[0],
"6040": find_images(image_folder / "c1245_v/sd/tuned", "60")[0],
"8020": find_images(image_folder / "c1245_v/sd/tuned", "80")[0],
},
"SegResNet": {
"1090": find_images(image_folder / "c1245_v/segres", "1090")[0],
"2080": find_images(image_folder / "c1245_v/segres", "2080")[0],
"6040": find_images(image_folder / "c1245_v/segres", "6040")[0],
"8020": find_images(image_folder / "c1245_v/segres", "8020")[0],
},
"SwinUNetR": {
"1090": find_images(image_folder / "c1245_v/swin", "1090")[0],
"2080": find_images(image_folder / "c1245_v/swin", "2080")[0],
"6040": find_images(image_folder / "c1245_v/swin", "6040")[0],
"8020": find_images(image_folder / "c1245_v/swin", "8020")[0],
},
}
c5_preds = {
"Cellpose": {
"1090": find_images(image_folder / "c1-4_v/cp", "1090")[0],
"2080": find_images(image_folder / "c1-4_v/cp", "2080")[0],
"6040": find_images(image_folder / "c1-4_v/cp", "6040")[0],
"8020" : find_images(image_folder / "c1-4_v/cp", "8020")[0],
},
"Cellpose - default": {
"1090": find_images(image_folder / "c1-4_v/cp/default", "1090")[0],
"2080": find_images(image_folder / "c1-4_v/cp/default", "2080")[0],
"6040": find_images(image_folder / "c1-4_v/cp/default", "6040")[0],
"8020": find_images(image_folder / "c1-4_v/cp/default", "8020")[0],
},
"StarDist - default": {
"1090": find_images(image_folder / "c1-4_v/sd", "1090")[0],
"2080": find_images(image_folder / "c1-4_v/sd", "2080")[0],
"6040": find_images(image_folder / "c1-4_v/sd", "6040")[0],
"8020": find_images(image_folder / "c1-4_v/sd", "8020")[0],
},
"StarDist": {
"1090": find_images(image_folder / "c1-4_v/sd/tuned", "10")[0],
"2080": find_images(image_folder / "c1-4_v/sd/tuned", "20")[0],
"6040": find_images(image_folder / "c1-4_v/sd/tuned", "60")[0],
"8020": find_images(image_folder / "c1-4_v/sd/tuned", "80")[0],
},
"SegResNet": {
"1090": find_images(image_folder / "c1-4_v/segres", "1090")[0],
"2080": find_images(image_folder / "c1-4_v/segres", "2080")[0],
"6040": find_images(image_folder / "c1-4_v/segres", "6040")[0],
"8020": find_images(image_folder / "c1-4_v/segres", "8020")[0],
},
"SwinUNetR": {
"1090": find_images(image_folder / "c1-4_v/swin", "1090")[0],
"2080": find_images(image_folder / "c1-4_v/swin", "2080")[0],
"6040": find_images(image_folder / "c1-4_v/swin", "6040")[0],
"8020": find_images(image_folder / "c1-4_v/swin", "8020")[0],
},
}
# organize as DataFrame
df = pd.DataFrame()
splits = [visual_preds, c3_preds, c5_preds]
for i, gt_name in enumerate(gt_dict):
preds = splits[i]
for model_name in preds:
for split in preds[model_name]:
gt = gt_dict[gt_name]
df = pd.concat([df, pd.DataFrame({
"model": model_name,
"split": split[:2] + "/" + split[2:],
"gt": gt_name,
"path": [preds[model_name][split]],
})])
df.reset_index(inplace=True, drop=True)
df
model | split | gt | path | |
---|---|---|---|---|
0 | Cellpose | 10/90 | visual | C:\Users\Cyril\Desktop\Code\CELLSEG_BENCHMARK\... |
1 | Cellpose | 20/80 | visual | C:\Users\Cyril\Desktop\Code\CELLSEG_BENCHMARK\... |
2 | Cellpose | 60/40 | visual | C:\Users\Cyril\Desktop\Code\CELLSEG_BENCHMARK\... |
3 | Cellpose | 80/20 | visual | C:\Users\Cyril\Desktop\Code\CELLSEG_BENCHMARK\... |
4 | Cellpose - default | 10/90 | visual | C:\Users\Cyril\Desktop\Code\CELLSEG_BENCHMARK\... |
... | ... | ... | ... | ... |
67 | SegResNet | 80/20 | c5 | C:\Users\Cyril\Desktop\Code\CELLSEG_BENCHMARK\... |
68 | SwinUNetR | 10/90 | c5 | C:\Users\Cyril\Desktop\Code\CELLSEG_BENCHMARK\... |
69 | SwinUNetR | 20/80 | c5 | C:\Users\Cyril\Desktop\Code\CELLSEG_BENCHMARK\... |
70 | SwinUNetR | 60/40 | c5 | C:\Users\Cyril\Desktop\Code\CELLSEG_BENCHMARK\... |
71 | SwinUNetR | 80/20 | c5 | C:\Users\Cyril\Desktop\Code\CELLSEG_BENCHMARK\... |
72 rows × 4 columns
df.iloc[0].path # check if the paths are correct
WindowsPath('C:/Users/Cyril/Desktop/Code/CELLSEG_BENCHMARK/RESULTS/SPLITS/Analysis/dataset_splits/c1_5/cp/cellpose_1090_labels.tif')
Add pretrained WNet3D splits to the data#
df.loc[len(df)] = ["WNet3D - Pretrained", "WNet - Artifacts", "visual", image_folder / "WNet/pretrained/visual_pred.tif"]
df.loc[len(df)] = ["WNet3D - Pretrained", "WNet - Artifacts", "c3", image_folder / "WNet/pretrained/c3_pred.tif"]
df.loc[len(df)] = ["WNet3D - Pretrained", "WNet - Artifacts", "c5", image_folder / "WNet/pretrained/c5_pred.tif"]
df
model | split | gt | path | |
---|---|---|---|---|
0 | Cellpose | 10/90 | visual | C:\Users\Cyril\Desktop\Code\CELLSEG_BENCHMARK\... |
1 | Cellpose | 20/80 | visual | C:\Users\Cyril\Desktop\Code\CELLSEG_BENCHMARK\... |
2 | Cellpose | 60/40 | visual | C:\Users\Cyril\Desktop\Code\CELLSEG_BENCHMARK\... |
3 | Cellpose | 80/20 | visual | C:\Users\Cyril\Desktop\Code\CELLSEG_BENCHMARK\... |
4 | Cellpose - default | 10/90 | visual | C:\Users\Cyril\Desktop\Code\CELLSEG_BENCHMARK\... |
... | ... | ... | ... | ... |
70 | SwinUNetR | 60/40 | c5 | C:\Users\Cyril\Desktop\Code\CELLSEG_BENCHMARK\... |
71 | SwinUNetR | 80/20 | c5 | C:\Users\Cyril\Desktop\Code\CELLSEG_BENCHMARK\... |
72 | WNet3D - Pretrained | WNet - Artifacts | visual | C:\Users\Cyril\Desktop\Code\CELLSEG_BENCHMARK\... |
73 | WNet3D - Pretrained | WNet - Artifacts | c3 | C:\Users\Cyril\Desktop\Code\CELLSEG_BENCHMARK\... |
74 | WNet3D - Pretrained | WNet - Artifacts | c5 | C:\Users\Cyril\Desktop\Code\CELLSEG_BENCHMARK\... |
75 rows × 4 columns
WNet3D - Order by splits#
wnet_splits_preds = {
"1090" : find_images(image_folder / "WNet/10", "pred"),
"2080" : find_images(image_folder / "WNet/20", "pred"),
"6040" : find_images(image_folder / "WNet/60", "pred"),
"8020" : find_images(image_folder / "WNet/80", "pred"),
} # these are lists where each element is an image for a split (3 per split)
for split, image_paths in wnet_splits_preds.items():
split = split[:2] + "/" + split[2:]
for i, path in enumerate(image_paths):
df.loc[len(df)] = ["WNet3D", f"WNet - {split}", "visual", path]
df
model | split | gt | path | |
---|---|---|---|---|
0 | Cellpose | 10/90 | visual | C:\Users\Cyril\Desktop\Code\CELLSEG_BENCHMARK\... |
1 | Cellpose | 20/80 | visual | C:\Users\Cyril\Desktop\Code\CELLSEG_BENCHMARK\... |
2 | Cellpose | 60/40 | visual | C:\Users\Cyril\Desktop\Code\CELLSEG_BENCHMARK\... |
3 | Cellpose | 80/20 | visual | C:\Users\Cyril\Desktop\Code\CELLSEG_BENCHMARK\... |
4 | Cellpose - default | 10/90 | visual | C:\Users\Cyril\Desktop\Code\CELLSEG_BENCHMARK\... |
... | ... | ... | ... | ... |
82 | WNet3D | WNet - 60/40 | visual | C:\Users\Cyril\Desktop\Code\CELLSEG_BENCHMARK\... |
83 | WNet3D | WNet - 60/40 | visual | C:\Users\Cyril\Desktop\Code\CELLSEG_BENCHMARK\... |
84 | WNet3D | WNet - 80/20 | visual | C:\Users\Cyril\Desktop\Code\CELLSEG_BENCHMARK\... |
85 | WNet3D | WNet - 80/20 | visual | C:\Users\Cyril\Desktop\Code\CELLSEG_BENCHMARK\... |
86 | WNet3D | WNet - 80/20 | visual | C:\Users\Cyril\Desktop\Code\CELLSEG_BENCHMARK\... |
87 rows × 4 columns
model_names = ["StarDist - default", "StarDist", "Cellpose - default", "Cellpose", "SegResNet", "SwinUNetR", "WNet3D", "WNet3D - Pretrained"]
Note
For WNet3D, we compute two sets of Dices :
One on c3, c5 and ALL of visual, which contains some artifacts. This is “All data”.
One on c3, c5 and PART of visual, which does not contain artifacts. This is “No artifacts”.
visual_slice = 50
visual_gt_no_artifact = gt_dict["visual"][:visual_slice,:,:]
wnet_c3_pred = imread(image_folder / "WNet/pretrained/c3_pred.tif")
wnet_c5_pred = imread(image_folder / "WNet/pretrained/c5_pred.tif")
wnet_visual_no_artifact = imread(image_folder / "WNet/pretrained/visual_pred.tif")[:visual_slice,:,:]
path_df = df.copy()
Dice computation#
Pre-processing for SwinUNetR and SegResNet#
# We use the same threshold for all splits, estimated from the training data from the supervised benchmark figure
swin_params = {
"thresh": 0.4,
"spot_sigma": 0.65,
"outline_sigma": 0.65,
}
segres_params = {
"thresh": 0.3,
"spot_sigma": 0.65,
"outline_sigma": 0.65,
}
wnet_params = {
"thresh": 0.6,
"spot_sigma": 0.65,
"outline_sigma": 0.65,
}
def models_instance_preprocessing(volume, params):
volume = np.where(volume > params["thresh"], 1, 0)
labels = cle.voronoi_otsu_labeling(volume, spot_sigma=params["spot_sigma"], outline_sigma=params["outline_sigma"])
labels = remove_small_objects(np.array(labels), min_size=5)
return labels
def wnet_preprocessing(volume, channel_foreground, params):
# this should only be done on volumes with more than 3 dimensions
if len(volume.shape) < 4:
return volume
volume = volume[channel_foreground]
volume = np.where(volume > params["thresh"], 1, 0)
labels = cle.voronoi_otsu_labeling(volume, spot_sigma=params["spot_sigma"], outline_sigma=params["outline_sigma"])
return labels
# import napari
# viewer = napari.Viewer()
# # show visual gt and swin preprocessed for each split
# swin_path = df.iloc[9]
# swin_pred = imread(swin_path.path)
# viewer.add_labels(gt_dict["c3"], name="c3_gt")
# swin_pred_processed = models_instance_preprocessing(swin_pred, swin_params)
# viewer.add_labels(swin_pred_processed, name="swin_pred")
# viewer.add_image(np.swapaxes(swin_pred), name="swin_pred_raw", colormap="turbo")
df.set_index(
[
"model",
# "split"
],
inplace=True
)
Dice score computation#
def model_dices_across_splits(df, verbose=False):
dices_df = pd.DataFrame(columns=["Dice", "Model", "Split", "GT"])
for model in model_names:
for split in df.loc[model]["split"].unique():
for i, row in df.loc[model][df.loc[model]["split"] == split].iterrows():
gt = gt_dict[row["gt"]]
pred = imread(row.path)
if model == "SwinUNetR":
pred = models_instance_preprocessing(pred, swin_params)
if model == "SegResNet":
pred = models_instance_preprocessing(pred, segres_params)
if model == "WNet3D":
# look into path for foreground channel : c0 indicates first channel is foreground, c1 indicates second channel is foreground
channel = 1 if "c1" in str(row.path) else 0
pred = wnet_preprocessing(pred, channel, wnet_params)
if row["gt"] == "visual":
pred = pred[:visual_slice,:,:]
gt = gt[:visual_slice,:,:]
if verbose:
print(f"Model: {model}, Split: {split}, GT: {row['gt']}")
print(f"Image shape: {pred.shape}")
print(f"GT shape: {gt_dict[row['gt']].shape}")
gt = np.where(gt > 0, 1, 0)
pred = np.where(pred > 0, 1, 0)
dice = dice_coeff(gt, pred)
if verbose:
print(f"Dice: {dice}")
print("_"*20)
dices_df.loc[len(dices_df)] = [dice, model, split, row["gt"]]
return dices_df
dices_df = model_dices_across_splits(df, verbose=True)
Model: StarDist - default, Split: 10/90, GT: visual
Image shape: (65, 127, 214)
GT shape: (65, 127, 214)
Dice: 0.45813969740382265
____________________
Model: StarDist - default, Split: 10/90, GT: c3
Image shape: (149, 105, 147)
GT shape: (149, 105, 147)
Dice: 1.7072712683318253e-05
____________________
Model: StarDist - default, Split: 10/90, GT: c5
Image shape: (124, 86, 94)
GT shape: (124, 86, 94)
Dice: 0.7047307334629646
____________________
Model: StarDist - default, Split: 20/80, GT: visual
Image shape: (65, 127, 214)
GT shape: (65, 127, 214)
Dice: 0.6900064226075787
____________________
Model: StarDist - default, Split: 20/80, GT: c3
Image shape: (149, 105, 147)
GT shape: (149, 105, 147)
Dice: 1.7072712683318253e-05
____________________
Model: StarDist - default, Split: 20/80, GT: c5
Image shape: (124, 86, 94)
GT shape: (124, 86, 94)
Dice: 2.441108263151471e-05
____________________
Model: StarDist - default, Split: 60/40, GT: visual
Image shape: (65, 127, 214)
GT shape: (65, 127, 214)
Dice: 0.6491599555678979
____________________
Model: StarDist - default, Split: 60/40, GT: c3
Image shape: (149, 105, 147)
GT shape: (149, 105, 147)
Dice: 0.5219389414443077
____________________
Model: StarDist - default, Split: 60/40, GT: c5
Image shape: (124, 86, 94)
GT shape: (124, 86, 94)
Dice: 0.725773126793911
____________________
Model: StarDist - default, Split: 80/20, GT: visual
Image shape: (65, 127, 214)
GT shape: (65, 127, 214)
Dice: 0.6641397223887263
____________________
Model: StarDist - default, Split: 80/20, GT: c3
Image shape: (149, 105, 147)
GT shape: (149, 105, 147)
Dice: 0.5228543567663311
____________________
Model: StarDist - default, Split: 80/20, GT: c5
Image shape: (124, 86, 94)
GT shape: (124, 86, 94)
Dice: 0.6306411323896752
____________________
Model: StarDist, Split: 10/90, GT: visual
Image shape: (65, 127, 214)
GT shape: (65, 127, 214)
Dice: 0.7079987191802753
____________________
Model: StarDist, Split: 10/90, GT: c3
Image shape: (149, 105, 147)
GT shape: (149, 105, 147)
Dice: 0.7066773429884418
____________________
Model: StarDist, Split: 10/90, GT: c5
Image shape: (124, 86, 94)
GT shape: (124, 86, 94)
Dice: 0.666873158778668
____________________
Model: StarDist, Split: 20/80, GT: visual
Image shape: (65, 127, 214)
GT shape: (65, 127, 214)
Dice: 0.6870511156437357
____________________
Model: StarDist, Split: 20/80, GT: c3
Image shape: (149, 105, 147)
GT shape: (149, 105, 147)
Dice: 1.7072712683318253e-05
____________________
Model: StarDist, Split: 20/80, GT: c5
Image shape: (124, 86, 94)
GT shape: (124, 86, 94)
Dice: 0.6354648592898188
____________________
Model: StarDist, Split: 60/40, GT: visual
Image shape: (65, 127, 214)
GT shape: (65, 127, 214)
Dice: 0.698111261522594
____________________
Model: StarDist, Split: 60/40, GT: c3
Image shape: (149, 105, 147)
GT shape: (149, 105, 147)
Dice: 0.7395951257585239
____________________
Model: StarDist, Split: 60/40, GT: c5
Image shape: (124, 86, 94)
GT shape: (124, 86, 94)
Dice: 0.7753721081706993
____________________
Model: StarDist, Split: 80/20, GT: visual
Image shape: (65, 127, 214)
GT shape: (65, 127, 214)
Dice: 0.6634502993726294
____________________
Model: StarDist, Split: 80/20, GT: c3
Image shape: (149, 105, 147)
GT shape: (149, 105, 147)
Dice: 0.6909480066618091
____________________
Model: StarDist, Split: 80/20, GT: c5
Image shape: (124, 86, 94)
GT shape: (124, 86, 94)
Dice: 0.7941101315957583
____________________
Model: Cellpose - default, Split: 10/90, GT: visual
Image shape: (65, 127, 214)
GT shape: (65, 127, 214)
Dice: 0.6095447577151236
____________________
Model: Cellpose - default, Split: 10/90, GT: c3
Image shape: (149, 105, 147)
GT shape: (149, 105, 147)
Dice: 1.7072712683318253e-05
____________________
Model: Cellpose - default, Split: 10/90, GT: c5
Image shape: (124, 86, 94)
GT shape: (124, 86, 94)
Dice: 2.441108263151471e-05
____________________
Model: Cellpose - default, Split: 20/80, GT: visual
Image shape: (65, 127, 214)
GT shape: (65, 127, 214)
Dice: 0.38262849366691093
____________________
Model: Cellpose - default, Split: 20/80, GT: c3
Image shape: (149, 105, 147)
GT shape: (149, 105, 147)
Dice: 0.41333207922388543
____________________
Model: Cellpose - default, Split: 20/80, GT: c5
Image shape: (124, 86, 94)
GT shape: (124, 86, 94)
Dice: 0.4346342300432785
____________________
Model: Cellpose - default, Split: 60/40, GT: visual
Image shape: (65, 127, 214)
GT shape: (65, 127, 214)
Dice: 0.3216480826760946
____________________
Model: Cellpose - default, Split: 60/40, GT: c3
Image shape: (149, 105, 147)
GT shape: (149, 105, 147)
Dice: 0.46105638941004795
____________________
Model: Cellpose - default, Split: 60/40, GT: c5
Image shape: (124, 86, 94)
GT shape: (124, 86, 94)
Dice: 0.28778826871657753
____________________
Model: Cellpose - default, Split: 80/20, GT: visual
Image shape: (65, 127, 214)
GT shape: (65, 127, 214)
Dice: 0.5324423418095802
____________________
Model: Cellpose - default, Split: 80/20, GT: c3
Image shape: (149, 105, 147)
GT shape: (149, 105, 147)
Dice: 0.3586485504207554
____________________
Model: Cellpose - default, Split: 80/20, GT: c5
Image shape: (124, 86, 94)
GT shape: (124, 86, 94)
Dice: 0.4212584018642991
____________________
Model: Cellpose, Split: 10/90, GT: visual
Image shape: (65, 127, 214)
GT shape: (65, 127, 214)
Dice: 0.7861751295015355
____________________
Model: Cellpose, Split: 10/90, GT: c3
Image shape: (149, 105, 147)
GT shape: (149, 105, 147)
Dice: 0.08595350922131148
____________________
Model: Cellpose, Split: 10/90, GT: c5
Image shape: (124, 86, 94)
GT shape: (124, 86, 94)
Dice: 0.08457272790551985
____________________
Model: Cellpose, Split: 20/80, GT: visual
Image shape: (65, 127, 214)
GT shape: (65, 127, 214)
Dice: 0.7072484308564487
____________________
Model: Cellpose, Split: 20/80, GT: c3
Image shape: (149, 105, 147)
GT shape: (149, 105, 147)
Dice: 0.7293200647090544
____________________
Model: Cellpose, Split: 20/80, GT: c5
Image shape: (124, 86, 94)
GT shape: (124, 86, 94)
Dice: 0.7463728863187176
____________________
Model: Cellpose, Split: 60/40, GT: visual
Image shape: (65, 127, 214)
GT shape: (65, 127, 214)
Dice: 0.6880578161562196
____________________
Model: Cellpose, Split: 60/40, GT: c3
Image shape: (149, 105, 147)
GT shape: (149, 105, 147)
Dice: 0.7393016007796444
____________________
Model: Cellpose, Split: 60/40, GT: c5
Image shape: (124, 86, 94)
GT shape: (124, 86, 94)
Dice: 0.6729493389605614
____________________
Model: Cellpose, Split: 80/20, GT: visual
Image shape: (65, 127, 214)
GT shape: (65, 127, 214)
Dice: 0.7981322424499242
____________________
Model: Cellpose, Split: 80/20, GT: c3
Image shape: (149, 105, 147)
GT shape: (149, 105, 147)
Dice: 0.7286757739677883
____________________
Model: Cellpose, Split: 80/20, GT: c5
Image shape: (124, 86, 94)
GT shape: (124, 86, 94)
Dice: 0.756078768391886
____________________
Model: SegResNet, Split: 10/90, GT: visual
Image shape: (65, 127, 214)
GT shape: (65, 127, 214)
Dice: 0.4221998199819982
____________________
Model: SegResNet, Split: 10/90, GT: c3
Image shape: (149, 105, 147)
GT shape: (149, 105, 147)
Dice: 0.28690334887141233
____________________
Model: SegResNet, Split: 10/90, GT: c5
Image shape: (124, 86, 94)
GT shape: (124, 86, 94)
Dice: 0.5435601236412202
____________________
Model: SegResNet, Split: 20/80, GT: visual
Image shape: (65, 127, 214)
GT shape: (65, 127, 214)
Dice: 0.5148477010728031
____________________
Model: SegResNet, Split: 20/80, GT: c3
Image shape: (149, 105, 147)
GT shape: (149, 105, 147)
Dice: 0.4388548365803519
____________________
Model: SegResNet, Split: 20/80, GT: c5
Image shape: (124, 86, 94)
GT shape: (124, 86, 94)
Dice: 0.751233228097869
____________________
Model: SegResNet, Split: 60/40, GT: visual
Image shape: (65, 127, 214)
GT shape: (65, 127, 214)
Dice: 0.7747215240932079
____________________
Model: SegResNet, Split: 60/40, GT: c3
Image shape: (149, 105, 147)
GT shape: (149, 105, 147)
Dice: 0.8371491990562523
____________________
Model: SegResNet, Split: 60/40, GT: c5
Image shape: (124, 86, 94)
GT shape: (124, 86, 94)
Dice: 0.6906855807000525
____________________
Model: SegResNet, Split: 80/20, GT: visual
Image shape: (65, 127, 214)
GT shape: (65, 127, 214)
Dice: 0.7061559736831078
____________________
Model: SegResNet, Split: 80/20, GT: c3
Image shape: (149, 105, 147)
GT shape: (149, 105, 147)
Dice: 0.8270254509885067
____________________
Model: SegResNet, Split: 80/20, GT: c5
Image shape: (124, 86, 94)
GT shape: (124, 86, 94)
Dice: 0.7074302192549407
____________________
Model: SwinUNetR, Split: 10/90, GT: visual
Image shape: (65, 127, 214)
GT shape: (65, 127, 214)
Dice: 0.8157865498468327
____________________
Model: SwinUNetR, Split: 10/90, GT: c3
Image shape: (149, 105, 147)
GT shape: (149, 105, 147)
Dice: 0.8174437462975013
____________________
Model: SwinUNetR, Split: 10/90, GT: c5
Image shape: (124, 86, 94)
GT shape: (124, 86, 94)
Dice: 0.7018095117791954
____________________
Model: SwinUNetR, Split: 20/80, GT: visual
Image shape: (65, 127, 214)
GT shape: (65, 127, 214)
Dice: 0.7901964762246602
____________________
Model: SwinUNetR, Split: 20/80, GT: c3
Image shape: (149, 105, 147)
GT shape: (149, 105, 147)
Dice: 0.7719215148797227
____________________
Model: SwinUNetR, Split: 20/80, GT: c5
Image shape: (124, 86, 94)
GT shape: (124, 86, 94)
Dice: 0.760966675600822
____________________
Model: SwinUNetR, Split: 60/40, GT: visual
Image shape: (65, 127, 214)
GT shape: (65, 127, 214)
Dice: 0.8177856932892985
____________________
Model: SwinUNetR, Split: 60/40, GT: c3
Image shape: (149, 105, 147)
GT shape: (149, 105, 147)
Dice: 0.8450172476830877
____________________
Model: SwinUNetR, Split: 60/40, GT: c5
Image shape: (124, 86, 94)
GT shape: (124, 86, 94)
Dice: 0.8482937883231139
____________________
Model: SwinUNetR, Split: 80/20, GT: visual
Image shape: (65, 127, 214)
GT shape: (65, 127, 214)
Dice: 0.8279958413580352
____________________
Model: SwinUNetR, Split: 80/20, GT: c3
Image shape: (149, 105, 147)
GT shape: (149, 105, 147)
Dice: 0.8281666840544286
____________________
Model: SwinUNetR, Split: 80/20, GT: c5
Image shape: (124, 86, 94)
GT shape: (124, 86, 94)
Dice: 0.8328909913502661
____________________
Model: WNet3D, Split: WNet - 10/90, GT: visual
Image shape: (50, 127, 214)
GT shape: (65, 127, 214)
Dice: 0.7501949364268551
____________________
Model: WNet3D, Split: WNet - 10/90, GT: visual
Image shape: (50, 127, 214)
GT shape: (65, 127, 214)
Dice: 0.7432460865174213
____________________
Model: WNet3D, Split: WNet - 10/90, GT: visual
Image shape: (50, 127, 214)
GT shape: (65, 127, 214)
Dice: 0.7162874698995267
____________________
Model: WNet3D, Split: WNet - 20/80, GT: visual
Image shape: (50, 127, 214)
GT shape: (65, 127, 214)
Dice: 0.7421859482493997
____________________
Model: WNet3D, Split: WNet - 20/80, GT: visual
Image shape: (50, 127, 214)
GT shape: (65, 127, 214)
Dice: 0.7438786658383496
____________________
Model: WNet3D, Split: WNet - 20/80, GT: visual
Image shape: (50, 127, 214)
GT shape: (65, 127, 214)
Dice: 0.7254014598540146
____________________
Model: WNet3D, Split: WNet - 60/40, GT: visual
Image shape: (50, 127, 214)
GT shape: (65, 127, 214)
Dice: 0.7340520280550711
____________________
Model: WNet3D, Split: WNet - 60/40, GT: visual
Image shape: (50, 127, 214)
GT shape: (65, 127, 214)
Dice: 0.7233859587228788
____________________
Model: WNet3D, Split: WNet - 60/40, GT: visual
Image shape: (50, 127, 214)
GT shape: (65, 127, 214)
Dice: 0.674949028745547
____________________
Model: WNet3D, Split: WNet - 80/20, GT: visual
Image shape: (50, 127, 214)
GT shape: (65, 127, 214)
Dice: 0.7368010056376657
____________________
Model: WNet3D, Split: WNet - 80/20, GT: visual
Image shape: (50, 127, 214)
GT shape: (65, 127, 214)
Dice: 0.7209019295450522
____________________
Model: WNet3D, Split: WNet - 80/20, GT: visual
Image shape: (50, 127, 214)
GT shape: (65, 127, 214)
Dice: 0.6764699294433867
____________________
Model: WNet3D - Pretrained, Split: WNet - Artifacts, GT: visual
Image shape: (65, 127, 214)
GT shape: (65, 127, 214)
Dice: 0.5896556737502726
____________________
Model: WNet3D - Pretrained, Split: WNet - Artifacts, GT: c3
Image shape: (149, 105, 147)
GT shape: (149, 105, 147)
Dice: 0.8113536341409202
____________________
Model: WNet3D - Pretrained, Split: WNet - Artifacts, GT: c5
Image shape: (124, 86, 94)
GT shape: (124, 86, 94)
Dice: 0.8087545264446371
____________________
dices_df
Dice | Model | Split | GT | |
---|---|---|---|---|
0 | 0.458140 | StarDist - default | 10/90 | visual |
1 | 0.000017 | StarDist - default | 10/90 | c3 |
2 | 0.704731 | StarDist - default | 10/90 | c5 |
3 | 0.690006 | StarDist - default | 20/80 | visual |
4 | 0.000017 | StarDist - default | 20/80 | c3 |
... | ... | ... | ... | ... |
82 | 0.720902 | WNet3D | WNet - 80/20 | visual |
83 | 0.676470 | WNet3D | WNet - 80/20 | visual |
84 | 0.589656 | WNet3D - Pretrained | WNet - Artifacts | visual |
85 | 0.811354 | WNet3D - Pretrained | WNet - Artifacts | c3 |
86 | 0.808755 | WNet3D - Pretrained | WNet - Artifacts | c5 |
87 rows × 4 columns
WNet3D - No artifacts#
wnet_preds_no_artifacts = { # No artifacts
"WNet3D - No artifacts": {
"visual": wnet_visual_no_artifact,
"c3": wnet_c3_pred,
"c5": wnet_c5_pred,
}
}
gt_dict_no_artifacts = {
"visual": visual_gt_no_artifact,
"c3": c3,
"c5": c5,
}
for split in wnet_preds_no_artifacts:
for image in wnet_preds_no_artifacts[split]:
gt = gt_dict_no_artifacts[image]
pred = wnet_preds_no_artifacts[split][image]
gt = np.where(gt > 0, 1, 0)
pred = np.where(pred > 0, 1, 0)
dice = dice_coeff(gt, pred)
print(f"Split: {split}, Image: {image}, Dice: {dice}")
dices_df.loc[len(dices_df)] = [dice, "WNet3D - Pretrained", split, image]
Split: WNet3D - No artifacts, Image: visual, Dice: 0.8179572126452918
Split: WNet3D - No artifacts, Image: c3, Dice: 0.8113536341409202
Split: WNet3D - No artifacts, Image: c5, Dice: 0.8087545264446371
Detailed stats#
Shows means and stds of select splits for manuscript. This is the data was used in the text of the manuscript.
dice_df_8020 = dices_df[dices_df["Split"] == "80/20"]
dice_df_8020.groupby("Model", sort=False).agg({"Dice": ["mean", "std"]}).sort_values(("Dice", "mean"), ascending=False)
Dice | ||
---|---|---|
mean | std | |
Model | ||
SwinUNetR | 0.829685 | 0.002778 |
Cellpose | 0.760962 | 0.034985 |
SegResNet | 0.746871 | 0.069419 |
StarDist | 0.716169 | 0.068885 |
StarDist - default | 0.605878 | 0.073826 |
Cellpose - default | 0.437450 | 0.088021 |
dices_df_1090 = dices_df[dices_df["Split"] == "10/90"]
dices_df_1090.groupby("Model", sort=False).agg({"Dice": ["mean", "std"]}).sort_values(("Dice", "mean"), ascending=False)
Dice | ||
---|---|---|
mean | std | |
Model | ||
SwinUNetR | 0.778347 | 0.066288 |
StarDist | 0.693850 | 0.023372 |
SegResNet | 0.417554 | 0.128391 |
StarDist - default | 0.387629 | 0.357609 |
Cellpose | 0.318900 | 0.404672 |
Cellpose - default | 0.203195 | 0.351909 |
dice_df_wnet = dices_df[dices_df["Model"] == "WNet3D - Pretrained"]
dice_df_wnet.groupby("Split").agg({"Dice": ["mean", "std"]})
Dice | ||
---|---|---|
mean | std | |
Split | ||
WNet - Artifacts | 0.736588 | 0.127254 |
WNet3D - No artifacts | 0.812688 | 0.004744 |
dice_df_wnet_trained_on_subsets = dices_df[dices_df["Model"] == "WNet3D"]
dice_df_wnet_trained_on_subsets.groupby("Split").agg({"Dice": ["mean", "std"]})
Dice | ||
---|---|---|
mean | std | |
Split | ||
WNet - 10/90 | 0.736576 | 0.017911 |
WNet - 20/80 | 0.737155 | 0.010214 |
WNet - 60/40 | 0.710796 | 0.031499 |
WNet - 80/20 | 0.711391 | 0.031270 |
Plots#
dices_df['Model_Split'] = dices_df['Model'] + ' (' + dices_df['Split'].astype(str) + ')'
Due to the way seaborn boxplots interact with categories, the axes will be generated for all models/splits, and then separate boxplots are used to have readable boxes for each model.
The data is the same everywhere, this is only to have cleaner figures. The boxplots were then re-assembled separately (while being careful of the axes) to have the final figure.
~~In addition, since SwinUNetR and WNet have much smaller variance, we show a zoomed-in inset for these models.~~ Deprecated.
General plot#
All models (hard to read)#
fig, ax = plt.subplots(figsize=(9, 6), dpi=DPI)
sns.boxplot(
data=dices_df,
x="Model",
y="Dice",
hue="Model_Split",
ax=ax,
palette=EXPANDED_COLORMAP,
# dodge=False,
)
for i, artist in enumerate(ax.artists): # try to center the boxplot on xticks
for j in range(i*6,i*6+6):
line = ax.lines[j]
line.set_color(artist.get_facecolor())
if j % 6 == 4:
x, y = line.get_xydata()[0]
x_center = i // 2
if i % 2:
line.set_xdata([x_center + 0.2, x_center + 0.2, x, x_center + 0.2])
else:
line.set_xdata([x_center - 0.2, x_center - 0.2, x, x_center - 0.2])
ax.spines["right"].set_visible(False)
ax.spines["top"].set_visible(False)
ax.tick_params(axis="both", which="major", labelsize=LEGEND_FONT_SIZE)
ax.tick_params(axis = 'x', rotation = 65)
plt.ylim([-0.02,1])
ax.set_yticks(np.arange(0,1.1,0.1))
sns.despine(
left=False,
right=True,
bottom=False,
top=True,
trim=True,
offset={"bottom": 40, "left": 15},
)
legend = ax.legend(fontsize=LEGEND_FONT_SIZE, bbox_to_anchor=BBOX_TO_ANCHOR, loc=LOC)
fig.patch.set_alpha(0)
ax.patch.set_alpha(0)
legend.get_frame().set_alpha(0)
ax.set_xlabel("", fontsize=LABEL_FONT_SIZE)
ax.set_ylabel("Dice coefficient", fontsize=LABEL_FONT_SIZE)
plt.legend([],[], frameon=False)
plt.show()
if SAVE_PLOTS_AS_PNG:
fig.savefig("Label_efficiency.png", dpi=DPI, bbox_inches="tight")
if SAVE_PLOTS_AS_SVG:
fig.savefig("Label_efficiency.svg", bbox_inches="tight")
C:\Users\Cyril\AppData\Local\Temp\ipykernel_8904\518262752.py:2: UserWarning:
The palette list has fewer values (29) than needed (30) and will cycle, which may produce an uninterpretable plot.
sns.boxplot(
Legend#
fig_leg = plt.figure(figsize=(3, 2))
ax_leg = fig_leg.add_subplot(111)
ax_leg.legend(handles=legend.legend_handles, labels=[text.get_text() for text in legend.texts])
ax_leg.axis('off')
plt.show()
if SAVE_PLOTS_AS_PNG:
fig_leg.savefig("Label_efficiency_legend.png", dpi=DPI, bbox_inches="tight")
if SAVE_PLOTS_AS_SVG:
fig_leg.savefig("Label_efficiency_legend.svg", bbox_inches="tight")
Self-supervised models#
fig, ax = plt.subplots(figsize=(6, 6), dpi=DPI)
wnet_models = ["WNet3D", "WNet3D - Pretrained"]
dices_df_selfsupervised = dices_df[dices_df["Model"].isin(wnet_models)]
sns.boxplot(
data=dices_df_selfsupervised,
x="Model",
y="Dice",
hue="Model_Split",
ax=ax,
palette=EXPANDED_COLORMAP[20:]
# dodge=False,
)
for i, artist in enumerate(ax.artists): # try to center the boxplot on xticks
for j in range(i*6,i*6+6):
line = ax.lines[j]
line.set_color(artist.get_facecolor())
if j % 6 == 4:
x, y = line.get_xydata()[0]
x_center = i // 2
if i % 2:
line.set_xdata([x_center + 0.2, x_center + 0.2, x, x_center + 0.2])
else:
line.set_xdata([x_center - 0.2, x_center - 0.2, x, x_center - 0.2])
ax.spines["right"].set_visible(False)
ax.spines["top"].set_visible(False)
ax.tick_params(axis="both", which="major", labelsize=LEGEND_FONT_SIZE)
ax.tick_params(axis = 'x', rotation = 65)
plt.ylim([-0.02,1])
ax.set_yticks(np.arange(0,1.1,0.1))
sns.despine(
left=False,
right=True,
bottom=False,
top=True,
trim=True,
offset={"bottom": 40, "left": 15},
)
legend = ax.legend(fontsize=LEGEND_FONT_SIZE, bbox_to_anchor=BBOX_TO_ANCHOR, loc=LOC)
fig.patch.set_alpha(0)
ax.patch.set_alpha(0)
legend.get_frame().set_alpha(0)
ax.set_xlabel("", fontsize=LABEL_FONT_SIZE)
ax.set_ylabel("Dice coefficient", fontsize=LABEL_FONT_SIZE)
plt.legend([],[], frameon=False)
plt.show()
if SAVE_PLOTS_AS_PNG:
fig.savefig("Label_efficiency_selfsup.png", dpi=DPI, bbox_inches="tight")
if SAVE_PLOTS_AS_SVG:
fig.savefig("Label_efficiency_selfsup.svg", bbox_inches="tight")
C:\Users\Cyril\AppData\Local\Temp\ipykernel_8904\2984398448.py:4: UserWarning: The palette list has more values (9) than needed (6), which may not be intended.
sns.boxplot(
Boxplots for each model#
models_dfs = [dices_df[dices_df["Model"] == model].copy() for model in model_names]
save_path = Path("label_efficiency")
save_path.mkdir(exist_ok=True)
for i, model in enumerate(model_names):
print(model)
# figsize_width = 1 if model != "WNet - Ours" else 0.33 # does not accept float values
fig, ax = plt.subplots(figsize=(1, 6), dpi=DPI)
df = models_dfs[i]
sns.boxplot(
data=df,
x="Model_Split",
y="Dice",
hue="Split",
ax=ax,
palette=EXPANDED_COLORMAP[0+4*i:4+4*i],
)
ax.tick_params(axis = 'x', rotation = 45)
ax.spines["right"].set_visible(False)
ax.spines["top"].set_visible(False)
ax.tick_params(axis="both", which="major", labelsize=LEGEND_FONT_SIZE)
ax.set_yticks(np.arange(0,1.1,0.1))
ax.set_ylim([-0.02,1])
sns.despine(
left=False,
right=True,
bottom=False,
top=True,
trim=True,
offset={"bottom": 40, "left": 15},
ax=ax
)
ax.set_xticklabels("", fontsize=LABEL_FONT_SIZE)
ax.set_xlabel("", fontsize=LABEL_FONT_SIZE)
ax.set_ylabel("Dice coefficient", fontsize=LABEL_FONT_SIZE)
ax.legend([],[], frameon=False)
ax.patch.set_alpha(0)
legend = ax.get_legend()
if legend:
legend.remove()
# ax.set_title(model, fontsize=LABEL_FONT_SIZE)
fig.patch.set_alpha(0)
plt.show()
if SAVE_PLOTS_AS_PNG:
fig.savefig(f"{str(save_path)}/Label_efficiency_{model}.png", dpi=DPI, bbox_inches="tight")
if SAVE_PLOTS_AS_SVG:
fig.savefig(f"{str(save_path)}/Label_efficiency_{model}.svg", bbox_inches="tight")
StarDist - default
StarDist
Cellpose - default
Cellpose
SegResNet
SwinUNetR
WNet3D
WNet3D - Pretrained
C:\Users\Cyril\AppData\Local\Temp\ipykernel_8904\1066515323.py:9: UserWarning:
The palette list has fewer values (1) than needed (2) and will cycle, which may produce an uninterpretable plot.
sns.boxplot(
Zoom-in for rightmost models#
# same as above with Swin and WNet and y axis on the right
fig, ax = plt.subplots(figsize=(3, 6), dpi=DPI)
zoom_dices_df = dices_df[dices_df["Model"].isin(["SwinUNetR", "WNet3D"])]
sns.boxplot(data=zoom_dices_df, x="Model", y="Dice", hue="Split", ax=ax, palette=EXPANDED_COLORMAP[16:], dodge=True)
ax.spines["left"].set_visible(False)
ax.spines["top"].set_visible(False)
ax.tick_params(axis="both", which="major", labelsize=LEGEND_FONT_SIZE)
ax.tick_params(axis = 'x', rotation = 45,)
# ax.set_ylim([0.8,0.85])
# ax.set_yticks(np.arange(0.8,0.86,0.01))
ax.set_ylim([0.55,0.95])
ax.set_yticks(np.arange(0.55,0.96,0.05))
ax.set_xlabel("", fontsize=LABEL_FONT_SIZE)
ax.set_ylabel("Dice coefficient", fontsize=LABEL_FONT_SIZE)
ax.legend_.remove()
sns.despine(
left=False,
right=True,
bottom=False,
top=True,
trim=True,
offset={"bottom": 40, "right": 15},
)
# legend = ax.legend(fontsize=LEGEND_FONT_SIZE, bbox_to_anchor=BBOX_TO_ANCHOR, loc=LOC)
# legend.get_frame().set_alpha(0)
fig.patch.set_alpha(0)
ax.patch.set_alpha(0)
if SAVE_PLOTS_AS_PNG:
fig.savefig("Label_efficiency_SWIN_WNET_zoom.png", dpi=DPI, bbox_inches="tight")
if SAVE_PLOTS_AS_SVG:
fig.savefig("Label_efficiency_SWIN_WNET_zoom.svg", bbox_inches="tight")
C:\Users\Cyril\AppData\Local\Temp\ipykernel_8904\1916514661.py:4: UserWarning: The palette list has more values (13) than needed (8), which may not be intended.
sns.boxplot(data=zoom_dices_df, x="Model", y="Dice", hue="Split", ax=ax, palette=EXPANDED_COLORMAP[16:], dodge=True)
Statistical tests#
Here we run a Kruskal-Wallis test to assess if all models differ from each other, when performance is averaged across splits.
Performance across models#
dices_df
Dice | Model | Split | GT | Model_Split | |
---|---|---|---|---|---|
0 | 0.458140 | StarDist - default | 10/90 | visual | StarDist - default (10/90) |
1 | 0.000017 | StarDist - default | 10/90 | c3 | StarDist - default (10/90) |
2 | 0.704731 | StarDist - default | 10/90 | c5 | StarDist - default (10/90) |
3 | 0.690006 | StarDist - default | 20/80 | visual | StarDist - default (20/80) |
4 | 0.000017 | StarDist - default | 20/80 | c3 | StarDist - default (20/80) |
... | ... | ... | ... | ... | ... |
85 | 0.811354 | WNet3D - Pretrained | WNet - Artifacts | c3 | WNet3D - Pretrained (WNet - Artifacts) |
86 | 0.808755 | WNet3D - Pretrained | WNet - Artifacts | c5 | WNet3D - Pretrained (WNet - Artifacts) |
87 | 0.817957 | WNet3D - Pretrained | WNet3D - No artifacts | visual | WNet3D - Pretrained (WNet3D - No artifacts) |
88 | 0.811354 | WNet3D - Pretrained | WNet3D - No artifacts | c3 | WNet3D - Pretrained (WNet3D - No artifacts) |
89 | 0.808755 | WNet3D - Pretrained | WNet3D - No artifacts | c5 | WNet3D - Pretrained (WNet3D - No artifacts) |
90 rows × 5 columns
Boxplot of model performance across all splits#
from plots import _format_plot
fig, ax = plt.subplots(figsize=(8, 8), dpi=DPI)
test_df = dices_df[dices_df["Split"] != "WNet - Artifacts"]
test_df = test_df.groupby("Model", sort=False).Dice.apply(list).reset_index()
sns.boxplot(data=dices_df, hue="Model", y="Dice", palette=temp_cmap, ax=ax)
_format_plot(ax, xlabel="Model", ylabel="Dice coefficient", title="Dice coefficient across splits")
fig.patch.set_alpha(0)
ax.patch.set_alpha(0)
plt.show()
if SAVE_PLOTS_AS_PNG:
fig.savefig("Label_efficiency_all_models_pooled.png", dpi=DPI)
if SAVE_PLOTS_AS_SVG:
fig.savefig("Label_efficiency_all_models_pooled.svg", bbox_inches="tight")
C:\Users\Cyril\AppData\Local\Temp\ipykernel_8904\1582456374.py:6: UserWarning: The palette list has more values (10) than needed (8), which may not be intended.
sns.boxplot(data=dices_df, hue="Model", y="Dice", palette=temp_cmap, ax=ax)
print(kruskal(*test_df.Dice.tolist()))
KruskalResult(statistic=49.212286777549934, pvalue=2.0617114022093387e-08)
p-values heatmap - Default notation#
from scikit_posthocs import posthoc_conover
import matplotlib.colors as colors
import matplotlib.patches as mpatches
fig, ax = plt.subplots(figsize=(8, 8), dpi=DPI)
results = posthoc_conover(dices_df, val_col="Dice", group_col="Model", p_adjust="holm")
low_color = COLORMAP[0]
mid_color = COLORMAP[1]
high_color = COLORMAP[2]
equals_1_color = COLORMAP[3]
levels = [0, 0.05, 0.06, 0.99, 1]
colors_list = [low_color, mid_color, high_color, high_color, equals_1_color]
cmap = colors.LinearSegmentedColormap.from_list("", list(zip(levels, colors_list)))
norm = colors.Normalize(vmin=0, vmax=1)
diag = np.diag_indices(results.shape[0])
results.values[diag] = np.nan
sns.heatmap(results, annot=True, cmap=cmap, norm=norm, cbar=False, fmt=".4f", linewidths=0.5, ax=ax)
patches = [mpatches.Patch(color=low_color, label='Above 0 (significant)'),
mpatches.Patch(color=mid_color, label='Below 0.05 (significant)'),
mpatches.Patch(color=high_color, label='Above 0.05 (not significant)'),
mpatches.Patch(color=equals_1_color, label='Equals 1 (not significant)')]
ax.legend(handles=patches, bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0.)
# transparent background
fig.patch.set_alpha(0)
ax.patch.set_alpha(0)
if SAVE_PLOTS_AS_PNG:
plt.savefig("Label_efficiency_posthoc_conover.png", dpi=DPI, bbox_inches="tight")
if SAVE_PLOTS_AS_SVG:
plt.savefig("Label_efficiency_posthoc_conover.svg", bbox_inches="tight")
p-values heatmap - Scientific notation#
from scikit_posthocs import posthoc_conover
import matplotlib.colors as colors
import matplotlib.patches as mpatches
fig, ax = plt.subplots(figsize=(8, 8), dpi=DPI)
results = posthoc_conover(dices_df, val_col="Dice", group_col="Model", p_adjust="holm")
low_color = COLORMAP[0]
mid_color = COLORMAP[1]
high_color = COLORMAP[2]
equals_1_color = COLORMAP[3]
levels = [0, 0.05, 0.06, 0.99, 1]
colors_list = [low_color, mid_color, high_color, high_color, equals_1_color]
cmap = colors.LinearSegmentedColormap.from_list("", list(zip(levels, colors_list)))
norm = colors.Normalize(vmin=0, vmax=1)
diag = np.diag_indices(results.shape[0])
results.values[diag] = np.nan
sns.heatmap(results, annot=True, cmap=cmap, norm=norm, cbar=False, fmt=".2e", linewidths=0.5, ax=ax)
patches = [mpatches.Patch(color=low_color, label='Above 0 (significant)'),
mpatches.Patch(color=mid_color, label='Below 0.05 (significant)'),
mpatches.Patch(color=high_color, label='Above 0.05 (not significant)'),
mpatches.Patch(color=equals_1_color, label='Equals 1 (not significant)')]
ax.legend(handles=patches, bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0.)
# transparent background
fig.patch.set_alpha(0)
ax.patch.set_alpha(0)
if SAVE_PLOTS_AS_PNG:
plt.savefig("Label_efficiency_posthoc_conover_sci.png", dpi=DPI, bbox_inches="tight")
if SAVE_PLOTS_AS_SVG:
plt.savefig("Label_efficiency_posthoc_conover_sci.svg", bbox_inches="tight")
Unused#
test = dices_df.copy()
test["Split"].apply(lambda x: x.split("/")[0][-3:]).values
array(['10', '10', '10', '20', '20', '20', '60', '60', '60', '80', '80',
'80', '10', '10', '10', '20', '20', '20', '60', '60', '60', '80',
'80', '80', '10', '10', '10', '20', '20', '20', '60', '60', '60',
'80', '80', '80', '10', '10', '10', '20', '20', '20', '60', '60',
'60', '80', '80', '80', '10', '10', '10', '20', '20', '20', '60',
'60', '60', '80', '80', '80', '10', '10', '10', '20', '20', '20',
'60', '60', '60', '80', '80', '80', ' 10', ' 10', ' 10', ' 20',
' 20', ' 20', ' 60', ' 60', ' 60', ' 80', ' 80', ' 80', 'cts',
'cts', 'cts', 'cts', 'cts', 'cts'], dtype=object)
dices_df
Dice | Model | Split | GT | Model_Split | |
---|---|---|---|---|---|
0 | 0.458140 | StarDist - default | 10/90 | visual | StarDist - default (10/90) |
1 | 0.000017 | StarDist - default | 10/90 | c3 | StarDist - default (10/90) |
2 | 0.704731 | StarDist - default | 10/90 | c5 | StarDist - default (10/90) |
3 | 0.690006 | StarDist - default | 20/80 | visual | StarDist - default (20/80) |
4 | 0.000017 | StarDist - default | 20/80 | c3 | StarDist - default (20/80) |
... | ... | ... | ... | ... | ... |
85 | 0.811354 | WNet3D - Pretrained | WNet - Artifacts | c3 | WNet3D - Pretrained (WNet - Artifacts) |
86 | 0.808755 | WNet3D - Pretrained | WNet - Artifacts | c5 | WNet3D - Pretrained (WNet - Artifacts) |
87 | 0.817957 | WNet3D - Pretrained | WNet3D - No artifacts | visual | WNet3D - Pretrained (WNet3D - No artifacts) |
88 | 0.811354 | WNet3D - Pretrained | WNet3D - No artifacts | c3 | WNet3D - Pretrained (WNet3D - No artifacts) |
89 | 0.808755 | WNet3D - Pretrained | WNet3D - No artifacts | c5 | WNet3D - Pretrained (WNet3D - No artifacts) |
90 rows × 5 columns
import statsmodels.api as sm
import statsmodels.formula.api as smf
# Fit a linear mixed effects model
mlm_df = dices_df.copy()
mlm_df["Data_percentage"] = mlm_df["Split"].apply(lambda x: x.split("/")[0][-3:])
# replace "cts" by 100 in the Data_percentage column
mlm_df["Data_percentage"] = mlm_df["Data_percentage"].replace("cts", 100)
mlm_df["Data_percentage"] = mlm_df["Data_percentage"].astype(int)
mlm_df
Dice | Model | Split | GT | Model_Split | Data_percentage | |
---|---|---|---|---|---|---|
0 | 0.458140 | StarDist - default | 10/90 | visual | StarDist - default (10/90) | 10 |
1 | 0.000017 | StarDist - default | 10/90 | c3 | StarDist - default (10/90) | 10 |
2 | 0.704731 | StarDist - default | 10/90 | c5 | StarDist - default (10/90) | 10 |
3 | 0.690006 | StarDist - default | 20/80 | visual | StarDist - default (20/80) | 20 |
4 | 0.000017 | StarDist - default | 20/80 | c3 | StarDist - default (20/80) | 20 |
... | ... | ... | ... | ... | ... | ... |
85 | 0.811354 | WNet3D - Pretrained | WNet - Artifacts | c3 | WNet3D - Pretrained (WNet - Artifacts) | 100 |
86 | 0.808755 | WNet3D - Pretrained | WNet - Artifacts | c5 | WNet3D - Pretrained (WNet - Artifacts) | 100 |
87 | 0.817957 | WNet3D - Pretrained | WNet3D - No artifacts | visual | WNet3D - Pretrained (WNet3D - No artifacts) | 100 |
88 | 0.811354 | WNet3D - Pretrained | WNet3D - No artifacts | c3 | WNet3D - Pretrained (WNet3D - No artifacts) | 100 |
89 | 0.808755 | WNet3D - Pretrained | WNet3D - No artifacts | c5 | WNet3D - Pretrained (WNet3D - No artifacts) | 100 |
90 rows × 6 columns
sns.scatterplot(data=mlm_df, x="Data_percentage", y="Dice", hue="Model", style="GT")
<Axes: xlabel='Data_percentage', ylabel='Dice'>
sns.scatterplot(data=mlm_df, x="Data_percentage", y="Dice", hue="GT", legend=False)
<Axes: xlabel='Data_percentage', ylabel='Dice'>
model = smf.mixedlm("Dice ~ Data_percentage + C(Model) + C(GT)", mlm_df.copy(), groups="Model")
result = model.fit()
# Create a DataFrame to hold the p-values
pvalues = pd.DataFrame(result.pvalues, columns=["p-value"])
result.summary()
c:\Users\Cyril\anaconda3\envs\cellseg3d-figures\lib\site-packages\statsmodels\regression\mixed_linear_model.py:2262: ConvergenceWarning: The Hessian matrix at the estimated parameter values is not positive definite.
warnings.warn(msg, ConvergenceWarning)
Model: | MixedLM | Dependent Variable: | Dice |
No. Observations: | 90 | Method: | REML |
No. Groups: | 8 | Scale: | 0.0285 |
Min. group size: | 6 | Log-Likelihood: | 10.5757 |
Max. group size: | 12 | Converged: | Yes |
Mean group size: | 11.2 |
Coef. | Std.Err. | z | P>|z| | [0.025 | 0.975] | |
---|---|---|---|---|---|---|
Intercept | 0.450 | 0.180 | 2.504 | 0.012 | 0.098 | 0.803 |
C(Model)[T.Cellpose - default] | -0.275 | 0.248 | -1.107 | 0.268 | -0.762 | 0.212 |
C(Model)[T.SegResNet] | -0.002 | 0.248 | -0.007 | 0.994 | -0.489 | 0.485 |
C(Model)[T.StarDist] | 0.020 | 0.248 | 0.081 | 0.935 | -0.467 | 0.507 |
C(Model)[T.StarDist - default] | -0.163 | 0.248 | -0.656 | 0.512 | -0.650 | 0.324 |
C(Model)[T.SwinUNetR] | 0.178 | 0.248 | 0.716 | 0.474 | -0.309 | 0.665 |
C(Model)[T.WNet3D] | 0.040 | 0.250 | 0.159 | 0.874 | -0.450 | 0.530 |
C(Model)[T.WNet3D - Pretrained] | -0.005 | 0.256 | -0.020 | 0.984 | -0.507 | 0.497 |
C(GT)[T.c5] | 0.070 | 0.047 | 1.493 | 0.135 | -0.022 | 0.162 |
C(GT)[T.visual] | 0.121 | 0.047 | 2.586 | 0.010 | 0.029 | 0.213 |
Data_percentage | 0.003 | 0.001 | 4.126 | 0.000 | 0.001 | 0.004 |
Model Var | 0.028 |
# use non-parametric repeated measures ANOVA
from scipy.stats import friedmanchisquare
test_df = dices_df[dices_df["Model"] != "WNet3D - Pretrained"].copy()
test_df = test_df.groupby("Model", sort=False)["Dice"]
fried_stats, p_value = friedmanchisquare(*test_df.apply(list))
print(f"Friedman test p-value: {p_value}, statistic: {fried_stats}")
Friedman test p-value: 8.575653390194426e-08, statistic: 43.67462686567162