HUIJZER.XYZ

Optimizing Julia code

2022-03-19

I'm lately doing for the first time some optimizations of Julia code and I sort of find it super beautiful.

This is how I started a message on the Julia language Slack in response to a question about why optimising Julia code is so difficult compared to other languages. In the message I argued against that claim. Optimising isn't hard in Julia if you compare it to Python or R where you have to be an expert in Python or R and C/C++. Also, in that message I went through a high-level overview of how I approached optimising. The next day, Frames Catherine White, who is a true Julia veteran, suggested that I write a blog post about my overview, so here we are.

In this blog post, I'll describe what type stability is and why it is important for performance. Unlike most other posts, I'll discuss it in the context of performance (raw throughput) and in the context of time to first X (TTFX). Julia is sort of notorious for having really bad TTFX in certain cases. For example, creating a plot with the Makie.jl package takes 40 seconds at the time of writing. On the second call, it takes about 0.001 seconds. This blog post explains the workflow that you can use to reduce running time and TTFX.

Type stability

Let's first talk about that type stability thing that everyone keeps talking about. Why is it important? To show this, let's write naive Julia code. Specifically, for this example, we write code which can hide the type from the compiler, that is, we need to add some kind of indirection so that the compiler cannot infer the types. This can be done via a dictionary. Note that our dictionary returns different types, namely an Float32 and a Float64:

numbers = Dict(:one => 1f0, :two => 2.0);
function double(mapping, key::Symbol)
    return 2 * mapping[key]
end;

This code works, we can pass :one or :two and the number will be doubled:

double(numbers, :one)
2.0f0
double(numbers, :two)
4.0

Let's look at the optimized LLVM code via @code_warntype. Here, you can ignore the with_terminal; it's only needed because this blog post is running in a Pluto.jl notebook.

using PlutoUI: with_terminal
with_terminal() do
    @code_warntype double(numbers, :one)
end
MethodInstance for Main.var\"workspace#7\".double(::Dict{Symbol, AbstractFloat}, ::Symbol)
  from double(�[90mmapping�[39m, �[90mkey�[39m::�[1mSymbol�[22m)�[90m @�[39m �[90mMain.var\"workspace#7\"�[39m �[90m/builds/rikh/blog/posts/notebooks/�[39m�[90m�[4minference.jl#==#a36ef63b-436d-45e2-8edf-625df1575e7a:1�[24m�[39m
Arguments
  #self#�[36m::Core.Const(Main.var\"workspace#7\".double)�[39m
  mapping�[36m::Dict{Symbol, AbstractFloat}�[39m
  key�[36m::Symbol�[39m
Body�[91m�[1m::Any�[22m�[39m
�[90m1 ─�[39m %1 = Base.getindex(mapping, key)�[91m�[1m::AbstractFloat�[22m�[39m
�[90m│  �[39m %2 = (2 * %1)�[91m�[1m::Any�[22m�[39m
�[90m└──�[39m      return %2

Ouch. The optimized code looks quite good with one Base.getindex and a 2 * %1, but we do get some big red warnings about the output type which is an Any. That color indicates that something is wrong. What is wrong is that an Any type cannot easily be put into a memory spot. For a concrete type such as Float64, we know how much space we need so we don't need a pointer and we can even put the number nearer to the CPU so that it can quickly be accessed. To see whether a type is concrete, we can use isconcretetype:

isconcretetype(Float64)
true
isconcretetype(AbstractFloat)
false

To make matters worse, Julia does a lot of optimizing, but it cannot do much for abstract types. For example, let's write two very simple functions:

function use_double(mapping, x)
    doubled = 2 * double(mapping, x)
    string(doubled)
end;
use_double(numbers, :one)
"4.0"

This is how the @code_warntype looks:

with_terminal() do
    @code_warntype use_double(numbers, :one)
end
MethodInstance for Main.var\"workspace#7\".use_double(::Dict{Symbol, AbstractFloat}, ::Symbol)
  from use_double(�[90mmapping�[39m, �[90mx�[39m)�[90m @�[39m �[90mMain.var\"workspace#7\"�[39m �[90m/builds/rikh/blog/posts/notebooks/�[39m�[90m�[4minference.jl#==#da60c86b-ed10-4475-988c-099d5929946e:1�[24m�[39m
Arguments
  #self#�[36m::Core.Const(Main.var\"workspace#7\".use_double)�[39m
  mapping�[36m::Dict{Symbol, AbstractFloat}�[39m
  x�[36m::Symbol�[39m
Locals
  doubled�[91m�[1m::Any�[22m�[39m
Body�[91m�[1m::Any�[22m�[39m
�[90m1 ─�[39m %1 = Main.var\"workspace#7\".double(mapping, x)�[91m�[1m::Any�[22m�[39m
�[90m│  �[39m      (doubled = 2 * %1)
�[90m│  �[39m %3 = Main.var\"workspace#7\".string(doubled)�[91m�[1m::Any�[22m�[39m
�[90m└──�[39m      return %3

The Any type propagated. Now, also the use_naive_double function has an Any output type. And, the type of the variable doubled isn't known when the function is compiled meaning that the call string(doubled) ends up being a runtime dispatch. This means that Julia has to lookup the right method during running time in the method lookup table. If the type was known, Julia would just hardcode the link to the right method and thus avoid a method table lookup or it would just copy-paste the content of the function to avoid jumping at all. This is called inlining.

To see that in action, let's go on a little digression and take a look at optimised code for the case when the types are known. For this, consider two simple functions:

inner(x) = 2 * x;
outer(x) = 3 * inner(x);

We can now call this for, say an Int and get an output:

outer(2)
12

Let's look at the LLVM code for this function:

with_terminal() do
    @code_llvm outer(2)
end
�[90m;  @ /builds/rikh/blog/posts/notebooks/inference.jl#==#879d0d78-0c25-45d6-92be-9350b85def96:1 within `outer`�[39m
�[95mdefine�[39m �[36mi64�[39m �[93m@julia_outer_2105�[39m�[33m(�[39m�[36mi64�[39m �[95msignext�[39m �[0m%0�[33m)�[39m �[0m#0 �[33m{�[39m
�[91mtop:�[39m
�[90m; ┌ @ int.jl:88 within `*`�[39m
   �[0m%1 �[0m= �[96m�[1mmul�[22m�[39m �[36mi64�[39m �[0m%0�[0m, �[33m6�[39m
   �[96m�[1mret�[22m�[39m �[36mi64�[39m �[0m%1
�[90m; └�[39m
�[33m}�[39m

Hopefully, you're now thinking "WOW!". The compiler figured out that inner is just 2 * x so there is no need to step into that function, we can just calculate 2 * x directly. But then, it figures out that 2 * 3 * x = 6 * x, so we can get the answer in one LLVM instruction.

On the other hand, what if we add a Base.inferencebarrier to block inference inside the outer function:

blocked_outer(x) = 3 * inner(Base.inferencebarrier(x));
with_terminal() do
    @code_llvm blocked_outer(2)
end
�[90m;  @ /builds/rikh/blog/posts/notebooks/inference.jl#==#e02d6f85-22da-4262-9dad-3911c2222280:1 within `blocked_outer`�[39m
�[95mdefine�[39m �[95mnonnull�[39m �[33m{�[39m�[33m}�[39m�[0m* �[93m@julia_blocked_outer_2143�[39m�[33m(�[39m�[36mi64�[39m �[95msignext�[39m �[0m%0�[33m)�[39m �[0m#0 �[33m{�[39m
�[91mtop:�[39m
  �[0m%1 �[0m= �[96m�[1malloca�[22m�[39m �[33m[�[39m�[33m2�[39m �[0mx �[33m{�[39m�[33m}�[39m�[0m*�[33m]�[39m�[0m, �[95malign�[39m �[33m8�[39m
  �[0m%gcframe2 �[0m= �[96m�[1malloca�[22m�[39m �[33m[�[39m�[33m3�[39m �[0mx �[33m{�[39m�[33m}�[39m�[0m*�[33m]�[39m�[0m, �[95malign�[39m �[33m16�[39m
  �[0m%gcframe2.sub �[0m= �[96m�[1mgetelementptr�[22m�[39m �[95minbounds�[39m �[33m[�[39m�[33m3�[39m �[0mx �[33m{�[39m�[33m}�[39m�[0m*�[33m]�[39m�[0m, �[33m[�[39m�[33m3�[39m �[0mx �[33m{�[39m�[33m}�[39m�[0m*�[33m]�[39m�[0m* �[0m%gcframe2�[0m, �[36mi64�[39m �[33m0�[39m�[0m, �[36mi64�[39m �[33m0�[39m
  �[0m%.sub �[0m= �[96m�[1mgetelementptr�[22m�[39m �[95minbounds�[39m �[33m[�[39m�[33m2�[39m �[0mx �[33m{�[39m�[33m}�[39m�[0m*�[33m]�[39m�[0m, �[33m[�[39m�[33m2�[39m �[0mx �[33m{�[39m�[33m}�[39m�[0m*�[33m]�[39m�[0m* �[0m%1�[0m, �[36mi64�[39m �[33m0�[39m�[0m, �[36mi64�[39m �[33m0�[39m
  �[0m%2 �[0m= �[96m�[1mbitcast�[22m�[39m �[33m[�[39m�[33m3�[39m �[0mx �[33m{�[39m�[33m}�[39m�[0m*�[33m]�[39m�[0m* �[0m%gcframe2 �[95mto�[39m �[36mi8�[39m�[0m*
  �[96m�[1mcall�[22m�[39m �[36mvoid�[39m �[[email protected]�[39m�[33m(�[39m�[36mi8�[39m�[0m* �[95malign�[39m �[33m16�[39m �[0m%2�[0m, �[36mi8�[39m �[33m0�[39m�[0m, �[36mi64�[39m �[33m24�[39m�[0m, �[36mi1�[39m �[95mtrue�[39m�[33m)�[39m
  �[0m%thread_ptr �[0m= �[95mcall�[39m �[36mi8�[39m�[0m* �[95masm�[39m �[0m\"movq %fs:0, $0\"�[0m, �[0m\"=r\"�[33m(�[39m�[33m)�[39m �[0m#6
  �[0m%tls_ppgcstack �[0m= �[96m�[1mgetelementptr�[22m�[39m �[36mi8�[39m�[0m, �[36mi8�[39m�[0m* �[0m%thread_ptr�[0m, �[36mi64�[39m �[33m-8�[39m
  �[0m%3 �[0m= �[96m�[1mbitcast�[22m�[39m �[36mi8�[39m�[0m* �[0m%tls_ppgcstack �[95mto�[39m �[33m{�[39m�[33m}�[39m�[0m****
  �[0m%tls_pgcstack �[0m= �[96m�[1mload�[22m�[39m �[33m{�[39m�[33m}�[39m�[0m***�[0m, �[33m{�[39m�[33m}�[39m�[0m**** �[0m%3�[0m, �[95malign�[39m �[33m8�[39m
  �[0m%4 �[0m= �[96m�[1mbitcast�[22m�[39m �[33m[�[39m�[33m3�[39m �[0mx �[33m{�[39m�[33m}�[39m�[0m*�[33m]�[39m�[0m* �[0m%gcframe2 �[95mto�[39m �[36mi64�[39m�[0m*
  �[96m�[1mstore�[22m�[39m �[36mi64�[39m �[33m4�[39m�[0m, �[36mi64�[39m�[0m* �[0m%4�[0m, �[95malign�[39m �[33m16�[39m
  �[0m%5 �[0m= �[96m�[1mgetelementptr�[22m�[39m �[95minbounds�[39m �[33m[�[39m�[33m3�[39m �[0mx �[33m{�[39m�[33m}�[39m�[0m*�[33m]�[39m�[0m, �[33m[�[39m�[33m3�[39m �[0mx �[33m{�[39m�[33m}�[39m�[0m*�[33m]�[39m�[0m* �[0m%gcframe2�[0m, �[36mi64�[39m �[33m0�[39m�[0m, �[36mi64�[39m �[33m1�[39m
  �[0m%6 �[0m= �[96m�[1mbitcast�[22m�[39m �[33m{�[39m�[33m}�[39m�[0m** �[0m%5 �[95mto�[39m �[33m{�[39m�[33m}�[39m�[0m***
  �[0m%7 �[0m= �[96m�[1mload�[22m�[39m �[33m{�[39m�[33m}�[39m�[0m**�[0m, �[33m{�[39m�[33m}�[39m�[0m*** �[0m%tls_pgcstack�[0m, �[95malign�[39m �[33m8�[39m
  �[96m�[1mstore�[22m�[39m �[33m{�[39m�[33m}�[39m�[0m** �[0m%7�[0m, �[33m{�[39m�[33m}�[39m�[0m*** �[0m%6�[0m, �[95malign�[39m �[33m8�[39m
  �[0m%8 �[0m= �[96m�[1mbitcast�[22m�[39m �[33m{�[39m�[33m}�[39m�[0m*** �[0m%tls_pgcstack �[95mto�[39m �[33m{�[39m�[33m}�[39m�[0m***
  �[96m�[1mstore�[22m�[39m �[33m{�[39m�[33m}�[39m�[0m** �[0m%gcframe2.sub�[0m, �[33m{�[39m�[33m}�[39m�[0m*** �[0m%8�[0m, �[95malign�[39m �[33m8�[39m
  �[0m%9 �[0m= �[96m�[1mcall�[22m�[39m �[95mnonnull�[39m �[33m{�[39m�[33m}�[39m�[0m* �[93m@ijl_box_int64�[39m�[33m(�[39m�[36mi64�[39m �[95msignext�[39m �[0m%0�[33m)�[39m
  �[0m%10 �[0m= �[96m�[1mgetelementptr�[22m�[39m �[95minbounds�[39m �[33m[�[39m�[33m3�[39m �[0mx �[33m{�[39m�[33m}�[39m�[0m*�[33m]�[39m�[0m, �[33m[�[39m�[33m3�[39m �[0mx �[33m{�[39m�[33m}�[39m�[0m*�[33m]�[39m�[0m* �[0m%gcframe2�[0m, �[36mi64�[39m �[33m0�[39m�[0m, �[36mi64�[39m �[33m2�[39m
  �[96m�[1mstore�[22m�[39m �[33m{�[39m�[33m}�[39m�[0m* �[0m%9�[0m, �[33m{�[39m�[33m}�[39m�[0m** �[0m%10�[0m, �[95malign�[39m �[33m16�[39m
  �[96m�[1mstore�[22m�[39m �[33m{�[39m�[33m}�[39m�[0m* �[0m%9�[0m, �[33m{�[39m�[33m}�[39m�[0m** �[0m%.sub�[0m, �[95malign�[39m �[33m8�[39m
  �[0m%11 �[0m= �[96m�[1mcall�[22m�[39m �[95mnonnull�[39m �[33m{�[39m�[33m}�[39m�[0m* �[93m@ijl_apply_generic�[39m�[33m(�[39m�[33m{�[39m�[33m}�[39m�[0m* �[95minttoptr�[39m �[33m(�[39m�[36mi64�[39m �[33m137879257644608�[39m �[95mto�[39m �[33m{�[39m�[33m}�[39m�[0m*�[33m)�[39m�[0m, �[33m{�[39m�[33m}�[39m�[0m** �[95mnonnull�[39m �[0m%.sub�[0m, �[36mi32�[39m �[33m1�[39m�[33m)�[39m
  �[96m�[1mstore�[22m�[39m �[33m{�[39m�[33m}�[39m�[0m* �[0m%11�[0m, �[33m{�[39m�[33m}�[39m�[0m** �[0m%10�[0m, �[95malign�[39m �[33m16�[39m
  �[96m�[1mstore�[22m�[39m �[33m{�[39m�[33m}�[39m�[0m* �[95minttoptr�[39m �[33m(�[39m�[36mi64�[39m �[33m137879940236192�[39m �[95mto�[39m �[33m{�[39m�[33m}�[39m�[0m*�[33m)�[39m�[0m, �[33m{�[39m�[33m}�[39m�[0m** �[0m%.sub�[0m, �[95malign�[39m �[33m8�[39m
  �[0m%12 �[0m= �[96m�[1mgetelementptr�[22m�[39m �[95minbounds�[39m �[33m[�[39m�[33m2�[39m �[0mx �[33m{�[39m�[33m}�[39m�[0m*�[33m]�[39m�[0m, �[33m[�[39m�[33m2�[39m �[0mx �[33m{�[39m�[33m}�[39m�[0m*�[33m]�[39m�[0m* �[0m%1�[0m, �[36mi64�[39m �[33m0�[39m�[0m, �[36mi64�[39m �[33m1�[39m
  �[96m�[1mstore�[22m�[39m �[33m{�[39m�[33m}�[39m�[0m* �[0m%11�[0m, �[33m{�[39m�[33m}�[39m�[0m** �[0m%12�[0m, �[95malign�[39m �[33m8�[39m
  �[0m%13 �[0m= �[96m�[1mcall�[22m�[39m �[95mnonnull�[39m �[33m{�[39m�[33m}�[39m�[0m* �[93m@ijl_apply_generic�[39m�[33m(�[39m�[33m{�[39m�[33m}�[39m�[0m* �[95minttoptr�[39m �[33m(�[39m�[36mi64�[39m �[33m137879720038880�[39m �[95mto�[39m �[33m{�[39m�[33m}�[39m�[0m*�[33m)�[39m�[0m, �[33m{�[39m�[33m}�[39m�[0m** �[95mnonnull�[39m �[0m%.sub�[0m, �[36mi32�[39m �[33m2�[39m�[33m)�[39m
  �[0m%14 �[0m= �[96m�[1mload�[22m�[39m �[33m{�[39m�[33m}�[39m�[0m*�[0m, �[33m{�[39m�[33m}�[39m�[0m** �[0m%5�[0m, �[95malign�[39m �[33m8�[39m
  �[0m%15 �[0m= �[96m�[1mbitcast�[22m�[39m �[33m{�[39m�[33m}�[39m�[0m*** �[0m%tls_pgcstack �[95mto�[39m �[33m{�[39m�[33m}�[39m�[0m**
  �[96m�[1mstore�[22m�[39m �[33m{�[39m�[33m}�[39m�[0m* �[0m%14�[0m, �[33m{�[39m�[33m}�[39m�[0m** �[0m%15�[0m, �[95malign�[39m �[33m8�[39m
  �[96m�[1mret�[22m�[39m �[33m{�[39m�[33m}�[39m�[0m* �[0m%13
�[33m}�[39m

To see the difference in running time, we can compare the output @benchmark for both:

using BenchmarkTools: @benchmark
@benchmark outer(2)
BenchmarkTools.Trial: 10000 samples with 1000 evaluations.
 Range (min … max):  1.549 ns … 37.220 ns  ┊ GC (min … max): 0.00% … 0.00%
 Time  (median):     2.140 ns              ┊ GC (median):    0.00%
 Time  (mean ± σ):   2.025 ns ±  0.831 ns  ┊ GC (mean ± σ):  0.00% ± 0.00%

  █                                                           
  █▂▁▁▁▂▁▁▁▁▁▂▂▁▁▁▂▁▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▃▃▃▄▃▃▅▃▄▅▄▅▃▃▃▃▃▃▂▂ ▂
  1.55 ns        Histogram: frequency by time        2.34 ns <

 Memory estimate: 0 bytes, allocs estimate: 0.
@benchmark blocked_outer(2)
BenchmarkTools.Trial: 10000 samples with 993 evaluations.
 Range (min … max):  33.555 ns …  10.152 μs  ┊ GC (min … max): 0.00% … 0.00%
 Time  (median):     57.503 ns               ┊ GC (median):    0.00%
 Time  (mean ± σ):   60.659 ns ± 209.277 ns  ┊ GC (mean ± σ):  0.00% ± 0.00%

  ▇▅                   ▁▂▃▃▃▂▂▂▂▂▃▄▆███▇▆▂▁▁▁▂▃▂▃▂  ▁▂▂▃▂▂▁    ▃
  ██▇█▇▇▇▆▆▅▆▆▆▁▃▃▃▄▃▃▅████████████████████████████████████▇▇▇ █
  33.6 ns       Histogram: log(frequency) by time      74.1 ns <

 Memory estimate: 0 bytes, allocs estimate: 0.

So, even though benchmarks below 1 ns aren't reliable, we can see that the inferable function (outer) is much faster. Next, we'll show that this is not all due to having the extra call to Base.inferencebarrier.

We've seen that knowing the types is important for the compiler, so let's improve the type inference for the function above. We could fix it in a few ways. We could add a type hint at the function. For example, a type hint could look like this:

function with_type_hint(x)
    Base.inferrencebarrier(x)::Int
end;

With this, the output type of the function body is known:

with_terminal() do
    @code_warntype with_type_hint(1)
end
MethodInstance for Main.var\"workspace#7\".with_type_hint(::Int64)
  from with_type_hint(�[90mx�[39m)�[90m @�[39m �[90mMain.var\"workspace#7\"�[39m �[90m/builds/rikh/blog/posts/notebooks/�[39m�[90m�[4minference.jl#==#f6bece65-9735-4e4b-9e74-99735013fb93:1�[24m�[39m
Arguments
  #self#�[36m::Core.Const(Main.var\"workspace#7\".with_type_hint)�[39m
  x�[36m::Int64�[39m
Body�[36m::Int64�[39m
�[90m1 ─�[39m %1 = Base.inferrencebarrier�[91m�[1m::Any�[22m�[39m
�[90m│  �[39m %2 = (%1)(x)�[91m�[1m::Any�[22m�[39m
�[90m│  �[39m %3 = Core.typeassert(%2, Main.var\"workspace#7\".Int)�[36m::Int64�[39m
�[90m└──�[39m      return %3

which solves further inference problems if we use this method, but it is a bit risky. The Core.typeassert will assert the type and throw an error if the type turns out to be wrong. This hinders writing generic code. Also, it takes the system a little bit of time to actually assert the type.

So, instead it would be better to go to the root of the problem. Above, we had a dictionary numbers:

numbers
Dict{Symbol, AbstractFloat} with 2 entries:
  :two => 2.0
  :one => 1.0

The type is:

typeof(numbers)
Dict{Symbol, AbstractFloat}

Where AbstractFloat is a non-concrete type meaning that it cannot have direct instance values, and more importantly meaning that we cannot say with certainty which method should be called for an object of such a type.

We can make this type concrete by manually specifying the type of the dictionary. Now, Julia will automatically convert our Float32 to a Float64:

typednumbers = Dict{Symbol, Float64}(:one => 1f0, :two => 2.0);

Let's look again to the @code_warntype:

with_terminal() do
    @code_warntype use_double(typednumbers, :one)
end
MethodInstance for Main.var\"workspace#7\".use_double(::Dict{Symbol, Float64}, ::Symbol)
  from use_double(�[90mmapping�[39m, �[90mx�[39m)�[90m @�[39m �[90mMain.var\"workspace#7\"�[39m �[90m/builds/rikh/blog/posts/notebooks/�[39m�[90m�[4minference.jl#==#da60c86b-ed10-4475-988c-099d5929946e:1�[24m�[39m
Arguments
  #self#�[36m::Core.Const(Main.var\"workspace#7\".use_double)�[39m
  mapping�[36m::Dict{Symbol, Float64}�[39m
  x�[36m::Symbol�[39m
Locals
  doubled�[36m::Float64�[39m
Body�[36m::String�[39m
�[90m1 ─�[39m %1 = Main.var\"workspace#7\".double(mapping, x)�[36m::Float64�[39m
�[90m│  �[39m      (doubled = 2 * %1)
�[90m│  �[39m %3 = Main.var\"workspace#7\".string(doubled)�[36m::String�[39m
�[90m└──�[39m      return %3

Great! So, this is now exactly the same function as above, but all the types are known and the compiler is happy.

Let's run the benchmarks for both numbers and typednumbers:

@benchmark use_double(numbers, :one)
BenchmarkTools.Trial: 10000 samples with 580 evaluations.
 Range (min … max):  134.931 ns … 48.045 μs  ┊ GC (min … max): 0.00% … 0.00%
 Time  (median):     264.742 ns              ┊ GC (median):    0.00%
 Time  (mean ± σ):   373.725 ns ±  1.463 μs  ┊ GC (mean ± σ):  3.53% ± 4.26%

                                  ▄█▁                           
  ▃▂▂▂▂▂▁▁▂▂▂▂▂▂▂▂▁▂▂▂▂▂▂▂▂▂▃▃▃▃▄▆████▇▅▅▄▃▃▃▃▃▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂ ▃
  135 ns          Histogram: frequency by time          365 ns <

 Memory estimate: 168 bytes, allocs estimate: 4.
@benchmark use_double(typednumbers, :one)
BenchmarkTools.Trial: 10000 samples with 918 evaluations.
 Range (min … max):   81.983 ns …  12.172 μs  ┊ GC (min … max): 0.00% … 0.00%
 Time  (median):     156.797 ns               ┊ GC (median):    0.00%
 Time  (mean ± σ):   187.358 ns ± 495.181 ns  ┊ GC (mean ± σ):  5.03% ± 7.88%

  ▄▁▄▅█▆▂                                                       ▁
  ███████▆▅▄▄▁▃▆▇█▇▄▁▁▁▁▃▄▁▁▃▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▃▁▁▃▁▁▄▃▃▁▁▁▁▁▁▁▃▁▄ █
  82 ns         Histogram: log(frequency) by time        1.1 μs <

 Memory estimate: 432 bytes, allocs estimate: 2.

So, that's a reduction in running time which we basically got for free. The only thing we needed to do was look through our naive code and help out the compiler a bit by adding more information.

And this is exactly what I find so beautiful about the Julia language. You have this high-level language where you can be very expressive, write in whatever style you want and don't have to bother about putting type annotations on all your functions. Instead, you first focus on your proof of concept and get your code working and only then you start digging into optimizing your code. To do this, you can often get pretty far already by looking at @code_warntype.

But, what if your code contains more than a few functions? Let's take a look at some of the available tooling.

Tooling

The most common tool for improving performance is a profiler. Julia has a profiler in the standard library:

using Profile

This is a sampling-based profiler meaning that it takes samples to estimate how much time is spent in each function.

@profile foreach(x -> blocked_outer(2), 1:100)

We can now call Profile.print() to see the output and how many samples were taken in each function. However, in most cases we want to have a nice plot. Here, I use ProfileSVG.jl, but other options are also listed in the Julia Profiling documentation. See especially PProf.jl since that viewer can show graphs as well as flame graphs.

using ProfileSVG: @profview
@profview foreach(x -> blocked_outer(2), 1:10_000_000)
Profile results in :-1#1055 in Base.jl:608profile_printing_listener in Base.jl:572wait in asyncevent.jl:159_trywait in asyncevent.jl:142wait in condition.jl:125#wait#645 in condition.jl:130wait in task.jl:1008poptask in task.jl:999#1 in worker.jl:120eval in boot.jl:385run_expression in PlutoRunner.jl:513#run_expression#31 in PlutoRunner.jl:595with_logger_and_io_to_logs in PlutoRunner.jl:2864#with_logger_and_io_to_logs#126 in PlutoRunner.jl:2865with_logger in logging.jl:627with_logstate in logging.jl:515#127 in PlutoRunner.jl:2866with_io_to_logs in PlutoRunner.jl:2789#with_io_to_logs#123 in PlutoRunner.jl:2842#32 in PlutoRunner.jl:618run_expression in PlutoRunner.jl:513run_expression in PlutoRunner.jl:513#run_expression#31 in PlutoRunner.jl:595with_logger_and_io_to_logs in PlutoRunner.jl:2864#with_logger_and_io_to_logs#126 in PlutoRunner.jl:2865with_logger in logging.jl:627with_logstate in logging.jl:515#127 in PlutoRunner.jl:2866with_io_to_logs in PlutoRunner.jl:2789#with_io_to_logs#123 in PlutoRunner.jl:2842#32 in PlutoRunner.jl:599run_inside_trycatch in PlutoRunner.jl:481eval in boot.jl:385typeinf_ext_toplevel in typeinfer.jl:1078typeinf_ext_toplevel in typeinfer.jl:1082typeinf_ext in typeinfer.jl:1051typeinf in typeinfer.jl:216_typeinf in typeinfer.jl:247typeinf_nocycle in abstractinterpretation.jl:3186typeinf_local in abstractinterpretation.jl:3098abstract_eval_basic_statement in abstractinterpretation.jl:2913abstract_eval_statement in abstractinterpretation.jl:2624abstract_eval_statement_expr in abstractinterpretation.jl:2380abstract_eval_call in abstractinterpretation.jl:2370abstract_call in abstractinterpretation.jl:2354abstract_call in abstractinterpretation.jl:2162abstract_call in abstractinterpretation.jl:2169abstract_call_known in abstractinterpretation.jl:2087abstract_call_gf_by_type in abstractinterpretation.jl:95abstract_call_method in abstractinterpretation.jl:629typeinf_edge in typeinfer.jl:930typeinf in typeinfer.jl:216_typeinf in typeinfer.jl:247typeinf_nocycle in abstractinterpretation.jl:3186typeinf_local in abstractinterpretation.jl:3098abstract_eval_basic_statement in abstractinterpretation.jl:2913abstract_eval_statement in abstractinterpretation.jl:2624foreach in abstractarray.jl:3098iterate in range.jl:901== in promotion.jl:521#1 in inference.jl#==#a3cbe176-d1b9-485c-b655-ee9f164e84a3:1blocked_outer in inference.jl#==#e02d6f85-22da-4262-9dad-3911c2222280:1* in int.jl:88inner in inference.jl#==#0e9b74c1-723d-4ca1-8526-c77c12a17ee0:1* in int.jl:88* in int.jl:88inner in inference.jl#==#0e9b74c1-723d-4ca1-8526-c77c12a17ee0:1* in int.jl:88blocked_outer in inference.jl#==#e02d6f85-22da-4262-9dad-3911c2222280:1blocked_outer in inference.jl#==#e02d6f85-22da-4262-9dad-3911c2222280:1task_done_hook in task.jl:682wait in task.jl:1008poptask in task.jl:999

In this image, you can click on an element to see the location of the called function. Unfortunately, because the page that you're looking at was running inside a Pluto notebook, this output shows a bit of noise. You can focus on everything above eval in boot.jl to see where the time was spent. In essence, the idea is here that the wider a block, the more time is spent on it. Also, blocks which lay on top of other block indicate that they were called inside the outer block. As can be seen, the profiler is very useful to get an idea of which function takes the most time to run.

However, this doesn't tell us what is happening exactly. For that, we need to dive deeper and look critically at the source code of the function which takes long. Sometimes, that already provides enough information to see what can be optimized. In other cases, the problem isn't so obvious. Probably, there is a type inference problem because that can make huge differences as is shown in the section above. One way would then be to go to the function which takes the most time to run and see how the type inference looks via @code_warntype. Unfortunately, this can be a bit tricky. Consider, for example, a function with keyword arguments:

with_keyword_arguments(a; b=3) = a + b;
with_terminal() do
    @code_warntype with_keyword_arguments(1)
end
MethodInstance for Main.var\"workspace#7\".with_keyword_arguments(::Int64)
  from with_keyword_arguments(�[90ma�[39m; b)�[90m @�[39m �[90mMain.var\"workspace#7\"�[39m �[90m/builds/rikh/blog/posts/notebooks/�[39m�[90m�[4minference.jl#==#8a64d202-811c-4189-b306-103054acdc28:1�[24m�[39m
Arguments
  #self#�[36m::Core.Const(Main.var\"workspace#7\".with_keyword_arguments)�[39m
  a�[36m::Int64�[39m
Body�[36m::Int64�[39m
�[90m1 ─�[39m %1 = Main.var\"workspace#7\".:(var\"#with_keyword_arguments#3\")(3, #self#, a)�[36m::Int64�[39m
�[90m└──�[39m      return %1

Here, we don't see the a + b as we would expect, but instead see that the with_keyword_arguments calls another function without keyword arguments. Now, we would need to manually call this nested function with a generated name var"#with_keyword_arguments#1" with exactly the right inputs to see what @code_warntype does exactly inside this function. Even worse, imagine that you have a function which calls a function which calls a function...

To solve this, there is Cthulhu.jl. With Cthulhu, it is possible to @descend into a function and see the code warntype. Next, the arrow keys and enter can be used to step into a function and see the code warntype for that. By continuously stepping into and out of functions, it is much easier to see what code is calling what and where exactly the type inference starts to fail. Often, by solving a type inference problem at exactly the right spot, inference problems for a whole bunch of functions can be fixed. For more information about Cthulhu, see the GitHub page linked above.

A complementary tool to find the root of type problems is JET.jl. Basically, this tool can automate the process described above. It relies on Julia's compiler and can point to the root of type inference problems. Let's do a demo. Here, we use the optimization analysis:

using JET: @report_opt
@report_opt blocked_outer(2)
═════ 2 possible errors found ═════
┌ blocked_outer(x::Int64) @ Main.var"workspace#7" /builds/rikh/blog/posts/notebooks/inference.jl#==#e02d6f85-22da-4262-9dad-3911c2222280:1
│ runtime dispatch detected: Main.var"workspace#7".inner(x::Int64)::Any
└────────────────────
┌ blocked_outer(x::Int64) @ Main.var"workspace#7" /builds/rikh/blog/posts/notebooks/inference.jl#==#e02d6f85-22da-4262-9dad-3911c2222280:1
│ runtime dispatch detected: (3 Main.var"workspace#7".:* %1::Any)::Any
└────────────────────

In this case, the tool points out exactly the problem we've had. Namely, because the function definition is 3 * inner(Base.inferencebarrier(x)), the inner function call cannot be optimized because the type is unknown at that point. Also, the output of inner(Base.inferencebarrier(x)) is unkown and we have another runtime dispatch.

For extremely long outputs, it can be useful to print the output of JET to a file to easily navigate through the output.

These are the most important tools to improve performance. If this is all you care about, then feel free to stop reading here. In the next section, let's take a look at how to reduce the time to first X.

Precompilation

As described above, Julia does lots of optimizations on your code. For example, it removes unnecessary function calls and hardcodes method calls if possible. This takes time and that is a problem. Like said above, Makie runs extremely quick after the first time that you have created a plot going from 40 seconds to something like 0.001 seconds. And, we need to wait all these seconds every time that we restart Julia. Of course, Julia developers don't develop by changing their plotting code and wait 40 seconds to see the output. We use tools such as Pluto.jl or Revise.jl to use code changes without restarting Julia. Still, sometimes it is necessary to restart Julia, so what can we do to reduce the compilation time?

Well, we can reduce the compilation time by shouting I am the compiler now! and write optimized code manually. For example, this is done in OrdinaryDiffEq.jl#1465. In some cases, this can be a great last-resort solution to make some compilation time disappear.

However, it is quite laborious and not suitable in all cases. A very nice alternative idea is to move the compilation time into the precompilation stage. Precompilation occurs right after package installation or when loading a package after it has been changed. The results of this compilation are retained even after restarting the Julia instance. So, instead of having to compile things for each restart, we just compile it only when changing the package! Sounds like a good deal.

It is a good deal. Except, we have to note that we're working with the Julia language. Not all functions have typed arguments let alone concretely typed arguments, so the precompile phase cannot always know what it should compile. Even more, Julia by default doesn't compile all functions with concretely typed arguments. It just assumes that some function will probably not be used, so no need to precompile it. This is on purpose, to avoid developers putting concrete types everywhere which would make Julia packages less composable which is a very fair argument.

Anyway, we can fix this by adding precompile directives ourselves. For example, we can create a new function, call precompile on it for integers and look at the existing method specializations:

begin
    add_one(x) = x + 1
    precompile(add_one, (Int,))
    methods(add_one)[1].specializations
end
MethodInstance for Main.var"workspace#7".add_one(::Int64)

A method specialization is just another way of saying a compiled instance for a method. So, a specialization is always for some concrete types. This method specialization shows that add_one is compiled even though we haven't called add_one yet. The function is completely ready for use for integers. If we pass another type, the function would still need to compile.

What is nice about this is that the precompile will compile everything recursively. So, say, we have a large codebase handling some kind of notebooks and the package has some kind of open function with concrete types such as a ServerSession to open the notebook into and a String with the path for the notebook location, then we can add a precompile on that function as follows:

precompile(open, (ServerSession, String))

Inside this large codebase. Since the open function is calling many other functions, the precompile will compile many functions and can reduce the time to first X by a lot. This is what happened in Pluto.jl#1934. We've added one line of code to reduce the time to first open a notebook from 11 to 8 seconds. That is a 30% reduction in running time by adding one line of code. To figure out where you need to add precompile directives exactly, you can use SnoopCompile.jl.

Alas, now you probably wonder why we didn't have a 100% reduction. The answer is type inference. precompile will go through all the functions recursively but once the type becomes non-concrete, it cannot know what to compile. To fix this, we can use the tools presented above to fix type inference problems.

In conclusion, this is what I find so beautiful about the language. You can hack your proof-of-concept together in very naive ways and then throw on a few precompiles if you want to reduce the TTFX. Then, once you need performance, you can pinpoint what method takes the most time, look at the generated LLVM code and start fixing problems such as type inference. Improving the inferability will often make code more readable, it will reduce running time and it will reduce time to first X; all at the same time.

Acknowledgements

Thanks to Michael Helton, Rafael Fourquet and Guillaume Dalle for providing feedback on this blog post.

Appendix

Built with Julia 1.10.6 and

BenchmarkTools 1.4.0
JET 0.8.27
PlutoUI 0.7.55
ProfileSVG 0.2.1