-
Notifications
You must be signed in to change notification settings - Fork 69
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #984 from orionarcher/only_schemas
Additional updates to ClassicalMD Schemas
- Loading branch information
Showing
9 changed files
with
394 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
from emmet.core.classical_md.tasks import MoleculeSpec, ClassicalMDTaskDocument |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
from emmet.core.classical_md.openmm.tasks import ( | ||
Calculation, | ||
CalculationInput, | ||
CalculationOutput, | ||
OpenMMTaskDocument, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,235 @@ | ||
"""Schemas for OpenMM tasks.""" | ||
|
||
from __future__ import annotations | ||
|
||
from pathlib import Path | ||
from typing import Optional, Union | ||
|
||
import pandas as pd # type: ignore[import-untyped] | ||
from emmet.core.vasp.task_valid import TaskState # type: ignore[import-untyped] | ||
from pydantic import BaseModel, Field | ||
|
||
from emmet.core.classical_md import ClassicalMDTaskDocument # type: ignore[import-untyped] | ||
from emmet.core.classical_md.tasks import HexBytes # type: ignore[import-untyped] | ||
|
||
|
||
class CalculationInput(BaseModel, extra="allow"): # type: ignore[call-arg] | ||
"""OpenMM input settings for a job, these are the attributes of the OpenMMMaker.""" | ||
|
||
n_steps: Optional[int] = Field( | ||
None, description="The number of simulation steps to run." | ||
) | ||
|
||
step_size: Optional[float] = Field( | ||
None, description="The size of each simulation step (picoseconds)." | ||
) | ||
|
||
temperature: Optional[float] = Field( | ||
None, description="The simulation temperature (kelvin)." | ||
) | ||
|
||
friction_coefficient: Optional[float] = Field( | ||
None, | ||
description=( | ||
"The friction coefficient for the integrator (inverse picoseconds)." | ||
), | ||
) | ||
|
||
platform_name: Optional[str] = Field( | ||
None, | ||
description=( | ||
"The name of the OpenMM platform to use, passed to " | ||
"Interchange.to_openmm_simulation." | ||
), | ||
) | ||
|
||
platform_properties: Optional[dict] = Field( | ||
None, | ||
description=( | ||
"Properties for the OpenMM platform, passed to " | ||
"Interchange.to_openmm_simulation." | ||
), | ||
) | ||
|
||
state_interval: Optional[int] = Field( | ||
None, | ||
description=( | ||
"State is saved every `state_interval` timesteps. For no state, set to 0." | ||
), | ||
) | ||
|
||
state_file_name: Optional[str] = Field( | ||
None, description="The name of the state file to save." | ||
) | ||
|
||
traj_interval: Optional[int] = Field( | ||
None, | ||
description=( | ||
"The trajectory is saved every `traj_interval` timesteps. For no trajectory, set to 0." | ||
), | ||
) | ||
|
||
wrap_traj: Optional[bool] = Field( | ||
None, description="Whether to wrap trajectory coordinates." | ||
) | ||
|
||
report_velocities: Optional[bool] = Field( | ||
None, description="Whether to report velocities in the trajectory file." | ||
) | ||
|
||
traj_file_name: Optional[str] = Field( | ||
None, description="The name of the trajectory file to save." | ||
) | ||
|
||
traj_file_type: Optional[str] = Field( | ||
None, | ||
description="The type of trajectory file to save.", | ||
) | ||
|
||
embed_traj: Optional[bool] = Field( | ||
None, | ||
description="Whether to embed the trajectory blob in CalculationOutput.", | ||
) | ||
|
||
|
||
class CalculationOutput(BaseModel): | ||
"""OpenMM calculation output files and extracted data.""" | ||
|
||
dir_name: Optional[str] = Field( | ||
None, description="The directory for this OpenMM task" | ||
) | ||
|
||
traj_file: Optional[str] = Field( | ||
None, description="Path to the trajectory file relative to `dir_name`" | ||
) | ||
|
||
traj_blob: Optional[HexBytes] = Field( | ||
None, description="Trajectory file as a binary blob" | ||
) | ||
|
||
state_file: Optional[str] = Field( | ||
None, description="Path to the state file relative to `dir_name`" | ||
) | ||
|
||
steps_reported: Optional[list[int]] = Field( | ||
None, description="Steps where outputs are reported" | ||
) | ||
|
||
time: Optional[list[float]] = Field(None, description="List of times") | ||
|
||
potential_energy: Optional[list[float]] = Field( | ||
None, description="List of potential energies" | ||
) | ||
|
||
kinetic_energy: Optional[list[float]] = Field( | ||
None, description="List of kinetic energies" | ||
) | ||
|
||
total_energy: Optional[list[float]] = Field( | ||
None, description="List of total energies" | ||
) | ||
|
||
temperature: Optional[list[float]] = Field(None, description="List of temperatures") | ||
|
||
volume: Optional[list[float]] = Field(None, description="List of volumes") | ||
|
||
density: Optional[list[float]] = Field(None, description="List of densities") | ||
|
||
elapsed_time: Optional[float] = Field( | ||
None, description="Elapsed time for the calculation (seconds)." | ||
) | ||
|
||
@classmethod | ||
def from_directory( | ||
cls, | ||
dir_name: Path | str, | ||
state_file_name: str, | ||
traj_file_name: str, | ||
elapsed_time: Optional[float] = None, | ||
n_steps: Optional[int] = None, | ||
state_interval: Optional[int] = None, | ||
embed_traj: bool = False, | ||
) -> CalculationOutput: | ||
"""Extract data from the output files in the directory.""" | ||
state_file = Path(dir_name) / state_file_name | ||
column_name_map = { | ||
'#"Step"': "steps_reported", | ||
"Potential Energy (kJ/mole)": "potential_energy", | ||
"Kinetic Energy (kJ/mole)": "kinetic_energy", | ||
"Total Energy (kJ/mole)": "total_energy", | ||
"Temperature (K)": "temperature", | ||
"Box Volume (nm^3)": "volume", | ||
"Density (g/mL)": "density", | ||
} | ||
state_is_not_empty = state_file.exists() and state_file.stat().st_size > 0 | ||
state_steps = state_interval and n_steps and n_steps // state_interval or 0 | ||
if state_is_not_empty and (state_steps > 0): | ||
data = pd.read_csv(state_file, header=0) | ||
data = data.rename(columns=column_name_map) | ||
data = data.filter(items=column_name_map.values()) | ||
data = data.iloc[-state_steps:] | ||
attributes = data.to_dict(orient="list") | ||
else: | ||
attributes = {name: None for name in column_name_map.values()} | ||
state_file_name = None # type: ignore[assignment] | ||
|
||
traj_file = Path(dir_name) / traj_file_name | ||
traj_is_not_empty = traj_file.exists() and traj_file.stat().st_size > 0 | ||
traj_file_name = traj_file_name if traj_is_not_empty else None # type: ignore | ||
|
||
if embed_traj and traj_is_not_empty: | ||
with open(traj_file, "rb") as f: | ||
traj_blob = f.read() | ||
else: | ||
traj_blob = None | ||
|
||
return CalculationOutput( | ||
dir_name=str(dir_name), | ||
elapsed_time=elapsed_time, | ||
traj_file=traj_file_name, | ||
state_file=state_file_name, | ||
traj_blob=traj_blob, | ||
**attributes, | ||
) | ||
|
||
|
||
class Calculation(BaseModel): | ||
"""All input and output data for an OpenMM calculation.""" | ||
|
||
dir_name: Optional[str] = Field( | ||
None, description="The directory for this OpenMM calculation" | ||
) | ||
|
||
has_openmm_completed: Optional[Union[TaskState, bool]] = Field( | ||
None, description="Whether OpenMM completed the calculation successfully" | ||
) | ||
|
||
input: Optional[CalculationInput] = Field( | ||
None, description="OpenMM input settings for the calculation" | ||
) | ||
output: Optional[CalculationOutput] = Field( | ||
None, description="The OpenMM calculation output" | ||
) | ||
|
||
completed_at: Optional[str] = Field( | ||
None, description="Timestamp for when the calculation was completed" | ||
) | ||
task_name: Optional[str] = Field( | ||
None, description="Name of task given by custodian (e.g., relax1, relax2)" | ||
) | ||
|
||
calc_type: Optional[str] = Field( | ||
None, | ||
description="Return calculation type (run type + task_type). or just new thing", | ||
) | ||
|
||
|
||
class OpenMMTaskDocument(ClassicalMDTaskDocument): | ||
"""Definition of the OpenMM task document.""" | ||
|
||
calcs_reversed: Optional[list[Calculation]] = Field( | ||
None, | ||
title="Calcs reversed data", | ||
description="Detailed data for each OpenMM calculation contributing to the " | ||
"task document.", | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,97 @@ | ||
"""Schemas for classical MD package.""" | ||
|
||
from __future__ import annotations | ||
|
||
from dataclasses import dataclass | ||
from datetime import datetime | ||
from typing import Optional | ||
from typing_extensions import Annotated | ||
import zlib | ||
from typing import Any | ||
|
||
from pydantic import ( | ||
BaseModel, | ||
Field, | ||
PlainValidator, | ||
PlainSerializer, | ||
WithJsonSchema, | ||
errors, | ||
) | ||
from monty.json import MSONable | ||
|
||
from emmet.core.vasp.task_valid import TaskState # type: ignore[import-untyped] | ||
|
||
|
||
def hex_bytes_validator(o: Any) -> bytes: | ||
if isinstance(o, bytes): | ||
return o | ||
elif isinstance(o, bytearray): | ||
return bytes(o) | ||
elif isinstance(o, str): | ||
return zlib.decompress(bytes.fromhex(o)) | ||
raise errors.BytesError() | ||
|
||
|
||
def hex_bytes_serializer(b: bytes) -> str: | ||
return zlib.compress(b).hex() | ||
|
||
|
||
HexBytes = Annotated[ | ||
bytes, | ||
PlainValidator(hex_bytes_validator), | ||
PlainSerializer(hex_bytes_serializer), | ||
WithJsonSchema({"type": "string"}), | ||
] | ||
|
||
|
||
@dataclass | ||
class MoleculeSpec(MSONable): | ||
"""A molecule schema to be output by OpenMMGenerators.""" | ||
|
||
name: str | ||
count: int | ||
charge_scaling: float | ||
charge_method: str | ||
openff_mol: str # a tk.Molecule object serialized with to_json | ||
|
||
|
||
class ClassicalMDTaskDocument(BaseModel, extra="allow"): # type: ignore[call-arg] | ||
"""Definition of the OpenMM task document.""" | ||
|
||
tags: Optional[list[str]] = Field( | ||
[], title="tag", description="Metadata tagged to a given task." | ||
) | ||
|
||
dir_name: Optional[str] = Field(None, description="The directory for this MD task") | ||
|
||
state: Optional[TaskState] = Field(None, description="State of this calculation") | ||
|
||
calcs_reversed: Optional[list] = Field( | ||
None, | ||
title="Calcs reversed data", | ||
description="Detailed data for each MD calculation contributing to " | ||
"the task document.", | ||
) | ||
|
||
interchange: Optional[HexBytes] = Field( | ||
None, | ||
description="A byte serialized OpenFF interchange object. " | ||
"To generate, the Interchange is serialized to json and" | ||
"the json is transformed to bytes with a utf-8 encoding. ", | ||
) | ||
|
||
molecule_specs: Optional[list[MoleculeSpec]] = Field( | ||
None, description="Molecules within the system." | ||
) | ||
|
||
force_field: Optional[str] = Field(None, description="The classical MD forcefield.") | ||
|
||
task_type: Optional[str] = Field(None, description="The type of calculation.") | ||
|
||
# task_label: Optional[str] = Field(None, description="A description of the task") | ||
# TODO: where does task_label get added | ||
|
||
last_updated: Optional[datetime] = Field( | ||
None, | ||
description="Timestamp for the most recent calculation for this task document", | ||
) |
Empty file.
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
from pathlib import Path | ||
|
||
import numpy as np | ||
from emmet.core.classical_md.openmm import CalculationOutput | ||
|
||
|
||
def test_calc_output_from_directory(test_dir): | ||
output_dir = test_dir / "classical_md" / "calc_output" | ||
# Call the from_directory function | ||
calc_out = CalculationOutput.from_directory( | ||
output_dir, | ||
"state.csv", | ||
"trajectory.dcd", | ||
elapsed_time=10.0, | ||
n_steps=1500, | ||
state_interval=100, | ||
) | ||
|
||
# Assert the expected attributes of the CalculationOutput object | ||
assert isinstance(calc_out, CalculationOutput) | ||
assert calc_out.dir_name == str(output_dir) | ||
assert calc_out.elapsed_time == 10.0 | ||
assert calc_out.traj_file == "trajectory.dcd" | ||
assert calc_out.state_file == "state.csv" | ||
|
||
# Assert the contents of the state data | ||
assert np.array_equal(calc_out.steps_reported[:3], [100, 200, 300]) | ||
assert np.allclose(calc_out.potential_energy[:3], [-26192.4, -25648.6, -25149.6]) | ||
assert np.allclose(calc_out.kinetic_energy[:3], [609.4, 1110.4, 1576.4], atol=0.1) | ||
assert np.allclose(calc_out.total_energy[:3], [-25583.1, -24538.1, -23573.2]) | ||
assert np.allclose(calc_out.temperature[:3], [29.6, 54.0, 76.6], atol=0.1) | ||
assert np.allclose(calc_out.volume[:3], [21.9, 21.9, 21.9], atol=0.1) | ||
assert np.allclose(calc_out.density[:3], [1.0, 0.99, 0.99], atol=0.1) | ||
|
||
# Assert the existence of the DCD and state files | ||
assert Path(calc_out.dir_name, calc_out.traj_file).exists() | ||
assert Path(calc_out.dir_name, calc_out.state_file).exists() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
#"Step","Potential Energy (kJ/mole)","Kinetic Energy (kJ/mole)","Total Energy (kJ/mole)","Temperature (K)","Box Volume (nm^3)","Density (g/mL)" | ||
1000,-26192.479795343505,609.4196914344211,-25583.060103909083,29.614683433627977,21.873459457163882,0.9954543820541577 | ||
2000,-25648.552310042112,1110.4255423058567,-24538.126767736256,53.96100811019054,21.896418997832697,0.9944105960647234 | ||
100,-26192.479795343505,609.4196914344211,-25583.060103909083,29.614683433627977,21.873459457163882,0.9954543820541577 | ||
200,-25648.552310042112,1110.4255423058567,-24538.126767736256,53.96100811019054,21.896418997832697,0.9944105960647234 | ||
300,-25149.615762218746,1576.3685252566356,-23573.24723696211,76.60345654458364,21.902472359270778,0.9941357628560747 | ||
400,-24737.794927124953,2048.908203465864,-22688.88672365909,99.5664706655355,22.025973627010902,0.988561569901083 | ||
500,-24236.83499597991,2362.5727053481387,-21874.26229063177,114.80896292197377,22.427603013306435,0.9708585912814258 | ||
600,-23952.5612414965,2744.7859341483563,-21207.775307348144,133.38257308612992,22.451255694017004,0.9698357795248072 | ||
700,-23561.24490686599,2981.7360809766687,-20579.508825889323,144.89713962623532,22.617654864431856,0.9627006512314953 | ||
800,-23257.599271389336,3263.6894834260456,-19993.90978796329,158.59863446457717,22.561532702010723,0.9650953840284621 | ||
900,-22989.069612634717,3515.5816009845585,-19473.48801165016,170.83930444254142,22.982310428217957,0.9474256792121007 | ||
1000,-22721.45251535368,3712.6360813537613,-19008.81643399992,180.4151454226322,23.05915218995153,0.9442685007650109 | ||
1100,-22390.948659398535,3880.3415879751556,-18510.60707142338,188.5647762246478,23.158223282091488,0.9402289114362175 | ||
1200,-22183.04180200165,4033.800234109862,-18149.24156789179,196.02208239526976,23.20718359651434,0.9382453056728586 | ||
1300,-22048.592668279394,4216.629234793596,-17831.9634334858,204.90663774167015,22.84203464974821,0.9532439382565145 | ||
1400,-21737.752101440245,4308.225194310769,-17429.526907129475,209.35773340370193,23.121043846718113,0.94174083193084 | ||
1500,-21603.87279039141,4662.04510079883,-16941.82768959258,226.55157316706197,23.185483890016233,0.9391234261318885 |
Binary file not shown.