diff --git a/mp_api/client/core/__init__.py b/mp_api/client/core/__init__.py index 76b81870..dc2047f0 100644 --- a/mp_api/client/core/__init__.py +++ b/mp_api/client/core/__init__.py @@ -1,4 +1,5 @@ from __future__ import annotations -from .client import BaseRester, MPRestError, MPRestWarning +from .client import BaseRester +from .exceptions import MPRestError, MPRestWarning from .settings import MAPIClientSettings diff --git a/mp_api/client/core/client.py b/mp_api/client/core/client.py index 793d0902..918c88f6 100644 --- a/mp_api/client/core/client.py +++ b/mp_api/client/core/client.py @@ -30,6 +30,7 @@ from tqdm.auto import tqdm from urllib3.util.retry import Retry +from mp_api.client.core.exceptions import MPRestError from mp_api.client.core.settings import MAPIClientSettings from mp_api.client.core.utils import load_json, validate_ids @@ -92,11 +93,11 @@ def __init__( session: requests.Session | None = None, s3_client: Any | None = None, debug: bool = False, - monty_decode: bool = True, use_document_model: bool = True, timeout: int = 20, headers: dict | None = None, mute_progress_bars: bool = SETTINGS.MUTE_PROGRESS_BARS, + **kwargs, ): """Initialize the REST API helper class. @@ -121,13 +122,13 @@ def __init__( advanced usage only. s3_client: boto3 S3 client object with which to connect to the object stores.ct to the object stores.ct to the object stores. debug: if True, print the URL for every request - monty_decode: Decode the data using monty into python objects use_document_model: If False, skip the creating the document model and return data as a dictionary. This can be simpler to work with but bypasses data validation and will not give auto-complete for available fields. timeout: Time in seconds to wait until a request timeout error is thrown headers: Custom headers for localhost connections. mute_progress_bars: Whether to disable progress bars. + **kwargs: access to legacy kwargs that may be in the process of being deprecated """ # TODO: think about how to migrate from PMG_MAPI_KEY self.api_key = api_key or os.getenv("MP_API_KEY") @@ -136,7 +137,6 @@ def __init__( ) self.debug = debug self.include_user_agent = include_user_agent - self.monty_decode = monty_decode self.use_document_model = use_document_model self.timeout = timeout self.headers = headers or {} @@ -151,6 +151,12 @@ def __init__( self._session = session self._s3_client = s3_client + if "monty_decode" in kwargs: + warnings.warn( + "Ignoring `monty_decode`, as it is no longer a supported option in `mp_api`." + "The client by default returns results consistent with `monty_decode=True`." + ) + @property def session(self) -> requests.Session: if not self._session: @@ -265,7 +271,7 @@ def _post_resource( response = self.session.post(url, json=payload, verify=True, params=params) if response.status_code == 200: - data = load_json(response.text, deser=self.monty_decode) + data = load_json(response.text) if self.document_model and use_document_model: if isinstance(data["data"], dict): data["data"] = self.document_model.model_validate(data["data"]) # type: ignore @@ -333,7 +339,7 @@ def _patch_resource( response = self.session.patch(url, json=payload, verify=True, params=params) if response.status_code == 200: - data = load_json(response.text, deser=self.monty_decode) + data = load_json(response.text) if self.document_model and use_document_model: if isinstance(data["data"], dict): data["data"] = self.document_model.model_validate(data["data"]) # type: ignore @@ -384,10 +390,7 @@ def _query_open_data( Returns: dict: MontyDecoded data """ - if not decoder: - - def decoder(x): - return load_json(x, deser=self.monty_decode) + decoder = decoder or load_json file = open( f"s3://{bucket}/{key}", @@ -997,7 +1000,7 @@ def _submit_request_and_process( ) if response.status_code == 200: - data = load_json(response.text, deser=self.monty_decode) + data = load_json(response.text) # other sub-urls may use different document models # the client does not handle this in a particularly smart way currently if self.document_model and use_document_model: @@ -1302,12 +1305,10 @@ def count(self, criteria: dict | None = None) -> int | str: """ criteria = criteria or {} user_preferences = ( - self.monty_decode, self.use_document_model, self.mute_progress_bars, ) - self.monty_decode, self.use_document_model, self.mute_progress_bars = ( - False, + self.use_document_model, self.mute_progress_bars = ( False, True, ) # do not waste cycles decoding @@ -1329,7 +1330,6 @@ def count(self, criteria: dict | None = None) -> int | str: ) ( - self.monty_decode, self.use_document_model, self.mute_progress_bars, ) = user_preferences @@ -1351,11 +1351,3 @@ def __str__(self): # pragma: no cover f"{self.__class__.__name__} connected to {self.endpoint}\n\n" f"Available fields: {', '.join(self.available_fields)}\n\n" ) - - -class MPRestError(Exception): - """Raised when the query has problems, e.g., bad query format.""" - - -class MPRestWarning(Warning): - """Raised when a query is malformed but interpretable.""" diff --git a/mp_api/client/core/exceptions.py b/mp_api/client/core/exceptions.py new file mode 100644 index 00000000..fa9f8793 --- /dev/null +++ b/mp_api/client/core/exceptions.py @@ -0,0 +1,10 @@ +"""Define custom exceptions and warnings for the client.""" +from __future__ import annotations + + +class MPRestError(Exception): + """Raised when the query has problems, e.g., bad query format.""" + + +class MPRestWarning(Warning): + """Raised when a query is malformed but interpretable.""" diff --git a/mp_api/client/core/settings.py b/mp_api/client/core/settings.py index 200b6778..9c2955d5 100644 --- a/mp_api/client/core/settings.py +++ b/mp_api/client/core/settings.py @@ -2,7 +2,7 @@ from multiprocessing import cpu_count from typing import List -from pydantic import Field +from pydantic import Field, field_validator from pydantic_settings import BaseSettings, SettingsConfigDict from pymatgen.core import _load_pmg_settings @@ -14,6 +14,7 @@ _MUTE_PROGRESS_BAR = PMG_SETTINGS.get("MPRESTER_MUTE_PROGRESS_BARS", False) _MAX_HTTP_URL_LENGTH = PMG_SETTINGS.get("MPRESTER_MAX_HTTP_URL_LENGTH", 2000) _MAX_LIST_LENGTH = min(PMG_SETTINGS.get("MPRESTER_MAX_LIST_LENGTH", 10000), 10000) +_DEFAULT_ENDPOINT = "https://api.materialsproject.org/" try: CPU_COUNT = cpu_count() @@ -80,11 +81,21 @@ class MAPIClientSettings(BaseSettings): ) MIN_EMMET_VERSION: str = Field( - "0.54.0", description="Minimum compatible version of emmet-core for the client." + "0.86.3rc0", + description="Minimum compatible version of emmet-core for the client.", ) MAX_LIST_LENGTH: int = Field( _MAX_LIST_LENGTH, description="Maximum length of query parameter list" ) + ENDPOINT: str = Field( + _DEFAULT_ENDPOINT, description="The default API endpoint to use." + ) + model_config = SettingsConfigDict(env_prefix="MPRESTER_") + + @field_validator("ENDPOINT", mode="before") + def _get_endpoint_from_env(cls, v: str | None) -> str: + """Support setting endpoint via MP_API_ENDPOINT environment variable.""" + return v or os.environ.get("MP_API_ENDPOINT") or _DEFAULT_ENDPOINT diff --git a/mp_api/client/core/utils.py b/mp_api/client/core/utils.py index d46852f2..d68b632e 100644 --- a/mp_api/client/core/utils.py +++ b/mp_api/client/core/utils.py @@ -1,5 +1,6 @@ from __future__ import annotations +import os from typing import TYPE_CHECKING, Literal import orjson @@ -8,6 +9,7 @@ from monty.json import MontyDecoder from packaging.version import parse as parse_version +from mp_api.client.core.exceptions import MPRestError from mp_api.client.core.settings import MAPIClientSettings if TYPE_CHECKING: @@ -50,6 +52,25 @@ def load_json( return MontyDecoder().process_decoded(data) if deser else data +def validate_api_key(api_key: str | None = None) -> str: + """Find and validate an API key.""" + # SETTINGS tries to read API key from ~/.config/.pmgrc.yaml + api_key = api_key or os.getenv("MP_API_KEY") + if not api_key: + from pymatgen.core import SETTINGS + + api_key = SETTINGS.get("PMG_MAPI_KEY") + + if not api_key or len(api_key) != 32: + addendum = " Valid API keys are 32 characters." if api_key else "" + raise MPRestError( + "Please obtain a valid API key from https://materialsproject.org/api " + f"and export it as an environment variable `MP_API_KEY`.{addendum}" + ) + + return api_key + + def validate_ids(id_list: list[str]) -> list[str]: """Function to validate material and task IDs. @@ -57,13 +78,13 @@ def validate_ids(id_list: list[str]) -> list[str]: id_list (List[str]): List of material or task IDs. Raises: - ValueError: If at least one ID is not formatted correctly. + MPRestError: If at least one ID is not formatted correctly. Returns: id_list: Returns original ID list if everything is formatted correctly. """ if len(id_list) > MAPIClientSettings().MAX_LIST_LENGTH: - raise ValueError( + raise MPRestError( "List of material/molecule IDs provided is too long. Consider removing the ID filter to automatically pull" " data for all IDs and filter locally." ) diff --git a/mp_api/client/mprester.py b/mp_api/client/mprester.py index 589854c1..f989266a 100644 --- a/mp_api/client/mprester.py +++ b/mp_api/client/mprester.py @@ -1,7 +1,6 @@ from __future__ import annotations import itertools -import os import warnings from collections import defaultdict from functools import cache, lru_cache @@ -16,17 +15,17 @@ from packaging import version from pymatgen.analysis.phase_diagram import PhaseDiagram from pymatgen.analysis.pourbaix_diagram import IonEntry -from pymatgen.core import SETTINGS, Composition, Element, Structure +from pymatgen.core import Composition, Element, Structure from pymatgen.core.ion import Ion from pymatgen.entries.computed_entries import ComputedStructureEntry from pymatgen.io.vasp import Chgcar from pymatgen.symmetry.analyzer import SpacegroupAnalyzer from requests import Session, get -from mp_api.client.core import BaseRester, MPRestError +from mp_api.client.core import BaseRester, MPRestError, MPRestWarning from mp_api.client.core._oxygen_evolution import OxygenEvolution from mp_api.client.core.settings import MAPIClientSettings -from mp_api.client.core.utils import load_json, validate_ids +from mp_api.client.core.utils import load_json, validate_api_key, validate_ids from mp_api.client.routes import GeneralStoreRester, MessagesRester, UserSettingsRester from mp_api.client.routes.materials import ( AbsorptionRester, @@ -125,11 +124,11 @@ def __init__( endpoint: str | None = None, notify_db_version: bool = False, include_user_agent: bool = True, - monty_decode: bool = True, use_document_model: bool = True, session: Session | None = None, headers: dict | None = None, mute_progress_bars: bool = _MAPI_SETTINGS.MUTE_PROGRESS_BARS, + **kwargs, ): """Initialize the MPRester. @@ -157,29 +156,20 @@ def __init__( making the API request. This helps MP support pymatgen users, and is similar to what most web browsers send with each page request. Set to False to disable the user agent. - monty_decode: Decode the data using monty into python objects use_document_model: If False, skip the creating the document model and return data as a dictionary. This can be simpler to work with but bypasses data validation and will not give auto-complete for available fields. session: Session object to use. By default (None), the client will create one. headers: Custom headers for localhost connections. mute_progress_bars: Whether to mute progress bars. - + **kwargs: access to legacy kwargs that may be in the process of being deprecated """ - # SETTINGS tries to read API key from ~/.config/.pmgrc.yaml - api_key = api_key or os.getenv("MP_API_KEY") or SETTINGS.get("PMG_MAPI_KEY") + self.api_key = validate_api_key(api_key) - if api_key and len(api_key) != 32: - raise ValueError( - "Please use a new API key from https://materialsproject.org/api " - "Keys for the new API are 32 characters, whereas keys for the legacy " - "API are 16 characters." - ) + self.endpoint = endpoint or _MAPI_SETTINGS.ENDPOINT + if not self.endpoint.endswith("/"): + self.endpoint += "/" - self.api_key = api_key - self.endpoint = endpoint or os.getenv( - "MP_API_ENDPOINT", "https://api.materialsproject.org/" - ) self.headers = headers or {} self.session = session or BaseRester._create_session( api_key=self.api_key, @@ -187,7 +177,6 @@ def __init__( headers=self.headers, ) self.use_document_model = use_document_model - self.monty_decode = monty_decode self.mute_progress_bars = mute_progress_bars self._contribs = None @@ -221,8 +210,13 @@ def __init__( "chemenv", ] - if not self.endpoint.endswith("/"): - self.endpoint += "/" + if "monty_decode" in kwargs: + warnings.warn( + "Ignoring `monty_decode`, as it is no longer a supported option in `mp_api`." + "The client by default returns results consistent with `monty_decode=True`.", + stacklevel=2, + category=MPRestWarning, + ) # Check if emmet version of server is compatible emmet_version = MPRester.get_emmet_version(self.endpoint) @@ -260,7 +254,6 @@ def __init__( endpoint=self.endpoint, include_user_agent=include_user_agent, session=self.session, - monty_decode=self.monty_decode, use_document_model=self.use_document_model, headers=self.headers, mute_progress_bars=self.mute_progress_bars, @@ -278,15 +271,11 @@ def __init__( suffix_split = cls.suffix.split("/") if len(suffix_split) == 1: - # Disable monty decode on nested data which may give errors - monty_disable = cls in [TaskRester, ProvenanceRester] - monty_decode = False if monty_disable else self.monty_decode rester = cls( api_key=api_key, endpoint=self.endpoint, include_user_agent=include_user_agent, session=self.session, - monty_decode=monty_decode, use_document_model=self.use_document_model, headers=self.headers, mute_progress_bars=self.mute_progress_bars, @@ -303,20 +292,15 @@ def __init__( elif "molecules" in suffix_split: _sub_rester_suffix_map["molecules"][attr] = cls - # TODO: Enable monty decoding when tasks and SNL schema is normalized - # # Allow lazy loading of nested resters under materials and molecules using custom __getattr__ methods def __core_custom_getattr(_self, _attr, _rester_map): if _attr in _rester_map: cls = _rester_map[_attr] - monty_disable = cls in [TaskRester, ProvenanceRester] - monty_decode = False if monty_disable else self.monty_decode rester = cls( api_key=api_key, endpoint=self.endpoint, include_user_agent=include_user_agent, session=self.session, - monty_decode=monty_decode, use_document_model=self.use_document_model, headers=self.headers, mute_progress_bars=self.mute_progress_bars, @@ -752,7 +736,7 @@ def get_entries( # Need to store object to permit de-duplication entries.add(ComputedStructureEntry.from_dict(entry_dict)) - return [e if self.monty_decode else e.as_dict() for e in entries] + return list(entries) def get_pourbaix_entries( self, @@ -1190,18 +1174,12 @@ def get_entries_in_chemsys( ) ) - if not self.monty_decode: - entries = [ComputedStructureEntry.from_dict(entry) for entry in entries] - if use_gibbs: # replace the entries with GibbsComputedStructureEntry from pymatgen.entries.computed_entries import GibbsComputedStructureEntry entries = GibbsComputedStructureEntry.from_entries(entries, temp=use_gibbs) - if not self.monty_decode: - entries = [entry.as_dict() for entry in entries] - return entries def get_bandstructure_by_material_id( @@ -1312,7 +1290,7 @@ def get_charge_density_from_task_id( kwargs = dict( bucket="materialsproject-parsed", key=f"chgcars/{validate_ids([task_id])[0]}.json.gz", - decoder=lambda x: load_json(x, deser=self.monty_decode), + decoder=lambda x: load_json(x, deser=True), ) chgcar = self.materials.tasks._query_open_data(**kwargs)[0] if not chgcar: @@ -1493,17 +1471,11 @@ def get_cohesive_energy( conventional_unit_cell=False, ) for entry in entries: - # Ensure that this works with monty_decode = False and True - if not self.monty_decode: - entry["uncorrected_energy_per_atom"] = entry["energy"] / sum( - entry["composition"].values() - ) - else: - entry = { - "data": entry.data, - "uncorrected_energy_per_atom": entry.uncorrected_energy_per_atom, - "composition": entry.composition, - } + entry = { + "data": entry.data, + "uncorrected_energy_per_atom": entry.uncorrected_energy_per_atom, + "composition": entry.composition, + } mp_id = entry["data"]["material_id"] if (run_type := entry["data"]["run_type"]) not in energies[mp_id]: diff --git a/mp_api/client/routes/_general_store.py b/mp_api/client/routes/_general_store.py index 659d0606..2ed73097 100644 --- a/mp_api/client/routes/_general_store.py +++ b/mp_api/client/routes/_general_store.py @@ -9,7 +9,6 @@ class GeneralStoreRester(BaseRester): # pragma: no cover suffix = "_general_store" document_model = GeneralStoreDoc # type: ignore primary_key = "submission_id" - monty_decode = False use_document_model = False def add_item(self, kind: str, markdown: str, meta: dict): # pragma: no cover diff --git a/mp_api/client/routes/_messages.py b/mp_api/client/routes/_messages.py index 64f796ad..a1e85c85 100644 --- a/mp_api/client/routes/_messages.py +++ b/mp_api/client/routes/_messages.py @@ -11,7 +11,6 @@ class MessagesRester(BaseRester): # pragma: no cover suffix = "_messages" document_model = MessagesDoc # type: ignore primary_key = "title" - monty_decode = False use_document_model = False def set_message( diff --git a/mp_api/client/routes/_user_settings.py b/mp_api/client/routes/_user_settings.py index 0f8d0bf3..a1eea304 100644 --- a/mp_api/client/routes/_user_settings.py +++ b/mp_api/client/routes/_user_settings.py @@ -9,7 +9,6 @@ class UserSettingsRester(BaseRester): # pragma: no cover suffix = "_user_settings" document_model = UserSettingsDoc # type: ignore primary_key = "consumer_id" - monty_decode = False use_document_model = False def create_user_settings(self, consumer_id, settings): diff --git a/mp_api/client/routes/materials/electronic_structure.py b/mp_api/client/routes/materials/electronic_structure.py index 277779ac..37a3a1e1 100644 --- a/mp_api/client/routes/materials/electronic_structure.py +++ b/mp_api/client/routes/materials/electronic_structure.py @@ -14,7 +14,7 @@ from pymatgen.electronic_structure.core import OrbitalType, Spin from mp_api.client.core import BaseRester, MPRestError -from mp_api.client.core.utils import validate_ids +from mp_api.client.core.utils import load_json, validate_ids if TYPE_CHECKING: from pymatgen.electronic_structure.dos import CompleteDos @@ -158,7 +158,6 @@ def es_rester(self) -> ElectronicStructureRester: endpoint=self.base_endpoint, include_user_agent=self.include_user_agent, session=self.session, - monty_decode=self.monty_decode, use_document_model=self.use_document_model, headers=self.headers, mute_progress_bars=self.mute_progress_bars, @@ -269,6 +268,7 @@ def get_bandstructure_from_task_id(self, task_id: str): result = self._query_open_data( bucket="materialsproject-parsed", key=f"bandstructures/{validate_ids([task_id])[0]}.json.gz", + decoder=lambda x: load_json(x, deser=True), )[0] except OSError: result = None @@ -473,6 +473,7 @@ def get_dos_from_task_id(self, task_id: str) -> CompleteDos: result = self._query_open_data( bucket="materialsproject-parsed", key=f"dos/{validate_ids([task_id])[0]}.json.gz", + decoder=lambda x: load_json(x, deser=True), )[0] except OSError: result = None diff --git a/mp_api/client/routes/materials/materials.py b/mp_api/client/routes/materials/materials.py index 7df557ef..2a09140e 100644 --- a/mp_api/client/routes/materials/materials.py +++ b/mp_api/client/routes/materials/materials.py @@ -128,7 +128,7 @@ def get_structure_by_material_id( if response and response[0]: response = response[0] - # Ensure that return type is a Structure regardless of `monty_decode` or `model_dump` output + # Ensure that return type is a Structure regardless of `model_dump` if isinstance(response[field], dict): response[field] = Structure.from_dict(response[field]) elif isinstance(response[field], list) and any( diff --git a/pyproject.toml b/pyproject.toml index 044015c5..24920b08 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -25,7 +25,7 @@ dependencies = [ "typing-extensions>=3.7.4.1", "requests>=2.23.0", "monty>=2024.12.10", - "emmet-core>=0.86.2", + "emmet-core>=0.86.3rc0", "smart_open", "boto3", "orjson >= 3.10,<4", diff --git a/requirements/requirements-ubuntu-latest_py3.11.txt b/requirements/requirements-ubuntu-latest_py3.11.txt index 87f83f4a..6d8b721b 100644 --- a/requirements/requirements-ubuntu-latest_py3.11.txt +++ b/requirements/requirements-ubuntu-latest_py3.11.txt @@ -24,7 +24,7 @@ contourpy==1.3.3 # via matplotlib cycler==0.12.1 # via matplotlib -emmet-core==0.86.2 +emmet-core==0.86.3rc0 # via mp-api (pyproject.toml) fonttools==4.61.0 # via matplotlib diff --git a/requirements/requirements-ubuntu-latest_py3.11_extras.txt b/requirements/requirements-ubuntu-latest_py3.11_extras.txt index 256664d1..ea20d785 100644 --- a/requirements/requirements-ubuntu-latest_py3.11_extras.txt +++ b/requirements/requirements-ubuntu-latest_py3.11_extras.txt @@ -62,7 +62,7 @@ dnspython==2.8.0 # pymongo docutils==0.22.3 # via sphinx -emmet-core[all]==0.86.2 +emmet-core[all]==0.86.3rc0 # via mp-api (pyproject.toml) execnet==2.1.2 # via pytest-xdist diff --git a/requirements/requirements-ubuntu-latest_py3.12.txt b/requirements/requirements-ubuntu-latest_py3.12.txt index b1fd5663..98658bd8 100644 --- a/requirements/requirements-ubuntu-latest_py3.12.txt +++ b/requirements/requirements-ubuntu-latest_py3.12.txt @@ -24,7 +24,7 @@ contourpy==1.3.3 # via matplotlib cycler==0.12.1 # via matplotlib -emmet-core==0.86.2 +emmet-core==0.86.3rc0 # via mp-api (pyproject.toml) fonttools==4.61.0 # via matplotlib diff --git a/requirements/requirements-ubuntu-latest_py3.12_extras.txt b/requirements/requirements-ubuntu-latest_py3.12_extras.txt index 37bf1832..bcdeb137 100644 --- a/requirements/requirements-ubuntu-latest_py3.12_extras.txt +++ b/requirements/requirements-ubuntu-latest_py3.12_extras.txt @@ -62,7 +62,7 @@ dnspython==2.8.0 # pymongo docutils==0.22.3 # via sphinx -emmet-core[all]==0.86.2 +emmet-core[all]==0.86.3rc0 # via mp-api (pyproject.toml) execnet==2.1.2 # via pytest-xdist diff --git a/tests/core/test_utils.py b/tests/core/test_utils.py index a1159104..8c6696e4 100644 --- a/tests/core/test_utils.py +++ b/tests/core/test_utils.py @@ -3,6 +3,8 @@ from packaging.version import parse as parse_version import pytest +from mp_api.client.core.exceptions import MPRestError + def test_emmet_core_version_checks(monkeypatch: pytest.MonkeyPatch): ref_ver = (1, 2, "3rc5") @@ -32,7 +34,7 @@ def test_id_validation(): max_num_idxs = MAPIClientSettings().MAX_LIST_LENGTH - with pytest.raises(ValueError, match="too long"): + with pytest.raises(MPRestError, match="too long"): _ = validate_ids([f"mp-{x}" for x in range(max_num_idxs + 1)]) # For all legacy MPIDs, ensure these validate correctly @@ -46,3 +48,40 @@ def test_id_validation(): isinstance(x, str) and AlphaID(x).string == x for x in validate_ids([y + AlphaID._cut_point for y in range(max_num_idxs)]) ) + + +def test_api_key_validation(monkeypatch: pytest.MonkeyPatch): + from mp_api.client.core.utils import validate_api_key + import pymatgen.core + + # Ensure any user settings are ignored + monkeypatch.setenv("MP_API_KEY", "") + monkeypatch.setenv("PMG_MAPI_KEY", "") + non_api_key_settings = { + k: v for k, v in pymatgen.core.SETTINGS.items() if k != "PMG_MAPI_KEY" + } + monkeypatch.setattr(pymatgen.core, "SETTINGS", non_api_key_settings) + + with pytest.raises(MPRestError, match="32 characters"): + validate_api_key("invalid_key") + + with pytest.raises(MPRestError, match="Please obtain a valid"): + validate_api_key() + + junk_api_key = "a" * 32 + monkeypatch.setenv("MP_API_KEY", junk_api_key) + assert validate_api_key() == junk_api_key + assert validate_api_key(junk_api_key) == junk_api_key + + other_junk_api_key = "b" * 32 + monkeypatch.setattr( + pymatgen.core, + "SETTINGS", + {**non_api_key_settings, "PMG_MAPI_KEY": other_junk_api_key}, + ) + # MP API environment variable takes precedence + assert validate_api_key() == junk_api_key + + # Check that pymatgen API key is used + monkeypatch.setenv("MP_API_KEY", "") + assert validate_api_key() == other_junk_api_key diff --git a/tests/materials/test_electronic_structure.py b/tests/materials/test_electronic_structure.py index c3d4f666..c744e3df 100644 --- a/tests/materials/test_electronic_structure.py +++ b/tests/materials/test_electronic_structure.py @@ -4,7 +4,7 @@ import pytest from pymatgen.analysis.magnetism import Ordering -from mp_api.client.core.client import MPRestError +from mp_api.client.core.exceptions import MPRestError from mp_api.client.routes.materials.electronic_structure import ( BandStructureRester, DosRester, diff --git a/tests/materials/test_phonon.py b/tests/materials/test_phonon.py index 3b04dfaa..8805176c 100644 --- a/tests/materials/test_phonon.py +++ b/tests/materials/test_phonon.py @@ -5,7 +5,7 @@ from emmet.core.phonon import PhononBS, PhononDOS -from mp_api.client.core import MPRestError +from mp_api.client.core.exceptions import MPRestError from mp_api.client.routes.materials.phonon import PhononRester from ..conftest import client_search_testing, requires_api_key diff --git a/tests/materials/test_summary.py b/tests/materials/test_summary.py index ba21e027..77784233 100644 --- a/tests/materials/test_summary.py +++ b/tests/materials/test_summary.py @@ -7,7 +7,7 @@ from pymatgen.analysis.magnetism import Ordering from mp_api.client.routes.materials.summary import SummaryRester -from mp_api.client.core.client import MPRestWarning, MPRestError +from mp_api.client.core.exceptions import MPRestWarning, MPRestError excluded_params = [ "include_gnome", diff --git a/tests/materials/test_tasks.py b/tests/materials/test_tasks.py index 1a92169f..d8bef85a 100644 --- a/tests/materials/test_tasks.py +++ b/tests/materials/test_tasks.py @@ -10,7 +10,7 @@ @pytest.fixture def rester(): - rester = TaskRester(monty_decode=False) + rester = TaskRester() yield rester rester.session.close() diff --git a/tests/test_client.py b/tests/test_client.py index c3445d00..ef1f6541 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -63,8 +63,6 @@ def test_generic_get_methods(rester): endpoint=mpr.endpoint, include_user_agent=True, session=mpr.session, - # Disable monty decode on nested data which may give errors - monty_decode=rester not in [TaskRester, ProvenanceRester], use_document_model=True, ) diff --git a/tests/test_mprester.py b/tests/test_mprester.py index eb0d3531..f497dc1a 100644 --- a/tests/test_mprester.py +++ b/tests/test_mprester.py @@ -33,7 +33,7 @@ from pymatgen.io.vasp import Chgcar from mp_api.client import MPRester -from mp_api.client.core.client import MPRestError +from mp_api.client.core import MPRestError, MPRestWarning from mp_api.client.core.settings import MAPIClientSettings from .conftest import requires_api_key @@ -391,12 +391,12 @@ def test_get_default_api_key_endpoint(self, monkeypatch: pytest.MonkeyPatch): monkeypatch.delenv("MP_API_KEY", raising=False) monkeypatch.delenv("PMG_MAPI_KEY", raising=False) monkeypatch.setitem(SETTINGS, "PMG_MAPI_KEY", None) - with pytest.raises(MPRestError, match="No API key found in request"): + with pytest.raises(MPRestError, match="Please obtain a valid API key"): MPRester().get_structure_by_material_id("mp-149") def test_invalid_api_key(self, monkeypatch): monkeypatch.setenv("MP_API_KEY", "INVALID") - with pytest.raises(ValueError, match="Keys for the new API are 32 characters"): + with pytest.raises(MPRestError, match="Valid API keys are 32 characters"): MPRester().get_structure_by_material_id("mp-149") def test_get_cohesive_energy_per_atom_utility(self): @@ -453,14 +453,16 @@ def test_get_cohesive_energy(self): }, } e_coh = {} - for monty_decode in (True, False): + for use_document_model in (True, False): with MPRester( - use_document_model=monty_decode, monty_decode=monty_decode + use_document_model=use_document_model, ) as _mpr: for norm, refs in ref_e_coh.items(): _e_coh = _mpr.get_cohesive_energy(list(refs), normalization=norm) if norm == "atom": - e_coh["serial" if monty_decode else "noserial"] = _e_coh.copy() + e_coh[ + "serial" if use_document_model else "noserial" + ] = _e_coh.copy() # Ensure energies match reference data assert all(v == pytest.approx(refs[k]) for k, v in _e_coh.items()) @@ -573,3 +575,7 @@ def test_oxygen_evolution_bad_input(self, mpr): with pytest.raises(ValueError, match="No available insertion electrode data"): _ = mpr.get_oxygen_evolution("mp-2207", "Al") + + def test_monty_decode_warning(self): + with pytest.warns(MPRestWarning, match="Ignoring `monty_decode`"): + MPRester(monty_decode=False)