Skip to content

Commit

Permalink
core\feat: #68 crowd_seg improvements
Browse files Browse the repository at this point in the history
- improvements in data parsing
- improvements in sample visualization
- made retrieves submodule private
  • Loading branch information
blotero committed May 27, 2024
1 parent f1cd090 commit 251ab99
Show file tree
Hide file tree
Showing 6 changed files with 151 additions and 108 deletions.
104 changes: 52 additions & 52 deletions core/poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 2 additions & 2 deletions core/pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
[project]
description = "Framework for handling image segmentation in the context of multiple annotators"
name = "seg_tgce"
version = "0.1.3.dev2"
version = "0.1.3.dev3"
readme = "README.md"
authors = [{ name = "Brandon Lotero", email = "[email protected]" }]
maintainers = [{ name = "Brandon Lotero", email = "[email protected]" }]
Expand All @@ -15,7 +15,7 @@ Issues = "https://github.com/blotero/seg_tgce/issues"

[tool.poetry]
name = "seg_tgce"
version = "0.1.3.dev2"
version = "0.1.3.dev3"
authors = ["Brandon Lotero <[email protected]>"]
description = "A package for the SEG TGCE project"
readme = "README.md"
Expand Down
10 changes: 2 additions & 8 deletions core/seg_tgce/data/crowd_seg/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,11 @@
from .generator import ImageDataGenerator
from .stage import Stage

_DEFAULT_N_CLASSES = 6


def get_all_data(
n_classes: int = _DEFAULT_N_CLASSES,
image_size: Tuple[int, int] = (256, 256),
batch_size: int = 32,
shuffle: bool = True,
shuffle: bool = False,
) -> Tuple[ImageDataGenerator, ...]:
"""
Retrieve all data generators for the crowd segmentation task.
Expand All @@ -19,7 +16,6 @@ def get_all_data(
return tuple(
ImageDataGenerator(
batch_size=batch_size,
n_classes=n_classes,
image_size=image_size,
shuffle=shuffle,
stage=stage,
Expand All @@ -30,17 +26,15 @@ def get_all_data(

def get_stage_data(
stage: Stage,
n_classes: int = _DEFAULT_N_CLASSES,
image_size: Tuple[int, int] = (256, 256),
batch_size: int = 32,
shuffle: bool = True,
shuffle: bool = False,
) -> ImageDataGenerator:
"""
Retrieve a data generator for a specific stage of the crowd segmentation task.
"""
return ImageDataGenerator(
batch_size=batch_size,
n_classes=n_classes,
image_size=image_size,
shuffle=shuffle,
stage=stage,
Expand Down
5 changes: 1 addition & 4 deletions core/seg_tgce/data/crowd_seg/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,10 @@

def main() -> None:
train, val, test = get_all_data(batch_size=8)
val.visualize_sample(["NP8", "NP16", "NP21", "expert"])
val.visualize_sample(batch_index=138, sample_indexes=[2, 3, 4, 5])
print(f"Train: {len(train)}")
print(f"Val: {len(val)}")
print(f"Test: {len(test)}")
img, mask = train[0]
print(f"Images shape: {img.shape}")
print(f"Masks shape: {mask.shape}")


main()
File renamed without changes.
Loading

0 comments on commit 251ab99

Please sign in to comment.