Skip to content

Commit

Permalink
core\refac: #40 standard shape for crowd seg
Browse files Browse the repository at this point in the history
- delivered masks as the standard shape for crowd seg data
  • Loading branch information
blotero committed May 21, 2024
1 parent da977db commit 91366b5
Show file tree
Hide file tree
Showing 4 changed files with 23 additions and 11 deletions.
11 changes: 5 additions & 6 deletions core/poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 2 additions & 2 deletions core/pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
[project]
description = "Framework for handling image segmentation in the context of multiple annotators"
name = "seg_tgce"
version = "0.1.3"
version = "0.1.3.dev1"
readme = "README.md"
authors = [{ name = "Brandon Lotero", email = "[email protected]" }]
maintainers = [{ name = "Brandon Lotero", email = "[email protected]" }]
Expand All @@ -15,7 +15,7 @@ Issues = "https://github.com/blotero/seg_tgce/issues"

[tool.poetry]
name = "seg_tgce"
version = "0.1.3"
version = "0.1.3.dev1"
authors = ["Brandon Lotero <[email protected]>"]
description = "A package for the SEG TGCE project"
readme = "README.md"
Expand Down
5 changes: 4 additions & 1 deletion core/seg_tgce/data/crowd_seg/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,14 @@


def main() -> None:
train, val, test = get_all_data()
train, val, test = get_all_data(batch_size=8)
val.visualize_sample(["NP8", "NP16", "NP21", "expert"])
print(f"Train: {len(train)}")
print(f"Val: {len(val)}")
print(f"Test: {len(test)}")
img, mask = train[0]
print(f"Images shape: {img.shape}")
print(f"Masks shape: {mask.shape}")


main()
14 changes: 12 additions & 2 deletions core/seg_tgce/data/crowd_seg/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from keras.preprocessing.image import img_to_array, load_img
from keras.utils import Sequence
from matplotlib import pyplot as plt
from tensorflow import transpose

from .retrieve import fetch_data, get_masks_dir, get_patches_dir
from .stage import Stage
Expand All @@ -15,11 +16,20 @@


class CustomPath(TypedDict):
"""Custom path for image and mask directories."""

image_dir: str
mask_dir: str


class ImageDataGenerator(Sequence): # pylint: disable=too-many-instance-attributes
"""
Data generator for crowd segmentation data.
Delivered data is in the form of images and masks.
Shapes are as follows:
- images: (batch_size, image_size[0], image_size[1], 3)
- masks: (batch_size, image_size[0], image_size[1]), n_classes, n_scorers"""

def __init__( # pylint: disable=too-many-arguments
self,
n_classes: int,
Expand Down Expand Up @@ -81,7 +91,7 @@ def visualize_sample(
for class_num in range(self.n_classes):
axes[scorer_num][0].imshow(images[sample_index].astype(int))
axes[scorer_num][class_num + 1].imshow(
masks[sample_index, scorer_num, class_num]
masks[sample_index, :, :, class_num, scorer_num]
)
axes[scorer_num][0].axis("off")
axes[scorer_num][class_num + 1].axis("off")
Expand Down Expand Up @@ -135,4 +145,4 @@ def __data_generation(self, batch_filenames):

images[batch] = image

return images, masks
return images, transpose(masks, perm=[0, 3, 4, 2, 1])

0 comments on commit 91366b5

Please sign in to comment.