Skip to content

Commit

Permalink
Add helper class for wrapping MPI open of zarr Group.
Browse files Browse the repository at this point in the history
  • Loading branch information
tskisner committed Jan 2, 2025
1 parent bff181d commit ddc2d68
Show file tree
Hide file tree
Showing 3 changed files with 53 additions and 18 deletions.
2 changes: 1 addition & 1 deletion docs/mkdocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ plugins:
show_symbol_type_heading: true
show_symbol_type_toc: true
- mkdocs-jupyter
execute: false
execute: false

nav:
- Introduction: index.md
Expand Down
34 changes: 17 additions & 17 deletions flacarray/tests/zarr.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

from ..array import FlacArray
from ..demo import create_fake_data
from ..zarr import have_zarr, write_array, read_array
from ..zarr import have_zarr, write_array, read_array, ZarrGroup
from ..mpi import use_mpi, MPI

if have_zarr:
Expand Down Expand Up @@ -65,7 +65,7 @@ def test_direct_write_read(self):
tmppath = self.comm.bcast(tmppath, root=0)

i32_file = os.path.join(tmppath, "data_i32.zarr")
with zarr.open_group(i32_file, mode="w") as zf:
with ZarrGroup(i32_file, mode="w", comm=self.comm) as zf:
write_array(
input32,
zf,
Expand All @@ -77,7 +77,7 @@ def test_direct_write_read(self):
)
if self.comm is not None:
self.comm.barrier()
with zarr.open_group(i32_file, mode="r") as zf:
with ZarrGroup(i32_file, mode="r", comm=self.comm) as zf:
check_i32 = read_array(
zf,
keep=None,
Expand All @@ -89,7 +89,7 @@ def test_direct_write_read(self):
)

i64_file = os.path.join(tmppath, "data_i64.zarr")
with zarr.open_group(i64_file, mode="w") as zf:
with ZarrGroup(i64_file, mode="w", comm=self.comm) as zf:
write_array(
input64,
zf,
Expand All @@ -101,7 +101,7 @@ def test_direct_write_read(self):
)
if self.comm is not None:
self.comm.barrier()
with zarr.open_group(i64_file, mode="r") as zf:
with ZarrGroup(i64_file, mode="r", comm=self.comm) as zf:
check_i64 = read_array(
zf,
keep=None,
Expand All @@ -113,7 +113,7 @@ def test_direct_write_read(self):
)

f32_file = os.path.join(tmppath, "data_f32.zarr")
with zarr.open_group(f32_file, mode="w") as zf:
with ZarrGroup(f32_file, mode="w", comm=self.comm) as zf:
write_array(
inputf32,
zf,
Expand All @@ -125,7 +125,7 @@ def test_direct_write_read(self):
)
if self.comm is not None:
self.comm.barrier()
with zarr.open_group(f32_file, mode="r") as zf:
with ZarrGroup(f32_file, mode="r", comm=self.comm) as zf:
check_f32 = read_array(
zf,
keep=None,
Expand All @@ -137,7 +137,7 @@ def test_direct_write_read(self):
)

f64_file = os.path.join(tmppath, "data_f64.zarr")
with zarr.open_group(f64_file, mode="w") as zf:
with ZarrGroup(f64_file, mode="w", comm=self.comm) as zf:
write_array(
inputf64,
zf,
Expand All @@ -149,7 +149,7 @@ def test_direct_write_read(self):
)
if self.comm is not None:
self.comm.barrier()
with zarr.open_group(f64_file, mode="r") as zf:
with ZarrGroup(f64_file, mode="r", comm=self.comm) as zf:
check_f64 = read_array(
zf,
keep=None,
Expand Down Expand Up @@ -231,35 +231,35 @@ def test_array_write_read(self):
tmppath = self.comm.bcast(tmppath, root=0)

i32_file = os.path.join(tmppath, "data_i32.zarr")
with zarr.open_group(i32_file, mode="w") as zf:
with ZarrGroup(i32_file, mode="w", comm=self.comm) as zf:
flcarr_i32.write_zarr(zf)
if self.comm is not None:
self.comm.barrier()
with zarr.open_group(i32_file, mode="r") as zf:
with ZarrGroup(i32_file, mode="r", comm=self.comm) as zf:
check_i32 = FlacArray.read_zarr(zf, mpi_comm=self.comm)

i64_file = os.path.join(tmppath, "data_i64.zarr")
with zarr.open_group(i64_file, mode="w") as zf:
with ZarrGroup(i64_file, mode="w", comm=self.comm) as zf:
flcarr_i64.write_zarr(zf)
if self.comm is not None:
self.comm.barrier()
with zarr.open_group(i64_file, mode="r") as zf:
with ZarrGroup(i64_file, mode="r", comm=self.comm) as zf:
check_i64 = FlacArray.read_zarr(zf, mpi_comm=self.comm)

f32_file = os.path.join(tmppath, "data_f32.zarr")
with zarr.open_group(f32_file, mode="w") as zf:
with ZarrGroup(f32_file, mode="w", comm=self.comm) as zf:
flcarr_f32.write_zarr(zf)
if self.comm is not None:
self.comm.barrier()
with zarr.open_group(f32_file, mode="r") as zf:
with ZarrGroup(f32_file, mode="r", comm=self.comm) as zf:
check_f32 = FlacArray.read_zarr(zf, mpi_comm=self.comm)

f64_file = os.path.join(tmppath, "data_f64.zarr")
with zarr.open_group(f64_file, mode="w") as zf:
with ZarrGroup(f64_file, mode="w", comm=self.comm) as zf:
flcarr_f64.write_zarr(zf)
if self.comm is not None:
self.comm.barrier()
with zarr.open_group(f64_file, mode="r") as zf:
with ZarrGroup(f64_file, mode="r", comm=self.comm) as zf:
check_f64 = FlacArray.read_zarr(zf, mpi_comm=self.comm)

del tmppath
Expand Down
35 changes: 35 additions & 0 deletions flacarray/zarr.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,41 @@
from .utils import function_timer


class ZarrGroup(object):
"""Wrapper class containing an open Zarr Group.
The named object is a file opened in the specified mode on the root process.
On other processes the handle will be None.
Args:
name (str): The filesystem path.
mode (str): The opening mode.
comm (MPI.Comm): The MPI communicator or None.
"""
def __init__(self, name, mode, comm=None):
self.handle = None
if comm is None or comm.rank == 0:
self.handle = zarr.open_group(name, mode=mode)
if comm is not None:
comm.barrier()

def close(self):
if hasattr(self, "handle") and self.handle is not None:
self.handle.store.close()
del self.handle
self.handle = None

def __del__(self):
self.close()

def __enter__(self):
return self.handle

def __exit__(self, *args):
self.close()


class WriterZarr:
"""Helper class for the common writer function."""

Expand Down

0 comments on commit ddc2d68

Please sign in to comment.