Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions example_script.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,16 @@ def raise_error():
raise ValueError("Something went wrong!")


def optional_value(value: Optional[int] = None):
"""
A command which accepts an optional value.
"""
if value:
print(value + 1)
else:
print("Unknown")


if __name__ == "__main__":
cli = CLI()
cli.register(say_hello)
Expand All @@ -152,4 +162,5 @@ def raise_error():
cli.register(add, command_name="sum")
cli.register(print_address)
cli.register(raise_error)
cli.register(optional_value)
cli.run()
14 changes: 12 additions & 2 deletions targ/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,16 @@

from .format import Color, format_text, get_underline

# Only available in Python 3.10 and above:
try:
from types import NoneType, UnionType # type: ignore
except ImportError:
NoneType = type(None) # type: ignore

class UnionType: # type: ignore
pass


__VERSION__ = "0.6.0"


Expand Down Expand Up @@ -210,10 +220,10 @@ def call_with(self, arg_class: Arguments):

if annotation in CONVERTABLE_TYPES:
value = annotation(value)
elif get_origin(annotation) is Union: # type: ignore
elif get_origin(annotation) in [Union, UnionType]: # type: ignore
# Union is used to detect Optional
inner_annotations = get_args(annotation)
filtered = [i for i in inner_annotations if i is not None]
filtered = [i for i in inner_annotations if i is not NoneType]
if len(filtered) == 1:
annotation = filtered[0]
if annotation in CONVERTABLE_TYPES:
Expand Down
182 changes: 134 additions & 48 deletions tests/test_command.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import dataclasses
import decimal
import sys
from typing import Any, Optional
from typing import Any, Optional, Union
from unittest import TestCase
from unittest.mock import MagicMock, patch

Expand Down Expand Up @@ -206,64 +206,89 @@ def test_command(arg1: bool = False):
@patch("targ.CLI._get_cleaned_args")
def test_optional_bool_arg(self, _get_cleaned_args: MagicMock):
"""
Test command arguments which are of type Optional[bool].
Test command arguments which are optional booleans.
"""

def test_command(arg1: Optional[bool] = None):
"""
A command for testing optional boolean arguments.
"""
if arg1 is None:
def print_arg(arg):
if arg is None:
print("arg1 is None")
elif arg1 is True:
elif arg is True:
print("arg1 is True")
elif arg1 is False:
elif arg is False:
print("arg1 is False")
else:
raise ValueError("arg1 is the wrong type")

cli = CLI()
cli.register(test_command)
def test_optional(arg1: Optional[bool] = None):
"""
A command for testing `Optional[bool]` arguments.
"""
print_arg(arg1)

with patch("builtins.print", side_effect=print_) as print_mock:
def test_union(arg1: Union[bool, None] = None):
"""
A command for testing `Union[bool, None]` arguments.
"""
print_arg(arg1)

configs: list[Config] = [
Config(
params=["test_command", "--arg1"],
output="arg1 is True",
),
Config(
params=["test_command", "--arg1=True"],
output="arg1 is True",
),
Config(
params=["test_command", "--arg1=true"],
output="arg1 is True",
),
Config(
params=["test_command", "--arg1=t"],
output="arg1 is True",
),
Config(
params=["test_command", "--arg1=False"],
output="arg1 is False",
),
Config(
params=["test_command", "--arg1=false"],
output="arg1 is False",
),
Config(
params=["test_command", "--arg1=f"],
output="arg1 is False",
),
Config(params=["test_command"], output="arg1 is None"),
]
commands = [test_optional, test_union]

for config in configs:
_get_cleaned_args.return_value = config.params
cli.run()
print_mock.assert_called_with(config.output)
print_mock.reset_mock()
if sys.version_info.major == 3 and sys.version_info.minor >= 10:

def test_union_syntax(arg1: bool | None = None): # type: ignore
"""
A command for testing `bool | None` arguments.
"""
print_arg(arg1)

commands.append(test_union_syntax)

cli = CLI()

for command in commands:
cli.register(command)

with patch("builtins.print", side_effect=print_) as print_mock:
for command in commands:
command_name = command.__name__

configs: list[Config] = [
Config(
params=[command_name, "--arg1"],
output="arg1 is True",
),
Config(
params=[command_name, "--arg1=True"],
output="arg1 is True",
),
Config(
params=[command_name, "--arg1=true"],
output="arg1 is True",
),
Config(
params=[command_name, "--arg1=t"],
output="arg1 is True",
),
Config(
params=[command_name, "--arg1=False"],
output="arg1 is False",
),
Config(
params=[command_name, "--arg1=false"],
output="arg1 is False",
),
Config(
params=[command_name, "--arg1=f"],
output="arg1 is False",
),
Config(params=[command_name], output="arg1 is None"),
]

for config in configs:
_get_cleaned_args.return_value = config.params
cli.run()
print_mock.assert_called_with(config.output)
print_mock.reset_mock()

@patch("targ.CLI._get_cleaned_args")
def test_int_arg(self, _get_cleaned_args: MagicMock):
Expand Down Expand Up @@ -302,6 +327,67 @@ def test_command(arg1: decimal.Decimal):
print_mock.assert_called_with(config.output)
print_mock.reset_mock()

@patch("targ.CLI._get_cleaned_args")
def test_optional_int_arg(self, _get_cleaned_args: MagicMock):
"""
Test command arguments which are optional int.
"""

def print_arg(arg):
if arg is None:
print("arg1 is None")
elif isinstance(arg, int):
print("arg1 is an int")
else:
raise ValueError("arg1 is the wrong type")

def test_optional(arg1: Optional[int] = None):
"""
A command for testing `Optional[int]` arguments.
"""
print_arg(arg1)

def test_union(arg1: Union[int, None] = None):
"""
A command for testing `Union[int, None]` arguments.
"""
print_arg(arg1)

commands = [test_optional, test_union]

if sys.version_info.major == 3 and sys.version_info.minor >= 10:

def test_union_syntax(arg1: int | None = None): # type: ignore
"""
A command for testing `int | None` arguments.
"""
print_arg(arg1)

commands.append(test_union_syntax)

cli = CLI()

for command in commands:
cli.register(command)

with patch("builtins.print", side_effect=print_) as print_mock:
for command in commands:
command_name = command.__name__

configs: list[Config] = [
Config(
params=[command_name, "--arg1=1"],
output="arg1 is an int",
),
Config(params=[command_name], output="arg1 is None"),
]

for config in configs:
_get_cleaned_args.return_value = config.params
cli.run()
print_mock.assert_called_with(config.output)
print_mock.reset_mock()

@patch("targ.CLI._get_cleaned_args")
def test_decimal_arg(self, _get_cleaned_args: MagicMock):
"""
Expand Down