from meta_policy_search.utils.utils import remove_scope_from_name
from meta_policy_search.utils import Serializable
import tensorflow as tf
from collections import OrderedDict
[docs]class Policy(Serializable):
"""
A container for storing the current pre and post update policies
Also provides functions for executing and updating policy parameters
Note:
the preupdate policy is stored as tf.Variables, while the postupdate
policy is stored in numpy arrays and executed through tf.placeholders
Args:
obs_dim (int): dimensionality of the observation space -> specifies the input size of the policy
action_dim (int): dimensionality of the action space -> specifies the output size of the policy
name (str) : Name used for scoping variables in policy
hidden_sizes (tuple) : size of hidden layers of network
learn_std (bool) : whether to learn variance of network output
hidden_nonlinearity (Operation) : nonlinearity used between hidden layers of network
output_nonlinearity (Operation) : nonlinearity used after the final layer of network
"""
def __init__(self,
obs_dim,
action_dim,
name='policy',
hidden_sizes=(32, 32),
learn_std=True,
hidden_nonlinearity=tf.tanh,
output_nonlinearity=None,
**kwargs
):
Serializable.quick_init(self, locals())
self.obs_dim = obs_dim
self.action_dim = action_dim
self.name = name
self.hidden_sizes = hidden_sizes
self.learn_std = learn_std
self.hidden_nonlinearity = hidden_nonlinearity
self.output_nonlinearity = output_nonlinearity
self._dist = None
self.policy_params = None
self._assign_ops = None
self._assign_phs = None
[docs] def build_graph(self):
"""
Builds computational graph for policy
"""
raise NotImplementedError
[docs] def get_action(self, observation):
"""
Runs a single observation through the specified policy
Args:
observation (array) : single observation
Returns:
(array) : array of arrays of actions for each env
"""
raise NotImplementedError
[docs] def get_actions(self, observations):
"""
Runs each set of observations through each task specific policy
Args:
observations (array) : array of arrays of observations generated by each task and env
Returns:
(tuple) : array of arrays of actions for each env (meta_batch_size) x (batch_size) x (action_dim)
and array of arrays of agent_info dicts
"""
raise NotImplementedError
def reset(self, dones=None):
pass
[docs] def log_diagnostics(self, paths):
"""
Log extra information per iteration based on the collected paths
"""
pass
@property
def distribution(self):
"""
Returns this policy's distribution
Returns:
(Distribution) : this policy's distribution
"""
raise NotImplementedError
[docs] def distribution_info_sym(self, obs_var, params=None):
"""
Return the symbolic distribution information about the actions.
Args:
obs_var (placeholder) : symbolic variable for observations
params (None or dict) : a dictionary of placeholders that contains information about the
state of the policy at the time it received the observation
Returns:
(dict) : a dictionary of tf placeholders for the policy output distribution
"""
raise NotImplementedError
[docs] def distribution_info_keys(self, obs, state_infos):
"""
Args:
obs (placeholder) : symbolic variable for observations
state_infos (dict) : a dictionary of placeholders that contains information about the
state of the policy at the time it received the observation
Returns:
(dict) : a dictionary of tf placeholders for the policy output distribution
"""
raise NotImplementedError
[docs] def likelihood_ratio_sym(self, obs, action, dist_info_old, policy_params):
"""
Computes the likelihood p_new(obs|act)/p_old ratio between
Args:
obs (tf.Tensor): symbolic variable for observations
action (tf.Tensor): symbolic variable for actions
dist_info_old (dict): dictionary of tf.placeholders with old policy information
policy_params (dict): dictionary of the policy parameters (each value is a tf.Tensor)
Returns:
(tf.Tensor) : likelihood ratio
"""
distribution_info_new = self.distribution_info_sym(obs, params=policy_params)
likelihood_ratio = self._dist.likelihood_ratio_sym(action, dist_info_old, distribution_info_new)
return likelihood_ratio
[docs] def log_likelihood_sym(self, obs, action, policy_params):
"""
Computes the log likelihood p(obs|act)
Args:
obs (tf.Tensor): symbolic variable for observations
action (tf.Tensor): symbolic variable for actions
policy_params (dict): dictionary of the policy parameters (each value is a tf.Tensor)
Returns:
(tf.Tensor) : log likelihood
"""
distribution_info_var = self._dist.distribution_info_sym(obs, params=policy_params)
log_likelihood = self._dist.log_likelihood_sym(action, distribution_info_var)
return log_likelihood
""" --- methods for serialization --- """
[docs] def get_params(self):
"""
Get the tf.Variables representing the trainable weights of the network (symbolic)
Returns:
(dict) : a dict of all trainable Variables
"""
return self.policy_params
[docs] def get_param_values(self):
"""
Gets a list of all the current weights in the network (in original code it is flattened, why?)
Returns:
(list) : list of values for parameters
"""
param_values = tf.get_default_session().run(self.policy_params)
return param_values
[docs] def set_params(self, policy_params):
"""
Sets the parameters for the graph
Args:
policy_params (dict): of variable names and corresponding parameter values
"""
assert all([k1 == k2 for k1, k2 in zip(self.get_params().keys(), policy_params.keys())]), \
"parameter keys must match with variable"
if self._assign_ops is None:
assign_ops, assign_phs = [], []
for var in self.get_params().values():
assign_placeholder = tf.placeholder(dtype=var.dtype)
assign_op = tf.assign(var, assign_placeholder)
assign_ops.append(assign_op)
assign_phs.append(assign_placeholder)
self._assign_ops = assign_ops
self._assign_phs = assign_phs
feed_dict = dict(zip(self._assign_phs, policy_params.values()))
tf.get_default_session().run(self._assign_ops, feed_dict=feed_dict)
def __getstate__(self):
state = {
'init_args': Serializable.__getstate__(self),
'network_params': self.get_param_values()
}
return state
def __setstate__(self, state):
Serializable.__setstate__(self, state['init_args'])
tf.get_default_session().run(tf.global_variables_initializer())
self.set_params(state['network_params'])