Edge-Enhanced Attention GCN#

Edge-Enhanced Attention GNN definition#

We would now like to make use of edge features in our model, as we are for now completely missing out on these additional predictors.

To this end, we will use a custom approach inspired by the Exploiting Edge Features for Graph Neural Networks paper by Gong and Cheng, 2019.

We will adapt their Attention-based edge-enhanced neural network, termed EGNN(A), by redefining our Attention layer in the following way :

  • Compute the attention score \(\alpha^l_{ij}\) separately for each feature dimension \(p\) such that:

    \(\alpha^l_{ijp} = f^l(X^{l-1}_i, X^{l-1}_j) \cdot E^{l-1}_{ijp}\)

    In this, \(f^l\) is our previously used formula for the attention score (scalar) and \(X^{l-1}\) the output of the previous layer.

    Here we simply compute the dot product with the edge feature matrix \(E\) of shape \(N \times N \times E_{feat}\).

  • Replace their proposed Double Stochastic Normalization with Group normalization for simplicity. We only have to ensure that the channel dimension corresponds to the edge features before normalizing; we will select the number of groups as a hyperparameter.

  • The obtained edge-enhanced attention score \(\alpha^l\) of shape \(N \times N \times E_{feat}\) is then used as the edge features \(E^l\) in the next layer as they suggest, i.e. \(E^{l} = \alpha^l\)

  • Then, not unlike multi-head attention (but with edge feature-specific heads instead of randomly initialized heads using the same input data), we concatenate on the edge feature dimension and feed the result to a fully-connected layer.

    With this we obtain \(A^l\) of shape \(B \times N \times N\) :

    \(A^l_{ij} = a_{ij} \vert \vert_{p=1}^P \alpha_{ijp}^l\)

    where \(a\) is a learnable weight of shape \((E_{feat} N) \times N\) and \(\vert \vert\) is the concatenation operator.

  • Finally, the output is obtained by multiplying \(A^l\) with the support \(X^{l-1} W^l\) and applying the activation function:

    \(X^l = \sigma (A^l X^{l-1} W^l)\)

  • We return \(X^l\) and \(E^l\), to be used by the next layer.

Validating our model#

In order to show that our proposed architecture is indeed an improvement, we perform some simple checks to make sure that our edge features and the Group normalization are actually helpful.

We will compare performance when :

  • Setting the edge feature matrix \(E^l\) to ones

  • Setting the edge-enhanced attention matrix \(A^l\) to ones before multiplying with the support

  • Removing the Group normalization

Results can be seen in Hyperparameter tuning for Edge-Enhanced Attention GCN.

Hide code cell source
import warnings
import os
from sys import path

warnings.filterwarnings(
    "ignore"
)  # ignore warnings from missing deterministic implementation. These are strictly from wandb and do not affect the reproducibility of the actual runs.
os.environ[
    "WANDB_NOTEBOOK_NAME"
] = "EdgeAttentionGCN.ipynb"  # name the notebook for wandb tracking
path.append("../code")
Hide code cell source
%load_ext autoreload
%autoreload 2
Hide code cell source
import torch
from torch import nn
import numpy as np
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
import matplotlib.pyplot as plt
import seaborn as sns
import model as m
import training as t
import utils
import logging
import sys
from utils import LOG as logger

logging.basicConfig(
    format="%(message)s", level=logging.INFO, stream=sys.stdout
)
logger.setLevel(logging.INFO)
# logger.setLevel(logging.DEBUG)

t.WANDB_MODE = "disabled"

DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
dataloaders, dataset = t.create_dataloaders(
    batch_size=20,
    use_edge_features=True,
    # train_split=0.5,
)
Length of train set: 131
Length of validation set: 28
Length of test set: 29

Training#

Model#

node_features = dataset.node_features.shape[2]
conv_dims = [256, 256, 128, 64]
model = m.EdgeAttentionGCN(
    num_features=node_features,
    conv_dims=conv_dims,
    # fcn_layers=[128],
    activation=nn.LeakyReLU(),
    dropout=0.3,
    pooling="max",
    norm=nn.BatchNorm1d,
)

# utils._print_gradient_hook(model)


print(model)
Initialized model with 4 graph conv layers
Initialized model with 1 fully connected layers
EdgeAttentionGCN(
  (convs_layers): ModuleList(
    (0): EdgeConv(
      (weight): Linear(in_features=7, out_features=256, bias=False)
      (S): Linear(in_features=512, out_features=28, bias=False)
      (edge_layer): Linear(in_features=112, out_features=28, bias=False)
      (activation): LeakyReLU(negative_slope=0.01)
      (att_activation): LeakyReLU(negative_slope=0.1)
      (softmax): Softmax(dim=1)
      (instance_norm): GroupNorm(4, 4, eps=1e-05, affine=True)
    )
    (1): EdgeConv(
      (weight): Linear(in_features=256, out_features=256, bias=False)
      (S): Linear(in_features=512, out_features=28, bias=False)
      (edge_layer): Linear(in_features=112, out_features=28, bias=False)
      (activation): LeakyReLU(negative_slope=0.01)
      (att_activation): LeakyReLU(negative_slope=0.1)
      (softmax): Softmax(dim=1)
      (instance_norm): GroupNorm(4, 4, eps=1e-05, affine=True)
    )
    (2): EdgeConv(
      (weight): Linear(in_features=256, out_features=128, bias=False)
      (S): Linear(in_features=256, out_features=28, bias=False)
      (edge_layer): Linear(in_features=112, out_features=28, bias=False)
      (activation): LeakyReLU(negative_slope=0.01)
      (att_activation): LeakyReLU(negative_slope=0.1)
      (softmax): Softmax(dim=1)
      (instance_norm): GroupNorm(4, 4, eps=1e-05, affine=True)
    )
    (3): EdgeConv(
      (weight): Linear(in_features=128, out_features=64, bias=False)
      (S): Linear(in_features=128, out_features=28, bias=False)
      (edge_layer): Linear(in_features=112, out_features=28, bias=False)
      (activation): Identity()
      (att_activation): LeakyReLU(negative_slope=0.1)
      (softmax): Softmax(dim=1)
      (instance_norm): GroupNorm(4, 4, eps=1e-05, affine=True)
    )
  )
  (batch_norms): ModuleList(
    (0-1): 2 x BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (3): Identity()
  )
  (fcn_layers): ModuleList(
    (0): Linear(in_features=64, out_features=1, bias=True)
  )
  (dropout): Dropout(p=0.3, inplace=False)
  (pooling): MaxPooling()
)

Training parameters#

epochs = 200 # this model converges much faster than the others
learning_rate = 1e-4 # it also benefits from a lower learning rate

model.to(DEVICE)

label_counts = dataset[:]["class_y"].unique(return_counts=True)[1]
pos_weight = label_counts[0] / label_counts[1]

loss_fn = nn.BCEWithLogitsLoss(pos_weight=torch.tensor(pos_weight))
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

history = {
    "epoch": 0,
    "loss": [],
    "acc": [],
    "val-roc": [],
    "val-ap": [],
}

Training loop#

history = t.train_loop(
    history=history,
    train_dataloader=dataloaders[0],
    val_dataloader=dataloaders[1],
    model=model,
    optimizer=optimizer,
    loss_fn=loss_fn,
    epochs=epochs,
    device=DEVICE,
    use_scheduler=False,
    test_dataloader=dataloaders[2],
    use_edges=True,
)
Watching
Epoch   1/200:Epoch loss: 0.8547 - avg acc: 30.5% - val-roc: 0.7551 - val-ap: 0.7927 (0.5s/epoch)
Epoch   2/200:Epoch loss: 0.8229 - avg acc: 30.5% - val-roc: 0.8367 - val-ap: 0.8719 (0.2s/epoch)
Epoch   3/200:Epoch loss: 0.6942 - avg acc: 30.5% - val-roc: 0.8571 - val-ap: 0.9011 (0.2s/epoch)
Epoch   4/200:Epoch loss: 0.6281 - avg acc: 30.5% - val-roc: 0.8520 - val-ap: 0.9025 (0.2s/epoch)
Epoch   5/200:Epoch loss: 0.5854 - avg acc: 30.5% - val-roc: 0.8367 - val-ap: 0.8865 (0.2s/epoch)
Epoch   6/200:Epoch loss: 0.4621 - avg acc: 30.5% - val-roc: 0.7857 - val-ap: 0.8465 (0.2s/epoch)
Epoch   7/200:Epoch loss: 0.4798 - avg acc: 30.5% - val-roc: 0.7398 - val-ap: 0.8151 (0.2s/epoch)
Epoch   8/200:Epoch loss: 0.4373 - avg acc: 31.3% - val-roc: 0.7347 - val-ap: 0.8136 (0.2s/epoch)
Epoch   9/200:Epoch loss: 0.3781 - avg acc: 32.1% - val-roc: 0.7347 - val-ap: 0.8104 (0.2s/epoch)
Epoch  10/200:Epoch loss: 0.4047 - avg acc: 33.6% - val-roc: 0.7857 - val-ap: 0.8379 (0.2s/epoch)
Epoch  11/200:Epoch loss: 0.3483 - avg acc: 34.4% - val-roc: 0.7857 - val-ap: 0.8274 (0.2s/epoch)
Epoch  12/200:Epoch loss: 0.3539 - avg acc: 37.4% - val-roc: 0.7959 - val-ap: 0.8267 (0.2s/epoch)
Epoch  13/200:Epoch loss: 0.3533 - avg acc: 36.6% - val-roc: 0.7908 - val-ap: 0.8148 (0.2s/epoch)
Epoch  14/200:Epoch loss: 0.3279 - avg acc: 37.4% - val-roc: 0.8061 - val-ap: 0.8228 (0.2s/epoch)
Epoch  15/200:Epoch loss: 0.3074 - avg acc: 36.6% - val-roc: 0.8061 - val-ap: 0.8265 (0.2s/epoch)
Epoch  16/200:Epoch loss: 0.3455 - avg acc: 36.6% - val-roc: 0.8112 - val-ap: 0.8211 (0.2s/epoch)
Epoch  17/200:Epoch loss: 0.2500 - avg acc: 37.4% - val-roc: 0.8112 - val-ap: 0.8106 (0.2s/epoch)
Epoch  18/200:Epoch loss: 0.3071 - avg acc: 37.4% - val-roc: 0.8112 - val-ap: 0.8145 (0.2s/epoch)
Epoch  19/200:Epoch loss: 0.3333 - avg acc: 37.4% - val-roc: 0.8112 - val-ap: 0.8147 (0.2s/epoch)
Epoch  20/200:Epoch loss: 0.2852 - avg acc: 37.4% - val-roc: 0.8163 - val-ap: 0.7983 (0.2s/epoch)
Epoch  21/200:Epoch loss: 0.2973 - avg acc: 37.4% - val-roc: 0.8010 - val-ap: 0.7767 (0.2s/epoch)
Epoch  22/200:Epoch loss: 0.2527 - avg acc: 40.5% - val-roc: 0.7908 - val-ap: 0.7807 (0.2s/epoch)
Epoch  23/200:Epoch loss: 0.2623 - avg acc: 41.2% - val-roc: 0.7959 - val-ap: 0.8056 (0.2s/epoch)
Epoch  24/200:Epoch loss: 0.2349 - avg acc: 42.0% - val-roc: 0.7959 - val-ap: 0.8080 (0.2s/epoch)
Epoch  25/200:Epoch loss: 0.2452 - avg acc: 44.3% - val-roc: 0.8010 - val-ap: 0.8157 (0.2s/epoch)
Epoch  26/200:Epoch loss: 0.2307 - avg acc: 43.5% - val-roc: 0.8061 - val-ap: 0.8159 (0.2s/epoch)
Epoch  27/200:Epoch loss: 0.2093 - avg acc: 45.0% - val-roc: 0.8112 - val-ap: 0.8257 (0.2s/epoch)
Epoch  28/200:Epoch loss: 0.2342 - avg acc: 46.6% - val-roc: 0.8163 - val-ap: 0.8333 (0.2s/epoch)
Epoch  29/200:Epoch loss: 0.1984 - avg acc: 46.6% - val-roc: 0.8163 - val-ap: 0.8277 (0.2s/epoch)
Epoch  30/200:Epoch loss: 0.1390 - avg acc: 46.6% - val-roc: 0.8163 - val-ap: 0.8156 (0.2s/epoch)
Epoch  31/200:Epoch loss: 0.1583 - avg acc: 47.3% - val-roc: 0.8316 - val-ap: 0.8359 (0.2s/epoch)
Epoch  32/200:Epoch loss: 0.1273 - avg acc: 47.3% - val-roc: 0.8316 - val-ap: 0.8244 (0.2s/epoch)
Epoch  33/200:Epoch loss: 0.2196 - avg acc: 50.4% - val-roc: 0.8469 - val-ap: 0.8433 (0.2s/epoch)
Epoch  34/200:Epoch loss: 0.1588 - avg acc: 51.9% - val-roc: 0.8571 - val-ap: 0.8471 (0.2s/epoch)
Epoch  35/200:Epoch loss: 0.1904 - avg acc: 51.9% - val-roc: 0.8622 - val-ap: 0.8493 (0.3s/epoch)
Epoch  36/200:Epoch loss: 0.1403 - avg acc: 51.1% - val-roc: 0.8724 - val-ap: 0.8636 (0.2s/epoch)
Epoch  37/200:Epoch loss: 0.1701 - avg acc: 51.9% - val-roc: 0.8776 - val-ap: 0.8662 (0.2s/epoch)
Epoch  38/200:Epoch loss: 0.2017 - avg acc: 51.9% - val-roc: 0.8776 - val-ap: 0.8656 (0.2s/epoch)
Epoch  39/200:Epoch loss: 0.1411 - avg acc: 51.9% - val-roc: 0.8776 - val-ap: 0.8656 (0.2s/epoch)
Epoch  40/200:Epoch loss: 0.1241 - avg acc: 54.2% - val-roc: 0.8878 - val-ap: 0.8768 (0.2s/epoch)
Epoch  41/200:Epoch loss: 0.1325 - avg acc: 55.0% - val-roc: 0.8929 - val-ap: 0.8809 (0.2s/epoch)
Epoch  42/200:Epoch loss: 0.1458 - avg acc: 55.0% - val-roc: 0.8929 - val-ap: 0.8809 (0.2s/epoch)
Epoch  43/200:Epoch loss: 0.0882 - avg acc: 56.5% - val-roc: 0.8929 - val-ap: 0.8809 (0.2s/epoch)
Epoch  44/200:Epoch loss: 0.1151 - avg acc: 55.7% - val-roc: 0.8878 - val-ap: 0.8732 (0.2s/epoch)
Epoch  45/200:Epoch loss: 0.1482 - avg acc: 58.0% - val-roc: 0.8929 - val-ap: 0.8809 (0.2s/epoch)
Epoch  46/200:Epoch loss: 0.0878 - avg acc: 58.0% - val-roc: 0.8929 - val-ap: 0.8809 (0.2s/epoch)
Epoch  47/200:Epoch loss: 0.1028 - avg acc: 58.8% - val-roc: 0.8980 - val-ap: 0.8843 (0.2s/epoch)
Epoch  48/200:Epoch loss: 0.0999 - avg acc: 60.3% - val-roc: 0.8980 - val-ap: 0.8843 (0.2s/epoch)
Epoch  49/200:Epoch loss: 0.0654 - avg acc: 58.8% - val-roc: 0.9031 - val-ap: 0.8876 (0.2s/epoch)
Epoch  50/200:Epoch loss: 0.1199 - avg acc: 60.3% - val-roc: 0.9031 - val-ap: 0.8876 (0.2s/epoch)
Epoch  51/200:Epoch loss: 0.1378 - avg acc: 63.4% - val-roc: 0.9082 - val-ap: 0.8945 (0.2s/epoch)
Epoch  52/200:Epoch loss: 0.0923 - avg acc: 65.6% - val-roc: 0.9082 - val-ap: 0.8945 (0.2s/epoch)
Epoch  53/200:Epoch loss: 0.0485 - avg acc: 67.2% - val-roc: 0.9082 - val-ap: 0.8945 (0.2s/epoch)
Epoch  54/200:Epoch loss: 0.0706 - avg acc: 66.4% - val-roc: 0.9133 - val-ap: 0.8984 (0.2s/epoch)
Epoch  55/200:Epoch loss: 0.0468 - avg acc: 68.7% - val-roc: 0.9082 - val-ap: 0.8914 (0.2s/epoch)
Epoch  56/200:Epoch loss: 0.0991 - avg acc: 69.5% - val-roc: 0.9031 - val-ap: 0.8838 (0.2s/epoch)
Epoch  57/200:Epoch loss: 0.1314 - avg acc: 69.5% - val-roc: 0.8980 - val-ap: 0.8799 (0.2s/epoch)
Epoch  58/200:Epoch loss: 0.0765 - avg acc: 71.8% - val-roc: 0.8980 - val-ap: 0.8799 (0.2s/epoch)
Epoch  59/200:Epoch loss: 0.0730 - avg acc: 69.5% - val-roc: 0.8980 - val-ap: 0.8799 (0.2s/epoch)
Epoch  60/200:Epoch loss: 0.0632 - avg acc: 71.8% - val-roc: 0.8980 - val-ap: 0.8799 (0.2s/epoch)
Epoch  61/200:Epoch loss: 0.0682 - avg acc: 73.3% - val-roc: 0.8980 - val-ap: 0.8799 (0.2s/epoch)
Epoch  62/200:Epoch loss: 0.0533 - avg acc: 75.6% - val-roc: 0.8980 - val-ap: 0.8799 (0.2s/epoch)
Epoch  63/200:Epoch loss: 0.0880 - avg acc: 74.8% - val-roc: 0.9031 - val-ap: 0.8876 (0.2s/epoch)
Epoch  64/200:Epoch loss: 0.0847 - avg acc: 74.8% - val-roc: 0.9031 - val-ap: 0.8876 (0.2s/epoch)
Epoch  65/200:Epoch loss: 0.0609 - avg acc: 77.9% - val-roc: 0.9031 - val-ap: 0.8876 (0.2s/epoch)
Epoch  66/200:Epoch loss: 0.0503 - avg acc: 79.4% - val-roc: 0.8980 - val-ap: 0.8799 (0.2s/epoch)
Epoch  67/200:Epoch loss: 0.0408 - avg acc: 81.7% - val-roc: 0.8980 - val-ap: 0.8799 (0.2s/epoch)
Epoch  68/200:Epoch loss: 0.0571 - avg acc: 83.2% - val-roc: 0.8980 - val-ap: 0.8799 (0.2s/epoch)
Epoch  69/200:Epoch loss: 0.0463 - avg acc: 86.3% - val-roc: 0.8980 - val-ap: 0.8799 (0.2s/epoch)
Epoch  70/200:Epoch loss: 0.0654 - avg acc: 87.8% - val-roc: 0.8980 - val-ap: 0.8799 (0.2s/epoch)
Epoch  71/200:Epoch loss: 0.0577 - avg acc: 87.8% - val-roc: 0.8980 - val-ap: 0.8799 (0.2s/epoch)
Epoch  72/200:Epoch loss: 0.0408 - avg acc: 87.0% - val-roc: 0.9031 - val-ap: 0.8876 (0.2s/epoch)
Epoch  73/200:Epoch loss: 0.0389 - avg acc: 86.3% - val-roc: 0.9031 - val-ap: 0.8876 (0.2s/epoch)
Epoch  74/200:Epoch loss: 0.0258 - avg acc: 86.3% - val-roc: 0.9031 - val-ap: 0.8876 (0.2s/epoch)
Epoch  75/200:Epoch loss: 0.0276 - avg acc: 86.3% - val-roc: 0.9031 - val-ap: 0.8876 (0.2s/epoch)
Epoch  76/200:Epoch loss: 0.0512 - avg acc: 87.8% - val-roc: 0.9031 - val-ap: 0.8876 (0.2s/epoch)
Epoch  77/200:Epoch loss: 0.0405 - avg acc: 85.5% - val-roc: 0.8980 - val-ap: 0.8843 (0.2s/epoch)
Epoch  78/200:Epoch loss: 0.0409 - avg acc: 87.0% - val-roc: 0.8878 - val-ap: 0.8726 (0.2s/epoch)
Epoch  79/200:Epoch loss: 0.0296 - avg acc: 87.0% - val-roc: 0.8827 - val-ap: 0.8696 (0.2s/epoch)
Epoch  80/200:Epoch loss: 0.0304 - avg acc: 88.5% - val-roc: 0.8827 - val-ap: 0.8696 (0.2s/epoch)
Epoch  81/200:Epoch loss: 0.0387 - avg acc: 90.1% - val-roc: 0.8878 - val-ap: 0.8726 (0.2s/epoch)
Epoch  82/200:Epoch loss: 0.0263 - avg acc: 90.8% - val-roc: 0.8776 - val-ap: 0.8670 (0.2s/epoch)
Epoch  83/200:Epoch loss: 0.0204 - avg acc: 90.1% - val-roc: 0.8878 - val-ap: 0.8737 (0.2s/epoch)
Epoch  84/200:Epoch loss: 0.0411 - avg acc: 90.8% - val-roc: 0.8827 - val-ap: 0.8696 (0.2s/epoch)
Epoch  85/200:Epoch loss: 0.0329 - avg acc: 90.8% - val-roc: 0.8827 - val-ap: 0.8711 (0.2s/epoch)
Epoch  86/200:Epoch loss: 0.0386 - avg acc: 91.6% - val-roc: 0.8827 - val-ap: 0.8711 (0.2s/epoch)
Epoch  87/200:Epoch loss: 0.0520 - avg acc: 91.6% - val-roc: 0.8929 - val-ap: 0.8766 (0.2s/epoch)
Epoch  88/200:Epoch loss: 0.0240 - avg acc: 90.1% - val-roc: 0.9082 - val-ap: 0.8951 (0.2s/epoch)
Epoch  89/200:Epoch loss: 0.0224 - avg acc: 90.1% - val-roc: 0.9031 - val-ap: 0.8832 (0.2s/epoch)
Epoch  90/200:Epoch loss: 0.0585 - avg acc: 90.1% - val-roc: 0.9031 - val-ap: 0.8832 (0.2s/epoch)
Epoch  91/200:Epoch loss: 0.0255 - avg acc: 90.1% - val-roc: 0.8980 - val-ap: 0.8793 (0.2s/epoch)
Epoch  92/200:Epoch loss: 0.0253 - avg acc: 90.8% - val-roc: 0.8980 - val-ap: 0.8763 (0.2s/epoch)
Epoch  93/200:Epoch loss: 0.0780 - avg acc: 90.8% - val-roc: 0.8827 - val-ap: 0.8606 (0.2s/epoch)
Epoch  94/200:Epoch loss: 0.0135 - avg acc: 90.8% - val-roc: 0.8724 - val-ap: 0.8529 (0.2s/epoch)
Epoch  95/200:Epoch loss: 0.0269 - avg acc: 91.6% - val-roc: 0.8673 - val-ap: 0.8458 (0.2s/epoch)
Epoch  96/200:Epoch loss: 0.0454 - avg acc: 92.4% - val-roc: 0.8724 - val-ap: 0.8543 (0.2s/epoch)
Epoch  97/200:Epoch loss: 0.0126 - avg acc: 93.9% - val-roc: 0.8724 - val-ap: 0.8543 (0.2s/epoch)
Epoch  98/200:Epoch loss: 0.0169 - avg acc: 93.9% - val-roc: 0.8673 - val-ap: 0.8517 (0.2s/epoch)
Epoch  99/200:Epoch loss: 0.0254 - avg acc: 93.9% - val-roc: 0.8571 - val-ap: 0.8438 (0.2s/epoch)
Epoch  100/200:Epoch loss: 0.0154 - avg acc: 94.7% - val-roc: 0.8622 - val-ap: 0.8481 (0.2s/epoch)
Epoch  101/200:Epoch loss: 0.0211 - avg acc: 94.7% - val-roc: 0.8622 - val-ap: 0.8481 (0.2s/epoch)
Epoch  102/200:Epoch loss: 0.0142 - avg acc: 95.4% - val-roc: 0.8622 - val-ap: 0.8415 (0.2s/epoch)
Epoch  103/200:Epoch loss: 0.0343 - avg acc: 94.7% - val-roc: 0.8673 - val-ap: 0.8400 (0.2s/epoch)
Epoch  104/200:Epoch loss: 0.0085 - avg acc: 93.9% - val-roc: 0.8520 - val-ap: 0.8249 (0.2s/epoch)
Epoch  105/200:Epoch loss: 0.0303 - avg acc: 92.4% - val-roc: 0.8673 - val-ap: 0.8344 (0.2s/epoch)
Epoch  106/200:Epoch loss: 0.0153 - avg acc: 93.1% - val-roc: 0.8673 - val-ap: 0.8344 (0.2s/epoch)
Epoch  107/200:Epoch loss: 0.0279 - avg acc: 93.9% - val-roc: 0.8673 - val-ap: 0.8344 (0.2s/epoch)
Epoch  108/200:Epoch loss: 0.0131 - avg acc: 93.1% - val-roc: 0.8673 - val-ap: 0.8344 (0.2s/epoch)
Epoch  109/200:Epoch loss: 0.0053 - avg acc: 93.1% - val-roc: 0.8776 - val-ap: 0.8464 (0.2s/epoch)
Epoch  110/200:Epoch loss: 0.0144 - avg acc: 93.9% - val-roc: 0.8776 - val-ap: 0.8464 (0.2s/epoch)
Epoch  111/200:Epoch loss: 0.0063 - avg acc: 94.7% - val-roc: 0.8673 - val-ap: 0.8335 (0.2s/epoch)
Epoch  112/200:Epoch loss: 0.0169 - avg acc: 93.9% - val-roc: 0.8673 - val-ap: 0.8283 (0.2s/epoch)
Epoch  113/200:Epoch loss: 0.0164 - avg acc: 93.1% - val-roc: 0.8673 - val-ap: 0.8283 (0.2s/epoch)
Epoch  114/200:Epoch loss: 0.0258 - avg acc: 94.7% - val-roc: 0.8724 - val-ap: 0.8379 (0.2s/epoch)
Epoch  115/200:Epoch loss: 0.0053 - avg acc: 94.7% - val-roc: 0.8673 - val-ap: 0.8344 (0.2s/epoch)
Epoch  116/200:Epoch loss: 0.0048 - avg acc: 96.2% - val-roc: 0.8673 - val-ap: 0.8344 (0.2s/epoch)
Epoch  117/200:Epoch loss: 0.0114 - avg acc: 96.9% - val-roc: 0.8622 - val-ap: 0.8309 (0.2s/epoch)
Epoch  118/200:Epoch loss: 0.0141 - avg acc: 97.7% - val-roc: 0.8571 - val-ap: 0.8213 (0.2s/epoch)
Epoch  119/200:Epoch loss: 0.0111 - avg acc: 96.9% - val-roc: 0.8571 - val-ap: 0.8213 (0.2s/epoch)
Epoch  120/200:Epoch loss: 0.0118 - avg acc: 96.9% - val-roc: 0.8622 - val-ap: 0.8309 (0.2s/epoch)
Epoch  121/200:Epoch loss: 0.0071 - avg acc: 96.2% - val-roc: 0.8622 - val-ap: 0.8309 (0.2s/epoch)
Epoch  122/200:Epoch loss: 0.0080 - avg acc: 96.2% - val-roc: 0.8673 - val-ap: 0.8507 (0.2s/epoch)
Epoch  123/200:Epoch loss: 0.0166 - avg acc: 93.9% - val-roc: 0.8827 - val-ap: 0.8649 (0.2s/epoch)
Epoch  124/200:Epoch loss: 0.0119 - avg acc: 93.9% - val-roc: 0.8878 - val-ap: 0.8683 (0.2s/epoch)
Epoch  125/200:Epoch loss: 0.0115 - avg acc: 93.9% - val-roc: 0.8980 - val-ap: 0.8785 (0.2s/epoch)
Epoch  126/200:Epoch loss: 0.0236 - avg acc: 94.7% - val-roc: 0.8980 - val-ap: 0.8756 (0.2s/epoch)
Epoch  127/200:Epoch loss: 0.0164 - avg acc: 95.4% - val-roc: 0.8980 - val-ap: 0.8843 (0.2s/epoch)
Epoch  128/200:Epoch loss: 0.0138 - avg acc: 96.2% - val-roc: 0.8980 - val-ap: 0.8843 (0.3s/epoch)
Epoch  129/200:Epoch loss: 0.0053 - avg acc: 96.2% - val-roc: 0.8929 - val-ap: 0.8814 (0.2s/epoch)
Epoch  130/200:Epoch loss: 0.0055 - avg acc: 96.2% - val-roc: 0.8929 - val-ap: 0.8814 (0.2s/epoch)
Epoch  131/200:Epoch loss: 0.0133 - avg acc: 96.2% - val-roc: 0.8929 - val-ap: 0.8814 (0.2s/epoch)
Epoch  132/200:Epoch loss: 0.0093 - avg acc: 96.9% - val-roc: 0.8827 - val-ap: 0.8749 (0.2s/epoch)
Epoch  133/200:Epoch loss: 0.0099 - avg acc: 97.7% - val-roc: 0.8776 - val-ap: 0.8673 (0.2s/epoch)
Epoch  134/200:Epoch loss: 0.0272 - avg acc: 96.9% - val-roc: 0.8776 - val-ap: 0.8695 (0.2s/epoch)
Epoch  135/200:Epoch loss: 0.0057 - avg acc: 97.7% - val-roc: 0.8724 - val-ap: 0.8650 (0.2s/epoch)
Epoch  136/200:Epoch loss: 0.0075 - avg acc: 97.7% - val-roc: 0.8776 - val-ap: 0.8662 (0.2s/epoch)
Epoch  137/200:Epoch loss: 0.0014 - avg acc: 97.7% - val-roc: 0.8776 - val-ap: 0.8662 (0.2s/epoch)
Epoch  138/200:Epoch loss: 0.0549 - avg acc: 97.7% - val-roc: 0.8980 - val-ap: 0.8852 (0.2s/epoch)
Epoch  139/200:Epoch loss: 0.0054 - avg acc: 96.9% - val-roc: 0.8929 - val-ap: 0.8814 (0.2s/epoch)
Epoch  140/200:Epoch loss: 0.0057 - avg acc: 96.9% - val-roc: 0.8878 - val-ap: 0.8773 (0.2s/epoch)
Epoch  141/200:Epoch loss: 0.0154 - avg acc: 97.7% - val-roc: 0.8827 - val-ap: 0.8703 (0.2s/epoch)
Epoch  142/200:Epoch loss: 0.0053 - avg acc: 98.5% - val-roc: 0.8776 - val-ap: 0.8662 (0.2s/epoch)
Epoch  143/200:Epoch loss: 0.0063 - avg acc: 98.5% - val-roc: 0.8776 - val-ap: 0.8662 (0.2s/epoch)
Epoch  144/200:Epoch loss: 0.0134 - avg acc: 99.2% - val-roc: 0.8776 - val-ap: 0.8662 (0.2s/epoch)
Epoch  145/200:Epoch loss: 0.0041 - avg acc: 100.0% - val-roc: 0.8776 - val-ap: 0.8662 (0.2s/epoch)
Epoch  146/200:Epoch loss: 0.0113 - avg acc: 100.0% - val-roc: 0.8724 - val-ap: 0.8619 (0.2s/epoch)
Epoch  147/200:Epoch loss: 0.0092 - avg acc: 100.0% - val-roc: 0.8776 - val-ap: 0.8662 (0.2s/epoch)
Epoch  148/200:Epoch loss: 0.0070 - avg acc: 100.0% - val-roc: 0.8724 - val-ap: 0.8626 (0.2s/epoch)
Epoch  149/200:Epoch loss: 0.0064 - avg acc: 100.0% - val-roc: 0.8673 - val-ap: 0.8600 (0.2s/epoch)
Epoch  150/200:Epoch loss: 0.0022 - avg acc: 100.0% - val-roc: 0.8724 - val-ap: 0.8626 (0.2s/epoch)
Epoch  151/200:Epoch loss: 0.0045 - avg acc: 100.0% - val-roc: 0.8724 - val-ap: 0.8626 (0.2s/epoch)
Epoch  152/200:Epoch loss: 0.0055 - avg acc: 100.0% - val-roc: 0.8673 - val-ap: 0.8596 (0.2s/epoch)
Epoch  153/200:Epoch loss: 0.0035 - avg acc: 100.0% - val-roc: 0.8673 - val-ap: 0.8596 (0.2s/epoch)
Epoch  154/200:Epoch loss: 0.0074 - avg acc: 100.0% - val-roc: 0.8673 - val-ap: 0.8513 (0.2s/epoch)
Epoch  155/200:Epoch loss: 0.0055 - avg acc: 100.0% - val-roc: 0.8673 - val-ap: 0.8513 (0.2s/epoch)
Epoch  156/200:Epoch loss: 0.0040 - avg acc: 100.0% - val-roc: 0.8724 - val-ap: 0.8554 (0.2s/epoch)
Epoch  157/200:Epoch loss: 0.0067 - avg acc: 100.0% - val-roc: 0.8776 - val-ap: 0.8708 (0.2s/epoch)
Epoch  158/200:Epoch loss: 0.0032 - avg acc: 100.0% - val-roc: 0.8673 - val-ap: 0.8596 (0.2s/epoch)
Epoch  159/200:Epoch loss: 0.0027 - avg acc: 100.0% - val-roc: 0.8673 - val-ap: 0.8596 (0.2s/epoch)
Epoch  160/200:Epoch loss: 0.0021 - avg acc: 100.0% - val-roc: 0.8673 - val-ap: 0.8596 (0.2s/epoch)
Epoch  161/200:Epoch loss: 0.0096 - avg acc: 100.0% - val-roc: 0.8673 - val-ap: 0.8596 (0.2s/epoch)
Epoch  162/200:Epoch loss: 0.0036 - avg acc: 100.0% - val-roc: 0.8776 - val-ap: 0.8673 (0.2s/epoch)
Epoch  163/200:Epoch loss: 0.0114 - avg acc: 100.0% - val-roc: 0.8673 - val-ap: 0.8596 (0.2s/epoch)
Epoch  164/200:Epoch loss: 0.0059 - avg acc: 100.0% - val-roc: 0.8724 - val-ap: 0.8632 (0.2s/epoch)
Epoch  165/200:Epoch loss: 0.0044 - avg acc: 100.0% - val-roc: 0.8827 - val-ap: 0.8703 (0.2s/epoch)
Epoch  166/200:Epoch loss: 0.0016 - avg acc: 100.0% - val-roc: 0.8827 - val-ap: 0.8703 (0.2s/epoch)
Epoch  167/200:Epoch loss: 0.0060 - avg acc: 100.0% - val-roc: 0.8827 - val-ap: 0.8703 (0.2s/epoch)
Epoch  168/200:Epoch loss: 0.0179 - avg acc: 100.0% - val-roc: 0.8827 - val-ap: 0.8703 (0.2s/epoch)
Epoch  169/200:Epoch loss: 0.0016 - avg acc: 100.0% - val-roc: 0.8827 - val-ap: 0.8703 (0.2s/epoch)
Epoch  170/200:Epoch loss: 0.0103 - avg acc: 99.2% - val-roc: 0.8878 - val-ap: 0.8779 (0.2s/epoch)
Epoch  171/200:Epoch loss: 0.0135 - avg acc: 99.2% - val-roc: 0.8827 - val-ap: 0.8739 (0.2s/epoch)
Epoch  172/200:Epoch loss: 0.0105 - avg acc: 99.2% - val-roc: 0.8776 - val-ap: 0.8662 (0.2s/epoch)
Epoch  173/200:Epoch loss: 0.0017 - avg acc: 99.2% - val-roc: 0.8724 - val-ap: 0.8534 (0.2s/epoch)
Epoch  174/200:Epoch loss: 0.0020 - avg acc: 99.2% - val-roc: 0.8776 - val-ap: 0.8577 (0.2s/epoch)
Epoch  175/200:Epoch loss: 0.0020 - avg acc: 99.2% - val-roc: 0.8724 - val-ap: 0.8534 (0.2s/epoch)
Epoch  176/200:Epoch loss: 0.0034 - avg acc: 99.2% - val-roc: 0.8776 - val-ap: 0.8577 (0.2s/epoch)
Epoch  177/200:Epoch loss: 0.0012 - avg acc: 99.2% - val-roc: 0.8776 - val-ap: 0.8577 (0.2s/epoch)
Epoch  178/200:Epoch loss: 0.0024 - avg acc: 99.2% - val-roc: 0.8724 - val-ap: 0.8534 (0.2s/epoch)
Epoch  179/200:Epoch loss: 0.0052 - avg acc: 100.0% - val-roc: 0.8673 - val-ap: 0.8500 (0.2s/epoch)
Epoch  180/200:Epoch loss: 0.0011 - avg acc: 100.0% - val-roc: 0.8622 - val-ap: 0.8464 (0.2s/epoch)
Epoch  181/200:Epoch loss: 0.0075 - avg acc: 100.0% - val-roc: 0.8673 - val-ap: 0.8508 (0.2s/epoch)
Epoch  182/200:Epoch loss: 0.0020 - avg acc: 100.0% - val-roc: 0.8724 - val-ap: 0.8551 (0.2s/epoch)
Epoch  183/200:Epoch loss: 0.0009 - avg acc: 100.0% - val-roc: 0.8673 - val-ap: 0.8508 (0.2s/epoch)
Epoch  184/200:Epoch loss: 0.0163 - avg acc: 100.0% - val-roc: 0.8673 - val-ap: 0.8508 (0.2s/epoch)
Epoch  185/200:Epoch loss: 0.0012 - avg acc: 100.0% - val-roc: 0.8622 - val-ap: 0.8462 (0.2s/epoch)
Epoch  186/200:Epoch loss: 0.0066 - avg acc: 100.0% - val-roc: 0.8673 - val-ap: 0.8538 (0.2s/epoch)
Epoch  187/200:Epoch loss: 0.0015 - avg acc: 100.0% - val-roc: 0.8673 - val-ap: 0.8538 (0.2s/epoch)
Epoch  188/200:Epoch loss: 0.0047 - avg acc: 100.0% - val-roc: 0.8571 - val-ap: 0.8469 (0.2s/epoch)
Epoch  189/200:Epoch loss: 0.0067 - avg acc: 100.0% - val-roc: 0.8571 - val-ap: 0.8469 (0.2s/epoch)
Epoch  190/200:Epoch loss: 0.0039 - avg acc: 100.0% - val-roc: 0.8673 - val-ap: 0.8633 (0.2s/epoch)
Epoch  191/200:Epoch loss: 0.0033 - avg acc: 100.0% - val-roc: 0.8673 - val-ap: 0.8633 (0.2s/epoch)
Epoch  192/200:Epoch loss: 0.0013 - avg acc: 100.0% - val-roc: 0.8622 - val-ap: 0.8514 (0.2s/epoch)
Epoch  193/200:Epoch loss: 0.0013 - avg acc: 100.0% - val-roc: 0.8724 - val-ap: 0.8584 (0.2s/epoch)
Epoch  194/200:Epoch loss: 0.0033 - avg acc: 100.0% - val-roc: 0.8520 - val-ap: 0.8431 (0.2s/epoch)
Epoch  195/200:Epoch loss: 0.0023 - avg acc: 100.0% - val-roc: 0.8622 - val-ap: 0.8503 (0.2s/epoch)
Epoch  196/200:Epoch loss: 0.0016 - avg acc: 100.0% - val-roc: 0.8673 - val-ap: 0.8541 (0.2s/epoch)
Epoch  197/200:Epoch loss: 0.0054 - avg acc: 100.0% - val-roc: 0.8673 - val-ap: 0.8541 (0.2s/epoch)
Epoch  198/200:Epoch loss: 0.0017 - avg acc: 100.0% - val-roc: 0.8622 - val-ap: 0.8503 (0.2s/epoch)
Epoch  199/200:Epoch loss: 0.0064 - avg acc: 100.0% - val-roc: 0.8724 - val-ap: 0.8584 (0.2s/epoch)
Epoch  200/200:Epoch loss: 0.0014 - avg acc: 100.0% - val-roc: 0.8673 - val-ap: 0.8541 (0.2s/epoch)
Done!
Test ROC: 0.9111 - Test AP: 0.9649

Plots of training history#

utils.plot_history(history)
../_images/aaacb910b3da52d457af90e3724ec68c7d3f045e89b6d5efba5f42a03e188ecf.png

Validation on unseen data#

test_roc, test_ap, preds, ys = t.test(
    dataloaders[2].dataset, model, return_preds=True
)
print(f"Test ROC-AUC: {test_roc:.4f}")
print(f"Test AP: {test_ap:.4f}")

# Plot the ROC curve
fig = t.plot_roc_curve(dataloaders[2].dataset, model)
print(
    f"Number of correct positive predictions on test set: {torch.sum((preds > 0.5) & (ys == 1)).numpy()} out of {torch.sum(ys == 1).numpy()}"
)
print(
    f"Number of correct negative predictions on test set: {torch.sum((preds < 0.5) & (ys == 0)).numpy()} out of {torch.sum(ys == 0).numpy()}"
)
Test ROC-AUC: 0.9111
Test AP: 0.9649
../_images/c8d32d1aed38e3790e1328c74072ad8cf0e21885c8d73fdfe9bff06213f43dbd.png
Number of correct positive predictions on test set: 17 out of 20
Number of correct negative predictions on test set: 6 out of 9
utils.show_preds_distribution(preds, ys)
../_images/db20087f963c527ea5fc75fe691847bd8f63d4f239aefedcf27d1b2d65ba3597.png
This plot shows the distribution of the labels and predictions;
predictions are overlayed on top of the labels,
showing whether they are missing or surnumerous.
The labels are 20 positive and 9 negative.
The predictions are 20 positive and 9 negative.
utils.plot_confusion_matrix(preds, ys)
../_images/294e1961c93095353004899ad1958f7310d5819dab32f4058de5a0e1dc06cd33.png
mislabeled = utils.find_mislabeled_molecules(preds, ys.numpy().astype(int))
print(f"{len(mislabeled)} mislabeled molecules")
mols = []
for id_mol in mislabeled:
    mols.append(dataloaders[2].dataset[id_mol])

thresh_preds = np.where(preds > 0.5, 1, 0)[mislabeled]
utils.draw_molecule_from_dict(
    mols[:6],
    preds=thresh_preds,
    mol_ids=mislabeled,
    n_cols=3,
    n_rows=2,
    figsize=(6, 4),
)
6 mislabeled molecules
../_images/6791c6ac8dfd3bbc57ee0e7c1c05dc1941adf4ba57284cf499b0cbd9f0414aaa.png

Full dataset performance#

We now check how the model performs overall on the whole dataset. We also check which molecules are not properly labeled, as well as the distribution of predictions to get a sense of the model’s (un)certainty.

data_all = t.create_dataset_dict(add_edge_features=True)
full_dataset = t.MutagDataset(data_all)

Full ROC curve#

full_roc, full_ap, preds, ys = t.test(dataset, model, return_preds=True)
print(f"Full ROC-AUC: {full_roc:.4f}")
print(f"Full AP: {full_ap:.4f}")

# Plot the ROC curve
fig = t.plot_roc_curve(dataset, model)

preds = preds.cpu().detach().numpy()
ys = ys.cpu().detach().numpy()
print(
    f"Number of correct positive predictions on test set: {np.sum((preds > 0.5) & (ys == 1))} out of {np.sum(ys == 1)}"
)
print(
    f"Number of correct negative predictions on test set: {np.sum((preds < 0.5) & (ys == 0))} out of {np.sum(ys == 0)}"
)
Full ROC-AUC: 0.9792
Full AP: 0.9878
../_images/86570c12d41e49bb18b410b0bc2fe1f78a753fafe10b676b25d9ab0cabb805a2.png
Number of correct positive predictions on test set: 117 out of 125
Number of correct negative predictions on test set: 58 out of 63

Predictions distribution#

utils.show_preds_distribution(preds, ys)
../_images/06cc562af13c349f467856480e9ae06022f8c8bffae5f66fc01bbb57a82a6472.png
This plot shows the distribution of the labels and predictions;
predictions are overlayed on top of the labels,
showing whether they are missing or surnumerous.
The labels are 125 positive and 63 negative.
The predictions are 122 positive and 66 negative.

Confusion matrix#

utils.plot_confusion_matrix(preds, ys)
../_images/0e91b02316abe892db017d24a207669a8f984e1d3ef7a3aa7c42699ab828a9f7.png

Check mislabelled molecules#

mislabeled = utils.find_mislabeled_molecules(preds, ys.astype(int))
print(f"{len(mislabeled)} mislabeled molecules")
mols = []
for id_mol in mislabeled:
    mols.append(full_dataset[id_mol])

thresh_preds = np.where(preds > 0.5, 1, 0)[mislabeled]
utils.draw_molecule_from_dict(
    mols[:12],
    preds=thresh_preds,
    mol_ids=mislabeled,
    n_cols=4,
    n_rows=3,
    figsize=(8, 6),
)
13 mislabeled molecules
../_images/0a049e3a9623b5797cab8fef2e0ed33690a68323dd994b3f4418330e6bc198b9.png