diff --git a/piccolo/apps/migrations/tables.py b/piccolo/apps/migrations/tables.py index 91906be4f..18104e21d 100644 --- a/piccolo/apps/migrations/tables.py +++ b/piccolo/apps/migrations/tables.py @@ -4,10 +4,20 @@ from piccolo.columns import Timestamp, Varchar from piccolo.columns.defaults.timestamp import TimestampNow +from piccolo.conf.overrides import get_override from piccolo.table import Table +TABLENAME = get_override(key="PICCOLO_MIGRATION_TABLENAME", default=None) -class Migration(Table): + +SCHEMA = get_override(key="PICCOLO_MIGRATION_SCHEMA", default=None) + + +class Migration( + Table, + tablename=TABLENAME, + schema=SCHEMA, +): name = Varchar(length=200) app_name = Varchar(length=200) ran_on = Timestamp(default=TimestampNow()) diff --git a/piccolo/conf/overrides.py b/piccolo/conf/overrides.py new file mode 100644 index 000000000..7c3e62290 --- /dev/null +++ b/piccolo/conf/overrides.py @@ -0,0 +1,101 @@ +import logging +import os +import typing as t + +try: + import tomllib # type: ignore +except ImportError: + import tomli as tomllib # type: ignore + + +logger = logging.getLogger(__name__) + + +class OverrideLoader: + """ + The user can optionally add a ``piccolo_overrides.toml`` file at the root + of their project (wherever the Python interpreter will be started). + + This can be used to modify some of the internal workings of Piccolo. + Examples are the ``tablename`` and ``schema`` of the ``Migration`` table. + + The reason we don't have these values in ``piccolo_conf.py`` is we want the + values to be static. Also, they could be required at any point by Piccolo, + before the app has fully loaded all of the Python code, so having a static + file helps prevent circular imports. + + """ + + _toml_contents: t.Optional[t.Dict[str, t.Any]] + + def __init__( + self, + toml_file_name: str = "piccolo_overrides.toml", + folder_path: t.Optional[str] = None, + ): + self.toml_file_name = toml_file_name + self.folder_path = folder_path or os.getcwd() + self._toml_contents = None + + def get_toml_contents(self) -> t.Dict[str, t.Any]: + """ + Returns the contents of the TOML file. + """ + if self._toml_contents is None: + toml_file_path = os.path.join( + self.folder_path, + self.toml_file_name, + ) + + if os.path.exists(toml_file_path): + with open(toml_file_path, "rb") as f: + self._toml_contents = tomllib.load(f) + else: + raise FileNotFoundError("The TOML file wasn't found.") + + return self._toml_contents + + def get_override(self, key: str, default: t.Any = None) -> t.Any: + """ + Returns the corresponding value from the TOML file if found, otherwise + the default value is returned instead. + """ + try: + toml_contents = self.get_toml_contents() + except FileNotFoundError: + logger.debug(f"{self.toml_file_name} wasn't found") + return default + else: + value = toml_contents.get(key, ...) + + if value is ...: + return default + + logger.debug(f"Loaded {key} from TOML file.") + return value + + def validate(self): + """ + Can be used to manually check the TOML file can be loaded:: + + from piccolo.conf.overrides import OverrideLoader + + >>> OverrideLoader().validate() + + """ + try: + self.get_toml_contents() + except Exception as e: + raise e + else: + print("Successfully loaded TOML file!") + + +DEFAULT_LOADER = OverrideLoader() + + +def get_override(key: str, default: t.Any = None) -> t.Any: + """ + A convenience, instead of ``DEFAULT_LOADER.get_override``. + """ + return DEFAULT_LOADER.get_override(key=key, default=default) diff --git a/requirements/requirements.txt b/requirements/requirements.txt index 0a5ee6244..378540f70 100644 --- a/requirements/requirements.txt +++ b/requirements/requirements.txt @@ -5,3 +5,4 @@ targ>=0.3.7 inflection>=0.5.1 typing-extensions>=4.3.0 pydantic[email]==2.* +tomli;python_version<="3.10" diff --git a/tests/conf/data/test_overrides.toml b/tests/conf/data/test_overrides.toml new file mode 100644 index 000000000..591d31afa --- /dev/null +++ b/tests/conf/data/test_overrides.toml @@ -0,0 +1,2 @@ +# Used by test_overrides.py +MY_TEST_VALUE = "hello world" diff --git a/tests/conf/test_overrides.py b/tests/conf/test_overrides.py new file mode 100644 index 000000000..d55107c4e --- /dev/null +++ b/tests/conf/test_overrides.py @@ -0,0 +1,27 @@ +import os +from unittest import TestCase + +from piccolo.conf.overrides import OverrideLoader + + +class TestOverrideLoader(TestCase): + def test_get_override(self): + """ + Make sure we can get overrides from a TOML file. + """ + loader = OverrideLoader( + toml_file_name="test_overrides.toml", + folder_path=os.path.join(os.path.dirname(__file__), "data"), + ) + self.assertEqual( + loader.get_override( + key="MY_TEST_VALUE", + ), + "hello world", + ) + + self.assertIsNone( + loader.get_override( + key="WRONG_KEY", + ) + )