Compiling Einsum to MLIR
January 03, 2024
On This Page
The Einstein summation convention, einsum convention in short, is a notational convention for specifying operations on multidimensional arrays.
It has uses in many branches of scientific computing ranging from tensor networks used in physics to expressing new, exotic deep learning operators.
Below, two examples are shown. Each element of the array in the left-hand side of the expression is given by the result of computing the value for all possible indices on the right-hand side and summing over all indices that don't appear in the result.The notation used here doesn't match with e.g. Numpy or Pytorch's way of representing einsums. In these packages, these examples would beeinsum('i,i->i', [A, B])
andeinsum('ik,kj->ij', [α*A, α*B])
.
Though equivalent, I find the notation used in this post slightly easier to parse.
# elementwise addition
C[i] = A[i] + B[i]
# matrix multiplication
C[i, j] = α * (A[i, k] * B[k, j])
This post isn't concerned with how to write einsum notation to implement linear algebra operationsFor an overview of einsum from a deep learning perspective, you might want to read this write-up by Tim Rocktäschel., rather, it goes over how these expressions can actually be compiled and run on your computer.
The Idea
The basic principle of generating code for an einsum expression is to create a series of nested loops. The outer loops loop over each element in the output array (left-hand side). While the inner loops loop over any indices that only appear in the inputs (right-hand side). For these inner loops, the result is aggregated.
For example, with arrays , and being of size (N×K), (K×M) and (N×M) respectively, you get:
# (matrix multiplication)
# C[i, j] = A[i, k]*B[k, j]
for i in 1:N
for j in 1:M
temp = 0
for k in 1:K
temp += A[i, k] * B[k, j]
end
C[i, j] = temp
end
end
This is of course a contrived example as we could simply call into BLAS (C = A*B
) and get code that is much faster because it has been optimized to death.Note also that this is the most naive approach to generating code for einsum expressions. Different implementations might gain a lot of additional performance by reordering loops and trying to map them on BLAS operations automatically. e.g. opt_einsum (in Python), and TensorOperations.jl (in Julia)
However, in combination with an optimizing compiler, the simple approach outlined above has proven to be competitive for some cases that don't simply map on one or two BLAS calls.
Julia and MLIR
Julia
The Julia programming language uses the LLVM compiler and because of its powerful metaprogramming capabilities, it's the ideal breeding groundEinsum.jl, OMEinsum.jl, TensorOperations.jl, Tullio.jl, TensorCast.jl, TensorRules.jl, … for einsum packages and similar domain specific languages for tensor operations.
These different packages typically expose macros to convert an einsum expression (or something similar) to valid Julia code. This code is then compiled just as if it were manually written by the user and out comes the optimized machine code.
MLIR
Now, what's MLIR and where does it come into the picture?
MLIR (Multi-Level Intermediate Representation) is a compiler infrastructure project providing tools to create intermediate representations (IRs) and passes to transform them.
If you're familiar with LLVM, this might sound familiar, and you'd be right. In fact, MLIR is part of the LLVM project but its IR is more flexible and modular than LLVM IR.
For example, in MLIR it's possible to represent high-level deep learning operations (e.g. matmul, convolution, …), low-level machine intrinsics (e.g. AVX512 instructions), and everything in between, all in the same IR. This is done by defining different dialects, each having their own set of operations and types.
MLIR also has the affine
dialect, which is used to represent polyhedral structures. This is a very powerful dialect that can be used to represent loops and their dependencies in a way that allows for powerful transformations and optimizations.
What if we generate affine MLIR IR from the einsum expression? This would allow us to use the MLIR infrastructure to optimize the code and generate machine code for different targets (e.g. CPU, GPU, …). Turns out, we can.
The Julia Compilation Pipeline
Julia code typically goes through a few different steps before becoming the machine code that's actually executed on your machine.
First is parsing, where code is converted to an abstract syntax tree without any semantic information. In Julia, macros can be used to alter code after parsing, taking in expressions and returning expressions. Next is lowering, where the AST is converted to Julia IR, a format that can contain type information and some optimizations such as inlining, constant propagation, … At this level, it's also possible to apply transformations using the same functions the Julia compiler uses internally.Large parts of the Julia compiler itself are written in Julia. Finally we have codegen. Here, the Julia IR is converted to LLVM IR, which is then further compiled to machine code and executed.
Where Did My Loops Go?
Take one of the simplest einsum expressions, an elementwise copy from vector to vector :
@einsum y[i] = x[i]
The @einsum
macro will expand this to the following code:
for i in 1:length(x)
y[i] = x[i]
end
And lowering to Julia IRCode, we get:With Julia 1.11
and beyond, this IR looks quite different because the Array
type was redesigned. I show here the output from Julia 1.9
as it is slightly shorter and easier to follow.
1 ── %1 = Base.arraylen(_2)::Int64
│ %2 = Base.sle_int(1, %1)::Bool
└─── goto #3 if not %2
2 ── goto #4
3 ── goto #4
4 ┄─ %6 = φ (#2 => %1, #3 => 0)::Int64
└─── goto #5
5 ── goto #6
6 ── %9 = Base.slt_int(%6, 1)::Bool
└─── goto #8 if not %9
7 ── goto #9
8 ── goto #9
9 ┄─ %13 = φ (#7 => true, #8 => false)::Bool
│ %14 = φ (#8 => 1)::Int64
│ %15 = φ (#8 => 1)::Int64
│ %16 = Base.not_int(%13)::Bool
└─── goto #15 if not %16
10 ┄ %18 = φ (#9 => %14, #14 => %27)::Int64
│ %19 = φ (#9 => %15, #14 => %28)::Int64
│ %20 = Base.arrayref(true, _2, %18)::Float64
│ Base.arrayset(true, _3, %20, %18)::Vector{Float64}
│ %22 = (%19 === %6)::Bool
└─── goto #12 if not %22
11 ─ goto #13
12 ─ %25 = Base.add_int(%19, 1)::Int64
└─── goto #13
13 ┄ %27 = φ (#12 => %25)::Int64
│ %28 = φ (#12 => %25)::Int64
│ %29 = φ (#11 => true, #12 => false)::Bool
│ %30 = Base.not_int(%29)::Bool
└─── goto #15 if not %30
14 ─ goto #10
15 ┄ return nothing
=> Nothing
Looking at the Julia IR, it might be difficult to spot the for loops in the original function. That's because Julia IR only contains unstructured control flow, these are (conditional) goto statements.Aditionally, the phi-nodes (φ
) in the IR can be seen as operations returning a different value depending on from which block the execution flow is coming.
So φ (#2 => %1, #3 => 0)
should be read as “if the execution flow is coming from block 2, return value %1
, otherwise return the constant 0
”. For generating MLIR, that's a problem. Even though MLIR has a dialect for unstructured control flowThe cf dialect, our goal is to keep the IR as structured as possible. This way, it's easier to apply transformations and optimizations.
To still have a way to easily recover a for loop from Julia IR, we can replace the for loop by two carefully crafted function calls begin_for
and yield_for
.
# instead of writing:
for i in 1:length(x)
y[i] = x[i]
end
# write:
i = Brutus.begin_for(1, length(x))
y[i] = x[i]
Brutus.yield_for()
The only use for these functions is to show up in Julia IR. Apart from that, they don't do anything. The Julia code does not work anymore, since the functions don't have a sensible definition, but that's not a problem since we're only interested in the IR.
1 ─ %1 = Base.arraylen(_2)::Int64
│ %2 = invoke Main.begin_for(1::Int64, %1::Int64)::Int64
│ %3 = Base.arrayref(true, _2, %2)::Float64
│ Base.arrayset(true, _3, %3, %2)::Vector{Float64}
│ %5 = invoke Main.yield_for()::Core.Const(nothing)
└── return %5
=> Nothing
The unstructured control flow has disappeared and we now see our two function calls clearly demarcating the loop body. The reason why the IR is so much shorter is also because some of the bounds checking has disappeared. Let's ignore that for now.
We also need a way to represent an accumulator variable that is updated in the loop body. This can be done by making begin_for
have an extra argument (the initial value), and a second output (the accumulator variable). The accumulator variable is then passed to yield_for
which will also return it.
⋮
accumulator, i = Brutus.begin_for(initial_value, start, stop)
accumulator += b[i, k] * c[k, j]
accumulator = Brutus.yield_for(accumulator)
⋮
By hacking the existing Einsum.jl package to emit these functions instead of for loops, we now have a way to generate Julia IR that is ready to be converted to MLIR.
Generating MLIR
Interfacing Julia with MLIR is relatively straightforward thanks to the MLIR C API and Julia's abillity to call C functions directly using ccall
. The MLIR C API is a minimal wrapper around the MLIR framework, which is written in C++, and allows building IR among other things. In Julia, this C API is wrapped and a higher-level API for IR handling is exposed in the MLIR.jl package.To my knowledge, this higher-level API was first developed in the Coil.jl project.
The MLIR C API doesn't contain a lot of dialect-specific functionality. The mlirOperationCreate
function allows to build operations to insert into IR, but is overly generic and annoyingly verbose to use when generating many operations. To also get a higher-level Julia API for each dialect, we can use a tool to generate Julia code based on dialect Operation Definition Specifications (ODS).This is similar to how the official Python API and mlir-hs (Haskell) integrate more tightly with dialect specifications. These specifications are stored in tablegen files, a special domain-specific language used by the LLVM project that is typically converted to C++ source code.
Generating MLIR comes down to looping over all Julia IR statements and emitting the corresponding MLIR operations. Simple binary operations such as Base.mul_int
(integer multiplication) can be mapped directly onto an existing operation in MLIR (e.g. arith.muli
).
Unsuprisingly, instructions such as our begin_for
and yield_for
, need a little more care. A begin_for
instruction signals that following instructions need to be nested within the loop body. In practice this means a new MLIR Region is created and used until the corresponding yield_for
is encountered.In fact, since the loop body region can only contain one block, two regions are created. The innermost region contains all the blocks from the Julia IR - typically more than one. The outermost region contains a single block with the operation scf.execute_region
which will simply execute that child region (containing multiple blocks). Since we're dealing with simple IR, an MLIR canonicalization pass will typically simplify this by merging multiple blocks and getting rid of the surrounding execute_region
. The for loop itself also can't be created immediately because its loop body region is not complete yet. Instead a thunk is put on a stack and will be called once the yield_for
is encountered.
We can use the same simple vector copy example. Base.code_ircode
can be used to see what Julia IR is generated for a function. The second argument to that function contains the argument types.MemRef
isn't a builtin Julia type. Rather, it's a custom Array type that dispatches its getindex
and setindex
methods to mlir_load
and mlir_store
invocations. These functions are similar to our custom begin_for
and yield_for
in that the only purpose for them is to show up in the IR, they can't be executed in regular Julia.
If you'd use regular array arguments, a lot more Julia IR would be generated including boundschecks and error handling. These could eventually be handled correctly but for now its easier to just get rid of them in this way.
f(x, y) = @einsum y[i] = x[i]
Base.code_ircode(f, Tuple{MemRef{Float64, 1}, MemRef{Float64, 1}})
This returns the following Julia IR:
⋮
%13 = invoke Brutus.begin_for(1::Int64, %12::Int64)::Int64
%14 = Base.sub_int(%13, 1)::Int64
%15 = invoke Brutus.mlir_load(_3::MemRef{Float64, 1}, %14::Int64)::Float64
%16 = Base.sub_int(%13, 1)::Int64
invoke Brutus.mlir_store!(_2::MemRef{Float64, 1}, %15::Float64, %16::Int64)::Float64
invoke Brutus.yield_for()::Any
⋮
code_mlir
is the function that will loop over all Julia IR and create the corresponding MLIR:
code_mlir(f, Tuple{Brutus.MemRef{Float64, 1}, Brutus.MemRef{Float64, 1}})
which returns:
⋮
scf.for %arg2 = %c1_4 to %6 step %c1_6 {
scf.execute_region {
%8 = arith.index_cast %arg2 : index to i64
%c1_8 = arith.constant 1 : index
%9 = arith.index_cast %c1_8 : index to i64
%10 = arith.subi %8, %9 : i64
%11 = arith.index_cast %10 : i64 to index
%12 = memref.load %arg1[%11] : memref<?xf64, strided<[1], offset: ?>>
%13 = arith.index_cast %arg2 : index to i64
%c1_9 = arith.constant 1 : index
%14 = arith.index_cast %c1_9 : index to i64
%15 = arith.subi %13, %14 : i64
%16 = arith.index_cast %15 : i64 to index
memref.store %12, %arg0[%16] : memref<?xf64, strided<[1], offset: ?>>
scf.yield
}
}
⋮
A lot of arithmetic operations are generated to support the one-based, column-major layout of Julia arrays and conversion between index and integer types. After canonicalization in MLIR, most of these operations are processed away and we're left with the following:
⋮
scf.for %arg2 = %c1 to %4 step %c1 {
%5 = arith.index_cast %arg2 : index to i64
%6 = arith.subi %5, %c1_i64 : i64
%7 = arith.index_cast %6 : i64 to index
%8 = memref.load %arg1[%7] : memref<?xf64, strided<[1], offset: ?>>
memref.store %8, %arg0[%7] : memref<?xf64, strided<[1], offset: ?>>
}
⋮
Some details are not discussed here. For example the Julia IR generally contains control flow operations (br
, return
, φ
-nodes, …) that need to be handled. This, and a lot of the plumbing required to make this work was developed by Pangoraw in this pull request. Pangoraw also wrote Coil.jl, which similarly compiles Julia IR to MLIR, albeit using a tracing approach instead of the one outlined here.
In the beginning of this post I promised MLIR IR in the affine
dialect, however we ended up with scf
operations. Fear not, these can be raised to the affine
dialect using Polygeist.
polygeist-opt --allow-unregistered-dialect --raise-scf-to-affine --affine-cfg my_mlir.mlir
which returns:
⋮
affine.for %arg2 = 1 to #map()[%3] {
%4 = affine.load %arg1[%arg2 - 1] : memref<?xf64, strided<[1], offset: ?>>
affine.store %4, %arg0[%arg2 - 1] : memref<?xf64, strided<[1], offset: ?>>
}
⋮
Where to Go From Here
It's now possible to lower the generated MLIR to LLVM, which can then be compiled to machine code. Doing so doesn't really get us anything, though. Since we're not running any optimization passes, the performance we get is actually about the same as simply writing the loops yourself.
What this whole exercise did give us, though, is the ability to run MLIR passes on the generated IR. This webpage contains an overview of transformations that are available in the upstream MLIR repository. Aditionally, there are other projects providing MLIR passes as well.
The Polymer subproject of Polygeist also allows connecting Pluto, a more “traditional” polyhedral compilation tool, to MLIR.
Other projects for optimizing MLIR include IREE, Accera, Polyblocks (closed-source), and undoubtedly many more.
Upstream MLIR does not contain a lot of automatic optimization passes, meaning that you'll have to develop your own heuristics or cost models to drive the tranformations. Some of the aforementioned tools do provide such automated optimizations, however, they typically constrain the type of IR they accept or have their own external dialects. For this reason it isn't trivial to integrate these tools yet, but it's definitely possible and I hope to explore this in the future.
This post started with the goal to generate MLIR in the affine
dialect. It might be interesting to also consider the linear algebra dialect (linalg
) since that dialect provides an einsum-inspired way to represent linear algebra operations.
The way loops are represented in Julia IR is not ideal. In fact, some more advanced einsum expressions fail to compile because the lowering to Julia IR is too aggressive and gets rid of some blocks in the control flow graph that are needed to generate the MLIR. The lack of high-level loops in Julia IR is a fundamental problem but more elegant workarounds might be possible. Currently this is definitely something that stands in the way of being a reliable tool instead of an initial proof-of-concept.
In the same vein, the code I wrote is available on Github but comes with no guarantees whatsoever. Things are flaky and might not work as expected but if you're on Linux, using Julia 1.11.0-DEV.1141
or similar Julia 1.10
won't cut it because the code assumes the redesigned Array implementation., it should be possible to clone the repository and run the example on the blog
branch.
On Linux, an MLIR release should automatically be downloaded as an aritfact when instantiating the environment, on MacOS or Windows, you're on your own to provide a valid MLIR installation.