I'm lately doing for the first time some optimizations of Julia code and I sort of find it super beautiful.
This is how I started a message on the Julia language Slack in response to a question about why optimising Julia code is so difficult compared to other languages. In the message I argued against that claim. Optimising isn't hard in Julia if you compare it to Python or R where you have to be an expert in Python or R and C/C++. Also, in that message I went through a high-level overview of how I approached optimising. The next day, Frames Catherine White, who is a true Julia veteran, suggested that I write a blog post about my overview, so here we are.
In this blog post, I'll describe what type stability is and why it is important for performance. Unlike most other posts, I'll discuss it in the context of performance (raw throughput) and in the context of time to first X (TTFX). Julia is sort of notorious for having really bad TTFX in certain cases. For example, creating a plot with the Makie.jl package takes 40 seconds at the time of writing. On the second call, it takes about 0.001 seconds. This blog post explains the workflow that you can use to reduce running time and TTFX.
Let's first talk about that type stability thing that everyone keeps talking about. Why is it important? To show this, let's write naive Julia code. Specifically, for this example, we write code which can hide the type from the compiler, that is, we need to add some kind of indirection so that the compiler cannot infer the types. This can be done via a dictionary. Note that our dictionary returns different types, namely an Float32
and a Float64
:
numbers = Dict(:one => 1f0, :two => 2.0);
function double(mapping, key::Symbol)
return 2 * mapping[key]
end;
This code works, we can pass :one
or :two
and the number will be doubled:
double(numbers, :one)
2.0f0
double(numbers, :two)
4.0
Let's look at the optimized LLVM code via @code_warntype
. Here, you can ignore the with_terminal
; it's only needed because this blog post is running in a Pluto.jl notebook.
using PlutoUI: with_terminal
with_terminal() do
@code_warntype double(numbers, :one)
end
MethodInstance for Main.var\"workspace#7\".double(::Dict{Symbol, AbstractFloat}, ::Symbol) from double(�[90mmapping�[39m, �[90mkey�[39m::�[1mSymbol�[22m)�[90m @�[39m �[90mMain.var\"workspace#7\"�[39m �[90m/builds/rikh/blog/posts/notebooks/�[39m�[90m�[4minference.jl#==#a36ef63b-436d-45e2-8edf-625df1575e7a:1�[24m�[39m Arguments #self#�[36m::Core.Const(Main.var\"workspace#7\".double)�[39m mapping�[36m::Dict{Symbol, AbstractFloat}�[39m key�[36m::Symbol�[39m Body�[91m�[1m::Any�[22m�[39m �[90m1 ─�[39m %1 = Main.var\"workspace#7\".:*�[36m::Core.Const(*)�[39m �[90m│ �[39m %2 = Base.getindex(mapping, key)�[91m�[1m::AbstractFloat�[22m�[39m �[90m│ �[39m %3 = (%1)(2, %2)�[91m�[1m::Any�[22m�[39m �[90m└──�[39m return %3
Ouch. The optimized code looks quite good with one Base.getindex
and a 2 * %1
, but we do get some big red warnings about the output type which is an Any
. That color indicates that something is wrong. What is wrong is that an Any
type cannot easily be put into a memory spot. For a concrete type such as Float64
, we know how much space we need so we don't need a pointer and we can even put the number nearer to the CPU so that it can quickly be accessed. To see whether a type is concrete, we can use isconcretetype
:
isconcretetype(Float64)
true
isconcretetype(AbstractFloat)
false
To make matters worse, Julia does a lot of optimizing, but it cannot do much for abstract types. For example, let's write two very simple functions:
function use_double(mapping, x)
doubled = 2 * double(mapping, x)
string(doubled)
end;
use_double(numbers, :one)
"4.0"
This is how the @code_warntype
looks:
with_terminal() do
@code_warntype use_double(numbers, :one)
end
MethodInstance for Main.var\"workspace#7\".use_double(::Dict{Symbol, AbstractFloat}, ::Symbol) from use_double(�[90mmapping�[39m, �[90mx�[39m)�[90m @�[39m �[90mMain.var\"workspace#7\"�[39m �[90m/builds/rikh/blog/posts/notebooks/�[39m�[90m�[4minference.jl#==#da60c86b-ed10-4475-988c-099d5929946e:1�[24m�[39m Arguments #self#�[36m::Core.Const(Main.var\"workspace#7\".use_double)�[39m mapping�[36m::Dict{Symbol, AbstractFloat}�[39m x�[36m::Symbol�[39m Locals doubled�[91m�[1m::Any�[22m�[39m Body�[91m�[1m::Any�[22m�[39m �[90m1 ─�[39m %1 = Main.var\"workspace#7\".:*�[36m::Core.Const(*)�[39m �[90m│ �[39m %2 = Main.var\"workspace#7\".double(mapping, x)�[91m�[1m::Any�[22m�[39m �[90m│ �[39m (doubled = (%1)(2, %2)) �[90m│ �[39m %4 = doubled�[91m�[1m::Any�[22m�[39m �[90m│ �[39m %5 = Main.var\"workspace#7\".string(%4)�[91m�[1m::Any�[22m�[39m �[90m└──�[39m return %5
The Any
type propagated. Now, also the use_naive_double
function has an Any
output type. And, the type of the variable doubled
isn't known when the function is compiled meaning that the call string(doubled)
ends up being a runtime dispatch. This means that Julia has to lookup the right method during running time in the method lookup table. If the type was known, Julia would just hardcode the link to the right method and thus avoid a method table lookup or it would just copy-paste the content of the function to avoid jumping at all. This is called inlining.
To see that in action, let's go on a little digression and take a look at optimised code for the case when the types are known. For this, consider two simple functions:
inner(x) = 2 * x;
outer(x) = 3 * inner(x);
We can now call this for, say an Int
and get an output:
outer(2)
12
Let's look at the LLVM code for this function:
with_terminal() do
@code_llvm outer(2)
end
�[90m; Function Signature: outer(Int64)�[39m �[90m; @ /builds/rikh/blog/posts/notebooks/inference.jl#==#879d0d78-0c25-45d6-92be-9350b85def96:1 within `outer`�[39m �[95mdefine�[39m �[36mi64�[39m �[93m@julia_outer_13588�[39m�[33m(�[39m�[36mi64�[39m �[95msignext�[39m �[0m%\"x::Int64\"�[33m)�[39m �[0m#0 �[33m{�[39m �[91mtop:�[39m �[90m; ┌ @ int.jl:88 within `*`�[39m �[0m%0 �[0m= �[96m�[1mmul�[22m�[39m �[36mi64�[39m �[0m%\"x::Int64\"�[0m, �[33m6�[39m �[96m�[1mret�[22m�[39m �[36mi64�[39m �[0m%0 �[90m; └�[39m �[33m}�[39m
Hopefully, you're now thinking "WOW!". The compiler figured out that inner
is just 2 * x
so there is no need to step into that function, we can just calculate 2 * x
directly. But then, it figures out that 2 * 3 * x = 6 * x
, so we can get the answer in one LLVM instruction.
On the other hand, what if we add a Base.inferencebarrier
to block inference inside the outer function:
blocked_outer(x) = 3 * inner(Base.inferencebarrier(x));
with_terminal() do
@code_llvm blocked_outer(2)
end
�[90m; Function Signature: blocked_outer(Int64)�[39m �[90m; @ /builds/rikh/blog/posts/notebooks/inference.jl#==#e02d6f85-22da-4262-9dad-3911c2222280:1 within `blocked_outer`�[39m �[95mdefine�[39m �[95mnonnull�[39m �[95mptr�[39m �[93m@julia_blocked_outer_13915�[39m�[33m(�[39m�[36mi64�[39m �[95msignext�[39m �[0m%\"x::Int64\"�[33m)�[39m �[0m#0 �[33m{�[39m �[91mtop:�[39m �[0m%jlcallframe1 �[0m= �[96m�[1malloca�[22m�[39m �[33m[�[39m�[33m2�[39m �[0mx �[95mptr�[39m�[33m]�[39m�[0m, �[95malign�[39m �[33m8�[39m �[0m%gcframe2 �[0m= �[96m�[1malloca�[22m�[39m �[33m[�[39m�[33m3�[39m �[0mx �[95mptr�[39m�[33m]�[39m�[0m, �[95malign�[39m �[33m16�[39m �[96m�[1mcall�[22m�[39m �[36mvoid�[39m �[[email protected]�[39m�[33m(�[39m�[95mptr�[39m �[95malign�[39m �[33m16�[39m �[0m%gcframe2�[0m, �[36mi8�[39m �[33m0�[39m�[0m, �[36mi64�[39m �[33m24�[39m�[0m, �[36mi1�[39m �[95mtrue�[39m�[33m)�[39m �[0m%thread_ptr �[0m= �[95mcall�[39m �[95mptr�[39m �[95masm�[39m �[0m\"movq %fs:0, $0\"�[0m, �[0m\"=r\"�[33m(�[39m�[33m)�[39m �[0m#8 �[0m%tls_ppgcstack �[0m= �[96m�[1mgetelementptr�[22m�[39m �[36mi8�[39m�[0m, �[95mptr�[39m �[0m%thread_ptr�[0m, �[36mi64�[39m �[33m-8�[39m �[0m%tls_pgcstack �[0m= �[96m�[1mload�[22m�[39m �[95mptr�[39m�[0m, �[95mptr�[39m �[0m%tls_ppgcstack�[0m, �[95malign�[39m �[33m8�[39m �[96m�[1mstore�[22m�[39m �[36mi64�[39m �[33m4�[39m�[0m, �[95mptr�[39m �[0m%gcframe2�[0m, �[95malign�[39m �[33m16�[39m �[0m%frame.prev �[0m= �[96m�[1mgetelementptr�[22m�[39m �[95minbounds�[39m �[95mptr�[39m�[0m, �[95mptr�[39m �[0m%gcframe2�[0m, �[36mi64�[39m �[33m1�[39m �[0m%task.gcstack �[0m= �[96m�[1mload�[22m�[39m �[95mptr�[39m�[0m, �[95mptr�[39m �[0m%tls_pgcstack�[0m, �[95malign�[39m �[33m8�[39m �[96m�[1mstore�[22m�[39m �[95mptr�[39m �[0m%task.gcstack�[0m, �[95mptr�[39m �[0m%frame.prev�[0m, �[95malign�[39m �[33m8�[39m �[96m�[1mstore�[22m�[39m �[95mptr�[39m �[0m%gcframe2�[0m, �[95mptr�[39m �[0m%tls_pgcstack�[0m, �[95malign�[39m �[33m8�[39m �[0m%box_Int64 �[0m= �[96m�[1mcall�[22m�[39m �[95mnonnull�[39m �[95malign�[39m �[33m8�[39m �[95mdereferenceable�[39m�[33m(�[39m�[33m8�[39m�[33m)�[39m �[95mptr�[39m �[93m@ijl_box_int64�[39m�[33m(�[39m�[36mi64�[39m �[95msignext�[39m �[0m%\"x::Int64\"�[33m)�[39m �[0m#2 �[0m%gc_slot_addr_0 �[0m= �[96m�[1mgetelementptr�[22m�[39m �[95minbounds�[39m �[95mptr�[39m�[0m, �[95mptr�[39m �[0m%gcframe2�[0m, �[36mi64�[39m �[33m2�[39m �[96m�[1mstore�[22m�[39m �[95mptr�[39m �[0m%box_Int64�[0m, �[95mptr�[39m �[0m%gc_slot_addr_0�[0m, �[95malign�[39m �[33m16�[39m �[96m�[1mstore�[22m�[39m �[95mptr�[39m �[0m%box_Int64�[0m, �[95mptr�[39m �[0m%jlcallframe1�[0m, �[95malign�[39m �[33m8�[39m �[0m%0 �[0m= �[96m�[1mcall�[22m�[39m �[95mnonnull�[39m �[95mptr�[39m �[93m@ijl_apply_generic�[39m�[33m(�[39m�[95mptr�[39m �[95mnonnull�[39m �[93m@\"jl_global#13919.jit\"�[39m�[0m, �[95mptr�[39m �[95mnonnull�[39m �[0m%jlcallframe1�[0m, �[36mi32�[39m �[33m1�[39m�[33m)�[39m �[96m�[1mstore�[22m�[39m �[95mptr�[39m �[0m%0�[0m, �[95mptr�[39m �[0m%gc_slot_addr_0�[0m, �[95malign�[39m �[33m16�[39m �[96m�[1mstore�[22m�[39m �[95mptr�[39m �[93m@\"jl_global#13922.jit\"�[39m�[0m, �[95mptr�[39m �[0m%jlcallframe1�[0m, �[95malign�[39m �[33m8�[39m �[0m%1 �[0m= �[96m�[1mgetelementptr�[22m�[39m �[95minbounds�[39m �[95mptr�[39m�[0m, �[95mptr�[39m �[0m%jlcallframe1�[0m, �[36mi64�[39m �[33m1�[39m �[96m�[1mstore�[22m�[39m �[95mptr�[39m �[0m%0�[0m, �[95mptr�[39m �[0m%1�[0m, �[95malign�[39m �[33m8�[39m �[0m%2 �[0m= �[96m�[1mcall�[22m�[39m �[95mnonnull�[39m �[95mptr�[39m �[93m@ijl_apply_generic�[39m�[33m(�[39m�[95mptr�[39m �[95mnonnull�[39m �[93m@\"jl_global#13921.jit\"�[39m�[0m, �[95mptr�[39m �[95mnonnull�[39m �[0m%jlcallframe1�[0m, �[36mi32�[39m �[33m2�[39m�[33m)�[39m �[0m%frame.prev7 �[0m= �[96m�[1mload�[22m�[39m �[95mptr�[39m�[0m, �[95mptr�[39m �[0m%frame.prev�[0m, �[95malign�[39m �[33m8�[39m �[96m�[1mstore�[22m�[39m �[95mptr�[39m �[0m%frame.prev7�[0m, �[95mptr�[39m �[0m%tls_pgcstack�[0m, �[95malign�[39m �[33m8�[39m �[96m�[1mret�[22m�[39m �[95mptr�[39m �[0m%2 �[33m}�[39m
To see the difference in running time, we can compare the output @benchmark
for both:
using BenchmarkTools: @benchmark
@benchmark outer(2)
BenchmarkTools.Trial: 10000 samples with 1000 evaluations. Range (min … max): 1.800 ns … 10.026 μs ┊ GC (min … max): 0.00% … 0.00% Time (median): 2.260 ns ┊ GC (median): 0.00% Time (mean ± σ): 4.310 ns ± 141.748 ns ┊ GC (mean ± σ): 0.00% ± 0.00% ▂▃▂▇█▃ ▄▆▄▅██████▄▃▂▂▂▁▁▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▂▂ ▃ 1.8 ns Histogram: frequency by time 5.57 ns < Memory estimate: 0 bytes, allocs estimate: 0.
@benchmark blocked_outer(2)
BenchmarkTools.Trial: 10000 samples with 985 evaluations. Range (min … max): 52.934 ns … 406.944 ns ┊ GC (min … max): 0.00% … 0.00% Time (median): 61.938 ns ┊ GC (median): 0.00% Time (mean ± σ): 63.289 ns ± 6.585 ns ┊ GC (mean ± σ): 0.00% ± 0.00% ▁▄▆█▇▆▄▂ ▁▁▁▁▁▁▂▂▃▄▅▅▆██████████▆▃▂▂▂▁▂▂▂▂▂▂▂▂▂▂▂▂▁▁▁▂▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁ ▃ 52.9 ns Histogram: frequency by time 82.6 ns < Memory estimate: 0 bytes, allocs estimate: 0.
So, even though benchmarks below 1 ns aren't reliable, we can see that the inferable function (outer
) is much faster. Next, we'll show that this is not all due to having the extra call to Base.inferencebarrier
.
We've seen that knowing the types is important for the compiler, so let's improve the type inference for the function above. We could fix it in a few ways. We could add a type hint at the function. For example, a type hint could look like this:
function with_type_hint(x)
Base.inferrencebarrier(x)::Int
end;
With this, the output type of the function body is known:
with_terminal() do
@code_warntype with_type_hint(1)
end
MethodInstance for Main.var\"workspace#7\".with_type_hint(::Int64) from with_type_hint(�[90mx�[39m)�[90m @�[39m �[90mMain.var\"workspace#7\"�[39m �[90m/builds/rikh/blog/posts/notebooks/�[39m�[90m�[4minference.jl#==#f6bece65-9735-4e4b-9e74-99735013fb93:1�[24m�[39m Arguments #self#�[36m::Core.Const(Main.var\"workspace#7\".with_type_hint)�[39m x�[36m::Int64�[39m Body�[36m::Int64�[39m �[90m1 ─�[39m %1 = Base.inferrencebarrier�[91m�[1m::Any�[22m�[39m �[90m│ �[39m %2 = (%1)(x)�[91m�[1m::Any�[22m�[39m �[90m│ �[39m %3 = Main.var\"workspace#7\".Int�[36m::Core.Const(Int64)�[39m �[90m│ �[39m %4 = Core.typeassert(%2, %3)�[36m::Int64�[39m �[90m└──�[39m return %4
which solves further inference problems if we use this method, but it is a bit risky. The Core.typeassert
will assert the type and throw an error if the type turns out to be wrong. This hinders writing generic code. Also, it takes the system a little bit of time to actually assert the type.
So, instead it would be better to go to the root of the problem. Above, we had a dictionary numbers
:
numbers
Dict{Symbol, AbstractFloat} with 2 entries: :two => 2.0 :one => 1.0
The type is:
typeof(numbers)
Dict{Symbol, AbstractFloat}
Where AbstractFloat
is a non-concrete type meaning that it cannot have direct instance values, and more importantly meaning that we cannot say with certainty which method should be called for an object of such a type.
We can make this type concrete by manually specifying the type of the dictionary. Now, Julia will automatically convert our Float32
to a Float64
:
typednumbers = Dict{Symbol, Float64}(:one => 1f0, :two => 2.0);
Let's look again to the @code_warntype
:
with_terminal() do
@code_warntype use_double(typednumbers, :one)
end
MethodInstance for Main.var\"workspace#7\".use_double(::Dict{Symbol, Float64}, ::Symbol) from use_double(�[90mmapping�[39m, �[90mx�[39m)�[90m @�[39m �[90mMain.var\"workspace#7\"�[39m �[90m/builds/rikh/blog/posts/notebooks/�[39m�[90m�[4minference.jl#==#da60c86b-ed10-4475-988c-099d5929946e:1�[24m�[39m Arguments #self#�[36m::Core.Const(Main.var\"workspace#7\".use_double)�[39m mapping�[36m::Dict{Symbol, Float64}�[39m x�[36m::Symbol�[39m Locals doubled�[36m::Float64�[39m Body�[36m::String�[39m �[90m1 ─�[39m %1 = Main.var\"workspace#7\".:*�[36m::Core.Const(*)�[39m �[90m│ �[39m %2 = Main.var\"workspace#7\".double(mapping, x)�[36m::Float64�[39m �[90m│ �[39m (doubled = (%1)(2, %2)) �[90m│ �[39m %4 = doubled�[36m::Float64�[39m �[90m│ �[39m %5 = Main.var\"workspace#7\".string(%4)�[36m::String�[39m �[90m└──�[39m return %5
Great! So, this is now exactly the same function as above, but all the types are known and the compiler is happy.
Let's run the benchmarks for both numbers
and typednumbers
:
@benchmark use_double(numbers, :one)
BenchmarkTools.Trial: 10000 samples with 550 evaluations. Range (min … max): 200.873 ns … 29.160 μs ┊ GC (min … max): 0.00% … 98.54% Time (median): 261.109 ns ┊ GC (median): 0.00% Time (mean ± σ): 288.076 ns ± 436.233 ns ┊ GC (mean ± σ): 5.82% ± 4.81% ▂▄██▆▃ ▁▁▁▂▂▂▂▂▂▂▂▂▂▃▃▄▅▅████████▅▅▄▄▄▃▄▄▄▃▃▃▃▃▄▃▃▃▃▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁ ▃ 201 ns Histogram: frequency by time 359 ns < Memory estimate: 168 bytes, allocs estimate: 5.
@benchmark use_double(typednumbers, :one)
BenchmarkTools.Trial: 10000 samples with 873 evaluations. Range (min … max): 80.195 ns … 23.536 μs ┊ GC (min … max): 0.00% … 0.00% Time (median): 169.650 ns ┊ GC (median): 0.00% Time (mean ± σ): 206.071 ns ± 464.042 ns ┊ GC (mean ± σ): 14.59% ± 11.24% ▆▁▆█▅ ▃▃ ▁ ██████▅▆███▇▄▃▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▄▁▄▄▃▄▄▁▃▁▄▄▅▄▄▄▄▄▄▅▅▄▅▅ █ 80.2 ns Histogram: log(frequency) by time 1.74 μs < Memory estimate: 432 bytes, allocs estimate: 3.
So, that's a reduction in running time which we basically got for free. The only thing we needed to do was look through our naive code and help out the compiler a bit by adding more information.
And this is exactly what I find so beautiful about the Julia language. You have this high-level language where you can be very expressive, write in whatever style you want and don't have to bother about putting type annotations on all your functions. Instead, you first focus on your proof of concept and get your code working and only then you start digging into optimizing your code. To do this, you can often get pretty far already by looking at @code_warntype
.
But, what if your code contains more than a few functions? Let's take a look at some of the available tooling.
The most common tool for improving performance is a profiler. Julia has a profiler in the standard library:
using Profile
This is a sampling-based profiler meaning that it takes samples to estimate how much time is spent in each function.
@profile foreach(x -> blocked_outer(2), 1:100)
We can now call Profile.print()
to see the output and how many samples were taken in each function. However, in most cases we want to have a nice plot. Here, I use ProfileSVG.jl, but other options are also listed in the Julia Profiling documentation. See especially PProf.jl since that viewer can show graphs as well as flame graphs.
using ProfileSVG: @profview
@profview foreach(x -> blocked_outer(2), 1:10_000_000)
In this image, you can click on an element to see the location of the called function. Unfortunately, because the page that you're looking at was running inside a Pluto notebook, this output shows a bit of noise. You can focus on everything above eval in boot.jl
to see where the time was spent. In essence, the idea is here that the wider a block, the more time is spent on it. Also, blocks which lay on top of other block indicate that they were called inside the outer block. As can be seen, the profiler is very useful to get an idea of which function takes the most time to run.
However, this doesn't tell us what is happening exactly. For that, we need to dive deeper and look critically at the source code of the function which takes long. Sometimes, that already provides enough information to see what can be optimized. In other cases, the problem isn't so obvious. Probably, there is a type inference problem because that can make huge differences as is shown in the section above. One way would then be to go to the function which takes the most time to run and see how the type inference looks via @code_warntype
. Unfortunately, this can be a bit tricky. Consider, for example, a function with keyword arguments:
with_keyword_arguments(a; b=3) = a + b;
with_terminal() do
@code_warntype with_keyword_arguments(1)
end
MethodInstance for Main.var\"workspace#7\".with_keyword_arguments(::Int64) from with_keyword_arguments(�[90ma�[39m; b)�[90m @�[39m �[90mMain.var\"workspace#7\"�[39m �[90m/builds/rikh/blog/posts/notebooks/�[39m�[90m�[4minference.jl#==#8a64d202-811c-4189-b306-103054acdc28:1�[24m�[39m Arguments #self#�[36m::Core.Const(Main.var\"workspace#7\".with_keyword_arguments)�[39m a�[36m::Int64�[39m Body�[36m::Int64�[39m �[90m1 ─�[39m %1 = Main.var\"workspace#7\".:(var\"#with_keyword_arguments#3\")�[36m::Core.Const(Main.var\"workspace#7\".var\"#with_keyword_arguments#3\")�[39m �[90m│ �[39m %2 = (%1)(3, #self#, a)�[36m::Int64�[39m �[90m└──�[39m return %2
Here, we don't see the a + b
as we would expect, but instead see that the with_keyword_arguments
calls another function without keyword arguments. Now, we would need to manually call this nested function with a generated name var"#with_keyword_arguments#1"
with exactly the right inputs to see what @code_warntype
does exactly inside this function. Even worse, imagine that you have a function which calls a function which calls a function...
To solve this, there is Cthulhu.jl. With Cthulhu, it is possible to @descend
into a function and see the code warntype. Next, the arrow keys and enter can be used to step into a function and see the code warntype for that. By continuously stepping into and out of functions, it is much easier to see what code is calling what and where exactly the type inference starts to fail. Often, by solving a type inference problem at exactly the right spot, inference problems for a whole bunch of functions can be fixed. For more information about Cthulhu, see the GitHub page linked above.
A complementary tool to find the root of type problems is JET.jl. Basically, this tool can automate the process described above. It relies on Julia's compiler and can point to the root of type inference problems. Let's do a demo. Here, we use the optimization analysis:
using JET: @report_opt
@report_opt blocked_outer(2)
═════ 2 possible errors found ═════ ┌ blocked_outer(x::Int64) @ Main.var"workspace#7" /builds/rikh/blog/posts/notebooks/inference.jl#==#e02d6f85-22da-4262-9dad-3911c2222280:1 │ runtime dispatch detected: Main.var"workspace#7".inner(%1::Any)::Any └──────────────────── ┌ blocked_outer(x::Int64) @ Main.var"workspace#7" /builds/rikh/blog/posts/notebooks/inference.jl#==#e02d6f85-22da-4262-9dad-3911c2222280:1 │ runtime dispatch detected: (3 Main.var"workspace#7".:* %2::Any)::Any └────────────────────
In this case, the tool points out exactly the problem we've had. Namely, because the function definition is 3 * inner(Base.inferencebarrier(x))
, the inner
function call cannot be optimized because the type is unknown at that point. Also, the output of inner(Base.inferencebarrier(x))
is unkown and we have another runtime dispatch.
For extremely long outputs, it can be useful to print the output of JET to a file to easily navigate through the output.
These are the most important tools to improve performance. If this is all you care about, then feel free to stop reading here. In the next section, let's take a look at how to reduce the time to first X.
As described above, Julia does lots of optimizations on your code. For example, it removes unnecessary function calls and hardcodes method calls if possible. This takes time and that is a problem. Like said above, Makie runs extremely quick after the first time that you have created a plot going from 40 seconds to something like 0.001 seconds. And, we need to wait all these seconds every time that we restart Julia. Of course, Julia developers don't develop by changing their plotting code and wait 40 seconds to see the output. We use tools such as Pluto.jl or Revise.jl to use code changes without restarting Julia. Still, sometimes it is necessary to restart Julia, so what can we do to reduce the compilation time?
Well, we can reduce the compilation time by shouting I am the compiler now! and write optimized code manually. For example, this is done in OrdinaryDiffEq.jl#1465. In some cases, this can be a great last-resort solution to make some compilation time disappear.
However, it is quite laborious and not suitable in all cases. A very nice alternative idea is to move the compilation time into the precompilation stage. Precompilation occurs right after package installation or when loading a package after it has been changed. The results of this compilation are retained even after restarting the Julia instance. So, instead of having to compile things for each restart, we just compile it only when changing the package! Sounds like a good deal.
It is a good deal. Except, we have to note that we're working with the Julia language. Not all functions have typed arguments let alone concretely typed arguments, so the precompile phase cannot always know what it should compile. Even more, Julia by default doesn't compile all functions with concretely typed arguments. It just assumes that some function will probably not be used, so no need to precompile it. This is on purpose, to avoid developers putting concrete types everywhere which would make Julia packages less composable which is a very fair argument.
Anyway, we can fix this by adding precompile directives ourselves. For example, we can create a new function, call precompile
on it for integers and look at the existing method specializations:
begin
add_one(x) = x + 1
precompile(add_one, (Int,))
methods(add_one)[1].specializations
end
MethodInstance for Main.var"workspace#7".add_one(::Int64)
A method specialization is just another way of saying a compiled instance for a method. So, a specialization is always for some concrete types. This method specialization shows that add_one
is compiled even though we haven't called add_one
yet. The function is completely ready for use for integers. If we pass another type, the function would still need to compile.
What is nice about this is that the precompile
will compile everything recursively. So, say, we have a large codebase handling some kind of notebooks and the package has some kind of open
function with concrete types such as a ServerSession
to open the notebook into and a String
with the path for the notebook location, then we can add a precompile on that function as follows:
precompile(open, (ServerSession, String))
Inside this large codebase. Since the open
function is calling many other functions, the precompile
will compile many functions and can reduce the time to first X by a lot. This is what happened in Pluto.jl#1934. We've added one line of code to reduce the time to first open a notebook from 11 to 8 seconds. That is a 30% reduction in running time by adding one line of code. To figure out where you need to add precompile directives exactly, you can use SnoopCompile.jl.
Alas, now you probably wonder why we didn't have a 100% reduction. The answer is type inference. precompile
will go through all the functions recursively but once the type becomes non-concrete, it cannot know what to compile. To fix this, we can use the tools presented above to fix type inference problems.
In conclusion, this is what I find so beautiful about the language. You can hack your proof-of-concept together in very naive ways and then throw on a few precompiles if you want to reduce the TTFX. Then, once you need performance, you can pinpoint what method takes the most time, look at the generated LLVM code and start fixing problems such as type inference. Improving the inferability will often make code more readable, it will reduce running time and it will reduce time to first X; all at the same time.
Thanks to Michael Helton, Rafael Fourquet and Guillaume Dalle for providing feedback on this blog post.
Built with Julia 1.11.2 and
BenchmarkTools 1.5.0