Skip to content

Commit

Permalink
core\refac: #68 handle symbolic tensor dims in loss
Browse files Browse the repository at this point in the history
- handled sym tensor dimensions in tgce loss for allowing its usage as a
  keras sequence directly into fit methods
  • Loading branch information
blotero committed Jun 18, 2024
1 parent 8178cf8 commit 631cd26
Show file tree
Hide file tree
Showing 3 changed files with 48 additions and 38 deletions.
66 changes: 33 additions & 33 deletions core/poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 2 additions & 2 deletions core/pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
[project]
description = "Framework for handling image segmentation in the context of multiple annotators"
name = "seg_tgce"
version = "0.1.4.dev2"
version = "0.1.4.dev3"
readme = "README.md"
authors = [{ name = "Brandon Lotero", email = "[email protected]" }]
maintainers = [{ name = "Brandon Lotero", email = "[email protected]" }]
Expand All @@ -15,7 +15,7 @@ Issues = "https://github.com/blotero/seg_tgce/issues"

[tool.poetry]
name = "seg_tgce"
version = "0.1.4.dev2"
version = "0.1.4.dev3"
authors = ["Brandon Lotero <[email protected]>"]
description = "A package for the SEG TGCE project"
readme = "README.md"
Expand Down
16 changes: 13 additions & 3 deletions core/seg_tgce/loss/tgce.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from typing import Any

import keras.backend as K
from matplotlib.pylab import f
import tensorflow as tf
from keras.losses import Loss
from tensorflow import cast
Expand Down Expand Up @@ -132,12 +133,21 @@ def call(self, y_true: tf.Tensor, y_pred: tf.Tensor) -> tf.Tensor:
y_pred = cast(y_pred, TARGET_DATA_TYPE)

y_pred = y_pred[..., : self.num_classes + self.num_annotators] # type:ignore
y_true = tf.reshape(
y_true, (y_true.shape[:-1]) + (self.num_classes, self.num_annotators)

y_true_shape = tf.shape(y_true)

new_shape = tf.concat(
[y_true_shape[:-1], [self.num_classes, self.num_annotators]], axis=0
)
y_true = tf.reshape(y_true, new_shape)

lambda_r = y_pred[..., self.num_classes :] # type:ignore
y_pred_ = y_pred[..., : self.num_classes]
n_samples, width, height, _ = y_pred_.shape

n_samples = tf.shape(y_pred_)[0]
width = tf.shape(y_pred_)[1]
height = tf.shape(y_pred_)[2]

y_pred_ = y_pred_[..., tf.newaxis] # type:ignore
y_pred_ = tf.repeat(y_pred_, repeats=[self.num_annotators], axis=-1)

Expand Down

0 comments on commit 631cd26

Please sign in to comment.