Gradient descent: back to basics with F# and C#
After dabbling with high-level AI tools recently, I wanted to revisit the basics and look under the hood. As they say, the best way to learn something is to build it. I've been following wonderful "from scratch" videos from Andrej Karpathy, where he starts with the essence of all modern AI architectures: gradient descent engine.
So, I decided to keep along and build my own engine at https://github.com/khmylov/dumb-gradient, and this article is an overview of its implementation, and some thoughts on using F# and C# for a task like that.
Side note: look back at university
I completely forgot that one of my university projects back in 2010 was actually about building an image recognition neural network! It was funny to look back at the source code from that era. Turns out, the core ideas are still relevant, it's just computational advances, new modelling discoveries and infrastructure around them that allowed so much progress in the last decade.
I'm not going to recite all the details here, the main (much simplified) ideas to comprehend are as follows (I recommend Deep Learning with PyTorch, fast.ai course, and Huggingface course to get a much deeper dive into this)
- We can represent most real-world data, like text, images, sounds, as vast arrays of numbers
- Turns out, we can model a lot of "intelligence" tasks (natural language processing, image recognition, etc.) with sophisticated mathematical functions operating on our numeric representation of real-world data
- Universal approximation theorem was a break-through postulating that even most sophisticated functions can be actually modeled by combining only primitive linear
f(x) = a * x + b
and basic non-linear functions such astanh
into some inter-connected network (here come the inevitable neuron models and neural networks), we just need a huge number of them to do that, and some computation power to figure out the arguments for those functions - "Training" such networks is figuring out how to tweak randomly initialized function arguments to make the entire network close to modelling our sophisticated function
- You can certainly mess around with those function arguments randomly until you get what you want, but that's extremely inefficient computationally-wise. However, since we only use simple linear and non-linear functions underneath, your algebra course may remind you that derivatives let you find out how the function reacts to the change of its arguments analytically.
Essentially, it means that neural networks are huge machines crunching partial derivatives of various functions at specific local points (text/images/etc. encoded as numbers), to understand how to tune the a
and b
in f(x) = a * x + b
Of course in real world you don't even have to think about it these days. Libraries like PyTorch give you nice production-ready, battle-tested and GPU-ready building blocks. Heck, even that is considered too low-level now, for most people the better approach is to use a pre-built model from Huggingface, which you can access with a single function call in Python now. What a time to be alive!
But I had this urge to look into the boring old implementation details, have some fun re-building a dumbified PyTorch from scratch, and play with recent generic math features in C#. However, I actually started with F#, as it has been my go-to choice for various scripts and algorithmic tasks for almost 15 years now. It's also praised a lot as data-oriented language, so seems like a good fit here.
Spoiler alert: I'm not happy with neither F# nor C# implementations, both languages have their quirks which sometimes make them annoying for this specific task at hand 😒.
F#: basic implementation
One of the core features of PyTorch that I wanted to re-implement is automatic gradient tracking, meaning that all operations on PyTorch primitives keep track of their origins, which then lets you apply derivatives to build a gradient vector, and apply it to your model's parameters to nudge them into needed direction.
So if I write a = b + c
in PyTorch, then a
is not just some numerical value, it's an object having both the numerical value, and the notion of it originating from addition of b
and c
.
First, let's see how our example program may look like:
let a, b, c = Value(1), Value(-4), Value(10)
let p1 = a * b + c // Add(Mult(Value(1), Value(-4)), Value(10))
In my implementation I decided to model it with expression hierarchy, i.e. the result of *
or +
or any other operation is a dedicated case class wrapping the underlying value and expression nodes it depends upon.
[<AbstractClass>]
type Expr() =
// Local gradient of this node
member val Grad = 0.0 with get, set
// Mutates local gradients of expressions this expression depends on,
// i.e. "how much should expression I depend upon change to make my value go up by 1 unit?"
abstract member UpdateGradient: unit -> unit
// Numerical result of the computation represented by this expression
abstract member Value: float
For example, the results of addition and multiplication are then modeled as:
type Add(left: Expr, right: Expr) =
inherit Expr()
override this.UpdateGradient() =
// z = x + y; x = f(a)
// dz/da = dz/dx * dx/da = 1 * dx/da = dx/da
left.Grad <- left.Grad + this.Grad
right.Grad <- right.Grad + this.Grad
override _.Value = left.Value + right.Value
type Mult(left: Expr, right: Expr) =
inherit Expr()
override this.UpdateGradient() =
// z = x * y; x = f(a)
// dz/da = dz/dx * dx/da = y * dx/da
left.Grad <- left.Grad + right.Value * this.Grad
right.Grad <- right.Grad + left.Value * this.Grad
override _.Value = left.Value * right.Value
We will also need the simplest Value
case to lift the primitive underlying value into our expression world
// Just a dumb Expr wrapper around value, most likely to be used as parameter or model input
type Value(initialValue: float) =
inherit Expr()
// Mutable because we're going to change it when training
let mutable value = initialValue
override _.UpdateGradient() = ()
override _.Value with get() = value
Let's also throw in operator overloads to make a * b + c
possible
type Expr
with
static member inline (*) (left, right) = Mult(left, right)
static member inline (+) (left, right) = Add(left, right)
The immediate question is why not use something more idiomatic like discriminated unions type Expr = | Value | Add of Expr * Expr | Mult of Expr * Expr
or records?
For this specific task I want expressions to be "open", meaning that it's possible to extend the "base" library by adding more expression types. Discriminated unions are great for modelling some internal state, but not really great for external extensibility.
Records are kind of supposed to be immutable, but training a network is essentially mutating its inner state, so wasting GC on allocating immutable data structures does not sound appealing.
So, we're back to the ugly OOP and class inheritance!
Side note: playing around with C# implementation, I ditched the inheritance in the end, and lifted method overrides like UpdateGradient
into parameters. The same trick could be applied to F# version (listed below) to make it more functional, but for the sake of experiment, I'm going to proceed the way it worked for me initially.
type Expr = {
mutable Gradient: double
GetValueImpl: Expr -> double
Dependencies: Expr array
UpdateGradient: Expr -> unit
}
Backward pass
Before we add more expression nodes needed for modelling, like tanh
or sigmoid
, let's consider how we're going to use the future model. The model itself is a large function (our expression tree), with leaf Value
nodes representing both numerical encoding of input data and tunable model parameters.
We want to know how far the model's output (value of the root node) is different from our predictions. This is called the loss metric of the current model state, and we want it to go down. There are numerous strategies for loss calculation, one of the simplest is just taking the difference between actual and predicted number, and then throwing away its sign, either by calling abs()
or squaring it.
Therefore, the loss can be represented as our differentiable expression node as well, which also means we can build a gradient on it and nudge the dependencies to make the loss go down.
PyTorch exposes backward
method on its modules, so let's implement something similar for our Expr
. First, we need every expression node to expose the nodes they depend on. I'm going to skip insignificant code details, the full code is available in the accompanying repository.
// Only new code listed here, previously defined methods are skipped
type Expr() =
// ...
abstract member Children: Expr seq
type Value(initialValue: float) =
// ...
override _.Children = Seq.empty
type Mult(left: Expr, right: Expr) =
// ...
override _.Children = seq {left; right}
type Add(left: Expr, right: Expr) =
// ...
override _.Children = seq {left; right}
Let's add a helper method to get all dependency nodes (down to leaf levels) for a given expression, and then add our Backward()
method. First, it sets the gradient of the current node to 1 (rudimentary case: to change the current node by 1 unit, we have to change it by 1), and then goes through all dependencies once (order matters) and updates their gradients based on the nodes depending on them (everything is going backwards, yay 😵).
type Expr
with
member this.GetAllChildren() =
let mutable res = []
let visited = System.Collections.Generic.HashSet<_>()
let rec loop (v: Expr) =
if visited.Add v then
for child in v.Children do
loop child
res <- v::res
loop this
res
member this.Backward() =
this.Grad <- 1.0
for v in this.GetAllChildren() do
v.UpdateGradient()
This is once again far from idiomatic functional code, but good enough for our task at hand.
Annoyance creeps in
This is the place where my annoyance started picking up. I feel that I'm using the wrong tool for the job. Modelling class hierarchies in F#just does not feel good. For example, the language does not support a notion of protected
members, which is kind of by design, but arrgh.. Essentially, it means my Expr
exposes too much inner things at the public surface, like UpdateGradient
method, which is an internal thing expected to be called by Backwards
only.
Next thing, I wanted to add some optional constructor parameters to prettify the debugging outputs in some cases without enforcing parameter passing in all places. Turns out, F# supports optional parameters on let
bindings only, not on class members.
F# is famously a single-pass compiler, which has its share of benefits. However, it also means I can't define class-based operator overloads in the base Expr
definition (and that's the core feature of the engine):
type Expr() =
// .. some code omitted
static member (+) (left, right) = Add(left, right) // this does not work as Add is unknown yet
type Add(left: Expr, right: Expr) =
// .. some code omitted
Fortunately, this can be mitigated with type augmentations, so I can define operators after defining expression types
type Expr() =
// .. some code omitted
type Add(left: Expr, right: Expr) =
// .. some code omitted
type Expr
with
static member inline (+) (left, right) = Add(left, right)
There are also some minor quirks around type inference and math contracts, like exp
and pow
operators expecting the result type to match the argument types, which adds some redundancy (note : Expr
and :> Expr
), which kind of defeats the purpose.
type Pown(x: Expr, n: int) =
inherit Expr()
// ..
type Expr
with
static member inline Pow (left, right): Expr = Pown(left, right)
let a = Value(2.0)
let b = (a :> Expr) ** 2
In the end, I got the basic setup working, added more expression types, added minor abstractions for neuron, layer, and multi-layer perceptron, but I was somewhat dissatisfied with the result.
C#: class hierarchy
Then I remembered that recent C# and .NET versions added support for "generic math", based on static abstract interface members, and C# is also less strict on types, so I decided to play with these new features and re-implement the engine in C#.
I started by replicating the same inheritance-based approach, with base Expression
class and subtype for each operation kind. Once again, I'm skipping insignificant bits, the full code is available on Github.
internal abstract class Expression {
public double Gradient { get; set; }
public abstract double GetValue();
// Get expression this node depends on
protected abstract IEnumerable<Expression> GetDependencies();
// Propagate gradient from this node to its dependencies
protected abstract void UpdateGradient();
public void Backward() => // ... same as before
private IReadOnlyList<Expression> GetAllDependencies() => // ... same as before
public static Expression operator +(Expression left, Expression right) => new Add(left, right);
public static Expression operator *(Expression left, Expression right) => new Mult(left, right);
}
internal class Value : Expression {
// ...
}
internal class Add : Expression {
// ...
}
internal class Mult : Expression {
// ...
}
Here we immediately see the same "problem" as with F# implementation: +
and *
overloads can be defined only on the base class, you can't have them as external extension methods or something like trait implementations. I guess we can live with that for now.
C#: generic math
As an experiment, I wanted to abstract over the underlying value type. Currently it's double
, but can we make it any "number"-like type? That's where "generic math" comes in, which now lets us add generic type constraints like where T : IAdditionOperators<T, T, T>
, specifying that we want instances of T
to support some abstract +
operator, accepting two T
s and returning another T
. Applying this to our expressions (source):
internal abstract class Expression<T> {
public double Gradient { get; set; }
public abstract T GetValue();
// ... most other members unchanged
}
internal class Value<T> : Expression<T> {
// ... most other members unchanged
}
internal class Add<T> : Expression<T>
where T : IAdditionOperators<T, T, T> { // <------ HERE
// ... most other members unchanged
public override T GetValue() => _left.GetValue() + _right.GetValue();
protected override void UpdateGradient() {
_left.Gradient += Gradient;
_right.Gradient += Gradient;
}
}
internal class Mult<T> : Expression<T>
where T : IMultiplyOperators<T, T, T>, IMultiplyOperators<T, double, T>, IAdditionOperators<T, double, double> { // <------ WOW
// ... most other members unchanged
public override T GetValue() => _left.GetValue() * _right.GetValue();
protected override void UpdateGradient() {
// Must use `X = Y + X` form instead of `X += Y` because of IAdditionOperators<T, double, double>
_left.Gradient = _right.GetValue() * Gradient + _left.Gradient;
_right.Gradient = _left.GetValue() * Gradient + _right.Gradient;
}
}
Note that things start to get out of hand here: we'd like to abstract over underlying data type, but we still must add and multiply it with our Gradient
, which is double
, so we need more ugly constraints like IAdditionOperators<T, double, double
. The caveat here is that built-in types like int
do not implement such interfaces.
This in turn means that we can't actually have an instance of Expression<int>
. Also, it's not (yet?) possible to implement these static abstract interface methods in extension-like manner, only the T
type itself must be used, so the following imaginary code is not possible:
struct IntToDouble : IAdditionOperators<int, double, double> {
...
}
Finally, since operator overloads must be defined on the base abstract class, we actually must add type constraints from all implementation nodes on the base Expression<T>
class, which also makes them required on all descendants as well, even if their implementation does not need specific operators:
internal abstract class Expression<T>
where T : IAdditionOperators<T, T, T>,
IMultiplyOperators<T, T, T>,
IMultiplyOperators<T, double, T>,
IAdditionOperators<T, double, double> {
public static Expression<T> operator +(Expression<T> left, Expression<T> right) => new Add<T>(left, right);
public static Expression<T> operator *(Expression<T> left, Expression<T> right) => new Mult<T>(left, right);
// ...
}
internal class Value<T> : Expression<T>
where T : IAdditionOperators<T, T, T>,
IMultiplyOperators<T, T, T>,
IMultiplyOperators<T, double, T>,
IAdditionOperators<T, double, double> {
// ...
}
internal class Add<T> : Expression<T>
where T : IAdditionOperators<T, T, T>,
IMultiplyOperators<T, T, T>,
IMultiplyOperators<T, double, T>,
IAdditionOperators<T, double, double> {
// ...
}
internal class Mult<T> : Expression<T>
where T : IMultiplyOperators<T, T, T>,
IMultiplyOperators<T, double, T>,
IAdditionOperators<T, double, double>,
IAdditionOperators<T, T, T> {
// ...
}
Overall, I find the resulting implementation unappealing, redundant, and just not worth it.
C#: default interface members
While we're at it, I also wanted to play around with default interface implementations from C#8, so here's another version with IExpression
interface instead of Expression
abstract class (source):
public interface IExpression {
double Gradient { get; set; }
double GetValue();
protected void UpdateGradient();
protected IEnumerable<IExpression> GetDependencies();
public void Backward() => // ... same as before
private IReadOnlyList<IExpression> GetAllDependencies() => // ... same as before
static IExpression operator +(IExpression left, IExpression right) => new Add(left, right);
}
// Technically, we can use structs for interface implementations
public struct Value : IExpression {
// ...
}
public class Add : IExpression {
private readonly IExpression _left;
private readonly IExpression _right;
public Add(IExpression left, IExpression right) {
_left = left;
_right = right;
}
// Must define the implementation
public double Gradient { get; set; }
public double GetValue() => _left.GetValue() + _right.GetValue();
// Implicit interface implementation would require "public" access modifier
void IExpression.UpdateGradient() {
_left.Gradient += Gradient;
_right.Gradient += Gradient;
}
IEnumerable<IExpression> IExpression.GetDependencies() {
yield return _left;
yield return _right;
}
}
It immediately brings up a couple of interesting observations.
First, dropping abstract base classes in favor of interfaces lets us implement some expression nodes as structs instead of classes, which theoretically should be better from GC point of view. However, in practice it may not work out exactly as you'd expect, structs mixed with interfaces easily lead to boxing at runtime, there's even a nice warning from Rider/ReSharper:
Using the default implementations of the interface members may result in boxing of the struct 'Value' values at runtime in generic code (where T : IExpression)
All in all, this is definitely not the best scenario for default interface members:
- Even though we have
protected GetDependencies()
on the interface, it must be implemented aspublic GetDependencies()
in derived types (or using explicit interface implementation). - We just put
Backward()
oroperator +
implementation intoIExpression
interface for the code reuse, but types implementingIExpression
may decide to override it, which is undesirable.
C#: static abstract interface members
Things get even worse if we decide to plug in our expressions into the previously mentioned "generic math" ecosystem! Let's try implementing IAdditionOperators
for it, so that our expressions can be used with various numerics operations:
public interface IExpression : IAdditionOperators<IExpression, IExpression, IExpression> {
// ...
}
This fails to compile with the following error:
[CS8920] The interface 'IExpression' cannot be used as type argument. Static member 'IAdditionOperators<IExpression, IExpression, IExpression>.operator +(IExpression, IExpression)' does not have a most specific implementation in the interface.
Ehhmm, what? Looking for details of this error leads to obscure thread on Github (well, maybe not so obscure, but I'm past my prime days of being deep into compilers to quickly understand it). Basically, you can't mix static abstract members with type constraints in the same interface. Another dead end, feels like most of these features work only in for a limited amount of use-cases in the base library, and are not actually designed for general-purpose extensibility.
C#: back to data and functions
In the end, I was so annoyed that I decided to ditch the class hierarchy and interface nonsense altogether and just model everything in a dumb data structure. I did a naive straightforward refactoring:
- Lift protected member overloads into constructor parameters. Use static lambdas to share delegate instances and avoid redundant allocations.
- Pass dependencies as a collection parameter to constructor
This results in a less type-safe implementation with additional dumb code for indexed dependency access, but it's still quite simple and readable in my opinion (source).
class Expr<T> where T : INumber<T>, IMultiplyOperators<T, double, double> {
private readonly Func<Expr<T>, T> _getValueImpl;
private readonly Action<Expr<T>> _updateGradient;
/// <remarks>
/// Function parameters are expected to be static, so that each expression instance gets the same delegate, to avoid additional GC pressure.
/// </remarks>
public Expr(
IReadOnlyList<Expr<T>> dependencies,
Func<Expr<T>, T> getValueImpl,
Action<Expr<T>> updateGradient,
// Used only as a hack to pass the underlying value to Value implementation and avoid closure allocations
T? dummyValue = default) {
_getValueImpl = getValueImpl;
_updateGradient = updateGradient;
Dependencies = dependencies;
DummyValue = dummyValue;
}
public double Gradient { get; private set; }
public IReadOnlyList<Expr<T>> Dependencies { get; }
public T? DummyValue { get; }
public T GetValue() => _getValueImpl(this);
public void AddGradient(double value) {
Gradient += value;
}
public void Backward() {
Gradient = 1;
foreach (var dep in GetAllDependencies()) {
dep._updateGradient(dep);
}
}
private IReadOnlyList<Expr<T>> GetAllDependencies() => // ... same as before
public static Expr<T> operator +(Expr<T> left, Expr<T> right) {
return new Expr<T>(
new[] { left, right },
static e => e.Dependencies[0].GetValue() + e.Dependencies[1].GetValue(),
static e => {
e.Dependencies[0].AddGradient(e.Gradient);
e.Dependencies[1].AddGradient(e.Gradient);
});
}
public static Expr<T> operator *(Expr<T> left, Expr<T> right) => new(
new[] { left, right },
static e => e.Dependencies[0].GetValue() * e.Dependencies[1].GetValue(),
static e => {
e.Dependencies[0].AddGradient(e.Dependencies[1].GetValue() * e.Gradient);
e.Dependencies[1].AddGradient(e.Dependencies[0].GetValue() * e.Gradient);
});
}
Using static
lambdas results in additional hassle of passing "self" instance, which may require some getting used to, but it quickly becomes natural.
Conclusion
And now we've made a full cycle and got back to the most simple "data and functions" approach, which ironically resembles that infamous curve bell IQ meme.
So, what's next? By no means this is expected to be a production-grade implementation, I'm just playing around with the most basic ideas, trying to build up on them, and brush up my knowledge about internals in the mean time. I plan to proceed with building more complex machine learning primitives. Would be great to do some load testing on various implementations, extend it to multi-dimensional tensors, plug into CUDA, etc.