Skip to content

Commit

Permalink
enforce types to avoid json issues
Browse files Browse the repository at this point in the history
  • Loading branch information
JLPM22 committed May 30, 2024
1 parent 04714b8 commit e32e2ca
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 1 deletion.
19 changes: 19 additions & 0 deletions pymotion/render/viewer.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,14 @@ def add_skeleton(
"""

assert data.ndim in (2, 3), "'data' must have shape [frames, joints, 3] or [joints, 3]"
assert parents.ndim == 1, "'parents' must have shape [joints]"
assert data.shape[-1] == 3, "'data' must have shape [..., 3]"
assert data.shape[-2] == parents.shape[0], "'data' and 'parents' must have the same number of joints"

if data.dtype != np.float64:
data = data.astype(np.float64)
if parents.dtype != np.int64:
parents = parents.astype(np.int64)

if data.ndim == 2 or data.shape[0] == 1:
self.static_objs.extend(
Expand Down Expand Up @@ -142,6 +150,10 @@ def add_sphere(

assert sphere_mode in ("scatter", "mesh"), "'sphere_mode' must be 'scatter' or 'mesh'"
assert center.ndim in (1, 2), "'center' must have shape [frames, 3] or [3]"
assert center.shape[-1] == 3, "'center' must have shape [..., 3]"

if center.dtype != np.float64:
center = center.astype(np.float64)

if center.ndim == 1 or center.shape[0] == 1:
if sphere_mode == "scatter":
Expand Down Expand Up @@ -185,6 +197,13 @@ def add_line(
assert start.ndim in (1, 2), "'start' must have shape [frames, 3] or [3]"
assert end.ndim in (1, 2), "'end' must have shape [frames, 3] or [3]"
assert start.ndim == end.ndim, "'start' and 'end' must have the same number of dimensions"
assert start.shape[-1] == 3 and end.shape[-1], "'start' and 'end' must have shape [..., 3]"
assert start.shape[0] == end.shape[0], "'start' and 'end' must have the same number of frames"

if start.dtype != np.float64:
start = start.astype(np.float64)
if end.dtype != np.float64:
end = end.astype(np.float64)

if start.ndim == 1 or start.shape[0] == 1:
self.static_objs.append(
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ exclude = ["*test*"]

[project]
name = "upc-pymotion"
version = "0.1.9"
version = "0.1.10"
description = "A Python library for working with motion data in NumPy or PyTorch."
readme = "README.md"
authors = [{ name = "Jose Luis Ponton", email = "[email protected]" }]
Expand Down

0 comments on commit e32e2ca

Please sign in to comment.