Implicit Value Regularization (IVR)#

Why is Offline RL hard?#

Can we solve for an optimal policy using solely in-sample learning?#

Solving a behavior-regularized MDP#

We can directly solve for optimal policy \(\pi^*\) in the class of MDPs where an f-divergence is added to the reward at each timestep and a reference policy \(\mu\) is given

\[\begin{split} \begin{align} & \max_{\pi} \mathbb{E}_{\tau \sim \pi}[r(\tau)] - \alpha D_{f}(\pi \| \mu) \\ & \equiv \max_{\pi} \mathbb{E}_{\tau \sim \pi} \left[ \sum_{t=0}^{\infty} \gamma^t \left( r(s_t, a_t) - \alpha \cdot f \left( \frac{\pi(a_t|s_t)}{\mu(a_t|s_t)} \right) \right) \right] \\ & \equiv \max_{\pi} \mathbb{E}_{s \sim d^\pi}[V(s)] \\ & \equiv \max_{\pi} \mathbb{E}_{s \sim d^{\pi}, a \sim \pi(\cdot|s)}\left[Q(s,a) - \alpha \cdot f\left(\frac{\pi(a|s)}{\mu(a|s)}\right)\right] \\ \end{align} \end{split}\]

There are only two types of constraints in this optimization, constraining \(\pi\) to be a valid probability distribution

  1. \(\sum_a \pi(a\|s)=1 \quad \forall s\)

  2. \(\pi(a\|s) \geq 0 \quad \forall a,s\)

We write the Langrangian with dual variables \(u(s)\) for the first constraints and \(\beta(a\|s)\) for the second constraints

\[L(\pi, \beta, u) = \underbrace{\sum_s d^{\pi}(s) \sum_a \pi(a|s) \left(Q(s,a) - \alpha f\left(\frac{\pi(a|s)}{\mu(a|s)}\right)\right)}_{J(\pi)} - \sum_s d^{\pi}(s) \bigg[ \underbrace{u(s) \bigg(\sum_a \pi(a|s) - 1\bigg)}_{=0} - \underbrace{\sum_a \beta(a|s) \pi(a|s)}_{\geq 0}\bigg]\]

In a primal maximization problem, the lagrangian should be a upper bound, given feasible (\(\pi,\beta,u\)), on the primal optimal solution. This is why the constraints are subtracted because at feasibility they add to the objective and any violation will decrease the objective. This is apparent when we write the relation where \(C\) is the feasible set:

\[ J(\pi^*) \leq \max_{\pi \in C} L(\pi, \beta, u) \leq \max_\pi L(\pi, \beta, u) := g(\beta,u)\]

Instead of minimizing the dual function \(g(\beta,u)\) (which is actually a convex problem!), we can use the KKT conditions to derive the optimal solution. Famously, the four KKT conditions describe that if strong duality holds and we find a tuple \((\pi,\beta,u)\) that satisfy these conditions, then \(\pi\) is primal optimal and \((\beta, u)\) is dual optimal. They are,

Primal feasibility: \(0 \leq \pi(a\|s) \leq 1 \quad \forall a,s \text{ and } \pi(a\|s) \geq 0 \quad \forall a,s\)

Dual Feasibility: \(u \text{ is unconstrained}, \beta(a\|s) \geq 0 \quad \forall a,s\)

Complementary Slackness: \(\beta(a\|s) \cdot \pi(a\|s) = 0 \quad \forall a,s\)

Stationarity: \(\frac{\partial L}{\partial \pi(a\|s)} = 0\)

\[\frac{\partial}{\partial \pi(a|s)} \left[ \sum_{s'} d^{\pi}(s') \cdot (\text{local objective at } s') \right]\]

Notice that this will involve the product rule on every term in the summation.

\(d^\pi\) is dependent on \(\pi\) in that changing the policy at state \(s\) affects the visit frequency of all other states \(s'\) i.e the stationary distribution of the entire system changes.

To see this, let’s define the “local objective” at state \(s'\) as \(J(s', \pi)\):

\[J(s', \pi) = \left[ \sum_{a'} \pi(a'|s') \left(Q(s',a') - \alpha f\left(\frac{\pi(a'|s')}{\mu(a'|s')}\right)\right) - u(s') \left(\sum_{a'} \pi(a'|s') - 1\right) + \dots \right]\]

The stationary condition can be restated as

\[\frac{\partial L}{\partial \pi(a|s)} = \frac{\partial}{\partial \pi(a|s)} \left[ \sum_{s'} d^{\pi}(s') \cdot J(s', \pi) \right] = 0\]

Applying the product rule to each term \(s'\):

\[\frac{\partial L}{\partial \pi(a|s)} = \sum_{s'} \left[ \underbrace{\left( \frac{\partial d^{\pi}(s')}{\partial \pi(a|s)} \right) \cdot J(s', \pi)}_{\text{Term 1: Global Impact}} + \underbrace{d^{\pi}(s') \cdot \left( \frac{\partial J(s', \pi)}{\partial \pi(a|s)} \right)}_{\text{Term 2: Local Impact}} \right] = 0\]

If we treat \(\frac{\partial d^{\pi}(s')}{\partial \pi(a\|s)}\) as a constant, then Term 1 goes to zero.

Term 2 collapses to just one term (where \(s' = s\)) because the partial derivative of \(J(s',\pi)\) with respect to \(\pi(a\|s)\) is zero unless \(s' = s\)

\[\frac{\partial L}{\partial \pi(a|s)} = d^\pi(s) \cdot \frac{ \partial J(s, \pi)}{\partial \pi(a|s)} \]

Then, we also assume \(d^\pi\) is an irreducible Markov chain (it’s possible to get from any state to any other state, but not necessarily in one step, so \(d^\pi > 0 \quad \forall s\)), so we can divide both sides of the equation to get rid of \(d^\pi\).

\[\frac{\partial L}{\partial \pi(a|s)} = \frac{ \partial J(s, \pi)}{\partial \pi(a|s)} = 0\]

We break down \(\partial J(s,\pi)\) term by term.

\[ \frac{\partial}{\partial \pi(a|s)} \left[ \sum_{a'} \pi(a'|s) Q(s,a') \right] = Q(s,a)\]

\[\begin{split} \begin{align} & \frac{\partial}{\partial \pi(a|s)} \left[-\alpha \sum_{a'} \pi(a'|s) f\left(\frac{\pi(a'|s)}{\mu(a'|s)}\right)\right]\\ &=-\alpha \left[ \frac{\partial \pi(a|s)}{\partial \pi(a|s)} \cdot f\left(\frac{\pi(a|s)}{\mu(a|s)}\right) + \pi(a|s) \cdot \frac{\partial f(\frac{\pi(a|s)}{\mu(a|s)})}{\partial \pi(a|s)} \right]\\ &= -\alpha \cdot \underbrace{ \left(f\left(\frac{\pi(a|s)}{\mu(a|s)}\right) + \frac{\pi(a|s)}{\mu(a|s)}f'\left(\frac{\pi(a|s)}{\mu(a|s)}\right)\right) }_{h'(x) = x \cdot f(x) \text{ where } x = \frac{\pi}{\mu}} \\ & = -\alpha \cdot h'_f(\frac{\pi(a|s)}{\mu(a|s)}) \end{align} \end{split}\]

\[\frac{\partial}{\partial \pi(a|s)} \left[ -u(s) \left(\sum_{a'} \pi(a'|s) - 1\right) \right] = -u(s)\]

\[\frac{\partial}{\partial \pi(a|s)} \left[ \sum_{a'} \beta(a'|s) \pi(a'|s) \right] = \beta(a|s)\]

Summing the terms, we get

\[Q(s,a) - \alpha h'_f\left(\frac{\pi(a|s)}{\mu(a|s)}\right) - u(s) + \beta(a|s) = 0\]

Finding the form of \(\pi^*\)#

We rearrange the stationarity condition to obtain:

\[\pi(a|s) = \mu(a|s) \cdot g_f\left( \frac{1}{\alpha} \left( Q(s,a) - u(s) + \beta(a|s) \right) \right)\]

where \(g_f\) is the inverse of \(h'\).

We can eliminate one of the dual variables \(\beta\) by leveraging complementary slackness.

The complementary slackness condition, \(\beta(a\|s) \cdot \pi(a\|s) = 0\), creates two distinct cases for any action a

Case A (Action is taken, \(\pi(a\|s) > 0\)): This forces \(\beta(a\|s) = 0\).

Case B (Action is ignored, \(\pi(a\|s) = 0\)): This implies that the argument of \(g_f\) (when \(\beta(a\|s)=0\)) must be non-positive.

The max function captures this: it selects the calculated probability if it’s positive (Case A) and selects 0 if it’s negative (Case B).

\[\pi(a|s) = \mu(a|s) \cdot \max\left\{ g_f\left( \frac{1}{\alpha} (Q(s,a) - u(s)) \right), 0 \right\}\]

Finding U*#

The optimal dual variable \(U*\) is found by first substituting the optimal policy into the primal feasibility constraint.

\[\sum_a \left[ \mu(a|s) \cdot \max\left\{ g_f\left( \frac{1}{\alpha} (Q(s,a) - u(s)) \right), 0 \right\} \right] = \mathbb E_{a\sim\mu} \left[ \max\left\{ g_f\left( \frac{1}{\alpha} (Q(s,a) - u(s)) \right), 0 \right\} \right] = 1\]

This is a one-dimensional equation in \(u(s)\). We solve this equation for \(u(s)\). We will see exactly how to recover this for specific choices \(f\) in the next section.

Finding V*#

Once this optimal policy form \(\pi^* \) is found by solving for the unique Lagrange multiplier \(u^* (s)\) (which enforces the normalization \(\sum_a \pi^* (a\|s) = 1\)), we can derive the optimal state value \(V^* (s)\). The optimal policy corresponds to the optimal state value

\[\begin{split} \begin{align} V^* (s) &= \sum_a \pi^* (a|s) \left( Q^* (s,a) - \alpha f\left(\frac{\pi^* (a|s)}{\mu(a|s)}\right) \right)\\ \end{align} \end{split}\]

At optimality, the primal optimum (\(\pi^* \)) and dual optimum (\( u^* \),\( \beta^* \)) solution satisfy the KKT conditions. We can simplify the stationarity condition using complementary slackness, where for any action where \(\pi^* (a\|s) > 0\), we have \(\beta(a\|s) = 0\), which gives

\[Q^* (s,a) = u^* (s) + \alpha h'_f\left(\frac{\pi^* (a|s)}{\mu(a|s)}\right)\]

Substitute the definition of \(h_f^{`}\) into \(Q^* (s,a)\).

\[Q^*(s,a) = u^*(s) + \alpha \left[ f\left(\frac{\pi^*(a|s)}{\mu(a|s)}\right) + \frac{\pi^*(a|s)}{\mu(a|s)} f'\left(\frac{\pi^*(a|s)}{\mu(a|s)}\right) \right]\]

Substitute \(Q^* \) back into \( V^* \)

\[\begin{split} \begin{align} V^*(s) &= \sum_a \pi^*(a|s) \left( \left[ u^*(s) + \alpha f(\dots) + \alpha \frac{\pi^* (a|s)}{\mu(a|s)} f'(\dots) \right] - \alpha f(\dots) \right) \\ &= \sum_a \pi^*(a|s) \left( u^*(s) + \alpha \frac{\pi^*(a|s)}{\mu(a|s)} f'\left(\frac{\pi^*(a|s)}{\mu(a|s)}\right) \right)\\ &= \sum_a \pi^*(a|s) u^*(s) + \sum_a \pi^*(a|s) \left( \alpha \frac{\pi^*(a|s)}{\mu(a|s)} f'(\dots) \right) \\ &= u^*(s) \sum_a \pi^*(a|s) + \alpha \sum_a \frac{\pi^*(a|s)^2}{\mu(a|s)} f'(\dots) \\ &= u^*(s) + \alpha \sum_a \mu(a|s) \left[ \left(\frac{\pi^*(a|s)}{\mu(a|s)}\right)^2 f'\left(\frac{\pi^*(a|s)}{\mu(a|s)}\right) \right]\\ & = u^*(s) + \alpha \mathbb{E}_{a \sim \mu} \left[ \left(\frac{\pi^*(a|s)}{\mu(a|s)}\right)^2 f'\left(\frac{\pi^*(a|s)}{\mu(a|s)}\right) \right] \end{align} \end{split}\]

Theorem 1#

In the behavior-regularized MDP, we have the following for any state and action

\[Q^* (s,a) = r(s,a) + E_{s'\sim T(\cdot,s,a)} [V^* (s')]\]
\[\pi^* (a|s) = \mu(a|s) \cdot \max\left\{ g_f\left( \frac{1}{\alpha} (Q^* (s,a) - u^* (s)) \right), 0 \right\}\]
\[V^* (s)= u^*(s) + \alpha \mathbb{E}_{a \sim \mu} \left[ \left(\frac{\pi^*(a|s)}{\mu(a|s)}\right)^2 f'\left(\frac{\pi^*(a|s)}{\mu(a|s)}\right) \right]\]

where \(U^* (s)\) ensures \(\sum_a \pi^* (a\|s)=1\)

Instantiating a practical algorithm#

In offline RL, in order to completely avoid out-of-distribution actions, we want a zero-forcing support constraint, that \(\mu(a\|s)=0 \implies \pi(a\|s)=0\). This is given by \(\alpha\)-divergence, a subset of \(f\)-divergence, and takes the following form for \(\alpha \in [0,1]\)

\[D_\alpha(\mu,\pi) = \frac{1}{\alpha (\alpha-1)}\mathbb E_\pi \left[\left(\frac{\pi}{\mu}\right)^{-\alpha}-1\right]\]

Specifically, when \(\alpha \leq 0\), the divergence will enforce this zero-forcing constraint or ‘mode-seeking’.

Sparse Q-Learning#

\(\alpha=-1\) yields \(\chi\)-squared divergence, where $\(f(x) = x-1 \implies f'(x) = 1\)$

And by definition:

\[h'_f(x) = f(x) + xf'(x) = (x-1) + x(1) = 2x - 1\]
\[g_f(y) = \frac{1}{2}y + \frac{1}{2}\]

With this we can define the relevant optimization objectives

  1. Substitute \(g_f(y)\) into general policy form from Theorem 1:

    \[\pi^*(a|s) = \mu(a|s) \cdot \max\left\{ \frac{1}{2} + \frac{Q^*(s,a) - U^*(s)}{2\alpha}, 0 \right\}\]
  2. The optimal Lagrange multiplier \(U^* (s)\) is the unique value such that \(\pi^* \) satisfies the normalization constraint \(\sum_a \pi^* (a|s) = 1\)

    \[\mathbb{E}_{a \sim \mu} \left[ \max\left\{ \frac{1}{2} + \frac{Q^*(s,a) - U^*(s)}{2\alpha}, 0 \right\} \right] = 1\]
  3. Substitute the derivative \(f'(x) = 1\) into general form for \(V^*(s)\) from Theorem 1:

    \[V^*(s) = U^*(s) + \alpha \mathbb{E}_{a \sim \mu} \left[ \left(\frac{\pi^*(a|s)}{\mu(a|s)}\right)^2 \right]\]

SQL-U Derivation#

\(U*\) is not computable#

Unfortunately, \(U^*(s)\) is not tractable to compute from this expectation, so the authors come up with clever minimization problem where its solution \(U^*\) exactly satisfies this probability constraint.

\[L(U) = \mathbb{E}_{a \sim \mu} \left[ \left(\max\left\{ \frac{1}{2} + \frac{Q^*(s,a) - U(s)}{2\alpha}, 0 \right\}\right)^2 \right] + \frac{U(s)}{\alpha}\]

The solution \(U^*\) is just where the \(\nabla_u L(U)\)=0.

To see this, define \(X(a, U) = \frac{1}{2} + \frac{Q^*(s,a) - U(s)}{2\alpha}\) for clarity, so

\[L(U) = \mathbb{E}_{a \sim \mu} \left[ (\max\{X(a, U), 0\})^2 \right] + \frac{U(s)}{\alpha}\]
\[\frac{dL(U)}{dU(s)} = \frac{d}{dU(s)} \left( \mathbb{E}_{a \sim \mu} \left[ (\max\{X(\alpha,U), 0\})^2 \right] + \frac{U(s)}{\alpha} \right)\]

We can move the derivative inside the expectation and differentiate term by term:

\[\frac{dL(U)}{dU(s)} = \mathbb{E}_{a \sim \mu} \left[ \frac{d}{dU(s)} (\max\{X, 0\})^2 \right] + \frac{d}{dU(s)} \left( \frac{U(s)}{\alpha} \right)\]

With application of chain rule in the first term, $\( \begin{align} & = \left( 2 \cdot \frac{d}{dU(s)} \max\{X, 0\} \right) \\ & = 2 \cdot 1(X(\alpha,U)>0) \cdot \frac{d}{dU(s)} X(\alpha,U)\\ & = 2 \cdot 1(X(\alpha,U)>0) \cdot\left( -\frac{1}{2\alpha} \right)\\ & = 1(X(\alpha,U)>0) \cdot\left( -\frac{1}{\alpha} \right)\\ \end{align} \)$

Substitute the definition \(X(U^*,\alpha)\) back in and calculate the derivative of the second term \begin{align} &=\mathbb E_{a\sim\mu} \left[-\frac{1}{\alpha}\max(\frac{1}{2} + \frac{Q^(s,a) - U^(s)}{2\alpha},0)\right] + \frac{1}{\alpha} = 0 \end{align}

Subtract \(\frac{1}{\alpha}\) from both sides and multiply by \(\alpha\) to get the probability constraint back

\[\mathbb{E}_{a \sim \mu} \left[ \max\left\{ \frac{1}{2} + \frac{Q^*(s,a) - U^*(s)}{2\alpha}, 0 \right\} \right] = 1\]

\(\pi^*\) is not computable either#

Similarly, \(\pi*\) cannot be computed using our closed form expression because we don’t have access to \(\mu\). Instead, we can do simple BC, which minimizes forward \(KL(\pi^*\|\pi)\)

\[\begin{split} \begin{align} & \min_\pi E_{a\sim\pi^*} \left[\frac{\log \pi^*(a|s)}{\log \pi(a|s)}\right] \\ & \equiv - \min_\pi E_{a\sim\pi^*} \left[\log \pi(a|s)\right] \\ & \equiv \max_\pi E_{a\sim\pi^*} \left[\log \pi(a|s)\right]\\ & \equiv \max_\pi \mathbb{E}_{a \sim \mu} \left[ \max\left\{ \frac{1}{2} + \frac{Q^*(s,a) - U^*(s)}{2\alpha}, 0 \right\}\log \pi(a|s)\right] \end{align} \end{split}\]

which is exactly a weighted log-likelihood i.e weighted BC

Algorithm 1 (Sparse Q-Learning-U)

Require Dataset \(\mathcal{D}\), regularization strength \(\alpha\)

Initialize Networks \(U_{\kappa}\), \(V_{\psi}\), \(\pi_{\theta}\), \(Q_{\phi}\), and target network \(Q_{\phi'}\).

for \(t = 1, 2, \dots, N\) do:

  1. Sample a mini-batch of \(M\) transitions \(\mathcal B = \{(s_i, a_i, r_i, s'_i)\}_{i=1}^M \sim \mathcal{D}\)

  2. Update V-Network (\(V_\psi\)):

    1. Compute target \(y_{V,i}\), estimated with the single sample \((s_i, a_i)\): \(y_{V,i} = U_\kappa(s_i) + \alpha \left( \max\left\{ \frac{1}{2} + \frac{Q_{\phi'}(s_i,a_i) - U_\kappa(s_i)}{2\alpha}, 0 \right\} \right)^2\)

    2. Minimize MSE loss: \(\mathcal{L}_V(\psi) = \frac{1}{M} \sum_{i=1}^M \left( V_\psi(s_i) - y_{V,i} \right)^2\)

  3. Update U-Network (\(U_\kappa\)): \(\mathcal{L}_U(\kappa) = \frac{1}{M} \sum_{i=1}^M \left[ \left(\max\left\{ \frac{1}{2} + \frac{Q_{\phi'}(s_i,a_i) - U_\kappa(s_i)}{2\alpha}, 0 \right\}\right)^2 + \frac{U_\kappa(s_i)}{\alpha} \right]\)

  4. Update Q-Network (\(Q_\phi\)):

    1. Compute standard Bellman target using the V-network: \(y_{Q,i} = r_i + \gamma V_\psi(s'_i)\)

    2. Minimize MSE loss: $\(\mathcal{L}_Q(\phi) = \frac{1}{M} \sum_{i=1}^M \left( Q_\phi(s_i, a_i) - y_{Q,i} \right)^2\)$

  5. Update Target Networks: \(\phi' \leftarrow \lambda\phi + (1 - \lambda)\phi'\)

  6. Update Policy Network (\(\pi_\theta\)) (Extraction):

    1. Calculate detached weights \(w_i\) for standard weighted regression (behavior cloning): \(w_i = \max\left\{ \frac{1}{2} + \frac{Q_{\phi'}(s_i,a_i) - U_\kappa(s_i)}{2\alpha}, 0 \right\}\)

    2. Minimize negative weighted log-likelihood: \(\mathcal{L}_\pi(\theta) = - \frac{1}{M} \sum_{i=1}^M w_i \log \pi_\theta(a_i|s_i)\)

Sparse Q-Learning Derivation#

To eliminate the computation/storage of maintaining a U network, we notice the following. Since \(\alpha\)-divergence is mode-seeking, \(\pi^*\) will suffer a huge divergence penalty if it places any density \(\mu\) places none. This penalty is exactly the zero-forcing constraint where as a result, for actions \(a\) sampled under \(\pi^*\), \(\pi^*(a|s) \approx \mu(a|s)\).

Optimization to obtain \(V^*\)#

\[\implies \mathbb E_{a* \sim \pi^*(\cdot|s)}\left[\frac{\pi^*(a|s)}{\mu(a|s)}\right] \approx 1\]

In \(V^*\), notice \(\mathbb E_{a \sim \mu(\cdot|s)}\left[\left(\frac{\pi^*(a|s)}{\mu(a|s)}\right)^2\right] = E_{a* \sim \pi^*(\cdot|s)}[\frac{\pi^*(a|s)}{\mu(a|s)}]\)

We substitute our approximation into this term to reveal,

\[V^*(s) = U^*(s) + \alpha\]

In \(U\)’s optimization problem, we replace \(U\) with \(V\) in both terms

  1. The term with \(U\) in the expectation is

\[ \frac{1}{2} + \frac{Q(s,a) - (V(s) - \alpha)}{2\alpha} = \frac{1}{2} + \frac{Q(s,a) - V(s) + \alpha}{2\alpha} = \frac{1}{2} + \frac{Q(s,a) - V(s)}{2\alpha} + \frac{1}{2}\]
  1. The added term is $\(\frac{U(s)}{\alpha} = \frac{V(s)}{\alpha} - 1\)$

Putting it together, we obtain V’s optimization

\[\min_V \mathbb{E} \left[ \left(\max\left\{ 1 + \frac{Q(s,a) - V(s)}{2\alpha}, 0 \right\}\right)^2 \right] + \mathbb{E}\left[\frac{V(s)}{\alpha} - 1\right]\]

Optimization to obtain \(\pi^*\)#

Plug the approximation \(V*(s) = U*(s)+\alpha\) into the max term in $pi^*

\[\max\left\{ \frac{1}{2} + \frac{Q^*(s,a) - U^*(s)}{2\alpha}, 0 \right\} = \max\left\{ 1 + \frac{Q^*(s,a) - V^*(s)}{2\alpha}, 0 \right\} \]

And follow the same derivation using forward KL

\[\begin{split} \begin{align} &\equiv \max_\pi \mathbb{E}_{a \sim \pi^*(\cdot|s)} \left[\log\pi(a|s)\right]\\ &\equiv \max_\pi \mathbb{E}_{a \sim \mu} \left[ \max\left\{ 1 + \frac{Q(s,a) - V(s)}{2\alpha}, 0 \right\}\log \pi(a|s)\right] \end{align} \end{split}\]

SQL Algorithm#

Algorithm 2 (Sparse Q-Learning)

Require Dataset \(\mathcal{D}\), regularization strength \(\alpha\)

Initialize Networks \(V_{\psi}\), \(\pi_{\theta}\), \(Q_{\phi}\), and target network \(Q_{\phi'}\).

for \(t = 1, 2, \dots, N\) do:

  1. Sample a mini-batch of \(M\) transitions \(\mathcal B = \{(s_i, a_i, r_i, s'_i)\}_{i=1}^M \sim \mathcal{D}\)

  2. Update V-Network (\(V_\psi\)): \(\mathcal{L}_V(\psi) = \frac{1}{M} \sum_{i=1}^M \left[ \left(\max\left\{ 1 + \frac{Q_{\phi'}(s_i,a_i) - V_\psi(s_i)}{2\alpha}, 0 \right\}\right)^2 + \frac{V_\psi(s_i)}{\alpha} \right]\)

  3. Update Q-Network (\(Q_\phi\)):

    1. Compute standard Bellman target using the V-network: \(y_{Q,i} = r_i + \gamma V_\psi(s'_i)\)

    2. Minimize MSE loss: \(\mathcal{L}_Q(\phi) = \frac{1}{M} \sum_{i=1}^M \left( Q_\phi(s_i, a_i) - y_{Q,i} \right)^2\)

  4. Update Target Network: \(\phi' \leftarrow \lambda\phi + (1 - \lambda)\phi'\)

  5. Update Policy Network (\(\pi_\theta\)) (Extraction):

    1. Calculate detached weights \(w_i\) using the approximation \(U(s) \approx V(s) - \alpha\): \(w_i = \max\left\{ 1 + \frac{Q_{\phi'}(s_i,a_i) - V_\psi(s_i)}{2\alpha}, 0 \right\}\)

    2. Minimize negative weighted log-likelihood: \(\mathcal{L}_\pi(\theta) = - \frac{1}{M} \sum_{i=1}^M w_i \log \pi_\theta(a_i|s_i)\)

Exponential Q-Learning#

Alternatively, as \(\alpha \rightarrow 0\), \(D_f\) is Reverse KL, a mode-seeking divergence.

\[f(x) = \log x \implies f'(x) = \frac{1}{x}\]

And by definition:

\[h'_f(x) = f(x) + xf'(x) = \log x + 1\]
\[g_f(y) = \exp (x-1)\]

With this we can define the relevant optimization objectives

  1. Substitute \(g_f(y)\) into general policy form from Theorem 1: $\( \begin{align} \pi^*(a|s) &= \mu(a|s) \cdot \max\left\{ \exp\left( \frac{Q^*(s,a) - U^*(s)}{\alpha}-1\right), 0 \right\}\\ &=\mu(a|s) \cdot \exp\left( \frac{Q^*(s,a) - U^*(s)}{\alpha}-1\right) \end{align} \)$

  2. Substitute the derivative \(f'(x) = \frac{1}{x}\) into general form for \(V^*(s)\) from Theorem 1:

    \[V^*(s) = U^*(s) + \alpha \mathbb{E}_{a \sim \mu} \left[ \left(\frac{\pi^*(a|s)}{\mu(a|s)}\right)^2 \frac{\mu(a|s)}{\pi^*(a|s)} \right]\]

In the above equation, we notice a very interesting fact that leads to a simple equation for \(U^*\)

\[\begin{split} \begin{align} & \mathbb{E}_{a \sim \mu} \left[ \left(\frac{\pi^*(a|s)}{\mu(a|s)}\right)^2 \frac{\mu(a|s)}{\pi^*(a|s)} \right] \\ & = E_{a\sim\mu}\left[\frac{\pi(a|s)}{\mu(a|s)}\right] \\ &= \sum_a \pi^* (a|s) \\ &= 1 \\ &\implies V^*(s) = U^*(s) + \alpha \end{align} \end{split}\]

V#

This relationship \(V^*(s) = U^*(s) + \alpha\) also simplifies the form of \(\pi^*\) $\( \begin{align} \pi^*(a|s) &= \mu(a|s) \cdot \exp\left( \frac{Q^*(s,a) - U^*(s)}{\alpha}-1\right)\\ &=\mu(a|s) \cdot \exp\left( \frac{Q^*(s,a) - V^*(s)+\alpha}{\alpha}-1\right) \\ &=\mu(a|s) \cdot \exp\left( \frac{Q^*(s,a) - V^*(s)}{\alpha}\right) \\ \end{align} \)$

Unlike SQL, the dual variable \(U^*\) can be eliminated without any approximation by substituting \(V^*(s)-\alpha\)

Exactly as SQL, the Lagrange multiplier \(U^* (s)\) is the unique value such that \(\pi^* \) satisfies the normalization constraint \(\sum_a \pi^* (a|s) = 1\)

\[\mathbb{E}_{a \sim \mu} \left[ \exp\left( \frac{Q^*(s,a) - V^*(s)}{\alpha}\right) \right] = 1\]

We could solve this for \(V*(s)\) algebraically,

\[\begin{split} \begin{align} 1=&\exp\left(\frac{-V^*(s))}{\alpha}\right) \cdot \mathbb{E}_{a \sim \mu} \left[ \exp\left(\frac{Q(s,a))}{\alpha}\right) \right] \\ \log(1)=&log\left[\left(\frac{-V^*(s))}{\alpha}\right) \cdot \mathbb{E}_{a \sim \mu} \left[ \exp\left(\frac{Q(s,a))}{\alpha}\right) \right]\right]\\ V^*(s)&=\alpha\log\mathbb{E}_{a \sim \mu} \left[ \exp\left(\frac{Q(s,a))}{\alpha}\right) \right] \end{align}\\ \end{split}\]

Unfortunately, this poses a problem in estimating the expectation because it is unlikely our offline dataset contains more than 1 transition from a state in the continous state setting or more than a couple in the discrete state setting.

Instead, the authors define a similar optimization to \(U\) in \(SQL\), but in this case, whose solution \(V^*\) (where the gradient is zero) will ensure we satisfy the probability constraint. Take the gradient yourself and set to 0 to see it is exactly the probability constraint.

\[\min_V \mathbb{E}_{a \sim \mu} \left[ \exp\left(\frac{Q(s,a) - V(s)}{\alpha}\right) \right] + \frac{V(s)}{\alpha}\]

Optimization for \(\pi\)#

We follow the same policy extraction procedure as SQL.

\[\max_\pi E_{a\sim\pi^*}\left[\log\pi(a|s)\right]\equiv\max_\pi E_{a\sim\pi^*}\left[\mu(a|s) \cdot \exp\left( \frac{Q^*(s,a) - V^*(s)}{\alpha}\right)\log\pi(a|s)\right]\]

EQL Algorithm#

Algorithm 3 (Sparse Q-Learning)

Require Dataset \(\mathcal{D}\), regularization strength \(\alpha\)

Initialize Networks \(V_{\psi}\), \(\pi_{\theta}\), \(Q_{\phi}\), and target network \(Q_{\phi'}\).

For \(t = 1, 2, \dots, N\) do:

  1. Sample a mini-batch of \(M\) transitions \(\mathcal B = \{(s_i, a_i, r_i, s'_i)\}_{i=1}^M \sim \mathcal{D}\)

  2. Update V-Network (\(V_\psi\)): \(\mathcal{L}_V(\psi) = \frac{1}{M} \sum_{i=1}^M \left[ \exp\left(\frac{Q_{\phi'}(s_i,a_i) - V_\psi(s_i)}{\alpha}\right) + \frac{V_\psi(s_i)}{\alpha} \right]\)

  3. Update Q-Network (\(Q_\phi\)):

    1. Compute standard Bellman target using the V-network: \(y_{Q,i} = r_i + \gamma V_\psi(s'_i)\)

    2. Minimize MSE loss: \(\mathcal{L}_Q(\phi) = \frac{1}{M} \sum_{i=1}^M \left( Q_\phi(s_i, a_i) - y_{Q,i} \right)^2\)

  4. Update Target Network: \(\phi' \leftarrow \lambda\phi + (1 - \lambda)\phi'\)

  5. Update Policy Network (\(\pi_\theta\)) (Extraction):

    1. Calculate detached exponential weights (derived from the transformed distillation objective): \(w_i = \exp\left(\frac{Q_{\phi'}(s_i,a_i) - V_\psi(s_i)}{\alpha}\right)\)

    2. Minimize negative weighted log-likelihood: \(\mathcal{L}_\pi(\theta) = - \frac{1}{M} \sum_{i=1}^M w_i \log \pi_\theta(a_i|s_i)\)