DiCE: The Infinitely Differentiable Monte Carlo Estimator

If you’ve stumbled upon this blog post, you’ve probably used policy gradient methods in Reinforcement Learning (RL). Or you might have maximised the likelihood in probabilistic models. In both cases, we need to estimate the gradient of the loss, which is an expectation over random variables.

The problem is that you cannot just differentiate the objective. Usually, you will apply the score function trick (aka log likelihood trick) here. We can view this trick as providing a differentiable function, whose gradient is an estimate of the gradient of the original objective. We can then apply any deep learning toolbox to do automatic differentiation. However, sometimes we need higher-order gradients, e.g., in meta-learning or multi-agent RL when we need to differentiate through other agents’ learning steps. This makes life much harder.

Infinitely Differentiable Monte Carlo Estimator (DiCE) [1] to the rescue! You can apply the magic \magic objective repeatedly infinitely many times to get the correct higher order gradients under Stochastic Computation Graph (SCG) formalism [2]. This lets automatic differentiation software do the job instead of us manipulating the graph manually. We illustrate the benefits of our approach applying “Learning with Opponent Learning Awareness” (LOLA) [3] to the iterated prisoner’s dilemma.

DiCE

As we mention above, in the surrogate loss (SL) approach, we choose an objective, whose gradient equals the true gradient of the objective and use this function to do the optimisation.

Sadly, constructing surrogate loss using the first-order gradient as an objective leads to wrong second-order gradient estimation. Simply put, applying SL twice and estimating the gradient is not the same as the second-order gradient of the true objective.

The wrong estimation happens because, in the SL approach, we treat part of the objective as a sampled cost. This causes the corresponding terms to lose a functional dependency on the sampling distribution.

We illustrate our reasoning graphically in the figure below using Stochastic Computation Graphs (SCGs) (Schulman et al. 2015) formalism.

Stochastic nodes are in orange, costs in grey, surrogate losses in blue, DiCE in purple, and gradient estimators in red.

We introduce the magic \magic operator, which allows us to compute the gradient to any order we like: \Expect[\nabla_{\theta}^n\calL_{\magic}] \rightarrowtail \nabla_{\theta}^{n}\calL, \forall n \in \{0, 1, 2, ...\}.

DiCE is easy to implement:

(1)   \begin{equation*} \magic(\calW) = \exp{(\tau - \perp(\tau))}, \tau=\sum_{w \in \calW}\log(p(w;\theta)), \end{equation*}

where \perp is an operator which sets the gradient of its operand to zero (detach in Pytorch and stop_gradient() in Tensorflow:

Alternatively, we can rewrite DiCE in the following way:

(2)   \begin{equation*} \magic(\calW) = \frac{\prod_{w \in \calW}p(w;\theta)}{\prod_{w \in \calW} \perp p(w;\theta)}. \end{equation*}

The figure below shows an example of DiCE applied to an RL problem:

DiCE applied to a reinforcement learning problem. A stochastic policy conditioned on s_t and \theta produces actions, a_t, which lead to rewards r_t and next states, s_{t+1}. Associated with each reward is a DiCE objective that takes as input the set of all causal dependencies that are functions of \theta, i.e, the actions. Arrows from \theta,a_i and r_i to gradient estimators omitted for clarity.

Variance Reduction

Variance reduction is an integral part of Monte Carlo estimation.
Though DiCE is not limited to the RL case, we are most interested in policy gradients that use the score function trick.

DiCE inherently reduces variance by taking causality into account. The cost node c is multiplied by the sum of the gradients of the log probabilities only for those nodes that influence c.

Now we propose another variance reduction mechanism by adding the following term to the DiCE objective:

(3)   \begin{align*} \calB_{\magic}^{(1)} &= \sum_{w \in \calS}{(1-\magic({w}))b_w},\nonumber \end{align*}

where b_w is any function of nodes not influenced by w. The baseline keeps the gradient estimation unbiased and does not influence the evaluation of the original objective \calL_{\magic}.

The flaw of \calB_{\magic}^{(1)} becomes apparent when we calculate second-order gradients. In two words, some the terms do not have control variates keeping variance high.

To fix the problem, we can subtract the following term from the objective to reduce the second-order gradient variance:

(5)   \begin{align*} \calB_{\magic}^{(2)} &= \sum_{w \in \calS'}{\big((1-\magic({w})\big) \big(1-\magic({\calS_w})\big)b_w}, \nonumber \end{align*}

where \calS' is the set of stochastic nodes that depend on \theta and at least one other stochastic node.

Code example

To show DiCE in action, we apply it to the iterated prisoner’s dilemma (IPD). In IPD, two agents iteratively play matrix games where they can either (C)ooperate or (D)efect. The first agent’s payoffs are the following: -2 (DD), 0 (DC), -3 (CD), -1 (CC).

Let’s build policies for both agents first:

Now, let’s build the DiCE objective:

Computing the gradient or hessian of the parameters is just calling tf.gradients() or tf.hessians() on the parameters:

You can find the complete working example here.

Empirical Results

Let’s now see the empirical verification of DiCE. From the figure below we can see that the second-order baseline \calL_{\magic}^{b_2} helps us to match the analytically derived Hessian, whereas the first-order one fails to do that.



The following figure shows that however, the quality of the gradient estimation increases with the sample size, \calL_{\magic}^{b_1} does not achieve that performance as \calL_{\magic}^{b_2} does. The results including the second-order baseline are in orange, the ones for first-order only are in blue.

Finally, we will show how DiCE helps us get better performance on IPD using LOLA [3]. Comparing LOLA-DICE agents and the original formulation LOLA-DICE agents discover strategies of high social welfare, replicating the results of the original LOLA paper in a way that is both more direct and efficient.

Joint average per step returns for different training methods. Shaded areas represent the 95% confidence intervals based on five runs. All agents used batches of size 64, which is more than 60 times smaller than the size required in the original LOLA paper.

As we can see in the figure below, the second-order baseline dramatically improves LOLA performance on the IPD problem:

LOLA performance with \calL_{\magic}^{b_1} (red) and \calL_{\magic}^{b_2} (blue).

Conclusion

In this post, we have described DiCE, a general method for computing any order gradient estimators for stochastic computation graphs. DiCE is easy to implement, however, at the same time it allows us to use the whole power of auto-differentiation software without manually constructing the graph for each order of the gradient. We believe DiCE will be a stepping stone for further exploration of higher order learning methods in meta-learning, reinforcement learning other applications of stochastic computation graphs.

Whether you want to build upon DiCE or are just interested to find out more, you can find our implementation here. For PyTorch lovers there is also an implementation by Alexis David Jacq.

References

Blogpost: Vitaly Kurin, Jakob Foerster, Shimon Whiteson.

[1]
J. Foerster, G. Farquhar, M. Al-Shedivat, T. Rocktaschel, E. Xing, and S. Whiteson, “DiCE: The Infinitely Differentiable Monte Carlo Estimator,” in Proceedings of the 35th International Conference on Machine Learning, 2018, vol. 80, pp. 1524–1533.
[2]
J. Schulman, N. Heess, T. Weber, and P. Abbeel, “Gradient estimation using stochastic computation graphs,” in Advances in Neural Information Processing Systems, 2015, pp. 3528–3536.
[3]
J. Foerster, R. Y. Chen, M. Al-Shedivat, S. Whiteson, P. Abbeel, and I. Mordatch, “Learning with opponent-learning awareness,” in Proceedings of the 17th International Conference on Autonomous Agents and MultiAgent Systems, 2018, pp. 122–130.

Leave a Reply

Your email address will not be published. Required fields are marked *