Skip to content

Commit

Permalink
adding some code generation for sample loading
Browse files Browse the repository at this point in the history
  • Loading branch information
chrishavlin committed Oct 22, 2024
1 parent e1beb32 commit 39ae545
Show file tree
Hide file tree
Showing 19 changed files with 529 additions and 89 deletions.
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,4 @@ skip = ["venv", "benchmarks"]
[tool.taskipy.tasks]
validate_release = { cmd = "python repo_utilities/validate.py", help = "validates for a release" }
update_schema_docs = { cmd = "python repo_utilities/update_schema_docs.py", help = "updates the schema related documentation" }
update_sample_data = { cmd = "python repo_utilities/update_sample_data.py", help = "updates sample data code" }
176 changes: 176 additions & 0 deletions repo_utilities/update_sample_data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,176 @@
import json
import os
from collections import defaultdict

import yaml

# requirements: cartesian, 3D, grid-based
enabled = [
"DeeplyNestedZoom",
"Enzo_64",
"HiresIsolatedGalaxy",
"IsolatedGalaxy",
"PopIII_mini",
# 'MHDSloshing',
"GaussianCloud",
"SmartStars",
# 'ENZOE_orszag-tang_0.5', # cant handle -, .
"GalaxyClusterMerger", # big but neat
# 'InteractingJets',
"cm1_tornado_lofs",
]
enabled.sort()

# default field to load, whether or not to log
sample_field = defaultdict(lambda: ("gas", "density"))
log_field = defaultdict(lambda: True)


# over-ride the default for some
sample_field["cm1_tornado_lofs"] = ("cm1", "dbz")
log_field["cm1_tornado_lofs"] = False


def get_sample_func_name(sample: str):
return f"sample_{sample.lower()}"


def pop_a_command(command: str, napari_config: dict):

popid = None
for icmd, cmd in enumerate(napari_config["contributions"]["commands"]):
if cmd["id"] == command:
popid = icmd

if popid is not None:
napari_config["contributions"]["commands"].pop(popid)


def get_command_name(sample_name: str):
return f"yt-napari.data.{sample_name.lower()}"


def get_command_entry(sample_name: str):
cmmnd = {}
cmmnd["id"] = get_command_name(sample_name)
cmmnd["title"] = f"Load {sample_name}"
funcname = get_sample_func_name(sample_name)
cmmnd["python_name"] = f"yt_napari.sample_data._sample_data:{funcname}"
return cmmnd


def get_sample_table_entry(sample_name: str):
entry = {}
entry["key"] = sample_name.lower()
entry["display_name"] = sample_name
entry["command"] = get_command_name(sample_name)
return entry


def update_napari_hooks(napari_yaml):

with open(napari_yaml, "r") as file:
napari_config = yaml.safe_load(file)

existing = []
if "sample_data" in napari_config["contributions"]:
existing = napari_config["contributions"]["sample_data"]

# first remove existing commands
for sample in existing:
pop_a_command(sample["command"], napari_config)

# now remove the sample data entries
napari_config["contributions"]["sample_data"] = []

# now repopulate
for sample in enabled:
entry = get_sample_table_entry(sample)
napari_config["contributions"]["sample_data"].append(entry)

new_command = get_command_entry(sample)
napari_config["contributions"]["commands"].append(new_command)

with open(napari_yaml, "w") as file:
yaml.dump(napari_config, file)


def get_load_dict(sample_name):
load_dict = {"datasets": []}

field_type, field_name = sample_field[sample_name]
ds_entry = {
"filename": sample_name,
"selections": {
"regions": [
{
"fields": [
{
"field_name": field_name,
"field_type": field_type,
"take_log": log_field[sample_name],
}
]
}
]
},
}
load_dict["datasets"].append(ds_entry)
return load_dict


def write_sample_jsons(json_dir):

# first clear out
for fname in os.listdir(json_dir):
if fname.endswith(".json"):
os.remove(os.path.join(json_dir, fname))

# and add back
for sample in enabled:
json_name = os.path.join(json_dir, f"sample_{sample.lower()}.json")
load_dict = get_load_dict(sample)
with open(json_name, "w") as fi:
json.dump(load_dict, fi, indent=4)

enabled_j = {"enabled": enabled}
enabled_file = os.path.join(json_dir, "sample_registry.json")
with open(enabled_file, "w") as fi:
json.dump(enabled_j, fi, indent=4)


def single_sample_loader(sample: str):
code = []
code.append(f"def {get_sample_func_name(sample)}():")
code.append(f" return load_sample_data('{sample}')")
code.append("")
code.append("")
return code


def write_sample_data_python_loaders(sample_data_dir):
sd_py = []
sd_py.append("# this file is autogenerated byt the taskipy update_sample data task")
sd_py.append("# to re-generate it, along with all the json files in this dir, run:")
sd_py.append("# task update_sample_data")
sd_py.append("# (requires taskipy: pip install taskipy)")
sd_py.append("from yt_napari.sample_data._generic_loader import load_sample_data")
sd_py.append("")
sd_py.append("")

for sample in enabled:
sample_code = single_sample_loader(sample)
sd_py += sample_code

sd_py.pop(-1) # only want one blank line at the end

loader_file = os.path.join(sample_data_dir, "_sample_data.py")
with open(loader_file, "w") as fi:
fi.write("\n".join(sd_py))


if __name__ == "__main__":

update_napari_hooks("src/yt_napari/napari.yaml")
write_sample_jsons("src/yt_napari/sample_data/")
write_sample_data_python_loaders("src/yt_napari/sample_data/")
24 changes: 16 additions & 8 deletions src/yt_napari/_ds_cache.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import json
import os.path
from os import PathLike
from typing import Optional
from typing import List, Optional

import yt

Expand All @@ -11,15 +12,27 @@

def _load_sample(filename):
# TODO: check for pooch, pandas.
# TODO: catch key errors, etc.
ds = yt.load_sample(filename)
return ds

Check warning on line 16 in src/yt_napari/_ds_cache.py

View check run for this annotation

Codecov / codecov/patch

src/yt_napari/_ds_cache.py#L15-L16

Added lines #L15 - L16 were not covered by tests


def get_sample_set_list() -> List[str]:
import importlib.resources as importlib_resources

jdata = json.loads(
importlib_resources.files("yt_napari")
.joinpath("sample_data")
.joinpath("sample_registry.json")
.read_bytes()
)
return jdata["enabled"]


class DatasetCache:
def __init__(self):
self.available = {}
self._most_recent: str = None
self.sample_sets: List[str] = get_sample_set_list()

def add_ds(self, ds, name: str):
if name in self.available:
Expand Down Expand Up @@ -60,12 +73,7 @@ def check_then_load(self, filename: str, cache_if_not_found: bool = True):
ds_callable = getattr(_special_loaders, callable_name)
ds = ds_callable()
else:
# TODO: have this sample files registry come from yt,
# just setting up the napari side for now. Should
# also add a config option maybe to handle name
# conflicts between sample files and local files?
sample_files = ["IsolatedGalaxy"]
if filename in sample_files:
if filename in self.sample_sets:
ds = _load_sample(filename)

Check warning on line 77 in src/yt_napari/_ds_cache.py

View check run for this annotation

Codecov / codecov/patch

src/yt_napari/_ds_cache.py#L77

Added line #L77 was not covered by tests
else:
ds = yt.load(filename)
Expand Down
1 change: 0 additions & 1 deletion src/yt_napari/_model_ingestor.py
Original file line number Diff line number Diff line change
Expand Up @@ -764,7 +764,6 @@ def load_from_json_strs(json_strs: List[str]) -> List[Layer]:
for json_str in json_strs:
# InputModel is a pydantic class, the following will validate the json
model = InputModel.model_validate_json(json_str)
print(model)
# now that we have a validated model, we can use the model attributes
# to execute the code that will return our array for the image
layer_lists_j, timeseries_layers_j = _process_validated_model(model)
Expand Down
50 changes: 0 additions & 50 deletions src/yt_napari/_sample_data.py

This file was deleted.

Loading

0 comments on commit 39ae545

Please sign in to comment.