Action Robust Reinforcement Learning
C. Tessler, Y. Efroni, and S. Mannor, “Action robust reinforcement learning and applications in continuous control,” in ICML, 2019, pp. 10846–10855.
http://proceedings.mlr.press/v97/tessler19a.html
https://chentessler.wixsite.com/tessler/post/action-robust-reinforcement-learning
https://tesslerc.github.io/
Motivation
- The gap between simulation and real world is so large that policy learning approaches fail to transfer.
- Even if policy learning is done in real world, the data scarcity leads to failed generalization from training to test scenarios (e.g. due to different friction or object masses).
Robust Policy
A policy is said to be robust if it maximizes the reward while considering a bad, or even adversarial, model
.
The core idea is that modeling errors
should be viewed as extra forces or disturbances
in the system.
Perturbations: it is not clear how a learned policy would generalize under these perturbations.
- traction
- friction
- tire pressure
- humidity
- vehicle mass
- road conditions
Different from Transition Dynamics Uncertainty Robust RL
Transition dynamics uncertainty
Solving a max-min optimization problem over a set of possible model parameters, an uncertainty set.
- It is not clear how to obtain these uncertainty sets.
- It is not clear if and how these approaches may be extended to non-linear function approximation schemes, e.g. neural networks.
Action uncertainty
Consider action perturbations:
- Probabilistic action: correlate to abrupt interruptions such as a sudden push.
- Noisy action: correlate to a constant interrupting force, may be seen as an adversary applying force in the opposite direction.
Consider Two Scenarios of Action Uncertainty
Policy Space
: The agent attempts to perform an action \(a\), and with probability \(\alpha\), an alternative adversarial action \(\bar{a}\) is taken.Action Space
: An adversary adds a perturbation to the selected action \(a\).
Zero-Sum Games
Two player zero-sum game:
- \(a \in A, \bar{a} \in \bar{A}\).
- \(r(s_t,a_t,\bar{a}_t)\).
- \(s_{t+1} \sim P(s_t,a_t,\bar{a}_t)\).
Value function of the game:
\[V_{\pi,\bar{\pi}}(s) = \mathbb{E}_{\pi,\bar{\pi}} \left[ \sum_{t = 0}^\infty \gamma^t r(s_t,a_t,\bar{a}_t) \bigg\vert s_0 = s \right], \quad \forall s \in S.\] \[\begin{aligned} V_*(s) & = \max_{\pi \in P(\Pi)} \min_{\bar{\pi} \in \Pi} \mathbb{E}_{\pi,\bar{\pi}} \left[ \sum_{t=0}^\infty \gamma^t r(s_t,a_t,\bar{a}_t) \bigg\vert s_0 = s \right], \\ & = \min_{\bar{\pi} \in P(\Pi)} \max_{\pi \in \Pi} \mathbb{E}_{\pi,\bar{\pi}} \left[ \sum_{t=0}^\infty \gamma^t r(s_t,a_t,\bar{a}_t) \bigg\vert s_0 = s \right], \end{aligned}\]where
- \(\Pi\) is the set of stationary deterministic policies on \(A\).
- \(P(\Pi)\) denotes the set of stationary stochastic policies.
Note that, policies which attain this value, \(\pi^\star\) and \(\bar{\pi}^\star\) for the players, are said to be in Nash-Equilibrium
. In this scenario, neither player may improve it’s outcome further, e.g. \(\forall \pi,\bar{\pi} \in P(\Pi), V^{\pi,\bar{\pi}^\star} \le V^\star \le V^{\pi^\star,\bar{\pi}}\).
Probabilistic Action Robust MDP
Let \(\pi,\bar{\pi}\) be policies of agent and adversary, respectively. Define their probabilistic joint policy
:
The value function of the joint policy is defined by:
\[V_{P,\alpha}^\pi = \min_{\bar{\pi} \in \Pi} \mathbb{E}_{\pi_{P,\alpha}^{mix}(\pi,\bar{\pi})} \left[ \sum_{t=0}^\infty \gamma^t r(s_t,a_t) \right], \quad a_t \sim \pi_{P,\alpha}^{mix}(\pi(s_t),\bar{\pi}(s_t)).\]The optimal probabilistic robust policy is
:
MDP Model for PR-MDP
By considering the probabilistic action, a class of models is implicitly defined, and the probabilistic robust policy is optimal w.r.t. the worst possible model in this class
.
Define the following class of models:
\[P_\alpha = \{ (1-\alpha)P + \alpha P^\pi : P(\Pi) \},\] \[R_\alpha = \{ (1-\alpha)r + \alpha r^\pi : \pi \in P(\Pi) \}.\]Then, the optimal policy is:
\[\pi_{P,\alpha}^* \in \arg\max_{\pi' \in \Pi} \min_{P \in P_\alpha, r \in R_\alpha} \mathbb{E}_{P,\pi'} \left[ \sum_{t=0}^\infty \gamma^t r(s_t,a_t) \right].\]Use Policy Iteration schemes to solve the optimal policy
- Probabilistic Robust PI
- given a fixed adversary strategy, it calculates the optimal counter strategy.
- it solves the 1-step greedy policy w.r.t. the value function of the mixture policy.
- Soft Probabilistic Robust PI
- using gradient information to update the
adversary policy
, it works by finding the valid policy with the highest correlation, i.e. inner product, with the direction of gradient descent and performs a step towards it.
- using gradient information to update the
- Algorithms of Probabilistic Robust PI
When \(\eta = 1\), the two algorithms are completely equivalent.
Noisy Action Robust MDP
Instead of a stochastic perturbation in the policy space
, we consider a perturbation in the action space
.
- For probabilistic action robust MDP, a
deterministic stationary optimal policy
exists. - For noisy action robust MDP, the optimal policy is a
stochastic policy
.
Define the noisy joint policy
:
The value function is defined by:
\[V_{N,\alpha}^\pi = \min_{\bar{\pi} \in \Pi} \mathbb{E}_{\pi_{N,\alpha}^{mix}(\pi,\bar{\pi})} \left[ \sum_{t=0}^\infty \gamma^t r(s_t,a_t) \right], \quad a_t \sim \pi_{N,\alpha}^{mix}(\pi(s_t), \bar{\pi}(s_t)).\]The optimal noisy robust policy is
:
Use Policy Iteration to solve the optimal policy
\[\pi_k \in \arg \max_{\pi \in \Pi} V_{\pi_{N,\alpha}^{mix}(\pi,\bar{\pi}_k)},\] \[\pi_k \in \arg \min_{\bar{\pi} \in P(\Pi)} \max_{\pi \in \Pi} \left[ r_{\pi_{N,\alpha}^{mix}(\pi,\bar{\pi})} + P_{\pi_{N,\alpha}^{mix}(\pi,\bar{\pi})} V_{\pi_{N,\alpha}^{mix}(\pi_k, \bar{\pi}_k)} \right].\] \[r_{\pi_{N,\alpha}^{mix}(\pi,\bar{\pi})}(s) = \mathbb{E}_{a \sim \pi, \bar{a} \sim \bar{\pi}} [r(s, (1-\alpha)a + \alpha \bar{a})].\] \[P_{\pi_{N,\alpha}^{mix}(\pi,\bar{\pi})}(s,s') = \mathbb{E}_{a \sim \pi, \bar{a} \sim \bar{\pi}} [P(s' \vert s, (1-\alpha)a + \alpha \bar{a})].\]Applications in Continuous Control
- MuJoCo domains: robotic manipulation (continuous control problems).
- Probabilistic actor: the occurrence of large abrupt forces (e.g. someone suddenly pushes the robot).
- Noisy actor: mass uncertainty, the robot is heavier or lighter.
Action Robust DDPG
Train two networks, the actor
and adversary
, denote by \(\mu_\theta\) and \(\bar{\mu}_{\bar{\theta}}\).
The probabilistic action
robust joint policy is:
where \(\delta(\cdot)\) is the Dirac delta function:
\[\delta(x-x_0) = 0 \quad x \ne x_0,\] \[\delta(x-x_0) = \infty \quad x = x_0,\] \[\int \delta(x-x_0)dx = 1,\] \[\int f(x)\delta(x-x_0)dx = f(x_0).\]The noisy action
robust joint policy is:
The performance objective is defined by:
\[J(\pi(\mu_\theta, \bar{\mu}_{\bar{\theta}})) = \mathbb{E}_{s \sim p^\pi} [V_\pi(s)].\]The gradient of the actor and adversary parameters is:
\[\nabla_\theta J(\pi(\mu_\theta, \bar{\mu}_{\bar{\theta}})) = (1-\alpha) \mathbb{E}_{s \sim p^\pi}[\nabla_\theta \mu_\theta(s)\nabla_a Q_\pi(s,a)],\] \[\nabla_{\bar{\theta}} J(\pi(\mu_\theta, \bar{\mu}_{\bar{\theta}})) = \alpha \mathbb{E}_{s \sim p^\pi}[\nabla_{\bar{\theta}} \bar{\mu}_{\bar{\theta}}(s) \nabla_{\bar{a}} Q_\pi(s,\bar{a})].\]- For probabilistic action, \(a = \mu_\theta(s)\) and \(\bar{a} = \bar{\mu}_{\bar{\theta}}(s)\).
- For noisy action, \(a = \bar{a} = (1-\alpha)\mu_\theta(s) + \alpha \bar{\mu}_{\bar{\theta}}(s)\).
- Algorithm of Action Robust DDPG
Code of AR-DDPG
Select the action
NR-MDP
: \(a_t = (1-\alpha)\pi_\theta + \alpha\bar{\pi}_{\bar{\theta}}\).PR-MDP
: \(\pi_\theta\) w.p. \((1-\alpha)\), or \(\bar{\pi}_{\bar{\theta}}\) w.p. \(\alpha\).MDP
: \(\pi_\theta\).
def select_action(self, state, action_noise=None, param_noise=None, mdp_type='mdp'):
state = normalize(Variable(state).to(self.device), self.obs_rms, self.device)
if mdp_type != 'mdp':
if mdp_type == 'nr_mdp':
if param_noise is not None:
mu = self.actor_perturbed(state)
else:
mu = self.actor(state)
mu = mu.data
if action_noise is not None:
mu += self.Tensor(action_noise()).to(self.device)
mu = mu.clamp(-1, 1) * (1 - self.alpha)
if param_noise is not None:
adv_mu = self.adversary_perturbed(state)
else:
adv_mu = self.adversary(state)
adv_mu = adv_mu.data.clamp(-1, 1) * self.alpha
mu += adv_mu
else: # mdp_type == 'pr_mdp':
if np.random.rand() < (1 - self.alpha):
if param_noise is not None:
mu = self.actor_perturbed(state)
else:
mu = self.actor(state)
mu = mu.data
if action_noise is not None:
mu += self.Tensor(action_noise()).to(self.device)
mu = mu.clamp(-1, 1)
else:
if param_noise is not None:
mu = self.adversary_perturbed(state)
else:
mu = self.adversary(state)
mu = mu.data.clamp(-1, 1)
else:
if param_noise is not None:
mu = self.actor_perturbed(state)
else:
mu = self.actor(state)
mu = mu.data
if action_noise is not None:
mu += self.Tensor(action_noise()).to(self.device)
mu = mu.clamp(-1, 1)
return mu
Store transition
\[(s_t,a_t,r_t,s_{t+1}) \in B.\]def store_transition(self, state, action, mask, next_state, reward):
B = state.shape[0]
for b in range(B):
self.memory.push(state[b], action[b], mask[b], next_state[b], reward[b])
if self.normalize_observations:
self.obs_rms.update(state[b].cpu().numpy())
if self.normalize_returns:
self.ret = self.ret * self.gamma + reward[b]
self.ret_rms.update(np.array([self.ret]))
if mask[b] == 0: # if terminal is True
self.ret = 0
Sample batch from replay buffer
transitions = self.memory.sample(batch_size)
batch = Transition(*zip(*transitions))
if mdp_type != 'mdp':
robust_update_type = 'full'
elif exploration_method != 'mdp':
robust_update_type = 'adversary'
else:
robust_update_type = None
state_batch = normalize(Variable(torch.stack(batch.state)).to(self.device), self.obs_rms, self.device)
action_batch = Variable(torch.stack(batch.action)).to(self.device)
reward_batch = normalize(Variable(torch.stack(batch.reward)).to(self.device).unsqueeze(1), self.ret_rms, self.device)
mask_batch = Variable(torch.stack(batch.mask)).to(self.device).unsqueeze(1)
next_state_batch = normalize(Variable(torch.stack(batch.next_state)).to(self.device), self.obs_rms, self.device)
Update parameters of actor & critic & adversary
def update_robust(self, state_batch, action_batch, reward_batch, mask_batch, next_state_batch, adversary_update,
mdp_type, robust_update_type):
# TRAIN CRITIC
if robust_update_type == 'full':
if mdp_type == 'nr_mdp':
next_action_batch = (1 - self.alpha) * self.actor_target(next_state_batch) \
+ self.alpha * self.adversary_target(next_state_batch)
next_state_action_values = self.critic_target(next_state_batch, next_action_batch)
else: # mdp_type == 'pr_mdp':
next_action_actor_batch = self.actor_target(next_state_batch)
next_action_adversary_batch = self.adversary_target(next_state_batch)
next_state_action_values = self.critic_target(next_state_batch, next_action_actor_batch) * (
1 - self.alpha) \
+ self.critic_target(next_state_batch,
next_action_adversary_batch) * self.alpha
expected_state_action_batch = reward_batch + self.gamma * mask_batch * next_state_action_values
self.critic_optim.zero_grad()
state_action_batch = self.critic(state_batch, action_batch)
value_loss = F.mse_loss(state_action_batch, expected_state_action_batch)
value_loss.backward()
self.critic_optim.step()
value_loss = value_loss.item()
else:
value_loss = 0
if adversary_update:
# TRAIN ADVERSARY
self.adversary_optim.zero_grad()
if mdp_type == 'nr_mdp':
with torch.no_grad():
real_action = self.actor_target(next_state_batch)
action = (1 - self.alpha) * real_action + self.alpha * self.adversary(next_state_batch)
adversary_loss = self.critic(state_batch, action)
else: # mdp_type == 'pr_mdp'
action = self.adversary(next_state_batch)
adversary_loss = self.critic(state_batch, action) * self.alpha
adversary_loss = adversary_loss.mean()
adversary_loss.backward()
self.adversary_optim.step()
adversary_loss = adversary_loss.item()
policy_loss = 0
else:
if robust_update_type == 'full':
# TRAIN ACTOR
self.actor_optim.zero_grad()
if mdp_type == 'nr_mdp':
with torch.no_grad():
adversary_action = self.adversary_target(next_state_batch)
action = (1 - self.alpha) * self.actor(next_state_batch) + self.alpha * adversary_action
policy_loss = -self.critic(state_batch, action)
else: # mdp_type == 'pr_mdp':
action = self.actor(next_state_batch)
policy_loss = -self.critic(state_batch, action) * (1 - self.alpha)
policy_loss = policy_loss.mean()
policy_loss.backward()
self.actor_optim.step()
policy_loss = policy_loss.item()
adversary_loss = 0
else:
policy_loss = 0
adversary_loss = 0
return value_loss, policy_loss, adversary_loss
soft_update(self.actor_target, self.actor, self.tau)
soft_update(self.adversary_target, self.adversary, self.tau)
soft_update(self.critic_target, self.critic, self.tau)
def soft_update(target, source, tau):
for target_param, param in zip(target.parameters(), source.parameters()):
target_param.data.copy_(target_param.data * (1.0 - tau) + param.data * tau)