Mechanistic estimation for wide random MLPs
This post covers joint work with Wilson Wu, George Robinson, Mike Winer, Victor Lecomte and Paul Christiano. Thanks to Geoffrey Irving and Jess Riedel for comments on the post. In ARC's latest paper, we study the following problem: given a randomly initialized multilayer perceptron (MLP), produce an estimate for the expected output of the model under Gaussian input. The usual approach to this problem is to sample many possible inputs, run them all through the model, and take the average. Instead, we produce an estimate "mechanistically", without running the model even once. For wide models, our approach produces more accurate estimates, both in theory and in practice. Paper: Estimating the expected output of wide random MLPs more efficiently than sampling Code: mlp_cumulant_propagation GitHub repo We are excited about this result as an early step towards our goal of producing mechanistic estimates that outperform random sampling for any trained neural network. Drawing an analogy between this goal and a proof by induction, we see this result as (part of) the "base case": handling networks at initialization. We have a vision for the "inductive step", although we expect that to be much more difficult. Summary of results In our paper, we consider MLPs mjx-msub { display: inline-block; text-align: left; } mjx-TeXAtom { display: inline-block; text-align: left; } mjx-msup { display: inline-block; text-align: left; } mjx-mrow { display: inline-block; text-align: left; } mjx-mtext { display: inline-block; text-align: left; } mjx-mfrac { display: inline-block; text-align: left; } mjx-frac { display: inline-block; vertical-align: 0.17em; padding: 0 .22em; } mjx-frac[type="d"] { vertical-align: .04em; } mjx-frac[delims] { padding: 0 .1em; } mjx-frac[atop] { padding: 0 .12em; } mjx-frac[atop][delims] { padding: 0; } mjx-dtable { display: inline-table; width: 100%; } mjx-dtable > * { font-size: 2000%; } mjx-dbox { display: block; font-size: 5%; } mjx-num { display: block; text-align: center; } mjx-den { display: block; text-align: center; } mjx-mfrac[bevelled] > mjx-num { display: inline-block; } mjx-mfrac[bevelled] > mjx-den { display: inline-block; } mjx-den[align="right"], mjx-num[align="right"] { text-align: right; } mjx-den[align="left"], mjx-num[align="left"] { text-align: left; } mjx-nstrut { display: inline-block; height: .054em; width: 0; vertical-align: -.054em; } mjx-nstrut[type="d"] { height: .217em; vertical-align: -.217em; } mjx-dstrut { display: inline-block; height: .505em; width: 0; } mjx-dstrut[type="d"] { height: .726em; } mjx-line { display: block; box-sizing: border-box; min-height: 1px; height: .06em; border-top: .06em solid; margin: .06em -.1em; overflow: hidden; } mjx-line[type="d"] { margin: .18em -.1em; } mjx-c.mjx-c1D440.TEX-I::before { padding: 0.683em 1.051em 0 0; content: "M"; } mjx-c.mjx-c1D703.TEX-I::before { padding: 0.705em 0.469em 0.01em 0; content: "\3B8"; } mjx-c.mjx-c3A::before { padding: 0.43em 0.278em 0 0; content: ":"; } mjx-c.mjx-c211D.TEX-A::before { padding: 0.683em 0.722em 0 0; content: "R"; } mjx-c.mjx-c1D45B.TEX-I::before { padding: 0.442em 0.6em 0.011em 0; content: "n"; } mjx-c.mjx-c2192::before { padding: 0.511em 1em 0.011em 0; content: "\2192"; } mjx-c.mjx-c3D::before { padding: 0.583em 0.778em 0.082em 0; content: "="; } mjx-c.mjx-c28.TEX-S1::before { padding: 0.85em 0.458em 0.349em 0; content: "("; } mjx-c.mjx-c1D416.TEX-B::before { padding: 0.686em 1.189em 0.007em 0; content: "W"; } mjx-c.mjx-c28::before { padding: 0.75em 0.389em 0.25em 0; content: "("; } mjx-c.mjx-c31::before { padding: 0.666em 0.5em 0 0; content: "1"; } mjx-c.mjx-c29::before { padding: 0.75em 0.389em 0.25em 0; content: ")"; } mjx-c.mjx-c2C::before { padding: 0.121em 0.278em 0.194em 0; content: ","; } mjx-c.mjx-c2026::before { padding: 0.12em 1.172em 0 0; content: "\2026"; } mjx-c.mjx-c1D43F.TEX-I::before { padding: 0.683em 0.681em 0 0; content: "L"; } mjx-c.mjx-c2B::before { padding: 0.583em 0.778em 0.082em 0; content: "+"; } mjx-c.mjx-c29.TEX-S1::before { padding: 0.85em 0.458em 0.349em 0; content: ")"; } mjx-c.mjx-c2208::before { padding: 0.54em 0.667em 0.04em 0; content: "\2208"; } mjx-c.mjx-cD7::before { padding: 0.491em 0.778em 0 0; content: "\D7"; } mjx-c.mjx-c1D431.TEX-B::before { padding: 0.444em 0.607em 0 0; content: "x"; } mjx-c.mjx-c1D719.TEX-I::before { padding: 0.694em 0.596em 0.205em 0; content: "\3D5"; } mjx-c.mjx-c28.TEX-S2::before { padding: 1.15em 0.597em 0.649em 0; content: "("; } mjx-c.mjx-c32::before { padding: 0.666em 0.5em 0 0; content: "2"; } mjx-c.mjx-c29.TEX-S2::before { padding: 1.15em 0.597em 0.649em 0; content: ")"; } mjx-c.mjx-c52::before { padding: 0.683em 0.736em 0.022em 0; content: "R"; } mjx-c.mjx-c65::before { padding: 0.448em 0.444em 0.011em 0; content: "e"; } mjx-c.mjx-c4C::before { padding: 0.683em 0.625em 0 0; content: "L"; } mjx-c.mjx-c55::before { padding: 0.683em 0.75em 0.022em 0; content: "U"; } mjx-c.mjx-c1D467.TEX-I::before { padding: 0.442em 0.465em 0.011em 0; content: "z"; } mjx-c.mjx-c6D::before { padding: 0.442em 0.833em 0 0; content: "m"; } mjx-c.mjx-c61::before { padding: 0.448em 0.5em 0.011em 0; content: "a"; } mjx-c.mjx-c78::before { padding: 0.431em 0.528em 0 0; content: "x"; } mjx-c.mjx-c1D700.TEX-I::before { padding: 0.452em 0.466em 0.022em 0; content: "\3B5"; } mjx-c.mjx-c1D53C.TEX-A::before { padding: 0.683em 0.667em 0 0; content: "E"; } mjx-c.mjx-c1D44B.TEX-I::before { padding: 0.683em 0.852em 0 0; content: "X"; } mjx-c.mjx-c223C::before { padding: 0.367em 0.778em 0 0; content: "\223C"; } mjx-c.mjx-c4E.TEX-C::before { padding: 0.789em 0.979em 0.05em 0; content: "N"; } mjx-c.mjx-c1D408.TEX-B::before { padding: 0.686em 0.436em 0 0; content: "I"; } mjx-c.mjx-c5B::before { padding: 0.75em 0.278em 0.25em 0; content: "["; } mjx-c.mjx-c5D::before { padding: 0.75em 0.278em 0.25em 0; content: "]"; } mjx-c.mjx-c2308.TEX-S1::before { padding: 0.85em 0.472em 0.349em 0; content: "\2308"; } mjx-c.mjx-c2309.TEX-S1::before { padding: 0.85em 0.472em 0.349em 0; content: "\2309"; } mjx-c.mjx-c398::before { padding: 0.705em 0.778em 0.022em 0; content: "\398"; } mjx-c.mjx-c1D442.TEX-I::before { padding: 0.704em 0.763em 0.022em 0; content: "O"; } mjx-c.mjx-c2F::before { padding: 0.75em 0.5em 0.25em 0; content: "/"; } mjx-c.mjx-c1D441.TEX-I::before { padding: 0.683em 0.888em 0 0; content: "N"; } mjx-c.mjx-c1D458.TEX-I::before { padding: 0.694em 0.521em 0.011em 0; content: "k"; } mjx-c.mjx-c2265::before { padding: 0.636em 0.778em 0.138em 0; content: "\2265"; } mjx-container[jax="CHTML"] { line-height: 0; } mjx-container [space="1"] { margin-left: .111em; } mjx-container [space="2"] { margin-left: .167em; } mjx-container [space="3"] { margin-left: .222em; } mjx-container [space="4"] { margin-left: .278em; } mjx-container [space="5"] { margin-left: .333em; } mjx-container [rspace="1"] { margin-right: .111em; } mjx-container [rspace="2"] { margin-right: .167em; } mjx-container [rspace="3"] { margin-right: .222em; } mjx-container [rspace="4"] { margin-right: .278em; } mjx-container [rspace="5"] { margin-right: .333em; } mjx-container [size="s"] { font-size: 70.7%; } mjx-container [size="ss"] { font-size: 50%; } mjx-container [size="Tn"] { font-size: 60%; } mjx-container [size="sm"] { font-size: 85%; } mjx-container [size="lg"] { font-size: 120%; } mjx-container [size="Lg"] { font-size: 144%; } mjx-container [size="LG"] { font-size: 173%; } mjx-container [size="hg"] { font-size: 207%; } mjx-container [size="HG"] { font-size: 249%; } mjx-container [width="full"] { width: 100%; } mjx-box { display: inline-block; } mjx-block { display: block; } mjx-itable { display: inline-table; } mjx-row { display: table-row; } mjx-row > * { display: table-cell; } mjx-mtext { display: inline-block; } mjx-mstyle { display: inline-block; } mjx-merror { display: inline-block; color: red; background-color: yellow; } mjx-mphantom { visibility: hidden; } _::-webkit-full-page-media, _:future, :root mjx-container { will-change: opacity; } mjx-math { display: inline-block; text-align: left; line-height: 0; text-indent: 0; font-style: normal; font-weight: normal; font-size: 100%; font-size-adjust: none; letter-spacing: normal; border-collapse: collapse; word-wrap: normal; word-spacing: normal; white-space: nowrap; direction: ltr; padding: 1px 0; } mjx-container[jax="CHTML"][display="true"] { display: block; text-align: center; margin: 1em 0; } mjx-container[jax="CHTML"][display="true"][width="full"] { display: flex; } mjx-container[jax="CHTML"][display="true"] mjx-math { padding: 0; } mjx-container[jax="CHTML"][justify="left"] { text-align: left; } mjx-container[jax="CHTML"][justify="right"] { text-align: right; } mjx-mi { display: inline-block; text-align: left; } mjx-c { display: inline-block; } mjx-utext { display: inline-block; padding: .75em 0 .2em 0; } mjx-mo { display: inline-block; text-align: left; } mjx-stretchy-h { display: inline-table; width: 100%; } mjx-stretchy-h > * { display: table-cell; width: 0; } mjx-stretchy-h > * > mjx-c { display: inline-block; transform: scalex(1.0000001); } mjx-stretchy-h > * > mjx-c::before { display: inline-block; width: initial; } mjx-stretchy-h > mjx-ext { /* IE */ overflow: hidden; /* others */ overflow: clip visible; width: 100%; } mjx-stretchy-h > mjx-ext > mjx-c::before { transform: scalex(500); } mjx-stretchy-h > mjx-ext > mjx-c { width: 0; } mjx-stretchy-h > mjx-beg > mjx-c { margin-right: -.1em; } mjx-stretchy-h > mjx-end > mjx-c { margin-left: -.1em; } mjx-stretchy-v { display: inline-block; } mjx-stretchy-v > * { display: block; } mjx-stretchy-v > mjx-beg { height: 0; } mjx-stretchy-v > mjx-end > mjx-c { display: block; } mjx-stretchy-v > * > mjx-c { transform: scaley(1.0000001); transform-origin: left center; overflow: hidden; } mjx-stretchy-v > mjx-ext { display: block; height: 100%; box-sizing: border-box; border: 0px solid transparent; /* IE */ overflow: hidden; /* others */ overflow: visible clip; } mjx-stretchy-v > mjx-ext > mjx-c::before { width: initial; box-sizing: border-box; } mjx-stretchy-v > mjx-ext > mjx-c { transform: scaleY(500) translateY(.075em); overflow: visible; } mjx-mark { display: inline-block; height: 0px; } mjx-mn { display: inline-block; text-align: left; } mjx-c::before { display: block; width: 0; } .MJX-TEX { font-family: MJXZERO, MJXTEX; } .TEX-B { font-family: MJXZERO, MJXTEX-B; } .TEX-I { font-family: MJXZERO, MJXTEX-I; } .TEX-MI { font-family: MJXZERO, MJXTEX-MI; } .TEX-BI { font-family: MJXZERO, MJXTEX-BI; } .TEX-S1 { font-family: MJXZERO, MJXTEX-S1; } .TEX-S2 { font-family: MJXZERO, MJXTEX-S2; } .TEX-S3 { font-family: MJXZERO, MJXTEX-S3; } .TEX-S4 { font-family: MJXZERO, MJXTEX-S4; } .TEX-A { font-family: MJXZERO, MJXTEX-A; } .TEX-C { font-family: MJXZERO, MJXTEX-C; } .TEX-CB { font-family: MJXZERO, MJXTEX-CB; } .TEX-FR { font-family: MJXZERO, MJXTEX-FR; } .TEX-FRB { font-family: MJXZERO, MJXTEX-FRB; } .TEX-SS { font-family: MJXZERO, MJXTEX-SS; } .TEX-SSB { font-family: MJXZERO, MJXTEX-SSB; } .TEX-SSI { font-family: MJXZERO, MJXTEX-SSI; } .TEX-SC { font-family: MJXZERO, MJXTEX-SC; } .TEX-T { font-family: MJXZERO, MJXTEX-T; } .TEX-V { font-family: MJXZERO, MJXTEX-V; } .TEX-VB { font-family: MJXZERO, MJXTEX-VB; } mjx-stretchy-v mjx-c, mjx-stretchy-h mjx-c { font-family: MJXZERO, MJXTEX-S1, MJXTEX-S4, MJXTEX, MJXTEX-A ! important; } @font-face /* 0 */ { font-family: MJXZERO; src: url("https://cdn.jsdelivr.net/npm/mathjax@3/es5/output/chtml/fonts/woff-v2/MathJax_Zero.woff") format("woff"); } @font-face /* 1 */ { font-family: MJXTEX; src: url("https://cdn.jsdelivr.net/npm/mathjax@3/es5/output/chtml/fonts/woff-v2/MathJax_Main-Regular.woff") format("woff"); } @font-face /* 2 */ { font-family: MJXTEX-B; src: url("https://cdn.jsdelivr.net/npm/mathjax@3/es5/output/chtml/fonts/woff-v2/MathJax_Main-Bold.woff") format("woff"); } @font-face /* 3 */ { font-family: MJXTEX-I; src: url("https://cdn.jsdelivr.net/npm/mathjax@3/es5/output/chtml/fonts/woff-v2/MathJax_Math-Italic.woff") format("woff"); } @font-face /* 4 */ { font-family: MJXTEX-MI; src: url("https://cdn.jsdelivr.net/npm/mathjax@3/es5/output/chtml/fonts/woff-v2/MathJax_Main-Italic.woff") format("woff"); } @font-face /* 5 */ { font-family: MJXTEX-BI; src: url("https://cdn.jsdelivr.net/npm/mathjax@3/es5/output/chtml/fonts/woff-v2/MathJax_Math-BoldItalic.woff") format("woff"); } @font-face /* 6 */ { font-family: MJXTEX-S1; src: url("https://cdn.jsdelivr.net/npm/mathjax@3/es5/output/chtml/fonts/woff-v2/MathJax_Size1-Regular.woff") format("woff"); } @font-face /* 7 */ { font-family: MJXTEX-S2; src: url("https://cdn.jsdelivr.net/npm/mathjax@3/es5/output/chtml/fonts/woff-v2/MathJax_Size2-Regular.woff") format("woff"); } @font-face /* 8 */ { font-family: MJXTEX-S3; src: url("https://cdn.jsdelivr.net/npm/mathjax@3/es5/output/chtml/fonts/woff-v2/MathJax_Size3-Regular.woff") format("woff"); } @font-face /* 9 */ { font-family: MJXTEX-S4; src: url("https://cdn.jsdelivr.net/npm/mathjax@3/es5/output/chtml/fonts/woff-v2/MathJax_Size4-Regular.woff") format("woff"); } @font-face /* 10 */ { font-family: MJXTEX-A; src: url("https://cdn.jsdelivr.net/npm/mathjax@3/es5/output/chtml/fonts/woff-v2/MathJax_AMS-Regular.woff") format("woff"); } @font-face /* 11 */ { font-family: MJXTEX-C; src: url("https://cdn.jsdelivr.net/npm/mathjax@3/es5/output/chtml/fonts/woff-v2/MathJax_Calligraphic-Regular.woff") format("woff"); } @font-face /* 12 */ { font-family: MJXTEX-CB; src: url("https://cdn.jsdelivr.net/npm/mathjax@3/es5/output/chtml/fonts/woff-v2/MathJax_Calligraphic-Bold.woff") format("woff"); } @font-face /* 13 */ { font-family: MJXTEX-FR; src: url("https://cdn.jsdelivr.net/npm/mathjax@3/es5/output/chtml/fonts/woff-v2/MathJax_Fraktur-Regular.woff") format("woff"); } @font-face /* 14 */ { font-family: MJXTEX-FRB; src: url("https://cdn.jsdelivr.net/npm/mathjax@3/es5/output/chtml/fonts/woff-v2/MathJax_Fraktur-Bold.woff") format("woff"); } @font-face /* 15 */ { font-family: MJXTEX-SS; src: url("https://cdn.jsdelivr.net/npm/mathjax@3/es5/output/chtml/fonts/woff-v2/MathJax_SansSerif-Regular.woff") format("woff"); } @font-face /* 16 */ { font-family: MJXTEX-SSB; src: url("https://cdn.jsdelivr.net/npm/mathjax@3/es5/output/chtml/fonts/woff-v2/MathJax_SansSerif-Bold.woff") format("woff"); } @font-face /* 17 */ { font-family: MJXTEX-SSI; src: url("https://cdn.jsdelivr.net/npm/mathjax@3/es5/output/chtml/fonts/woff-v2/MathJax_SansSerif-Italic.woff") format("woff"); } @font-face /* 18 */ { font-family: MJXTEX-SC; src: url("https://cdn.jsdelivr.net/npm/mathjax@3/es5/output/chtml/fonts/woff-v2/MathJax_Script-Regular.woff") format("woff"); } @font-face /* 19 */ { font-family: MJXTEX-T; src: url("https://cdn.jsdelivr.net/npm/mathjax@3/es5/output/chtml/fonts/woff-v2/MathJax_Typewriter-Regular.woff") format("woff"); } @font-face /* 20 */ { font-family: MJXTEX-V; src: url("https://cdn.jsdelivr.net/npm/mathjax@3/es5/output/chtml/fonts/woff-v2/MathJax_Vector-Regular.woff") format("woff"); } @font-face /* 21 */ { font-family: MJXTEX-VB; src: url("https://cdn.jsdelivr.net/npm/mathjax@3/es5/output/chtml/fonts/woff-v2/MathJax_Vector-Bold.woff") format("woff"); } mjx-c.mjx-c1D45D.TEX-I::before { padding: 0.442em 0.503em 0.194em 0; content: "p"; } mjx-c.mjx-c3E::before { padding: 0.54em 0.778em 0.04em 0; content: ">"; } mjx-c.mjx-c35::before { padding: 0.666em 0.5em 0.022em 0; content: "5"; } mjx-c.mjx-c30::before { padding: 0.666em 0.5em 0.022em 0; content: "0"; } mjx-c.mjx-c25::before { padding: 0.75em 0.833em 0.056em 0; content: "%"; } with weights , defined by is applied coordinatewise, and is taken to be by default. Schematic of our ReLU MLP. An estimation algorithm takes in and a tolerance parameter , and aims to estimate . We evaluate estimation algorithms by checking their mean squared error over weights with randomly initialized entries drawn independently from .[1] Our baseline is Monte Carlo sampling, which draws samples, runs them through the model, and averages the results. As a function of the width and the tolerance (holding the depth constant), in the average case over initializations, this has mean squared error and runs in time . Our best-performing algorithm, on the other hand, has mean squared error and runs in time – a factor of faster.[2] This means that for any given depth, our algorithm is guaranteed to outperform Monte Carlo sampling at large enough width, although the dependence of our algorithms on depth is worse.[3] Since this theoretical result provides no guarantee about the performance at any particular width, we empirically check the performance of our algorithm at realistic widths, looking at mean squared error versus the number of floating point operations (FLOPs) performed.[4] For ReLU MLPs with 4 hidden layers and width 256, our best algorithms outperform Monte Carlo sampling at FLOP budgets spanning 7 orders of magnitude, in some cases achieving the same mean squared error using fewer than as many FLOPs.[5] Performance of our best cumulant propagation algorithms at estimating the mean of random ReLU MLPs with 4 hidden layers and width 256. Perhaps more importantly, our algorithm outperforms Monte Carlo sampling in the tails of the distribution to an even greater extent than the body. For low probability estimation, Monte Carlo sampling is essentially useless for probabilities significantly below below , where is the number of samples used. On the other hand, using a similar number of FLOPs to Monte Carlo sampling with samples, our algorithm sometimes achieves a relative error of under 30% for probabilities 100 times lower than .[6] Finally, it is straightforward to extend our method from the expected output of the network to bona fide loss functions. We demonstrate this for a simple distillation loss between two networks of different widths. Since our estimates are differentiable, we can train the student network using our estimates of this distillation loss, a process we refer to as mechanistic distillation. Although we do not outperform ordinary training, this serves as a proof of concept for using mechanistic estimation for training. Significance of results We are excited about these results not simply because of the performance improvements, but because our algorithms are "mechanistic": rather than running the network many times and seeing what it does, we read off behavioral properties of the network directly from the weights. If we could produce good mechanistic estimates for frontier models, we should be able to catch deceptive alignment at train time, even if the model's behavior looks benign on every training input. This potentially offers us a completely different way to train models that avoids deceptive alignment. Instead of using stochastic gradient descent, we can apply gradient descent to a mechanistic estimate, a process we refer to as mechanistic training. Even if mechanistic training is about as efficient as ordinary training overall, it would produce models that allocate capacity in a very different way, because the kinds of errors that mechanistic estimates make are very different. These models would therefore generalize very differently. For example, consider a loss function with a certain "dangerous" event that is very rare but has very high loss. Stochastic gradient descent may fail to sample the event even once during training, resulting in a model that does not account for the dangerous event at all. On the other hand, mechanistic training could notice how the rare event might transpire, resulting in a model that does a better job of avoiding it (potentially allocating capacity to this at the expense of other aspects of performance). This could be especially important if we are concerned about the dangerous event becoming significantly more likely under distribution shift. Of course, we are a long way from being able to train frontier models using mechanistic training. Randomly initialized networks serve as the natural starting point for this goal, both because they are the simplest networks, and because every network starts out randomly initialized. We have produced mechanistic estimates for these networks that significantly outperform random sampling (for wide networks),[7] as well as providing demonstrations of low probability estimation and mechanistic training. Note that the high-level algorithm in our paper is not new: it is a form of cumulant propagation, an algorithm we introduced in 2022 in Formalizing the presumption of indepdence (Appendix D). This works by propagating an approximate probability distribution through the model, without running the model on any particular input. However, a number of specific details about our algorithm are new, and these are essential for achieving low mean squared error.[8] It is only because we reoriented our research around outperforming random sampling that we ended up discovering these details and obtaining our results. Extending to trained networks Although we are able to outperform random sampling, our methods rely on the networks being randomly initialized. Roughly speaking, our methods start from Gaussian approximations to the activation distributions, and then track the lowest-order deviations from that. In a trained network, specific higher-order deviations can become much more important, and so our algorithm would need to be adapted to track these. We suspect that this can be made to work with some sort of auxiliary "advice" given to the estimation algorithm, which somehow points out which of these higher-order deviations to track, and which gets updated incrementally with each gradient step. In our analogy with a proof by induction, the algorithm for this incremental update would be our "inductive step". One idea for what this could look like is for the advice to provide a different, more structured weight distribution such that the model's weights can be viewed as having been drawn randomly from this distribution. This structured distribution can be thought of as an efficient compression of the model's weights, an idea we discuss further in Compression as a possible MSP approach.[9] Regardless, being able to handle trained networks is clearly essential for our methods to be of practical utility, and we are pursuing this direction in further research. Conclusion Our paper achieves our goal of producing mechanistic estimates that outperform random sampling in the case of wide, randomly initialized MLPs. We are working on extending this to trained networks in order to produce new training methods that reduce the likelihood of deceptive alignment. For , this is the standard He initialization, and ensures that the outputs of have variance around 1. ↩︎ This works providing that grows like for some . ↩︎ The depth scaling of our algorithms is discussed Appendix D of the paper. ↩︎ We also empirically validate that the scaling with width matches our theoretical predictions, as discussed in Section 6.3 of the paper. ↩︎ Our algorithms often underperform Monte Carlo sampling in wall-clock time, since we make no serious attempt to optimize their performance on hardware. We provide wall-clock time in Appendix I of the paper. ↩︎ We explain this result in Section 6.6 of the paper. ↩︎ It remains an open problem to match the performance of random sampling for random MLPs in other regimes, for example, with the depth growing linearly in the width. ↩︎ An ablation of some of these details is discussed in Section 6.4 of the paper. ↩︎ We don't think this approach works quite as we have stated it. For example, if the model starts by computing a one-way function of the weights, we may want an efficient compression of the output of this one-way function, rather than of the weights themselves. ↩︎ Discuss
