A "Lay" Introduction to "On the Complexity of Neural Computation in Superposition"
This is a writeup based on a lightning talk I gave at an InkHaven hosted by Georgia Ray, where we were supposed to read a paper in about an hour, and then present what we learned to other participants. Introduction and Background So. I foolishly thought I could read a theoretical machine learning paper in an hour because it was in my area of expertise. Unfortunately, it turns out that theoretical CS professors know a lot of math and theoretical CS results that they reference constantly in their work, which makes their work very hard to read, even if you’re familiar with the general area. Instead of explaining a bunch of the substantial actual math behind the paper, the best I can do is give an overview of what the setup for the paper is, what the contributions of the paper are, and how they fit in. Back in the olden days (2021) there was a dream that you could just open up a neural network and understand it by looking at individual neurons. For example, you might ask, “is this neuron a ‘cat’ neuron? Or is it the ‘betray all humans’ neuron?”. Then you could just check if the ‘betray all humans neuron’ is on. But it turned out that neural networks were a lot more complicated than this. For one thing, a serious issue was neuron polysemanticity, where a neuron fired on a bunch of seemingly unrelated things. We’d see things like the ‘betray all humans’ neuron firing on discussion of cats and the like. Maybe the AI is planning on instigating the grand robotic uprising, maybe it’s just thinking about the genealogy of Maine coons. Though, of course, there is some chance that we were wrong, and there actually is some deep connection between cats and attempts to subjugate humanity. I doubt anyone asked this Maine coon what he thought about robotic uprisings. Image source. The leading theory that people had is this: in high dimensional spaces, if you’re okay with a small amount of interference between your representations, you can represent a lot more things by using near-orthogonal vectors (even random near-orthogonal vectors). Arguably, if you take seriously this result called the Johnson-Lindenstrauss lemma, you should be able to represent exponentially more. The first image I could find to represent the Johnson-Lindenstrauss lemma. The lemma states that you can represent m points in O(log m) dimensions while preserving the distances between pairs up to some small amount of noise. In fact, this is so easy that a random projection works. This led to a series of research projects in 2022 studying what we would nowadays call representational superposition.[1] People would study toy problems where small networks had to represent many concepts at once (by representing them in near-orthogonal ways). Then they'd use their understanding from these results to construct techniques to extract concepts from LLMs. (As an aside, yes, I’m aware that every other field uses the word superposition differently – closer to how we use ‘polysemanticity’ in model internals work. For example, in quantum mechanics, superposition just means the system is not in any ‘pure’ state. And yes, it is pretty funny that the word ‘superposition’ ended up meaning not one thing, but several different concepts.) But in response to this, there was some amount of work that made the point that, you can’t just think of a neural network as representing a bunch of things. In fact, it’s probably important to think of neural networks as computing things, given that that is what the interesting parts of the network are doing. As a general rule, neural networks are not just representing concepts that God handed to it in its input. In 2024, I did some work in this area (though my collaborators deserve more of the credit). Our work had some clever constructions that showed that you can indeed get some amount of efficiency by computing using concepts in superposition. But the gains in compressing concepts were not exponential – they were basically quadratic. That is, if for some toy problems, if you want to compute something that normally would take m pure concepts, you can do this with something like sqrt(m) number of neurons (with a bunch of unspecified log factors). Adler and Shavit’s “On the Complexity of Neural Computation in Superposition” Adler and Shavit’s "On the Complexity of Neural Computation in Superposition" builds on this initial work. And while reading it, my main impression was something along the lines of “Wow, this is what real computer scientists are like.” I do have some complaints about how they wrote the paper and presented their results, but one thing that stood out to me is that the paper makes it very clear that they really do know a lot of math. They're really careful with their math and constructions in a way that I think the work I was involved in was just not. A lot of what we did in this area felt like gesturing at proof sketches that should probably work out.[2] They also cite some theoretical CS results that I didn’t know about, that seem quite relevant. I think the first main contribution of their work (from my perspective) is that while our work had upper bounds on how large the networks needed to be to do certain types of computation, their work also provided lower bounds. That is, they show that for some classes of toy problems with m pure concepts, you need a network that has at least sqrt(m/log(m)) number of neurons. Their argument for this is arguably “obvious” in the sense that it starts from an information theory based counting argument – you can’t represent enough things if you don’t have enough parameters, basically. However, it turns out that making this argument rigorous in the presence of noise becomes complicated. While I thought it would be easy, it definitely was not easy to make the results work out; I found it impressive because they put in the work and managed it anyways. Adler and Shavit also provided several cleaner constructions for the upper bound results, with O(sqrt(m) log(m)) neurons needed. (They made the unspecified log factors explicit.) If you combine these two results together, their paper shows that the sqrt(m) result is tight (that is, this is how many neurons you need to compute with m concepts), up to log factors. Their third contribution is less of a specific result, and a general procedure for how to construct models to solve the classes of toy problems (that is, how you might pick the weights by hand). They envision every single weight matrix as being composed internally of two parts: first, a big decompression matrix that takes the small, dense representation and expands it into a large sparse representation. Second, a large computation and compression matrix, which both does the computation on the sparse representation and also compresses it back into a single dense representation. Notably, they imagine this happening inside of a single weight matrix, instead of being spread between weight matrices of different layers. As far as I know, this is a new construction (or at least was so at the time) that seems useful for hand-constructing neural networks for similar problems Figure 2 from Adler and Shavit 2024. Adler and Shavit suggest thinking of each weight matrix W_i as composed of a decompression matrix D_i and a compression/computation matrix C_i. Of course, I confess that I did not have time to read through their proofs, let alone all of the results they cite. The paper also contains parts which present a non-trivial theoretical CS result, and says that it’s true because of citation 38. Then I’d click on 38 and it’d say “personal communication with another MIT professor”. So I can’t really say that I’ve fully understood the work, nor that I’ve checked it for correctness. To summarize, my overall impression of the paper is something along the lines of: This was a cool paper. It really does show that theoretical CS people have a lot of expertise in doing proofs carefully and doing the work to make their results go through. I’ll probably spend time in the coming weeks reading it more carefully. But at the same time, I felt a bit disappointed in the paper. I thought there would be a lot more new content. What I instead learned (while failing to read the paper in an hour) was that it can be pretty nontrivial to make some very basic seeming arguments in this area mathematically rigorous. ^ The canonical work is Elhage et al.'s "Toy Models of Superposition": https://transformer-circuits.pub/2022/toy_model/index.html. ^ Adding after the fact: It’s worth noting that, Dmitry Vaintrob did go through and prove all of our results rigorously – he’s a real mathematician! But I was much less involved in that part of the work; my contributions mainly stopped at the proof sketch stage. Also, this is why a lot of our results had a long chain of unspecified log factors. Whoops. Discuss
