From c964133e009776d08867df92932084af8a9c66db Mon Sep 17 00:00:00 2001 From: Giovanni Fregonese Date: Fri, 29 Mar 2024 18:02:13 +0100 Subject: [PATCH 1/6] Convert to gymnasium --- brax/envs/wrappers/gym.py | 6 +++--- brax/v1/envs/__init__.py | 2 +- brax/v1/envs/wrappers.py | 6 +++--- setup.py | 2 +- 4 files changed, 8 insertions(+), 8 deletions(-) diff --git a/brax/envs/wrappers/gym.py b/brax/envs/wrappers/gym.py index 69ff2e9b5..ea1bb3276 100644 --- a/brax/envs/wrappers/gym.py +++ b/brax/envs/wrappers/gym.py @@ -17,9 +17,9 @@ from brax.envs.base import PipelineEnv from brax.io import image -import gym -from gym import spaces -from gym.vector import utils +import gymnasium as gym +from gymnasium import spaces +from gymnasium.vector import utils import jax import numpy as np diff --git a/brax/v1/envs/__init__.py b/brax/v1/envs/__init__.py index b5240ab94..6745a4c8f 100644 --- a/brax/v1/envs/__init__.py +++ b/brax/v1/envs/__init__.py @@ -39,7 +39,7 @@ from brax.v1.envs import walker2d from brax.v1.envs import wrappers from brax.v1.envs.env import Env, State, Wrapper -import gym +import gymnasium as gym _envs = { 'acrobot': acrobot.Acrobot, diff --git a/brax/v1/envs/wrappers.py b/brax/v1/envs/wrappers.py index 6326c103f..3e20831ed 100644 --- a/brax/v1/envs/wrappers.py +++ b/brax/v1/envs/wrappers.py @@ -22,9 +22,9 @@ import dm_env from dm_env import specs import flax -import gym -from gym import spaces -from gym.vector import utils +import gymnasium as gym +from gymnasium import spaces +from gymnasium.vector import utils import jax import jax.numpy as jnp diff --git a/setup.py b/setup.py index 13bd2c164..975a31eca 100644 --- a/setup.py +++ b/setup.py @@ -46,7 +46,7 @@ "flax", # TODO: remove grpcio and gym after dropping legacy v1 code "grpcio", - "gym", + "gymnasium", "jax>=0.4.6", "jaxlib>=0.4.6", "jaxopt", From a7b6870bc36ffb3ab0f345e2bf9deaf79c065004 Mon Sep 17 00:00:00 2001 From: Giovanni Fregonese Date: Tue, 14 May 2024 11:15:12 +0200 Subject: [PATCH 2/6] Update gymnasium import to gym in brax code, as extra alternatives --- brax/envs/wrappers/gym.py | 11 ++++++++--- brax/v1/envs/__init__.py | 5 ++++- brax/v1/envs/wrappers.py | 11 ++++++++--- setup.py | 3 ++- 4 files changed, 22 insertions(+), 8 deletions(-) diff --git a/brax/envs/wrappers/gym.py b/brax/envs/wrappers/gym.py index ea1bb3276..d89a1e333 100644 --- a/brax/envs/wrappers/gym.py +++ b/brax/envs/wrappers/gym.py @@ -17,9 +17,14 @@ from brax.envs.base import PipelineEnv from brax.io import image -import gymnasium as gym -from gymnasium import spaces -from gymnasium.vector import utils +try: + import gym + from gym import spaces + from gym.vector import utils +except ImportError: + from gymnasium import gym + from gymnasium import spaces + from gymnasium.vector import utils import jax import numpy as np diff --git a/brax/v1/envs/__init__.py b/brax/v1/envs/__init__.py index 6745a4c8f..aa67d3732 100644 --- a/brax/v1/envs/__init__.py +++ b/brax/v1/envs/__init__.py @@ -39,7 +39,10 @@ from brax.v1.envs import walker2d from brax.v1.envs import wrappers from brax.v1.envs.env import Env, State, Wrapper -import gymnasium as gym +try: + import gym +except ImportError: + import gymnasium as gym _envs = { 'acrobot': acrobot.Acrobot, diff --git a/brax/v1/envs/wrappers.py b/brax/v1/envs/wrappers.py index 3e20831ed..ac5bcd019 100644 --- a/brax/v1/envs/wrappers.py +++ b/brax/v1/envs/wrappers.py @@ -22,9 +22,14 @@ import dm_env from dm_env import specs import flax -import gymnasium as gym -from gymnasium import spaces -from gymnasium.vector import utils +try: + import gym + from gym import spaces + from gym.vector import utils +except ImportError: + from gymnasium import gym + from gymnasium import spaces + from gymnasium.vector import utils import jax import jax.numpy as jnp diff --git a/setup.py b/setup.py index 975a31eca..83418310c 100644 --- a/setup.py +++ b/setup.py @@ -46,7 +46,6 @@ "flax", # TODO: remove grpcio and gym after dropping legacy v1 code "grpcio", - "gymnasium", "jax>=0.4.6", "jaxlib>=0.4.6", "jaxopt", @@ -66,6 +65,8 @@ ], extras_require={ "develop": ["pytest", "transforms3d"], + "gym": ["gym"], + "gymnasium": ["gymnasium"], }, classifiers=[ "Development Status :: 4 - Beta", From 7b0bc0ed1dcbc27a63dec1b5753c4a783ba13764 Mon Sep 17 00:00:00 2001 From: Giovanni Fregonese Date: Tue, 14 May 2024 11:56:04 +0200 Subject: [PATCH 3/6] Update gymnasium import to gym in brax code --- brax/envs/wrappers/gym.py | 4 +++- brax/v1/envs/wrappers.py | 7 +++---- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/brax/envs/wrappers/gym.py b/brax/envs/wrappers/gym.py index d89a1e333..5ac3a097f 100644 --- a/brax/envs/wrappers/gym.py +++ b/brax/envs/wrappers/gym.py @@ -17,14 +17,16 @@ from brax.envs.base import PipelineEnv from brax.io import image + try: import gym from gym import spaces from gym.vector import utils except ImportError: - from gymnasium import gym + import gymnasium as gym from gymnasium import spaces from gymnasium.vector import utils + import jax import numpy as np diff --git a/brax/v1/envs/wrappers.py b/brax/v1/envs/wrappers.py index ac5bcd019..ffde56fd7 100644 --- a/brax/v1/envs/wrappers.py +++ b/brax/v1/envs/wrappers.py @@ -19,17 +19,16 @@ from brax.v1 import jumpy as jp from brax.v1.envs import env as brax_env -import dm_env -from dm_env import specs -import flax + try: import gym from gym import spaces from gym.vector import utils except ImportError: - from gymnasium import gym + import gymnasium as gym from gymnasium import spaces from gymnasium.vector import utils + import jax import jax.numpy as jnp From d5d46f2a40fe7a69cb1ce0d8039455a50c20baa0 Mon Sep 17 00:00:00 2001 From: Giovanni Fregonese Date: Fri, 7 Jun 2024 18:17:12 +0200 Subject: [PATCH 4/6] Remove gym dependency and add gymnasium --- brax/envs/wrappers/gym.py | 11 +++-------- setup.py | 3 +-- 2 files changed, 4 insertions(+), 10 deletions(-) diff --git a/brax/envs/wrappers/gym.py b/brax/envs/wrappers/gym.py index 5ac3a097f..5b0ca29b7 100644 --- a/brax/envs/wrappers/gym.py +++ b/brax/envs/wrappers/gym.py @@ -18,14 +18,9 @@ from brax.envs.base import PipelineEnv from brax.io import image -try: - import gym - from gym import spaces - from gym.vector import utils -except ImportError: - import gymnasium as gym - from gymnasium import spaces - from gymnasium.vector import utils +import gymnasium as gym +from gymnasium import spaces +from gymnasium.vector import utils import jax import numpy as np diff --git a/setup.py b/setup.py index 83418310c..520523a7e 100644 --- a/setup.py +++ b/setup.py @@ -45,6 +45,7 @@ "flask_cors", "flax", # TODO: remove grpcio and gym after dropping legacy v1 code + "gymnasium", "grpcio", "jax>=0.4.6", "jaxlib>=0.4.6", @@ -65,8 +66,6 @@ ], extras_require={ "develop": ["pytest", "transforms3d"], - "gym": ["gym"], - "gymnasium": ["gymnasium"], }, classifiers=[ "Development Status :: 4 - Beta", From 1db22fbd97540220cec60469d58e1475ea47a033 Mon Sep 17 00:00:00 2001 From: Giovanni Fregonese Date: Fri, 7 Jun 2024 18:17:38 +0200 Subject: [PATCH 5/6] Revert v1 code --- brax/v1/envs/__init__.py | 5 +---- brax/v1/envs/wrappers.py | 11 +++-------- 2 files changed, 4 insertions(+), 12 deletions(-) diff --git a/brax/v1/envs/__init__.py b/brax/v1/envs/__init__.py index aa67d3732..b5240ab94 100644 --- a/brax/v1/envs/__init__.py +++ b/brax/v1/envs/__init__.py @@ -39,10 +39,7 @@ from brax.v1.envs import walker2d from brax.v1.envs import wrappers from brax.v1.envs.env import Env, State, Wrapper -try: - import gym -except ImportError: - import gymnasium as gym +import gym _envs = { 'acrobot': acrobot.Acrobot, diff --git a/brax/v1/envs/wrappers.py b/brax/v1/envs/wrappers.py index ffde56fd7..dffb065dd 100644 --- a/brax/v1/envs/wrappers.py +++ b/brax/v1/envs/wrappers.py @@ -20,14 +20,9 @@ from brax.v1 import jumpy as jp from brax.v1.envs import env as brax_env -try: - import gym - from gym import spaces - from gym.vector import utils -except ImportError: - import gymnasium as gym - from gymnasium import spaces - from gymnasium.vector import utils +import gym +from gym import spaces +from gym.vector import utils import jax import jax.numpy as jnp From c5bb67f40a4a3c87a5cabc9af78119ab5b006894 Mon Sep 17 00:00:00 2001 From: Giovanni Fregonese Date: Fri, 7 Jun 2024 18:18:06 +0200 Subject: [PATCH 6/6] Adapt to new gymnasium step api --- brax/envs/wrappers/gym.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/brax/envs/wrappers/gym.py b/brax/envs/wrappers/gym.py index 5b0ca29b7..c1ed97c55 100644 --- a/brax/envs/wrappers/gym.py +++ b/brax/envs/wrappers/gym.py @@ -74,7 +74,7 @@ def reset(self): def step(self, action): self._state, obs, reward, done, info = self._step(self._state, action) # We return device arrays for pytorch users. - return obs, reward, done, info + return gym.utils.step_api_compatibility.convert_to_terminated_truncated_step_api((obs, reward, done, info)) def seed(self, seed: int = 0): self._key = jax.random.PRNGKey(seed) @@ -131,7 +131,7 @@ def reset(key): def step(state, action): state = self._env.step(state, action) info = {**state.metrics, **state.info} - return state, state.obs, state.reward, state.done, info + return gym.utils.step_api_compatibility.convert_to_terminated_truncated_step_api((state, state.obs, state.reward, state.done, info), is_vector_env=True) self._step = jax.jit(step, backend=self.backend)