diff --git a/core/poetry.lock b/core/poetry.lock index 2446290..146ceda 100644 --- a/core/poetry.lock +++ b/core/poetry.lock @@ -1744,19 +1744,18 @@ pyasn1 = ">=0.1.3" [[package]] name = "setuptools" -version = "69.5.1" +version = "70.0.0" description = "Easily download, build, install, upgrade, and uninstall Python packages" optional = false python-versions = ">=3.8" files = [ - {file = "setuptools-69.5.1-py3-none-any.whl", hash = "sha256:c636ac361bc47580504644275c9ad802c50415c7522212252c033bd15f301f32"}, - {file = "setuptools-69.5.1.tar.gz", hash = "sha256:6c1fccdac05a97e598fb0ae3bbed5904ccb317337a51139dcd51453611bbb987"}, + {file = "setuptools-70.0.0-py3-none-any.whl", hash = "sha256:54faa7f2e8d2d11bcd2c07bed282eef1046b5c080d1c32add737d7b5817b1ad4"}, + {file = "setuptools-70.0.0.tar.gz", hash = "sha256:f211a66637b8fa059bb28183da127d4e86396c991a942b028c6650d4319c3fd0"}, ] [package.extras] -docs = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "pygments-github-lexers (==0.0.5)", "rst.linker (>=1.9)", "sphinx (>=3.5)", "sphinx-favicon", "sphinx-inline-tabs", "sphinx-lint", "sphinx-notfound-page (>=1,<2)", "sphinx-reredirects", "sphinxcontrib-towncrier"] -testing = ["build[virtualenv]", "filelock (>=3.4.0)", "importlib-metadata", "ini2toml[lite] (>=0.9)", "jaraco.develop (>=7.21)", "jaraco.envs (>=2.2)", "jaraco.path (>=3.2.0)", "mypy (==1.9)", "packaging (>=23.2)", "pip (>=19.1)", "pytest (>=6,!=8.1.1)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=2.2)", "pytest-home (>=0.5)", "pytest-mypy", "pytest-perf", "pytest-ruff (>=0.2.1)", "pytest-timeout", "pytest-xdist (>=3)", "tomli", "tomli-w (>=1.0.0)", "virtualenv (>=13.0.0)", "wheel"] -testing-integration = ["build[virtualenv] (>=1.0.3)", "filelock (>=3.4.0)", "jaraco.envs (>=2.2)", "jaraco.path (>=3.2.0)", "packaging (>=23.2)", "pytest", "pytest-enabler", "pytest-xdist", "tomli", "virtualenv (>=13.0.0)", "wheel"] +docs = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "pygments-github-lexers (==0.0.5)", "pyproject-hooks (!=1.1)", "rst.linker (>=1.9)", "sphinx (>=3.5)", "sphinx-favicon", "sphinx-inline-tabs", "sphinx-lint", "sphinx-notfound-page (>=1,<2)", "sphinx-reredirects", "sphinxcontrib-towncrier"] +testing = ["build[virtualenv] (>=1.0.3)", "filelock (>=3.4.0)", "importlib-metadata", "ini2toml[lite] (>=0.14)", "jaraco.develop (>=7.21)", "jaraco.envs (>=2.2)", "jaraco.path (>=3.2.0)", "mypy (==1.9)", "packaging (>=23.2)", "pip (>=19.1)", "pyproject-hooks (!=1.1)", "pytest (>=6,!=8.1.1)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=2.2)", "pytest-home (>=0.5)", "pytest-mypy", "pytest-perf", "pytest-ruff (>=0.2.1)", "pytest-subprocess", "pytest-timeout", "pytest-xdist (>=3)", "tomli", "tomli-w (>=1.0.0)", "virtualenv (>=13.0.0)", "wheel"] [[package]] name = "six" diff --git a/core/pyproject.toml b/core/pyproject.toml index 5c9f050..52d066d 100644 --- a/core/pyproject.toml +++ b/core/pyproject.toml @@ -1,7 +1,7 @@ [project] description = "Framework for handling image segmentation in the context of multiple annotators" name = "seg_tgce" -version = "0.1.3" +version = "0.1.3.dev1" readme = "README.md" authors = [{ name = "Brandon Lotero", email = "blotero@gmail.com" }] maintainers = [{ name = "Brandon Lotero", email = "blotero@gmail.com" }] @@ -15,7 +15,7 @@ Issues = "https://github.com/blotero/seg_tgce/issues" [tool.poetry] name = "seg_tgce" -version = "0.1.3" +version = "0.1.3.dev1" authors = ["Brandon Lotero "] description = "A package for the SEG TGCE project" readme = "README.md" diff --git a/core/seg_tgce/data/crowd_seg/__main__.py b/core/seg_tgce/data/crowd_seg/__main__.py index 2b0b0a8..19ceb69 100644 --- a/core/seg_tgce/data/crowd_seg/__main__.py +++ b/core/seg_tgce/data/crowd_seg/__main__.py @@ -2,11 +2,14 @@ def main() -> None: - train, val, test = get_all_data() + train, val, test = get_all_data(batch_size=8) val.visualize_sample(["NP8", "NP16", "NP21", "expert"]) print(f"Train: {len(train)}") print(f"Val: {len(val)}") print(f"Test: {len(test)}") + img, mask = train[0] + print(f"Images shape: {img.shape}") + print(f"Masks shape: {mask.shape}") main() diff --git a/core/seg_tgce/data/crowd_seg/generator.py b/core/seg_tgce/data/crowd_seg/generator.py index cb77521..0f6aee0 100644 --- a/core/seg_tgce/data/crowd_seg/generator.py +++ b/core/seg_tgce/data/crowd_seg/generator.py @@ -6,6 +6,7 @@ from keras.preprocessing.image import img_to_array, load_img from keras.utils import Sequence from matplotlib import pyplot as plt +from tensorflow import transpose from .retrieve import fetch_data, get_masks_dir, get_patches_dir from .stage import Stage @@ -15,11 +16,20 @@ class CustomPath(TypedDict): + """Custom path for image and mask directories.""" + image_dir: str mask_dir: str class ImageDataGenerator(Sequence): # pylint: disable=too-many-instance-attributes + """ + Data generator for crowd segmentation data. + Delivered data is in the form of images and masks. + Shapes are as follows: + - images: (batch_size, image_size[0], image_size[1], 3) + - masks: (batch_size, image_size[0], image_size[1]), n_classes, n_scorers""" + def __init__( # pylint: disable=too-many-arguments self, n_classes: int, @@ -81,7 +91,7 @@ def visualize_sample( for class_num in range(self.n_classes): axes[scorer_num][0].imshow(images[sample_index].astype(int)) axes[scorer_num][class_num + 1].imshow( - masks[sample_index, scorer_num, class_num] + masks[sample_index, :, :, class_num, scorer_num] ) axes[scorer_num][0].axis("off") axes[scorer_num][class_num + 1].axis("off") @@ -135,4 +145,4 @@ def __data_generation(self, batch_filenames): images[batch] = image - return images, masks + return images, transpose(masks, perm=[0, 3, 4, 2, 1])