Computing probabilities of a transformed distribution#

Background#

Change of variable#

Given \(x\sim p_X\) and \(y=g(x)\) where \(g\) is a bijective mapping, the probability of the changed variable \(y\) is

\[\begin{split}\begin{array}{rcl} p_Y(y) &=& p_X(g^{-1}(y))\Big|\text{det}\big[\frac{d g(x)}{d x}\Big|_{x=g^{-1}(y)}\big]\Big|^{-1}\\ &=& p_X(g^{-1}(y)) h(g^{-1}(y))\\ \end{array}\end{split}\]

Note that the inverted absolute determinant of the Jacobian can just be treated as a function \(h\) of \(x\), and we would like to evaluate its value for each particular value \(\hat{y}\), i.e., \(p_Y(\hat{y})=p_X(g^{-1}(\hat{y}))h(g^{-1}(\hat{y}))\).

The following two common cases involve computing the (log) probability of a changed variable, but should be treated differently.

Policy gradient (PG)#

The policy gradient for training a policy \(\pi_{\theta}(a)\) is

\[-\int_{a}\pi_{\theta}(a)\nabla_{\theta}\log\pi_{\theta}(a) A(a) da\]

where we omit the dependency on the observation \(s\) for simplicity. It should be noted that the above formula is an expectation of gradient vectors. If we denote \(\mathbf{g}(\cdot)=\nabla_{\theta}\log\pi_{\theta}(\cdot)\), then policy gradient tells us that for each sampled \(a\), we evaluate \(\mathbf{g}(a)\), and compute the average. Or alternatively, we could first detach \(a\) (it might be re-parameterized) and then compute \(\nabla_{\theta}\log\pi_{\theta}(a)\).

For other objectives like MLE, a similar issue should be paid attention to.

Entropy gradient#

The gradient of entropy can be estimated by the re-parameterization trick:

\[\begin{split}\begin{array}{rl} &-\nabla_{\theta}\int_{a}\pi_{\theta}(a)\log\pi_{\theta}(a) da\\ =&-\nabla_{\theta}\int_{\epsilon}p(\epsilon)\log\pi_{\theta}(f(\epsilon;\theta)) d\epsilon\\ =&-\int_{\epsilon}p(\epsilon)\nabla_{\theta}\log\pi_{\theta}(f(\epsilon;\theta)) d\epsilon\\ \end{array}\end{split}\]

That is, the derivative takes place after the substitution. As a result, we should not detach the gradient of \(f(\epsilon;\theta)\) when computing the log probability in this case. Also see the analysis in Estimating the derivative of an expectation.

In the PG scenario, it is much easier for us to make mistakes of forgetting to detach sampled \(a\), thus we will discuss it in depth below.

What is the essential difference between detach and no detach of \(y\) for PG?#

Back to the general case \(y=g(x)\), if we just want to compute the probability without taking the derivative w.r.t. some parameters of \(g\), then whether detach \(y\) or not does not matter. Namely, the following two values are equal:

\[\begin{split}\begin{array}{rclr} p_Y(y) &=& p_X(g^{-1}(y)) h(g^{-1}(y)) & \text{(detach)} \\ &=& p_X(g^{-1}(g(x)))h(g^{-1}(g(x)))=p_X(x)h(x) & \text{(no detach)}\\ \end{array}\end{split}\]

However, when \(g\) contains some trainable parameters we’d like to optimize, whether to detach \(y\) or not should strictly depend on the gradient formula. Otherwise, the gradient might be incorrect. For example, the second form \(p_X(x)h(x)\) might wipe out some parameters because \(g^{-1}\) is not included.

An example#

For a simple example, if \(x\sim\mathcal{N}(0,1)\), \(y=g(x;\sigma,\mu)=x\sigma + \mu\), then with detached \(y\)

\[\begin{split}\begin{array}{rcl} p_Y(y) &=& p_X(g^{-1}(y)) \Big|\frac{\partial g(x)}{\partial x}\big|_{x=g^{-1}(y)}\Big|^{-1} \\ &=& p_X(\frac{y-\mu}{\sigma}) \sigma^{-1}\\ &=& \frac{1}{\sqrt{2\pi}\sigma} e^{-\frac{(x-\mu)^2}{2\sigma^2}}\\ \end{array}\end{split}\]

which is the correct p.d.f. for training \(\mu\) with PG. However, with \(y\) undetached:

\[\begin{split}\begin{array}{rcl} p_Y(y) &=& p_X(x) \Big|\frac{\partial g(x)}{\partial x}\big|_{x=x}\Big|^{-1} \\ &=& \frac{1}{\sqrt{2\pi}\sigma} e^{-\frac{x^2}{2}}\\ \end{array}\end{split}\]

with \(x=g^{-1}(g(x))\), not including \(\mu\) in the resulting p.d.f. Thus PG won’t be able to adjust \(\mu\).

Transform caches#

A cache is usually helpful for a bijective transform where the inverse is computationally expensive or numerically unstable. The way PyTorch uses a cache is that for an inquery \(y\), it checks whether it has the same object id with \(y_{old}\) from the cached pair \((x_{old}, y_{old})\). If yes, \(x_{old}\) is returned. Any out-of-place operation (e.g., detach()) that makes \(y\) a different object will invalidate the cache, and an inverse has to be computed.

If a transform \(g\) has a cache turned on and \(y\) is undetached, it will circumvent the computation of \(g^{-1}\) and retrieves \(x(\equiv g^{-1}(g(x)))\) directly for numerical stability. Again, PyTorch transform cache only works when \(y\) is not detached; the role of a cache is no more than solving the numerical issue for an undetached \(y\).

Cache on

Cache off

\(y\) detached

\(x \leftarrow g^{-1}(y)\) (*)

\(x \leftarrow g^{-1}(y)\) (*)

\(y\) undetached

\(x \leftarrow x\)

\(x \leftarrow g^{-1}(g(x))\) (*)

‘*’ means potential numerical issues when inverting.

A composition of transforms#

Suppose \(g\) is now a composition of transforms \(g=g_N\circ \ldots g_k \ldots \circ g_1\), where \(g_k\) is the first (from left to right) transform that contains trainable parameters. Then for PG, it is completely fine to do either of the following things:

  1. Simply detach \(y\) and invalidate all caches.

  2. Don’t detach \(y\), turn on caches for transforms \(g_N, \ldots, g_{k+1}\), and detach the input \(g^{-1}_{k+1}\circ \ldots \circ g^{-1}_N(y)\) to \(g_1^{-1}\circ \ldots \circ g^{-1}_k\).

PyTorch’s sample() and rsample()#

So far we’ve assumed that \(p_X(\cdot)\) does not contain any trainable parameter, and all model parameters are in the transforms. This is in fact a very general and valid assumption, and it will work well in practice. However, PyTorch’s implementation is different by assuming \(p_{\theta}(x)\), which can be treated as merging the original \(p_X\) with the first parameterized transform \(g_1(\circ;\theta)\). This is convenient if there is a closed form p.d.f. of \(p_X(\cdot;\theta)\) (e.g. \(\mathcal{N}(\mu,\sigma)\)). For this parameterized distribution, PyTorch introduces two functions of sample() and rsample(). We can simply assume that

rsample() of \(p_{\theta}(x)\)

\(x\sim p_X, y_1=g_1(x;\theta)\)

sample() of \(p_{\theta}(x)\)

additional detach on the rsampled \(y_1\)

When we call sample() (rsample()) on the transformed distribution \(p_Y(y)\), PyTorch will first call sample() (rsample()) of \(p_{\theta}(x)\), and apply the remaining transforms. This means that even with sample(), the final \(y=g_N\circ\ldots\circ g_2\) itself is not detached by default! This might cause errors when it is directly used for computing PG w.r.t. parameters in \(g_2,\ldots,g_N\) (e.g., normalizing flow transforms). Thus for PG, no matter whether an action is sampled by rsample() or sample(), the safest way is to always detach it before computing the probability.

However, if you are certain that there is no trainable parameter in \(g_2,\ldots,g_N\), then detach is not necessary for sample(). But detach is still necessary for rsample() because of \(g_1\) (\(p_{\theta}(x)\)) parameters.

SAC and DDPG are rsample safe by nature#

For non policy gradient methods like SAC and DDPG, it’s safe to use undetached action from rsample() anywhere in the code and have all transform caches turned on. The only place they need to compute its log probability is when estimating the entropy gradient. But as explained above, in this case, the action must not be detached.

Action transformations as environment wrappers#

ALF’s AC and PPO algorithms always detach the action from sample() for PG loss, without checking if the transforms have trainable parameters or not. This simplicity invalidates caches and sometimes causes numerical issues even when all the transforms do not have trainable parameters.

If we know the transforms are not trainable, then a better way is that we don’t detach the action but exploit the cache for PG to avoid inverting. When the transform \(g=g_N\circ\ldots\circ g_2\) has no trainable parameters (e.g., StableTanh), we have parameters only exist in \(p_{\theta}(x)\). It follows

\[\begin{split}\begin{array}{rcl} \log p_Y(y) &=& \log p_{\theta}(g^{-1}(y)) - \log \Big|\frac{\partial g(x)}{\partial x}\big|_{x=g^{-1}(y)}\Big|\\ \int_y P_Y(y)\nabla_{\theta}\log p_Y(y) dy &=& \int_y P_Y(y)\nabla_{\theta}\log p_{\theta}(g^{-1}(y)) dy\\ &=&\int_x p_{\theta}(x) \nabla_{\theta}\log p_{\theta}(x) dx\\ \end{array}\end{split}\]

because we can discard the Jacobian determinant for \(\nabla_{\theta}\). Thus in this case, regarding PG, it’s equivalent to directly training \(p_{\theta}(x)\) in the untransformed action space \(X\) and apply the transformation on the environment side. If we do so, there is no longer an instability issue associated with PPO and AC.

One caveat of applying nonparameterized transformations on the environment side is, the actual entropy of environment actions is difficult to be estimated on the algorithm side. One solution is to still have \(g\) applied to \(p_{\theta}(x)\) for entropy calculation, while the PG loss directly uses \(p_{\theta}(x)\).