Skip to content

Commit

Permalink
core\feat: #68 sparse tgce
Browse files Browse the repository at this point in the history
- added sparse tgce loss
  • Loading branch information
blotero committed Jun 12, 2024
1 parent 68f7ba2 commit 229ef28
Showing 1 changed file with 64 additions and 0 deletions.
64 changes: 64 additions & 0 deletions core/seg_tgce/loss/tgce.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,3 +117,67 @@ def get_config(
"""
base_config = super().get_config()
return {**base_config, "q": self.q}


class TcgeSsSparse(TcgeSs):
"""
Truncated generalized cross entropy
for semantic segmentation loss.
This is a much more sparse version which completes missing dimensions
for facilitating its use in the Bayesian U-Net.
"""

def call(self, y_true: tf.Tensor, y_pred: tf.Tensor) -> tf.Tensor:
y_true = cast(y_true, TARGET_DATA_TYPE)
y_pred = cast(y_pred, TARGET_DATA_TYPE)

y_pred = y_pred[..., : self.num_classes + self.num_annotators] # type:ignore
y_true = tf.reshape(
y_true, (y_true.shape[:-1]) + (self.num_classes, self.num_annotators)
)
lambda_r = y_pred[..., self.num_classes :] # type:ignore
y_pred_ = y_pred[..., : self.num_classes]
n_samples, width, height, _ = y_pred_.shape
y_pred_ = y_pred_[..., tf.newaxis] # type:ignore
y_pred_ = tf.repeat(y_pred_, repeats=[self.num_annotators], axis=-1)

epsilon = 1e-8
y_pred_ = tf.clip_by_value(y_pred_, epsilon, 1.0 - epsilon)

term_r = tf.math.reduce_mean(
tf.math.multiply(
y_true,
(
tf.ones(
[
n_samples,
width,
height,
self.num_classes,
self.num_annotators,
]
)
- tf.pow(y_pred_, self.q)
)
/ (self.q + epsilon + self.smooth),
),
axis=-2,
)

term_c = tf.math.multiply(
tf.ones([n_samples, width, height, self.num_annotators]) - lambda_r,
(
tf.ones([n_samples, width, height, self.num_annotators])
- tf.pow(
(1 / self.num_classes + self.smooth)
* tf.ones([n_samples, width, height, self.num_annotators]),
self.q,
)
)
/ (self.q + epsilon + self.smooth),
)

loss = tf.math.reduce_mean(tf.math.multiply(lambda_r, term_r) + term_c)
loss = tf.where(tf.math.is_nan(loss), tf.constant(1e-8), loss)

return loss

0 comments on commit 229ef28

Please sign in to comment.