Skip to content

Commit

Permalink
core\refac: #68 prepare bayesian optimization
Browse files Browse the repository at this point in the history
- prepare unet arch for bayesian optimizer
  • Loading branch information
blotero committed Jun 12, 2024
1 parent 4adcd61 commit c6f9762
Show file tree
Hide file tree
Showing 4 changed files with 158 additions and 27 deletions.
2 changes: 1 addition & 1 deletion core/seg_tgce/data/crowd_seg/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@
from keras.preprocessing.image import img_to_array, load_img
from keras.utils import Sequence
from matplotlib import pyplot as plt
from tensorflow import transpose
from tensorflow import Tensor
from tensorflow import argmax as tf_argmax
from tensorflow import transpose

from .__retrieve import fetch_data, get_masks_dir, get_patches_dir
from .stage import Stage
Expand Down
16 changes: 16 additions & 0 deletions core/seg_tgce/layers/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
import keras.backend as K
from keras.layers import Layer
from tensorflow import Tensor


class SparseSoftmax(Layer):
"""Custom layer implementing the sparse softmax activation function."""

def _init_(self, name="SparseSoftmax", **kwargs):
super()._init_(name=name, **kwargs)

def call(self, inputs: Tensor) -> Tensor: # pylint: disable=arguments-differ
e_x = K.exp(inputs - K.max(inputs, axis=-1, keepdims=True))
sum_e_x = K.sum(e_x, axis=-1, keepdims=True)
output = e_x / (sum_e_x + K.epsilon())
return output
64 changes: 64 additions & 0 deletions core/seg_tgce/loss/tgce.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,3 +117,67 @@ def get_config(
"""
base_config = super().get_config()
return {**base_config, "q": self.q}


class TcgeSsSparse(TcgeSs):
"""
Truncated generalized cross entropy
for semantic segmentation loss.
This is a much more sparse version which completes missing dimensions
for facilitating its use in the Bayesian U-Net.
"""

def call(self, y_true: tf.Tensor, y_pred: tf.Tensor) -> tf.Tensor:
y_true = cast(y_true, TARGET_DATA_TYPE)
y_pred = cast(y_pred, TARGET_DATA_TYPE)

y_pred = y_pred[..., : self.num_classes + self.num_annotators] # type:ignore
y_true = tf.reshape(
y_true, (y_true.shape[:-1]) + (self.num_classes, self.num_annotators)
)
lambda_r = y_pred[..., self.num_classes :] # type:ignore
y_pred_ = y_pred[..., : self.num_classes]
n_samples, width, height, _ = y_pred_.shape
y_pred_ = y_pred_[..., tf.newaxis] # type:ignore
y_pred_ = tf.repeat(y_pred_, repeats=[self.num_annotators], axis=-1)

epsilon = 1e-8
y_pred_ = tf.clip_by_value(y_pred_, epsilon, 1.0 - epsilon)

term_r = tf.math.reduce_mean(
tf.math.multiply(
y_true,
(
tf.ones(
[
n_samples,
width,
height,
self.num_classes,
self.num_annotators,
]
)
- tf.pow(y_pred_, self.q)
)
/ (self.q + epsilon + self.smooth),
),
axis=-2,
)

term_c = tf.math.multiply(
tf.ones([n_samples, width, height, self.num_annotators]) - lambda_r,
(
tf.ones([n_samples, width, height, self.num_annotators])
- tf.pow(
(1 / self.num_classes + self.smooth)
* tf.ones([n_samples, width, height, self.num_annotators]),
self.q,
)
)
/ (self.q + epsilon + self.smooth),
)

loss = tf.math.reduce_mean(tf.math.multiply(lambda_r, term_r) + term_c)
loss = tf.where(tf.math.is_nan(loss), tf.constant(1e-8), loss)

return loss
103 changes: 77 additions & 26 deletions core/seg_tgce/models/unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,101 +7,116 @@
Concatenate,
Conv2D,
Input,
Layer,
MaxPool2D,
UpSampling2D,
)
from keras.models import Model
from keras.utils import get_custom_objects

from seg_tgce.layers import SparseSoftmax

from .ma_model import ModelMultipleAnnotators

get_custom_objects()["sparse_softmax"] = SparseSoftmax()

DefaultConv2D = partial(Conv2D, kernel_size=3, activation="relu", padding="same")

DefaultPooling = partial(MaxPool2D, pool_size=2)
DilatedConv = partial(
Conv2D,
kernel_size=3,
activation="relu",
padding="same",
dilation_rate=10,
name="DilatedConv",
)

upsample = partial(UpSampling2D, (2, 2))

UpSample = partial(UpSampling2D, (2, 2))


def kernel_initializer(seed: float) -> GlorotUniform:
return GlorotUniform(seed=seed)


def unet_tgce( # pylint: disable=too-many-statements
input_shape: Tuple[int, int, int],
name: str = "UNET",
out_channels: int = 2,
n_scorers: int = 5,
out_act_functions: Tuple[str, str] = ("softmax", "sigmoid"),
) -> ModelMultipleAnnotators:
# Encoder
input_layer = Input(shape=input_shape)

def build_encoder(input_layer: Layer) -> Tuple[Layer, Layer, Layer, Layer, Layer]:
x = BatchNormalization(name="Batch00")(input_layer)

x = DefaultConv2D(8, kernel_initializer=kernel_initializer(34), name="Conv10")(x)
x = BatchNormalization(name="Batch10")(x)
x = level_1 = DefaultConv2D(
8, kernel_initializer=kernel_initializer(4), name="Conv11"
)(x)
x = BatchNormalization(name="Batch11")(x)
x = DefaultPooling(name="Pool10")(x) # 128x128 -> 64x64

x = DefaultConv2D(16, kernel_initializer=kernel_initializer(56), name="Conv20")(x)
x = BatchNormalization(name="Batch20")(x)
x = level_2 = DefaultConv2D(
16, kernel_initializer=kernel_initializer(32), name="Conv21"
)(x)
x = BatchNormalization(name="Batch22")(x)
x = DefaultPooling(name="Pool20")(x) # 64x64 -> 32x32

x = DefaultConv2D(32, kernel_initializer=kernel_initializer(87), name="Conv30")(x)
x = BatchNormalization(name="Batch30")(x)
x = level_3 = DefaultConv2D(
32, kernel_initializer=kernel_initializer(30), name="Conv31"
)(x)
x = BatchNormalization(name="Batch31")(x)
x = DefaultPooling(name="Pool30")(x) # 32x32 -> 16x16

x = DefaultConv2D(64, kernel_initializer=kernel_initializer(79), name="Conv40")(x)
x = BatchNormalization(name="Batch40")(x)
x = level_4 = DefaultConv2D(
64, kernel_initializer=kernel_initializer(81), name="Conv41"
)(x)
x = BatchNormalization(name="Batch41")(x)
x = DefaultPooling(name="Pool40")(x) # 16x16 -> 8x8
return x, level_1, level_2, level_3, level_4


# Decoder
def build_decoder(
x: Layer, level_1: Layer, level_2: Layer, level_3: Layer, level_4: Layer
) -> Layer:
x = DefaultConv2D(128, kernel_initializer=kernel_initializer(89), name="Conv50")(x)
x = BatchNormalization(name="Batch50")(x)
x = DefaultConv2D(128, kernel_initializer=kernel_initializer(42), name="Conv51")(x)
x = BatchNormalization(name="Batch51")(x)

x = upsample(name="Up60")(x) # 8x8 -> 16x16
x = UpSample(name="Up60")(x) # 8x8 -> 16x16
x = Concatenate(name="Concat60")([level_4, x])
x = DefaultConv2D(64, kernel_initializer=kernel_initializer(91), name="Conv60")(x)
x = BatchNormalization(name="Batch60")(x)
x = DefaultConv2D(64, kernel_initializer=kernel_initializer(47), name="Conv61")(x)
x = BatchNormalization(name="Batch61")(x)

x = upsample(name="Up70")(x) # 16x16 -> 32x32
x = UpSample(name="Up70")(x) # 16x16 -> 32x32
x = Concatenate(name="Concat70")([level_3, x])
x = DefaultConv2D(32, kernel_initializer=kernel_initializer(21), name="Conv70")(x)
x = BatchNormalization(name="Batch70")(x)
x = DefaultConv2D(32, kernel_initializer=kernel_initializer(96), name="Conv71")(x)
x = BatchNormalization(name="Batch71")(x)

x = upsample(name="Up80")(x) # 32x32 -> 64x64
x = UpSample(name="Up80")(x) # 32x32 -> 64x64
x = Concatenate(name="Concat80")([level_2, x])
x = DefaultConv2D(16, kernel_initializer=kernel_initializer(96), name="Conv80")(x)
x = BatchNormalization(name="Batch80")(x)
x = DefaultConv2D(16, kernel_initializer=kernel_initializer(98), name="Conv81")(x)
x = BatchNormalization(name="Batch81")(x)

x = upsample(name="Up90")(x) # 64x64 -> 128x128
x = UpSample(name="Up90")(x) # 64x64 -> 128x128
x = Concatenate(name="Concat90")([level_1, x])
x = DefaultConv2D(8, kernel_initializer=kernel_initializer(35), name="Conv90")(x)
x = BatchNormalization(name="Batch90")(x)
x = DefaultConv2D(8, kernel_initializer=kernel_initializer(7), name="Conv91")(x)
x = BatchNormalization(name="Batch91")(x)
return x


def unet_tgce(
input_shape: Tuple[int, int, int],
name: str = "UNET",
out_channels: int = 2,
n_scorers: int = 5,
out_act_functions: Tuple[str, str] = ("softmax", "sigmoid"),
) -> ModelMultipleAnnotators:
input_layer = Input(shape=input_shape)
x, level_1, level_2, level_3, level_4 = build_encoder(input_layer)
x = build_decoder(x, level_1, level_2, level_3, level_4)
xy = DefaultConv2D(
out_channels,
kernel_size=(1, 1),
Expand All @@ -117,7 +132,43 @@ def unet_tgce( # pylint: disable=too-many-statements
name="Conv101-Lambda",
)(x)
y = Concatenate()([xy, x_lambda])
return ModelMultipleAnnotators(input_layer, y, name=name)


def bayesian_unet_tgce( # pylint: disable=too-many-arguments, too-many-locals
input_shape: Tuple[int, int, int],
*,
name: str = "UNET",
out_channels: int = 2,
out_filler_channels: int = 2,
n_scorers: int = 5,
out_act_functions: Tuple[str, str] = ("sparse_softmax", "sparse_softmax"),
) -> Model:
input_layer = Input(shape=input_shape)
x, level_1, level_2, level_3, level_4 = build_encoder(input_layer)
x = build_decoder(x, level_1, level_2, level_3, level_4)

model = ModelMultipleAnnotators(input_layer, y, name=name)
xy = DefaultConv2D(
out_channels,
kernel_size=(1, 1),
activation=out_act_functions[0],
kernel_initializer=kernel_initializer(42),
name="Conv100",
)(x)
xyy = DefaultConv2D(
out_filler_channels,
kernel_size=(1, 1),
activation=out_act_functions[0],
kernel_initializer=kernel_initializer(42),
name="Conv101",
)(x)
x_lambda = DilatedConv(
n_scorers,
kernel_size=(1, 1),
activation=out_act_functions[1],
kernel_initializer=kernel_initializer(42),
name="Conv101-Lambda",
)(x)
y = Concatenate()([xy, x_lambda, xyy])

return model
return Model(input_layer, y, name=name)

0 comments on commit c6f9762

Please sign in to comment.