# Control As Inference

We introduce a binary variable for Optimality \(\mathcal{O}_t\) at each time-step. We want to infer: \(p(\tau | \mathcal{O}_{1:T})\)

If we choose \(p(O_t | s_t, a_t) = \mathrm{exp}(r(s_t, a_t))\), then:

\begin{align}
p(\tau | \mathcal{O}_{1:T}) &= \frac{p(\tau,
\mathcal{O}_{1:T})}{p(\mathcal{O}_{1:T})}
\\\

&\propto \prod_{t} \mathrm{exp}(r(s_t,
a_t)) \\\

&= p(\tau) \mathrm{exp} \left( \sum_{t}
r(s_t, a_t) \right)
\end{align}

With this Probabilistic Graph Model, we can:

- model sub-optimal behaviour (important for inverse RL)
- can apply inference algorithms to solve control and planning problems
- provides an explanation for why stochastic behaviour may be preferred (useful for exploration and transfer learning)

## Inference

- compute backward messages \(\beta_t (s_t, a_t) = p(\mathcal{O}_{t:T} | s_t, a_t)\)
- compute policy \(p(a_t | s_t, \mathcal{O}_{1:T})\), the policy of this model under assumption of optimality
- compute forward messages \(\alpha_t(s_t) = p(s_t | \mathcal{O}_{1:t-1})\)
- useful for figuring out which states the optimal policy lands in, for the inverse RL problem (not used for forward RL)

## Backward Messages

\begin{align}
\beta_t (s_t, a_t) &= p(\mathcal{O}_{t:T} | s_t, a_t) \\\

&= \int p(\mathcal{O}_{t:T}, s_{t+1} | s_t, a_t)
ds_{t+1} \\\

&= \int p(\mathcal{O}_{t+1:T}|s_{t+1})
p(s_{t+1}|s_t,a_t) p(\mathcal{O}_t | s_t, a_t)
ds_{t+1}
\end{align}

\begin{align}
p(\mathcal{O}_{t+1:T} | s_{t+1}) &= \int p(\mathcal{O}_{t+1:T} |
s_{t+1}, a_{t+1})p(a_{t+1}| s_{t+1}) da_{t+1} \\\

&= \int \beta_t(s_{t+1}, a_{t+1}) da_{t+1}
\end{align}

where we assume actions are likely a priori uniform. From these equations, we can get:

For \(t = T-1 \mathrm{ to } 1\):

\begin{equation} \beta_t(s_t, a_t) = p(\mathcal{O}_t | s_t, a_t) E_{s_{t+1} \sim p(s_{t+1},a_{t+1})} \left[ \beta_{t+1} (s_{t+1}) \right] \end{equation}

\begin{equation} \beta_{t}(s_t) = E_{a_t \sim p(a_t | s_t)} \left[ \beta_t(s_t, a_t) \right] \end{equation}

If we choose \(V_t (s_t) = \log \beta_t (s_t)\) and \(Q_t(s_t, a_t) = \log \beta_t (s_t, a_t)\):

\begin{align}
V_t(s_t) &= \log \int \mathrm{exp} (Q_t(s_t, a_t))da_t \\\

&\rightarrow \mathrm{max}_{a_t} Q_t(s_t, a_t) \textrm { as
} Q_t(s_t, a_t) \textrm { gets bigger }
\end{align}

For \(Q\):

\begin{equation} Q_t (s_t, a_t) = r(s_t, a_t) + \log E\left[ \mathrm{exp} (V_{t+1} (s_{t+1}, a_{t+1})) \right] \end{equation}

In a deterministic transition setting, the log and exp cancel out. However, this otherwise results in an optimistic transition, which is not a good idea!

What if the action prior is not uniform? **We can always fold the action
prior into the reward!**

## Policy computation

\begin{align}
p(a_t | s_t, \mathcal{O}_{1:T}) &= \pi (s_t | a_t) \\\

&= p(a_t | s_t, \mathcal{O}_{t:T})
\\\

&= \frac{\beta_t(s_t,
a_t)}{\beta_t(s_t)}p(s_t|a_t) \\\

&= \frac{\beta_t(s_t,
a_t)}{\beta_t(s_t)}
\end{align}

It turns out the policy is just the ratio between the 2 backward messages. Substituting \(V\) and \(Q\):

\begin{equation} \pi(a_t | s_t) = \mathrm{exp}(Q_t(s_t, a_t) - V_t(s_t)) = \mathrm{exp}(A_t(s_t, a_t)) \end{equation}

One can also add a temperature: \(\pi(a_t | s_t) = \mathrm{exp}(\frac{1}{\alpha} A_t(s_t, a_t))\)

## Forward Messages

\begin{equation} p(s_t) \propto \beta_t(s_t) \alpha_t(s_t) \end{equation}

same derivations as Hidden Markov Model!

## Resolving Optimism with Variational Inference

For more, see Levine, n.d..