Skip to content

Commit

Permalink
all\refac: #5 trimmed scorers
Browse files Browse the repository at this point in the history
- trimmed scorers instead of global forced balance for crowd_seg
  • Loading branch information
blotero committed Dec 6, 2024
1 parent 53e246c commit fd89287
Show file tree
Hide file tree
Showing 5 changed files with 33 additions and 21 deletions.
4 changes: 2 additions & 2 deletions core/pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
[project]
description = "Framework for handling image segmentation in the context of multiple annotators"
name = "seg_tgce"
version = "0.1.7"
version = "0.1.8"
readme = "README.md"
authors = [{ name = "Brandon Lotero", email = "[email protected]" }]
maintainers = [{ name = "Brandon Lotero", email = "[email protected]" }]
Expand All @@ -15,7 +15,7 @@ Issues = "https://github.com/blotero/seg_tgce/issues"

[tool.poetry]
name = "seg_tgce"
version = "0.1.7"
version = "0.1.8"
authors = ["Brandon Lotero <[email protected]>"]
description = "A package for the SEG TGCE project"
readme = "README.md"
Expand Down
4 changes: 2 additions & 2 deletions core/seg_tgce/data/crowd_seg/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ def get_all_data(
batch_size: int = 32,
shuffle: bool = False,
with_sparse_data: bool = False,
force_balance: bool = False,
trim_n_scorers: int | None = None,
) -> Tuple[ImageDataGenerator, ...]:
"""
Retrieve all data generators for the crowd segmentation task.
Expand All @@ -24,7 +24,7 @@ def get_all_data(
shuffle=shuffle,
stage=stage,
schema=DataSchema.MA_SPARSE if with_sparse_data else DataSchema.MA_RAW,
force_balance=force_balance,
trim_n_scorers=trim_n_scorers,
)
for stage in (Stage.TRAIN, Stage.VAL, Stage.TEST)
)
Expand Down
4 changes: 2 additions & 2 deletions core/seg_tgce/data/crowd_seg/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,10 @@ def main() -> None:
print(f"Val: {len(val)} batches, {len(val) * val.batch_size} samples")
print(f"Test: {len(test)} batches, {len(test) * test.batch_size} samples")

print("Loading train data with forced balance...")
print("Loading train data with trimmed scorers...")
train = ImageDataGenerator(
batch_size=8,
force_balance=True,
trim_n_scorers=6,
)
print(f"Train: {len(train)} batches, {len(train) * train.batch_size} samples")

Expand Down
32 changes: 21 additions & 11 deletions core/seg_tgce/data/crowd_seg/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,10 +55,16 @@ class DataSchema(str, Enum):
MA_SPARSE = "ma_sparse"


def find_n_scorers(data: dict[str, dict[str, Any]], n: int) -> List[str]:
# return a list of length n with the scorers that scored the most images
scorers = sorted(data.keys(), key=lambda x: data[x]["total"], reverse=True)
return scorers[:n]


def get_image_filenames(
image_dir: str, stage: Stage, *, force_balance: bool = False
image_dir: str, stage: Stage, *, trim_n_scorers: int | None
) -> List[str]:
if not force_balance:
if trim_n_scorers is None:
return sorted(
[
filename
Expand All @@ -70,11 +76,16 @@ def get_image_filenames(
inverted_data_path = f"{METADATA_PATH}/{stage.name.lower()}_inverted.json"
with open(inverted_data_path, "r", newline="", encoding="utf-8") as json_file:
inverted_data: dict[str, Any] = json.load(json_file)
# determine `limit` as the lowest number of images scored by a scorer
limit = min(data["total"] for data in inverted_data.values())
LOGGER.info("Forced balance: limiting to %d images per scorer.", limit)
for scorer_data in inverted_data.values():
filenames.update(scorer_data["scored"][:limit])
# trim to n scorers which scored the most images:
trimmed_scorers = find_n_scorers(inverted_data, trim_n_scorers)

LOGGER.info(
"Limiting dataset to only images scored by the top %d scorers: %s",
trim_n_scorers,
trimmed_scorers,
)
for scorer in trimmed_scorers:
filenames.update(inverted_data[scorer]["scored"])
return list(filenames)


Expand All @@ -93,7 +104,7 @@ class ImageDataGenerator(Sequence): # pylint: disable=too-many-instance-attribu
- stage: Stage = Stage.TRAIN: Stage of the dataset.
- paths: Optional[CustomPath] = None: Custom paths for image and mask directories.
- schema: DataSchema = DataSchema.MA_RAW: Data schema for the dataset.
- force_balance: bool = False: Force balance the dataset by downsampling.
- trim_n_scorers: int | None = None: Trim and leave only top n scorers
"""

Expand All @@ -106,7 +117,7 @@ def __init__( # pylint: disable=too-many-arguments
stage: Stage = Stage.TRAIN,
paths: Optional[CustomPath] = None,
schema: DataSchema = DataSchema.MA_RAW,
force_balance: bool = False,
trim_n_scorers: int | None = None,
) -> None:
if paths is not None:
image_dir = paths["image_dir"]
Expand All @@ -121,7 +132,7 @@ def __init__( # pylint: disable=too-many-arguments
self.batch_size = batch_size
self.shuffle = shuffle
self.image_filenames = get_image_filenames(
image_dir, stage, force_balance=force_balance
image_dir, stage, trim_n_scorers=trim_n_scorers
)
self.scorers_tags = sorted(os.listdir(mask_dir))
self.on_epoch_end()
Expand All @@ -131,7 +142,6 @@ def __init__( # pylint: disable=too-many-arguments
for filename in self.image_filenames
}
self.stage = stage
self.force_balance = force_balance

@property
def classes_definition(self) -> dict[int, str]:
Expand Down
10 changes: 6 additions & 4 deletions docs/source/experiments.rst
Original file line number Diff line number Diff line change
Expand Up @@ -96,18 +96,20 @@ A single stage can also be fetched, and even visualized:
val.visualize_sample()
For architecture debugging purposes, you can also fetch the data with a
downsampling for forcing balance between scoreres:
downsampling for trimming scorers and leaving only the ones who scored
the most patches:

.. code:: python
train = get_stage_data(stage = Stage.TRAIN, batch_size=8, force_balance=True)
train = get_stage_data(stage = Stage.TRAIN, batch_size=8, trim_n_scorers=6)
Output:

.. code:: text
Loading train data with forced balance...
INFO:seg_tgce.data.crowd_seg.generator:Forced balance: limiting to 102 images per scorer.
Loading train data with trimmed scorers...
INFO:seg_tgce.data.crowd_seg.generator:Limiting dataset to only images scored by the top 6 scorers: ['MV', 'STAPLE', 'expert', 'NP6', 'NP10', 'NP3']
Train: 1272 batches, 10176 samples
When running the ``visualize_sample`` method, the generator will load
the images and masks from the disk and display them, with a result
Expand Down

0 comments on commit fd89287

Please sign in to comment.