diff --git a/docs/tutorial/parameter-types/enum.md b/docs/tutorial/parameter-types/enum.md index 278908fc72..012db6bfa8 100644 --- a/docs/tutorial/parameter-types/enum.md +++ b/docs/tutorial/parameter-types/enum.md @@ -74,6 +74,37 @@ Training neural network of type: lstm +### Using Enum names instead of values + +Sometimes you want to accept `Enum` names from the command line and convert +that into `Enum` values in the command handler. You can enable this by setting +`enum_by_name=True`: + +{* docs_src/parameter_types/enum/tutorial007_an_py310.py hl[18] *} + +And then the names of the `Enum` can be used instead of values: + +
+ +```console +$ python main.py --log-level debug + +Log level set to DEBUG +``` + +
+ +This can be particularly useful if the enum values are not strings: + +{* docs_src/parameter_types/enum/tutorial005_an_py310.py hl[8:11,18] *} + +```console +$ python main.py --access protected + +Access level: protected (2) +``` + + ### List of Enum values A *CLI parameter* can also take a list of `Enum` values: @@ -112,6 +143,29 @@ Buying groceries: Eggs, Bacon +You can also combine `enum_by_name=True` with a list of enums: + +{* docs_src/parameter_types/enum/tutorial006_an_py310.py hl[18] *} + +This works exactly the same, but you're using the enum names instead of values: + +
+ +```console +// Try it with a single value +$ python main.py --groceries "f1" + +Buying groceries: Eggs + +// Try it with multiple values +$ python main.py --groceries "f1" --groceries "f2" + +Buying groceries: Eggs, Bacon +``` + +
+ + ### Literal choices You can also use `Literal` to represent a set of possible predefined choices, without having to use an `Enum`: diff --git a/docs_src/parameter_types/enum/tutorial005_an_py310.py b/docs_src/parameter_types/enum/tutorial005_an_py310.py new file mode 100644 index 0000000000..96aa2c7b4d --- /dev/null +++ b/docs_src/parameter_types/enum/tutorial005_an_py310.py @@ -0,0 +1,23 @@ +import enum +from typing import Annotated + +import typer + + +class Access(enum.IntEnum): + private = 1 + protected = 2 + public = 3 + open = 4 + + +app = typer.Typer() + + +@app.command() +def main(access: Annotated[Access, typer.Option(enum_by_name=True)] = "private"): + typer.echo(f"Access level: {access.name} ({access.value})") + + +if __name__ == "__main__": + app() diff --git a/docs_src/parameter_types/enum/tutorial005_py310.py b/docs_src/parameter_types/enum/tutorial005_py310.py new file mode 100644 index 0000000000..fc53a87a30 --- /dev/null +++ b/docs_src/parameter_types/enum/tutorial005_py310.py @@ -0,0 +1,22 @@ +import enum + +import typer + + +class Access(enum.IntEnum): + private = 1 + protected = 2 + public = 3 + open = 4 + + +app = typer.Typer() + + +@app.command() +def main(access: Access = typer.Option("private", enum_by_name=True)): + typer.echo(f"Access level: {access.name} ({access.value})") + + +if __name__ == "__main__": + app() diff --git a/docs_src/parameter_types/enum/tutorial006_an_py310.py b/docs_src/parameter_types/enum/tutorial006_an_py310.py new file mode 100644 index 0000000000..0d633f792e --- /dev/null +++ b/docs_src/parameter_types/enum/tutorial006_an_py310.py @@ -0,0 +1,24 @@ +from enum import Enum +from typing import Annotated + +import typer + + +class Food(str, Enum): + f1 = "Eggs" + f2 = "Bacon" + f3 = "Cheese" + + +app = typer.Typer() + + +@app.command() +def main( + groceries: Annotated[list[Food], typer.Option(enum_by_name=True)] = ["f1", "f3"], +): + print(f"Buying groceries: {', '.join([f.value for f in groceries])}") + + +if __name__ == "__main__": + app() diff --git a/docs_src/parameter_types/enum/tutorial006_py310.py b/docs_src/parameter_types/enum/tutorial006_py310.py new file mode 100644 index 0000000000..aa4edd78e9 --- /dev/null +++ b/docs_src/parameter_types/enum/tutorial006_py310.py @@ -0,0 +1,21 @@ +from enum import Enum + +import typer + + +class Food(str, Enum): + f1 = "Eggs" + f2 = "Bacon" + f3 = "Cheese" + + +app = typer.Typer() + + +@app.command() +def main(groceries: list[Food] = typer.Option(["f1", "f3"], enum_by_name=True)): + print(f"Buying groceries: {', '.join([f.value for f in groceries])}") + + +if __name__ == "__main__": + app() diff --git a/docs_src/parameter_types/enum/tutorial007_an_py310.py b/docs_src/parameter_types/enum/tutorial007_an_py310.py new file mode 100644 index 0000000000..eaf9b3da36 --- /dev/null +++ b/docs_src/parameter_types/enum/tutorial007_an_py310.py @@ -0,0 +1,23 @@ +import enum +import logging +from typing import Annotated + +import typer + + +class LogLevel(enum.Enum): + debug = logging.DEBUG + info = logging.INFO + warning = logging.WARNING + + +app = typer.Typer() + + +@app.command() +def main(log_level: Annotated[LogLevel, typer.Option(enum_by_name=True)] = "warning"): + typer.echo(f"Log level set to: {logging.getLevelName(log_level.value)}") + + +if __name__ == "__main__": + app() diff --git a/docs_src/parameter_types/enum/tutorial007_py310.py b/docs_src/parameter_types/enum/tutorial007_py310.py new file mode 100644 index 0000000000..a4cc486c6e --- /dev/null +++ b/docs_src/parameter_types/enum/tutorial007_py310.py @@ -0,0 +1,22 @@ +import enum +import logging + +import typer + + +class LogLevel(enum.Enum): + debug = logging.DEBUG + info = logging.INFO + warning = logging.WARNING + + +app = typer.Typer() + + +@app.command() +def main(log_level: LogLevel = typer.Option("warning", enum_by_name=True)): + typer.echo(f"Log level set to: {logging.getLevelName(log_level.value)}") + + +if __name__ == "__main__": + app() diff --git a/pyproject.toml b/pyproject.toml index 990366bd3e..a70d01734a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -232,6 +232,7 @@ ignore = [ "docs_src/options_autocompletion/tutorial008_an_py310.py" = ["B006"] "docs_src/options_autocompletion/tutorial009_an_py310.py" = ["B006"] "docs_src/parameter_types/enum/tutorial003_an_py310.py" = ["B006"] +"docs_src/parameter_types/enum/tutorial006_an_py310.py" = ["B006"] # Loop control variable `value` not used within loop body "docs_src/progressbar/tutorial001_py310.py" = ["B007"] "docs_src/progressbar/tutorial003_py310.py" = ["B007"] diff --git a/tests/test_tutorial/test_parameter_types/test_enum/test_tutorial005.py b/tests/test_tutorial/test_parameter_types/test_enum/test_tutorial005.py new file mode 100644 index 0000000000..4f09eb1824 --- /dev/null +++ b/tests/test_tutorial/test_parameter_types/test_enum/test_tutorial005.py @@ -0,0 +1,43 @@ +import importlib +import subprocess +import sys +from types import ModuleType + +import pytest +from typer.testing import CliRunner + +runner = CliRunner() + + +@pytest.fixture( + name="mod", + params=[ + pytest.param("tutorial005_py310"), + pytest.param("tutorial005_an_py310"), + ], +) +def get_mod(request: pytest.FixtureRequest) -> ModuleType: + module_name = f"docs_src.parameter_types.enum.{request.param}" + mod = importlib.import_module(module_name) + return mod + + +def test_int_enum_default(mod: ModuleType): + result = runner.invoke(mod.app) + assert result.exit_code == 0 + assert "Access level: private (1)" in result.output + + +def test_int_enum(mod: ModuleType): + result = runner.invoke(mod.app, ["--access", "open"]) + assert result.exit_code == 0 + assert "Access level: open (4)" in result.output + + +def test_script(mod: ModuleType): + result = subprocess.run( + [sys.executable, "-m", "coverage", "run", mod.__file__, "--help"], + capture_output=True, + encoding="utf-8", + ) + assert "Usage" in result.stdout diff --git a/tests/test_tutorial/test_parameter_types/test_enum/test_tutorial006.py b/tests/test_tutorial/test_parameter_types/test_enum/test_tutorial006.py new file mode 100644 index 0000000000..0e7308275b --- /dev/null +++ b/tests/test_tutorial/test_parameter_types/test_enum/test_tutorial006.py @@ -0,0 +1,57 @@ +import importlib +import subprocess +import sys +from types import ModuleType + +import pytest +from typer.testing import CliRunner + +runner = CliRunner() + + +@pytest.fixture( + name="mod", + params=[ + pytest.param("tutorial006_py310"), + pytest.param("tutorial006_an_py310"), + ], +) +def get_mod(request: pytest.FixtureRequest) -> ModuleType: + module_name = f"docs_src.parameter_types.enum.{request.param}" + mod = importlib.import_module(module_name) + return mod + + +def test_help(mod: ModuleType): + result = runner.invoke(mod.app, ["--help"]) + assert result.exit_code == 0 + assert "--groceries" in result.output + assert "[f1|f2|f3]" in result.output + assert "default: f1, f3" in result.output + + +def test_call_no_arg(mod: ModuleType): + result = runner.invoke(mod.app) + assert result.exit_code == 0 + assert "Buying groceries: Eggs, Cheese" in result.output + + +def test_call_single_arg(mod: ModuleType): + result = runner.invoke(mod.app, ["--groceries", "f2"]) + assert result.exit_code == 0 + assert "Buying groceries: Bacon" in result.output + + +def test_call_multiple_arg(mod: ModuleType): + result = runner.invoke(mod.app, ["--groceries", "f1", "--groceries", "f2"]) + assert result.exit_code == 0 + assert "Buying groceries: Eggs, Bacon" in result.output + + +def test_script(mod: ModuleType): + result = subprocess.run( + [sys.executable, "-m", "coverage", "run", mod.__file__, "--help"], + capture_output=True, + encoding="utf-8", + ) + assert "Usage" in result.stdout diff --git a/tests/test_tutorial/test_parameter_types/test_enum/test_tutorial007.py b/tests/test_tutorial/test_parameter_types/test_enum/test_tutorial007.py new file mode 100644 index 0000000000..51322b0ade --- /dev/null +++ b/tests/test_tutorial/test_parameter_types/test_enum/test_tutorial007.py @@ -0,0 +1,43 @@ +import importlib +import subprocess +import sys +from types import ModuleType + +import pytest +from typer.testing import CliRunner + +runner = CliRunner() + + +@pytest.fixture( + name="mod", + params=[ + pytest.param("tutorial007_py310"), + pytest.param("tutorial007_an_py310"), + ], +) +def get_mod(request: pytest.FixtureRequest) -> ModuleType: + module_name = f"docs_src.parameter_types.enum.{request.param}" + mod = importlib.import_module(module_name) + return mod + + +def test_enum_names_default(mod: ModuleType): + result = runner.invoke(mod.app) + assert result.exit_code == 0 + assert "Log level set to: WARNING" in result.output + + +def test_enum_names(mod: ModuleType): + result = runner.invoke(mod.app, ["--log-level", "debug"]) + assert result.exit_code == 0 + assert "Log level set to: DEBUG" in result.output + + +def test_script(mod: ModuleType): + result = subprocess.run( + [sys.executable, "-m", "coverage", "run", mod.__file__, "--help"], + capture_output=True, + encoding="utf-8", + ) + assert "Usage" in result.stdout diff --git a/typer/main.py b/typer/main.py index f4f21bb844..3a0c7b5750 100644 --- a/typer/main.py +++ b/typer/main.py @@ -1424,12 +1424,17 @@ def get_command_from_info( return command -def determine_type_convertor(type_: Any) -> Callable[[Any], Any] | None: +def determine_type_convertor( + type_: Any, enum_by_name: bool +) -> Callable[[Any], Any] | None: convertor: Callable[[Any], Any] | None = None if lenient_issubclass(type_, Path): convertor = param_path_convertor if lenient_issubclass(type_, Enum): - convertor = generate_enum_convertor(type_) + if enum_by_name: + convertor = generate_enum_name_convertor(type_) + else: + convertor = generate_enum_convertor(type_) return convertor @@ -1454,6 +1459,18 @@ def convertor(value: Any) -> Any: return convertor +def generate_enum_name_convertor(enum: type[Enum]) -> Callable[..., Any]: + val_map = {str(item.name): item for item in enum} + + def convertor(value: Any) -> Any: + if value is not None: + val = str(value) + if val in val_map: + return val_map[val] + + return convertor + + def generate_list_convertor( convertor: Callable[[Any], Any] | None, default_value: Any | None ) -> Callable[[Sequence[Any] | None], list[Any] | None]: @@ -1467,8 +1484,9 @@ def internal_convertor(value: Sequence[Any] | None) -> list[Any] | None: def generate_tuple_convertor( types: Sequence[Any], + enum_by_name: bool, ) -> Callable[[tuple[Any, ...] | None], tuple[Any, ...] | None]: - convertors = [determine_type_convertor(type_) for type_ in types] + convertors = [determine_type_convertor(type_, enum_by_name) for type_ in types] def internal_convertor( param_args: tuple[Any, ...] | None, @@ -1603,15 +1621,16 @@ def get_click_type( atomic=parameter_info.atomic, ) elif lenient_issubclass(annotation, Enum): + if parameter_info.enum_by_name: + choices = [item.name for item in annotation] + else: + choices = [item.value for item in annotation] # The custom TyperChoice is only needed for Click < 8.2.0, to parse the # command line values matching them to the enum values. Click 8.2.0 added # support for enum values but reading enum names. # Passing here the list of enum values (instead of just the enum) accounts for # Click < 8.2.0. - return TyperChoice( - [item.value for item in annotation], - case_sensitive=parameter_info.case_sensitive, - ) + return TyperChoice(choices, case_sensitive=parameter_info.case_sensitive) elif is_literal_type(annotation): return click.Choice( literal_values(annotation), @@ -1690,13 +1709,14 @@ def get_click_param( parameter_type = get_click_type( annotation=main_type, parameter_info=parameter_info ) - convertor = determine_type_convertor(main_type) + enum_by_name = parameter_info.enum_by_name + convertor = determine_type_convertor(main_type, enum_by_name) if is_list: convertor = generate_list_convertor( convertor=convertor, default_value=default_value ) if is_tuple: - convertor = generate_tuple_convertor(get_args(main_type)) + convertor = generate_tuple_convertor(get_args(main_type), enum_by_name) if isinstance(parameter_info, OptionInfo): if main_type is bool: is_flag = True diff --git a/typer/models.py b/typer/models.py index 3285a96a24..448bbef561 100644 --- a/typer/models.py +++ b/typer/models.py @@ -303,6 +303,7 @@ def __init__( hidden: bool = False, # Choice case_sensitive: bool = True, + enum_by_name: bool = False, # Numbers min: int | float | None = None, max: int | float | None = None, @@ -355,6 +356,7 @@ def __init__( self.hidden = hidden # Choice self.case_sensitive = case_sensitive + self.enum_by_name = enum_by_name # Numbers self.min = min self.max = max @@ -421,6 +423,7 @@ def __init__( show_envvar: bool = True, # Choice case_sensitive: bool = True, + enum_by_name: bool = False, # Numbers min: int | float | None = None, max: int | float | None = None, @@ -467,6 +470,7 @@ def __init__( hidden=hidden, # Choice case_sensitive=case_sensitive, + enum_by_name=enum_by_name, # Numbers min=min, max=max, @@ -540,6 +544,7 @@ def __init__( hidden: bool = False, # Choice case_sensitive: bool = True, + enum_by_name: bool = False, # Numbers min: int | float | None = None, max: int | float | None = None, @@ -586,6 +591,7 @@ def __init__( hidden=hidden, # Choice case_sensitive=case_sensitive, + enum_by_name=enum_by_name, # Numbers min=min, max=max, diff --git a/typer/params.py b/typer/params.py index b325b273c4..3b4e1dc7ac 100644 --- a/typer/params.py +++ b/typer/params.py @@ -49,6 +49,7 @@ def Option( show_envvar: bool = True, # Choice case_sensitive: bool = True, + enum_by_name: bool = False, # Numbers min: int | float | None = None, max: int | float | None = None, @@ -114,6 +115,7 @@ def Option( show_envvar: bool = True, # Choice case_sensitive: bool = True, + enum_by_name: bool = False, # Numbers min: int | float | None = None, max: int | float | None = None, @@ -581,7 +583,8 @@ class NeuralNetwork(str, Enum): @app.command() def main( - network: Annotated[NeuralNetwork, typer.Option(case_sensitive=False)]): + network: Annotated[NeuralNetwork, typer.Option(case_sensitive=False)] + ): print(f"Training neural network of type: {network.value}") ``` @@ -589,6 +592,32 @@ def main( """ ), ] = True, + enum_by_name: Annotated[ + bool, + Doc( + """ + For a CLI Option representing an [Enum (choice)](https://typer.tiangolo.com/tutorial/parameter-types/enum), + accept names from the command line instead of values. + + **Example** + + ```python + from enum import Enum + + class Food(str, Enum): + f1 = "Eggs" + f2 = "Bacon" + f3 = "Cheese" + + @app.command() + def main( + groceries: Annotated[list[Food], typer.Option(enum_by_name=True)] = ["f1", "f3"] + ): + print(f"Buying groceries: {', '.join([f.value for f in groceries])}") + ``` + """ + ), + ] = False, # Numbers min: Annotated[ int | float | None, @@ -974,6 +1003,7 @@ def register( show_envvar=show_envvar, # Choice case_sensitive=case_sensitive, + enum_by_name=enum_by_name, # Numbers min=min, max=max, @@ -1030,6 +1060,7 @@ def Argument( hidden: bool = False, # Choice case_sensitive: bool = True, + enum_by_name: bool = False, # Numbers min: int | float | None = None, max: int | float | None = None, @@ -1086,6 +1117,7 @@ def Argument( hidden: bool = False, # Choice case_sensitive: bool = True, + enum_by_name: bool = False, # Numbers min: int | float | None = None, max: int | float | None = None, @@ -1431,6 +1463,32 @@ def main( """ ), ] = True, + enum_by_name: Annotated[ + bool, + Doc( + """ + For a CLI Argument representing an [Enum (choice)](https://typer.tiangolo.com/tutorial/parameter-types/enum), + accept names from the command line instead of values. + + **Example** + + ```python + from enum import Enum + + class Food(str, Enum): + f1 = "Eggs" + f2 = "Bacon" + f3 = "Cheese" + + @app.command() + def main( + groceries: Annotated[list[Food], typer.Argument(enum_by_name=True)] = ["f1", "f3"] + ): + print(f"Buying groceries: {', '.join([f.value for f in groceries])}") + ``` + """ + ), + ] = False, # Numbers min: Annotated[ int | float | None, @@ -1805,6 +1863,7 @@ def main(name: Annotated[str, typer.Argument()] = "World"): hidden=hidden, # Choice case_sensitive=case_sensitive, + enum_by_name=enum_by_name, # Numbers min=min, max=max,