Figure 1.f/g : Training efficiency#

Goals :

  • Show that 3D models are more efficient in terms of training data needed to reach a certain accuracy, compared to 2D models.

(Also shows slight advantage of 3D models in terms of accuracy, and performance of self-supervised model.)

%load_ext autoreload
%autoreload 2
from pathlib import Path
from tifffile import imread
import sys
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import numpy as np
from scipy.stats import kruskal
import pyclesperanto_prototype as cle
from skimage.morphology import remove_small_objects
sys.path.append("../..")

from utils import *
from plots import *

show_params()
Plot parameters (set in plots.py) : 
- COLORMAP : ███████
- DPI : 200
- Data path : C:\Users\Cyril\Desktop\Code\CELLSEG_BENCHMARK
- Font size : 20
- Title font size : 25.0
- Label font size : 20.0
# expanded colormap has darker and lighter shades for each original color (see get_shades in utils.py)
# See intensity parameter in get_shades to adjust the intensity of the shades
# The Cellpose color needs to be there twice, we insert it between color 1 and 2
# Same for StarDist, default color is COLORMAP[0]
temp_cmap = COLORMAP.copy()
temp_cmap.insert(2, COLORMAP[1])
temp_cmap.insert(1, COLORMAP[0])
temp_cmap[2], _ = get_shades(temp_cmap[3])
temp_cmap[0], _ = get_shades(temp_cmap[1])
EXPANDED_COLORMAP = []
for color in temp_cmap[:7]:
    colors = get_n_shades(color, 4)
    EXPANDED_COLORMAP.extend(colors)
EXPANDED_COLORMAP.extend(COLORMAP[7:])
SAVE_PLOTS_AS_PNG = False
SAVE_PLOTS_AS_SVG = True

Data loading#

image_folder = DATA_PATH / "RESULTS/SPLITS/Analysis/dataset_splits"
c5 = imread(image_folder / "c5.tif")
c3 = imread(image_folder / "c3.tif")
visual = imread(image_folder / "visual.tif")

gt_dict = {
    "visual": visual,
    "c3": c3,
    "c5": c5,
}

Order data by model and split#

def find_images(path, split):
    return list(path.glob(f"*{split}*.tif"))

Supervised models#

Hide code cell source
visual_preds = {
    "Cellpose": {
        "1090": find_images(image_folder / "c1_5/cp", "1090")[0],
        "2080": find_images(image_folder / "c1_5/cp", "2080")[0],
        "6040": find_images(image_folder / "c1_5/cp", "6040")[0],
        "8020": find_images(image_folder / "c1_5/cp", "8020")[0],
    },
    "Cellpose - default": {
        "1090": find_images(image_folder / "c1_5/cp/default", "1090")[0],
        "2080": find_images(image_folder / "c1_5/cp/default", "2080")[0],
        "6040": find_images(image_folder / "c1_5/cp/default", "6040")[0],
        "8020": find_images(image_folder / "c1_5/cp/default", "8020")[0],
    },
    "StarDist - default": {
        "1090": find_images(image_folder / "c1_5/sd", "1090")[0],
        "2080": find_images(image_folder / "c1_5/sd", "2080")[0],
        "6040": find_images(image_folder / "c1_5/sd", "6040")[0],
        "8020": find_images(image_folder / "c1_5/sd", "8020")[0],
    },
    "StarDist": {
        "1090": find_images(image_folder / "c1_5/sd/tuned", "10")[0],
        "2080": find_images(image_folder / "c1_5/sd/tuned", "20")[0],
        "6040": find_images(image_folder / "c1_5/sd/tuned", "60")[0],
        "8020": find_images(image_folder / "c1_5/sd/tuned", "80")[0],
    },
    "SegResNet": {
        "1090": find_images(image_folder / "c1_5/segres", "1090")[0],
        "2080": find_images(image_folder / "c1_5/segres", "2080")[0],
        "6040": find_images(image_folder / "c1_5/segres", "6040")[0],
        "8020": find_images(image_folder / "c1_5/segres", "8020")[0],
    },
    "SwinUNetR": {
        "1090": find_images(image_folder / "c1_5/swin", "1090")[0],
        "2080": find_images(image_folder / "c1_5/swin", "2080")[0],
        "6040": find_images(image_folder / "c1_5/swin", "6040")[0],
        "8020": find_images(image_folder / "c1_5/swin", "8020")[0],
    },
}
c3_preds = {
    "Cellpose": {
        "1090": find_images(image_folder / "c1245_v/cp", "1090")[0],
        "2080": find_images(image_folder / "c1245_v/cp", "2080")[0],
        "6040": find_images(image_folder / "c1245_v/cp", "6040")[0],
        "8020" : find_images(image_folder / "c1245_v/cp", "8020")[0],
    },
    "Cellpose - default": {
        "1090": find_images(image_folder / "c1245_v/cp/default", "1090")[0],
        "2080": find_images(image_folder / "c1245_v/cp/default", "2080")[0],
        "6040": find_images(image_folder / "c1245_v/cp/default", "6040")[0],
        "8020": find_images(image_folder / "c1245_v/cp/default", "8020")[0],
    },
    "StarDist - default": {
        "1090": find_images(image_folder / "c1245_v/sd", "1090")[0],
        "2080": find_images(image_folder / "c1245_v/sd", "2080")[0],
        "6040": find_images(image_folder / "c1245_v/sd", "6040")[0],
        "8020": find_images(image_folder / "c1245_v/sd", "8020")[0],
    },
    "StarDist": {
        "1090": find_images(image_folder / "c1245_v/sd/tuned", "10")[0],
        "2080": find_images(image_folder / "c1245_v/sd/tuned", "20")[0],
        "6040": find_images(image_folder / "c1245_v/sd/tuned", "60")[0],
        "8020": find_images(image_folder / "c1245_v/sd/tuned", "80")[0],
    },
    "SegResNet": {
        "1090": find_images(image_folder / "c1245_v/segres", "1090")[0],
        "2080": find_images(image_folder / "c1245_v/segres", "2080")[0],
        "6040": find_images(image_folder / "c1245_v/segres", "6040")[0],
        "8020": find_images(image_folder / "c1245_v/segres", "8020")[0],
    },
    "SwinUNetR": {
        "1090": find_images(image_folder / "c1245_v/swin", "1090")[0],
        "2080": find_images(image_folder / "c1245_v/swin", "2080")[0],
        "6040": find_images(image_folder / "c1245_v/swin", "6040")[0],
        "8020": find_images(image_folder / "c1245_v/swin", "8020")[0],
    },
}
c5_preds = {
    "Cellpose": {
        "1090": find_images(image_folder / "c1-4_v/cp", "1090")[0],
        "2080": find_images(image_folder / "c1-4_v/cp", "2080")[0],
        "6040": find_images(image_folder / "c1-4_v/cp", "6040")[0],
        "8020" : find_images(image_folder / "c1-4_v/cp", "8020")[0],
    },
    "Cellpose - default": {
        "1090": find_images(image_folder / "c1-4_v/cp/default", "1090")[0],
        "2080": find_images(image_folder / "c1-4_v/cp/default", "2080")[0],
        "6040": find_images(image_folder / "c1-4_v/cp/default", "6040")[0],
        "8020": find_images(image_folder / "c1-4_v/cp/default", "8020")[0],
    },
    "StarDist - default": {
        "1090": find_images(image_folder / "c1-4_v/sd", "1090")[0],
        "2080": find_images(image_folder / "c1-4_v/sd", "2080")[0],
        "6040": find_images(image_folder / "c1-4_v/sd", "6040")[0],
        "8020": find_images(image_folder / "c1-4_v/sd", "8020")[0],
    },
    "StarDist": {
        "1090": find_images(image_folder / "c1-4_v/sd/tuned", "10")[0],
        "2080": find_images(image_folder / "c1-4_v/sd/tuned", "20")[0],
        "6040": find_images(image_folder / "c1-4_v/sd/tuned", "60")[0],
        "8020": find_images(image_folder / "c1-4_v/sd/tuned", "80")[0],
    },
    "SegResNet": {
        "1090": find_images(image_folder / "c1-4_v/segres", "1090")[0],
        "2080": find_images(image_folder / "c1-4_v/segres", "2080")[0],
        "6040": find_images(image_folder / "c1-4_v/segres", "6040")[0],
        "8020": find_images(image_folder / "c1-4_v/segres", "8020")[0],
    },
    "SwinUNetR": {
        "1090": find_images(image_folder / "c1-4_v/swin", "1090")[0],
        "2080": find_images(image_folder / "c1-4_v/swin", "2080")[0],
        "6040": find_images(image_folder / "c1-4_v/swin", "6040")[0],
        "8020": find_images(image_folder / "c1-4_v/swin", "8020")[0],
    },
}
# organize as DataFrame
df = pd.DataFrame()
splits = [visual_preds, c3_preds, c5_preds]
for i, gt_name in enumerate(gt_dict):
    preds = splits[i]
    for model_name in preds:
        for split in preds[model_name]:
            gt = gt_dict[gt_name]
            df = pd.concat([df, pd.DataFrame({
                "model": model_name,
                "split": split[:2] + "/" + split[2:],
                "gt": gt_name,
                "path": [preds[model_name][split]],
            })])
df.reset_index(inplace=True, drop=True)
df
model split gt path
0 Cellpose 10/90 visual C:\Users\Cyril\Desktop\Code\CELLSEG_BENCHMARK\...
1 Cellpose 20/80 visual C:\Users\Cyril\Desktop\Code\CELLSEG_BENCHMARK\...
2 Cellpose 60/40 visual C:\Users\Cyril\Desktop\Code\CELLSEG_BENCHMARK\...
3 Cellpose 80/20 visual C:\Users\Cyril\Desktop\Code\CELLSEG_BENCHMARK\...
4 Cellpose - default 10/90 visual C:\Users\Cyril\Desktop\Code\CELLSEG_BENCHMARK\...
... ... ... ... ...
67 SegResNet 80/20 c5 C:\Users\Cyril\Desktop\Code\CELLSEG_BENCHMARK\...
68 SwinUNetR 10/90 c5 C:\Users\Cyril\Desktop\Code\CELLSEG_BENCHMARK\...
69 SwinUNetR 20/80 c5 C:\Users\Cyril\Desktop\Code\CELLSEG_BENCHMARK\...
70 SwinUNetR 60/40 c5 C:\Users\Cyril\Desktop\Code\CELLSEG_BENCHMARK\...
71 SwinUNetR 80/20 c5 C:\Users\Cyril\Desktop\Code\CELLSEG_BENCHMARK\...

72 rows × 4 columns

df.iloc[0].path # check if the paths are correct
WindowsPath('C:/Users/Cyril/Desktop/Code/CELLSEG_BENCHMARK/RESULTS/SPLITS/Analysis/dataset_splits/c1_5/cp/cellpose_1090_labels.tif')

Add pretrained WNet3D splits to the data#

df.loc[len(df)] = ["WNet3D - Pretrained", "WNet - Artifacts", "visual", image_folder / "WNet/pretrained/visual_pred.tif"]
df.loc[len(df)] = ["WNet3D - Pretrained", "WNet - Artifacts", "c3", image_folder / "WNet/pretrained/c3_pred.tif"]
df.loc[len(df)] = ["WNet3D - Pretrained", "WNet - Artifacts", "c5", image_folder / "WNet/pretrained/c5_pred.tif"]
df
model split gt path
0 Cellpose 10/90 visual C:\Users\Cyril\Desktop\Code\CELLSEG_BENCHMARK\...
1 Cellpose 20/80 visual C:\Users\Cyril\Desktop\Code\CELLSEG_BENCHMARK\...
2 Cellpose 60/40 visual C:\Users\Cyril\Desktop\Code\CELLSEG_BENCHMARK\...
3 Cellpose 80/20 visual C:\Users\Cyril\Desktop\Code\CELLSEG_BENCHMARK\...
4 Cellpose - default 10/90 visual C:\Users\Cyril\Desktop\Code\CELLSEG_BENCHMARK\...
... ... ... ... ...
70 SwinUNetR 60/40 c5 C:\Users\Cyril\Desktop\Code\CELLSEG_BENCHMARK\...
71 SwinUNetR 80/20 c5 C:\Users\Cyril\Desktop\Code\CELLSEG_BENCHMARK\...
72 WNet3D - Pretrained WNet - Artifacts visual C:\Users\Cyril\Desktop\Code\CELLSEG_BENCHMARK\...
73 WNet3D - Pretrained WNet - Artifacts c3 C:\Users\Cyril\Desktop\Code\CELLSEG_BENCHMARK\...
74 WNet3D - Pretrained WNet - Artifacts c5 C:\Users\Cyril\Desktop\Code\CELLSEG_BENCHMARK\...

75 rows × 4 columns

WNet3D - Order by splits#

wnet_splits_preds = {
    "1090" : find_images(image_folder / "WNet/10", "pred"),
    "2080" : find_images(image_folder / "WNet/20", "pred"),
    "6040" : find_images(image_folder / "WNet/60", "pred"),
    "8020" : find_images(image_folder / "WNet/80", "pred"),
} # these are lists where each element is an image for a split (3 per split)
for split, image_paths in wnet_splits_preds.items():
    split = split[:2] + "/" + split[2:]
    for i, path in enumerate(image_paths):
        df.loc[len(df)] = ["WNet3D", f"WNet - {split}", "visual", path]
df
model split gt path
0 Cellpose 10/90 visual C:\Users\Cyril\Desktop\Code\CELLSEG_BENCHMARK\...
1 Cellpose 20/80 visual C:\Users\Cyril\Desktop\Code\CELLSEG_BENCHMARK\...
2 Cellpose 60/40 visual C:\Users\Cyril\Desktop\Code\CELLSEG_BENCHMARK\...
3 Cellpose 80/20 visual C:\Users\Cyril\Desktop\Code\CELLSEG_BENCHMARK\...
4 Cellpose - default 10/90 visual C:\Users\Cyril\Desktop\Code\CELLSEG_BENCHMARK\...
... ... ... ... ...
82 WNet3D WNet - 60/40 visual C:\Users\Cyril\Desktop\Code\CELLSEG_BENCHMARK\...
83 WNet3D WNet - 60/40 visual C:\Users\Cyril\Desktop\Code\CELLSEG_BENCHMARK\...
84 WNet3D WNet - 80/20 visual C:\Users\Cyril\Desktop\Code\CELLSEG_BENCHMARK\...
85 WNet3D WNet - 80/20 visual C:\Users\Cyril\Desktop\Code\CELLSEG_BENCHMARK\...
86 WNet3D WNet - 80/20 visual C:\Users\Cyril\Desktop\Code\CELLSEG_BENCHMARK\...

87 rows × 4 columns

model_names = ["StarDist - default", "StarDist", "Cellpose - default", "Cellpose", "SegResNet", "SwinUNetR", "WNet3D", "WNet3D - Pretrained"]

Note

For WNet3D, we compute two sets of Dices :

  • One on c3, c5 and ALL of visual, which contains some artifacts. This is “All data”.

  • One on c3, c5 and PART of visual, which does not contain artifacts. This is “No artifacts”.

visual_slice = 50
visual_gt_no_artifact = gt_dict["visual"][:visual_slice,:,:]
wnet_c3_pred = imread(image_folder / "WNet/pretrained/c3_pred.tif")
wnet_c5_pred = imread(image_folder / "WNet/pretrained/c5_pred.tif")
wnet_visual_no_artifact = imread(image_folder / "WNet/pretrained/visual_pred.tif")[:visual_slice,:,:]
path_df = df.copy()

Dice computation#

Pre-processing for SwinUNetR and SegResNet#

# We use the same threshold for all splits, estimated from the training data from the supervised benchmark figure 
swin_params = {
    "thresh": 0.4,
    "spot_sigma": 0.65,
    "outline_sigma": 0.65,
}
segres_params = {
    "thresh": 0.3,
    "spot_sigma": 0.65,
    "outline_sigma": 0.65,
}
wnet_params = {
    "thresh": 0.6,
    "spot_sigma": 0.65,
    "outline_sigma": 0.65,
}

def models_instance_preprocessing(volume, params):
    volume = np.where(volume > params["thresh"], 1, 0)
    labels = cle.voronoi_otsu_labeling(volume, spot_sigma=params["spot_sigma"], outline_sigma=params["outline_sigma"])
    labels = remove_small_objects(np.array(labels), min_size=5)
    return labels

def wnet_preprocessing(volume, channel_foreground, params):
    # this should only be done on volumes with more than 3 dimensions
    if len(volume.shape) < 4:
        return volume
    volume = volume[channel_foreground]
    volume = np.where(volume > params["thresh"], 1, 0)
    labels = cle.voronoi_otsu_labeling(volume, spot_sigma=params["spot_sigma"], outline_sigma=params["outline_sigma"])
    return labels
# import napari
# viewer = napari.Viewer()
# # show visual gt and swin preprocessed for each split
# swin_path = df.iloc[9]
# swin_pred = imread(swin_path.path)
# viewer.add_labels(gt_dict["c3"], name="c3_gt")
# swin_pred_processed = models_instance_preprocessing(swin_pred, swin_params)
# viewer.add_labels(swin_pred_processed, name="swin_pred")
# viewer.add_image(np.swapaxes(swin_pred), name="swin_pred_raw", colormap="turbo")
df.set_index(
    [
    "model",
    # "split"
    ],
    inplace=True
             )

Dice score computation#

def model_dices_across_splits(df, verbose=False):
    dices_df = pd.DataFrame(columns=["Dice", "Model", "Split", "GT"])
    for model in model_names:
        for split in df.loc[model]["split"].unique():
            for i, row in df.loc[model][df.loc[model]["split"] == split].iterrows():
                gt = gt_dict[row["gt"]]
                pred = imread(row.path)
                if model == "SwinUNetR":
                    pred = models_instance_preprocessing(pred, swin_params)
                if model == "SegResNet":
                    pred = models_instance_preprocessing(pred, segres_params)
                if model == "WNet3D":
                    # look into path for foreground channel : c0 indicates first channel is foreground, c1 indicates second channel is foreground
                    channel = 1 if "c1" in str(row.path) else 0
                    pred = wnet_preprocessing(pred, channel, wnet_params)
                    if row["gt"] == "visual":
                        pred = pred[:visual_slice,:,:]
                        gt = gt[:visual_slice,:,:]
                if verbose:
                    print(f"Model: {model}, Split: {split}, GT: {row['gt']}")
                    print(f"Image shape: {pred.shape}")
                    print(f"GT shape: {gt_dict[row['gt']].shape}")
                gt = np.where(gt > 0, 1, 0)
                pred = np.where(pred > 0, 1, 0)
                dice = dice_coeff(gt, pred)
                if verbose:
                    print(f"Dice: {dice}")
                    print("_"*20)
                dices_df.loc[len(dices_df)] = [dice, model, split, row["gt"]]
    return dices_df
dices_df = model_dices_across_splits(df, verbose=True)
Model: StarDist - default, Split: 10/90, GT: visual
Image shape: (65, 127, 214)
GT shape: (65, 127, 214)
Dice: 0.45813969740382265
____________________
Model: StarDist - default, Split: 10/90, GT: c3
Image shape: (149, 105, 147)
GT shape: (149, 105, 147)
Dice: 1.7072712683318253e-05
____________________
Model: StarDist - default, Split: 10/90, GT: c5
Image shape: (124, 86, 94)
GT shape: (124, 86, 94)
Dice: 0.7047307334629646
____________________
Model: StarDist - default, Split: 20/80, GT: visual
Image shape: (65, 127, 214)
GT shape: (65, 127, 214)
Dice: 0.6900064226075787
____________________
Model: StarDist - default, Split: 20/80, GT: c3
Image shape: (149, 105, 147)
GT shape: (149, 105, 147)
Dice: 1.7072712683318253e-05
____________________
Model: StarDist - default, Split: 20/80, GT: c5
Image shape: (124, 86, 94)
GT shape: (124, 86, 94)
Dice: 2.441108263151471e-05
____________________
Model: StarDist - default, Split: 60/40, GT: visual
Image shape: (65, 127, 214)
GT shape: (65, 127, 214)
Dice: 0.6491599555678979
____________________
Model: StarDist - default, Split: 60/40, GT: c3
Image shape: (149, 105, 147)
GT shape: (149, 105, 147)
Dice: 0.5219389414443077
____________________
Model: StarDist - default, Split: 60/40, GT: c5
Image shape: (124, 86, 94)
GT shape: (124, 86, 94)
Dice: 0.725773126793911
____________________
Model: StarDist - default, Split: 80/20, GT: visual
Image shape: (65, 127, 214)
GT shape: (65, 127, 214)
Dice: 0.6641397223887263
____________________
Model: StarDist - default, Split: 80/20, GT: c3
Image shape: (149, 105, 147)
GT shape: (149, 105, 147)
Dice: 0.5228543567663311
____________________
Model: StarDist - default, Split: 80/20, GT: c5
Image shape: (124, 86, 94)
GT shape: (124, 86, 94)
Dice: 0.6306411323896752
____________________
Model: StarDist, Split: 10/90, GT: visual
Image shape: (65, 127, 214)
GT shape: (65, 127, 214)
Dice: 0.7079987191802753
____________________
Model: StarDist, Split: 10/90, GT: c3
Image shape: (149, 105, 147)
GT shape: (149, 105, 147)
Dice: 0.7066773429884418
____________________
Model: StarDist, Split: 10/90, GT: c5
Image shape: (124, 86, 94)
GT shape: (124, 86, 94)
Dice: 0.666873158778668
____________________
Model: StarDist, Split: 20/80, GT: visual
Image shape: (65, 127, 214)
GT shape: (65, 127, 214)
Dice: 0.6870511156437357
____________________
Model: StarDist, Split: 20/80, GT: c3
Image shape: (149, 105, 147)
GT shape: (149, 105, 147)
Dice: 1.7072712683318253e-05
____________________
Model: StarDist, Split: 20/80, GT: c5
Image shape: (124, 86, 94)
GT shape: (124, 86, 94)
Dice: 0.6354648592898188
____________________
Model: StarDist, Split: 60/40, GT: visual
Image shape: (65, 127, 214)
GT shape: (65, 127, 214)
Dice: 0.698111261522594
____________________
Model: StarDist, Split: 60/40, GT: c3
Image shape: (149, 105, 147)
GT shape: (149, 105, 147)
Dice: 0.7395951257585239
____________________
Model: StarDist, Split: 60/40, GT: c5
Image shape: (124, 86, 94)
GT shape: (124, 86, 94)
Dice: 0.7753721081706993
____________________
Model: StarDist, Split: 80/20, GT: visual
Image shape: (65, 127, 214)
GT shape: (65, 127, 214)
Dice: 0.6634502993726294
____________________
Model: StarDist, Split: 80/20, GT: c3
Image shape: (149, 105, 147)
GT shape: (149, 105, 147)
Dice: 0.6909480066618091
____________________
Model: StarDist, Split: 80/20, GT: c5
Image shape: (124, 86, 94)
GT shape: (124, 86, 94)
Dice: 0.7941101315957583
____________________
Model: Cellpose - default, Split: 10/90, GT: visual
Image shape: (65, 127, 214)
GT shape: (65, 127, 214)
Dice: 0.6095447577151236
____________________
Model: Cellpose - default, Split: 10/90, GT: c3
Image shape: (149, 105, 147)
GT shape: (149, 105, 147)
Dice: 1.7072712683318253e-05
____________________
Model: Cellpose - default, Split: 10/90, GT: c5
Image shape: (124, 86, 94)
GT shape: (124, 86, 94)
Dice: 2.441108263151471e-05
____________________
Model: Cellpose - default, Split: 20/80, GT: visual
Image shape: (65, 127, 214)
GT shape: (65, 127, 214)
Dice: 0.38262849366691093
____________________
Model: Cellpose - default, Split: 20/80, GT: c3
Image shape: (149, 105, 147)
GT shape: (149, 105, 147)
Dice: 0.41333207922388543
____________________
Model: Cellpose - default, Split: 20/80, GT: c5
Image shape: (124, 86, 94)
GT shape: (124, 86, 94)
Dice: 0.4346342300432785
____________________
Model: Cellpose - default, Split: 60/40, GT: visual
Image shape: (65, 127, 214)
GT shape: (65, 127, 214)
Dice: 0.3216480826760946
____________________
Model: Cellpose - default, Split: 60/40, GT: c3
Image shape: (149, 105, 147)
GT shape: (149, 105, 147)
Dice: 0.46105638941004795
____________________
Model: Cellpose - default, Split: 60/40, GT: c5
Image shape: (124, 86, 94)
GT shape: (124, 86, 94)
Dice: 0.28778826871657753
____________________
Model: Cellpose - default, Split: 80/20, GT: visual
Image shape: (65, 127, 214)
GT shape: (65, 127, 214)
Dice: 0.5324423418095802
____________________
Model: Cellpose - default, Split: 80/20, GT: c3
Image shape: (149, 105, 147)
GT shape: (149, 105, 147)
Dice: 0.3586485504207554
____________________
Model: Cellpose - default, Split: 80/20, GT: c5
Image shape: (124, 86, 94)
GT shape: (124, 86, 94)
Dice: 0.4212584018642991
____________________
Model: Cellpose, Split: 10/90, GT: visual
Image shape: (65, 127, 214)
GT shape: (65, 127, 214)
Dice: 0.7861751295015355
____________________
Model: Cellpose, Split: 10/90, GT: c3
Image shape: (149, 105, 147)
GT shape: (149, 105, 147)
Dice: 0.08595350922131148
____________________
Model: Cellpose, Split: 10/90, GT: c5
Image shape: (124, 86, 94)
GT shape: (124, 86, 94)
Dice: 0.08457272790551985
____________________
Model: Cellpose, Split: 20/80, GT: visual
Image shape: (65, 127, 214)
GT shape: (65, 127, 214)
Dice: 0.7072484308564487
____________________
Model: Cellpose, Split: 20/80, GT: c3
Image shape: (149, 105, 147)
GT shape: (149, 105, 147)
Dice: 0.7293200647090544
____________________
Model: Cellpose, Split: 20/80, GT: c5
Image shape: (124, 86, 94)
GT shape: (124, 86, 94)
Dice: 0.7463728863187176
____________________
Model: Cellpose, Split: 60/40, GT: visual
Image shape: (65, 127, 214)
GT shape: (65, 127, 214)
Dice: 0.6880578161562196
____________________
Model: Cellpose, Split: 60/40, GT: c3
Image shape: (149, 105, 147)
GT shape: (149, 105, 147)
Dice: 0.7393016007796444
____________________
Model: Cellpose, Split: 60/40, GT: c5
Image shape: (124, 86, 94)
GT shape: (124, 86, 94)
Dice: 0.6729493389605614
____________________
Model: Cellpose, Split: 80/20, GT: visual
Image shape: (65, 127, 214)
GT shape: (65, 127, 214)
Dice: 0.7981322424499242
____________________
Model: Cellpose, Split: 80/20, GT: c3
Image shape: (149, 105, 147)
GT shape: (149, 105, 147)
Dice: 0.7286757739677883
____________________
Model: Cellpose, Split: 80/20, GT: c5
Image shape: (124, 86, 94)
GT shape: (124, 86, 94)
Dice: 0.756078768391886
____________________
Model: SegResNet, Split: 10/90, GT: visual
Image shape: (65, 127, 214)
GT shape: (65, 127, 214)
Dice: 0.4221998199819982
____________________
Model: SegResNet, Split: 10/90, GT: c3
Image shape: (149, 105, 147)
GT shape: (149, 105, 147)
Dice: 0.28690334887141233
____________________
Model: SegResNet, Split: 10/90, GT: c5
Image shape: (124, 86, 94)
GT shape: (124, 86, 94)
Dice: 0.5435601236412202
____________________
Model: SegResNet, Split: 20/80, GT: visual
Image shape: (65, 127, 214)
GT shape: (65, 127, 214)
Dice: 0.5148477010728031
____________________
Model: SegResNet, Split: 20/80, GT: c3
Image shape: (149, 105, 147)
GT shape: (149, 105, 147)
Dice: 0.4388548365803519
____________________
Model: SegResNet, Split: 20/80, GT: c5
Image shape: (124, 86, 94)
GT shape: (124, 86, 94)
Dice: 0.751233228097869
____________________
Model: SegResNet, Split: 60/40, GT: visual
Image shape: (65, 127, 214)
GT shape: (65, 127, 214)
Dice: 0.7747215240932079
____________________
Model: SegResNet, Split: 60/40, GT: c3
Image shape: (149, 105, 147)
GT shape: (149, 105, 147)
Dice: 0.8371491990562523
____________________
Model: SegResNet, Split: 60/40, GT: c5
Image shape: (124, 86, 94)
GT shape: (124, 86, 94)
Dice: 0.6906855807000525
____________________
Model: SegResNet, Split: 80/20, GT: visual
Image shape: (65, 127, 214)
GT shape: (65, 127, 214)
Dice: 0.7061559736831078
____________________
Model: SegResNet, Split: 80/20, GT: c3
Image shape: (149, 105, 147)
GT shape: (149, 105, 147)
Dice: 0.8270254509885067
____________________
Model: SegResNet, Split: 80/20, GT: c5
Image shape: (124, 86, 94)
GT shape: (124, 86, 94)
Dice: 0.7074302192549407
____________________
Model: SwinUNetR, Split: 10/90, GT: visual
Image shape: (65, 127, 214)
GT shape: (65, 127, 214)
Dice: 0.8157865498468327
____________________
Model: SwinUNetR, Split: 10/90, GT: c3
Image shape: (149, 105, 147)
GT shape: (149, 105, 147)
Dice: 0.8174437462975013
____________________
Model: SwinUNetR, Split: 10/90, GT: c5
Image shape: (124, 86, 94)
GT shape: (124, 86, 94)
Dice: 0.7018095117791954
____________________
Model: SwinUNetR, Split: 20/80, GT: visual
Image shape: (65, 127, 214)
GT shape: (65, 127, 214)
Dice: 0.7901964762246602
____________________
Model: SwinUNetR, Split: 20/80, GT: c3
Image shape: (149, 105, 147)
GT shape: (149, 105, 147)
Dice: 0.7719215148797227
____________________
Model: SwinUNetR, Split: 20/80, GT: c5
Image shape: (124, 86, 94)
GT shape: (124, 86, 94)
Dice: 0.760966675600822
____________________
Model: SwinUNetR, Split: 60/40, GT: visual
Image shape: (65, 127, 214)
GT shape: (65, 127, 214)
Dice: 0.8177856932892985
____________________
Model: SwinUNetR, Split: 60/40, GT: c3
Image shape: (149, 105, 147)
GT shape: (149, 105, 147)
Dice: 0.8450172476830877
____________________
Model: SwinUNetR, Split: 60/40, GT: c5
Image shape: (124, 86, 94)
GT shape: (124, 86, 94)
Dice: 0.8482937883231139
____________________
Model: SwinUNetR, Split: 80/20, GT: visual
Image shape: (65, 127, 214)
GT shape: (65, 127, 214)
Dice: 0.8279958413580352
____________________
Model: SwinUNetR, Split: 80/20, GT: c3
Image shape: (149, 105, 147)
GT shape: (149, 105, 147)
Dice: 0.8281666840544286
____________________
Model: SwinUNetR, Split: 80/20, GT: c5
Image shape: (124, 86, 94)
GT shape: (124, 86, 94)
Dice: 0.8328909913502661
____________________
Model: WNet3D, Split: WNet - 10/90, GT: visual
Image shape: (50, 127, 214)
GT shape: (65, 127, 214)
Dice: 0.7501949364268551
____________________
Model: WNet3D, Split: WNet - 10/90, GT: visual
Image shape: (50, 127, 214)
GT shape: (65, 127, 214)
Dice: 0.7432460865174213
____________________
Model: WNet3D, Split: WNet - 10/90, GT: visual
Image shape: (50, 127, 214)
GT shape: (65, 127, 214)
Dice: 0.7162874698995267
____________________
Model: WNet3D, Split: WNet - 20/80, GT: visual
Image shape: (50, 127, 214)
GT shape: (65, 127, 214)
Dice: 0.7421859482493997
____________________
Model: WNet3D, Split: WNet - 20/80, GT: visual
Image shape: (50, 127, 214)
GT shape: (65, 127, 214)
Dice: 0.7438786658383496
____________________
Model: WNet3D, Split: WNet - 20/80, GT: visual
Image shape: (50, 127, 214)
GT shape: (65, 127, 214)
Dice: 0.7254014598540146
____________________
Model: WNet3D, Split: WNet - 60/40, GT: visual
Image shape: (50, 127, 214)
GT shape: (65, 127, 214)
Dice: 0.7340520280550711
____________________
Model: WNet3D, Split: WNet - 60/40, GT: visual
Image shape: (50, 127, 214)
GT shape: (65, 127, 214)
Dice: 0.7233859587228788
____________________
Model: WNet3D, Split: WNet - 60/40, GT: visual
Image shape: (50, 127, 214)
GT shape: (65, 127, 214)
Dice: 0.674949028745547
____________________
Model: WNet3D, Split: WNet - 80/20, GT: visual
Image shape: (50, 127, 214)
GT shape: (65, 127, 214)
Dice: 0.7368010056376657
____________________
Model: WNet3D, Split: WNet - 80/20, GT: visual
Image shape: (50, 127, 214)
GT shape: (65, 127, 214)
Dice: 0.7209019295450522
____________________
Model: WNet3D, Split: WNet - 80/20, GT: visual
Image shape: (50, 127, 214)
GT shape: (65, 127, 214)
Dice: 0.6764699294433867
____________________
Model: WNet3D - Pretrained, Split: WNet - Artifacts, GT: visual
Image shape: (65, 127, 214)
GT shape: (65, 127, 214)
Dice: 0.5896556737502726
____________________
Model: WNet3D - Pretrained, Split: WNet - Artifacts, GT: c3
Image shape: (149, 105, 147)
GT shape: (149, 105, 147)
Dice: 0.8113536341409202
____________________
Model: WNet3D - Pretrained, Split: WNet - Artifacts, GT: c5
Image shape: (124, 86, 94)
GT shape: (124, 86, 94)
Dice: 0.8087545264446371
____________________
dices_df
Dice Model Split GT
0 0.458140 StarDist - default 10/90 visual
1 0.000017 StarDist - default 10/90 c3
2 0.704731 StarDist - default 10/90 c5
3 0.690006 StarDist - default 20/80 visual
4 0.000017 StarDist - default 20/80 c3
... ... ... ... ...
82 0.720902 WNet3D WNet - 80/20 visual
83 0.676470 WNet3D WNet - 80/20 visual
84 0.589656 WNet3D - Pretrained WNet - Artifacts visual
85 0.811354 WNet3D - Pretrained WNet - Artifacts c3
86 0.808755 WNet3D - Pretrained WNet - Artifacts c5

87 rows × 4 columns

WNet3D - No artifacts#

wnet_preds_no_artifacts = { # No artifacts
    "WNet3D - No artifacts": {
        "visual": wnet_visual_no_artifact,
        "c3": wnet_c3_pred,
        "c5": wnet_c5_pred,
    }
}
gt_dict_no_artifacts = {
    "visual": visual_gt_no_artifact,
    "c3": c3,
    "c5": c5,
}
for split in wnet_preds_no_artifacts:
    for image in wnet_preds_no_artifacts[split]:
        gt = gt_dict_no_artifacts[image]
        pred = wnet_preds_no_artifacts[split][image]
        gt = np.where(gt > 0, 1, 0)
        pred = np.where(pred > 0, 1, 0)
        dice = dice_coeff(gt, pred)
        print(f"Split: {split}, Image: {image}, Dice: {dice}")
        dices_df.loc[len(dices_df)] = [dice, "WNet3D - Pretrained", split, image]
Split: WNet3D - No artifacts, Image: visual, Dice: 0.8179572126452918
Split: WNet3D - No artifacts, Image: c3, Dice: 0.8113536341409202
Split: WNet3D - No artifacts, Image: c5, Dice: 0.8087545264446371

Detailed stats#

Shows means and stds of select splits for manuscript. This is the data was used in the text of the manuscript.

dice_df_8020 = dices_df[dices_df["Split"] == "80/20"]
dice_df_8020.groupby("Model", sort=False).agg({"Dice": ["mean", "std"]}).sort_values(("Dice", "mean"), ascending=False)
Dice
mean std
Model
SwinUNetR 0.829685 0.002778
Cellpose 0.760962 0.034985
SegResNet 0.746871 0.069419
StarDist 0.716169 0.068885
StarDist - default 0.605878 0.073826
Cellpose - default 0.437450 0.088021
dices_df_1090 = dices_df[dices_df["Split"] == "10/90"]
dices_df_1090.groupby("Model", sort=False).agg({"Dice": ["mean", "std"]}).sort_values(("Dice", "mean"), ascending=False)
Dice
mean std
Model
SwinUNetR 0.778347 0.066288
StarDist 0.693850 0.023372
SegResNet 0.417554 0.128391
StarDist - default 0.387629 0.357609
Cellpose 0.318900 0.404672
Cellpose - default 0.203195 0.351909
dice_df_wnet = dices_df[dices_df["Model"] == "WNet3D - Pretrained"]
dice_df_wnet.groupby("Split").agg({"Dice": ["mean", "std"]})
Dice
mean std
Split
WNet - Artifacts 0.736588 0.127254
WNet3D - No artifacts 0.812688 0.004744
dice_df_wnet_trained_on_subsets = dices_df[dices_df["Model"] == "WNet3D"]
dice_df_wnet_trained_on_subsets.groupby("Split").agg({"Dice": ["mean", "std"]})
Dice
mean std
Split
WNet - 10/90 0.736576 0.017911
WNet - 20/80 0.737155 0.010214
WNet - 60/40 0.710796 0.031499
WNet - 80/20 0.711391 0.031270

Plots#

dices_df['Model_Split'] = dices_df['Model'] + ' (' + dices_df['Split'].astype(str) + ')'

Due to the way seaborn boxplots interact with categories, the axes will be generated for all models/splits, and then separate boxplots are used to have readable boxes for each model.

The data is the same everywhere, this is only to have cleaner figures. The boxplots were then re-assembled separately (while being careful of the axes) to have the final figure.

~~In addition, since SwinUNetR and WNet have much smaller variance, we show a zoomed-in inset for these models.~~ Deprecated.

General plot#

All models (hard to read)#

fig, ax = plt.subplots(figsize=(9, 6), dpi=DPI)
sns.boxplot(
    data=dices_df,
    x="Model", 
    y="Dice", 
    hue="Model_Split", 
    ax=ax, 
    palette=EXPANDED_COLORMAP,
    # dodge=False,
    )
for i, artist in enumerate(ax.artists): # try to center the boxplot on xticks
    for j in range(i*6,i*6+6):
        line = ax.lines[j]
        line.set_color(artist.get_facecolor())
        if j % 6 == 4: 
            x, y = line.get_xydata()[0]
            x_center = i // 2 
            if i % 2: 
                line.set_xdata([x_center + 0.2, x_center + 0.2, x, x_center + 0.2])
            else:  
                line.set_xdata([x_center - 0.2, x_center - 0.2, x, x_center - 0.2])

ax.spines["right"].set_visible(False)
ax.spines["top"].set_visible(False)
ax.tick_params(axis="both", which="major", labelsize=LEGEND_FONT_SIZE)
ax.tick_params(axis = 'x',   rotation = 65)
plt.ylim([-0.02,1])
ax.set_yticks(np.arange(0,1.1,0.1))
sns.despine(
    left=False,
    right=True,
    bottom=False,
    top=True,
    trim=True,
    offset={"bottom": 40, "left": 15},
)
legend = ax.legend(fontsize=LEGEND_FONT_SIZE, bbox_to_anchor=BBOX_TO_ANCHOR, loc=LOC)
fig.patch.set_alpha(0)
ax.patch.set_alpha(0)
legend.get_frame().set_alpha(0)
ax.set_xlabel("", fontsize=LABEL_FONT_SIZE)
ax.set_ylabel("Dice coefficient", fontsize=LABEL_FONT_SIZE)
plt.legend([],[], frameon=False)
plt.show()

if SAVE_PLOTS_AS_PNG:
    fig.savefig("Label_efficiency.png", dpi=DPI, bbox_inches="tight")
if SAVE_PLOTS_AS_SVG:
    fig.savefig("Label_efficiency.svg", bbox_inches="tight")
C:\Users\Cyril\AppData\Local\Temp\ipykernel_8904\518262752.py:2: UserWarning: 
The palette list has fewer values (29) than needed (30) and will cycle, which may produce an uninterpretable plot.
  sns.boxplot(
../_images/48ad77d51decb02b0fc83cdb0d2d8de98491b0c8ecb271f4b58d354f293e9ba8.png

Legend#

fig_leg = plt.figure(figsize=(3, 2))
ax_leg = fig_leg.add_subplot(111)
ax_leg.legend(handles=legend.legend_handles, labels=[text.get_text() for text in legend.texts])
ax_leg.axis('off')
plt.show()
if SAVE_PLOTS_AS_PNG:
    fig_leg.savefig("Label_efficiency_legend.png", dpi=DPI, bbox_inches="tight")
if SAVE_PLOTS_AS_SVG:
    fig_leg.savefig("Label_efficiency_legend.svg", bbox_inches="tight")
../_images/c0a691aee4c21b90fb2c565c4110b4a81dfe2362123e21d40d3a914fbd85edbe.png

Self-supervised models#

fig, ax = plt.subplots(figsize=(6, 6), dpi=DPI)
wnet_models = ["WNet3D", "WNet3D - Pretrained"]
dices_df_selfsupervised = dices_df[dices_df["Model"].isin(wnet_models)]
sns.boxplot(
    data=dices_df_selfsupervised,
    x="Model", 
    y="Dice", 
    hue="Model_Split", 
    ax=ax, 
    palette=EXPANDED_COLORMAP[20:]
    # dodge=False,
    )
for i, artist in enumerate(ax.artists): # try to center the boxplot on xticks
    for j in range(i*6,i*6+6):
        line = ax.lines[j]
        line.set_color(artist.get_facecolor())
        if j % 6 == 4: 
            x, y = line.get_xydata()[0]
            x_center = i // 2 
            if i % 2: 
                line.set_xdata([x_center + 0.2, x_center + 0.2, x, x_center + 0.2])
            else:  
                line.set_xdata([x_center - 0.2, x_center - 0.2, x, x_center - 0.2])

ax.spines["right"].set_visible(False)
ax.spines["top"].set_visible(False)
ax.tick_params(axis="both", which="major", labelsize=LEGEND_FONT_SIZE)
ax.tick_params(axis = 'x',   rotation = 65)
plt.ylim([-0.02,1])
ax.set_yticks(np.arange(0,1.1,0.1))
sns.despine(
    left=False,
    right=True,
    bottom=False,
    top=True,
    trim=True,
    offset={"bottom": 40, "left": 15},
)
legend = ax.legend(fontsize=LEGEND_FONT_SIZE, bbox_to_anchor=BBOX_TO_ANCHOR, loc=LOC)
fig.patch.set_alpha(0)
ax.patch.set_alpha(0)
legend.get_frame().set_alpha(0)
ax.set_xlabel("", fontsize=LABEL_FONT_SIZE)
ax.set_ylabel("Dice coefficient", fontsize=LABEL_FONT_SIZE)
plt.legend([],[], frameon=False)
plt.show()

if SAVE_PLOTS_AS_PNG:
    fig.savefig("Label_efficiency_selfsup.png", dpi=DPI, bbox_inches="tight")
if SAVE_PLOTS_AS_SVG:
    fig.savefig("Label_efficiency_selfsup.svg", bbox_inches="tight")
C:\Users\Cyril\AppData\Local\Temp\ipykernel_8904\2984398448.py:4: UserWarning: The palette list has more values (9) than needed (6), which may not be intended.
  sns.boxplot(
../_images/7f5b6211aba55cb44c2b5c033d19a8379da03cbcbf69e4504fb906503f00b8ac.png

Boxplots for each model#

models_dfs = [dices_df[dices_df["Model"] == model].copy() for model in model_names]
save_path = Path("label_efficiency")
save_path.mkdir(exist_ok=True)
for i, model in enumerate(model_names):
    print(model)
    # figsize_width = 1 if model != "WNet - Ours" else 0.33 # does not accept float values
    fig, ax = plt.subplots(figsize=(1, 6), dpi=DPI)
    df = models_dfs[i]
    sns.boxplot(
        data=df,
        x="Model_Split", 
        y="Dice", 
        hue="Split", 
        ax=ax, 
        palette=EXPANDED_COLORMAP[0+4*i:4+4*i],
    )
    ax.tick_params(axis = 'x',   rotation = 45)
    ax.spines["right"].set_visible(False)
    ax.spines["top"].set_visible(False)
    ax.tick_params(axis="both", which="major", labelsize=LEGEND_FONT_SIZE)
    ax.set_yticks(np.arange(0,1.1,0.1))
    ax.set_ylim([-0.02,1])
    sns.despine(
        left=False,
        right=True,
        bottom=False,
        top=True,
        trim=True,
        offset={"bottom": 40, "left": 15},
        ax=ax
    )
    ax.set_xticklabels("", fontsize=LABEL_FONT_SIZE)
    ax.set_xlabel("", fontsize=LABEL_FONT_SIZE)
    ax.set_ylabel("Dice coefficient", fontsize=LABEL_FONT_SIZE)
    ax.legend([],[], frameon=False)
    ax.patch.set_alpha(0)
    legend = ax.get_legend()
    if legend:
        legend.remove()
    # ax.set_title(model, fontsize=LABEL_FONT_SIZE)
    fig.patch.set_alpha(0)
    plt.show()
    if SAVE_PLOTS_AS_PNG:
        fig.savefig(f"{str(save_path)}/Label_efficiency_{model}.png", dpi=DPI, bbox_inches="tight")
    if SAVE_PLOTS_AS_SVG:
        fig.savefig(f"{str(save_path)}/Label_efficiency_{model}.svg", bbox_inches="tight")
StarDist - default
../_images/3a800b8350b2511c380f26893f1989c698ee570f28cfdb0d73bc380b872d1409.png
StarDist
../_images/f31292683da0c3031cc8670e623a8632c80425f1878d9a0a7e957c95ec3a0e9c.png
Cellpose - default
../_images/36d2b04488afd7f99fed28b1f896f6b8ab733dc7a3aadca8f76d0c72be40a951.png
Cellpose
../_images/15afddd0e7e9ad69083e471d081850f4c90854caf6595589ded5d29f54b66ab8.png
SegResNet
../_images/a1ca4fb3f8215e6c2900e1fe5da45f13382cf946b2ee6b678c545408fe51fa68.png
SwinUNetR
../_images/4887f1f95cc1a4f9a626388fc005d16f2302535f3f0c72b89f1ed71e69dc8225.png
WNet3D
../_images/7aab5a6befc9c9c609f450b88a3508a8c602773bb483c8f589072f62bfbcdb42.png
WNet3D - Pretrained
C:\Users\Cyril\AppData\Local\Temp\ipykernel_8904\1066515323.py:9: UserWarning: 
The palette list has fewer values (1) than needed (2) and will cycle, which may produce an uninterpretable plot.
  sns.boxplot(
../_images/9b567da15a2f7f0b93d2523dc02241920e1c4030116b478a942752808a09d1b9.png

Zoom-in for rightmost models#

# same as above with Swin and WNet and y axis on the right
fig, ax = plt.subplots(figsize=(3, 6), dpi=DPI)
zoom_dices_df = dices_df[dices_df["Model"].isin(["SwinUNetR", "WNet3D"])]
sns.boxplot(data=zoom_dices_df, x="Model", y="Dice", hue="Split", ax=ax, palette=EXPANDED_COLORMAP[16:], dodge=True)
ax.spines["left"].set_visible(False)
ax.spines["top"].set_visible(False)
ax.tick_params(axis="both", which="major", labelsize=LEGEND_FONT_SIZE)
ax.tick_params(axis = 'x',   rotation = 45,)
# ax.set_ylim([0.8,0.85])
# ax.set_yticks(np.arange(0.8,0.86,0.01))
ax.set_ylim([0.55,0.95])
ax.set_yticks(np.arange(0.55,0.96,0.05))
ax.set_xlabel("", fontsize=LABEL_FONT_SIZE)
ax.set_ylabel("Dice coefficient", fontsize=LABEL_FONT_SIZE)
ax.legend_.remove()
sns.despine(
    left=False,
    right=True,
    bottom=False,
    top=True,
    trim=True,
    offset={"bottom": 40, "right": 15},
)
# legend = ax.legend(fontsize=LEGEND_FONT_SIZE, bbox_to_anchor=BBOX_TO_ANCHOR, loc=LOC)
# legend.get_frame().set_alpha(0)
fig.patch.set_alpha(0)
ax.patch.set_alpha(0)

if SAVE_PLOTS_AS_PNG:
    fig.savefig("Label_efficiency_SWIN_WNET_zoom.png", dpi=DPI, bbox_inches="tight")
if SAVE_PLOTS_AS_SVG:
    fig.savefig("Label_efficiency_SWIN_WNET_zoom.svg", bbox_inches="tight")
C:\Users\Cyril\AppData\Local\Temp\ipykernel_8904\1916514661.py:4: UserWarning: The palette list has more values (13) than needed (8), which may not be intended.
  sns.boxplot(data=zoom_dices_df, x="Model", y="Dice", hue="Split", ax=ax, palette=EXPANDED_COLORMAP[16:], dodge=True)
../_images/2728df63e1c45e85ef928a76d6116c0ddfd43602e62b182d2269a66fbbdc8a7f.png

Statistical tests#

Here we run a Kruskal-Wallis test to assess if all models differ from each other, when performance is averaged across splits.

Performance across models#

dices_df
Dice Model Split GT Model_Split
0 0.458140 StarDist - default 10/90 visual StarDist - default (10/90)
1 0.000017 StarDist - default 10/90 c3 StarDist - default (10/90)
2 0.704731 StarDist - default 10/90 c5 StarDist - default (10/90)
3 0.690006 StarDist - default 20/80 visual StarDist - default (20/80)
4 0.000017 StarDist - default 20/80 c3 StarDist - default (20/80)
... ... ... ... ... ...
85 0.811354 WNet3D - Pretrained WNet - Artifacts c3 WNet3D - Pretrained (WNet - Artifacts)
86 0.808755 WNet3D - Pretrained WNet - Artifacts c5 WNet3D - Pretrained (WNet - Artifacts)
87 0.817957 WNet3D - Pretrained WNet3D - No artifacts visual WNet3D - Pretrained (WNet3D - No artifacts)
88 0.811354 WNet3D - Pretrained WNet3D - No artifacts c3 WNet3D - Pretrained (WNet3D - No artifacts)
89 0.808755 WNet3D - Pretrained WNet3D - No artifacts c5 WNet3D - Pretrained (WNet3D - No artifacts)

90 rows × 5 columns

Boxplot of model performance across all splits#

from plots import _format_plot

fig, ax = plt.subplots(figsize=(8, 8), dpi=DPI)
test_df = dices_df[dices_df["Split"] != "WNet - Artifacts"]
test_df = test_df.groupby("Model", sort=False).Dice.apply(list).reset_index()
sns.boxplot(data=dices_df, hue="Model", y="Dice", palette=temp_cmap, ax=ax)
_format_plot(ax, xlabel="Model", ylabel="Dice coefficient", title="Dice coefficient across splits")
fig.patch.set_alpha(0)
ax.patch.set_alpha(0)
plt.show()

if SAVE_PLOTS_AS_PNG:
    fig.savefig("Label_efficiency_all_models_pooled.png", dpi=DPI)
if SAVE_PLOTS_AS_SVG:
    fig.savefig("Label_efficiency_all_models_pooled.svg", bbox_inches="tight")
C:\Users\Cyril\AppData\Local\Temp\ipykernel_8904\1582456374.py:6: UserWarning: The palette list has more values (10) than needed (8), which may not be intended.
  sns.boxplot(data=dices_df, hue="Model", y="Dice", palette=temp_cmap, ax=ax)
../_images/4fed5a64955c8fc7a9122962e4543fc8d50331d7937113d5fac472907aa438a1.png
print(kruskal(*test_df.Dice.tolist()))
KruskalResult(statistic=49.212286777549934, pvalue=2.0617114022093387e-08)

p-values heatmap - Default notation#

from scikit_posthocs import posthoc_conover
import matplotlib.colors as colors
import matplotlib.patches as mpatches

fig, ax = plt.subplots(figsize=(8, 8), dpi=DPI)
results = posthoc_conover(dices_df, val_col="Dice", group_col="Model", p_adjust="holm")

low_color = COLORMAP[0]
mid_color = COLORMAP[1]
high_color = COLORMAP[2]
equals_1_color = COLORMAP[3]

levels = [0, 0.05, 0.06, 0.99, 1]
colors_list = [low_color, mid_color, high_color, high_color, equals_1_color]
cmap = colors.LinearSegmentedColormap.from_list("", list(zip(levels, colors_list)))
norm = colors.Normalize(vmin=0, vmax=1)

diag = np.diag_indices(results.shape[0])
results.values[diag] = np.nan
sns.heatmap(results, annot=True, cmap=cmap, norm=norm, cbar=False, fmt=".4f", linewidths=0.5, ax=ax)

patches = [mpatches.Patch(color=low_color, label='Above 0 (significant)'),
           mpatches.Patch(color=mid_color, label='Below 0.05 (significant)'),
           mpatches.Patch(color=high_color, label='Above 0.05 (not significant)'),
           mpatches.Patch(color=equals_1_color, label='Equals 1 (not significant)')]

ax.legend(handles=patches, bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0.)
# transparent background
fig.patch.set_alpha(0)
ax.patch.set_alpha(0)

if SAVE_PLOTS_AS_PNG:
    plt.savefig("Label_efficiency_posthoc_conover.png", dpi=DPI, bbox_inches="tight")
if SAVE_PLOTS_AS_SVG:
    plt.savefig("Label_efficiency_posthoc_conover.svg", bbox_inches="tight")
../_images/659c39895ee78a5e1d57aa943a3c3772dbffd823cf85fc68dfd7267cdb174a3b.png

p-values heatmap - Scientific notation#

from scikit_posthocs import posthoc_conover
import matplotlib.colors as colors
import matplotlib.patches as mpatches

fig, ax = plt.subplots(figsize=(8, 8), dpi=DPI)
results = posthoc_conover(dices_df, val_col="Dice", group_col="Model", p_adjust="holm")

low_color = COLORMAP[0]
mid_color = COLORMAP[1]
high_color = COLORMAP[2]
equals_1_color = COLORMAP[3]

levels = [0, 0.05, 0.06, 0.99, 1]
colors_list = [low_color, mid_color, high_color, high_color, equals_1_color]
cmap = colors.LinearSegmentedColormap.from_list("", list(zip(levels, colors_list)))
norm = colors.Normalize(vmin=0, vmax=1)

diag = np.diag_indices(results.shape[0])
results.values[diag] = np.nan
sns.heatmap(results, annot=True, cmap=cmap, norm=norm, cbar=False, fmt=".2e", linewidths=0.5, ax=ax)

patches = [mpatches.Patch(color=low_color, label='Above 0 (significant)'),
           mpatches.Patch(color=mid_color, label='Below 0.05 (significant)'),
           mpatches.Patch(color=high_color, label='Above 0.05 (not significant)'),
           mpatches.Patch(color=equals_1_color, label='Equals 1 (not significant)')]

ax.legend(handles=patches, bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0.)
# transparent background
fig.patch.set_alpha(0)
ax.patch.set_alpha(0)

if SAVE_PLOTS_AS_PNG:
    plt.savefig("Label_efficiency_posthoc_conover_sci.png", dpi=DPI, bbox_inches="tight")
if SAVE_PLOTS_AS_SVG:
    plt.savefig("Label_efficiency_posthoc_conover_sci.svg", bbox_inches="tight")
../_images/16c8432176c08653a8c455e08a43b35701452c1193639aa35578e3da7b5a5dc9.png

Unused#

test = dices_df.copy()
test["Split"].apply(lambda x: x.split("/")[0][-3:]).values
array(['10', '10', '10', '20', '20', '20', '60', '60', '60', '80', '80',
       '80', '10', '10', '10', '20', '20', '20', '60', '60', '60', '80',
       '80', '80', '10', '10', '10', '20', '20', '20', '60', '60', '60',
       '80', '80', '80', '10', '10', '10', '20', '20', '20', '60', '60',
       '60', '80', '80', '80', '10', '10', '10', '20', '20', '20', '60',
       '60', '60', '80', '80', '80', '10', '10', '10', '20', '20', '20',
       '60', '60', '60', '80', '80', '80', ' 10', ' 10', ' 10', ' 20',
       ' 20', ' 20', ' 60', ' 60', ' 60', ' 80', ' 80', ' 80', 'cts',
       'cts', 'cts', 'cts', 'cts', 'cts'], dtype=object)
dices_df
Dice Model Split GT Model_Split
0 0.458140 StarDist - default 10/90 visual StarDist - default (10/90)
1 0.000017 StarDist - default 10/90 c3 StarDist - default (10/90)
2 0.704731 StarDist - default 10/90 c5 StarDist - default (10/90)
3 0.690006 StarDist - default 20/80 visual StarDist - default (20/80)
4 0.000017 StarDist - default 20/80 c3 StarDist - default (20/80)
... ... ... ... ... ...
85 0.811354 WNet3D - Pretrained WNet - Artifacts c3 WNet3D - Pretrained (WNet - Artifacts)
86 0.808755 WNet3D - Pretrained WNet - Artifacts c5 WNet3D - Pretrained (WNet - Artifacts)
87 0.817957 WNet3D - Pretrained WNet3D - No artifacts visual WNet3D - Pretrained (WNet3D - No artifacts)
88 0.811354 WNet3D - Pretrained WNet3D - No artifacts c3 WNet3D - Pretrained (WNet3D - No artifacts)
89 0.808755 WNet3D - Pretrained WNet3D - No artifacts c5 WNet3D - Pretrained (WNet3D - No artifacts)

90 rows × 5 columns

import statsmodels.api as sm
import statsmodels.formula.api as smf

# Fit a linear mixed effects model
mlm_df = dices_df.copy()
mlm_df["Data_percentage"] = mlm_df["Split"].apply(lambda x: x.split("/")[0][-3:])
# replace "cts" by 100 in the Data_percentage column
mlm_df["Data_percentage"] = mlm_df["Data_percentage"].replace("cts", 100)
mlm_df["Data_percentage"] = mlm_df["Data_percentage"].astype(int)
mlm_df
Dice Model Split GT Model_Split Data_percentage
0 0.458140 StarDist - default 10/90 visual StarDist - default (10/90) 10
1 0.000017 StarDist - default 10/90 c3 StarDist - default (10/90) 10
2 0.704731 StarDist - default 10/90 c5 StarDist - default (10/90) 10
3 0.690006 StarDist - default 20/80 visual StarDist - default (20/80) 20
4 0.000017 StarDist - default 20/80 c3 StarDist - default (20/80) 20
... ... ... ... ... ... ...
85 0.811354 WNet3D - Pretrained WNet - Artifacts c3 WNet3D - Pretrained (WNet - Artifacts) 100
86 0.808755 WNet3D - Pretrained WNet - Artifacts c5 WNet3D - Pretrained (WNet - Artifacts) 100
87 0.817957 WNet3D - Pretrained WNet3D - No artifacts visual WNet3D - Pretrained (WNet3D - No artifacts) 100
88 0.811354 WNet3D - Pretrained WNet3D - No artifacts c3 WNet3D - Pretrained (WNet3D - No artifacts) 100
89 0.808755 WNet3D - Pretrained WNet3D - No artifacts c5 WNet3D - Pretrained (WNet3D - No artifacts) 100

90 rows × 6 columns

sns.scatterplot(data=mlm_df, x="Data_percentage", y="Dice", hue="Model", style="GT")
<Axes: xlabel='Data_percentage', ylabel='Dice'>
../_images/f9d0b0440cc4bb824d1d045a0c15e9843078a4cbc5e4c46526a75407d46836e7.png
sns.scatterplot(data=mlm_df, x="Data_percentage", y="Dice", hue="GT", legend=False)
<Axes: xlabel='Data_percentage', ylabel='Dice'>
../_images/910e3ee84eb5eb67af5ec23e120a8646718ad8cfe45bfb2b2356f0a33926575f.png
model = smf.mixedlm("Dice ~ Data_percentage + C(Model) + C(GT)", mlm_df.copy(), groups="Model")
result = model.fit()

# Create a DataFrame to hold the p-values
pvalues = pd.DataFrame(result.pvalues, columns=["p-value"])
result.summary()
c:\Users\Cyril\anaconda3\envs\cellseg3d-figures\lib\site-packages\statsmodels\regression\mixed_linear_model.py:2262: ConvergenceWarning: The Hessian matrix at the estimated parameter values is not positive definite.
  warnings.warn(msg, ConvergenceWarning)
Model: MixedLM Dependent Variable: Dice
No. Observations: 90 Method: REML
No. Groups: 8 Scale: 0.0285
Min. group size: 6 Log-Likelihood: 10.5757
Max. group size: 12 Converged: Yes
Mean group size: 11.2
Coef. Std.Err. z P>|z| [0.025 0.975]
Intercept 0.450 0.180 2.504 0.012 0.098 0.803
C(Model)[T.Cellpose - default] -0.275 0.248 -1.107 0.268 -0.762 0.212
C(Model)[T.SegResNet] -0.002 0.248 -0.007 0.994 -0.489 0.485
C(Model)[T.StarDist] 0.020 0.248 0.081 0.935 -0.467 0.507
C(Model)[T.StarDist - default] -0.163 0.248 -0.656 0.512 -0.650 0.324
C(Model)[T.SwinUNetR] 0.178 0.248 0.716 0.474 -0.309 0.665
C(Model)[T.WNet3D] 0.040 0.250 0.159 0.874 -0.450 0.530
C(Model)[T.WNet3D - Pretrained] -0.005 0.256 -0.020 0.984 -0.507 0.497
C(GT)[T.c5] 0.070 0.047 1.493 0.135 -0.022 0.162
C(GT)[T.visual] 0.121 0.047 2.586 0.010 0.029 0.213
Data_percentage 0.003 0.001 4.126 0.000 0.001 0.004
Model Var 0.028

# use non-parametric repeated measures ANOVA
from scipy.stats import friedmanchisquare

test_df = dices_df[dices_df["Model"] != "WNet3D - Pretrained"].copy()
test_df = test_df.groupby("Model", sort=False)["Dice"]

fried_stats, p_value = friedmanchisquare(*test_df.apply(list))

print(f"Friedman test p-value: {p_value}, statistic: {fried_stats}")
Friedman test p-value: 8.575653390194426e-08, statistic: 43.67462686567162