AI News Hub Logo

AI News Hub

Chapter 8: RMS Normalisation and Residual Connections

DEV Community
Gary Jackson

What You'll Build Two architectural patterns that make deep networks trainable: RMSNorm (keeps activations from exploding or vanishing) and residual connections (gives gradients a highway to flow through). Chapters 1-2 (Value), Chapter 5 (Helpers). As data flows through many Linear operations and activation functions like ReLU (both of which you've already seen), the magnitude of the numbers can drift. They grow huge, or shrink to near-zero. Both are catastrophic for training. RMSNorm rescales the numbers after each layer to keep them in a stable range, and residual connections let the original signal bypass each layer entirely. Imagine a vector of numbers flowing through the network. After a few Linear operations, those numbers might have drifted to very large values like [500, 800, 300] or very small ones like [0.001, 0.002, 0.001]. RMSNorm fixes this by measuring the overall "size" of the vector (using the root mean square: the square root of the average of the squared values) and then dividing each element by that size. The result is a vector whose overall magnitude is always close to 1, regardless of what happened in previous layers. Why root-mean-square specifically? This is the same RMS pattern we saw in Adam's squared gradient average in Chapter 7, and for the same two reasons: Makes values positive. We care about overall magnitude, not direction. A vector [-5, 5] has the same "size" as [5, -5], and squaring makes the calculation agree. Emphasises larger values. A value of 10 contributes 100 to the sum; a value of 1 contributes just 1. So the measure is dominated by the biggest elements rather than being smeared across all of them. Squaring on the way in and square-rooting on the way out gives us a single number that represents the vector's "typical size". Dividing by that scale leaves a vector whose overall magnitude is ~1. Add it to Helpers.cs: // --- Helpers.cs (add inside the Helpers class) --- /// /// Rescales a vector so its overall magnitude is close to 1, using the root mean /// square of its values. Keeps activations stable across deep networks. /// public static List RmsNorm(List x) { var sumSq = new Value(0); foreach (Value xi in x) { sumSq += xi * xi; } Value ms = sumSq / x.Count; Value scale = (ms + 1e-5).Pow(-0.5); return [.. x.Select(xi => xi * scale)]; } The 1e-5 prevents division by zero if all values happen to be zero. RMSNorm was introduced by Zhang & Sennrich (2019) as a simpler alternative to LayerNorm (used in the original GPT-2). It drops the learned scale/shift parameters and the mean-subtraction step, making it faster while achieving similar results. See the References section for the paper. A residual connection simply adds a layer's input back to its output. It isn't a separate function, it's a pattern applied inline wherever a transformation occurs: // Pattern - not a standalone function, used inside Model.cs in Chapter 11 var xResidual = new List(x); x = SomeTransformation(x); for (int i = 0; i { new(3.0), new(4.0) }; List normed = RmsNorm(testVec); Console.WriteLine("--- RmsNorm ---"); Console.WriteLine("Expected: 0.849 1.131"); Console.Write("Got: "); foreach (Value v in normed) { Console.Write($"{v.Data:F3} "); } Console.WriteLine(); // Try it with a "drifted" vector - large values get scaled down // RMS of [500, 800, 300] ~ 571.548; normed ~ [0.875, 1.400, 0.525] // Values are now close to 1.0 in magnitude, regardless of the original scale. var bigVec = new List { new(500.0), new(800.0), new(300.0) }; List bigNormed = RmsNorm(bigVec); Console.WriteLine("--- RmsNorm on large values ---"); Console.WriteLine("Expected: 0.875 1.400 0.525"); Console.Write("Got: "); foreach (Value v in bigNormed) { Console.Write($"{v.Data:F3} "); } Console.WriteLine(); // ── Test Residual Connection ── // Start with [1, 2], apply a transformation (double each value), // then add the original back: [2+1, 4+2] = [3, 6] var x = new List { new(1.0), new(2.0) }; var xResidual = new List(x); // "Transformation": double each value x = [.. x.Select(xi => xi * 2.0)]; // Residual: add original back for (int i = 0; i < x.Count; i++) { x[i] += xResidual[i]; } Console.WriteLine("--- Residual Connection ---"); Console.WriteLine("Expected: 3.0 6.0 (transformation output + original input)"); Console.Write("Got: "); foreach (Value v in x) { Console.Write($"{v.Data:F1} "); } Console.WriteLine(); } } Uncomment the Chapter 8 case in the dispatcher in Program.cs: case "ch8": Chapter8Exercise.Run(); break; Then run it: dotnet run -- ch8