# Gradient descent: back to basics with F# and C#

After dabbling with high-level AI tools recently, I wanted to revisit the basics and look under the hood. As they say, the best way to learn something is to build it. I've been following wonderful "from scratch" videos from Andrej Karpathy, where he starts with the essence of all modern AI architectures: gradient descent engine.

#### Side note: look back at university

I completely forgot that one of my university projects back in 2010 was actually about building an image recognition neural network! It was funny to look back at the source code from that era. Turns out, the core ideas are still relevant, it's just computational advances, new modelling discoveries and infrastructure around them that allowed so much progress in the last decade.

I'm not going to recite all the details here, the main (much simplified) ideas to comprehend are as follows (I recommend Deep Learning with PyTorch, fast.ai course, and Huggingface course to get a much deeper dive into this)

• We can represent most real-world data, like text, images, sounds, as vast arrays of numbers
• Turns out, we can model a lot of "intelligence" tasks (natural language processing, image recognition, etc.) with sophisticated mathematical functions operating on our numeric representation of real-world data
• Universal approximation theorem was a break-through postulating that even most sophisticated functions can be actually modeled by combining only primitive linear `f(x) = a * x + b` and basic non-linear functions such as `tanh` into some inter-connected network (here come the inevitable neuron models and neural networks), we just need a huge number of them to do that, and some computation power to figure out the arguments for those functions
• "Training" such networks is figuring out how to tweak randomly initialized function arguments to make the entire network close to modelling our sophisticated function
• You can certainly mess around with those function arguments randomly until you get what you want, but that's extremely inefficient computationally-wise. However, since we only use simple linear and non-linear functions underneath, your algebra course may remind you that derivatives let you find out how the function reacts to the change of its arguments analytically.

Essentially, it means that neural networks are huge machines crunching partial derivatives of various functions at specific local points (text/images/etc. encoded as numbers), to understand how to tune the `a` and `b` in `f(x) = a * x + b`

Of course in real world you don't even have to think about it these days. Libraries like PyTorch give you nice production-ready, battle-tested and GPU-ready building blocks. Heck, even that is considered too low-level now, for most people the better approach is to use a pre-built model from Huggingface, which you can access with a single function call in Python now. What a time to be alive!

But I had this urge to look into the boring old implementation details, have some fun re-building a dumbified PyTorch from scratch, and play with recent generic math features in C#. However, I actually started with F#, as it has been my go-to choice for various scripts and algorithmic tasks for almost 15 years now. It's also praised a lot as data-oriented language, so seems like a good fit here.

Spoiler alert: I'm not happy with neither F# nor C# implementations, both languages have their quirks which sometimes make them annoying for this specific task at hand 😒.

## F#: basic implementation

One of the core features of PyTorch that I wanted to re-implement is automatic gradient tracking, meaning that all operations on PyTorch primitives keep track of their origins, which then lets you apply derivatives to build a gradient vector, and apply it to your model's parameters to nudge them into needed direction.

So if I write `a = b + c` in PyTorch, then `a` is not just some numerical value, it's an object having both the numerical value, and the notion of it originating from addition of `b` and `c`.

First, let's see how our example program may look like:

``````let a, b, c = Value(1), Value(-4), Value(10)
let p1 = a * b + c // Add(Mult(Value(1), Value(-4)), Value(10))``````

In my implementation I decided to model it with expression hierarchy, i.e. the result of `*` or `+` or any other operation is a dedicated case class wrapping the underlying value and expression nodes it depends upon.

``````[<AbstractClass>]
type Expr() =
// Local gradient of this node
member val Grad = 0.0 with get, set

// Mutates local gradients of expressions this expression depends on,
// i.e. "how much should expression I depend upon change to make my value go up by 1 unit?"
abstract member UpdateGradient: unit -> unit

// Numerical result of the computation represented by this expression
abstract member Value: float
``````

For example, the results of addition and multiplication are then modeled as:

``````type Add(left: Expr, right: Expr) =
inherit Expr()

// z = x + y; x = f(a)
// dz/da = dz/dx * dx/da = 1 * dx/da = dx/da

override _.Value = left.Value + right.Value

type Mult(left: Expr, right: Expr) =
inherit Expr()

// z = x * y; x = f(a)
// dz/da = dz/dx * dx/da = y * dx/da

override _.Value = left.Value * right.Value``````

We will also need the simplest `Value` case to lift the primitive underlying value into our expression world

``````// Just a dumb Expr wrapper around value, most likely to be used as parameter or model input
type Value(initialValue: float) =
inherit Expr()

// Mutable because we're going to change it when training
let mutable value = initialValue

override _.Value with get() = value``````

Let's also throw in operator overloads to make `a * b + c` possible

``````type Expr
with
static member inline (*) (left, right) = Mult(left, right)
static member inline (+) (left, right) = Add(left, right)``````

The immediate question is why not use something more idiomatic like discriminated unions `type Expr = | Value | Add of Expr * Expr | Mult of Expr * Expr` or records?

For this specific task I want expressions to be "open", meaning that it's possible to extend the "base" library by adding more expression types. Discriminated unions are great for modelling some internal state, but not really great for external extensibility.

Records are kind of supposed to be immutable, but training a network is essentially mutating its inner state, so wasting GC on allocating immutable data structures does not sound appealing.

So, we're back to the ugly OOP and class inheritance!

Side note: playing around with C# implementation, I ditched the inheritance in the end, and lifted method overrides like `UpdateGradient` into parameters. The same trick could be applied to F# version (listed below) to make it more functional, but for the sake of experiment, I'm going to proceed the way it worked for me initially.

``````type Expr = {
GetValueImpl: Expr -> double
Dependencies: Expr array
}``````

## Backward pass

Before we add more expression nodes needed for modelling, like `tanh` or `sigmoid`, let's consider how we're going to use the future model. The model itself is a large function (our expression tree), with leaf `Value` nodes representing both numerical encoding of input data and tunable model parameters.

We want to know how far the model's output (value of the root node) is different from our predictions. This is called the loss metric of the current model state, and we want it to go down. There are numerous strategies for loss calculation, one of the simplest is just taking the difference between actual and predicted number, and then throwing away its sign, either by calling `abs()` or squaring it.

Therefore, the loss can be represented as our differentiable expression node as well, which also means we can build a gradient on it and nudge the dependencies to make the loss go down.

PyTorch exposes `backward` method on its modules, so let's implement something similar for our `Expr`. First, we need every expression node to expose the nodes they depend on. I'm going to skip insignificant code details, the full code is available in the accompanying repository.

``````// Only new code listed here, previously defined methods are skipped
type Expr() =
// ...
abstract member Children: Expr seq

type Value(initialValue: float) =
// ...
override _.Children = Seq.empty

type Mult(left: Expr, right: Expr) =
// ...
override _.Children = seq {left; right}

type Add(left: Expr, right: Expr) =
// ...
override _.Children = seq {left; right}``````

Let's add a helper method to get all dependency nodes (down to leaf levels) for a given expression, and then add our `Backward()` method. First, it sets the gradient of the current node to 1 (rudimentary case: to change the current node by 1 unit, we have to change it by 1), and then goes through all dependencies once (order matters) and updates their gradients based on the nodes depending on them (everything is going backwards, yay 😵).

``````type Expr
with
member this.GetAllChildren() =
let mutable res = []
let visited = System.Collections.Generic.HashSet<_>()
let rec loop (v: Expr) =
for child in v.Children do
loop child
res <- v::res
loop this
res

member this.Backward() =
for v in this.GetAllChildren() do

This is once again far from idiomatic functional code, but good enough for our task at hand.

## Annoyance creeps in

This is the place where my annoyance started picking up. I feel that I'm using the wrong tool for the job. Modelling class hierarchies in F#just does not feel good. For example, the language does not support a notion of `protected` members, which is kind of by design, but arrgh.. Essentially, it means my `Expr` exposes too much inner things at the public surface, like `UpdateGradient` method, which is an internal thing expected to be called by `Backwards` only.

Next thing, I wanted to add some optional constructor parameters to prettify the debugging outputs in some cases without enforcing parameter passing in all places. Turns out, F# supports optional parameters on `let` bindings only, not on class members.

F# is famously a single-pass compiler, which has its share of benefits. However, it also means I can't define class-based operator overloads in the base `Expr` definition (and that's the core feature of the engine):

``````type Expr() =
// .. some code omitted
static member (+) (left, right) = Add(left, right) // this does not work as Add is unknown yet

type Add(left: Expr, right: Expr) =
// .. some code omitted``````

Fortunately, this can be mitigated with type augmentations, so I can define operators after defining expression types

``````type Expr() =
// .. some code omitted

type Add(left: Expr, right: Expr) =
// .. some code omitted

type Expr
with
static member inline (+) (left, right) = Add(left, right)``````

There are also some minor quirks around type inference and math contracts, like `exp` and `pow` operators expecting the result type to match the argument types, which adds some redundancy (note `: Expr` and `:> Expr`), which kind of defeats the purpose.

``````type Pown(x: Expr, n: int) =
inherit Expr()
// ..

type Expr
with
static member inline Pow (left, right): Expr = Pown(left, right)

let a = Value(2.0)
let b = (a :> Expr) ** 2``````

In the end, I got the basic setup working, added more expression types, added minor abstractions for neuron, layer, and multi-layer perceptron, but I was somewhat dissatisfied with the result.

## C#: class hierarchy

Then I remembered that recent C# and .NET versions added support for "generic math", based on static abstract interface members, and C# is also less strict on types, so I decided to play with these new features and re-implement the engine in C#.

I started by replicating the same inheritance-based approach, with base `Expression` class and subtype for each operation kind. Once again, I'm skipping insignificant bits, the full code is available on Github.

``````internal abstract class Expression {
public double Gradient { get; set; }

public abstract double GetValue();

// Get expression this node depends on
protected abstract IEnumerable<Expression> GetDependencies();

// Propagate gradient from this node to its dependencies

public void Backward()  => // ... same as before

private IReadOnlyList<Expression> GetAllDependencies() => // ... same as before

public static Expression operator +(Expression left, Expression right) => new Add(left, right);

public static Expression operator *(Expression left, Expression right) => new Mult(left, right);
}

internal class Value : Expression {
// ...
}

internal class Add : Expression {
// ...
}

internal class Mult : Expression {
// ...
}``````

Here we immediately see the same "problem" as with F# implementation: `+` and `*` overloads can be defined only on the base class, you can't have them as external extension methods or something like trait implementations. I guess we can live with that for now.

## C#: generic math

As an experiment, I wanted to abstract over the underlying value type. Currently it's `double`, but can we make it any "number"-like type? That's where "generic math" comes in, which now lets us add generic type constraints like `where T : IAdditionOperators<T, T, T>`, specifying that we want instances of `T` to support some abstract `+` operator, accepting two `T`s and returning another `T`. Applying this to our expressions (source):

``````internal abstract class Expression<T> {
public double Gradient { get; set; }

public abstract T GetValue();

// ... most other members unchanged
}

internal class Value<T> : Expression<T> {
// ... most other members unchanged
}

where T : IAdditionOperators<T, T, T> { // <------ HERE

// ... most other members unchanged

public override T GetValue() => _left.GetValue() + _right.GetValue();

}
}

internal class Mult<T> : Expression<T>
where T : IMultiplyOperators<T, T, T>, IMultiplyOperators<T, double, T>, IAdditionOperators<T, double, double> { // <------ WOW

// ... most other members unchanged

public override T GetValue() => _left.GetValue() * _right.GetValue();

// Must use `X = Y + X` form instead of `X += Y` because of IAdditionOperators<T, double, double>
}
}``````

Note that things start to get out of hand here: we'd like to abstract over underlying data type, but we still must add and multiply it with our `Gradient`, which is  `double`, so we need more ugly constraints like `IAdditionOperators<T, double, double`. The caveat here is that built-in types like `int` do not implement such interfaces.

This in turn means that we can't actually have an instance of `Expression<int>`. Also, it's not (yet?) possible to implement these static abstract interface methods in extension-like manner, only the `T` type itself must be used, so the following imaginary code is not possible:

``````struct IntToDouble : IAdditionOperators<int, double, double> {
...
}``````

Finally, since operator overloads must be defined on the base abstract class, we actually must add type constraints from all implementation nodes on the base `Expression<T>` class, which also makes them required on all descendants as well, even if their implementation does not need specific operators:

``````internal abstract class Expression<T>
where T : IAdditionOperators<T, T, T>,
IMultiplyOperators<T, T, T>,
IMultiplyOperators<T, double, T>,

public static Expression<T> operator +(Expression<T> left, Expression<T> right) => new Add<T>(left, right);

public static Expression<T> operator *(Expression<T> left, Expression<T> right) => new Mult<T>(left, right);

// ...
}

internal class Value<T> : Expression<T>
where T : IAdditionOperators<T, T, T>,
IMultiplyOperators<T, T, T>,
IMultiplyOperators<T, double, T>,
// ...
}

where T : IAdditionOperators<T, T, T>,
IMultiplyOperators<T, T, T>,
IMultiplyOperators<T, double, T>,
// ...
}

internal class Mult<T> : Expression<T>
where T : IMultiplyOperators<T, T, T>,
IMultiplyOperators<T, double, T>,

// ...
}``````

Overall, I find the resulting implementation unappealing, redundant, and just not worth it.

## C#: default interface members

While we're at it, I also wanted to play around with default interface implementations from C#8, so here's another version with `IExpression` interface instead of `Expression` abstract class (source):

``````public interface IExpression {
double Gradient { get; set; }

double GetValue();

protected IEnumerable<IExpression> GetDependencies();

public void Backward() => // ... same as before

private IReadOnlyList<IExpression> GetAllDependencies() => // ... same as before

static IExpression operator +(IExpression left, IExpression right) => new Add(left, right);
}

// Technically, we can use structs for interface implementations
public struct Value : IExpression {
// ...
}

public class Add : IExpression {

public Add(IExpression left, IExpression right) {
_left = left;
_right = right;
}

// Must define the implementation
public double Gradient { get; set; }

public double GetValue() => _left.GetValue() + _right.GetValue();

// Implicit interface implementation would require "public" access modifier
}

IEnumerable<IExpression> IExpression.GetDependencies() {
yield return _left;
yield return _right;
}
}``````

It immediately brings up a couple of interesting observations.

First, dropping abstract base classes in favor of interfaces lets us implement some expression nodes as structs instead of classes, which theoretically should be better from GC point of view. However, in practice it may not work out exactly as you'd expect, structs mixed with interfaces easily lead to boxing at runtime, there's even a nice warning from Rider/ReSharper:

``Using the default implementations of the interface members may result in boxing of the struct 'Value' values at runtime in generic code (where T : IExpression)``

All in all, this is definitely not the best scenario for default interface members:

• Even though we have `protected GetDependencies()` on the interface, it must be implemented as `public GetDependencies()` in derived types (or using explicit interface implementation).
• We just put `Backward()`  or `operator +` implementation into `IExpression` interface for the code reuse, but types implementing `IExpression` may decide to override it, which is undesirable.

## C#: static abstract interface members

Things get even worse if we decide to plug in our expressions into the previously mentioned "generic math" ecosystem! Let's try implementing `IAdditionOperators` for it, so that our expressions can be used with various numerics operations:

``````public interface IExpression : IAdditionOperators<IExpression, IExpression, IExpression> {
// ...
}``````

This fails to compile with the following error:

``[CS8920] The interface 'IExpression' cannot be used as type argument. Static member 'IAdditionOperators<IExpression, IExpression, IExpression>.operator +(IExpression, IExpression)' does not have a most specific implementation in the interface.``

Ehhmm, what? Looking for details of this error leads to obscure thread on Github (well, maybe not so obscure, but I'm past my prime days of being deep into compilers to quickly understand it). Basically, you can't mix static abstract members with type constraints in the same interface. Another dead end, feels like most of these features work only in for a limited amount of use-cases in the base library, and are not actually designed for general-purpose extensibility.

## C#: back to data and functions

In the end, I was so annoyed that I decided to ditch the class hierarchy and interface nonsense altogether and just model everything in a dumb data structure. I did a naive straightforward refactoring:

• Lift protected member overloads into constructor parameters. Use static lambdas to share delegate instances and avoid redundant allocations.
• Pass dependencies as a collection parameter to constructor

This results in a less type-safe implementation with additional dumb code for indexed dependency access, but it's still quite simple and readable in my opinion (source).

``````class Expr<T> where T : INumber<T>, IMultiplyOperators<T, double, double> {

/// <remarks>
/// Function parameters are expected to be static, so that each expression instance gets the same delegate, to avoid additional GC pressure.
/// </remarks>
public Expr(
Func<Expr<T>, T> getValueImpl,
// Used only as a hack to pass the underlying value to Value implementation and avoid closure allocations
T? dummyValue = default) {
_getValueImpl = getValueImpl;
Dependencies = dependencies;
DummyValue = dummyValue;
}

public double Gradient { get; private set; }
public IReadOnlyList<Expr<T>> Dependencies { get; }
public T? DummyValue { get; }

public T GetValue() => _getValueImpl(this);

}

public void Backward() {
foreach (var dep in GetAllDependencies()) {
}
}

private IReadOnlyList<Expr<T>> GetAllDependencies() => // ... same as before

public static Expr<T> operator +(Expr<T> left, Expr<T> right) {
return new Expr<T>(
new[] { left, right },
static e => e.Dependencies.GetValue() + e.Dependencies.GetValue(),
static e => {
});
}

public static Expr<T> operator *(Expr<T> left, Expr<T> right) => new(
new[] { left, right },
static e => e.Dependencies.GetValue() * e.Dependencies.GetValue(),
static e => {
Using `static` lambdas results in additional hassle of passing "self" instance, which may require some getting used to, but it quickly becomes natural.