Showing quantitative benchmarking performance on additional datasets#

  • Show that self-supervised model can perform well on additional datasets, without requiring any additional training.

import numpy as np
from tifffile import imread, imwrite
import sys
import numpy as np

import pyclesperanto_prototype as cle
from stardist.matching import matching_dataset
sys.path.append("../..")

from utils import *
from plots import *
print("Used GPU: ", cle.get_device())
show_params()
#################
SAVE_PLOTS_AS_PNG = False
SAVE_PLOTS_AS_SVG = True
Used GPU:  <NVIDIA GeForce RTX 4070 Ti on Platform: NVIDIA CUDA (1 refs)>
Plot parameters (set in plots.py) : 
- COLORMAP : ███████
- DPI : 200
- Data path : C:\Users\Cyril\Desktop\Code\CELLSEG_BENCHMARK
- Font size : 20
- Title font size : 25.0
- Label font size : 20.0
%load_ext autoreload
%autoreload 2

Data#

data_path = DATA_PATH / "RESULTS/WNET OTHERS/"

# list all folders in the data path
folders = [x for x in data_path.iterdir() if x.is_dir()]
folders
[WindowsPath('C:/Users/Cyril/Desktop/Code/CELLSEG_BENCHMARK/RESULTS/WNET OTHERS/Mouse-Skull-Nuclei-CBG'),
 WindowsPath('C:/Users/Cyril/Desktop/Code/CELLSEG_BENCHMARK/RESULTS/WNET OTHERS/Platynereis-ISH-Nuclei-CBG'),
 WindowsPath('C:/Users/Cyril/Desktop/Code/CELLSEG_BENCHMARK/RESULTS/WNET OTHERS/Platynereis-Nuclei-CBG'),
 WindowsPath('C:/Users/Cyril/Desktop/Code/CELLSEG_BENCHMARK/RESULTS/WNET OTHERS/processed_instance_labels'),
 WindowsPath('C:/Users/Cyril/Desktop/Code/CELLSEG_BENCHMARK/RESULTS/WNET OTHERS/processed_instance_labels_retrain'),
 WindowsPath('C:/Users/Cyril/Desktop/Code/CELLSEG_BENCHMARK/RESULTS/WNET OTHERS/processed_threshold_only_instance_labels'),
 WindowsPath('C:/Users/Cyril/Desktop/Code/CELLSEG_BENCHMARK/RESULTS/WNET OTHERS/Retrained WNet'),
 WindowsPath('C:/Users/Cyril/Desktop/Code/CELLSEG_BENCHMARK/RESULTS/WNET OTHERS/Seb cFOS')]
def get_predictions(path):
    return [imread(f) for f in path.glob("*.tif")]
#################
gt_folder = "labels"
mouse_skull_gt = get_predictions(folders[0] / gt_folder)[0]
platynereis_ISH_gt = get_predictions(folders[1] / gt_folder)[0]
platynereis_gt = get_predictions(folders[2] / gt_folder)[0]
prediction_folder = "pred"
mouse_skull_pred = get_predictions(folders[0] / prediction_folder)[0]
platynereis_ISH_pred = get_predictions(folders[1] / prediction_folder)[0]
platynereis_pred = get_predictions(folders[2] / prediction_folder)[0]
# get second channel of predictions
mouse_skull_pred = mouse_skull_pred[1]
platynereis_ISH_pred = platynereis_ISH_pred[1]
platynereis_pred = platynereis_pred[1]
def find_files(path):
    files = path.glob("*.tif")
    pretrained = []
    for f in files:
        if f.name.startswith("pretrained"):
            pretrained.append(f)
    return pretrained[0]
cellpose_folder = "cellpose"
cp_pretrained_mouse_skull = find_files(folders[0] / cellpose_folder)
cp_pretrained_platynereis_ISH = find_files(folders[1] / cellpose_folder)
cp_pretrained_platynereis = find_files(folders[2] / cellpose_folder)

cp_pretrained_mouse_skull = imread(cp_pretrained_mouse_skull)
cp_pretrained_platynereis_ISH = imread(cp_pretrained_platynereis_ISH)
cp_pretrained_platynereis = imread(cp_pretrained_platynereis)
stardist_folder = "stardist"
sd_retrained_mouse_skull = get_predictions(folders[0] / stardist_folder)[0]
sd_retrained_platynereis_ISH = get_predictions(folders[1] / stardist_folder)[0]
sd_retrained_platynereis = get_predictions(folders[2] / stardist_folder)[0]
# get validation set to estimate thresholds
mouse_skull_val = imread(folders[0] / "TEST/X2_left_WNet3D_pred_1.tif")[1] # take channel 1 of WNet prediction (0 is background)
mouse_skull_val_gt = imread(folders[0] / "TEST/Y2_left.tif")
###
platynereis_ISH_val = imread(folders[1] / "TEST/downsampled_cropped_X02_train_WNet3D_pred.tif")[1]
platynereis_ISH_val_gt = imread(folders[1] / "TEST/downsampled_cropped_Y02_train.tif")
###
platynereis_val = imread(folders[2] / "TEST/downsampled_cropped_dataset_hdf5_150_0_WNet3D_pred_1.tif")[1]
platynereis_val_gt = imread(folders[2] / "TEST/downsampled_cropped_mask_dataset_hdf5_150_0.tif")

Computations#

Threshold predictions#

GT_labels_val = [mouse_skull_val_gt, platynereis_ISH_val_gt, platynereis_val_gt]
predictions_val = [mouse_skull_val, platynereis_ISH_val, platynereis_val]

thresh = np.arange(0, 1, 0.05)
rows = []
for t in thresh:
    for i, (gt, pred) in enumerate(zip(GT_labels_val, predictions_val)):
        dices_row = {"Threshold": t, "Fold": i, "Dice": dice_coeff(
            np.where(gt > 0, 1, 0),
            np.where(pred > t, 1, 0)
            )}
        rows.append(dices_row)
        
dices_df = pd.DataFrame(rows)

sns.lineplot(data=dices_df, x="Threshold", y="Dice", hue="Fold", palette="tab10")
plt.title("Dice metric for different thresholds for WNet3D and GT")
plt.vlines([0.45, 0.55], 0, 1, colors="red", linestyles="dashed")
plt.show()
../_images/8f20d849b0b895a6f3460a11cc85c01011ac985ee9746ce2921ecb424de5443f.png
dices_df.groupby("Threshold").mean().sort_values("Dice", ascending=False).head(5)
Fold Dice
Threshold
0.50 1.0 0.717895
0.55 1.0 0.708049
0.45 1.0 0.631312
0.60 1.0 0.603701
0.65 1.0 0.455269
predictions = [mouse_skull_pred, platynereis_ISH_pred, platynereis_pred]
GT_labels = [mouse_skull_gt, platynereis_ISH_gt, platynereis_gt]
predictions_thresholded = []
thresholds = [0.45, 0.55, 0.55]
for i, pred in enumerate(predictions):
    predictions_thresholded.append(np.where(pred > thresholds[i], 1, 0))
mouse_skull_instance = np.array(
    cle.voronoi_otsu_labeling(predictions_thresholded[0], outline_sigma=1, spot_sigma=15)
)
platynereis_ISH_instance = np.array(
    cle.voronoi_otsu_labeling(predictions_thresholded[1], outline_sigma=0.5, spot_sigma=2)
)
platynereis_instance = np.array(
    cle.voronoi_otsu_labeling(predictions_thresholded[2], outline_sigma=0.5, spot_sigma=2.75)
)

Additional mouse skull postprocessing#

def remap_image(image, new_min=1, new_max=100):
    min_val = image.min()
    max_val = image.max()
    return (image - min_val) / (max_val - min_val) * (new_max - new_min) + new_min
def mouse_skull_postproc(instance_labels, pred):
    instance_labels = cle.closing_labels(instance_labels, radius=8)
    remap_pred = remap_image(pred)
    instance_labels = cle.merge_labels_with_border_intensity_within_range(
        image=remap_pred,
        labels=instance_labels.astype(np.int32), 
        minimum_intensity=35, 
        maximum_intensity=100
    )
    return np.array(instance_labels)
mouse_skull_image = imread(
    str(Path("c:/Users/Cyril/Desktop/Code/CELLSEG_BENCHMARK/RESULTS/WNET OTHERS/Mouse-Skull-Nuclei-CBG/X1.tif"))
)
mouse_skull_instance_postproc = mouse_skull_postproc(
    mouse_skull_instance, 
    mouse_skull_pred
)
_generate_touch_mean_intensity_matrix.py (30): generate_touch_mean_intensity_matrix is supposed to work with images of integer type only.
Loss of information is possible when passing non-integer images.
_opencl_execute.py (281): overflow encountered in cast
# apply post-processing to cellpose and stardist predictions (keep w/o post-processing for comparison)

cp_pretrained_mouse_skull_postproc = mouse_skull_postproc(cp_pretrained_mouse_skull, mouse_skull_pred)
sd_retrained_mouse_skull_postproc = mouse_skull_postproc(sd_retrained_mouse_skull, mouse_skull_pred)
# import napari
# viewer = napari.Viewer()
# viewer.add_image(predictions_thresholded[0], colormap="turbo")
# viewer.add_labels(mouse_skull_instance)
# viewer.add_labels(mouse_skull_instance, name="mouse_skull_instance_closed")
# viewer.add_image(mouse_skull_remap, name="mouse_skull_pred_remap", colormap="turbo")
# viewer.add_labels(mouse_skull_instance, name="mouse_skull_instance_closed_merged")
# Show the predictions and the instance segmentation
# import napari
# viewer = napari.Viewer()

# viewer.add_image(predictions_thresholded[0], name="mouse_skull_pred", colormap="turbo")
# viewer.add_labels(mouse_skull_instance, name="mouse_skull_instance")
# viewer.add_image(predictions_thresholded[1], name="platynereis_ISH_pred", colormap="turbo")
# viewer.add_labels(platynereis_ISH_instance, name="platynereis_ISH_instance")
# viewer.add_image(predictions_thresholded[2], name="platynereis_pred", colormap="turbo")
# viewer.add_labels(platynereis_instance, name="platynereis_instance")

Plots#

predictions = [
    mouse_skull_instance_postproc,
    mouse_skull_instance,
    platynereis_ISH_instance,
    platynereis_instance,
   ]
GT_labels = [
    mouse_skull_gt,
    mouse_skull_gt,
    platynereis_ISH_gt,
    platynereis_gt,
    ]
names = [
    "Mouse skull w/ post-processing",
    "Mouse skull w/o post-processing",
    "Platynereis ISH",
    "Platynereis",
    ]
# save instance labels
for pred, name in zip(predictions, names):
    save_path = data_path / "processed_instance_labels"
    save_path.mkdir(exist_ok=True)
    name = name.replace("/", "_")
    name = name.replace(" ", "_")
    imwrite(save_path / f"{name}.tif", pred.astype(np.uint32))
taus = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]


model_stats = []
names_stats = []

for i, p in enumerate(predictions):
    print(f"Validating on {names[i]}")
    stats = [matching_dataset(
        GT_labels[i], 
        p,
        thresh=t, 
        show_progress=False
        ) for t in taus]
    model_stats.append(stats)
    for t in taus:
        names_stats.append(names[i])
    # uncomment for ALL plots : 
    plot_performance(taus, stats, name=names[i])
    print("*"*20)
Validating on Mouse skull w/ post-processing
********************
Validating on Mouse skull w/o post-processing
********************
Validating on Platynereis ISH
********************
Validating on Platynereis
********************
../_images/0e7a1c9e446873b362b6cfa75e858640acc9d6a108acc2d3db4cce98a514e0c3.png ../_images/41391a2fb32872ea249ee61ece7c0cbdeb1cb13b58c35439aaff13864034b616.png ../_images/9648fdf802dd6bd204dc43116a2fda19b1d93955fe9653875e28db47783e60cb.png ../_images/0953092dcd0374af0d25723e3487122e712ce6f36f684d561e87723f4a3be7ce.png
dfs = [dataset_matching_stats_to_df(s) for s in model_stats]
df_wnet = pd.concat(dfs)
df_wnet["Dataset"] = names_stats
df_wnet["Model"] = "WNet3D"
df_wnet
criterion fp tp fn precision recall accuracy f1 n_true n_pred mean_true_score mean_matched_score panoptic_quality by_image Dataset Model
thresh
0.1 iou 1096 4061 845 0.787473 0.827762 0.676608 0.807115 4906 5157 0.558399 0.674588 0.544470 False Mouse skull w/ post-processing WNet3D
0.2 iou 1219 3938 968 0.763622 0.802691 0.642939 0.782669 4906 5157 0.555335 0.691842 0.541483 False Mouse skull w/ post-processing WNet3D
0.3 iou 1319 3838 1068 0.744231 0.782307 0.616546 0.762794 4906 5157 0.550126 0.703209 0.536404 False Mouse skull w/ post-processing WNet3D
0.4 iou 1507 3650 1256 0.707776 0.743987 0.569156 0.725430 4906 5157 0.536541 0.721169 0.523158 False Mouse skull w/ post-processing WNet3D
0.5 iou 1952 3205 1701 0.621485 0.653282 0.467337 0.636987 4906 5157 0.495302 0.758175 0.482948 False Mouse skull w/ post-processing WNet3D
0.6 iou 2468 2689 2217 0.521427 0.548104 0.364660 0.534433 4906 5157 0.437325 0.797887 0.426417 False Mouse skull w/ post-processing WNet3D
0.7 iou 3006 2151 2755 0.417103 0.438443 0.271866 0.427507 4906 5157 0.365661 0.834001 0.356541 False Mouse skull w/ post-processing WNet3D
0.8 iou 3670 1487 3419 0.288346 0.303098 0.173391 0.295538 4906 5157 0.263805 0.870361 0.257225 False Mouse skull w/ post-processing WNet3D
0.9 iou 4789 368 4538 0.071359 0.075010 0.037958 0.073139 4906 5157 0.069084 0.920993 0.067361 False Mouse skull w/ post-processing WNet3D
0.1 iou 1718 4116 790 0.705519 0.838973 0.621377 0.766480 4906 5834 0.479463 0.571488 0.438034 False Mouse skull w/o post-processing WNet3D
0.2 iou 1903 3931 975 0.673809 0.801264 0.577324 0.732030 4906 5834 0.474820 0.592590 0.433793 False Mouse skull w/o post-processing WNet3D
0.3 iou 2244 3590 1316 0.615358 0.731757 0.502098 0.668529 4906 5834 0.457198 0.624794 0.417693 False Mouse skull w/o post-processing WNet3D
0.4 iou 2761 3073 1833 0.526740 0.626376 0.400809 0.572253 4906 5834 0.419895 0.670357 0.383614 False Mouse skull w/o post-processing WNet3D
0.5 iou 3389 2445 2461 0.419095 0.498369 0.294756 0.455307 4906 5834 0.362430 0.727233 0.331114 False Mouse skull w/o post-processing WNet3D
0.6 iou 3925 1909 2997 0.327220 0.389115 0.216170 0.355493 4906 5834 0.302478 0.777348 0.276342 False Mouse skull w/o post-processing WNet3D
0.7 iou 4471 1363 3543 0.233630 0.277823 0.145356 0.253818 4906 5834 0.230154 0.828420 0.210268 False Mouse skull w/o post-processing WNet3D
0.8 iou 4895 939 3967 0.160953 0.191398 0.095807 0.174860 4906 5834 0.165426 0.864303 0.151132 False Mouse skull w/o post-processing WNet3D
0.9 iou 5657 177 4729 0.030339 0.036078 0.016757 0.032961 4906 5834 0.033077 0.916812 0.030219 False Mouse skull w/o post-processing WNet3D
0.1 iou 532 2484 168 0.823607 0.936652 0.780151 0.876500 2652 3016 0.630515 0.673159 0.590023 False Platynereis ISH WNet3D
0.2 iou 590 2426 226 0.804377 0.914781 0.748304 0.856034 2652 3016 0.627286 0.685722 0.587002 False Platynereis ISH WNet3D
0.3 iou 652 2364 288 0.783820 0.891403 0.715496 0.834157 2652 3016 0.621627 0.697358 0.581706 False Platynereis ISH WNet3D
0.4 iou 776 2240 412 0.742706 0.844646 0.653442 0.790402 2652 3016 0.605176 0.716485 0.566312 False Platynereis ISH WNet3D
0.5 iou 949 2067 585 0.685345 0.779412 0.574007 0.729358 2652 3016 0.575968 0.738978 0.538980 False Platynereis ISH WNet3D
0.6 iou 1226 1790 862 0.593501 0.674962 0.461578 0.631616 2652 3016 0.518761 0.768578 0.485446 False Platynereis ISH WNet3D
0.7 iou 1622 1394 1258 0.462202 0.525641 0.326158 0.491884 2652 3016 0.421467 0.801816 0.394401 False Platynereis ISH WNet3D
0.8 iou 2311 705 1947 0.233753 0.265837 0.142051 0.248765 2652 3016 0.226469 0.851910 0.211925 False Platynereis ISH WNet3D
0.9 iou 2920 96 2556 0.031830 0.036199 0.017229 0.033874 2652 3016 0.033264 0.918915 0.031128 False Platynereis ISH WNet3D
0.1 iou 58 800 252 0.932401 0.760456 0.720721 0.837696 1052 858 0.525770 0.691387 0.579173 False Platynereis WNet3D
0.2 iou 86 772 280 0.899767 0.733840 0.678383 0.808377 1052 858 0.521990 0.711313 0.575009 False Platynereis WNet3D
0.3 iou 115 743 309 0.865967 0.706274 0.636675 0.778010 1052 858 0.514955 0.729115 0.567259 False Platynereis WNet3D
0.4 iou 152 706 346 0.822844 0.671103 0.586379 0.739267 1052 858 0.502527 0.748808 0.553569 False Platynereis WNet3D
0.5 iou 194 664 388 0.773893 0.631179 0.532905 0.695288 1052 858 0.484465 0.767556 0.533672 False Platynereis WNet3D
0.6 iou 269 589 463 0.686480 0.559886 0.445874 0.616754 1052 858 0.445128 0.795033 0.490340 False Platynereis WNet3D
0.7 iou 369 489 563 0.569930 0.464829 0.344124 0.512042 1052 858 0.383217 0.824425 0.422140 False Platynereis WNet3D
0.8 iou 535 323 729 0.376457 0.307034 0.203529 0.338220 1052 858 0.264275 0.860736 0.291118 False Platynereis WNet3D
0.9 iou 802 56 996 0.065268 0.053232 0.030205 0.058639 1052 858 0.048994 0.920388 0.053970 False Platynereis WNet3D
plot_stat_comparison(taus=taus, stats_list=model_stats, model_names=names, metric="IoU", plt_size=(9, 6))
if SAVE_PLOTS_AS_PNG:
    plt.savefig("f1_comparison.png")
if SAVE_PLOTS_AS_SVG:
    plt.savefig("f1_comparison.svg", bbox_inches='tight')
../_images/64e5f1e60e4e28beb0477d5c652461d2a821e0f67e87c6c4afbb2c13d47773a0.png

Comparisons on additional datasets#

# check performance for cellpose and stardist
model_stats_others = []
names_stats_others = []

predictions_others = [
    cp_pretrained_mouse_skull_postproc,
    cp_pretrained_mouse_skull,
    sd_retrained_mouse_skull_postproc,
    sd_retrained_mouse_skull,
    cp_pretrained_platynereis_ISH,
    sd_retrained_platynereis_ISH,
    cp_pretrained_platynereis,
    sd_retrained_platynereis,
]
GT_labels_others = [
    mouse_skull_gt,
    mouse_skull_gt,
    mouse_skull_gt,
    mouse_skull_gt,
    platynereis_ISH_gt,
    platynereis_ISH_gt,
    platynereis_gt,
    platynereis_gt,
]
names_others = [
    "CellPose - pretrained w/ post-processing",
    "CellPose - pretrained",
    "StarDist - retrained w/ post-processing",
    "StarDist - retrained",
    "CellPose - pretrained",
    "StarDist - retrained",
    "CellPose - pretrained",
    "StarDist - retrained",
]
taus = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]

for i, p in enumerate(predictions_others):
    print(f"Validating on {names_others[i]}")
    stats = [matching_dataset(
        GT_labels_others[i],
        p,
        thresh=t, 
        show_progress=False
        ) for t in taus]
    model_stats_others.append(stats)
    for t in taus:
        names_stats_others.append(names_others[i])  
    # uncomment for ALL plots : 
    plot_performance(taus, stats, name=names_others[i])
    print("*"*20)
Validating on CellPose - pretrained w/ post-processing
********************
Validating on CellPose - pretrained
********************
Validating on StarDist - retrained w/ post-processing
********************
Validating on StarDist - retrained
********************
Validating on CellPose - pretrained
********************
Validating on StarDist - retrained
********************
Validating on CellPose - pretrained
********************
Validating on StarDist - retrained
********************
../_images/e6a82393c7d39be628c8f75587653eba9854a0fd03dbef9c6bd8e40d6bb58393.png ../_images/f80a24cc2757e14fa792e6c0ec2d31e80273a5be2fbc014e21a88d5e10b6c8a9.png ../_images/6183cb4cac24f9d15a5b75d381018fff6834687b991e6aa030831541c104d8d9.png ../_images/88a1de66bf08d76877ef980c398cadccdd2329a74c97f543146ddd77883e8225.png ../_images/3d3a4c45ec840b8453f68467d2d6b5c72fb0eeeb4a0edd07aea62a303a86a67f.png ../_images/c255ee99c1e09e42d8cfcf96d8cb45958fca185c0858f6fce9e35eb7383208cc.png ../_images/d3cf2faca046eff2c0e33361ef2732fe1253f4f3dedde61d29b9d5543dfcd8b5.png ../_images/655c5c4f1ea0d1d1a753453d58d0980f302d8e5edf43c5a1e36ee23a45bde712.png
# plot full comparison of stats
dfs = [dataset_matching_stats_to_df(s) for s in model_stats_others]
df_others = pd.concat(dfs)
df_others["Dataset"] = ["Mouse skull"] * len(taus) * 4 + ["Platynereis ISH"] * len(taus) * 2 + ["Platynereis"] * len(taus) * 2

def get_name(name):
    res = "CellPose" if "CellPose" in name else "StarDist"
    res += "- post-processing" if "post-processing" in name else ""
    return res

df_others["Model"] = [get_name(n) for n in names_stats_others]
df_others
criterion fp tp fn precision recall accuracy f1 n_true n_pred mean_true_score mean_matched_score panoptic_quality by_image Dataset Model
thresh
0.1 iou 13028 4283 623 0.247415 0.873013 0.238820 0.385561 4906 17311 0.547819 0.627504 0.241941 False Mouse skull CellPose- post-processing
0.2 iou 13295 4016 890 0.231991 0.818589 0.220647 0.361525 4906 17311 0.539571 0.659147 0.238298 False Mouse skull CellPose- post-processing
0.3 iou 13549 3762 1144 0.217318 0.766816 0.203847 0.338660 4906 17311 0.526620 0.686762 0.232578 False Mouse skull CellPose- post-processing
0.4 iou 13850 3461 1445 0.199931 0.705463 0.184528 0.311563 4906 17311 0.504832 0.715604 0.222956 False Mouse skull CellPose- post-processing
0.5 iou 14352 2959 1947 0.170932 0.603139 0.153650 0.266373 4906 17311 0.458649 0.760436 0.202559 False Mouse skull CellPose- post-processing
... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ...
0.5 iou 186 667 385 0.781946 0.634030 0.538772 0.700262 1052 853 0.474416 0.748255 0.523975 False Platynereis StarDist
0.6 iou 271 582 470 0.682298 0.553232 0.439909 0.611024 1052 853 0.429873 0.777021 0.474778 False Platynereis StarDist
0.7 iou 384 469 583 0.549824 0.445817 0.326602 0.492388 1052 853 0.359040 0.805351 0.396546 False Platynereis StarDist
0.8 iou 594 259 793 0.303634 0.246198 0.157351 0.271916 1052 853 0.208831 0.848225 0.230646 False Platynereis StarDist
0.9 iou 835 18 1034 0.021102 0.017110 0.009539 0.018898 1052 853 0.015703 0.917728 0.017343 False Platynereis StarDist

72 rows × 16 columns

wnet_color = COLORS_DICT["WNet3D"]
cellpose_color = COLORS_DICT["Cellpose"]
stardist_color = COLORS_DICT["Stardist"]

wnet_shades = get_shades(wnet_color)
cellpose_shades = get_shades(cellpose_color)
stardist_shades = get_shades(stardist_color)

SIMPLE_COLORMAP = [
    wnet_color, 
    cellpose_color,
    stardist_color,
    "black"
]
EXPANDED_COLORMAP = [
    wnet_color, wnet_shades[0],
    cellpose_color, cellpose_shades[0],
    stardist_color, stardist_shades[0]
]
Warning: Saturation in 1.0 is too low or too high in hex color FF4D00
Warning: Value in 1.0 is too low or too high in hex color FF4D00
plot_stat_comparison(taus=taus, stats_list=model_stats[0:2]+model_stats_others[0:4], model_names=["WNet3D w/ postprocessing", "WNet3D w/o postprocessing"]+names_others[0:4], metric="IoU", plt_size=(9, 6), colormap=EXPANDED_COLORMAP, title="Mouse-Skull-Nuclei-CBG")
if SAVE_PLOTS_AS_PNG:
    plt.savefig("f1_comparison_mouse_skull_full.png")
if SAVE_PLOTS_AS_SVG:
    plt.savefig("f1_comparison_mouse_skull_full.svg", bbox_inches='tight')
../_images/6379fd2d92dd41aa326f019e76e070e90e9d16d9be4db43c34f08fea3ba7cb44.png
plot_stat_comparison(taus=taus, stats_list=[model_stats[0]]+model_stats_others[0:4:2], 
                     model_names=
                    [
                        "WNet3D",
                        "CellPose - pretrained",
                        "StarDist - retrained"
                    ], 
                     metric="IoU", plt_size=(9, 6), colormap=SIMPLE_COLORMAP, title="Mouse-Skull-Nuclei-CBG")
if SAVE_PLOTS_AS_PNG:
    plt.savefig("f1_comparison_mouse_skull.png")
if SAVE_PLOTS_AS_SVG:
    plt.savefig("f1_comparison_mouse_skull.svg", bbox_inches='tight')
../_images/8493f7da357ca66b5803c5f32efd55a91e759e9c9e83867e19016e2d561c3f21.png
plot_stat_comparison(taus=taus, stats_list=[model_stats[2]]+model_stats_others[4:6], model_names=["WNet3D"]+names_others[4:6], metric="IoU", plt_size=(9, 6), colormap=SIMPLE_COLORMAP, title="Platynereis ISH")
if SAVE_PLOTS_AS_PNG:
    plt.savefig("f1_comparison_platynereis_ISH.png")
if SAVE_PLOTS_AS_SVG:
    plt.savefig("f1_comparison_platynereis_ISH.svg", bbox_inches='tight')
../_images/7346611ae32ac72b8732bf85672969f2f3a6bc280c2d93a68c10ad76ebada4c2.png
plot_stat_comparison(taus=taus, stats_list=[model_stats[3]]+model_stats_others[6:], model_names=[names[3]+" - WNet3D"]+names_others[6:], metric="IoU", plt_size=(9, 6), colormap=SIMPLE_COLORMAP, title="Platynereis Nuclei")
if SAVE_PLOTS_AS_PNG:
    plt.savefig("f1_comparison_platynereis.png")
if SAVE_PLOTS_AS_SVG:
    plt.savefig("f1_comparison_platynereis.svg", bbox_inches='tight')
../_images/0fd26d592f32281ddf5537ba62639a34afe1097a29368feb6b4e2605d9ffaa4f.png

Statistics tables#

test_df = pd.concat([df_wnet, df_others])

test_df
criterion fp tp fn precision recall accuracy f1 n_true n_pred mean_true_score mean_matched_score panoptic_quality by_image Dataset Model
thresh
0.1 iou 1096 4061 845 0.787473 0.827762 0.676608 0.807115 4906 5157 0.558399 0.674588 0.544470 False Mouse skull w/ post-processing WNet3D
0.2 iou 1219 3938 968 0.763622 0.802691 0.642939 0.782669 4906 5157 0.555335 0.691842 0.541483 False Mouse skull w/ post-processing WNet3D
0.3 iou 1319 3838 1068 0.744231 0.782307 0.616546 0.762794 4906 5157 0.550126 0.703209 0.536404 False Mouse skull w/ post-processing WNet3D
0.4 iou 1507 3650 1256 0.707776 0.743987 0.569156 0.725430 4906 5157 0.536541 0.721169 0.523158 False Mouse skull w/ post-processing WNet3D
0.5 iou 1952 3205 1701 0.621485 0.653282 0.467337 0.636987 4906 5157 0.495302 0.758175 0.482948 False Mouse skull w/ post-processing WNet3D
... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ...
0.5 iou 186 667 385 0.781946 0.634030 0.538772 0.700262 1052 853 0.474416 0.748255 0.523975 False Platynereis StarDist
0.6 iou 271 582 470 0.682298 0.553232 0.439909 0.611024 1052 853 0.429873 0.777021 0.474778 False Platynereis StarDist
0.7 iou 384 469 583 0.549824 0.445817 0.326602 0.492388 1052 853 0.359040 0.805351 0.396546 False Platynereis StarDist
0.8 iou 594 259 793 0.303634 0.246198 0.157351 0.271916 1052 853 0.208831 0.848225 0.230646 False Platynereis StarDist
0.9 iou 835 18 1034 0.021102 0.017110 0.009539 0.018898 1052 853 0.015703 0.917728 0.017343 False Platynereis StarDist

108 rows × 16 columns

# pivot the f1 column of test df to have f1 values across threshold. Each row is a model, each column is a threshold
stats_df = test_df[["Dataset", "Model", "f1"]]

def correct_name(row):
    if "Mouse skull" in row["Dataset"]:
        if "w/ post-processing" in row["Dataset"] and "WNet3D" in row["Model"]:
            return "WNet3D - post-processing"
        elif "w/o post-processing" in row["Dataset"] and "WNet3D" in row["Model"]:
            return "WNet3D"
        else:
            return row["Model"]
    else: 
        return row["Model"]
        
        
stats_df["Model"] = stats_df.apply(correct_name, axis=1)
stats_df["Dataset"] = stats_df["Dataset"].apply(lambda x: x.replace(" w/o post-processing", ""))
stats_df["Dataset"] = stats_df["Dataset"].apply(lambda x: x.replace(" w/ post-processing", ""))
stats_df
1697223708.py (16): 
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
1697223708.py (17): 
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
1697223708.py (18): 
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
Dataset Model f1
thresh
0.1 Mouse skull WNet3D - post-processing 0.807115
0.2 Mouse skull WNet3D - post-processing 0.782669
0.3 Mouse skull WNet3D - post-processing 0.762794
0.4 Mouse skull WNet3D - post-processing 0.725430
0.5 Mouse skull WNet3D - post-processing 0.636987
... ... ... ...
0.5 Platynereis StarDist 0.700262
0.6 Platynereis StarDist 0.611024
0.7 Platynereis StarDist 0.492388
0.8 Platynereis StarDist 0.271916
0.9 Platynereis StarDist 0.018898

108 rows × 3 columns

dataset_tables = {}
for dataset in stats_df["Dataset"].unique():
    table_df = stats_df[stats_df["Dataset"] == dataset].copy().reset_index()
    table_df = table_df.pivot(index="Model", columns="thresh", values="f1")
    # add the mean as an extra "threshold" - getting the mean across all thresholds
    table_df["Mean"] = table_df.mean(axis=1)
    # round all values to 3 decimals
    display(table_df)
    table_df = table_df.round(3)
    latex_table = table_df.to_latex()
    print(f"Dataset: {dataset}")
    print(latex_table)
    dataset_tables[dataset] = table_df
thresh 0.1 0.2 0.3 0.4 0.5 0.6 0.7 0.8 0.9 Mean
Model
CellPose 0.137428 0.110595 0.077202 0.053596 0.038268 0.028026 0.020169 0.013995 0.006489 0.053974
CellPose- post-processing 0.385561 0.361525 0.338660 0.311563 0.266373 0.227754 0.189495 0.119638 0.027096 0.247518
StarDist 0.573155 0.533034 0.410503 0.252808 0.135234 0.064751 0.020448 0.002788 0.000000 0.221413
StarDist- post-processing 0.688610 0.648815 0.556764 0.406881 0.275593 0.173687 0.073082 0.009670 0.000186 0.314810
WNet3D 0.766480 0.732030 0.668529 0.572253 0.455307 0.355493 0.253818 0.174860 0.032961 0.445748
WNet3D - post-processing 0.807115 0.782669 0.762794 0.725430 0.636987 0.534433 0.427507 0.295538 0.073139 0.560624
Dataset: Mouse skull
\begin{tabular}{lrrrrrrrrrr}
\toprule
thresh & 0.100000 & 0.200000 & 0.300000 & 0.400000 & 0.500000 & 0.600000 & 0.700000 & 0.800000 & 0.900000 & Mean \\
Model &  &  &  &  &  &  &  &  &  &  \\
\midrule
CellPose & 0.137000 & 0.111000 & 0.077000 & 0.054000 & 0.038000 & 0.028000 & 0.020000 & 0.014000 & 0.006000 & 0.054000 \\
CellPose- post-processing & 0.386000 & 0.362000 & 0.339000 & 0.312000 & 0.266000 & 0.228000 & 0.189000 & 0.120000 & 0.027000 & 0.248000 \\
StarDist & 0.573000 & 0.533000 & 0.411000 & 0.253000 & 0.135000 & 0.065000 & 0.020000 & 0.003000 & 0.000000 & 0.221000 \\
StarDist- post-processing & 0.689000 & 0.649000 & 0.557000 & 0.407000 & 0.276000 & 0.174000 & 0.073000 & 0.010000 & 0.000000 & 0.315000 \\
WNet3D & 0.766000 & 0.732000 & 0.669000 & 0.572000 & 0.455000 & 0.355000 & 0.254000 & 0.175000 & 0.033000 & 0.446000 \\
WNet3D - post-processing & 0.807000 & 0.783000 & 0.763000 & 0.725000 & 0.637000 & 0.534000 & 0.428000 & 0.296000 & 0.073000 & 0.561000 \\
\bottomrule
\end{tabular}
thresh 0.1 0.2 0.3 0.4 0.5 0.6 0.7 0.8 0.9 Mean
Model
CellPose 0.896114 0.865718 0.831858 0.777992 0.697961 0.575991 0.362062 0.116583 0.009619 0.570433
StarDist 0.840954 0.821589 0.786444 0.686032 0.536131 0.325982 0.109736 0.011476 0.000000 0.457594
WNet3D 0.876500 0.856034 0.834157 0.790402 0.729358 0.631616 0.491884 0.248765 0.033874 0.610288
Dataset: Platynereis ISH
\begin{tabular}{lrrrrrrrrrr}
\toprule
thresh & 0.100000 & 0.200000 & 0.300000 & 0.400000 & 0.500000 & 0.600000 & 0.700000 & 0.800000 & 0.900000 & Mean \\
Model &  &  &  &  &  &  &  &  &  &  \\
\midrule
CellPose & 0.896000 & 0.866000 & 0.832000 & 0.778000 & 0.698000 & 0.576000 & 0.362000 & 0.117000 & 0.010000 & 0.570000 \\
StarDist & 0.841000 & 0.822000 & 0.786000 & 0.686000 & 0.536000 & 0.326000 & 0.110000 & 0.011000 & 0.000000 & 0.458000 \\
WNet3D & 0.876000 & 0.856000 & 0.834000 & 0.790000 & 0.729000 & 0.632000 & 0.492000 & 0.249000 & 0.034000 & 0.610000 \\
\bottomrule
\end{tabular}
thresh 0.1 0.2 0.3 0.4 0.5 0.6 0.7 0.8 0.9 Mean
Model
CellPose 0.691070 0.662912 0.624296 0.593725 0.552695 0.497184 0.416734 0.290426 0.061947 0.487888
StarDist 0.850394 0.832546 0.803150 0.764304 0.700262 0.611024 0.492388 0.271916 0.018898 0.593876
WNet3D 0.837696 0.808377 0.778010 0.739267 0.695288 0.616754 0.512042 0.338220 0.058639 0.598255
Dataset: Platynereis
\begin{tabular}{lrrrrrrrrrr}
\toprule
thresh & 0.100000 & 0.200000 & 0.300000 & 0.400000 & 0.500000 & 0.600000 & 0.700000 & 0.800000 & 0.900000 & Mean \\
Model &  &  &  &  &  &  &  &  &  &  \\
\midrule
CellPose & 0.691000 & 0.663000 & 0.624000 & 0.594000 & 0.553000 & 0.497000 & 0.417000 & 0.290000 & 0.062000 & 0.488000 \\
StarDist & 0.850000 & 0.833000 & 0.803000 & 0.764000 & 0.700000 & 0.611000 & 0.492000 & 0.272000 & 0.019000 & 0.594000 \\
WNet3D & 0.838000 & 0.808000 & 0.778000 & 0.739000 & 0.695000 & 0.617000 & 0.512000 & 0.338000 & 0.059000 & 0.598000 \\
\bottomrule
\end{tabular}

Sanity check : Images only (no WNet)#

The aim here is to check that using the WNet does provide a benefit over using Otsu thresholding and Voronoi-based instance segmentation directly on the images.

mouse_skull_image = imread(folders[0] / "X1.tif")
platynereis_ISH_image = imread(folders[1] / "X01_cropped_downsampled.tif") 
platynereis_image = imread(folders[2] / "downsmapled_cropped_dataset_hdf5_100_0.tif")
mouse_skull_instance_tresh = np.array(
    cle.voronoi_otsu_labeling(cle.threshold_otsu(mouse_skull_image), outline_sigma=1, spot_sigma=15)
)
platynereis_ISH_instance_tresh = np.array(
    cle.voronoi_otsu_labeling(cle.threshold_otsu(platynereis_ISH_image), outline_sigma=0.5, spot_sigma=2)
)
platynereis_instance_tresh = np.array(
    cle.voronoi_otsu_labeling(cle.threshold_otsu(platynereis_image), outline_sigma=0.5, spot_sigma=2.75)
)
mouse_skull_instance_tresh_postproc = mouse_skull_postproc(
    mouse_skull_instance_tresh,
    mouse_skull_pred
)
_generate_touch_mean_intensity_matrix.py (30): generate_touch_mean_intensity_matrix is supposed to work with images of integer type only.
Loss of information is possible when passing non-integer images.
_opencl_execute.py (281): overflow encountered in cast
predictions_tresh = [
    mouse_skull_instance_tresh_postproc,
    mouse_skull_instance_tresh,
    platynereis_ISH_instance_tresh,
    platynereis_instance_tresh,
]
names_tresh = [
    "Otsu Threshold & post-processing",
    "Otsu Threshold",
    "Otsu Threshold",
    "Otsu Threshold",
]
gt_tresh = [
    mouse_skull_gt,
    mouse_skull_gt,
    platynereis_ISH_gt,
    platynereis_gt,
]
taus = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]


model_stats_tresh = []
names_stats_tresh = []

for i, p in enumerate(predictions_tresh):
    print(f"Validating on {names_tresh[i]}")
    stats = [matching_dataset(
        gt_tresh[i],
        p,
        thresh=t, 
        show_progress=False
        ) for t in taus]
    model_stats_tresh.append(stats)
    for t in taus:
        names_stats_tresh.append(names_tresh[i])
    # uncomment for ALL plots : 
    plot_performance(taus, stats, name=names_tresh[i])
    print("*"*20)
Validating on Otsu Threshold & post-processing
********************
Validating on Otsu Threshold
********************
Validating on Otsu Threshold
********************
Validating on Otsu Threshold
********************
../_images/8fe75b8ea0692c24ad6a6a945f1479db0908a6eaaa0112377ce09068de7b850c.png ../_images/bfc207ac2de3f1cfc3aeb2ae45fcd7a9f18ca1a8f3fa4c744a04b23c7b7fc56a.png ../_images/9bdee9928c4eff217956d4f34ce3fa0b152761f961f706be4fd5a8ad9220e70a.png ../_images/3f36d7439d611a5124a83ae77052f9da05087adfa8b316010f39b5bea320bf6c.png
# save threshold-based instance labels
for pred, name in zip(predictions_tresh, names):
    save_path = data_path / "processed_threshold_only_instance_labels"
    save_path.mkdir(exist_ok=True)
    name = name.replace("/", "_")
    name = name.replace(" ", "_")
    imwrite(save_path / f"{name}_otsu.tif", pred.astype(np.uint32))
dfs = [dataset_matching_stats_to_df(s) for s in model_stats_tresh]
df_im_only = pd.concat(dfs)
df_im_only["Dataset"] = names_stats
df_im_only["Model"] = "Otsu Threshold"

def replace_model_name(row):
    if "w/ post-processing" in row["Dataset"]:
        return "Otsu Threshold - post-processing"
    elif "w/o post-processing" in row["Dataset"]:
        return "Otsu Threshold"
    else:
        return row["Model"]

df_im_only["Model"] = df_im_only.apply(replace_model_name, axis=1)
df_im_only["Dataset"] = df_im_only["Dataset"].apply(lambda x: x.replace(" w/o post-processing", ""))
df_im_only["Dataset"] = df_im_only["Dataset"].apply(lambda x: x.replace(" w/ post-processing", ""))
df_w_tresh = pd.concat([df_wnet, df_im_only])
df_im_only
criterion fp tp fn precision recall accuracy f1 n_true n_pred mean_true_score mean_matched_score panoptic_quality by_image Dataset Model
thresh
0.1 iou 730 3004 1902 0.804499 0.612311 0.533002 0.695370 4906 3734 0.429889 0.702076 0.488203 False Mouse skull Otsu Threshold - post-processing
0.2 iou 848 2886 2020 0.772898 0.588259 0.501564 0.668056 4906 3734 0.426253 0.724600 0.484073 False Mouse skull Otsu Threshold - post-processing
0.3 iou 941 2793 2113 0.747991 0.569303 0.477681 0.646528 4906 3734 0.421603 0.740561 0.478793 False Mouse skull Otsu Threshold - post-processing
0.4 iou 1089 2645 2261 0.708356 0.539136 0.441201 0.612269 4906 3734 0.411080 0.762480 0.466842 False Mouse skull Otsu Threshold - post-processing
0.5 iou 1390 2344 2562 0.627745 0.477782 0.372300 0.542593 4906 3734 0.383330 0.802311 0.435328 False Mouse skull Otsu Threshold - post-processing
0.6 iou 1619 2115 2791 0.566417 0.431105 0.324138 0.489583 4906 3734 0.357739 0.829820 0.406266 False Mouse skull Otsu Threshold - post-processing
0.7 iou 1887 1847 3059 0.494644 0.376478 0.271898 0.427546 4906 3734 0.322019 0.855345 0.365700 False Mouse skull Otsu Threshold - post-processing
0.8 iou 2291 1443 3463 0.386449 0.294130 0.200500 0.334028 4906 3734 0.260116 0.884360 0.295401 False Mouse skull Otsu Threshold - post-processing
0.9 iou 3144 590 4316 0.158007 0.120261 0.073292 0.136574 4906 3734 0.110768 0.921066 0.125794 False Mouse skull Otsu Threshold - post-processing
0.1 iou 1071 2988 1918 0.736142 0.609050 0.499916 0.666592 4906 4059 0.405943 0.666518 0.444296 False Mouse skull Otsu Threshold
0.2 iou 1217 2842 2064 0.700172 0.579291 0.464152 0.634021 4906 4059 0.401560 0.693192 0.439499 False Mouse skull Otsu Threshold
0.3 iou 1389 2670 2236 0.657797 0.544232 0.424146 0.595650 4906 4059 0.392655 0.721485 0.429752 False Mouse skull Otsu Threshold
0.4 iou 1524 2535 2371 0.624538 0.516714 0.394246 0.565533 4906 4059 0.383139 0.741490 0.419337 False Mouse skull Otsu Threshold
0.5 iou 1839 2220 2686 0.546933 0.452507 0.329133 0.495259 4906 4059 0.353814 0.781896 0.387241 False Mouse skull Otsu Threshold
0.6 iou 2144 1915 2991 0.471791 0.390338 0.271631 0.427217 4906 4059 0.320217 0.820358 0.350471 False Mouse skull Otsu Threshold
0.7 iou 2403 1656 3250 0.407982 0.337546 0.226570 0.369437 4906 4059 0.285787 0.846660 0.312787 False Mouse skull Otsu Threshold
0.8 iou 2822 1237 3669 0.304755 0.252140 0.160067 0.275962 4906 4059 0.221314 0.877741 0.242223 False Mouse skull Otsu Threshold
0.9 iou 3625 434 4472 0.106923 0.088463 0.050873 0.096821 4906 4059 0.081456 0.920789 0.089152 False Mouse skull Otsu Threshold
0.1 iou 397 2356 296 0.855794 0.888386 0.772712 0.871785 2652 2753 0.588878 0.662863 0.577874 False Platynereis ISH Otsu Threshold
0.2 iou 463 2290 362 0.831820 0.863499 0.735152 0.847364 2652 2753 0.585265 0.677783 0.574328 False Platynereis ISH Otsu Threshold
0.3 iou 544 2209 443 0.802397 0.832956 0.691176 0.817391 2652 2753 0.577540 0.693362 0.566748 False Platynereis ISH Otsu Threshold
0.4 iou 668 2085 567 0.757356 0.786199 0.628012 0.771508 2652 2753 0.561198 0.713812 0.550712 False Platynereis ISH Otsu Threshold
0.5 iou 844 1909 743 0.693425 0.719834 0.546053 0.706383 2652 2753 0.531248 0.738015 0.521321 False Platynereis ISH Otsu Threshold
0.6 iou 1117 1636 1016 0.594261 0.616893 0.434067 0.605365 2652 2753 0.474799 0.769662 0.465927 False Platynereis ISH Otsu Threshold
0.7 iou 1471 1282 1370 0.465674 0.483409 0.310939 0.474376 2652 2753 0.387787 0.802194 0.380541 False Platynereis ISH Otsu Threshold
0.8 iou 2087 666 1986 0.241918 0.251131 0.140536 0.246438 2652 2753 0.213029 0.848277 0.209048 False Platynereis ISH Otsu Threshold
0.9 iou 2683 70 2582 0.025427 0.026395 0.013121 0.025902 2652 2753 0.024221 0.917630 0.023768 False Platynereis ISH Otsu Threshold
0.1 iou 68 743 309 0.916153 0.706274 0.663393 0.797638 1052 811 0.491839 0.696385 0.555464 False Platynereis Otsu Threshold
0.2 iou 91 720 332 0.887793 0.684411 0.629921 0.772947 1052 811 0.488707 0.714055 0.551926 False Platynereis Otsu Threshold
0.3 iou 128 683 369 0.842170 0.649240 0.578814 0.733226 1052 811 0.479873 0.739131 0.541950 False Platynereis Otsu Threshold
0.4 iou 157 654 398 0.806412 0.621673 0.540943 0.702093 1052 811 0.470436 0.756726 0.531293 False Platynereis Otsu Threshold
0.5 iou 193 618 434 0.762022 0.587452 0.496386 0.663446 1052 811 0.455120 0.774736 0.513995 False Platynereis Otsu Threshold
0.6 iou 261 550 502 0.678175 0.522814 0.418888 0.590446 1052 811 0.419812 0.802986 0.474119 False Platynereis Otsu Threshold
0.7 iou 339 472 580 0.581998 0.448669 0.339324 0.506710 1052 811 0.371810 0.828695 0.419908 False Platynereis Otsu Threshold
0.8 iou 498 313 739 0.385943 0.297529 0.201935 0.336017 1052 811 0.257857 0.866661 0.291213 False Platynereis Otsu Threshold
0.9 iou 739 72 980 0.088779 0.068441 0.040201 0.077295 1052 811 0.062995 0.920429 0.071144 False Platynereis Otsu Threshold

Statistics tables#

for dataset in df_im_only["Dataset"].unique():
    table_df = df_im_only[df_im_only["Dataset"] == dataset].copy().reset_index()
    table_df = table_df.pivot(index="Model", columns="thresh", values="f1")
    # add the mean as an extra "threshold" - getting the mean across all thresholds
    table_df["Mean"] = table_df.mean(axis=1)
    # round all values to 3 decimals
    display(table_df)
    table_df = table_df.round(3)
    latex_table = table_df.to_latex()
    print(f"Dataset: {dataset}")
    print(latex_table)
    dataset_tables[dataset] = pd.concat([dataset_tables[dataset], table_df])
thresh 0.1 0.2 0.3 0.4 0.5 0.6 0.7 0.8 0.9 Mean
Model
Otsu Threshold 0.666592 0.634021 0.595650 0.565533 0.495259 0.427217 0.369437 0.275962 0.096821 0.458499
Otsu Threshold - post-processing 0.695370 0.668056 0.646528 0.612269 0.542593 0.489583 0.427546 0.334028 0.136574 0.505838
Dataset: Mouse skull
\begin{tabular}{lrrrrrrrrrr}
\toprule
thresh & 0.100000 & 0.200000 & 0.300000 & 0.400000 & 0.500000 & 0.600000 & 0.700000 & 0.800000 & 0.900000 & Mean \\
Model &  &  &  &  &  &  &  &  &  &  \\
\midrule
Otsu Threshold & 0.667000 & 0.634000 & 0.596000 & 0.566000 & 0.495000 & 0.427000 & 0.369000 & 0.276000 & 0.097000 & 0.458000 \\
Otsu Threshold - post-processing & 0.695000 & 0.668000 & 0.647000 & 0.612000 & 0.543000 & 0.490000 & 0.428000 & 0.334000 & 0.137000 & 0.506000 \\
\bottomrule
\end{tabular}
thresh 0.1 0.2 0.3 0.4 0.5 0.6 0.7 0.8 0.9 Mean
Model
Otsu Threshold 0.871785 0.847364 0.817391 0.771508 0.706383 0.605365 0.474376 0.246438 0.025902 0.596279
Dataset: Platynereis ISH
\begin{tabular}{lrrrrrrrrrr}
\toprule
thresh & 0.100000 & 0.200000 & 0.300000 & 0.400000 & 0.500000 & 0.600000 & 0.700000 & 0.800000 & 0.900000 & Mean \\
Model &  &  &  &  &  &  &  &  &  &  \\
\midrule
Otsu Threshold & 0.872000 & 0.847000 & 0.817000 & 0.772000 & 0.706000 & 0.605000 & 0.474000 & 0.246000 & 0.026000 & 0.596000 \\
\bottomrule
\end{tabular}
thresh 0.1 0.2 0.3 0.4 0.5 0.6 0.7 0.8 0.9 Mean
Model
Otsu Threshold 0.797638 0.772947 0.733226 0.702093 0.663446 0.590446 0.50671 0.336017 0.077295 0.575535
Dataset: Platynereis
\begin{tabular}{lrrrrrrrrrr}
\toprule
thresh & 0.100000 & 0.200000 & 0.300000 & 0.400000 & 0.500000 & 0.600000 & 0.700000 & 0.800000 & 0.900000 & Mean \\
Model &  &  &  &  &  &  &  &  &  &  \\
\midrule
Otsu Threshold & 0.798000 & 0.773000 & 0.733000 & 0.702000 & 0.663000 & 0.590000 & 0.507000 & 0.336000 & 0.077000 & 0.576000 \\
\bottomrule
\end{tabular}
# plot each model vs threshold only pair separately
datasets = [
    "Mouse skull",
    "Mouse skull",
    "Platynereis ISH",
    "Platynereis"
]
wnet_names = [
    "WNet3D & post-processing",
    "WNet3D",
    "WNet3D",
    "WNet3D"
]
for i, stats_model, stats_image_only in zip(range(len(model_stats)), model_stats, model_stats_tresh):
    plot_stat_comparison(taus=taus, stats_list=[stats_model, stats_image_only], model_names=[wnet_names[i], names_tresh[i]], metric="IoU", plt_size=(9, 6))
    plt.title(f"{datasets[i]} - {wnet_names[i]} vs {names_tresh[i]}")
    if SAVE_PLOTS_AS_PNG:
        plt.savefig(f"{names_tresh[i]}_{datasets[i]}_f1_comparison.png")
    if SAVE_PLOTS_AS_SVG:
        plt.savefig(f"{names_tresh[i]}_{datasets[i]}_f1_comparison.svg", bbox_inches='tight')
# plot_stat_comparison(taus=taus, stats_list=model_stats+model_stats_images_only, model_names=df_all.Dataset.unique(), metric="IoU")
../_images/f7a0efb21ad384fe14db8f956ffadd140affc57c01d32d65751941fe1bea3101.png ../_images/ba872b293a2253aeca288464c3c0b0456b4e5d98b6f90e048d00777eece7a16c.png ../_images/d4eff47d8913ed4a5d81bd4465d8c7a7018dc010ae7f99d037559dca7bd39c68.png ../_images/de14942064e063def3732e2cba45fd710b7de9f7a089d113275fa1ff853808b5.png

Combined plots#

# Make plots with all models and thresholds : Cellpose, StarDist, WNet3D, Otsu

plot_stat_comparison(taus=taus, stats_list=model_stats+model_stats_others+model_stats_tresh, model_names=names+names_others+names_tresh, metric="IoU", plt_size=(9, 6), title="All models and thresholds")
../_images/80abdb54c3636aa2f9d4f7e9176c2b1674748df0537ba53c062f72d487fc698e.png
# Make plots with all models and thresholds : Cellpose, StarDist, WNet3D, Otsu
datasets = [
    "Mouse Skull with post-processing",
    "Mouse Skull",
    "Platynereis ISH",
    "Platynereis"
]
for i in range(4):
    if i < 2:
        skip = 2
        offset = 0
    elif i < 3:
        skip = 1
        offset = 2
    else:
        skip = 1
        offset = 3
    plot_stat_comparison(taus=taus, stats_list=[model_stats[i]]+model_stats_others[i+offset:i+1+offset+skip:skip]+[model_stats_tresh[i]], model_names=["WNet3D", "Cellpose - pretrained", "StarDist - retrained", "Otsu threshold"], metric="IoU", plt_size=(9, 6), colormap=SIMPLE_COLORMAP)
    plt.title(datasets[i])
    if SAVE_PLOTS_AS_PNG:
        plt.savefig(f"extra_{datasets[i]}_all_f1_comparison.png")
    if SAVE_PLOTS_AS_SVG:
        plt.savefig(f"extra_{datasets[i]}_all_f1_comparison.svg", bbox_inches='tight')
../_images/e39201bcba3d2f72d7df23d4ef38184f62d6dd4effc23ced8ce1dd5fe621780f.png ../_images/2b58b422601615a30694703fdd18e3a207e0f10e8f216ba3a1ea1f902012e0bc.png ../_images/9156696ec640ad36c51e19e50c65dd241263c1b720b42996c4eb0e0e88d6eadc.png ../_images/2efbf426f500c00f160f857d6535750b27c6d6efcb957b264d88fc3ad2aabece.png

Statistical analysis#

from plots import _format_plot

stats_test_dfs = {}
for dataset in dataset_tables.keys():
    test_df = dataset_tables[dataset].copy()
    test_df.drop(columns="Mean", inplace=True)
    test_df.reset_index(inplace=True)
    test_df = test_df.melt(id_vars=["Model"], var_name="Threshold", value_name="F1")
    display(test_df)
    stats_test_dfs[dataset] = test_df
Model Threshold F1
0 CellPose 0.1 0.137
1 CellPose- post-processing 0.1 0.386
2 StarDist 0.1 0.573
3 StarDist- post-processing 0.1 0.689
4 WNet3D 0.1 0.766
... ... ... ...
67 StarDist- post-processing 0.9 0.000
68 WNet3D 0.9 0.033
69 WNet3D - post-processing 0.9 0.073
70 Otsu Threshold 0.9 0.097
71 Otsu Threshold - post-processing 0.9 0.137

72 rows × 3 columns

Model Threshold F1
0 CellPose 0.1 0.896
1 StarDist 0.1 0.841
2 WNet3D 0.1 0.876
3 Otsu Threshold 0.1 0.872
4 CellPose 0.2 0.866
5 StarDist 0.2 0.822
6 WNet3D 0.2 0.856
7 Otsu Threshold 0.2 0.847
8 CellPose 0.3 0.832
9 StarDist 0.3 0.786
10 WNet3D 0.3 0.834
11 Otsu Threshold 0.3 0.817
12 CellPose 0.4 0.778
13 StarDist 0.4 0.686
14 WNet3D 0.4 0.790
15 Otsu Threshold 0.4 0.772
16 CellPose 0.5 0.698
17 StarDist 0.5 0.536
18 WNet3D 0.5 0.729
19 Otsu Threshold 0.5 0.706
20 CellPose 0.6 0.576
21 StarDist 0.6 0.326
22 WNet3D 0.6 0.632
23 Otsu Threshold 0.6 0.605
24 CellPose 0.7 0.362
25 StarDist 0.7 0.110
26 WNet3D 0.7 0.492
27 Otsu Threshold 0.7 0.474
28 CellPose 0.8 0.117
29 StarDist 0.8 0.011
30 WNet3D 0.8 0.249
31 Otsu Threshold 0.8 0.246
32 CellPose 0.9 0.010
33 StarDist 0.9 0.000
34 WNet3D 0.9 0.034
35 Otsu Threshold 0.9 0.026
Model Threshold F1
0 CellPose 0.1 0.691
1 StarDist 0.1 0.850
2 WNet3D 0.1 0.838
3 Otsu Threshold 0.1 0.798
4 CellPose 0.2 0.663
5 StarDist 0.2 0.833
6 WNet3D 0.2 0.808
7 Otsu Threshold 0.2 0.773
8 CellPose 0.3 0.624
9 StarDist 0.3 0.803
10 WNet3D 0.3 0.778
11 Otsu Threshold 0.3 0.733
12 CellPose 0.4 0.594
13 StarDist 0.4 0.764
14 WNet3D 0.4 0.739
15 Otsu Threshold 0.4 0.702
16 CellPose 0.5 0.553
17 StarDist 0.5 0.700
18 WNet3D 0.5 0.695
19 Otsu Threshold 0.5 0.663
20 CellPose 0.6 0.497
21 StarDist 0.6 0.611
22 WNet3D 0.6 0.617
23 Otsu Threshold 0.6 0.590
24 CellPose 0.7 0.417
25 StarDist 0.7 0.492
26 WNet3D 0.7 0.512
27 Otsu Threshold 0.7 0.507
28 CellPose 0.8 0.290
29 StarDist 0.8 0.272
30 WNet3D 0.8 0.338
31 Otsu Threshold 0.8 0.336
32 CellPose 0.9 0.062
33 StarDist 0.9 0.019
34 WNet3D 0.9 0.059
35 Otsu Threshold 0.9 0.077
from plots import _format_plot
from scipy.stats import kruskal
from scikit_posthocs import posthoc_conover
import matplotlib.colors as colors
import matplotlib.patches as mpatches

def plot_boxplot_and_test(test_df, dataset):
    fig, ax = plt.subplots(figsize=(8, 8), dpi=DPI)
    sns.boxplot(data=test_df, hue="Model", y="F1",
                palette="tab10", ax=ax)
    _format_plot(ax, xlabel="Model", ylabel="F1", title=f"F1 distribution for {dataset}")
    fig.patch.set_alpha(0)
    ax.patch.set_alpha(0)
    plt.show()
    print(f"Kruskal-Wallis H test for {dataset}")
    print(kruskal(*test_df.groupby("Model")["F1"].apply(list)))

posthoc_dfs = {}
for dataset in stats_test_dfs.keys():
    if dataset == "Mouse skull":
        test_df = stats_test_dfs[dataset].copy()
        df_postproc = test_df[test_df["Model"].str.contains("post-processing")]
        df_no_postproc = test_df[~test_df["Model"].str.contains("post-processing")]
        plot_boxplot_and_test(df_postproc, dataset+ " - post-processing")
        plot_boxplot_and_test(df_no_postproc, dataset)
        posthoc_dfs["Mouse skull"] = df_no_postproc
        posthoc_dfs["Mouse skull - post-processing"] = df_postproc
    else:
        test_df = stats_test_dfs[dataset].copy()
        plot_boxplot_and_test(test_df, dataset)
        posthoc_dfs[dataset] = test_df
../_images/cefbbd94e75f5c2d64024296a5b7f510ec291b0de7bf14abeeedd91fb67627c3.png
Kruskal-Wallis H test for Mouse skull - post-processing
KruskalResult(statistic=10.12823263531297, pvalue=0.017506978718827077)
../_images/2807d0e392657384b48b6e1ef1176ea71b4df64aaacd248699fb110d6125a409.png
Kruskal-Wallis H test for Mouse skull
KruskalResult(statistic=15.782311465797115, pvalue=0.0012566686698318734)
../_images/e4b7ef606b424e43c1d299f1e875489a2d454fb137362933861801108c30be97.png
Kruskal-Wallis H test for Platynereis ISH
KruskalResult(statistic=1.6026026026026017, pvalue=0.6587998410905409)
../_images/dfdf84ba984eb5b03ff56435f7e9b9f8ad5ac11e54c930c43eb56630867fe66a.png
Kruskal-Wallis H test for Platynereis
KruskalResult(statistic=3.0549477267201435, pvalue=0.3832306625105366)
from scikit_posthocs import posthoc_conover
import matplotlib.colors as colors
import matplotlib.patches as mpatches
for dataset in posthoc_dfs.keys():
    fig, ax = plt.subplots(figsize=(8, 8), dpi=DPI)
    results = posthoc_conover(posthoc_dfs[dataset], val_col="F1", group_col="Model", p_adjust="Holm")

    low_color = COLORMAP[0]
    mid_color = COLORMAP[1]
    high_color = COLORMAP[2]
    equals_1_color = COLORMAP[3]

    levels = [0, 0.05, 0.06, 0.99, 1]
    colors_list = [low_color, mid_color, high_color, high_color, equals_1_color]
    cmap = colors.LinearSegmentedColormap.from_list("", list(zip(levels, colors_list)))
    norm = colors.Normalize(vmin=0, vmax=1)

    diag = np.diag_indices(results.shape[0])
    results.values[diag] = np.nan
    sns.heatmap(results, annot=True, cmap=cmap, norm=norm, cbar=False, fmt=".4f", linewidths=0.5, ax=ax)

    patches = [mpatches.Patch(color=low_color, label='Above 0 (significant)'),
            mpatches.Patch(color=mid_color, label='Below 0.05 (significant)'),
            mpatches.Patch(color=high_color, label='Above 0.05 (not significant)'),
            mpatches.Patch(color=equals_1_color, label='Equals 1 (not significant)')]

    ax.legend(handles=patches, bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0.)
    # transparent background
    fig.patch.set_alpha(0)
    ax.patch.set_alpha(0)
../_images/57a17903869c5f8520756fe20ad607286bbf18a104e0144ffc02a05f6fb15f82.png ../_images/051a34fc0065075f3af5a4dc96312867ad600973b2998028c86896e7c5a49309.png ../_images/c748bc9c56329d57703ef1e084ca8a8b70eb0fda3f51aa934c6bd9bd2bf56157.png ../_images/6ea7a03d6b460472526f98958b58e94f5882a848eb5189308b0839db086900b9.png