Skip to content

Commit

Permalink
Merge pull request #984 from orionarcher/only_schemas
Browse files Browse the repository at this point in the history
Additional updates to ClassicalMD Schemas
  • Loading branch information
tsmathis authored Apr 17, 2024
2 parents 7c582b4 + 69b687e commit 1af52bc
Show file tree
Hide file tree
Showing 9 changed files with 394 additions and 0 deletions.
1 change: 1 addition & 0 deletions emmet-core/emmet/core/classical_md/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from emmet.core.classical_md.tasks import MoleculeSpec, ClassicalMDTaskDocument
6 changes: 6 additions & 0 deletions emmet-core/emmet/core/classical_md/openmm/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
from emmet.core.classical_md.openmm.tasks import (
Calculation,
CalculationInput,
CalculationOutput,
OpenMMTaskDocument,
)
235 changes: 235 additions & 0 deletions emmet-core/emmet/core/classical_md/openmm/tasks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,235 @@
"""Schemas for OpenMM tasks."""

from __future__ import annotations

from pathlib import Path
from typing import Optional, Union

import pandas as pd # type: ignore[import-untyped]
from emmet.core.vasp.task_valid import TaskState # type: ignore[import-untyped]
from pydantic import BaseModel, Field

from emmet.core.classical_md import ClassicalMDTaskDocument # type: ignore[import-untyped]
from emmet.core.classical_md.tasks import HexBytes # type: ignore[import-untyped]


class CalculationInput(BaseModel, extra="allow"): # type: ignore[call-arg]
"""OpenMM input settings for a job, these are the attributes of the OpenMMMaker."""

n_steps: Optional[int] = Field(
None, description="The number of simulation steps to run."
)

step_size: Optional[float] = Field(
None, description="The size of each simulation step (picoseconds)."
)

temperature: Optional[float] = Field(
None, description="The simulation temperature (kelvin)."
)

friction_coefficient: Optional[float] = Field(
None,
description=(
"The friction coefficient for the integrator (inverse picoseconds)."
),
)

platform_name: Optional[str] = Field(
None,
description=(
"The name of the OpenMM platform to use, passed to "
"Interchange.to_openmm_simulation."
),
)

platform_properties: Optional[dict] = Field(
None,
description=(
"Properties for the OpenMM platform, passed to "
"Interchange.to_openmm_simulation."
),
)

state_interval: Optional[int] = Field(
None,
description=(
"State is saved every `state_interval` timesteps. For no state, set to 0."
),
)

state_file_name: Optional[str] = Field(
None, description="The name of the state file to save."
)

traj_interval: Optional[int] = Field(
None,
description=(
"The trajectory is saved every `traj_interval` timesteps. For no trajectory, set to 0."
),
)

wrap_traj: Optional[bool] = Field(
None, description="Whether to wrap trajectory coordinates."
)

report_velocities: Optional[bool] = Field(
None, description="Whether to report velocities in the trajectory file."
)

traj_file_name: Optional[str] = Field(
None, description="The name of the trajectory file to save."
)

traj_file_type: Optional[str] = Field(
None,
description="The type of trajectory file to save.",
)

embed_traj: Optional[bool] = Field(
None,
description="Whether to embed the trajectory blob in CalculationOutput.",
)


class CalculationOutput(BaseModel):
"""OpenMM calculation output files and extracted data."""

dir_name: Optional[str] = Field(
None, description="The directory for this OpenMM task"
)

traj_file: Optional[str] = Field(
None, description="Path to the trajectory file relative to `dir_name`"
)

traj_blob: Optional[HexBytes] = Field(
None, description="Trajectory file as a binary blob"
)

state_file: Optional[str] = Field(
None, description="Path to the state file relative to `dir_name`"
)

steps_reported: Optional[list[int]] = Field(
None, description="Steps where outputs are reported"
)

time: Optional[list[float]] = Field(None, description="List of times")

potential_energy: Optional[list[float]] = Field(
None, description="List of potential energies"
)

kinetic_energy: Optional[list[float]] = Field(
None, description="List of kinetic energies"
)

total_energy: Optional[list[float]] = Field(
None, description="List of total energies"
)

temperature: Optional[list[float]] = Field(None, description="List of temperatures")

volume: Optional[list[float]] = Field(None, description="List of volumes")

density: Optional[list[float]] = Field(None, description="List of densities")

elapsed_time: Optional[float] = Field(
None, description="Elapsed time for the calculation (seconds)."
)

@classmethod
def from_directory(
cls,
dir_name: Path | str,
state_file_name: str,
traj_file_name: str,
elapsed_time: Optional[float] = None,
n_steps: Optional[int] = None,
state_interval: Optional[int] = None,
embed_traj: bool = False,
) -> CalculationOutput:
"""Extract data from the output files in the directory."""
state_file = Path(dir_name) / state_file_name
column_name_map = {
'#"Step"': "steps_reported",
"Potential Energy (kJ/mole)": "potential_energy",
"Kinetic Energy (kJ/mole)": "kinetic_energy",
"Total Energy (kJ/mole)": "total_energy",
"Temperature (K)": "temperature",
"Box Volume (nm^3)": "volume",
"Density (g/mL)": "density",
}
state_is_not_empty = state_file.exists() and state_file.stat().st_size > 0
state_steps = state_interval and n_steps and n_steps // state_interval or 0
if state_is_not_empty and (state_steps > 0):
data = pd.read_csv(state_file, header=0)
data = data.rename(columns=column_name_map)
data = data.filter(items=column_name_map.values())
data = data.iloc[-state_steps:]
attributes = data.to_dict(orient="list")
else:
attributes = {name: None for name in column_name_map.values()}
state_file_name = None # type: ignore[assignment]

traj_file = Path(dir_name) / traj_file_name
traj_is_not_empty = traj_file.exists() and traj_file.stat().st_size > 0
traj_file_name = traj_file_name if traj_is_not_empty else None # type: ignore

if embed_traj and traj_is_not_empty:
with open(traj_file, "rb") as f:
traj_blob = f.read()
else:
traj_blob = None

return CalculationOutput(
dir_name=str(dir_name),
elapsed_time=elapsed_time,
traj_file=traj_file_name,
state_file=state_file_name,
traj_blob=traj_blob,
**attributes,
)


class Calculation(BaseModel):
"""All input and output data for an OpenMM calculation."""

dir_name: Optional[str] = Field(
None, description="The directory for this OpenMM calculation"
)

has_openmm_completed: Optional[Union[TaskState, bool]] = Field(
None, description="Whether OpenMM completed the calculation successfully"
)

input: Optional[CalculationInput] = Field(
None, description="OpenMM input settings for the calculation"
)
output: Optional[CalculationOutput] = Field(
None, description="The OpenMM calculation output"
)

completed_at: Optional[str] = Field(
None, description="Timestamp for when the calculation was completed"
)
task_name: Optional[str] = Field(
None, description="Name of task given by custodian (e.g., relax1, relax2)"
)

calc_type: Optional[str] = Field(
None,
description="Return calculation type (run type + task_type). or just new thing",
)


class OpenMMTaskDocument(ClassicalMDTaskDocument):
"""Definition of the OpenMM task document."""

calcs_reversed: Optional[list[Calculation]] = Field(
None,
title="Calcs reversed data",
description="Detailed data for each OpenMM calculation contributing to the "
"task document.",
)
97 changes: 97 additions & 0 deletions emmet-core/emmet/core/classical_md/tasks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
"""Schemas for classical MD package."""

from __future__ import annotations

from dataclasses import dataclass
from datetime import datetime
from typing import Optional
from typing_extensions import Annotated
import zlib
from typing import Any

from pydantic import (
BaseModel,
Field,
PlainValidator,
PlainSerializer,
WithJsonSchema,
errors,
)
from monty.json import MSONable

from emmet.core.vasp.task_valid import TaskState # type: ignore[import-untyped]


def hex_bytes_validator(o: Any) -> bytes:
if isinstance(o, bytes):
return o
elif isinstance(o, bytearray):
return bytes(o)
elif isinstance(o, str):
return zlib.decompress(bytes.fromhex(o))
raise errors.BytesError()


def hex_bytes_serializer(b: bytes) -> str:
return zlib.compress(b).hex()


HexBytes = Annotated[
bytes,
PlainValidator(hex_bytes_validator),
PlainSerializer(hex_bytes_serializer),
WithJsonSchema({"type": "string"}),
]


@dataclass
class MoleculeSpec(MSONable):
"""A molecule schema to be output by OpenMMGenerators."""

name: str
count: int
charge_scaling: float
charge_method: str
openff_mol: str # a tk.Molecule object serialized with to_json


class ClassicalMDTaskDocument(BaseModel, extra="allow"): # type: ignore[call-arg]
"""Definition of the OpenMM task document."""

tags: Optional[list[str]] = Field(
[], title="tag", description="Metadata tagged to a given task."
)

dir_name: Optional[str] = Field(None, description="The directory for this MD task")

state: Optional[TaskState] = Field(None, description="State of this calculation")

calcs_reversed: Optional[list] = Field(
None,
title="Calcs reversed data",
description="Detailed data for each MD calculation contributing to "
"the task document.",
)

interchange: Optional[HexBytes] = Field(
None,
description="A byte serialized OpenFF interchange object. "
"To generate, the Interchange is serialized to json and"
"the json is transformed to bytes with a utf-8 encoding. ",
)

molecule_specs: Optional[list[MoleculeSpec]] = Field(
None, description="Molecules within the system."
)

force_field: Optional[str] = Field(None, description="The classical MD forcefield.")

task_type: Optional[str] = Field(None, description="The type of calculation.")

# task_label: Optional[str] = Field(None, description="A description of the task")
# TODO: where does task_label get added

last_updated: Optional[datetime] = Field(
None,
description="Timestamp for the most recent calculation for this task document",
)
Empty file.
Empty file.
37 changes: 37 additions & 0 deletions emmet-core/tests/classical_md/openmm_md/test_tasks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
from pathlib import Path

import numpy as np
from emmet.core.classical_md.openmm import CalculationOutput


def test_calc_output_from_directory(test_dir):
output_dir = test_dir / "classical_md" / "calc_output"
# Call the from_directory function
calc_out = CalculationOutput.from_directory(
output_dir,
"state.csv",
"trajectory.dcd",
elapsed_time=10.0,
n_steps=1500,
state_interval=100,
)

# Assert the expected attributes of the CalculationOutput object
assert isinstance(calc_out, CalculationOutput)
assert calc_out.dir_name == str(output_dir)
assert calc_out.elapsed_time == 10.0
assert calc_out.traj_file == "trajectory.dcd"
assert calc_out.state_file == "state.csv"

# Assert the contents of the state data
assert np.array_equal(calc_out.steps_reported[:3], [100, 200, 300])
assert np.allclose(calc_out.potential_energy[:3], [-26192.4, -25648.6, -25149.6])
assert np.allclose(calc_out.kinetic_energy[:3], [609.4, 1110.4, 1576.4], atol=0.1)
assert np.allclose(calc_out.total_energy[:3], [-25583.1, -24538.1, -23573.2])
assert np.allclose(calc_out.temperature[:3], [29.6, 54.0, 76.6], atol=0.1)
assert np.allclose(calc_out.volume[:3], [21.9, 21.9, 21.9], atol=0.1)
assert np.allclose(calc_out.density[:3], [1.0, 0.99, 0.99], atol=0.1)

# Assert the existence of the DCD and state files
assert Path(calc_out.dir_name, calc_out.traj_file).exists()
assert Path(calc_out.dir_name, calc_out.state_file).exists()
18 changes: 18 additions & 0 deletions test_files/classical_md/calc_output/state.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
#"Step","Potential Energy (kJ/mole)","Kinetic Energy (kJ/mole)","Total Energy (kJ/mole)","Temperature (K)","Box Volume (nm^3)","Density (g/mL)"
1000,-26192.479795343505,609.4196914344211,-25583.060103909083,29.614683433627977,21.873459457163882,0.9954543820541577
2000,-25648.552310042112,1110.4255423058567,-24538.126767736256,53.96100811019054,21.896418997832697,0.9944105960647234
100,-26192.479795343505,609.4196914344211,-25583.060103909083,29.614683433627977,21.873459457163882,0.9954543820541577
200,-25648.552310042112,1110.4255423058567,-24538.126767736256,53.96100811019054,21.896418997832697,0.9944105960647234
300,-25149.615762218746,1576.3685252566356,-23573.24723696211,76.60345654458364,21.902472359270778,0.9941357628560747
400,-24737.794927124953,2048.908203465864,-22688.88672365909,99.5664706655355,22.025973627010902,0.988561569901083
500,-24236.83499597991,2362.5727053481387,-21874.26229063177,114.80896292197377,22.427603013306435,0.9708585912814258
600,-23952.5612414965,2744.7859341483563,-21207.775307348144,133.38257308612992,22.451255694017004,0.9698357795248072
700,-23561.24490686599,2981.7360809766687,-20579.508825889323,144.89713962623532,22.617654864431856,0.9627006512314953
800,-23257.599271389336,3263.6894834260456,-19993.90978796329,158.59863446457717,22.561532702010723,0.9650953840284621
900,-22989.069612634717,3515.5816009845585,-19473.48801165016,170.83930444254142,22.982310428217957,0.9474256792121007
1000,-22721.45251535368,3712.6360813537613,-19008.81643399992,180.4151454226322,23.05915218995153,0.9442685007650109
1100,-22390.948659398535,3880.3415879751556,-18510.60707142338,188.5647762246478,23.158223282091488,0.9402289114362175
1200,-22183.04180200165,4033.800234109862,-18149.24156789179,196.02208239526976,23.20718359651434,0.9382453056728586
1300,-22048.592668279394,4216.629234793596,-17831.9634334858,204.90663774167015,22.84203464974821,0.9532439382565145
1400,-21737.752101440245,4308.225194310769,-17429.526907129475,209.35773340370193,23.121043846718113,0.94174083193084
1500,-21603.87279039141,4662.04510079883,-16941.82768959258,226.55157316706197,23.185483890016233,0.9391234261318885
Binary file not shown.

0 comments on commit 1af52bc

Please sign in to comment.