Source code for meta_policy_search.policies.meta_gaussian_mlp_policy

from meta_policy_search.policies.base import MetaPolicy
from meta_policy_search.policies.gaussian_mlp_policy import GaussianMLPPolicy
import numpy as np
import tensorflow as tf
from meta_policy_search.policies.networks.mlp import forward_mlp


[docs]class MetaGaussianMLPPolicy(GaussianMLPPolicy, MetaPolicy): def __init__(self, meta_batch_size, *args, **kwargs): self.quick_init(locals()) # store init arguments for serialization self.meta_batch_size = meta_batch_size self.pre_update_action_var = None self.pre_update_mean_var = None self.pre_update_log_std_var = None self.post_update_action_var = None self.post_update_mean_var = None self.post_update_log_std_var = None super(MetaGaussianMLPPolicy, self).__init__(*args, **kwargs)
[docs] def build_graph(self): """ Builds computational graph for policy """ # Create pre-update policy by calling build_graph of the super class super(MetaGaussianMLPPolicy, self).build_graph() self.pre_update_action_var = tf.split(self.action_var, self.meta_batch_size) self.pre_update_mean_var = tf.split(self.mean_var, self.meta_batch_size) self.pre_update_log_std_var = [self.log_std_var for _ in range(self.meta_batch_size)] # Create lightweight policy graph that takes the policy parameters as placeholders with tf.variable_scope(self.name + "_ph_graph"): mean_network_phs_meta_batch, log_std_network_phs_meta_batch = [], [] self.post_update_action_var = [] self.post_update_mean_var = [] self.post_update_log_std_var = [] # build meta_batch_size graphs for post-update policies --> thereby the policy parameters are placeholders obs_var_per_task = tf.split(self.obs_var, self.meta_batch_size, axis=0) for idx in range(self.meta_batch_size): with tf.variable_scope("task_%i" % idx): with tf.variable_scope("mean_network"): # create mean network parameter placeholders mean_network_phs = self._create_placeholders_for_vars( scope=self.name + "/mean_network") # -> returns ordered dict mean_network_phs_meta_batch.append(mean_network_phs) # forward pass through the mean mpl _, mean_var = forward_mlp(output_dim=self.action_dim, hidden_sizes=self.hidden_sizes, hidden_nonlinearity=self.hidden_nonlinearity, output_nonlinearity=self.output_nonlinearity, input_var=obs_var_per_task[idx], mlp_params=mean_network_phs, ) with tf.variable_scope("log_std_network"): # create log_stf parameter placeholders log_std_network_phs = self._create_placeholders_for_vars(scope=self.name + "/log_std_network") # -> returns ordered dict log_std_network_phs_meta_batch.append(log_std_network_phs) log_std_var = list(log_std_network_phs.values())[0] # weird stuff since log_std_network_phs is ordered dict action_var = mean_var + tf.random_normal(shape=tf.shape(mean_var)) * tf.exp(log_std_var) self.post_update_action_var.append(action_var) self.post_update_mean_var.append(mean_var) self.post_update_log_std_var.append(log_std_var) # merge mean_network_phs and log_std_network_phs into policies_params_phs self.policies_params_phs = [] for idx, odict in enumerate(mean_network_phs_meta_batch): # Mutate mean_network_ph here odict.update(log_std_network_phs_meta_batch[idx]) self.policies_params_phs.append(odict) self.policy_params_keys = list(self.policies_params_phs[0].keys())
[docs] def get_action(self, observation, task=0): """ Runs a single observation through the specified policy and samples an action Args: observation (ndarray) : single observation - shape: (obs_dim,) Returns: (ndarray) : single action - shape: (action_dim,) """ observation = np.repeat(np.expand_dims(np.expand_dims(observation, axis=0), axis=0), self.meta_batch_size, axis=0) action, agent_infos = self.get_actions(observation) action, agent_infos = action[task][0], dict(mean=agent_infos[task][0]['mean'], log_std=agent_infos[task][0]['log_std']) return action, agent_infos
[docs] def get_actions(self, observations): """ Args: observations (list): List of numpy arrays of shape (meta_batch_size, batch_size, obs_dim) Returns: (tuple) : A tuple containing a list of numpy arrays of action, and a list of list of dicts of agent infos """ assert len(observations) == self.meta_batch_size if self._pre_update_mode: actions, agent_infos = self._get_pre_update_actions(observations) else: actions, agent_infos = self._get_post_update_actions(observations) assert len(actions) == self.meta_batch_size return actions, agent_infos
def _get_pre_update_actions(self, observations): """ Args: observations (list): List of numpy arrays of shape (meta_batch_size, batch_size, obs_dim) """ batch_size = observations[0].shape[0] assert all([obs.shape[0] == batch_size for obs in observations]) assert len(observations) == self.meta_batch_size obs_stack = np.concatenate(observations, axis=0) feed_dict = {self.obs_var: obs_stack} sess = tf.get_default_session() actions, means, log_stds = sess.run([self.pre_update_action_var, self.pre_update_mean_var, self.pre_update_log_std_var], feed_dict=feed_dict) log_stds = np.concatenate(log_stds) # Get rid of fake batch size dimension (would be better to do this in tf, if we can match batch sizes) agent_infos = [[dict(mean=mean, log_std=log_stds[idx]) for mean in means[idx]] for idx in range(self.meta_batch_size)] return actions, agent_infos def _get_post_update_actions(self, observations): """ Args: observations (list): List of numpy arrays of shape (meta_batch_size, batch_size, obs_dim) """ assert self.policies_params_vals is not None obs_stack = np.concatenate(observations, axis=0) feed_dict = {self.obs_var: obs_stack} feed_dict.update(self.policies_params_feed_dict) sess = tf.get_default_session() actions, means, log_stds = sess.run([self.post_update_action_var, self.post_update_mean_var, self.post_update_log_std_var], feed_dict=feed_dict) log_stds = np.concatenate(log_stds) # Get rid of fake batch size dimension (would be better to do this in tf, if we can match batch sizes) agent_infos = [[dict(mean=mean, log_std=log_stds[idx]) for mean in means[idx]] for idx in range(self.meta_batch_size)] return actions, agent_infos