-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- added methods for retrieving all crowdseg stages
- Loading branch information
Showing
6 changed files
with
126 additions
and
21 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -7,3 +7,4 @@ docs/build | |
**/Seed-Detection-2-1 | ||
notebooks/**/.ipynb_checkpoints | ||
core/dist | ||
__data__ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
from typing import Tuple | ||
|
||
from .generator import ImageDataGenerator | ||
from .stage import Stage | ||
|
||
_DEFAULT_N_CLASSES = 6 | ||
|
||
|
||
def get_all_data( | ||
n_classes: int = _DEFAULT_N_CLASSES, | ||
image_size: Tuple[int, int] = (256, 256), | ||
batch_size: int = 32, | ||
shuffle: bool = True, | ||
) -> Tuple[ImageDataGenerator, ...]: | ||
""" | ||
Retrieve all data generators for the crowd segmentation task. | ||
returns a tuple of ImageDataGenerator instances for the train, val, and test stages. | ||
""" | ||
return tuple( | ||
ImageDataGenerator( | ||
batch_size=batch_size, | ||
n_classes=n_classes, | ||
image_size=image_size, | ||
shuffle=shuffle, | ||
stage=stage, | ||
) | ||
for stage in (Stage.TRAIN, Stage.VAL, Stage.TEST) | ||
) | ||
|
||
|
||
def get_stage_data( | ||
stage: Stage, | ||
n_classes: int = _DEFAULT_N_CLASSES, | ||
image_size: Tuple[int, int] = (256, 256), | ||
batch_size: int = 32, | ||
shuffle: bool = True, | ||
) -> ImageDataGenerator: | ||
""" | ||
Retrieve a data generator for a specific stage of the crowd segmentation task. | ||
""" | ||
return ImageDataGenerator( | ||
batch_size=batch_size, | ||
n_classes=n_classes, | ||
image_size=image_size, | ||
shuffle=shuffle, | ||
stage=stage, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
from seg_tgce.data.crowd_seg import get_all_data | ||
|
||
|
||
def main() -> None: | ||
train, val, test = get_all_data() | ||
val.visualize_sample(["NP8", "NP16", "NP21", "expert"]) | ||
print(f"Train: {len(train)}") | ||
print(f"Val: {len(val)}") | ||
print(f"Test: {len(test)}") | ||
|
||
|
||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
import logging | ||
import os | ||
import zipfile | ||
|
||
import gdown | ||
|
||
from .stage import Stage | ||
|
||
_DATA_URL = "https://drive.google.com/drive/folders/17VukoKpwZclRrDcWSK1aYd_lPeqWNM8N?usp=sharing=" | ||
TARGET_DIR = "__data__/crowd_seg" | ||
|
||
|
||
def get_masks_dir(stage: Stage) -> str: | ||
return os.path.join(TARGET_DIR, "masks", stage.value) | ||
|
||
|
||
def get_patches_dir(stage: Stage) -> str: | ||
return os.path.join(TARGET_DIR, "patches", stage.value) | ||
|
||
|
||
def unzip_dirs() -> None: | ||
for root, _, files in os.walk(TARGET_DIR): | ||
for file in files: | ||
if file.endswith(".zip"): | ||
with zipfile.ZipFile(os.path.join(root, file), "r") as zip_ref: | ||
zip_ref.extractall(root) | ||
os.remove(os.path.join(root, file)) | ||
|
||
|
||
def fetch_data() -> None: | ||
if not os.path.exists(TARGET_DIR): | ||
logging.info("Downloading data...") | ||
gdown.download_folder(_DATA_URL, quiet=False, output=TARGET_DIR) | ||
unzip_dirs() | ||
return | ||
logging.info("Data already exists.") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
from enum import Enum | ||
|
||
|
||
class Stage(Enum): | ||
""" | ||
Enum class for the stage of the data generator. | ||
""" | ||
|
||
TRAIN = "Train" | ||
VAL = "Val" | ||
TEST = "Test" |