diff --git a/examples/train_dagger_atari_interactive_policy.py b/examples/train_dagger_atari_interactive_policy.py index bb32f7194..103c2a4b6 100644 --- a/examples/train_dagger_atari_interactive_policy.py +++ b/examples/train_dagger_atari_interactive_policy.py @@ -7,28 +7,48 @@ import gymnasium as gym import numpy as np -from stable_baselines3.common import vec_env +import torch as th +from stable_baselines3.common import torch_layers, vec_env from imitation.algorithms import bc, dagger -from imitation.policies import interactive +from imitation.data import wrappers as data_wrappers +from imitation.policies import base as policy_base +from imitation.policies import interactive, obs_update_wrapper + + +def lr_schedule(_: float): + # Set lr_schedule to max value to force error if policy.optimizer + # is used by mistake (should use self.optimizer instead). + return th.finfo(th.float32).max + if __name__ == "__main__": rng = np.random.default_rng(0) - env = vec_env.DummyVecEnv([lambda: gym.wrappers.TimeLimit(gym.make("Pong-v4"), 10)]) - env.seed(0) + env = gym.make("PongNoFrameskip-v4", render_mode="rgb_array") + hr_env = data_wrappers.HumanReadableWrapper(env) + venv = vec_env.DummyVecEnv([lambda: hr_env]) + venv.seed(0) - expert = interactive.AtariInteractivePolicy(env) + expert = interactive.AtariInteractivePolicy(venv) + policy = policy_base.FeedForward32Policy( + observation_space=env.observation_space, + action_space=env.action_space, + lr_schedule=lr_schedule, + features_extractor_class=torch_layers.FlattenExtractor, + ) + wrapped_policy = obs_update_wrapper.RemoveHR(policy, lr_schedule=lr_schedule) bc_trainer = bc.BC( observation_space=env.observation_space, action_space=env.action_space, + policy=wrapped_policy, rng=rng, ) with tempfile.TemporaryDirectory(prefix="dagger_example_") as tmpdir: dagger_trainer = dagger.SimpleDAggerTrainer( - venv=env, + venv=venv, scratch_dir=tmpdir, expert_policy=expert, bc_trainer=bc_trainer, diff --git a/src/imitation/algorithms/dagger.py b/src/imitation/algorithms/dagger.py index fb68713e6..1e859b19e 100644 --- a/src/imitation/algorithms/dagger.py +++ b/src/imitation/algorithms/dagger.py @@ -11,7 +11,7 @@ import os import pathlib import uuid -from typing import Any, Callable, List, Mapping, Optional, Sequence, Tuple, Union +from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Tuple, Union import numpy as np import torch as th @@ -161,13 +161,16 @@ class InteractiveTrajectoryCollector(vec_env.VecEnvWrapper): """ traj_accum: Optional[rollout.TrajectoryAccumulator] - _last_obs: Optional[np.ndarray] + _last_obs: Optional[Union[Dict[str, np.ndarray], np.ndarray]] _last_user_actions: Optional[np.ndarray] def __init__( self, venv: vec_env.VecEnv, - get_robot_acts: Callable[[np.ndarray], np.ndarray], + get_robot_acts: Callable[ + [Union[Dict[str, np.ndarray], np.ndarray]], + np.ndarray, + ], beta: float, save_dir: types.AnyPath, rng: np.random.Generator, @@ -213,7 +216,7 @@ def seed(self, seed: Optional[int] = None) -> List[Optional[int]]: self.rng = np.random.default_rng(seed=seed) return list(self.venv.seed(seed)) - def reset(self) -> np.ndarray: + def reset(self) -> Union[np.ndarray, Dict[str, np.ndarray]]: """Resets the environment. Returns: @@ -221,8 +224,12 @@ def reset(self) -> np.ndarray: """ self.traj_accum = rollout.TrajectoryAccumulator() obs = self.venv.reset() - assert isinstance(obs, np.ndarray) - for i, ob in enumerate(obs): + assert isinstance( + obs, + (np.ndarray, dict), + ), "Tuple observations are not supported." + dictobs = types.maybe_wrap_in_dictobs(obs) + for i, ob in enumerate(dictobs): self.traj_accum.add_step({"obs": ob}, key=i) self._last_obs = obs self._is_reset = True @@ -256,7 +263,9 @@ def step_async(self, actions: np.ndarray) -> None: mask = self.rng.uniform(0, 1, size=(self.num_envs,)) > self.beta if np.sum(mask) != 0: - actual_acts[mask] = self.get_robot_acts(self._last_obs[mask]) + last_obs = types.maybe_wrap_in_dictobs(self._last_obs) + obs_for_robot = types.maybe_unwrap_dictobs(last_obs[mask]) + actual_acts[mask] = self.get_robot_acts(obs_for_robot) self._last_user_actions = actions self.venv.step_async(actual_acts) @@ -270,9 +279,13 @@ def step_wait(self) -> VecEnvStepReturn: Observation, reward, dones (is terminal?) and info dict. """ next_obs, rews, dones, infos = self.venv.step_wait() - assert isinstance(next_obs, np.ndarray) assert self.traj_accum is not None assert self._last_user_actions is not None + assert isinstance( + next_obs, + (np.ndarray, dict), + ), "Tuple observations are not supported." + self._last_obs = next_obs fresh_demos = self.traj_accum.add_steps_and_auto_finish( obs=next_obs, @@ -508,7 +521,7 @@ def create_trajectory_collector(self) -> InteractiveTrajectoryCollector: beta = self.beta_schedule(self.round_num) collector = InteractiveTrajectoryCollector( venv=self.venv, - get_robot_acts=lambda acts: self.bc_trainer.policy.predict(acts)[0], + get_robot_acts=lambda obs: self.bc_trainer.policy.predict(obs)[0], beta=beta, save_dir=save_dir, rng=self.rng, @@ -550,7 +563,7 @@ def save_trainer(self) -> Tuple[pathlib.Path, pathlib.Path]: class SimpleDAggerTrainer(DAggerTrainer): - """Simpler subclass of DAggerTrainer for training with synthetic feedback.""" + """Simpler subclass of DAggerTrainer for training with feedback.""" def __init__( self, @@ -571,7 +584,7 @@ def __init__( simultaneously for that timestep. scratch_dir: Directory to use to store intermediate training information (e.g. for resuming training). - expert_policy: The expert policy used to generate synthetic demonstrations. + expert_policy: The expert policy used to generate demonstrations. rng: Random state to use for the random number generator. expert_trajs: Optional starting dataset that is inserted into the round 0 dataset. diff --git a/src/imitation/data/types.py b/src/imitation/data/types.py index 573176ffe..be6c5312a 100644 --- a/src/imitation/data/types.py +++ b/src/imitation/data/types.py @@ -94,7 +94,7 @@ def dict_len(self): def __getitem__( self, - key: Union[int, slice, Tuple[Union[int, slice], ...]], + key: Union[int, slice, Tuple[Union[int, slice], ...], np.ndarray], ) -> "DictObs": """Indexes or slices each array. diff --git a/src/imitation/data/wrappers.py b/src/imitation/data/wrappers.py index 94c88111d..ee0148635 100644 --- a/src/imitation/data/wrappers.py +++ b/src/imitation/data/wrappers.py @@ -1,14 +1,18 @@ """Environment wrappers for collecting rollouts.""" -from typing import List, Optional, Sequence, Tuple +from typing import Dict, List, Optional, Sequence, Tuple, Union import gymnasium as gym import numpy as np import numpy.typing as npt +from gymnasium.core import Env from stable_baselines3.common.vec_env import VecEnv, VecEnvWrapper from imitation.data import rollout, types +# The key for human readable data in the observation. +HR_OBS_KEY = "HR_OBS" + class BufferingWrapper(VecEnvWrapper): """Saves transitions of underlying VecEnv. @@ -170,7 +174,7 @@ def pop_transitions(self) -> types.TransitionsWithRew: class RolloutInfoWrapper(gym.Wrapper): - """Add the entire episode's rewards and observations to `info` at episode end. + """Adds the entire episode's rewards and observations to `info` at episode end. Whenever done=True, `info["rollouts"]` is a dict with keys "obs" and "rews", whose corresponding values hold the NumPy arrays containing the raw observations and @@ -206,3 +210,54 @@ def step(self, action): "rews": np.stack(self._rews), } return obs, rew, terminated, truncated, info + + +class HumanReadableWrapper(gym.ObservationWrapper): + """Adds human-readable observation to `obs` at every step.""" + + def __init__(self, env: Env, original_obs_key: str = "ORI_OBS"): + """Builds HumanReadableWrapper. + + Args: + env: Environment to wrap. + original_obs_key: The key for original observation if the original + observation is not in dict format. + + Raises: + ValueError: If `env.render_mode` is not "rgb_array". + + """ + if env.render_mode != "rgb_array": + raise ValueError( + "HumanReadableWrapper requires render_mode='rgb_array', " + f"got {env.render_mode!r}", + ) + self._original_obs_key = original_obs_key + super().__init__(env) + + def observation( + self, + obs: Union[np.ndarray, Dict[str, np.ndarray]], + ) -> Dict[str, np.ndarray]: + """Adds human-readable observation to obs. + + Transforms obs into dictionary if it is not already, and adds the human-readable + observation from `env.render()` under the key HR_OBS_KEY. + + Args: + obs: Observation from environment. + + Returns: + Observation dictionary with the human-readable data. + + Raises: + KeyError: When the key HR_OBS_KEY already exists in the observation + dictionary. + """ + if not isinstance(obs, Dict): + obs = {self._original_obs_key: obs} + + if HR_OBS_KEY in obs: + raise KeyError(f"{HR_OBS_KEY!r} already exists in observation dict") + obs[HR_OBS_KEY] = self.env.render() # type: ignore[assignment] + return obs diff --git a/src/imitation/policies/interactive.py b/src/imitation/policies/interactive.py index d9934f9a0..d3bc19035 100644 --- a/src/imitation/policies/interactive.py +++ b/src/imitation/policies/interactive.py @@ -11,6 +11,7 @@ from stable_baselines3.common import vec_env import imitation.policies.base as base_policies +from imitation.data import wrappers from imitation.util import util @@ -64,9 +65,6 @@ def _choose_action( if self.clear_screen_on_query: util.clear_screen() - if isinstance(obs, dict): - raise ValueError("Dictionary observations are not supported here") - context = self._render(obs) key = self._get_input_key() self._clean_up(context) @@ -87,7 +85,10 @@ def _get_input_key(self) -> str: return key @abc.abstractmethod - def _render(self, obs: np.ndarray) -> Optional[object]: + def _render( + self, + obs: Union[np.ndarray, Dict[str, np.ndarray]], + ) -> Optional[object]: """Renders an observation, optionally returns a context for later cleanup.""" def _clean_up(self, context: object) -> None: @@ -97,7 +98,7 @@ def _clean_up(self, context: object) -> None: class ImageObsDiscreteInteractivePolicy(DiscreteInteractivePolicy): """DiscreteInteractivePolicy that renders image observations.""" - def _render(self, obs: np.ndarray) -> plt.Figure: + def _render(self, obs: Union[np.ndarray, Dict[str, np.ndarray]]) -> plt.Figure: img = self._prepare_obs_image(obs) fig, ax = plt.subplots() @@ -110,9 +111,16 @@ def _render(self, obs: np.ndarray) -> plt.Figure: def _clean_up(self, context: plt.Figure) -> None: plt.close(context) - def _prepare_obs_image(self, obs: np.ndarray) -> np.ndarray: + def _prepare_obs_image( + self, + obs: Union[np.ndarray, Dict[str, np.ndarray]], + ) -> np.ndarray: """Applies any required observation processing to get an image to show.""" - return obs + if not isinstance(obs, Dict): + return obs + if wrappers.HR_OBS_KEY not in obs: + raise KeyError(f"Observation does not contain {wrappers.HR_OBS_KEY!r}") + return obs[wrappers.HR_OBS_KEY] ATARI_ACTION_NAMES_TO_KEYS = { diff --git a/src/imitation/policies/obs_update_wrapper.py b/src/imitation/policies/obs_update_wrapper.py new file mode 100644 index 000000000..4c79e694a --- /dev/null +++ b/src/imitation/policies/obs_update_wrapper.py @@ -0,0 +1,121 @@ +"""Updates observation for the policy to use.""" + +import abc +from typing import Dict, Tuple, Union + +import numpy as np +import torch as th +from stable_baselines3.common.policies import ActorCriticPolicy +from stable_baselines3.common.type_aliases import Schedule + +from imitation.data import wrappers as data_wrappers + + +class Base(ActorCriticPolicy, abc.ABC): + """Updates the observation for the policy to use.""" + + def __init__(self, policy: ActorCriticPolicy, lr_schedule: Schedule): + """Builds the wrapper base and initializes the policy. + + Args: + policy: The policy to wrap. + lr_schedule: The learning rate schedule. + """ + if policy.use_sde: + assert policy.dist_kwargs is not None + full_std = policy.dist_kwargs["use_sde"] + use_expln = policy.dist_kwargs["use_expln"] + else: + full_std = True + use_expln = False + super().__init__( + observation_space=policy.observation_space, + action_space=policy.action_space, + lr_schedule=lr_schedule, + net_arch=policy.net_arch, + activation_fn=policy.activation_fn, + ortho_init=policy.ortho_init, + use_sde=policy.use_sde, + log_std_init=policy.log_std_init, + full_std=full_std, + use_expln=use_expln, + share_features_extractor=policy.share_features_extractor, + squash_output=policy.squash_output, + features_extractor_class=policy.features_extractor_class, + features_extractor_kwargs=policy.features_extractor_kwargs, + normalize_images=policy.normalize_images, + optimizer_class=policy.optimizer_class, + optimizer_kwargs=policy.optimizer_kwargs, + ) + + @abc.abstractmethod + def _update_ob( + self, + obs: Union[np.ndarray, Dict[str, np.ndarray]], + ) -> Union[np.ndarray, Dict[str, np.ndarray]]: + """Updates the observation for the policy to use.""" + + def _predict( + self, + observation: th.Tensor, + deterministic: bool = False, + ) -> th.Tensor: + """Gets the action according to the policy for a given observation.""" + return super()._predict(observation, deterministic) + + def is_vectorized_observation( + self, + observation: Union[np.ndarray, Dict[str, np.ndarray]], + ) -> bool: + """Checks whether or not the observation is vectorized.""" + observation = self._update_ob(observation) + return super().is_vectorized_observation(observation) + + def obs_to_tensor( + self, + observation: Union[np.ndarray, Dict[str, np.ndarray]], + ) -> Tuple[th.Tensor, bool]: + """Converts an observation to a PyTorch tensor that can be fed to a model.""" + observation = self._update_ob(observation) + return super().obs_to_tensor(observation) + + +class RemoveHR(Base): + """Removes human readable observation for the policy to use.""" + + def __init__(self, policy: ActorCriticPolicy, lr_schedule: Schedule): + """Builds the wrapper that removes human readable observation for the policy. + + Args: + policy: The policy to wrap. + lr_schedule: The learning rate schedule. + """ + super().__init__(policy, lr_schedule) + + def _update_ob( + self, + obs: Union[np.ndarray, Dict[str, np.ndarray]], + ) -> Union[np.ndarray, Dict[str, np.ndarray]]: + """Removes the human readable observation if any.""" + return _remove_hr_obs(obs) + + +def _remove_hr_obs( + obs: Union[np.ndarray, Dict[str, np.ndarray]], +) -> Union[np.ndarray, Dict[str, np.ndarray]]: + """Removes the human readable observation if any.""" + if not isinstance(obs, dict): + return obs + if data_wrappers.HR_OBS_KEY not in obs: + return obs + if len(obs) == 1: + raise ValueError( + "Only human readable observation exists, can't remove it", + ) + # keeps the original observation unchanged in case it is used elsewhere. + new_obs = obs.copy() + del new_obs[data_wrappers.HR_OBS_KEY] + if len(new_obs) == 1: + # unwrap dictionary structure + return next(iter(new_obs.values())) # type: ignore[return-value] + return new_obs diff --git a/tests/data/test_types.py b/tests/data/test_types.py index 74c658c26..e16b7393a 100644 --- a/tests/data/test_types.py +++ b/tests/data/test_types.py @@ -519,3 +519,10 @@ def test_dict_obs(): with pytest.raises(TypeError): types.DictObs({"a": "not an array"}) # type: ignore[wrong-arg-types] + + +def test_dict_obs_indexing_bool(): + # boolean indexing requires the same length in the indexing dimension. + A, B = np.random.rand(4, 3), np.random.rand(4, 2) + d = types.DictObs({"a": A, "b": B}) + np.testing.assert_equal(d[np.array([True, False, False, False])].get("a"), A[[0]]) diff --git a/tests/data/test_wrappers.py b/tests/data/test_wrappers.py index 33677c68f..8f430bbdb 100644 --- a/tests/data/test_wrappers.py +++ b/tests/data/test_wrappers.py @@ -1,6 +1,6 @@ """Tests for `imitation.data.wrappers`.""" -from typing import List, Sequence, Type +from typing import Dict, List, Optional, Sequence import gymnasium as gym import numpy as np @@ -8,7 +8,7 @@ from stable_baselines3.common.vec_env import DummyVecEnv from imitation.data import types -from imitation.data.wrappers import BufferingWrapper +from imitation.data.wrappers import HR_OBS_KEY, BufferingWrapper, HumanReadableWrapper class _CountingEnv(gym.Env): # pragma: no cover @@ -24,14 +24,19 @@ class _CountingEnv(gym.Env): # pragma: no cover ``` """ - def __init__(self, episode_length=5): + def __init__(self, episode_length: int = 5, render_mode: Optional[str] = None): assert episode_length >= 1 self.observation_space = gym.spaces.Box(low=0, high=np.inf, shape=()) self.action_space = gym.spaces.Box(low=0, high=np.inf, shape=()) self.episode_length = episode_length - self.timestep = None + self.timestep: Optional[int] = None + self._render_mode = render_mode - def reset(self, seed=None): + @property + def render_mode(self): + return self._render_mode + + def reset(self, *, seed: Optional[int] = None, options: Optional[dict] = None): t, self.timestep = 0, 1 return t, {} @@ -47,19 +52,27 @@ def step(self, action): done = t == self.episode_length return t, t * 10, done, False, {} + def render(self): + if self._render_mode != "rgb_array": + raise ValueError(f"Invalid render mode {self._render_mode}") + return np.array([self.timestep] * 10) + class _CountingDictEnv(_CountingEnv): # pragma: no cover """Similar to _CountingEnv, but with Dict observation.""" - def __init__(self, episode_length=5): - super().__init__(episode_length) + def __init__(self, episode_length: int = 5, render_mode: Optional[str] = None): + super().__init__(episode_length, render_mode) self.observation_space = gym.spaces.Dict( - spaces={"t": gym.spaces.Box(low=0, high=np.inf, shape=())}, + spaces={ + "t": gym.spaces.Box(low=0, high=np.inf, shape=()), + "2t": gym.spaces.Box(low=0, high=np.inf, shape=()), + }, ) - def reset(self, seed=None): - t, self.timestep = 0.0, 1.0 - return {"t": t}, {} + def reset(self, *, seed: Optional[int] = None, options: Optional[dict] = None): + t, self.timestep = 0, 1 + return {"t": t, "2t": 2 * t}, {} def step(self, action): if self.timestep is None: @@ -71,16 +84,17 @@ def step(self, action): t, self.timestep = self.timestep, self.timestep + 1 done = t == self.episode_length - return {"t": t}, t * 10, done, False, {} + return {"t": t, "2t": 2 * t}, t * 10, done, False, {} Envs = [_CountingEnv, _CountingDictEnv] def _make_buffering_venv( - Env: Type[gym.Env], + Env, error_on_premature_reset: bool, ) -> BufferingWrapper: + venv = DummyVecEnv([Env] * 2) venv = DummyVecEnv([Env] * 2) wrapped_venv = BufferingWrapper(venv, error_on_premature_reset) wrapped_venv.reset() @@ -121,7 +135,7 @@ def concat(x): @pytest.mark.parametrize("n_steps", [1, 2, 20, 21]) @pytest.mark.parametrize("extra_pop_timesteps", [(), (1,), (4, 8)]) def test_pop( - Env: Type[gym.Env], + Env, episode_lengths: Sequence[int], n_steps: int, extra_pop_timesteps: Sequence[int], @@ -151,6 +165,7 @@ def test_pop( ``` Args: + Env: Environment class type. Env: Environment class type. episode_lengths: The number of timesteps before episode end in each dummy environment. @@ -175,6 +190,7 @@ def test_pop( def make_env(ep_len): return lambda: Env(episode_length=ep_len) + return lambda: Env(episode_length=ep_len) venv = DummyVecEnv([make_env(ep_len) for ep_len in episode_lengths]) venv_buffer = BufferingWrapper(venv) @@ -190,8 +206,13 @@ def make_env(ep_len): np.testing.assert_array_equal(obs, [0] * venv.num_envs) else: np.testing.assert_array_equal(obs["t"], [0] * venv.num_envs) + if Env == _CountingEnv: + np.testing.assert_array_equal(obs, [0] * venv.num_envs) + else: + np.testing.assert_array_equal(obs["t"], [0] * venv.num_envs) for t in range(1, n_steps + 1): + acts = obs * 2.1 if Env == _CountingEnv else obs["t"] * 2.1 acts = obs * 2.1 if Env == _CountingEnv else obs["t"] * 2.1 venv_buffer.step_async(acts) obs, *_ = venv_buffer.step_wait() @@ -223,20 +244,30 @@ def make_env(ep_len): actual_next_obs = types.DictObs.stack(trans.next_obs).get("t") _assert_equal_scrambled_vectors(actual_obs, expect_obs) _assert_equal_scrambled_vectors(actual_next_obs, expect_next_obs) + if Env == _CountingEnv: + actual_obs = types.assert_not_dictobs(trans.obs) + actual_next_obs = types.assert_not_dictobs(trans.next_obs) + else: + actual_obs = types.DictObs.stack(trans.obs).get("t") + actual_next_obs = types.DictObs.stack(trans.next_obs).get("t") + _assert_equal_scrambled_vectors(actual_obs, expect_obs) + _assert_equal_scrambled_vectors(actual_next_obs, expect_next_obs) _assert_equal_scrambled_vectors(trans.acts, expect_acts) _assert_equal_scrambled_vectors(trans.rews, expect_rews) @pytest.mark.parametrize("Env", Envs) -def test_reset_error(Env: Type[gym.Env]): +def test_reset_error(Env): # Resetting before a `step()` is okay. for flag in [True, False]: + venv = _make_buffering_venv(Env, flag) venv = _make_buffering_venv(Env, flag) for _ in range(10): venv.reset() # Resetting after a `step()` is not okay if error flag is True. venv = _make_buffering_venv(Env, True) + venv = _make_buffering_venv(Env, True) zeros = np.array([0.0, 0.0], dtype=venv.action_space.dtype) venv.step(zeros) with pytest.raises(RuntimeError, match="before samples were accessed"): @@ -244,6 +275,7 @@ def test_reset_error(Env: Type[gym.Env]): # Same as previous case, but insert a `pop_transitions()` in between. venv = _make_buffering_venv(Env, True) + venv = _make_buffering_venv(Env, True) venv.step(zeros) venv.pop_transitions() venv.step(zeros) @@ -257,6 +289,7 @@ def test_reset_error(Env: Type[gym.Env]): # Resetting after a `step()` is ok if transitions are first collected. for flag in [True, False]: + venv = _make_buffering_venv(Env, flag) venv = _make_buffering_venv(Env, flag) venv.step(zeros) venv.pop_transitions() @@ -264,7 +297,7 @@ def test_reset_error(Env: Type[gym.Env]): @pytest.mark.parametrize("Env", Envs) -def test_n_transitions_and_empty_error(Env: Type[gym.Env]): +def test_n_transitions_and_empty_error(Env): venv = _make_buffering_venv(Env, True) trajs, ep_lens = venv.pop_trajectories() assert trajs == [] @@ -278,3 +311,30 @@ def test_n_transitions_and_empty_error(Env: Type[gym.Env]): assert venv.n_transitions == 0 with pytest.raises(RuntimeError, match=".* empty .*"): venv.pop_transitions() + + +@pytest.mark.parametrize("Env", Envs) +@pytest.mark.parametrize("original_obs_key", ["k1", "k2"]) +def test_human_readable_wrapper(Env, original_obs_key): + num_obs_key_expected = 2 if Env is _CountingEnv else 3 + origin_obs_key = original_obs_key if Env is _CountingEnv else "t" + ori_env = Env(render_mode="rgb_array") + hr_env = HumanReadableWrapper( + ori_env, + original_obs_key=original_obs_key, + ) + assert hr_env.observation_space == ori_env.observation_space + + obs, _ = hr_env.reset() + assert isinstance(obs, Dict) + assert HR_OBS_KEY in obs + assert len(obs) == num_obs_key_expected + assert obs[origin_obs_key] == 0 + _assert_equal_scrambled_vectors(obs[HR_OBS_KEY], np.array([1] * 10)) + + next_obs, *_ = hr_env.step(hr_env.action_space.sample()) + assert isinstance(next_obs, Dict) + assert HR_OBS_KEY in next_obs + assert len(next_obs) == num_obs_key_expected + assert next_obs[origin_obs_key] == 1 + _assert_equal_scrambled_vectors(next_obs[HR_OBS_KEY], np.array([2] * 10)) diff --git a/tests/policies/test_interactive.py b/tests/policies/test_interactive.py index 6326ef433..95824a424 100644 --- a/tests/policies/test_interactive.py +++ b/tests/policies/test_interactive.py @@ -1,7 +1,7 @@ """Tests interactive policies.""" import collections -from typing import cast +from typing import Dict, Union, cast from unittest import mock import gymnasium as gym @@ -23,7 +23,7 @@ class NoRenderingDiscreteInteractivePolicy(interactive.DiscreteInteractivePolicy): """DiscreteInteractivePolicy with no rendering.""" - def _render(self, obs: np.ndarray) -> None: + def _render(self, obs: Union[np.ndarray, Dict[str, np.ndarray]]) -> None: pass diff --git a/tests/policies/test_obs_update_wrapper.py b/tests/policies/test_obs_update_wrapper.py new file mode 100644 index 000000000..2f18c7d4e --- /dev/null +++ b/tests/policies/test_obs_update_wrapper.py @@ -0,0 +1,90 @@ +"""Tests for `imitation.policies.obs_update_wrapper`.""" + +from typing import Dict + +import gymnasium as gym +import numpy as np +import pytest +from stable_baselines3.common import torch_layers + +from imitation.data.wrappers import HR_OBS_KEY, HumanReadableWrapper +from imitation.policies import base as policy_base +from imitation.policies.obs_update_wrapper import RemoveHR, _remove_hr_obs + + +@pytest.mark.parametrize("use_hr_wrapper", [True, False]) +def test_remove_hr(use_hr_wrapper: bool): + env = gym.make("PongNoFrameskip-v4", render_mode="rgb_array") + new_env = HumanReadableWrapper(env) if use_hr_wrapper else env + policy = policy_base.FeedForward32Policy( + observation_space=env.observation_space, + action_space=env.action_space, + lr_schedule=lambda _: 1.0, + features_extractor_class=torch_layers.FlattenExtractor, + ) + wrapped_policy = RemoveHR(policy, lr_schedule=lambda _: 1.0) + assert wrapped_policy.net_arch == policy.net_arch + + obs, _ = env.reset(seed=0) + pred_action, _ = wrapped_policy.predict(obs, deterministic=True) + + new_obs, _ = new_env.reset(seed=0) + pred_action_with_hr, _ = wrapped_policy.predict(new_obs, deterministic=True) + assert np.equal(pred_action, pred_action_with_hr).all() + + +@pytest.mark.parametrize( + ("testname", "obs", "expected_obs"), + [ + ( + "np.ndarray", + np.array([1]), + np.array([1]), + ), + ( + "dict with np.ndarray", + {"a": np.array([1])}, + {"a": np.array([1])}, + ), + ( + "dict hr removed successfully and got unwrapped from dict", + { + "a": np.array([1]), + HR_OBS_KEY: np.array([3]), + }, + np.array([1]), + ), + ( + "dict hr removed successfully and got dict", + { + "a": np.array([1]), + "b": np.array([2]), + HR_OBS_KEY: np.array([3]), + }, + { + "a": np.array([1]), + "b": np.array([2]), + }, + ), + ], +) +def test_remove_hr_ob(testname, obs, expected_obs): + got_obs = _remove_hr_obs(obs) + assert type(got_obs) is type(expected_obs) + if isinstance(got_obs, (Dict, gym.spaces.Dict)): + assert len(got_obs.keys()) == len(expected_obs.keys()) + for k, v in got_obs.items(): + assert v == expected_obs[k] + else: + assert got_obs == expected_obs + + +def test_remove_hr_obs_failure(): + with pytest.raises(ValueError, match="Only human readable observation*"): + _remove_hr_obs({HR_OBS_KEY: np.array([1])}) + + +def test_remove_hr_obs_still_keep_origin_space_rgb(): + obs = {"a": np.array([1]), HR_OBS_KEY: np.array([2])} + _remove_hr_obs(obs) + assert HR_OBS_KEY in obs