Skip to content

Commit

Permalink
Implement Nan padding in blockwise coregistration stats (#663)
Browse files Browse the repository at this point in the history
  • Loading branch information
vschaffn authored Dec 16, 2024
1 parent 524311d commit cd86d69
Show file tree
Hide file tree
Showing 3 changed files with 120 additions and 42 deletions.
16 changes: 8 additions & 8 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -125,11 +125,11 @@ repos:
pass_filenames: false
additional_dependencies: [tomli, pyyaml]

# Add license header to the source files
- repo: local
hooks:
- id: add-license-header
name: Add License Header
entry: python .github/scripts/apply_license_header.py
language: python
files: \.py$
# # Add license header to the source files
# - repo: local
# hooks:
# - id: add-license-header
# name: Add License Header
# entry: python .github/scripts/apply_license_header.py
# language: python
# files: \.py$
43 changes: 37 additions & 6 deletions tests/test_coreg/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import xdem
from xdem import coreg, examples, misc, spatialstats
from xdem._typing import NDArrayf
from xdem.coreg import BlockwiseCoreg
from xdem.coreg.base import Coreg, apply_matrix, dict_key_to_str


Expand Down Expand Up @@ -924,11 +925,8 @@ def test_blockwise_coreg_large_gaps(self) -> None:

stats = blockwise.stats()

# We expect holes in the blockwise coregistration, so there should not be 64 "successful" blocks.
assert stats.shape[0] < 64

# Statistics are only calculated on finite values, so all of these should be finite as well.
assert np.all(np.isfinite(stats))
# We expect holes in the blockwise coregistration, but not in stats due to nan padding for failing chunks
assert stats.shape[0] == 64

# Copy the TBA DEM and set a square portion to nodata
tba = self.tba.copy()
Expand All @@ -938,7 +936,7 @@ def test_blockwise_coreg_large_gaps(self) -> None:

blockwise = xdem.coreg.BlockwiseCoreg(xdem.coreg.NuthKaab(), 8, warn_failures=False)

# Align the DEM and apply the blockwise to a zero-array (to get the zshift)
# Align the DEM and apply blockwise to a zero-array (to get the z_shift)
aligned = blockwise.fit(self.ref, tba).apply(tba)
zshift, _ = blockwise.apply(np.zeros_like(tba.data), transform=tba.transform, crs=tba.crs)

Expand All @@ -952,6 +950,39 @@ def test_blockwise_coreg_large_gaps(self) -> None:
assert abs(np.nanmedian(ddem_pre)) > abs(np.nanmedian(ddem_post))
# assert np.nanstd(ddem_pre) > np.nanstd(ddem_post)

def test_failed_chunks_return_nan(self) -> None:
blockwise = BlockwiseCoreg(xdem.coreg.NuthKaab(), subdivision=4)
blockwise.fit(**self.fit_params)
# Missing chunk 1 to simulate failure
blockwise._meta["step_meta"] = [meta for meta in blockwise._meta["step_meta"] if meta.get("i") != 1]

result_df = blockwise.stats()

# Check that chunk 1 (index 1) has NaN values for the statistics
assert np.isnan(result_df.loc[1, "inlier_count"])
assert np.isnan(result_df.loc[1, "nmad"])
assert np.isnan(result_df.loc[1, "median"])
assert isinstance(result_df.loc[1, "center_x"], float)
assert isinstance(result_df.loc[1, "center_y"], float)
assert np.isnan(result_df.loc[1, "center_z"])
assert np.isnan(result_df.loc[1, "x_off"])
assert np.isnan(result_df.loc[1, "y_off"])
assert np.isnan(result_df.loc[1, "z_off"])

def test_successful_chunks_return_values(self) -> None:
blockwise = BlockwiseCoreg(xdem.coreg.NuthKaab(), subdivision=2)
blockwise.fit(**self.fit_params)
result_df = blockwise.stats()

# Check that the correct statistics are returned for successful chunks
assert result_df.loc[0, "inlier_count"] == blockwise._meta["step_meta"][0]["inlier_count"]
assert result_df.loc[0, "nmad"] == blockwise._meta["step_meta"][0]["nmad"]
assert result_df.loc[0, "median"] == blockwise._meta["step_meta"][0]["median"]

assert result_df.loc[1, "inlier_count"] == blockwise._meta["step_meta"][1]["inlier_count"]
assert result_df.loc[1, "nmad"] == blockwise._meta["step_meta"][1]["nmad"]
assert result_df.loc[1, "median"] == blockwise._meta["step_meta"][1]["median"]


class TestAffineManipulation:

Expand Down
103 changes: 75 additions & 28 deletions xdem/coreg/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3039,6 +3039,7 @@ def __init__(
super().__init__()

self._meta: CoregDict = {"step_meta": []}
self._groups: NDArrayf = np.array([])

def fit(
self: CoregType,
Expand Down Expand Up @@ -3094,9 +3095,9 @@ def fit(
else:
mask = inlier_mask

groups = self.subdivide_array(tba_dem.shape if isinstance(tba_dem, np.ndarray) else ref_dem.shape)
self._groups = self.subdivide_array(tba_dem.shape if isinstance(tba_dem, np.ndarray) else ref_dem.shape)

indices = np.unique(groups)
indices = np.unique(self._groups)

progress_bar = tqdm(
total=indices.size, desc="Processing chunks", disable=logging.getLogger().getEffectiveLevel() > logging.INFO
Expand All @@ -3111,7 +3112,7 @@ def process(i: int) -> dict[str, Any] | BaseException | None:
* If it fails: The associated exception.
* If the block is empty: None
"""
group_mask = groups == i
group_mask = self._groups == i

# Find the corresponding slice of the inlier_mask to subset the data
rows, cols = np.where(group_mask)
Expand Down Expand Up @@ -3275,24 +3276,44 @@ def to_points(self) -> NDArrayf:
if len(self._meta["step_meta"]) == 0:
raise AssertionError("No coreg results exist. Has '.fit()' been called?")
points = np.empty(shape=(0, 3, 2))
for meta in self._meta["step_meta"]:
self._restore_metadata(meta)

# x_coord, y_coord = rio.transform.xy(meta["transform"], meta["representative_row"],
# meta["representative_col"])
x_coord, y_coord = meta["representative_x"], meta["representative_y"]
for i in range(self.subdivision):
# Try to restore the metadata for this chunk (if it succeeded)
chunk_meta = next((meta for meta in self._meta["step_meta"] if meta["i"] == i), None)

old_pos_arr = np.reshape([x_coord, y_coord, meta["representative_val"]], (1, 3))
if chunk_meta is not None:
# Successful chunk: Retrieve the representative X, Y, Z coordinates
self._restore_metadata(chunk_meta)
x_coord, y_coord = chunk_meta["representative_x"], chunk_meta["representative_y"]
repr_val = chunk_meta["representative_val"]
else:
# Failed chunk: Calculate the approximate center using the group's bounds
rows, cols = np.where(self._groups == i)
center_row = (rows.min() + rows.max()) // 2
center_col = (cols.min() + cols.max()) // 2

transform = self._meta["step_meta"][0]["transform"] # Assuming all chunks share a transform
x_coord, y_coord = rio.transform.xy(transform, center_row, center_col)
repr_val = np.nan # No valid Z value for failed chunks

# Old position based on the calculated or retrieved coordinates
old_pos_arr = np.reshape([x_coord, y_coord, repr_val], (1, 3))
old_position = gpd.GeoDataFrame(
geometry=gpd.points_from_xy(x=old_pos_arr[:, 0], y=old_pos_arr[:, 1], crs=None),
data={"z": old_pos_arr[:, 2]},
)

new_position = self.procstep.apply(old_position)
new_pos_arr = np.reshape(
[new_position.geometry.x.values, new_position.geometry.y.values, new_position["z"].values], (1, 3)
)
if chunk_meta is not None:
# Successful chunk: Apply the transformation
new_position = self.procstep.apply(old_position)
new_pos_arr = np.reshape(
[new_position.geometry.x.values, new_position.geometry.y.values, new_position["z"].values], (1, 3)
)
else:
# Failed chunk: Keep the new position the same as the old position (no transformation)
new_pos_arr = old_pos_arr.copy()

# Append the result
points = np.append(points, np.dstack((old_pos_arr, new_pos_arr)), axis=0)

return points
Expand All @@ -3310,6 +3331,7 @@ def stats(self) -> pd.DataFrame:
:raises ValueError: If no coregistration results exist yet.
:returns: A dataframe of statistics for each chunk.
If a chunk fails (not present in `chunk_meta`), the statistics will be returned as `NaN`.
"""
points = self.to_points()

Expand All @@ -3318,20 +3340,34 @@ def stats(self) -> pd.DataFrame:
statistics: list[dict[str, Any]] = []
for i in range(points.shape[0]):
if i not in chunk_meta:
continue
statistics.append(
{
"center_x": points[i, 0, 0],
"center_y": points[i, 1, 0],
"center_z": points[i, 2, 0],
"x_off": points[i, 0, 1] - points[i, 0, 0],
"y_off": points[i, 1, 1] - points[i, 1, 0],
"z_off": points[i, 2, 1] - points[i, 2, 0],
"inlier_count": chunk_meta[i]["inlier_count"],
"nmad": chunk_meta[i]["nmad"],
"median": chunk_meta[i]["median"],
}
)
# For missing chunks, return NaN for all stats
statistics.append(
{
"center_x": points[i, 0, 0],
"center_y": points[i, 1, 0],
"center_z": points[i, 2, 0],
"x_off": np.nan,
"y_off": np.nan,
"z_off": np.nan,
"inlier_count": np.nan,
"nmad": np.nan,
"median": np.nan,
}
)
else:
statistics.append(
{
"center_x": points[i, 0, 0],
"center_y": points[i, 1, 0],
"center_z": points[i, 2, 0],
"x_off": points[i, 0, 1] - points[i, 0, 0],
"y_off": points[i, 1, 1] - points[i, 1, 0],
"z_off": points[i, 2, 1] - points[i, 2, 0],
"inlier_count": chunk_meta[i]["inlier_count"],
"nmad": chunk_meta[i]["nmad"],
"median": chunk_meta[i]["median"],
}
)

stats_df = pd.DataFrame(statistics)
stats_df.index.name = "chunk"
Expand Down Expand Up @@ -3367,6 +3403,11 @@ def _apply_rst(
raise NotImplementedError("Option `resample=False` not supported for coreg method BlockwiseCoreg.")

points = self.to_points()
# Check for NaN values across both the old and new positions for each point
mask = ~np.isnan(points).any(axis=(1, 2))

# Filter out points where there are no NaN values
points = points[mask]

bounds = _bounds(transform=transform, shape=elev.shape)
resolution = _res(transform)
Expand Down Expand Up @@ -3410,6 +3451,12 @@ def _apply_pts(
"""Apply the scaling model to a set of points."""
points = self.to_points()

# Check for NaN values across both the old and new positions for each point
mask = ~np.isnan(points).any(axis=(1, 2))

# Filter out points where there are no NaN values
points = points[mask]

new_coords = np.array([elev.geometry.x.values, elev.geometry.y.values, elev["z"].values]).T

for dim in range(0, 3):
Expand Down Expand Up @@ -3518,7 +3565,7 @@ def warp_dem(
order = {"nearest": 0, "linear": 1, "cubic": 3}

with warnings.catch_warnings():
# An skimage warning that will hopefully be fixed soon. (2021-06-08)
# A skimage warning that will hopefully be fixed soon. (2021-06-08)
warnings.filterwarnings("ignore", message="Passing `np.nan` to mean no clipping in np.clip")
warped = skimage.transform.warp(
image=np.where(dem_mask, np.nan, dem_arr),
Expand Down

0 comments on commit cd86d69

Please sign in to comment.