Source code for meta_policy_search.samplers.vectorized_env_executor

import numpy as np
import pickle as pickle
from multiprocessing import Process, Pipe
import copy


[docs]class MetaIterativeEnvExecutor(object): """ Wraps multiple environments of the same kind and provides functionality to reset / step the environments in a vectorized manner. Internally, the environments are executed iteratively. Args: env (meta_policy_search.envs.base.MetaEnv): meta environment object meta_batch_size (int): number of meta tasks envs_per_task (int): number of environments per meta task max_path_length (int): maximum length of sampled environment paths - if the max_path_length is reached, the respective environment is reset """ def __init__(self, env, meta_batch_size, envs_per_task, max_path_length): self.envs = np.asarray([copy.deepcopy(env) for _ in range(meta_batch_size * envs_per_task)]) self.ts = np.zeros(len(self.envs), dtype='int') # time steps self.max_path_length = max_path_length
[docs] def step(self, actions): """ Steps the wrapped environments with the provided actions Args: actions (list): lists of actions, of length meta_batch_size x envs_per_task Returns (tuple): a length 4 tuple of lists, containing obs (np.array), rewards (float), dones (bool), env_infos (dict). Each list is of length meta_batch_size x envs_per_task (assumes that every task has same number of envs) """ assert len(actions) == self.num_envs all_results = [env.step(a) for (a, env) in zip(actions, self.envs)] # stack results split to obs, rewards, ... obs, rewards, dones, env_infos = list(map(list, zip(*all_results))) # reset env when done or max_path_length reached dones = np.asarray(dones) self.ts += 1 dones = np.logical_or(self.ts >= self.max_path_length, dones) for i in np.argwhere(dones).flatten(): obs[i] = self.envs[i].reset() self.ts[i] = 0 return obs, rewards, dones, env_infos
[docs] def set_tasks(self, tasks): """ Sets a list of tasks to each environment Args: tasks (list): list of the tasks for each environment """ envs_per_task = np.split(self.envs, len(tasks)) for task, envs in zip(tasks, envs_per_task): for env in envs: env.set_task(task)
[docs] def reset(self): """ Resets the environments Returns: (list): list of (np.ndarray) with the new initial observations. """ obses = [env.reset() for env in self.envs] self.ts[:] = 0 return obses
@property def num_envs(self): """ Number of environments Returns: (int): number of environments """ return len(self.envs)
[docs]class MetaParallelEnvExecutor(object): """ Wraps multiple environments of the same kind and provides functionality to reset / step the environments in a vectorized manner. Thereby the environments are distributed among meta_batch_size processes and executed in parallel. Args: env (meta_policy_search.envs.base.MetaEnv): meta environment object meta_batch_size (int): number of meta tasks envs_per_task (int): number of environments per meta task max_path_length (int): maximum length of sampled environment paths - if the max_path_length is reached, the respective environment is reset """ def __init__(self, env, meta_batch_size, envs_per_task, max_path_length): self.n_envs = meta_batch_size * envs_per_task self.meta_batch_size = meta_batch_size self.envs_per_task = envs_per_task self.remotes, self.work_remotes = zip(*[Pipe() for _ in range(meta_batch_size)]) seeds = np.random.choice(range(10**6), size=meta_batch_size, replace=False) self.ps = [ Process(target=worker, args=(work_remote, remote, pickle.dumps(env), envs_per_task, max_path_length, seed)) for (work_remote, remote, seed) in zip(self.work_remotes, self.remotes, seeds)] # Why pass work remotes? for p in self.ps: p.daemon = True # if the main process crashes, we should not cause things to hang p.start() for remote in self.work_remotes: remote.close()
[docs] def step(self, actions): """ Executes actions on each env Args: actions (list): lists of actions, of length meta_batch_size x envs_per_task Returns (tuple): a length 4 tuple of lists, containing obs (np.array), rewards (float), dones (bool), env_infos (dict) each list is of length meta_batch_size x envs_per_task (assumes that every task has same number of envs) """ assert len(actions) == self.num_envs # split list of actions in list of list of actions per meta tasks chunks = lambda l, n: [l[x: x + n] for x in range(0, len(l), n)] actions_per_meta_task = chunks(actions, self.envs_per_task) # step remote environments for remote, action_list in zip(self.remotes, actions_per_meta_task): remote.send(('step', action_list)) results = [remote.recv() for remote in self.remotes] obs, rewards, dones, env_infos = map(lambda x: sum(x, []), zip(*results)) return obs, rewards, dones, env_infos
[docs] def reset(self): """ Resets the environments of each worker Returns: (list): list of (np.ndarray) with the new initial observations. """ for remote in self.remotes: remote.send(('reset', None)) return sum([remote.recv() for remote in self.remotes], [])
[docs] def set_tasks(self, tasks=None): """ Sets a list of tasks to each worker Args: tasks (list): list of the tasks for each worker """ for remote, task in zip(self.remotes, tasks): remote.send(('set_task', task)) for remote in self.remotes: remote.recv()
@property def num_envs(self): """ Number of environments Returns: (int): number of environments """ return self.n_envs
def worker(remote, parent_remote, env_pickle, n_envs, max_path_length, seed): """ Instantiation of a parallel worker for collecting samples. It loops continually checking the task that the remote sends to it. Args: remote (multiprocessing.Connection): parent_remote (multiprocessing.Connection): env_pickle (pkl): pickled environment n_envs (int): number of environments per worker max_path_length (int): maximum path length of the task seed (int): random seed for the worker """ parent_remote.close() envs = [pickle.loads(env_pickle) for _ in range(n_envs)] np.random.seed(seed) ts = np.zeros(n_envs, dtype='int') while True: # receive command and data from the remote cmd, data = remote.recv() # do a step in each of the environment of the worker if cmd == 'step': all_results = [env.step(a) for (a, env) in zip(data, envs)] obs, rewards, dones, infos = map(list, zip(*all_results)) ts += 1 for i in range(n_envs): if dones[i] or (ts[i] >= max_path_length): dones[i] = True obs[i] = envs[i].reset() ts[i] = 0 remote.send((obs, rewards, dones, infos)) # reset all the environments of the worker elif cmd == 'reset': obs = [env.reset() for env in envs] ts[:] = 0 remote.send(obs) # set the specified task for each of the environments of the worker elif cmd == 'set_task': for env in envs: env.set_task(data) remote.send(None) # close the remote and stop the worker elif cmd == 'close': remote.close() break else: raise NotImplementedError