Skip to content

Commit

Permalink
Fix failing pkg_is_editable check in PyPI release (#78)
Browse files Browse the repository at this point in the history
* add StrEnums Targets + ModelType used in make_metrics_tables.py

* mv data/figshare matbench_discovery

update pyproject tool.setuptools.package-data and FIGSHARE_DIR

* add Targets col to main metrics table, remove meta cols from other metrics table variants

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* update site/ import paths for figshare URLs

* bump package.json deps

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
janosh and pre-commit-ci[bot] authored Jan 28, 2024
1 parent e7b6efa commit 8525349
Show file tree
Hide file tree
Showing 20 changed files with 508 additions and 350 deletions.
3 changes: 0 additions & 3 deletions data/figshare/readme.md

This file was deleted.

29 changes: 24 additions & 5 deletions matbench_discovery/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from pymatviz.utils import styled_html_tag

pkg_name = "matbench-discovery"
direct_url = Distribution.from_name(pkg_name).read_text("direct_url.json") or ""
direct_url = Distribution.from_name(pkg_name).read_text("direct_url.json") or "{}"
pkg_is_editable = json.loads(direct_url).get("dir_info", {}).get("editable", False)

PKG_DIR = os.path.dirname(__file__)
Expand All @@ -25,14 +25,14 @@
SITE_MODELS = f"{ROOT}/site/src/routes/models"
SCRIPTS = f"{ROOT}/scripts" # model and date analysis scripts
PDF_FIGS = f"{ROOT}/paper/figs" # directory for light-themed PDF figures
FIGSHARE_DIR = f"{PKG_DIR}/figshare"

for directory in (SITE_FIGS, SITE_MODELS, PDF_FIGS):
os.makedirs(directory, exist_ok=True)

os.makedirs(MP_DIR := f"{DATA_DIR}/mp", exist_ok=True)
os.makedirs(WBM_DIR := f"{DATA_DIR}/wbm", exist_ok=True)
# JSON files with URLS to data files on figshare
os.makedirs(FIGSHARE_DIR := f"{ROOT}/data/figshare", exist_ok=True)

# directory to store model checkpoints downloaded from wandb cloud storage
CHECKPOINT_DIR = f"{ROOT}/wandb/checkpoints"
Expand Down Expand Up @@ -108,15 +108,34 @@ class Task(StrEnum):
# version of the PES like CGCNN+P)
RP2RE = "RP2RE" # relaxed prototype to relaxed energy
IP2RE = "IP2RE" # initial prototype to relaxed energy
IS2E = "IS2E" # initial structure to energy
IS2RE_SR = "IS2RE-SR" # initial structure to relaxed energy after ML relaxation


@unique
class Targets(StrEnum):
"""Thermodynamic stability prediction task types."""

E = "E"
EFS = "EFS"
EFSM = "EFSM"


@unique
class ModelType(StrEnum):
"""Model types."""

GNN = "GNN"
UIP = "UIP-GNN"
BO_GNN = "BO-GNN"
Fingerprint = "Fingerprint"
Transformer = "Transformer"


# load figshare 1.0.0
with open(f"{FIGSHARE_DIR}/1.0.0.json") as file:
FIGSHARE_URLS = json.load(file)


# --- start global plot settings

ev_per_atom = styled_html_tag(
"(eV/atom)", tag="span", style="font-size: 0.8em; font-weight: lighter;"
)
Expand Down
File renamed without changes.
3 changes: 3 additions & 0 deletions matbench_discovery/figshare/readme.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
# Figshare File URLs

Files in this directory are auto-generated by [`scripts/upload_to_figshare.py`](../../scripts/upload_to_figshare.py).
2 changes: 1 addition & 1 deletion matbench_discovery/preds.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ class PredFiles(Files):
# original MEGNet straight from publication, not re-trained
megnet = "megnet/2022-11-18-megnet-wbm-IS2RE.csv.gz"
# # CHGNet-relaxed structures fed into MEGNet for formation energy prediction
# chgnet_megnet = "chgnet/2023-03-04-chgnet-wbm-IS2RE.csv.gz"
# chgnet_megnet = "chgnet/2023-03-06-chgnet-0.2.0-wbm-IS2RE.csv.gz"
# # M3GNet-relaxed structures fed into MEGNet for formation energy prediction
# m3gnet_megnet = "m3gnet/2022-10-31-m3gnet-wbm-IS2RE.csv.gz"
# megnet_rs2re = "megnet/2023-08-23-megnet-wbm-RS2RE.csv.gz"
Expand Down
9 changes: 4 additions & 5 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -50,13 +50,12 @@ Package = "https://pypi.org/project/matbench-discovery"
test = ["pytest", "pytest-cov"]
# how to specify git deps: https://stackoverflow.com/a/73572379
running-models = [
# aviary commented-out since dep on git repo raises "Invalid value for requires_dist"
# when attempting PyPI publish
# "aviary@git+https://github.com/CompRhys/aviary",
"alignn",
"chgnet",
"jarvis-tools",
# torch needs to install before aviary
"torch",

"aviary@git+https://github.com/CompRhys/aviary",
"m3gnet",
"mace-torch",
"maml",
Expand All @@ -71,7 +70,7 @@ include = ["matbench_discovery*"]
exclude = ["tests", "tests.*"]

[tool.setuptools.package-data]
matbench_discovery = ["../data/figshare/*"]
matbench_discovery = ["figshare/*"]

[tool.distutils.bdist_wheel]
universal = true
Expand Down
62 changes: 38 additions & 24 deletions scripts/model_figs/make_metrics_tables.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,15 @@
from pymatviz.utils import si_fmt
from sklearn.dummy import DummyClassifier

from matbench_discovery import PDF_FIGS, SCRIPTS, SITE_FIGS, Key, Task
from matbench_discovery import (
PDF_FIGS,
SCRIPTS,
SITE_FIGS,
Key,
ModelType,
Targets,
Task,
)
from matbench_discovery.data import DATA_FILES, df_wbm
from matbench_discovery.metrics import stable_metrics
from matbench_discovery.models import MODEL_METADATA
Expand Down Expand Up @@ -82,25 +90,28 @@


# %% for each model this ontology dict specifies (training type, test type, model type)
model_type_col, targets_col = "Model Type", "Targets"
ontology_cols = ["Trained", "Task", model_type_col, targets_col]
ontology = {
"ALIGNN": (Task.RS2RE, Task.IS2RE, "GNN"),
# "ALIGNN Pretrained": (Task.RS2RE, Task.IS2RE, "GNN"),
"CHGNet": (Task.S2EFSM, "IS2RE-SR", "UIP-GNN"),
"MACE": (Task.S2EFS, "IS2RE-SR", "UIP-GNN"),
"M3GNet": (Task.S2EFS, "IS2RE-SR", "UIP-GNN"),
"MEGNet": (Task.RS2RE, "IS2E", "GNN"),
"MEGNet RS2RE": (Task.RS2RE, "IS2E", "GNN"),
"CGCNN": (Task.RS2RE, "IS2E", "GNN"),
"CGCNN+P": (Task.S2RE, Task.IS2RE, "GNN"),
"Wrenformer": (Task.RP2RE, Task.IP2RE, "Transformer"),
"BOWSR": (Task.RS2RE, "IS2RE-BO", "BO-GNN"),
"Voronoi RF": (Task.RS2RE, "IS2E", "Fingerprint"),
"M3GNet→MEGNet": (Task.S2EFS, "IS2RE-SR", "UIP-GNN"),
"CHGNet→MEGNet": (Task.S2EFSM, "IS2RE-SR", "UIP-GNN"),
"PFP": (Task.S2EFS, Task.IS2RE, "UIP"),
"Dummy": ("", "", ""),
"ALIGNN": (Task.RS2RE, Task.IS2RE, ModelType.GNN, Targets.E),
# "ALIGNN Pretrained": (Task.RS2RE, Task.IS2RE, ModelType.GNN, Targets.E),
"CHGNet": (Task.S2EFSM, Task.IS2RE_SR, ModelType.UIP, Targets.EFSM),
"chgnet_no_relax": (Task.S2EFSM, "IS2RE-STATIC", ModelType.UIP, Targets.EFSM),
"MACE": (Task.S2EFS, Task.IS2RE_SR, ModelType.UIP, Targets.EFS),
"M3GNet": (Task.S2EFS, Task.IS2RE_SR, ModelType.UIP, Targets.EFS),
"MEGNet": (Task.RS2RE, Task.IS2E, ModelType.GNN, Targets.E),
"MEGNet RS2RE": (Task.RS2RE, Task.IS2E, ModelType.GNN, Targets.E),
"CGCNN": (Task.RS2RE, Task.IS2E, ModelType.GNN, Targets.E),
"CGCNN+P": ("S2RE", Task.IS2RE, ModelType.GNN, Targets.E),
"Wrenformer": ("RP2RE", "IP2E", ModelType.Transformer, Targets.E),
"BOWSR": (Task.RS2RE, "IS2RE-BO", ModelType.BO_GNN, Targets.E),
"Voronoi RF": (Task.RS2RE, Task.IS2E, "Fingerprint", Targets.E),
"M3GNet→MEGNet": (Task.S2EFS, Task.IS2RE_SR, ModelType.UIP, Targets.EFS),
"CHGNet→MEGNet": (Task.S2EFSM, Task.IS2RE_SR, ModelType.UIP, Targets.EFSM),
"PFP": (Task.S2EFS, Task.IS2RE, ModelType.UIP, Targets.EFS),
"GNoMe": (Task.S2EFS, Task.IS2RE, ModelType.UIP, Targets.EFS),
"Dummy": ("", "", "", ""),
}
ontology_cols = ["Trained", "Deployed", model_type_col := "Model Type"]
df_ont = pd.DataFrame(ontology, index=ontology_cols)
# RS2RE = relaxed structure to relaxed energy
# RP2RE = relaxed prototype to predicted energy
Expand Down Expand Up @@ -128,10 +139,11 @@
# when setting to True, uncomment the lines chgnet_megnet, m3gnet_megnet, megnet_rs2re
# in PredFiles!
make_uip_megnet_comparison = False
show_cols = (
f"F1,DAF,Precision,Accuracy,TPR,TNR,MAE,RMSE,{R2_col},"
f"{train_size_col},{model_type_col}".split(",")
)
meta_cols = [train_size_col, model_type_col, targets_col]
show_cols = [
*f"F1,DAF,Precision,Accuracy,TPR,TNR,MAE,RMSE,{R2_col}".split(","),
*meta_cols,
]

for label, df in (
("", df_metrics),
Expand All @@ -154,14 +166,16 @@
# abbreviate long column names
df_filtered = df_filtered.rename(columns={"Precision": "Prec", "Accuracy": "Acc"})

if label == "-first-10k":
if "-first-10k" in label:
# hide redundant metrics for first 10k preds (all TPR = 1, TNR = 0)
df_filtered = df_filtered.drop(["TPR", "TNR"], axis="columns")
if label != "-uniq-protos": # only show training size and model type once
df_filtered = df_filtered.drop(meta_cols, axis="columns")

styler = (
df_filtered.style.format(
# render integers without decimal places
{k: "{:,.0f}" for k in "TP FN FP TN".split()},
{key: "{:,.0f}" for key in "TP FN FP TN".split()},
precision=2, # render floats with 2 decimals
na_rep="", # render NaNs as empty string
)
Expand Down
2 changes: 2 additions & 0 deletions scripts/model_figs/per_element_errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,8 +250,10 @@
df_frac_comp * (df_each_err[model].to_numpy()[:, None]),
log=True,
cbar_title=f"{model} convex hull distance errors (eV/atom)",
cbar_title_kwds=dict(fontsize=16),
x_range=(-0.5, 0.5),
symbol_pos=(0.1, 0.8),
colormap="viridis",
)

img_name = f"ptable-each-error-hists-{model.lower().replace(' ', '-')}"
Expand Down
10 changes: 5 additions & 5 deletions site/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,10 @@
"@iconify/svelte": "^3.1.6",
"@rollup/plugin-yaml": "^4.1.2",
"@sveltejs/adapter-static": "^3.0.1",
"@sveltejs/kit": "^2.4.1",
"@sveltejs/vite-plugin-svelte": "^3.0.1",
"@typescript-eslint/eslint-plugin": "^6.19.0",
"@typescript-eslint/parser": "^6.19.0",
"@sveltejs/kit": "^2.5.0",
"@sveltejs/vite-plugin-svelte": "^3.0.2",
"@typescript-eslint/eslint-plugin": "^6.19.1",
"@typescript-eslint/parser": "^6.19.1",
"d3-scale-chromatic": "^3.0.0",
"elementari": "^0.2.3",
"eslint": "^8.56.0",
Expand All @@ -44,7 +44,7 @@
"svelte-multiselect": "^10.2.0",
"svelte-preprocess": "^5.1.3",
"svelte-toc": "^0.5.7",
"svelte-zoo": "^0.4.9",
"svelte-zoo": "^0.4.10",
"svelte2tsx": "^0.7.0",
"tslib": "^2.6.2",
"typescript": "5.3.3",
Expand Down
Loading

0 comments on commit 8525349

Please sign in to comment.