-
-
Notifications
You must be signed in to change notification settings - Fork 288
Off Policy Evaluation
Justin Fu edited this page Jul 13, 2020
·
5 revisions
D4RL Currently supports off-policy evaluation on the following tasks:
Policy Name | Environment Name | Returns (undiscounted, 10-seed average) |
---|---|---|
cheetah-random | halfcheetah-random-v0 | -199 |
cheetah-medium | halfcheetah-medium-v0 | 3985 |
cheetah-medium-high | halfcheetah-medium-v0 | 6751 |
cheetah | halfcheetah-expert-v0 | 12330 |
hopper-random | hopper-random-v0 | 1257 |
hopper-medium | hopper-medium-v0 | 2260 |
hopper-medium-high | hopper-medium-v0 | 3256 |
hopper | hopper-expert-v0 | 3624 |
walker2d-random | walker2d-random-v0 | 896 |
walker2d-medium-low | walker2d-medium-v0 | 1555 |
walker2d-medium | walker2d-medium-v0 | 2760 |
walker2d | walker2d-expert-v0 | 4005 |
Policies can be downloaded from http://rail.eecs.berkeley.edu/datasets/offline_rl/ope_policies.
Policies are provided in two formats:
-
[policy_name]_params.pkl
: A pickle file (that can be loaded with joblib) that contains numpy arrays containing the network weights. - A pair of ONNX files:
-
[policy_name]_params.sampler.onnx
takes as input (observations, gaussian noise) and outputs (actions, mean, log_std). The mean and log_std are the parameters of the gaussian before the tanh is applied. -
[policy_name]_params.log_prob.onnx
takes as input (observations, actions) and outputs log_probs.
-
A sample script for loading the ONNX files is the OPE rollout script. This script requires a separate installation of onnx
and onnxruntime
.
The d4rl.ope
module contains metrics for off-policy evaluation. Each metric takes in policy string ID (defined under then polic name column in the Tasks table), and can be computed using discounted or undiscounted returns by passing in a discounted=True/False
flag. We provide the following metrics:
-
ope.ranking_correlation_metric(policy_ids)
: Computes Spearmans' ranking correlation coefficient for a list of policy ids. -
ope.precision_at_k_metric(policy_ids, k, n_rel)
: Computes precision@k. -
ope.recall_at_k_metric(policy_ids, k, n_rel)
: Computes recall@k. -
ope.value_error_metric(policy_id, value)
: Computes error in the estimated policy value. -
ope.policy_regret_metric(policy_id, expert_ids)
: Computes the difference between the best return among the experts and the selected policy's returns.