from typing import Optional
import gymnasium as gym
import numpy as np
[docs]
class PyTorchObsWrapper(gym.ObservationWrapper):
"""
Transpose the observation image tensors for PyTorch
"""
def __init__(self, env):
super().__init__(env)
obs_shape = self.observation_space.shape
self.observation_space = gym.spaces.Box(
self.observation_space.low[0, 0, 0],
self.observation_space.high[0, 0, 0],
[obs_shape[2], obs_shape[1], obs_shape[0]],
dtype=self.observation_space.dtype,
)
def observation(self, observation):
return observation.transpose(2, 1, 0)
[docs]
class GreyscaleWrapper(gym.ObservationWrapper):
"""
Convert image observations from RGB to greyscale
"""
def __init__(self, env):
super().__init__(env)
obs_shape = self.observation_space.shape
self.observation_space = gym.spaces.Box(
self.observation_space.low[0, 0, 0],
self.observation_space.high[0, 0, 0],
(obs_shape[0], obs_shape[1], 1),
dtype=self.observation_space.dtype,
)
def observation(self, obs):
obs = 0.30 * obs[:, :, 0] + 0.59 * obs[:, :, 1] + 0.11 * obs[:, :, 2]
return np.expand_dims(obs, axis=2)
[docs]
class StochasticActionWrapper(gym.ActionWrapper):
"""
Add stochasticity to the actions
If a random action is provided, it is returned with probability `1 - prob`.
Else, a random action is sampled from the action space.
"""
def __init__(self, env, prob: float = 0.9, random_action: Optional[int] = None):
super().__init__(env)
self.prob = prob
self.random_action = random_action
def action(self, action):
""" """
if self.np_random.uniform() < self.prob:
return action
else:
if self.random_action is None:
return self.np_random.integers(0, 6)
else:
return self.random_action