Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chemiscope.metatensor_featurizer #357

Merged
merged 1 commit into from
Sep 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,12 @@ dependencies with:
pip install chemiscope[explore]
```

To use `chemiscope.metatensor_featurizer` for providing your trained model
to get the features for `chemiscope.explore`, install the dependencies with:
```bash
pip install chemiscope[metatensor]
```

## sphinx and sphinx-gallery integration

Chemiscope provides also extensions for `sphinx` and `sphinx-gallery` to
Expand Down
2 changes: 2 additions & 0 deletions docs/src/python/reference.rst
Original file line number Diff line number Diff line change
Expand Up @@ -26,3 +26,5 @@
.. autofunction:: chemiscope.ase_tensors_to_ellipsoids

.. autofunction:: chemiscope.explore

.. autofunction:: chemiscope.metatensor_featurizer
5 changes: 5 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -75,3 +75,8 @@ explore = [
"dscribe",
"scikit-learn",
]

metatensor = [
"metatensor",
"metatensor[torch]"
]
2 changes: 1 addition & 1 deletion python/chemiscope/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
extract_properties,
librascal_atomic_environments,
)
from .explore import explore # noqa: F401
from .explore import explore, metatensor_featurizer # noqa: F401
from .version import __version__ # noqa: F401

from .jupyter import show, show_input # noqa
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import os
from ..jupyter import show
from ._soap_pca import soap_pca_featurize
from ._metatensor import metatensor_featurizer

from .jupyter import show
__all__ = ["explore", "metatensor_featurizer"]


def explore(frames, featurize=None, properties=None, environments=None, mode="default"):
Expand Down Expand Up @@ -116,96 +118,9 @@ def soap_kpca_featurize(frames, environments):
# Add dimensionality reduction results to properties
properties["features"] = X_reduced

# Return chemiscope widget
return show(
frames=frames, properties=properties, mode=mode, environments=environments
frames=frames,
properties=properties,
environments=environments,
mode=mode,
)


def soap_pca_featurize(frames, environments=None):
"""
Computes SOAP features for a given set of atomic structures and performs
dimensionality reduction using PCA. Custom featurize functions should
have the same signature.

Note:
- The SOAP descriptor parameters are pre-defined.
- We use all available CPU cores for parallel computation of SOAP descriptors.
"""

# Check if dependencies were installed
try:
from dscribe.descriptors import SOAP
from sklearn.decomposition import PCA
except ImportError as e:
raise ImportError(
f"Required package not found: {str(e)}. Please install dependency "
+ "using 'pip install chemiscope[explore]'."
)
centers = None

# Get the atom indexes from the environments and pick related frames
if environments is not None:
centers = _extract_environment_indices(environments)

# Pick frames and properties related to the environments if provided
if environments is not None:
# Sort environments by structure id and atom id
environments = sorted(environments, key=lambda x: (x[0], x[1]))

# Check structure indexes
unique_structures = list({env[0] for env in environments})
if any(index >= len(frames) for index in unique_structures):
raise IndexError(
"Some structure indices in 'environments' are larger than the number of"
"frames"
)

if len(unique_structures) != len(frames):
# only include frames that are present in the user-provided environments
frames = [frames[index] for index in unique_structures]

# Get global species
species = set()
for frame in frames:
species.update(frame.get_chemical_symbols())
species = list(species)

# Check if periodic
is_periodic = all(all(frame.get_pbc()) for frame in frames)

# Initialize calculator
soap = SOAP(
species=species,
r_cut=4.5,
n_max=8,
l_max=6,
sigma=0.2,
rbf="gto",
average="outer",
periodic=is_periodic,
weighting={"function": "pow", "c": 1, "m": 5, "d": 1, "r0": 3.5},
compression={"mode": "mu1nu1"},
)

# Calculate descriptors
n_jobs = min(len(frames), os.cpu_count())
feats = soap.create(frames, centers=centers, n_jobs=n_jobs)

# Compute pca
pca = PCA(n_components=2)
return pca.fit_transform(feats)


def _extract_environment_indices(envs):
"""
Convert from chemiscope's environements to DScribe's centers selection

:param: list envs: each element is a list of [env_index, atom_index, cutoff]
"""
grouped_envs = {}
for [env_index, atom_index, _cutoff] in envs:
if env_index not in grouped_envs:
grouped_envs[env_index] = []
grouped_envs[env_index].append(atom_index)
return list(grouped_envs.values())
148 changes: 148 additions & 0 deletions python/chemiscope/explore/_metatensor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
import numpy as np


def metatensor_featurizer(
model,
extensions_directory=None,
check_consistency=False,
device=None,
):
"""
Create a featurizer function using a `metatensor`_ model to obtain the features from
structures. The model must be able to create a ``"feature"`` output.

:param model: model to use for the calculation. It can be a file path, a Python
instance of :py:class:`metatensor.torch.atomistic.MetatensorAtomisticModel`, or
the output of :py:func:`torch.jit.script` on
:py:class:`metatensor.torch.atomistic.MetatensorAtomisticModel`.
:param extensions_directory: a directory where model extensions are located
:param check_consistency: should we check the model for consistency when running,
defaults to False.
:param device: a torch device to use for the calculation. If ``None``, the function
will use the options in model's ``supported_device`` attribute.

:returns: a function that takes a list of frames and returns the features.

To use this function, additional dependencies are required. They can be installed
with the following command:

.. code:: bash

pip install chemiscope[metatensor]

Here is an example using a pre-trained `metatensor`_ model, stored as a ``model.pt``
file with the compiled extensions stored in the ``extensions/`` directory. To obtain
the details on how to get it, see metatensor `tutorial
<https://lab-cosmo.github.io/metatrain/latest/getting-started/usage.html>`_. The
frames are obtained by reading structures from a file that `ase <ase-io_>`_ can
read.

.. code-block:: python

import chemiscope
import ase.io

# Read the structures from the dataset frames =
ase.io.read("data/explore_c-gap-20u.xyz", ":")

# Provide model file ("model.pt") to `metatensor_featurizer`
featurizer = chemiscope.metatensor_featurizer(
"model.pt", extensions_directory="extensions"
)

chemiscope.explore(frames, featurize=featurizer)

For more examples, see the related :ref:`documentation
<chemiscope-explore-metatensor>`.

.. _metatensor: https://docs.metatensor.org/latest/index.html
.. _chemiscope-explore-metatensor:
https://chemiscope.org/docs/examples/7-explore-advanced.html#example-with-metatensor-model
"""

# Check if dependencies were installed
try:
from metatensor.torch.atomistic import ModelOutput
from metatensor.torch.atomistic.ase_calculator import MetatensorCalculator
except ImportError as e:
raise ImportError(
f"Required package not found: {e}. Please install the dependency using "
"'pip install chemiscope[metatensor]'."
)

# Initialize metatensor calculator
CALCULATOR = MetatensorCalculator(
model=model,
extensions_directory=extensions_directory,
check_consistency=check_consistency,
device=device,
)

def get_features(atoms, environments):
"""Run the model on a single atomic structure and extract the features"""
outputs = {"features": ModelOutput(per_atom=environments is not None)}
selected_atoms = _create_selected_atoms(environments)
output = CALCULATOR.run_model(atoms, outputs, selected_atoms)

return output["features"].block().values.detach().cpu().numpy()

def featurize(frames, environments):
if isinstance(frames, list):
envs_per_frame = _get_environments_per_frame(environments, len(frames))

outputs = [
get_features(frame, envs) for frame, envs in zip(frames, envs_per_frame)
]
return np.vstack(outputs)
else:
return get_features(frames, environments)

return featurize


def _get_environments_per_frame(environments, num_frames):
"""
Organize the environments for each frame

:param list environments: a list of atomic environments
:param int num_frames: total number of frames
"""
envs_per_frame = [None] * num_frames

if environments is None:
return envs_per_frame

frames_dict = {}

# Group environments by structure_id
for env in environments:
structure_id = env[0]
if structure_id not in frames_dict:
frames_dict[structure_id] = []
frames_dict[structure_id].append(env)

# Insert environments to the frame indices
for structure_id, envs in frames_dict.items():
if structure_id < num_frames:
envs_per_frame[structure_id] = envs

return envs_per_frame


def _create_selected_atoms(environments):
"""
Convert environments into ``Labels`` object, to be used as ``selected_atoms``

:param environments: a list of atom-centered environments
"""
import torch
from metatensor.torch import Labels

if environments is None:
return None

# Extract system and atom indices from environments, overriding the structure id to
# be 0 (since we will only give a single frame to the calculator at the same time).
values = torch.tensor([(0, atom_id) for _, atom_id, _ in environments])

return Labels(names=["system", "atom"], values=values)
90 changes: 90 additions & 0 deletions python/chemiscope/explore/_soap_pca.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
import os


def soap_pca_featurize(frames, environments=None):
"""
Computes SOAP features for a given set of atomic structures and performs
dimensionality reduction using PCA. Custom featurize functions should
have the same signature.

Note:
- The SOAP descriptor parameters are pre-defined.
- We use all available CPU cores for parallel computation of SOAP descriptors.
"""

# Check if dependencies were installed
try:
from dscribe.descriptors import SOAP
from sklearn.decomposition import PCA
except ImportError as e:
raise ImportError(
f"Required package not found: {str(e)}. Please install dependency "
+ "using 'pip install chemiscope[explore]'."
)
centers = None

# Get the atom indexes from the environments and pick related frames
if environments is not None:
centers = _extract_environment_indices(environments)

# Pick frames and properties related to the environments if provided
if environments is not None:
# Sort environments by structure id and atom id
environments = sorted(environments, key=lambda x: (x[0], x[1]))

# Check structure indexes
unique_structures = list({env[0] for env in environments})
if any(index >= len(frames) for index in unique_structures):
raise IndexError(
"Some structure indices in 'environments' are larger than the number of"
"frames"
)

if len(unique_structures) != len(frames):
# only include frames that are present in the user-provided environments
frames = [frames[index] for index in unique_structures]

# Get global species
species = set()
for frame in frames:
species.update(frame.get_chemical_symbols())
species = list(species)

# Check if periodic
is_periodic = all(all(frame.get_pbc()) for frame in frames)

# Initialize calculator
soap = SOAP(
species=species,
r_cut=4.5,
n_max=8,
l_max=6,
sigma=0.2,
rbf="gto",
average="outer",
periodic=is_periodic,
weighting={"function": "pow", "c": 1, "m": 5, "d": 1, "r0": 3.5},
compression={"mode": "mu1nu1"},
)

# Calculate descriptors
n_jobs = min(len(frames), os.cpu_count())
feats = soap.create(frames, centers=centers, n_jobs=n_jobs)

# Compute pca
pca = PCA(n_components=2)
return pca.fit_transform(feats)


def _extract_environment_indices(environments):
"""
Convert from chemiscope's environments to DScribe's centers selection

:param: list environments: each element is a list of [env_index, atom_index, cutoff]
"""
grouped_envs = {}
for [env_index, atom_index, _cutoff] in environments:
if env_index not in grouped_envs:
grouped_envs[env_index] = []
grouped_envs[env_index].append(atom_index)
return list(grouped_envs.values())
1 change: 1 addition & 0 deletions python/examples/7-explore-advanced.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,7 @@ def mace_mp0_tsne(frames, environments):
fetch_dataset("mace-mp-tsne-m3cd.json.gz")
chemiscope.show_input("data/mace-mp-tsne-m3cd.json.gz")


# %%
#
# Example with SOAP, t-SNE and environments
Expand Down
Loading
Loading