Skip to content

Commit

Permalink
precommit
Browse files Browse the repository at this point in the history
  • Loading branch information
esoteric-ephemera committed Jan 21, 2025
1 parent afb7691 commit b7d5c54
Show file tree
Hide file tree
Showing 2 changed files with 98 additions and 64 deletions.
15 changes: 9 additions & 6 deletions emmet-core/emmet/core/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,23 +4,26 @@
Vector3D = TypeVar("Vector3D", bound=tuple[float, float, float])
Vector3D.__doc__ = "Real space vector" # type: ignore

Matrix3D = TypeVar("Matrix3D", bound = tuple[Vector3D, Vector3D, Vector3D])
Matrix3D = TypeVar("Matrix3D", bound=tuple[Vector3D, Vector3D, Vector3D])
Matrix3D.__doc__ = "Real space Matrix" # type: ignore

Vector6D = TypeVar("Vector6D", bound=tuple[float, float, float, float, float, float])
Vector6D.__doc__ = "6D Voigt matrix component" # type: ignore

MatrixVoigt = TypeVar("MatrixVoigt",bound=tuple[Vector6D, Vector6D, Vector6D, Vector6D, Vector6D, Vector6D])
MatrixVoigt = TypeVar(
"MatrixVoigt",
bound=tuple[Vector6D, Vector6D, Vector6D, Vector6D, Vector6D, Vector6D],
)
MatrixVoigt.__doc__ = "Voigt representation of a 3x3x3x3 tensor" # type: ignore

Tensor3R = TypeVar("Tensor3R",bound=list[list[list[float]]])
Tensor3R = TypeVar("Tensor3R", bound=list[list[list[float]]])
Tensor3R.__doc__ = "Generic tensor of rank 3" # type: ignore

Tensor4R = TypeVar("Tensor4R",bound=list[list[list[list[float]]]])
Tensor4R = TypeVar("Tensor4R", bound=list[list[list[list[float]]]])
Tensor4R.__doc__ = "Generic tensor of rank 4" # type: ignore

ListVector3D = TypeVar("ListVector3D",bound=list[float])
ListVector3D = TypeVar("ListVector3D", bound=list[float])
ListVector3D.__doc__ = "Real space vector as list" # type: ignore

ListMatrix3D = TypeVar("ListMatrix3D",bound=list[ListVector3D])
ListMatrix3D = TypeVar("ListMatrix3D", bound=list[ListVector3D])
ListMatrix3D.__doc__ = "Real space Matrix as list" # type: ignore
147 changes: 89 additions & 58 deletions emmet-core/emmet/core/structure_replicas.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,24 +19,25 @@
from typing import Any
from typing_extensions import Self


class EmmetReplica(BaseModel):
"""Define strongly typed, fixed schema versions of generic pymatgen objects."""

@classmethod
def from_pymatgen(cls, pmg_obj : Any) -> Self:
def from_pymatgen(cls, pmg_obj: Any) -> Self:
"""Convert pymatgen objects to an EmmetReplica representation."""
raise NotImplementedError

def to_pymatgen(self) -> Any:
"""Convert EmmetReplica object to pymatgen equivalent."""
raise NotImplementedError

@classmethod
def from_dict(cls, dct : dict[str,Any]) -> Self:
def from_dict(cls, dct: dict[str, Any]) -> Self:
"""MSONable-like function to create this object from a dict."""
raise NotImplementedError

def as_dict(self) -> dict[str,Any]:
def as_dict(self) -> dict[str, Any]:
"""MSONable-like function to create dict representation of this object."""
raise NotImplementedError

Expand All @@ -49,6 +50,7 @@ class SiteProperties(Enum):
velocities = "velocities"
selective_dynamics = "selective_dynamics"


class ElementSymbol(Enum):
"""Lightweight representation of a chemical element."""

Expand Down Expand Up @@ -180,6 +182,7 @@ def __str__(self):
"""Get element name."""
return self.name


class LightLattice(tuple):
"""Low memory representation of a Lattice as a tuple of a 3x3 matrix."""

Expand All @@ -188,7 +191,9 @@ def __new__(cls, matrix):
lattice_matrix = np.array(matrix)
if lattice_matrix.shape != (3, 3):
raise ValueError("Lattice matrix must be 3x3.")
return super(LightLattice,cls).__new__(cls,tuple([tuple(v) for v in lattice_matrix.tolist()]))
return super(LightLattice, cls).__new__(
cls, tuple([tuple(v) for v in lattice_matrix.tolist()])
)

def as_dict(self) -> dict[str, list | str]:
"""Define MSONable-like as_dict."""
Expand All @@ -211,7 +216,7 @@ def volume(self) -> float:

class ElementReplica(EmmetReplica):
"""Define a flexible schema for elements and periodic sites.
The only required field in this model is `element`.
This is intended to mimic a `pymatgen` `.Element` object.
Additionally, the `lattice` and coordinates of the site can be specified
Expand Down Expand Up @@ -239,43 +244,59 @@ class ElementReplica(EmmetReplica):
was allowed to relax on.
"""

element : ElementSymbol = Field(description="The element.")
lattice : Matrix3D | None = Field(default = None, description="The lattice in 3x3 matrix form.")
cart_coords : Vector3D | None = Field(default = None, description="The postion of the site in Cartesian coordinates.")
frac_coords : Vector3D | None = Field(default = None, description="The postion of the site in direct lattice vector coordinates.")
charge : float | None = Field(default = None, description="The on-site charge.")
magmom : float | None = Field(default = None, description="The on-site magnetic moment.")
velocities : Vector3D | None = Field(default = None, description="The Cartesian components of the site velocity.")
selective_dynamics : tuple[bool, bool, bool] | None = Field(default = None, description="The degrees of freedom which are allowed to relax on the site.")

def model_post_init(self, __context : Any) -> None:
element: ElementSymbol = Field(description="The element.")
lattice: Matrix3D | None = Field(
default=None, description="The lattice in 3x3 matrix form."
)
cart_coords: Vector3D | None = Field(
default=None, description="The postion of the site in Cartesian coordinates."
)
frac_coords: Vector3D | None = Field(
default=None,
description="The postion of the site in direct lattice vector coordinates.",
)
charge: float | None = Field(default=None, description="The on-site charge.")
magmom: float | None = Field(
default=None, description="The on-site magnetic moment."
)
velocities: Vector3D | None = Field(
default=None, description="The Cartesian components of the site velocity."
)
selective_dynamics: tuple[bool, bool, bool] | None = Field(
default=None,
description="The degrees of freedom which are allowed to relax on the site.",
)

def model_post_init(self, __context: Any) -> None:
"""Ensure both Cartesian and direct coordinates are set, if necessary."""
if self.lattice:
if self.cart_coords is not None:
self.frac_coords = self.frac_coords or np.linalg.solve(
np.array(self.lattice).T, np.array(self.cart_coords)
)
np.array(self.lattice).T, np.array(self.cart_coords)
)
elif self.frac_coords is not None:
self.cart_coords = self.cart_coords or tuple(
np.matmul(np.array(self.lattice).T, np.array(self.frac_coords))
)

@classmethod
def from_pymatgen(cls, pmg_obj : Element | PeriodicSite) -> Self:
def from_pymatgen(cls, pmg_obj: Element | PeriodicSite) -> Self:
"""Convert a pymatgen .PeriodicSite or .Element to .ElementReplica.
Parameters
-----------
site : pymatgen .Element or .PeriodicSite
"""
if isinstance(pmg_obj, Element):
return cls(element = ElementSymbol(pmg_obj.name))
return cls(element=ElementSymbol(pmg_obj.name))

return cls(
element = ElementSymbol(next(iter(pmg_obj.species.remove_charges().as_dict()))),
lattice = LightLattice(pmg_obj.lattice.matrix),
frac_coords = pmg_obj.frac_coords,
cart_coords = pmg_obj.coords,
element=ElementSymbol(
next(iter(pmg_obj.species.remove_charges().as_dict()))
),
lattice=LightLattice(pmg_obj.lattice.matrix),
frac_coords=pmg_obj.frac_coords,
cart_coords=pmg_obj.coords,
)

def to_pymatgen(self) -> PeriodicSite:
Expand All @@ -285,20 +306,20 @@ def to_pymatgen(self) -> PeriodicSite:
self.frac_coords,
Lattice(self.lattice),
coords_are_cartesian=False,
properties = self.properties
properties=self.properties,
)

@property
def species(self) -> dict[str,int]:
def species(self) -> dict[str, int]:
"""Composition-like representation of site."""
return {self.element.name : 1}
return {self.element.name: 1}

@property
def properties(self) -> dict[str,float]:
def properties(self) -> dict[str, float]:
"""Aggregate optional properties defined on the site."""
props = {}
for k in SiteProperties.__members__:
if (prop := getattr(self,k,None)) is not None:
if (prop := getattr(self, k, None)) is not None:
props[k] = prop
return props

Expand All @@ -324,7 +345,7 @@ def Z(self) -> int:
def name(self) -> str:
"""Ensure compatibility with PeriodicSite."""
return self.element.name

@property
def species_string(self) -> str:
"""Ensure compatibility with PeriodicSite."""
Expand All @@ -337,18 +358,18 @@ def label(self) -> str:

def __str__(self):
return self.label

def add_attrs(self, **kwargs) -> ElementReplica:
"""Rapidly create a copy of this instance with additional fields set.
Parameters
-----------
**kwargs
Any of the fields defined in the model. This function is used to
add lattice and coordinate information to each site, and thereby
not store it in the StructureReplica object itself in addition to
each site.
Returns
-----------
ElementReplica
Expand All @@ -357,6 +378,7 @@ def add_attrs(self, **kwargs) -> ElementReplica:
config.update(**kwargs)
return ElementReplica(**config)


class StructureReplica(BaseModel):
"""Define a fixed schema structure.
Expand All @@ -367,10 +389,10 @@ class StructureReplica(BaseModel):
When the `.sites` attr of `StructureReplica` is accessed, all prior attributes
(respective aliases: `lattice`, `frac_coords`, and `coords`) are assigned to the
retrieved sites.
Compare this to pymatgen's .Structure, which stores the `lattice`, `frac_coords`,
Compare this to pymatgen's .Structure, which stores the `lattice`, `frac_coords`,
and `cart_coords` both in the .Structure object and each .PeriodicSite within it.
Parameters
-----------
lattice : LightLattice
Expand All @@ -385,21 +407,25 @@ class StructureReplica(BaseModel):
charge (optional) : float
The total charge on the structure.
"""

lattice : LightLattice = Field(description="The lattice in 3x3 matrix form.")
species : list[ElementReplica] = Field(description="The elements in the structure.")
frac_coords : ListMatrix3D = Field(description="The direct coordinates of the sites in the structure.")
cart_coords : ListMatrix3D = Field(description="The Cartesian coordinates of the sites in the structure.")
charge : float | None = Field(None, description="The net charge on the structure.")

lattice: LightLattice = Field(description="The lattice in 3x3 matrix form.")
species: list[ElementReplica] = Field(description="The elements in the structure.")
frac_coords: ListMatrix3D = Field(
description="The direct coordinates of the sites in the structure."
)
cart_coords: ListMatrix3D = Field(
description="The Cartesian coordinates of the sites in the structure."
)
charge: float | None = Field(None, description="The net charge on the structure.")

@property
def sites(self) -> list[ElementReplica]:
"""Return a list of sites in the structure with lattice and coordinate info."""
return [
species.add_attrs(
lattice = self.lattice,
cart_coords = self.cart_coords[idx],
frac_coords = self.frac_coords[idx],
lattice=self.lattice,
cart_coords=self.cart_coords[idx],
frac_coords=self.frac_coords[idx],
)
for idx, species in enumerate(self.species)
]
Expand Down Expand Up @@ -431,7 +457,7 @@ def num_sites(self) -> int:
@classmethod
def from_pymatgen(cls, pmg_obj: Structure) -> Self:
"""Create a StructureReplica from a pymatgen .Structure.
Parameters
-----------
pmg_obj : pymatgen .Structure
Expand All @@ -444,41 +470,46 @@ def from_pymatgen(cls, pmg_obj: Structure) -> Self:
raise ValueError(
"Currently, `StructureReplica` is intended to represent only ordered materials."
)

lattice = LightLattice(pmg_obj.lattice.matrix)
properties = [{} for _ in range(len(pmg_obj))]
for idx, site in enumerate(pmg_obj):
for k in ("charge","magmom","velocities","selective_dynamics"):
for k in ("charge", "magmom", "velocities", "selective_dynamics"):
if (prop := site.properties.get(k)) is not None:
properties[idx][k] = prop

species = [
ElementReplica(
element = ElementSymbol[next(iter(site.species.remove_charges().as_dict()))],
**properties[idx]
element=ElementSymbol[
next(iter(site.species.remove_charges().as_dict()))
],
**properties[idx],
)
for idx, site in enumerate(pmg_obj)
]

return cls(
lattice=lattice,
species = species,
frac_coords = [site.frac_coords for site in pmg_obj],
cart_coords = [site.coords for site in pmg_obj],
charge = pmg_obj.charge,
species=species,
frac_coords=[site.frac_coords for site in pmg_obj],
cart_coords=[site.coords for site in pmg_obj],
charge=pmg_obj.charge,
)

def to_pymatgen(self) -> Structure:
"""Convert to a pymatgen .Structure."""
return Structure.from_sites([site.to_periodic_site() for site in self], charge = self.charge)

return Structure.from_sites(
[site.to_periodic_site() for site in self], charge=self.charge
)

@classmethod
def from_poscar(cls, poscar_path: str | Path) -> Self:
"""Define convenience method to create a StructureReplica from a VASP POSCAR."""
return cls.from_structure(Poscar.from_file(poscar_path).structure)

def __str__(self):
"""Define format for printing a Structure."""

def _format_float(val: float | int) -> str:
nspace = 2 if val >= 0.0 else 1
return " " * nspace + f"{val:.8f}"
Expand Down

0 comments on commit b7d5c54

Please sign in to comment.