Source code for meta_policy_search.samplers.meta_sampler
from meta_policy_search.samplers.base import Sampler
from meta_policy_search.samplers.vectorized_env_executor import MetaParallelEnvExecutor, MetaIterativeEnvExecutor
from meta_policy_search.utils import utils, logger
from collections import OrderedDict
from pyprind import ProgBar
import numpy as np
import time
import itertools
[docs]class MetaSampler(Sampler):
"""
Sampler for Meta-RL
Args:
env (meta_policy_search.envs.base.MetaEnv) : environment object
policy (meta_policy_search.policies.base.Policy) : policy object
batch_size (int) : number of trajectories per task
meta_batch_size (int) : number of meta tasks
max_path_length (int) : max number of steps per trajectory
envs_per_task (int) : number of envs to run vectorized for each task (influences the memory usage)
"""
def __init__(
self,
env,
policy,
rollouts_per_meta_task,
meta_batch_size,
max_path_length,
envs_per_task=None,
parallel=False
):
super(MetaSampler, self).__init__(env, policy, rollouts_per_meta_task, max_path_length)
assert hasattr(env, 'set_task')
self.envs_per_task = rollouts_per_meta_task if envs_per_task is None else envs_per_task
self.meta_batch_size = meta_batch_size
self.total_samples = meta_batch_size * rollouts_per_meta_task * max_path_length
self.parallel = parallel
self.total_timesteps_sampled = 0
# setup vectorized environment
if self.parallel:
self.vec_env = MetaParallelEnvExecutor(env, self.meta_batch_size, self.envs_per_task, self.max_path_length)
else:
self.vec_env = MetaIterativeEnvExecutor(env, self.meta_batch_size, self.envs_per_task, self.max_path_length)
[docs] def update_tasks(self):
"""
Samples a new goal for each meta task
"""
tasks = self.env.sample_tasks(self.meta_batch_size)
assert len(tasks) == self.meta_batch_size
self.vec_env.set_tasks(tasks)
[docs] def obtain_samples(self, log=False, log_prefix=''):
"""
Collect batch_size trajectories from each task
Args:
log (boolean): whether to log sampling times
log_prefix (str) : prefix for logger
Returns:
(dict) : A dict of paths of size [meta_batch_size] x (batch_size) x [5] x (max_path_length)
"""
# initial setup / preparation
paths = OrderedDict()
for i in range(self.meta_batch_size):
paths[i] = []
n_samples = 0
running_paths = [_get_empty_running_paths_dict() for _ in range(self.vec_env.num_envs)]
pbar = ProgBar(self.total_samples)
policy_time, env_time = 0, 0
policy = self.policy
# initial reset of envs
obses = self.vec_env.reset()
while n_samples < self.total_samples:
# execute policy
t = time.time()
obs_per_task = np.split(np.asarray(obses), self.meta_batch_size)
actions, agent_infos = policy.get_actions(obs_per_task)
policy_time += time.time() - t
# step environments
t = time.time()
actions = np.concatenate(actions) # stack meta batch
next_obses, rewards, dones, env_infos = self.vec_env.step(actions)
env_time += time.time() - t
# stack agent_infos and if no infos were provided (--> None) create empty dicts
agent_infos, env_infos = self._handle_info_dicts(agent_infos, env_infos)
new_samples = 0
for idx, observation, action, reward, env_info, agent_info, done in zip(itertools.count(), obses, actions,
rewards, env_infos, agent_infos,
dones):
# append new samples to running paths
running_paths[idx]["observations"].append(observation)
running_paths[idx]["actions"].append(action)
running_paths[idx]["rewards"].append(reward)
running_paths[idx]["env_infos"].append(env_info)
running_paths[idx]["agent_infos"].append(agent_info)
# if running path is done, add it to paths and empty the running path
if done:
paths[idx // self.envs_per_task].append(dict(
observations=np.asarray(running_paths[idx]["observations"]),
actions=np.asarray(running_paths[idx]["actions"]),
rewards=np.asarray(running_paths[idx]["rewards"]),
env_infos=utils.stack_tensor_dict_list(running_paths[idx]["env_infos"]),
agent_infos=utils.stack_tensor_dict_list(running_paths[idx]["agent_infos"]),
))
new_samples += len(running_paths[idx]["rewards"])
running_paths[idx] = _get_empty_running_paths_dict()
pbar.update(new_samples)
n_samples += new_samples
obses = next_obses
pbar.stop()
self.total_timesteps_sampled += self.total_samples
if log:
logger.logkv(log_prefix + "PolicyExecTime", policy_time)
logger.logkv(log_prefix + "EnvExecTime", env_time)
return paths
def _handle_info_dicts(self, agent_infos, env_infos):
if not env_infos:
env_infos = [dict() for _ in range(self.vec_env.num_envs)]
if not agent_infos:
agent_infos = [dict() for _ in range(self.vec_env.num_envs)]
else:
assert len(agent_infos) == self.meta_batch_size
assert len(agent_infos[0]) == self.envs_per_task
agent_infos = sum(agent_infos, []) # stack agent_infos
assert len(agent_infos) == self.meta_batch_size * self.envs_per_task == len(env_infos)
return agent_infos, env_infos
def _get_empty_running_paths_dict():
return dict(observations=[], actions=[], rewards=[], env_infos=[], agent_infos=[])