Skip to content

Commit

Permalink
core\feat: #40 crowd seg stages
Browse files Browse the repository at this point in the history
- added methods for retrieving all crowdseg stages
  • Loading branch information
blotero committed May 20, 2024
1 parent 4e4488e commit 88bb094
Show file tree
Hide file tree
Showing 6 changed files with 126 additions and 21 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,4 @@ docs/build
**/Seed-Detection-2-1
notebooks/**/.ipynb_checkpoints
core/dist
__data__
47 changes: 47 additions & 0 deletions core/seg_tgce/data/crowd_seg/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
from typing import Tuple

from .generator import ImageDataGenerator
from .stage import Stage

_DEFAULT_N_CLASSES = 6


def get_all_data(
n_classes: int = _DEFAULT_N_CLASSES,
image_size: Tuple[int, int] = (256, 256),
batch_size: int = 32,
shuffle: bool = True,
) -> Tuple[ImageDataGenerator, ...]:
"""
Retrieve all data generators for the crowd segmentation task.
returns a tuple of ImageDataGenerator instances for the train, val, and test stages.
"""
return tuple(
ImageDataGenerator(
batch_size=batch_size,
n_classes=n_classes,
image_size=image_size,
shuffle=shuffle,
stage=stage,
)
for stage in (Stage.TRAIN, Stage.VAL, Stage.TEST)
)


def get_stage_data(
stage: Stage,
n_classes: int = _DEFAULT_N_CLASSES,
image_size: Tuple[int, int] = (256, 256),
batch_size: int = 32,
shuffle: bool = True,
) -> ImageDataGenerator:
"""
Retrieve a data generator for a specific stage of the crowd segmentation task.
"""
return ImageDataGenerator(
batch_size=batch_size,
n_classes=n_classes,
image_size=image_size,
shuffle=shuffle,
stage=stage,
)
12 changes: 12 additions & 0 deletions core/seg_tgce/data/crowd_seg/__main__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
from seg_tgce.data.crowd_seg import get_all_data


def main() -> None:
train, val, test = get_all_data()
val.visualize_sample(["NP8", "NP16", "NP21", "expert"])
print(f"Train: {len(train)}")
print(f"Val: {len(val)}")
print(f"Test: {len(test)}")


main()
40 changes: 19 additions & 21 deletions core/seg_tgce/data/crowd_seg/generator.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,41 @@
import logging
import os
from typing import List, Tuple
from typing import List, Optional, Tuple, TypedDict

import numpy as np
from keras.preprocessing.image import img_to_array, load_img
from keras.utils import Sequence
from matplotlib import pyplot as plt

from .retrieve import fetch_data, get_masks_dir, get_patches_dir
from .stage import Stage

LOGGER = logging.getLogger(__name__)
logging.basicConfig(level=logging.WARNING)


class CustomPath(TypedDict):
image_dir: str
mask_dir: str


class ImageDataGenerator(Sequence): # pylint: disable=too-many-instance-attributes
def __init__( # pylint: disable=too-many-arguments
self,
image_dir: str,
mask_dir: str,
n_classes: int,
image_size: Tuple[int, int] = (256, 256),
batch_size: int = 32,
shuffle: bool = True,
stage: Stage = Stage.TRAIN,
paths: Optional[CustomPath] = None,
):
if paths is not None:
image_dir = paths["image_dir"]
mask_dir = paths["mask_dir"]
else:
fetch_data()
image_dir = get_patches_dir(stage)
mask_dir = get_masks_dir(stage)
self.image_dir = image_dir
self.mask_dir = mask_dir
self.image_size = image_size
Expand All @@ -35,7 +50,7 @@ def __init__( # pylint: disable=too-many-arguments
)
self.n_scorers = len(os.listdir(mask_dir))
self.scorers_tags = sorted(os.listdir(mask_dir))
print(f"Scorer tags: {self.scorers_tags}")
LOGGER.info("Scorer tags: %s", self.scorers_tags)
self.n_classes = n_classes
self.on_epoch_end()

Expand Down Expand Up @@ -121,20 +136,3 @@ def __data_generation(self, batch_filenames):
images[batch] = image

return images, masks


if __name__ == "__main__":
val_gen = ImageDataGenerator(
image_dir="/home/brandon/unal/maestria/datasets/Histology Data/patches/Val",
mask_dir="/home/brandon/unal/maestria/datasets/Histology Data/masks/Val",
batch_size=16,
n_classes=6,
)
print(f"Train len: {len(val_gen)}")
print(f"Train masks scorers: {val_gen.n_scorers}")
print(f"Train masks scorers tags: {val_gen.scorers_tags}")
val_gen.visualize_sample(
batch_index=8,
sample_index=8,
scorers=["NP8", "NP16", "NP21", "expert"],
)
36 changes: 36 additions & 0 deletions core/seg_tgce/data/crowd_seg/retrieve.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
import logging
import os
import zipfile

import gdown

from .stage import Stage

_DATA_URL = "https://drive.google.com/drive/folders/17VukoKpwZclRrDcWSK1aYd_lPeqWNM8N?usp=sharing="
TARGET_DIR = "__data__/crowd_seg"


def get_masks_dir(stage: Stage) -> str:
return os.path.join(TARGET_DIR, "masks", stage.value)


def get_patches_dir(stage: Stage) -> str:
return os.path.join(TARGET_DIR, "patches", stage.value)


def unzip_dirs() -> None:
for root, _, files in os.walk(TARGET_DIR):
for file in files:
if file.endswith(".zip"):
with zipfile.ZipFile(os.path.join(root, file), "r") as zip_ref:
zip_ref.extractall(root)
os.remove(os.path.join(root, file))


def fetch_data() -> None:
if not os.path.exists(TARGET_DIR):
logging.info("Downloading data...")
gdown.download_folder(_DATA_URL, quiet=False, output=TARGET_DIR)
unzip_dirs()
return
logging.info("Data already exists.")
11 changes: 11 additions & 0 deletions core/seg_tgce/data/crowd_seg/stage.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
from enum import Enum


class Stage(Enum):
"""
Enum class for the stage of the data generator.
"""

TRAIN = "Train"
VAL = "Val"
TEST = "Test"

0 comments on commit 88bb094

Please sign in to comment.