diff --git a/docs/tutorial/parameter-types/enum.md b/docs/tutorial/parameter-types/enum.md
index 278908fc72..012db6bfa8 100644
--- a/docs/tutorial/parameter-types/enum.md
+++ b/docs/tutorial/parameter-types/enum.md
@@ -74,6 +74,37 @@ Training neural network of type: lstm
+### Using Enum names instead of values
+
+Sometimes you want to accept `Enum` names from the command line and convert
+that into `Enum` values in the command handler. You can enable this by setting
+`enum_by_name=True`:
+
+{* docs_src/parameter_types/enum/tutorial007_an_py310.py hl[18] *}
+
+And then the names of the `Enum` can be used instead of values:
+
+
+
+```console
+$ python main.py --log-level debug
+
+Log level set to DEBUG
+```
+
+
+
+This can be particularly useful if the enum values are not strings:
+
+{* docs_src/parameter_types/enum/tutorial005_an_py310.py hl[8:11,18] *}
+
+```console
+$ python main.py --access protected
+
+Access level: protected (2)
+```
+
+
### List of Enum values
A *CLI parameter* can also take a list of `Enum` values:
@@ -112,6 +143,29 @@ Buying groceries: Eggs, Bacon
+You can also combine `enum_by_name=True` with a list of enums:
+
+{* docs_src/parameter_types/enum/tutorial006_an_py310.py hl[18] *}
+
+This works exactly the same, but you're using the enum names instead of values:
+
+
+
+```console
+// Try it with a single value
+$ python main.py --groceries "f1"
+
+Buying groceries: Eggs
+
+// Try it with multiple values
+$ python main.py --groceries "f1" --groceries "f2"
+
+Buying groceries: Eggs, Bacon
+```
+
+
+
+
### Literal choices
You can also use `Literal` to represent a set of possible predefined choices, without having to use an `Enum`:
diff --git a/docs_src/parameter_types/enum/tutorial005_an_py310.py b/docs_src/parameter_types/enum/tutorial005_an_py310.py
new file mode 100644
index 0000000000..96aa2c7b4d
--- /dev/null
+++ b/docs_src/parameter_types/enum/tutorial005_an_py310.py
@@ -0,0 +1,23 @@
+import enum
+from typing import Annotated
+
+import typer
+
+
+class Access(enum.IntEnum):
+ private = 1
+ protected = 2
+ public = 3
+ open = 4
+
+
+app = typer.Typer()
+
+
+@app.command()
+def main(access: Annotated[Access, typer.Option(enum_by_name=True)] = "private"):
+ typer.echo(f"Access level: {access.name} ({access.value})")
+
+
+if __name__ == "__main__":
+ app()
diff --git a/docs_src/parameter_types/enum/tutorial005_py310.py b/docs_src/parameter_types/enum/tutorial005_py310.py
new file mode 100644
index 0000000000..fc53a87a30
--- /dev/null
+++ b/docs_src/parameter_types/enum/tutorial005_py310.py
@@ -0,0 +1,22 @@
+import enum
+
+import typer
+
+
+class Access(enum.IntEnum):
+ private = 1
+ protected = 2
+ public = 3
+ open = 4
+
+
+app = typer.Typer()
+
+
+@app.command()
+def main(access: Access = typer.Option("private", enum_by_name=True)):
+ typer.echo(f"Access level: {access.name} ({access.value})")
+
+
+if __name__ == "__main__":
+ app()
diff --git a/docs_src/parameter_types/enum/tutorial006_an_py310.py b/docs_src/parameter_types/enum/tutorial006_an_py310.py
new file mode 100644
index 0000000000..0d633f792e
--- /dev/null
+++ b/docs_src/parameter_types/enum/tutorial006_an_py310.py
@@ -0,0 +1,24 @@
+from enum import Enum
+from typing import Annotated
+
+import typer
+
+
+class Food(str, Enum):
+ f1 = "Eggs"
+ f2 = "Bacon"
+ f3 = "Cheese"
+
+
+app = typer.Typer()
+
+
+@app.command()
+def main(
+ groceries: Annotated[list[Food], typer.Option(enum_by_name=True)] = ["f1", "f3"],
+):
+ print(f"Buying groceries: {', '.join([f.value for f in groceries])}")
+
+
+if __name__ == "__main__":
+ app()
diff --git a/docs_src/parameter_types/enum/tutorial006_py310.py b/docs_src/parameter_types/enum/tutorial006_py310.py
new file mode 100644
index 0000000000..aa4edd78e9
--- /dev/null
+++ b/docs_src/parameter_types/enum/tutorial006_py310.py
@@ -0,0 +1,21 @@
+from enum import Enum
+
+import typer
+
+
+class Food(str, Enum):
+ f1 = "Eggs"
+ f2 = "Bacon"
+ f3 = "Cheese"
+
+
+app = typer.Typer()
+
+
+@app.command()
+def main(groceries: list[Food] = typer.Option(["f1", "f3"], enum_by_name=True)):
+ print(f"Buying groceries: {', '.join([f.value for f in groceries])}")
+
+
+if __name__ == "__main__":
+ app()
diff --git a/docs_src/parameter_types/enum/tutorial007_an_py310.py b/docs_src/parameter_types/enum/tutorial007_an_py310.py
new file mode 100644
index 0000000000..eaf9b3da36
--- /dev/null
+++ b/docs_src/parameter_types/enum/tutorial007_an_py310.py
@@ -0,0 +1,23 @@
+import enum
+import logging
+from typing import Annotated
+
+import typer
+
+
+class LogLevel(enum.Enum):
+ debug = logging.DEBUG
+ info = logging.INFO
+ warning = logging.WARNING
+
+
+app = typer.Typer()
+
+
+@app.command()
+def main(log_level: Annotated[LogLevel, typer.Option(enum_by_name=True)] = "warning"):
+ typer.echo(f"Log level set to: {logging.getLevelName(log_level.value)}")
+
+
+if __name__ == "__main__":
+ app()
diff --git a/docs_src/parameter_types/enum/tutorial007_py310.py b/docs_src/parameter_types/enum/tutorial007_py310.py
new file mode 100644
index 0000000000..a4cc486c6e
--- /dev/null
+++ b/docs_src/parameter_types/enum/tutorial007_py310.py
@@ -0,0 +1,22 @@
+import enum
+import logging
+
+import typer
+
+
+class LogLevel(enum.Enum):
+ debug = logging.DEBUG
+ info = logging.INFO
+ warning = logging.WARNING
+
+
+app = typer.Typer()
+
+
+@app.command()
+def main(log_level: LogLevel = typer.Option("warning", enum_by_name=True)):
+ typer.echo(f"Log level set to: {logging.getLevelName(log_level.value)}")
+
+
+if __name__ == "__main__":
+ app()
diff --git a/pyproject.toml b/pyproject.toml
index 990366bd3e..a70d01734a 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -232,6 +232,7 @@ ignore = [
"docs_src/options_autocompletion/tutorial008_an_py310.py" = ["B006"]
"docs_src/options_autocompletion/tutorial009_an_py310.py" = ["B006"]
"docs_src/parameter_types/enum/tutorial003_an_py310.py" = ["B006"]
+"docs_src/parameter_types/enum/tutorial006_an_py310.py" = ["B006"]
# Loop control variable `value` not used within loop body
"docs_src/progressbar/tutorial001_py310.py" = ["B007"]
"docs_src/progressbar/tutorial003_py310.py" = ["B007"]
diff --git a/tests/test_tutorial/test_parameter_types/test_enum/test_tutorial005.py b/tests/test_tutorial/test_parameter_types/test_enum/test_tutorial005.py
new file mode 100644
index 0000000000..4f09eb1824
--- /dev/null
+++ b/tests/test_tutorial/test_parameter_types/test_enum/test_tutorial005.py
@@ -0,0 +1,43 @@
+import importlib
+import subprocess
+import sys
+from types import ModuleType
+
+import pytest
+from typer.testing import CliRunner
+
+runner = CliRunner()
+
+
+@pytest.fixture(
+ name="mod",
+ params=[
+ pytest.param("tutorial005_py310"),
+ pytest.param("tutorial005_an_py310"),
+ ],
+)
+def get_mod(request: pytest.FixtureRequest) -> ModuleType:
+ module_name = f"docs_src.parameter_types.enum.{request.param}"
+ mod = importlib.import_module(module_name)
+ return mod
+
+
+def test_int_enum_default(mod: ModuleType):
+ result = runner.invoke(mod.app)
+ assert result.exit_code == 0
+ assert "Access level: private (1)" in result.output
+
+
+def test_int_enum(mod: ModuleType):
+ result = runner.invoke(mod.app, ["--access", "open"])
+ assert result.exit_code == 0
+ assert "Access level: open (4)" in result.output
+
+
+def test_script(mod: ModuleType):
+ result = subprocess.run(
+ [sys.executable, "-m", "coverage", "run", mod.__file__, "--help"],
+ capture_output=True,
+ encoding="utf-8",
+ )
+ assert "Usage" in result.stdout
diff --git a/tests/test_tutorial/test_parameter_types/test_enum/test_tutorial006.py b/tests/test_tutorial/test_parameter_types/test_enum/test_tutorial006.py
new file mode 100644
index 0000000000..0e7308275b
--- /dev/null
+++ b/tests/test_tutorial/test_parameter_types/test_enum/test_tutorial006.py
@@ -0,0 +1,57 @@
+import importlib
+import subprocess
+import sys
+from types import ModuleType
+
+import pytest
+from typer.testing import CliRunner
+
+runner = CliRunner()
+
+
+@pytest.fixture(
+ name="mod",
+ params=[
+ pytest.param("tutorial006_py310"),
+ pytest.param("tutorial006_an_py310"),
+ ],
+)
+def get_mod(request: pytest.FixtureRequest) -> ModuleType:
+ module_name = f"docs_src.parameter_types.enum.{request.param}"
+ mod = importlib.import_module(module_name)
+ return mod
+
+
+def test_help(mod: ModuleType):
+ result = runner.invoke(mod.app, ["--help"])
+ assert result.exit_code == 0
+ assert "--groceries" in result.output
+ assert "[f1|f2|f3]" in result.output
+ assert "default: f1, f3" in result.output
+
+
+def test_call_no_arg(mod: ModuleType):
+ result = runner.invoke(mod.app)
+ assert result.exit_code == 0
+ assert "Buying groceries: Eggs, Cheese" in result.output
+
+
+def test_call_single_arg(mod: ModuleType):
+ result = runner.invoke(mod.app, ["--groceries", "f2"])
+ assert result.exit_code == 0
+ assert "Buying groceries: Bacon" in result.output
+
+
+def test_call_multiple_arg(mod: ModuleType):
+ result = runner.invoke(mod.app, ["--groceries", "f1", "--groceries", "f2"])
+ assert result.exit_code == 0
+ assert "Buying groceries: Eggs, Bacon" in result.output
+
+
+def test_script(mod: ModuleType):
+ result = subprocess.run(
+ [sys.executable, "-m", "coverage", "run", mod.__file__, "--help"],
+ capture_output=True,
+ encoding="utf-8",
+ )
+ assert "Usage" in result.stdout
diff --git a/tests/test_tutorial/test_parameter_types/test_enum/test_tutorial007.py b/tests/test_tutorial/test_parameter_types/test_enum/test_tutorial007.py
new file mode 100644
index 0000000000..51322b0ade
--- /dev/null
+++ b/tests/test_tutorial/test_parameter_types/test_enum/test_tutorial007.py
@@ -0,0 +1,43 @@
+import importlib
+import subprocess
+import sys
+from types import ModuleType
+
+import pytest
+from typer.testing import CliRunner
+
+runner = CliRunner()
+
+
+@pytest.fixture(
+ name="mod",
+ params=[
+ pytest.param("tutorial007_py310"),
+ pytest.param("tutorial007_an_py310"),
+ ],
+)
+def get_mod(request: pytest.FixtureRequest) -> ModuleType:
+ module_name = f"docs_src.parameter_types.enum.{request.param}"
+ mod = importlib.import_module(module_name)
+ return mod
+
+
+def test_enum_names_default(mod: ModuleType):
+ result = runner.invoke(mod.app)
+ assert result.exit_code == 0
+ assert "Log level set to: WARNING" in result.output
+
+
+def test_enum_names(mod: ModuleType):
+ result = runner.invoke(mod.app, ["--log-level", "debug"])
+ assert result.exit_code == 0
+ assert "Log level set to: DEBUG" in result.output
+
+
+def test_script(mod: ModuleType):
+ result = subprocess.run(
+ [sys.executable, "-m", "coverage", "run", mod.__file__, "--help"],
+ capture_output=True,
+ encoding="utf-8",
+ )
+ assert "Usage" in result.stdout
diff --git a/typer/main.py b/typer/main.py
index f4f21bb844..3a0c7b5750 100644
--- a/typer/main.py
+++ b/typer/main.py
@@ -1424,12 +1424,17 @@ def get_command_from_info(
return command
-def determine_type_convertor(type_: Any) -> Callable[[Any], Any] | None:
+def determine_type_convertor(
+ type_: Any, enum_by_name: bool
+) -> Callable[[Any], Any] | None:
convertor: Callable[[Any], Any] | None = None
if lenient_issubclass(type_, Path):
convertor = param_path_convertor
if lenient_issubclass(type_, Enum):
- convertor = generate_enum_convertor(type_)
+ if enum_by_name:
+ convertor = generate_enum_name_convertor(type_)
+ else:
+ convertor = generate_enum_convertor(type_)
return convertor
@@ -1454,6 +1459,18 @@ def convertor(value: Any) -> Any:
return convertor
+def generate_enum_name_convertor(enum: type[Enum]) -> Callable[..., Any]:
+ val_map = {str(item.name): item for item in enum}
+
+ def convertor(value: Any) -> Any:
+ if value is not None:
+ val = str(value)
+ if val in val_map:
+ return val_map[val]
+
+ return convertor
+
+
def generate_list_convertor(
convertor: Callable[[Any], Any] | None, default_value: Any | None
) -> Callable[[Sequence[Any] | None], list[Any] | None]:
@@ -1467,8 +1484,9 @@ def internal_convertor(value: Sequence[Any] | None) -> list[Any] | None:
def generate_tuple_convertor(
types: Sequence[Any],
+ enum_by_name: bool,
) -> Callable[[tuple[Any, ...] | None], tuple[Any, ...] | None]:
- convertors = [determine_type_convertor(type_) for type_ in types]
+ convertors = [determine_type_convertor(type_, enum_by_name) for type_ in types]
def internal_convertor(
param_args: tuple[Any, ...] | None,
@@ -1603,15 +1621,16 @@ def get_click_type(
atomic=parameter_info.atomic,
)
elif lenient_issubclass(annotation, Enum):
+ if parameter_info.enum_by_name:
+ choices = [item.name for item in annotation]
+ else:
+ choices = [item.value for item in annotation]
# The custom TyperChoice is only needed for Click < 8.2.0, to parse the
# command line values matching them to the enum values. Click 8.2.0 added
# support for enum values but reading enum names.
# Passing here the list of enum values (instead of just the enum) accounts for
# Click < 8.2.0.
- return TyperChoice(
- [item.value for item in annotation],
- case_sensitive=parameter_info.case_sensitive,
- )
+ return TyperChoice(choices, case_sensitive=parameter_info.case_sensitive)
elif is_literal_type(annotation):
return click.Choice(
literal_values(annotation),
@@ -1690,13 +1709,14 @@ def get_click_param(
parameter_type = get_click_type(
annotation=main_type, parameter_info=parameter_info
)
- convertor = determine_type_convertor(main_type)
+ enum_by_name = parameter_info.enum_by_name
+ convertor = determine_type_convertor(main_type, enum_by_name)
if is_list:
convertor = generate_list_convertor(
convertor=convertor, default_value=default_value
)
if is_tuple:
- convertor = generate_tuple_convertor(get_args(main_type))
+ convertor = generate_tuple_convertor(get_args(main_type), enum_by_name)
if isinstance(parameter_info, OptionInfo):
if main_type is bool:
is_flag = True
diff --git a/typer/models.py b/typer/models.py
index 3285a96a24..448bbef561 100644
--- a/typer/models.py
+++ b/typer/models.py
@@ -303,6 +303,7 @@ def __init__(
hidden: bool = False,
# Choice
case_sensitive: bool = True,
+ enum_by_name: bool = False,
# Numbers
min: int | float | None = None,
max: int | float | None = None,
@@ -355,6 +356,7 @@ def __init__(
self.hidden = hidden
# Choice
self.case_sensitive = case_sensitive
+ self.enum_by_name = enum_by_name
# Numbers
self.min = min
self.max = max
@@ -421,6 +423,7 @@ def __init__(
show_envvar: bool = True,
# Choice
case_sensitive: bool = True,
+ enum_by_name: bool = False,
# Numbers
min: int | float | None = None,
max: int | float | None = None,
@@ -467,6 +470,7 @@ def __init__(
hidden=hidden,
# Choice
case_sensitive=case_sensitive,
+ enum_by_name=enum_by_name,
# Numbers
min=min,
max=max,
@@ -540,6 +544,7 @@ def __init__(
hidden: bool = False,
# Choice
case_sensitive: bool = True,
+ enum_by_name: bool = False,
# Numbers
min: int | float | None = None,
max: int | float | None = None,
@@ -586,6 +591,7 @@ def __init__(
hidden=hidden,
# Choice
case_sensitive=case_sensitive,
+ enum_by_name=enum_by_name,
# Numbers
min=min,
max=max,
diff --git a/typer/params.py b/typer/params.py
index b325b273c4..3b4e1dc7ac 100644
--- a/typer/params.py
+++ b/typer/params.py
@@ -49,6 +49,7 @@ def Option(
show_envvar: bool = True,
# Choice
case_sensitive: bool = True,
+ enum_by_name: bool = False,
# Numbers
min: int | float | None = None,
max: int | float | None = None,
@@ -114,6 +115,7 @@ def Option(
show_envvar: bool = True,
# Choice
case_sensitive: bool = True,
+ enum_by_name: bool = False,
# Numbers
min: int | float | None = None,
max: int | float | None = None,
@@ -581,7 +583,8 @@ class NeuralNetwork(str, Enum):
@app.command()
def main(
- network: Annotated[NeuralNetwork, typer.Option(case_sensitive=False)]):
+ network: Annotated[NeuralNetwork, typer.Option(case_sensitive=False)]
+ ):
print(f"Training neural network of type: {network.value}")
```
@@ -589,6 +592,32 @@ def main(
"""
),
] = True,
+ enum_by_name: Annotated[
+ bool,
+ Doc(
+ """
+ For a CLI Option representing an [Enum (choice)](https://typer.tiangolo.com/tutorial/parameter-types/enum),
+ accept names from the command line instead of values.
+
+ **Example**
+
+ ```python
+ from enum import Enum
+
+ class Food(str, Enum):
+ f1 = "Eggs"
+ f2 = "Bacon"
+ f3 = "Cheese"
+
+ @app.command()
+ def main(
+ groceries: Annotated[list[Food], typer.Option(enum_by_name=True)] = ["f1", "f3"]
+ ):
+ print(f"Buying groceries: {', '.join([f.value for f in groceries])}")
+ ```
+ """
+ ),
+ ] = False,
# Numbers
min: Annotated[
int | float | None,
@@ -974,6 +1003,7 @@ def register(
show_envvar=show_envvar,
# Choice
case_sensitive=case_sensitive,
+ enum_by_name=enum_by_name,
# Numbers
min=min,
max=max,
@@ -1030,6 +1060,7 @@ def Argument(
hidden: bool = False,
# Choice
case_sensitive: bool = True,
+ enum_by_name: bool = False,
# Numbers
min: int | float | None = None,
max: int | float | None = None,
@@ -1086,6 +1117,7 @@ def Argument(
hidden: bool = False,
# Choice
case_sensitive: bool = True,
+ enum_by_name: bool = False,
# Numbers
min: int | float | None = None,
max: int | float | None = None,
@@ -1431,6 +1463,32 @@ def main(
"""
),
] = True,
+ enum_by_name: Annotated[
+ bool,
+ Doc(
+ """
+ For a CLI Argument representing an [Enum (choice)](https://typer.tiangolo.com/tutorial/parameter-types/enum),
+ accept names from the command line instead of values.
+
+ **Example**
+
+ ```python
+ from enum import Enum
+
+ class Food(str, Enum):
+ f1 = "Eggs"
+ f2 = "Bacon"
+ f3 = "Cheese"
+
+ @app.command()
+ def main(
+ groceries: Annotated[list[Food], typer.Argument(enum_by_name=True)] = ["f1", "f3"]
+ ):
+ print(f"Buying groceries: {', '.join([f.value for f in groceries])}")
+ ```
+ """
+ ),
+ ] = False,
# Numbers
min: Annotated[
int | float | None,
@@ -1805,6 +1863,7 @@ def main(name: Annotated[str, typer.Argument()] = "World"):
hidden=hidden,
# Choice
case_sensitive=case_sensitive,
+ enum_by_name=enum_by_name,
# Numbers
min=min,
max=max,