Source code for meta_policy_search.meta_trainer

import tensorflow as tf
import numpy as np
import time
from meta_policy_search.utils import logger


[docs]class Trainer(object): """ Performs steps of meta-policy search. Pseudocode:: for iter in n_iter: sample tasks for task in tasks: for adapt_step in num_inner_grad_steps sample trajectories with policy perform update/adaptation step sample trajectories with post-update policy perform meta-policy gradient step(s) Args: algo (Algo) : env (Env) : sampler (Sampler) : sample_processor (SampleProcessor) : baseline (Baseline) : policy (Policy) : n_itr (int) : Number of iterations to train for start_itr (int) : Number of iterations policy has already trained for, if reloading num_inner_grad_steps (int) : Number of inner steps per maml iteration sess (tf.Session) : current tf session (if we loaded policy, for example) """ def __init__( self, algo, env, sampler, sample_processor, policy, n_itr, start_itr=0, num_inner_grad_steps=1, sess=None, ): self.algo = algo self.env = env self.sampler = sampler self.sample_processor = sample_processor self.baseline = sample_processor.baseline self.policy = policy self.n_itr = n_itr self.start_itr = start_itr self.num_inner_grad_steps = num_inner_grad_steps if sess is None: sess = tf.Session() self.sess = sess
[docs] def train(self): """ Trains policy on env using algo Pseudocode:: for itr in n_itr: for step in num_inner_grad_steps: sampler.sample() algo.compute_updated_dists() algo.optimize_policy() sampler.update_goals() """ with self.sess.as_default() as sess: # initialize uninitialized vars (only initialize vars that were not loaded) uninit_vars = [var for var in tf.global_variables() if not sess.run(tf.is_variable_initialized(var))] sess.run(tf.variables_initializer(uninit_vars)) start_time = time.time() for itr in range(self.start_itr, self.n_itr): itr_start_time = time.time() logger.log("\n ---------------- Iteration %d ----------------" % itr) logger.log("Sampling set of tasks/goals for this meta-batch...") self.sampler.update_tasks() self.policy.switch_to_pre_update() # Switch to pre-update policy all_samples_data, all_paths = [], [] list_sampling_time, list_inner_step_time, list_outer_step_time, list_proc_samples_time = [], [], [], [] start_total_inner_time = time.time() for step in range(self.num_inner_grad_steps+1): logger.log('** Step ' + str(step) + ' **') """ -------------------- Sampling --------------------------""" logger.log("Obtaining samples...") time_env_sampling_start = time.time() paths = self.sampler.obtain_samples(log=True, log_prefix='Step_%d-' % step) list_sampling_time.append(time.time() - time_env_sampling_start) all_paths.append(paths) """ ----------------- Processing Samples ---------------------""" logger.log("Processing samples...") time_proc_samples_start = time.time() samples_data = self.sample_processor.process_samples(paths, log='all', log_prefix='Step_%d-' % step) all_samples_data.append(samples_data) list_proc_samples_time.append(time.time() - time_proc_samples_start) self.log_diagnostics(sum(list(paths.values()), []), prefix='Step_%d-' % step) """ ------------------- Inner Policy Update --------------------""" time_inner_step_start = time.time() if step < self.num_inner_grad_steps: logger.log("Computing inner policy updates...") self.algo._adapt(samples_data) # train_writer = tf.summary.FileWriter('/home/ignasi/Desktop/meta_policy_search_graph', # sess.graph) list_inner_step_time.append(time.time() - time_inner_step_start) total_inner_time = time.time() - start_total_inner_time time_maml_opt_start = time.time() """ ------------------ Outer Policy Update ---------------------""" logger.log("Optimizing policy...") # This needs to take all samples_data so that it can construct graph for meta-optimization. time_outer_step_start = time.time() self.algo.optimize_policy(all_samples_data) """ ------------------- Logging Stuff --------------------------""" logger.logkv('Itr', itr) logger.logkv('n_timesteps', self.sampler.total_timesteps_sampled) logger.logkv('Time-OuterStep', time.time() - time_outer_step_start) logger.logkv('Time-TotalInner', total_inner_time) logger.logkv('Time-InnerStep', np.sum(list_inner_step_time)) logger.logkv('Time-SampleProc', np.sum(list_proc_samples_time)) logger.logkv('Time-Sampling', np.sum(list_sampling_time)) logger.logkv('Time', time.time() - start_time) logger.logkv('ItrTime', time.time() - itr_start_time) logger.logkv('Time-MAMLSteps', time.time() - time_maml_opt_start) logger.log("Saving snapshot...") params = self.get_itr_snapshot(itr) logger.save_itr_params(itr, params) logger.log("Saved") logger.dumpkvs() logger.log("Training finished") self.sess.close()
[docs] def get_itr_snapshot(self, itr): """ Gets the current policy and env for storage """ return dict(itr=itr, policy=self.policy, env=self.env, baseline=self.baseline)
def log_diagnostics(self, paths, prefix): # TODO: we aren't using it so far self.env.log_diagnostics(paths, prefix) self.policy.log_diagnostics(paths, prefix) self.baseline.log_diagnostics(paths, prefix)