Runtime Code Specialization

Specializing code is nothing new, but I’m surprised that modern programming languages don’t attempt to make this easier.

Just to give you some context, the idea is that whenever you have an algorithm that takes multiple parameters where one or more of the parameters are fixed for some period of time, you specialize the code for the fixed parameters in order to achieve greater optimization at runtime. A key idea is that you may not know at compile time what the fixed value is, and the duration for which a value is fixed could be shorter than the execution of the program (mere milliseconds, even).

One example would be to compile a regular expression into machine code – at runtime – whose function is to take the input string and match it against the (fixed) regular expression. This can be many times faster than an algorithm based on interpreting the regular expression as you go.

Other examples are algorithms where there are runtime inefficiencies due to combinatorial explosion. Such as converting one image format to another. If you have enough different image formats it’s infeasible to write one function for each pair of formats, so you would typically write a single conversion function that converts from the source format into some high precision intermediate format, and then calls another function to convert to the destination format. This can be inefficient because you have to switch on the data format inside the conversion loop. With code specialization you look up the two conversion function once, and then call them directly (or even inline them, then simplify away the intermediate format when possible).

Now, all this can be done manually, and while things haven’t improved as much as I would like, things are getting better. In C# you could always use the System.Reflection.Emit namespace to generate MSIL at runtime, and then compile it into a method. This was tedious and low level, but better than generating x86 assembly at least! Starting with .Net 3.5 you can instead use expression trees. This is much better. Not only can you compile some simple lambda expression directly to expression trees, but even building them manually is much higher level.

For example, starting with a simple function to match “glob” patterns (e.g. “f?o.*“) like this:

 1 2 3 4 5 6 7 8 91011121314151617181920212223
bool match(string pattern, string testString, int i, int j)
{
    if (i == pattern.Length) 
        return j == testString.Length;

    switch (pattern[i])
    {
        case '?'    : 
            return matchHelper(pattern, testString, i+1, j+1);
        case '*'    :                    
            for (int k = j; k <= testString.Length; ++k)
            {
                if (matchHelper(pattern, testString, i+1, k)) 
                      return true;                        
            }
            return false;                

        default     : 
            return j < testString.Length && 
                   pattern[i] == testString[j] && 
                   matchHelper(pattern, testString, i+1, j+1);
    }
}

We can (manually) convert this to a function that, passed an expression holding the string to be tested and the starting index, generates an expression representing whether or not the pattern matches. We do this by interpreting the “pattern”, and generating code that just tests a passed in string expression against it. We don’t call any functions in the generated expression, but rather keep traversing the pattern doing as much of the testing as we can ahead of time, only putting the stuff that absolutely depends on the passed in test string into the expression tree:

 1 2 3 4 5 6 7 8 910111213141516171819202122232425262728293031323334353637383940414243444546474849505152
// Generates an expression of type bool
static Expression GenerateMatchExpression(string pattern, int i, ParameterExpression str, Expression j )
{
    var strLenExp = Expression.Property(str, StringLengthProp);

    // If the pattern is empty, we match if we've reached the end of the string
    if (i == pattern.Length) 
        return Expression.Equal(j, strLenExp);

    // Shave off a single char from the string
    if (pattern[i] == '?') 
        return GenerateMatchExpression(pattern, i+1, str, Expression.Increment(j));

    // Remove the wild card, and try to match all substrings
    if (pattern[i] == '*')
    {
        var subStringStart = Expression.Variable(typeof(int));              
        var breakLabel = Expression.Label(typeof(bool));
        return Expression.Block(typeof(bool),
            new ParameterExpression[] { subStringStart },
            // start at current index, we'll increment this 
            // to check all "tails", including empty string
            Expression.Assign(subStringStart, j), 
            Expression.Loop(
                Expression.Block(                        
                    // strStart > str.Length
                    Expression.IfThen(Expression.GreaterThan(subStringStart, strLenExp),
                            Expression.Break(breakLabel, Expression.Constant(false))
                    ),

                    //check substring
                    Expression.IfThen(GenerateMatchExpression(pattern, i + 1, str, subStringStart), 
                        Expression.Break(breakLabel, Expression.Constant(true))
                    ),
                    Expression.Assign(subStringStart, Expression.Increment(subStringStart))
                ),
                breakLabel
            )
        );
    }

    // Check a single character.
    return  Expression.AndAlso(  
                Expression.LessThan(j, strLenExp), 
                Expression.AndAlso(  
                    Expression.Equal( Expression.Constant(pattern[i]), 
                                      Expression.Call(str, GetCharsMethod, j)),
                    GenerateMatchExpression(pattern, i + 1, str, Expression.Increment(j))
                )
            );

}

Then we pass in the string and index by wrapping this expression in a lambda like so:

12345
var strParam = Expression.Parameter(typeof(string));
var exp = Expression.Lambda<Func<string,bool>>(
                GenerateMatchExpression(pattern, 0, strParam, Expression.Constant(0)), strParam
          );        
Func match = exp.Compile();

For example, if we pass in “a*b*c?“, we get an expression tree that looks like this:

 1 2 3 4 5 6 7 8 910111213141516171819202122232425262728293031323334353637383940414243
.Lambda #Lambda1(System.String $var1) {
    0 < $var1.Length && 'a' == .Call $var1.get_Chars(0) && .Block(System.Int32 $var2) {
        $var2 = .Increment(0);
        .Loop  {
            .Block() {
                .If ($var2 > $var1.Length) {
                    .Break #Label1 { False }
                } .Else {
                    .Default(System.Void)
                };
                .If (
                    $var2 < $var1.Length && 'b' == .Call $var1.get_Chars($var2) && .Block(System.Int32 $var3) {
                        $var3 = .Increment($var2);
                        .Loop  {
                            .Block() {
                                .If ($var3 > $var1.Length) {
                                    .Break #Label2 { False }
                                } .Else {
                                    .Default(System.Void)
                                };
                                .If (
                                    $var3 < $var1.Length && 'c' == .Call $var1.get_Chars($var3) && .Increment(.Increment($var3)) == $var1.Length
                                ) {
                                    .Break #Label2 { True }
                                } .Else {
                                    .Default(System.Void)
                                };
                                $var3 = .Increment($var3)
                            }
                        }
                        .LabelTarget #Label2:
                    }
                ) {
                    .Break #Label1 { True }
                } .Else {
                    .Default(System.Void)
                };
                $var2 = .Increment($var2)
            }
        }
        .LabelTarget #Label1:
    }
}

Note how we get one loop generated for each ‘*’ character, in a nested fashion, and how the entire thing is inlined and specialized. This should compile to pretty efficient code. Of course it could be even more efficient if we first generated a DFA from this and then compiled that, but the point here is code specialization, not finite automata.

Now, this works, but holy hell that took a lot of work to get right, and this was for a fairly trivial example. Imagine if this was something more complicated! Imagine if you had to maintain this! The horror!

So, what do we actually need? We need a language that allows you to specialize (almost) arbitrary functions for some of their parameters, returning a new function with all the constants folded away, virtual functions inlined, etc. If this was as easy as writing down some simple attributes you could imagine such a language would be able to achieve far superior performance to even native languages like C++ simply because it’s infeasible to do this kind of code specialization on a wide scale unless there’s language support for it. You could potentially write a C# function that does this. I.e. create an expression tree by wrapping the code to be specialized in a lambda, then pass this expression to a function that massages the resulting tree given that some of the parameters are constant. Perhaps the JIT compiler can even be convinced to do some of the optimizations for you once it knows that some of the inputs are constants, but more likely than not you have to manually propagate the fact that these things are truly fixed through the expression tree.

This probably isn’t as simple as I make it out to be. For example, you may want to specialize an animation function on the blend tree, while allowing the specific blend values of each node to vary each time you invoke it. This could be done by having the programmer manually separate these two pieces of data out and only specializing the first, but ideally you could mix and match “fixed” data from “varying” data and have specialization only “boil away” the fixed stuff, while storing references to the varying stuff.

My contention is that client-facing apps have a lot of code that could benefit from this sort of strategy. Thinking about something like a game engine running at 60Hz it’s amazing just how much time you spend doing the exact same thing over and over again every frame where the only thing that’s different from frame to frame is the specific integer or float that gets passed into the ALU, but most branches and vtables go the same way every time. Being able to flag things as being constant for some duration of time and avoid all that “interpretation overhead” would be a huge win in a bunch of cases.

The next step would be to have a runtime that automatically traces the actual execution and specializes things for you as you go. Much like javascript compilers do. However, I’m talking about applying this to highly optimized ahead-of-time compiled programs, so the number of things the tracer would have to look out for would presumably be less extensive than for a dynamic language (e.g. we’re not looking to discover that a variable is an integer – we know that already). We’re basically just looking to do some constant folding, some vtable lookups (and subsequent inlining) and stuff like that. The static compiler could flag where it might be profitable to do runtime tracing, perhaps.

I kinda like the idea of having the programmer explicitly define specializations, though, because it means you can put the onus on the programmer to ensure certain things that a compiler wouldn’t want to assume. Like non-divergence of recursive functions, which means that e.g. a tree traversal function could be specialized with each recursive call being blindly inlined. The programmer would be responsible for proving that, taken the specified parameters as constants, the recursion will eventually stop.

Are there any languages out there that do this?

Comment Form is loading comments...