Source code for tianshou.env.worker.ray

import contextlib
from collections.abc import Callable
from typing import Any

import gymnasium as gym
import numpy as np

from tianshou.env.utils import ENV_TYPE, gym_new_venv_step_type
from tianshou.env.worker import EnvWorker

with contextlib.suppress(ImportError):
    import ray


# mypy: disable-error-code="unused-ignore"


class _SetAttrWrapper(gym.Wrapper):
    def set_env_attr(self, key: str, value: Any) -> None:
        setattr(self.env.unwrapped, key, value)

    def get_env_attr(self, key: str) -> Any:
        return getattr(self.env, key)


[docs] class RayEnvWorker(EnvWorker): """Ray worker used in RayVectorEnv.""" def __init__( self, env_fn: Callable[[], ENV_TYPE], ) -> None: # TODO: is ENV_TYPE actually correct? self.env = ray.remote(_SetAttrWrapper).options(num_cpus=0).remote(env_fn()) # type: ignore super().__init__(env_fn)
[docs] def get_env_attr(self, key: str) -> Any: return ray.get(self.env.get_env_attr.remote(key))
[docs] def set_env_attr(self, key: str, value: Any) -> None: ray.get(self.env.set_env_attr.remote(key, value))
[docs] def reset(self, **kwargs: Any) -> Any: if "seed" in kwargs: super().seed(kwargs["seed"]) return ray.get(self.env.reset.remote(**kwargs))
[docs] @staticmethod def wait( # type: ignore workers: list["RayEnvWorker"], wait_num: int, timeout: float | None = None, ) -> list["RayEnvWorker"]: results = [x.result for x in workers] ready_results, _ = ray.wait(results, num_returns=wait_num, timeout=timeout) return [workers[results.index(result)] for result in ready_results]
[docs] def send(self, action: np.ndarray | None, **kwargs: Any) -> None: # self.result is actually a handle if action is None: self.result = self.env.reset.remote(**kwargs) else: self.result = self.env.step.remote(action)
[docs] def recv(self) -> gym_new_venv_step_type: return ray.get(self.result) # type: ignore
[docs] def seed(self, seed: int | None = None) -> list[int] | None: super().seed(seed) try: return ray.get(self.env.seed.remote(seed)) except (AttributeError, NotImplementedError): self.env.reset.remote(seed=seed) return None
[docs] def render(self, **kwargs: Any) -> Any: return ray.get(self.env.render.remote(**kwargs))
[docs] def close_env(self) -> None: ray.get(self.env.close.remote())