Source code for meta_policy_search.samplers.meta_sample_processor

from meta_policy_search.samplers.base import SampleProcessor
from meta_policy_search.samplers.dice_sample_processor import DiceSampleProcessor
from meta_policy_search.utils import utils
import numpy as np

[docs]class MetaSampleProcessor(SampleProcessor):
[docs] def process_samples(self, paths_meta_batch, log=False, log_prefix=''): """ Processes sampled paths. This involves: - computing discounted rewards (returns) - fitting baseline estimator using the path returns and predicting the return baselines - estimating the advantages using GAE (+ advantage normalization id desired) - stacking the path data - logging statistics of the paths Args: paths_meta_batch (dict): A list of dict of lists, size: [meta_batch_size] x (batch_size) x [5] x (max_path_length) log (boolean): indicates whether to log log_prefix (str): prefix for the logging keys Returns: (list of dicts) : Processed sample data among the meta-batch; size: [meta_batch_size] x [7] x (batch_size x max_path_length) """ assert isinstance(paths_meta_batch, dict), 'paths must be a dict' assert self.baseline, 'baseline must be specified' samples_data_meta_batch = [] all_paths = [] for meta_task, paths in paths_meta_batch.items(): # fits baseline, compute advantages and stack path data samples_data, paths = self._compute_samples_data(paths) samples_data_meta_batch.append(samples_data) all_paths.extend(paths) # 7) compute normalized trajectory-batch rewards (for E-MAML) overall_avg_reward = np.mean(np.concatenate([samples_data['rewards'] for samples_data in samples_data_meta_batch])) overall_avg_reward_std = np.std(np.concatenate([samples_data['rewards'] for samples_data in samples_data_meta_batch])) for samples_data in samples_data_meta_batch: samples_data['adj_avg_rewards'] = (samples_data['rewards'] - overall_avg_reward) / (overall_avg_reward_std + 1e-8) # 8) log statistics if desired self._log_path_stats(all_paths, log=log, log_prefix=log_prefix) return samples_data_meta_batch
class DiceMetaSampleProcessor(DiceSampleProcessor): process_samples = MetaSampleProcessor.process_samples