Source code for meta_policy_search.optimizers.conjugate_gradient_optimizer

from meta_policy_search.utils import logger
import numpy as np
import tensorflow as tf
from collections import OrderedDict
from meta_policy_search.optimizers.base import Optimizer


class FiniteDifferenceHvp(Optimizer):
    def __init__(self, base_eps=1e-5, symmetric=True, grad_clip=None):
        self.base_eps = np.cast['float32'](base_eps)
        self.symmetric = symmetric
        self.grad_clip = grad_clip
        self._target = None
        self.reg_coeff = None
        self._constraint_gradient = None
        self._input_ph_dict = None

    def build_graph(self, constraint_obj, target, input_val_dict, reg_coeff):
        """
        Sets the objective function and target weights for the optimize function

        Args:
            constraint_obj (tf_op) : constraint objective
            target (Policy) : Policy whose values we are optimizing over
            inputs (list) : tuple of tf.placeholders for input data which may be subsampled. The first dimension corresponds to the number of data points
            reg_coeff (float): regularization coefficient
        """
        self._target = target
        self.reg_coeff = reg_coeff
        self._input_ph_dict = input_val_dict

        params = list(target.get_params().values())
        constraint_grads = tf.gradients(constraint_obj, xs=params)

        for idx, (grad, param) in enumerate(zip(constraint_grads, params)):
            if grad is None:
                constraint_grads[idx] = tf.zeros_like(param)

        constraint_gradient = tf.concat([tf.reshape(grad, [-1]) for grad in constraint_grads], axis=0)

        self._constraint_gradient = constraint_gradient

    def constraint_gradient(self, input_val_dict):
        """
        Computes the gradient of the constraint objective

        Args:
            inputs (list): inputs needed to compute the gradient

        Returns:
            (np.ndarray): flattened gradient
        """

        sess = tf.get_default_session()
        feed_dict = self.create_feed_dict(input_val_dict)
        constraint_gradient = sess.run(self._constraint_gradient, feed_dict)
        return constraint_gradient

    def Hx(self, input_val_dict, x):
        """
        Compute the second derivative of the constraint val in the direction of the vector x
        Args:
            inputs (list): inputs needed to compute the gradient of the constraint objective
            x (np.ndarray): vector indicating the direction on which the Hessian has to be computed

        Returns: (np.ndarray): second derivative in the direction of x

        """
        assert isinstance(x, np.ndarray)

        param_vals = self._target.get_param_values().copy()
        flat_param_vals = _flatten_params(param_vals)
        eps = self.base_eps
        params_plus_eps_vals = _unflatten_params(flat_param_vals + eps * x, params_example=param_vals)
        self._target.set_params(params_plus_eps_vals)
        constraint_grad_plus_eps = self.constraint_gradient(input_val_dict)
        self._target.set_params(param_vals)

        if self.symmetric:
            params_minus_eps_vals = _unflatten_params(flat_param_vals - eps * x, params_example=param_vals)
            self._target.set_params(params_minus_eps_vals)
            constraint_grad_minus_eps = self.constraint_gradient(input_val_dict)
            self._target.set_params(param_vals)
            hx = (constraint_grad_plus_eps - constraint_grad_minus_eps)/(2 * eps)

        else:
            constraint_grad = self.constraint_gradient(input_val_dict)
            hx = (constraint_grad_plus_eps - constraint_grad)/eps
        return hx

    def build_eval(self, inputs):
        """
        Build the Hessian evaluation function. It let's you evaluate the hessian of the constraint objective
        in any direction.
        Args:
            inputs (list): inputs needed to compute the gradient of the constraint objective

        Returns:
            (function): function that evaluates the Hessian of the constraint objective in the input direction
        """
        def evaluate_hessian(x):
            return self.Hx(inputs, x) + self.reg_coeff * x

        return evaluate_hessian


[docs]class ConjugateGradientOptimizer(Optimizer): """ Performs constrained optimization via line search. The search direction is computed using a conjugate gradient algorithm, which gives x = A^{-1}g, where A is a second order approximation of the constraint and g is the gradient of the loss function. Args: cg_iters (int) : The number of conjugate gradients iterations used to calculate A^-1 g reg_coeff (float) : A small value so that A -> A + reg*I subsample_factor (float) : Subsampling factor to reduce samples when using "conjugate gradient. Since the computation time for the descent direction dominates, this can greatly reduce the overall computation time. backtrack_ratio (float) : ratio for decreasing the step size for the line search max_backtracks (int) : maximum number of backtracking iterations for the line search debug_nan (bool) : if set to True, NanGuard will be added to the compilation, and ipdb will be invoked when nan is detected accept_violation (bool) : whether to accept the descent step if it violates the line search condition after exhausting all backtracking budgets hvp_approach (obj) : Hessian vector product approach """ def __init__( self, cg_iters=10, reg_coeff=0, subsample_factor=1., backtrack_ratio=0.8, max_backtracks=15, debug_nan=False, accept_violation=False, hvp_approach=FiniteDifferenceHvp(), ): self._cg_iters = cg_iters self._reg_coeff = reg_coeff self._subsample_factor = subsample_factor self._backtrack_ratio = backtrack_ratio self._max_backtracks = max_backtracks self._target = None self._max_constraint_val = None self._constraint_name = "kl-div" self._debug_nan = debug_nan self._accept_violation = accept_violation self._hvp_approach = hvp_approach self._loss = None self._gradient = None self._constraint_objective = None self._input_ph_dict = None
[docs] def build_graph(self, loss, target, input_ph_dict, leq_constraint): """ Sets the objective function and target weights for the optimize function Args: loss (tf_op) : minimization objective target (Policy) : Policy whose values we are optimizing over inputs (list) : tuple of tf.placeholders for input data which may be subsampled. The first dimension corresponds to the number of data points extra_inputs (list) : tuple of tf.placeholders for hyperparameters (e.g. learning rate, if annealed) leq_constraint (tuple) : A constraint provided as a tuple (f, epsilon), of the form f(*inputs) <= epsilon. """ assert isinstance(loss, tf.Tensor) assert hasattr(target, 'get_params') assert isinstance(input_ph_dict, dict) constraint_objective, constraint_value = leq_constraint self._target = target self._constraint_objective = constraint_objective self._max_constraint_val = constraint_value self._input_ph_dict = input_ph_dict self._loss = loss # build the graph of the hessian vector product (hvp) self._hvp_approach.build_graph(constraint_objective, target, self._input_ph_dict, self._reg_coeff) # build the graph of the gradients params = list(target.get_params().values()) grads = tf.gradients(loss, xs=params) for idx, (grad, param) in enumerate(zip(grads, params)): if grad is None: grads[idx] = tf.zeros_like(param) gradient = tf.concat([tf.reshape(grad, [-1]) for grad in grads], axis=0) self._gradient = gradient
[docs] def loss(self, input_val_dict): """ Computes the value of the loss for given inputs Args: inputs (list): inputs needed to compute the loss function extra_inputs (list): additional inputs needed to compute the loss function Returns: (float): value of the loss """ sess = tf.get_default_session() feed_dict = self.create_feed_dict(input_val_dict) loss = sess.run(self._loss, feed_dict=feed_dict) return loss
[docs] def constraint_val(self, input_val_dict): """ Computes the value of the KL-divergence between pre-update policies for given inputs Args: inputs (list): inputs needed to compute the inner KL extra_inputs (list): additional inputs needed to compute the inner KL Returns: (float): value of the loss """ sess = tf.get_default_session() feed_dict = self.create_feed_dict(input_val_dict) constrain_val = sess.run(self._constraint_objective, feed_dict) return constrain_val
[docs] def gradient(self, input_val_dict): """ Computes the gradient of the loss function Args: inputs (list): inputs needed to compute the gradient extra_inputs (list): additional inputs needed to compute the loss function Returns: (np.ndarray): flattened gradient """ sess = tf.get_default_session() feed_dict = self.create_feed_dict(input_val_dict) gradient = sess.run(self._gradient, feed_dict) return gradient
[docs] def optimize(self, input_val_dict): """ Carries out the optimization step Args: inputs (list): inputs for the optimization extra_inputs (list): extra inputs for the optimization subsample_grouped_inputs (None or list): subsample data from each element of the list """ logger.log("Start CG optimization") logger.log("computing loss before") loss_before = self.loss(input_val_dict) logger.log("performing update") logger.log("computing gradient") gradient = self.gradient(input_val_dict) logger.log("gradient computed") logger.log("computing descent direction") Hx = self._hvp_approach.build_eval(input_val_dict) descent_direction = conjugate_gradients(Hx, gradient, cg_iters=self._cg_iters) initial_step_size = np.sqrt(2.0 * self._max_constraint_val * (1. / (descent_direction.dot(Hx(descent_direction)) + 1e-8))) if np.isnan(initial_step_size): logger.log("Initial step size is NaN! Rejecting the step!") return initial_descent_step = initial_step_size * descent_direction logger.log("descent direction computed") prev_params = self._target.get_param_values() prev_params_values = _flatten_params(prev_params) loss, constraint_val, n_iter, violated = 0, 0, 0, False for n_iter, ratio in enumerate(self._backtrack_ratio ** np.arange(self._max_backtracks)): cur_step = ratio * initial_descent_step cur_params_values = prev_params_values - cur_step cur_params = _unflatten_params(cur_params_values, params_example=prev_params) self._target.set_params(cur_params) loss, constraint_val = self.loss(input_val_dict), self.constraint_val(input_val_dict) if loss < loss_before and constraint_val <= self._max_constraint_val: break """ ------------------- Logging Stuff -------------------------- """ if np.isnan(loss): violated = True logger.log("Line search violated because loss is NaN") if np.isnan(constraint_val): violated = True logger.log("Line search violated because constraint %s is NaN" % self._constraint_name) if loss >= loss_before: violated = True logger.log("Line search violated because loss not improving") if constraint_val >= self._max_constraint_val: violated = True logger.log("Line search violated because constraint %s is violated" % self._constraint_name) if violated and not self._accept_violation: logger.log("Line search condition violated. Rejecting the step!") self._target.set_params(prev_params) logger.log("backtrack iters: %d" % n_iter) logger.log("computing loss after") logger.log("optimization finished")
def _unflatten_params(flat_params, params_example): unflat_params = [] idx = 0 for key, param in params_example.items(): size_param = np.prod(param.shape) reshaped_param = np.reshape(flat_params[idx:idx+size_param], newshape=param.shape) unflat_params.append((key, reshaped_param)) idx += size_param return OrderedDict(unflat_params) def _flatten_params(params): return np.concatenate([param.reshape(-1) for param in params.values()]) def conjugate_gradients(f_Ax, b, cg_iters=10, verbose=False, residual_tol=1e-10): """ Demmel p 312 """ p = b.copy() r = b.copy() x = np.zeros_like(b, dtype=np.float32) rdotr = r.dot(r) fmtstr = "%10i %10.3g %10.3g" titlestr = "%10s %10s %10s" if verbose: print(titlestr % ("iter", "residual norm", "soln norm")) for i in range(cg_iters): if verbose: print(fmtstr % (i, rdotr, np.linalg.norm(x))) z = f_Ax(p) v = rdotr / p.dot(z) x += v * p r -= v * z newrdotr = r.dot(r) mu = newrdotr / rdotr p = r + mu * p rdotr = newrdotr if rdotr < residual_tol: break if verbose: print(fmtstr % (i + 1, rdotr, np.linalg.norm(x))) return x