From 6eb37a1ab679d54e8916ce959f4838acc4433e65 Mon Sep 17 00:00:00 2001 From: Coline Devin Date: Fri, 12 Jul 2019 10:54:43 -0700 Subject: [PATCH] Added linear scheduling option to epsilon greedy exploration --- rlkit/exploration_strategies/epsilon_greedy.py | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/rlkit/exploration_strategies/epsilon_greedy.py b/rlkit/exploration_strategies/epsilon_greedy.py index 73fb88f47..fcc4aa548 100644 --- a/rlkit/exploration_strategies/epsilon_greedy.py +++ b/rlkit/exploration_strategies/epsilon_greedy.py @@ -1,17 +1,24 @@ import random from rlkit.exploration_strategies.base import RawExplorationStrategy +from rlkit.util.ml_util import LinearSchedule class EpsilonGreedy(RawExplorationStrategy): """ Take a random discrete action with some probability. """ - def __init__(self, action_space, prob_random_action=0.1): - self.prob_random_action = prob_random_action + def __init__(self, action_space, prob_random_action=0.1, prob_end=None, steps=1e6): + """ + If prob_end is None, this will default to a fixed schedule. + """ + if prob_end is not None: + self.prob_random_action = LinearSchedule(prob_random_action, prob_end, steps) + else: + self.prob_random_action = LinearSchedule(prob_random_action, prob_random_action, steps) self.action_space = action_space - def get_action_from_raw_action(self, action, **kwargs): - if random.random() <= self.prob_random_action: + def get_action_from_raw_action(self, action,t=0 **kwargs): + if random.random() <= self.prob_random_action.get_value(t): return self.action_space.sample() return action