Gradient descent: back to basics with F# and C#

After dabbling with high-level AI tools recently, I wanted to revisit the basics and look under the hood. As they say, the best way to learn something is to build it. I've been following wonderful "from scratch" videos from Andrej Karpathy, where he starts with the essence of all modern AI architectures: gradient descent engine.

So, I decided to keep along and build my own engine at https://github.com/khmylov/dumb-gradient, and this article is an overview of its implementation, and some thoughts on using F# and C# for a task like that.

Side note: look back at university

I completely forgot that one of my university projects back in 2010 was actually about building an image recognition neural network! It was funny to look back at the source code from that era. Turns out, the core ideas are still relevant, it's just computational advances, new modelling discoveries and infrastructure around them that allowed so much progress in the last decade.

I'm not going to recite all the details here, the main (much simplified) ideas to comprehend are as follows (I recommend Deep Learning with PyTorch, fast.ai course, and Huggingface course to get a much deeper dive into this)

  • We can represent most real-world data, like text, images, sounds, as vast arrays of numbers
  • Turns out, we can model a lot of "intelligence" tasks (natural language processing, image recognition, etc.) with sophisticated mathematical functions operating on our numeric representation of real-world data
  • Universal approximation theorem was a break-through postulating that even most sophisticated functions can be actually modeled by combining only primitive linear f(x) = a * x + b and basic non-linear functions such as tanh into some inter-connected network (here come the inevitable neuron models and neural networks), we just need a huge number of them to do that, and some computation power to figure out the arguments for those functions
  • "Training" such networks is figuring out how to tweak randomly initialized function arguments to make the entire network close to modelling our sophisticated function
  • You can certainly mess around with those function arguments randomly until you get what you want, but that's extremely inefficient computationally-wise. However, since we only use simple linear and non-linear functions underneath, your algebra course may remind you that derivatives let you find out how the function reacts to the change of its arguments analytically.

Essentially, it means that neural networks are huge machines crunching partial derivatives of various functions at specific local points (text/images/etc. encoded as numbers), to understand how to tune the a and b in f(x) = a * x + b

Of course in real world you don't even have to think about it these days. Libraries like PyTorch give you nice production-ready, battle-tested and GPU-ready building blocks. Heck, even that is considered too low-level now, for most people the better approach is to use a pre-built model from Huggingface, which you can access with a single function call in Python now. What a time to be alive!

But I had this urge to look into the boring old implementation details, have some fun re-building a dumbified PyTorch from scratch, and play with recent generic math features in C#. However, I actually started with F#, as it has been my go-to choice for various scripts and algorithmic tasks for almost 15 years now. It's also praised a lot as data-oriented language, so seems like a good fit here.

Spoiler alert: I'm not happy with neither F# nor C# implementations, both languages have their quirks which sometimes make them annoying for this specific task at hand 😒.

F#: basic implementation

One of the core features of PyTorch that I wanted to re-implement is automatic gradient tracking, meaning that all operations on PyTorch primitives keep track of their origins, which then lets you apply derivatives to build a gradient vector, and apply it to your model's parameters to nudge them into needed direction.

So if I write a = b + c in PyTorch, then a is not just some numerical value, it's an object having both the numerical value, and the notion of it originating from addition of b and c.

First, let's see how our example program may look like:

let a, b, c = Value(1), Value(-4), Value(10)
let p1 = a * b + c // Add(Mult(Value(1), Value(-4)), Value(10))

In my implementation I decided to model it with expression hierarchy, i.e. the result of * or + or any other operation is a dedicated case class wrapping the underlying value and expression nodes it depends upon.

[<AbstractClass>]
type Expr() =
  // Local gradient of this node
  member val Grad = 0.0 with get, set

  // Mutates local gradients of expressions this expression depends on,
  // i.e. "how much should expression I depend upon change to make my value go up by 1 unit?"
  abstract member UpdateGradient: unit -> unit

  // Numerical result of the computation represented by this expression
  abstract member Value: float

For example, the results of addition and multiplication are then modeled as:

type Add(left: Expr, right: Expr) =
  inherit Expr()

  override this.UpdateGradient() =
    // z = x + y; x = f(a)
    // dz/da = dz/dx * dx/da = 1 * dx/da = dx/da
    left.Grad <- left.Grad + this.Grad
    right.Grad <- right.Grad + this.Grad

  override _.Value = left.Value + right.Value
  
type Mult(left: Expr, right: Expr) =
  inherit Expr()

  override this.UpdateGradient() =
    // z = x * y; x = f(a)
    // dz/da = dz/dx * dx/da = y * dx/da
    left.Grad <- left.Grad + right.Value * this.Grad
    right.Grad <- right.Grad + left.Value * this.Grad
 
  override _.Value = left.Value * right.Value

We will also need the simplest Value case to lift the primitive underlying value into our expression world

// Just a dumb Expr wrapper around value, most likely to be used as parameter or model input
type Value(initialValue: float) =
  inherit Expr()

  // Mutable because we're going to change it when training
  let mutable value = initialValue

  override _.UpdateGradient() = ()
  override _.Value with get() = value

Let's also throw in operator overloads to make a * b + c possible

type Expr
with
  static member inline (*) (left, right) = Mult(left, right)
  static member inline (+) (left, right) = Add(left, right)

The immediate question is why not use something more idiomatic like discriminated unions type Expr = | Value | Add of Expr * Expr | Mult of Expr * Expr or records?

For this specific task I want expressions to be "open", meaning that it's possible to extend the "base" library by adding more expression types. Discriminated unions are great for modelling some internal state, but not really great for external extensibility.

Records are kind of supposed to be immutable, but training a network is essentially mutating its inner state, so wasting GC on allocating immutable data structures does not sound appealing.

So, we're back to the ugly OOP and class inheritance!

Side note: playing around with C# implementation, I ditched the inheritance in the end, and lifted method overrides like UpdateGradient into parameters. The same trick could be applied to F# version (listed below) to make it more functional, but for the sake of experiment, I'm going to proceed the way it worked for me initially.

type Expr = {
  mutable Gradient: double
  GetValueImpl: Expr -> double
  Dependencies: Expr array
  UpdateGradient: Expr -> unit
}

Backward pass

Before we add more expression nodes needed for modelling, like tanh or sigmoid, let's consider how we're going to use the future model. The model itself is a large function (our expression tree), with leaf Value nodes representing both numerical encoding of input data and tunable model parameters.

We want to know how far the model's output (value of the root node) is different from our predictions. This is called the loss metric of the current model state, and we want it to go down. There are numerous strategies for loss calculation, one of the simplest is just taking the difference between actual and predicted number, and then throwing away its sign, either by calling abs() or squaring it.

Therefore, the loss can be represented as our differentiable expression node as well, which also means we can build a gradient on it and nudge the dependencies to make the loss go down.

PyTorch exposes backward method on its modules, so let's implement something similar for our Expr. First, we need every expression node to expose the nodes they depend on. I'm going to skip insignificant code details, the full code is available in the accompanying repository.

// Only new code listed here, previously defined methods are skipped
type Expr() =
  // ...
  abstract member Children: Expr seq

type Value(initialValue: float) =
  // ...
  override _.Children = Seq.empty

type Mult(left: Expr, right: Expr) =
  // ...
  override _.Children = seq {left; right}

type Add(left: Expr, right: Expr) =
  // ...
  override _.Children = seq {left; right}

Let's add a helper method to get all dependency nodes (down to leaf levels) for a given expression, and then add our Backward() method. First, it sets the gradient of the current node to 1 (rudimentary case: to change the current node by 1 unit, we have to change it by 1), and then goes through all dependencies once (order matters) and updates their gradients based on the nodes depending on them (everything is going backwards, yay 😵).

type Expr
with
  member this.GetAllChildren() =
    let mutable res = []
    let visited = System.Collections.Generic.HashSet<_>()
    let rec loop (v: Expr) =
      if visited.Add v then
        for child in v.Children do
          loop child
        res <- v::res
    loop this
    res

  member this.Backward() =
    this.Grad <- 1.0
    for v in this.GetAllChildren() do
      v.UpdateGradient()

This is once again far from idiomatic functional code, but good enough for our task at hand.

Annoyance creeps in

This is the place where my annoyance started picking up. I feel that I'm using the wrong tool for the job. Modelling class hierarchies in F#just does not feel good. For example, the language does not support a notion of protected members, which is kind of by design, but arrgh.. Essentially, it means my Expr exposes too much inner things at the public surface, like UpdateGradient method, which is an internal thing expected to be called by Backwards only.

Next thing, I wanted to add some optional constructor parameters to prettify the debugging outputs in some cases without enforcing parameter passing in all places. Turns out, F# supports optional parameters on let bindings only, not on class members.

F# is famously a single-pass compiler, which has its share of benefits. However, it also means I can't define class-based operator overloads in the base Expr definition (and that's the core feature of the engine):

type Expr() =
  // .. some code omitted
  static member (+) (left, right) = Add(left, right) // this does not work as Add is unknown yet

type Add(left: Expr, right: Expr) =
  // .. some code omitted

Fortunately, this can be mitigated with type augmentations, so I can define operators after defining expression types

type Expr() =
  // .. some code omitted
  
type Add(left: Expr, right: Expr) =
  // .. some code omitted

type Expr
with
  static member inline (+) (left, right) = Add(left, right)

There are also some minor quirks around type inference and math contracts, like exp and pow operators expecting the result type to match the argument types, which adds some redundancy (note : Expr and :> Expr), which kind of defeats the purpose.

type Pown(x: Expr, n: int) =
  inherit Expr()
  // ..

type Expr
with
  static member inline Pow (left, right): Expr = Pown(left, right)
  
let a = Value(2.0)
let b = (a :> Expr) ** 2

In the end, I got the basic setup working, added more expression types, added minor abstractions for neuron, layer, and multi-layer perceptron, but I was somewhat dissatisfied with the result.

C#: class hierarchy

Then I remembered that recent C# and .NET versions added support for "generic math", based on static abstract interface members, and C# is also less strict on types, so I decided to play with these new features and re-implement the engine in C#.

I started by replicating the same inheritance-based approach, with base Expression class and subtype for each operation kind. Once again, I'm skipping insignificant bits, the full code is available on Github.

internal abstract class Expression {
    public double Gradient { get; set; }

    public abstract double GetValue();

    // Get expression this node depends on
    protected abstract IEnumerable<Expression> GetDependencies();

    // Propagate gradient from this node to its dependencies
    protected abstract void UpdateGradient();

    public void Backward()  => // ... same as before

    private IReadOnlyList<Expression> GetAllDependencies() => // ... same as before

    public static Expression operator +(Expression left, Expression right) => new Add(left, right);

    public static Expression operator *(Expression left, Expression right) => new Mult(left, right);
}

internal class Value : Expression {
    // ...
}

internal class Add : Expression {
    // ...
}

internal class Mult : Expression {
    // ...
}

Here we immediately see the same "problem" as with F# implementation: + and * overloads can be defined only on the base class, you can't have them as external extension methods or something like trait implementations. I guess we can live with that for now.

C#: generic math

As an experiment, I wanted to abstract over the underlying value type. Currently it's double, but can we make it any "number"-like type? That's where "generic math" comes in, which now lets us add generic type constraints like where T : IAdditionOperators<T, T, T>, specifying that we want instances of T to support some abstract + operator, accepting two Ts and returning another T. Applying this to our expressions (source):

internal abstract class Expression<T> {
    public double Gradient { get; set; }

    public abstract T GetValue();

    // ... most other members unchanged
}

internal class Value<T> : Expression<T> {
     // ... most other members unchanged
}

internal class Add<T> : Expression<T>
    where T : IAdditionOperators<T, T, T> { // <------ HERE
    
    // ... most other members unchanged

    public override T GetValue() => _left.GetValue() + _right.GetValue();

    protected override void UpdateGradient() {
        _left.Gradient += Gradient;
        _right.Gradient += Gradient;
    }
}

internal class Mult<T> : Expression<T>
    where T : IMultiplyOperators<T, T, T>, IMultiplyOperators<T, double, T>, IAdditionOperators<T, double, double> { // <------ WOW

    // ... most other members unchanged

    public override T GetValue() => _left.GetValue() * _right.GetValue();

    protected override void UpdateGradient() {
        // Must use `X = Y + X` form instead of `X += Y` because of IAdditionOperators<T, double, double>
        _left.Gradient = _right.GetValue() * Gradient + _left.Gradient;
        _right.Gradient = _left.GetValue() * Gradient + _right.Gradient;
    }
}

Note that things start to get out of hand here: we'd like to abstract over underlying data type, but we still must add and multiply it with our Gradient, which is  double, so we need more ugly constraints like IAdditionOperators<T, double, double. The caveat here is that built-in types like int do not implement such interfaces.

This in turn means that we can't actually have an instance of Expression<int>. Also, it's not (yet?) possible to implement these static abstract interface methods in extension-like manner, only the T type itself must be used, so the following imaginary code is not possible:

struct IntToDouble : IAdditionOperators<int, double, double> {
    ...
}

Finally, since operator overloads must be defined on the base abstract class, we actually must add type constraints from all implementation nodes on the base Expression<T> class, which also makes them required on all descendants as well, even if their implementation does not need specific operators:

internal abstract class Expression<T>
    where T : IAdditionOperators<T, T, T>,
    IMultiplyOperators<T, T, T>,
    IMultiplyOperators<T, double, T>,
    IAdditionOperators<T, double, double> {

    public static Expression<T> operator +(Expression<T> left, Expression<T> right) => new Add<T>(left, right);

    public static Expression<T> operator *(Expression<T> left, Expression<T> right) => new Mult<T>(left, right);

    // ... 
}

internal class Value<T> : Expression<T>
    where T : IAdditionOperators<T, T, T>,
    IMultiplyOperators<T, T, T>,
    IMultiplyOperators<T, double, T>,
    IAdditionOperators<T, double, double> {
    // ...
}

internal class Add<T> : Expression<T>
    where T : IAdditionOperators<T, T, T>,
    IMultiplyOperators<T, T, T>,
    IMultiplyOperators<T, double, T>,
    IAdditionOperators<T, double, double> {
    // ...
}

internal class Mult<T> : Expression<T>
    where T : IMultiplyOperators<T, T, T>,
    IMultiplyOperators<T, double, T>,
    IAdditionOperators<T, double, double>,
    IAdditionOperators<T, T, T> {
    
    // ...
}

Overall, I find the resulting implementation unappealing, redundant, and just not worth it.

C#: default interface members

While we're at it, I also wanted to play around with default interface implementations from C#8, so here's another version with IExpression interface instead of Expression abstract class (source):

public interface IExpression {
    double Gradient { get; set; }

    double GetValue();

    protected void UpdateGradient();

    protected IEnumerable<IExpression> GetDependencies();

    public void Backward() => // ... same as before

    private IReadOnlyList<IExpression> GetAllDependencies() => // ... same as before

    static IExpression operator +(IExpression left, IExpression right) => new Add(left, right);
}

// Technically, we can use structs for interface implementations
public struct Value : IExpression {
    // ...
}

public class Add : IExpression {
    private readonly IExpression _left;
    private readonly IExpression _right;

    public Add(IExpression left, IExpression right) {
        _left = left;
        _right = right;
    }
    
    // Must define the implementation
    public double Gradient { get; set; }

    public double GetValue() => _left.GetValue() + _right.GetValue();

    // Implicit interface implementation would require "public" access modifier
    void IExpression.UpdateGradient() {
        _left.Gradient += Gradient;
        _right.Gradient += Gradient;
    }

    IEnumerable<IExpression> IExpression.GetDependencies() {
        yield return _left;
        yield return _right;
    }
}

It immediately brings up a couple of interesting observations.

First, dropping abstract base classes in favor of interfaces lets us implement some expression nodes as structs instead of classes, which theoretically should be better from GC point of view. However, in practice it may not work out exactly as you'd expect, structs mixed with interfaces easily lead to boxing at runtime, there's even a nice warning from Rider/ReSharper:

Using the default implementations of the interface members may result in boxing of the struct 'Value' values at runtime in generic code (where T : IExpression)

All in all, this is definitely not the best scenario for default interface members:

  • Even though we have protected GetDependencies() on the interface, it must be implemented as public GetDependencies() in derived types (or using explicit interface implementation).
  • We just put Backward()  or operator + implementation into IExpression interface for the code reuse, but types implementing IExpression may decide to override it, which is undesirable.

C#: static abstract interface members

Things get even worse if we decide to plug in our expressions into the previously mentioned "generic math" ecosystem! Let's try implementing IAdditionOperators for it, so that our expressions can be used with various numerics operations:

public interface IExpression : IAdditionOperators<IExpression, IExpression, IExpression> {
    // ...
}

This fails to compile with the following error:

[CS8920] The interface 'IExpression' cannot be used as type argument. Static member 'IAdditionOperators<IExpression, IExpression, IExpression>.operator +(IExpression, IExpression)' does not have a most specific implementation in the interface.

Ehhmm, what? Looking for details of this error leads to obscure thread on Github (well, maybe not so obscure, but I'm past my prime days of being deep into compilers to quickly understand it). Basically, you can't mix static abstract members with type constraints in the same interface. Another dead end, feels like most of these features work only in for a limited amount of use-cases in the base library, and are not actually designed for general-purpose extensibility.

C#: back to data and functions

In the end, I was so annoyed that I decided to ditch the class hierarchy and interface nonsense altogether and just model everything in a dumb data structure. I did a naive straightforward refactoring:

  • Lift protected member overloads into constructor parameters. Use static lambdas to share delegate instances and avoid redundant allocations.
  • Pass dependencies as a collection parameter to constructor

This results in a less type-safe implementation with additional dumb code for indexed dependency access, but it's still quite simple and readable in my opinion (source).

class Expr<T> where T : INumber<T>, IMultiplyOperators<T, double, double> {

    private readonly Func<Expr<T>, T> _getValueImpl;
    private readonly Action<Expr<T>> _updateGradient;

    /// <remarks>
    /// Function parameters are expected to be static, so that each expression instance gets the same delegate, to avoid additional GC pressure.
    /// </remarks>
    public Expr(
        IReadOnlyList<Expr<T>> dependencies,
        Func<Expr<T>, T> getValueImpl,
        Action<Expr<T>> updateGradient,
        // Used only as a hack to pass the underlying value to Value implementation and avoid closure allocations
        T? dummyValue = default) {
        _getValueImpl = getValueImpl;
        _updateGradient = updateGradient;
        Dependencies = dependencies;
        DummyValue = dummyValue;
    }

    public double Gradient { get; private set; }
    public IReadOnlyList<Expr<T>> Dependencies { get; }
    public T? DummyValue { get; }

    public T GetValue() => _getValueImpl(this);

    public void AddGradient(double value) {
        Gradient += value;
    }

    public void Backward() {
        Gradient = 1;
        foreach (var dep in GetAllDependencies()) {
            dep._updateGradient(dep);
        }
    }

    private IReadOnlyList<Expr<T>> GetAllDependencies() => // ... same as before

    public static Expr<T> operator +(Expr<T> left, Expr<T> right) {
        return new Expr<T>(
            new[] { left, right },
            static e => e.Dependencies[0].GetValue() + e.Dependencies[1].GetValue(),
            static e => {
                e.Dependencies[0].AddGradient(e.Gradient);
                e.Dependencies[1].AddGradient(e.Gradient);
            });
    }

    public static Expr<T> operator *(Expr<T> left, Expr<T> right) => new(
        new[] { left, right },
        static e => e.Dependencies[0].GetValue() * e.Dependencies[1].GetValue(),
        static e => {
            e.Dependencies[0].AddGradient(e.Dependencies[1].GetValue() * e.Gradient);
            e.Dependencies[1].AddGradient(e.Dependencies[0].GetValue() * e.Gradient);
        });
}

Using static lambdas results in additional hassle of passing "self" instance, which may require some getting used to, but it quickly becomes natural.

Conclusion

And now we've made a full cycle and got back to the most simple "data and functions" approach, which ironically resembles that infamous curve bell IQ meme.

So, what's next? By no means this is expected to be a production-grade implementation, I'm just playing around with the most basic ideas, trying to build up on them, and brush up my knowledge about internals in the mean time. I plan to proceed with building more complex machine learning primitives. Would be great to do some load testing on various implementations, extend it to multi-dimensional tensors, plug into CUDA, etc.