Source code for meta_policy_search.policies.base

from meta_policy_search.utils.utils import remove_scope_from_name
from meta_policy_search.utils import Serializable
import tensorflow as tf
from collections import OrderedDict


[docs]class Policy(Serializable): """ A container for storing the current pre and post update policies Also provides functions for executing and updating policy parameters Note: the preupdate policy is stored as tf.Variables, while the postupdate policy is stored in numpy arrays and executed through tf.placeholders Args: obs_dim (int): dimensionality of the observation space -> specifies the input size of the policy action_dim (int): dimensionality of the action space -> specifies the output size of the policy name (str) : Name used for scoping variables in policy hidden_sizes (tuple) : size of hidden layers of network learn_std (bool) : whether to learn variance of network output hidden_nonlinearity (Operation) : nonlinearity used between hidden layers of network output_nonlinearity (Operation) : nonlinearity used after the final layer of network """ def __init__(self, obs_dim, action_dim, name='policy', hidden_sizes=(32, 32), learn_std=True, hidden_nonlinearity=tf.tanh, output_nonlinearity=None, **kwargs ): Serializable.quick_init(self, locals()) self.obs_dim = obs_dim self.action_dim = action_dim self.name = name self.hidden_sizes = hidden_sizes self.learn_std = learn_std self.hidden_nonlinearity = hidden_nonlinearity self.output_nonlinearity = output_nonlinearity self._dist = None self.policy_params = None self._assign_ops = None self._assign_phs = None
[docs] def build_graph(self): """ Builds computational graph for policy """ raise NotImplementedError
[docs] def get_action(self, observation): """ Runs a single observation through the specified policy Args: observation (array) : single observation Returns: (array) : array of arrays of actions for each env """ raise NotImplementedError
[docs] def get_actions(self, observations): """ Runs each set of observations through each task specific policy Args: observations (array) : array of arrays of observations generated by each task and env Returns: (tuple) : array of arrays of actions for each env (meta_batch_size) x (batch_size) x (action_dim) and array of arrays of agent_info dicts """ raise NotImplementedError
def reset(self, dones=None): pass
[docs] def log_diagnostics(self, paths): """ Log extra information per iteration based on the collected paths """ pass
@property def distribution(self): """ Returns this policy's distribution Returns: (Distribution) : this policy's distribution """ raise NotImplementedError
[docs] def distribution_info_sym(self, obs_var, params=None): """ Return the symbolic distribution information about the actions. Args: obs_var (placeholder) : symbolic variable for observations params (None or dict) : a dictionary of placeholders that contains information about the state of the policy at the time it received the observation Returns: (dict) : a dictionary of tf placeholders for the policy output distribution """ raise NotImplementedError
[docs] def distribution_info_keys(self, obs, state_infos): """ Args: obs (placeholder) : symbolic variable for observations state_infos (dict) : a dictionary of placeholders that contains information about the state of the policy at the time it received the observation Returns: (dict) : a dictionary of tf placeholders for the policy output distribution """ raise NotImplementedError
[docs] def likelihood_ratio_sym(self, obs, action, dist_info_old, policy_params): """ Computes the likelihood p_new(obs|act)/p_old ratio between Args: obs (tf.Tensor): symbolic variable for observations action (tf.Tensor): symbolic variable for actions dist_info_old (dict): dictionary of tf.placeholders with old policy information policy_params (dict): dictionary of the policy parameters (each value is a tf.Tensor) Returns: (tf.Tensor) : likelihood ratio """ distribution_info_new = self.distribution_info_sym(obs, params=policy_params) likelihood_ratio = self._dist.likelihood_ratio_sym(action, dist_info_old, distribution_info_new) return likelihood_ratio
[docs] def log_likelihood_sym(self, obs, action, policy_params): """ Computes the log likelihood p(obs|act) Args: obs (tf.Tensor): symbolic variable for observations action (tf.Tensor): symbolic variable for actions policy_params (dict): dictionary of the policy parameters (each value is a tf.Tensor) Returns: (tf.Tensor) : log likelihood """ distribution_info_var = self._dist.distribution_info_sym(obs, params=policy_params) log_likelihood = self._dist.log_likelihood_sym(action, distribution_info_var) return log_likelihood
""" --- methods for serialization --- """
[docs] def get_params(self): """ Get the tf.Variables representing the trainable weights of the network (symbolic) Returns: (dict) : a dict of all trainable Variables """ return self.policy_params
[docs] def get_param_values(self): """ Gets a list of all the current weights in the network (in original code it is flattened, why?) Returns: (list) : list of values for parameters """ param_values = tf.get_default_session().run(self.policy_params) return param_values
[docs] def set_params(self, policy_params): """ Sets the parameters for the graph Args: policy_params (dict): of variable names and corresponding parameter values """ assert all([k1 == k2 for k1, k2 in zip(self.get_params().keys(), policy_params.keys())]), \ "parameter keys must match with variable" if self._assign_ops is None: assign_ops, assign_phs = [], [] for var in self.get_params().values(): assign_placeholder = tf.placeholder(dtype=var.dtype) assign_op = tf.assign(var, assign_placeholder) assign_ops.append(assign_op) assign_phs.append(assign_placeholder) self._assign_ops = assign_ops self._assign_phs = assign_phs feed_dict = dict(zip(self._assign_phs, policy_params.values())) tf.get_default_session().run(self._assign_ops, feed_dict=feed_dict)
def __getstate__(self): state = { 'init_args': Serializable.__getstate__(self), 'network_params': self.get_param_values() } return state def __setstate__(self, state): Serializable.__setstate__(self, state['init_args']) tf.get_default_session().run(tf.global_variables_initializer()) self.set_params(state['network_params'])
[docs]class MetaPolicy(Policy): def __init__(self, *args, **kwargs): super(MetaPolicy, self).__init__(*args, **kwargs) self._pre_update_mode = True self.policies_params_vals = None self.policy_params_keys = None self.policies_params_phs = None self.meta_batch_size = None
[docs] def build_graph(self): """ Also should create lists of variables and corresponding assign ops """ raise NotImplementedError
[docs] def switch_to_pre_update(self): """ Switches get_action to pre-update policy """ self._pre_update_mode = True # replicate pre-update policy params meta_batch_size times self.policies_params_vals = [self.get_param_values() for _ in range(self.meta_batch_size)]
[docs] def get_actions(self, observations): if self._pre_update_mode: return self._get_pre_update_actions(observations) else: return self._get_post_update_actions(observations)
def _get_pre_update_actions(self, observations): """ Args: observations (list): List of size meta-batch size with numpy arrays of shape batch_size x obs_dim """ raise NotImplementedError def _get_post_update_actions(self, observations): """ Args: observations (list): List of size meta-batch size with numpy arrays of shape batch_size x obs_dim """ raise NotImplementedError
[docs] def update_task_parameters(self, updated_policies_parameters): """ Args: updated_policies_parameters (list): List of size meta-batch size. Each contains a dict with the policies parameters as numpy arrays """ self.policies_params_vals = updated_policies_parameters self._pre_update_mode = False
def _create_placeholders_for_vars(self, scope, graph_keys=tf.GraphKeys.TRAINABLE_VARIABLES): var_list = tf.get_collection(graph_keys, scope=scope) placeholders = [] for var in var_list: var_name = remove_scope_from_name(var.name, scope.split('/')[0]) placeholders.append((var_name, tf.placeholder(tf.float32, shape=var.shape, name="%s_ph" % var_name))) return OrderedDict(placeholders) @property def policies_params_feed_dict(self): """ returns fully prepared feed dict for feeding the currently saved policy parameter values into the lightweight policy graph """ return dict(list((self.policies_params_phs[i][key], self.policies_params_vals[i][key]) for key in self.policy_params_keys for i in range(self.meta_batch_size)))