Writing JIT Compiler on the Plane

Plane

Feel free to join the discussion on HackerNews.

Compilers always seemed a little bit like magic to me. You write the code in some language and then it spits out bitcode that a small crystall inside your computer understands.

Some databases ship with a specialized compiler inside them. This compiler can generate assembly code for parts of the query that are executed often. Such technique is called JIT compilation and can greatly reduce the interpretation overhead during query execution. The same technique is used in some virtual machines (for instance, JVM and Dart).

I have read some details about how the JIT compilers are implemented, but I have never written my own. Recently, I had quite a long flight and decided to try out and see if I can write a small JIT compiler without any Internet access.

I am using MacBook with 2.3 GHz Quad-Core Intel Core i7 and macOS 12.3.

Problem statement

We will imagine ourselves as database developers. Our customers often complain that their queries with arithmetic formulas run really slow, for example:

SELECT (MyColumn + 10) * 5 FROM MyTable

After much profiling, you figure out that 90% of execution time is spent computing formula (x + 10) * 5 for each row in the table. Having tried everything else, you decide to try to JIT compile simple arithmetic formulas with one integer column.

Unfortunately, the evil sharks in the Pacific Ocean chew some cables and you no longer have Internet access! You have no other choice but to use just your memory, common sense and man pages.

Learning from the best

I had a university course covering Linux basics and assembly programming. For example, I wrote a function searching a value in the binary tree using pure assembly. Understandably, my memory quickly erased such a horrifying experience leaving only general understanding of some basics.

Thankfully, I have a buddy to learn from - clang compiler. We can try to compile a simple function applying 0 optimizations and see what kind of assembly is generated. Let’s try that!

int getNumber() {
    return 42;
}

int main() {
    return 0;
}

We compile the code above with clang++ -std=c++17 test.cpp -o test and disassembly the generated code with objdump -d test. clang is able to output assembly directly, but I personally find the output of objdump slightly easier to read. In the relevant section, we see:

<__Z9getNumberv>:
55              pushq	%rbp
48 89 e5        movq	%rsp, %rbp
b8 2a 00 00 00  movl	$42, %eax
5d              popq	%rbp
c3              retq

Wonderful, we have an assembly code for function that always returns number 42. On the right you can see human-readable assembly instructions and on the left there is corresponding bitcode in the hex format. Let us try to write a simple JIT compiler that always generates the assembly we have just read.

Dynamically generated code

The question is: How to transform the bitcode above to a function in the code? In C++, there is a thing called reinterpret_cast. It allows one to convert from one type to another without performing any checks. We will use it to convert an array with bitcode into a pointer to callable function.

#include <iostream>

int main() {
    // Array with bitcode generated by clang.
    uint8_t functionCode[] = {
        0x55,
        0x48, 0x89, 0xe5,
        0xb8, 0x2a, 0x00, 0x00, 0x00,
        0x5d,
        0xc3
    };

    // C++ type meaning "a function that accepts 0 arguments and returns
    // an integer".
    using Type = int();

    // We convert a pointer to the array with bitcode into a pointer to the
    // function.
    auto function = reinterpret_cast<Type*>(functionCode);

    // Call the function and print the result.
    int result = function();
    std::cout << result << std::endl;

    return 0;
}

This program crashes with a segmentation fault (of course it does). Quick investigation with debugger reveals EXC_BAD_ACCESS error on the first instruction of the functionCode. A distant memory comes to my mind - all memory has one or more “modes”: read, write and execute, much like files on the filesystem. It totally makes sense that by default we can read and write from memory, but execution is forbidden for security reasons.

To mark some memory as executable, we will need to allocate it ourselves. Usually, mmap system call is used for that and man pages for it are most helpful. Let us try the code above again, but now execute it from memory allocated with mmap:

#include <iostream>
#include <cstdint>
#include <cstdio>

#include <sys/mman.h>
#include <sys/errno.h>

int main() {
    // Allocate memory which can be executed from.
    void* executableMemory = mmap(
        nullptr /* addr */,
        // We do not actually need 4096 bytes, but `mmap` can only
        // allocate `N * PAGE_SIZE` bytes at a time and `PAGE_SIZE` equals
        // to 4096 on my system.
        4096 /* len */,
        // Specify that we want to write and execute from this memory.
        PROT_WRITE | PROT_EXEC /* prot */,
        // `MAP_JIT` is a macOS specific flag. Man page says that this flag is
        // required in hardened runtime if `PROT_WRITE` and `PROT_EXEC` are
        // specified. I have no idea what is hardened runtime, so I decided to
        // add it for the good measure.
        MAP_PRIVATE | MAP_ANONYMOUS | MAP_JIT /* flags */,
        -1 /* fd */,
        0 /* offset */
    );
    if (executableMemory == MAP_FAILED) {
        std::cout << "mmap failed: " << strerror(errno) << std::endl;
        return -1;
    }

    uint8_t functionCode[] = {
        0x55,
        0x48, 0x89, 0xe5,
        0xb8, 0x2a, 0x00, 0x00, 0x00,
        0x5d,
        0xc3
    };

    // Copy our code into the executable memory.
    memcpy(executableMemory, functionCode, sizeof(functionCode));

    using Type = int();
    // This time, `function` refers to `executableMemory`, not `functionCode`.
    auto function = reinterpret_cast<Type*>(executableMemory);

    int result = function();
    std::cout << result << std::endl;

    munmap(executableMemory);

    return 0;
}

This code prints 42! I was genuinely surprised when it worked. Now, let us try to generate some assembly ourselves, but still with a little help from clang.

Manually generating assembly

First, let us refactor our code a little bit. We will move all the code generation into Compiler struct and replace static bitcode array with a dynamic one:

struct Compiler {
    void emit(uint8_t inst) {
        code.push_back(inst);
    }

    void emit(std::vector<uint8_t> instrs) {
        for (const auto& instr : instrs) {
            emit(instr);
        }
    }

    void compile() {
        // This is the same assembly that `clang` generated before.
        emit(0x55);
        emit({0x48, 0x89, 0xe5});
        emit({0xb8, 0x2a, 0x00, 0x00, 0x00});
        emit(0x5d);
        emit(0xc3);
    }

    std::vector<uint8_t> code;
};

In the main function we now use Compiler instead of static array:

Compiler compiler;
compiler.compile();

memcpy(executableMemory, compiler.code.data(), compiler.code.size());

One part I remember from my university is that each function starts with “prologue” and “epilogue”. Looking at the assembly above:

// This is the prologue.
pushq	%rbp
movq	%rsp, %rbp

// This is the epilogue.
popq	%rbp
retq

These parts ensure that the function remembers the address of the caller to return to it later. Let us extract them in separate methods of our Compiler struct:

void emitPrologue() {
    emit(0x55); // pushq %rbp
    emit({0x48, 0x89, 0xe5}); // movq %rsp, %rbp
}

void emitEpilogue() {
    emit(0x5d); // popq %rbp
    emit(0xc3); // retq
}

void compile() {
    emitPrologue();

    emit({0xb8, 0x2a, 0x00, 0x00, 0x00}); // movl $42, %eax

    emitEpilogue();
}

Defining the language

Before we actually JIT compile any expression, we need to define our expression language. The problem is, there is only so much you can understand about an instruction set for the CPU without proper documentation. So we will limit ourselves to a very simple expression language which looks something like this:

((MyColumn + 1) * 2) - 4

The limitations are the following:

  • There are only 3 kinds of operations (+, -, *)
  • The input variable MyColumn has type uint8 and can be used only once
  • First argument to the operation is either MyColumn or the result of the child expression
  • Second argument to the operation is always a constant

We will assume that there is already a parser for this expression language created for us, so we are going to just define the syntax tree:

enum class OpType {
    Addition,
    Subtraction,
    Multiplication,
};

struct Node {
    OpType type;
    std::unique_ptr<Node> rhs;
    uint8_t lhs;

    Node(OpType type, std::unique_ptr<Node> rhs, uint8_t lhs)
        : type(type), rhs(std::move(rhs)), lhs(lhs)
    {}
};

Then the example expression ((MyColumn + 1) * 2) - 4 can be created like this:

auto root = std::make_unique<Node>( // (...) - 4
    OpType::Subtraction,
    std::make_unique<Node>( // (...) * 2
        OpType::Multiplication,
        std::make_unique<Node>(OpType::Addition, nullptr, 1), // (MyColumn + 1)
        2
    ),
    4
);

Compiling the expressions

Now that we have our language “defined”, let us try to compile the expressions.

As you probably already know, assembly is a very low level language. It works it terms of instructions, each performing a very simple computation. Instructions have access to the stack (which is a contiguous piece of memory program can use) and a fixed number of registers (you can think of them as small variables). We will need to somehow translate our expression into these terms.

One of the ways to do that is to use register %eax as our “intermediate result” variable. Then ((MyColumn + 1) * 2) - 4 compiled into pseudo-assembly will look something like this:

%eax = MyColumn
%eax = %eax + 1
%eax = %eax * 2
%eax = %eax - 4
return %eax

One can notice how each line in this pseudo-assembly is a very simple operation working with just 1 register.

In the same way we did before, we will figure out the exact assembly with the help of clang by feeding various sample programs to it. For example, the following program:

#include <iostream>

int readNumber() {
    int input;
    std::cin >> input;
    return input;
}

int getNumber(int a) {
    return a + 5;
}

int main() {
    return getNumber(readNumber());
}

When compiled using -O2 optimization level produces the following line of assembly for the a + 5 expression:

83 c0 05  addl	$5, %eax

We can clearly see our “variable” a as a register %eax and our constant 5 as the last hex component on the left.

By replacing constant 5 in our code with 10, we can see how the respective assembly line changes:

83 c0 0a  addl	$10, %eax

Notice how the last component of the hex on the left changed from 05 to 0a. This strongly suggests that 83 c0 part is “perform addition of %eax and some constant and store the result in %eax”, whereas the last part is the actual constant in question.

Applying the same technique to other operations, we can see that:

  • 83 c0 $constant = addl $constant, %eax
  • 6b c0 $constant = imull $constant, %eax, %eax (multiply %eax with $constant and store the result in %eax)
  • 83 e8 $constant = subl $constant, %eax (subtraction, but the arguments are the same as with addition)

Let us implement all these 3 instructions into our compiler:

struct Compiler {
    void addlToEAX(uint8_t constant) {
        emit({0x83, 0xc0, constant}); // movl	$constant, %eax
    }

    void sublToEAX(uint8_t constant) {
        emit({0x83, 0xe8, constant}); // subl	$constant, %eax
    }

    void imullFromEAXToEAX(uint8_t constant) {
        emit({0x6b, 0xc0, constant}); // imull	$constant, %eax, %eax
    }
};

Now we have some instructions, which perform simple arithmetics on our “intermediate result” variable, %eax register. But how do we actually get the initial value of MyColumn into this register? This is also a part which is hard to do without the documentation. There are (at least) two ways of how the argument can be passed to the function: on the stack and through the register. Depending on the compiler and the level optimizations you choose, it can be stack, register, or both. In the absence of any documentation, we will assume that our compiler always passes the argument to the function through the %edi register. Then, we can reuse some of the assembly from clang to actually move it (indirectly) to the %eax register:

// Move the function argument from the %edi register to the stack.
89 7d fc                    	movl	%edi, -4(%rbp)

// Move the function argument from the stack to the %eax register.
8b 45 fc                    	movl	-4(%rbp), %eax

This is, of course, not a reliable solution. Switch the compiler, change the optimization parameters - it will break. But it is good enough for us as Software Engineers without Internet writing a compiler for a language with 3 operations. We are going to add these two instructions right after the emitPrologue() call:

emitPrologue();

// Move from %edi to %eax.
emit({0x89, 0x7d, 0xfc});
emit({0x8b, 0x45, 0xfc});

Finally, let us make the Compiler::compile() method accept the expression syntax tree and perform the compilation:

void compile(const Node* root) {
    emitPrologue();

    // Move from %edi to %eax.
    emit({0x89, 0x7d, 0xfc});
    emit({0x8b, 0x45, 0xfc});

    // Recursively compile the expression.
    doCompile(root);

    emitEpilogue();
}

void doCompile(const Node* root) {
    // If we have a child expression, compile it first.
    if (root->rhs) {
        doCompile(root->rhs.get());
    }

    // Call the appropriate code generation function depending
    // on the type.
    switch (root->type) {
        case OpType::Addition:
            addlToEAX(root->lhs);
            break;
        case OpType::Multiplication:
            imullFromEAXToEAX(root->lhs);
            break;
        case OpType::Subtraction:
            sublToEAX(root->lhs);
            break;
    }
}

After that we pass our expression syntax tree to the compiler:

compiler.compile(root.get());

...

int result = function(3);
std::cout << result << std::endl;

And the program outputs 4, which is expected result for the expression ((MyColumn + 1) * 2) - 4 and value MyColumn = 3. Huzzah! We have just created a JIT compiler without any Internet access on the flight across Europe. Customers are happy, our beloved database is saved.

Conclusion

Of course, the compiler above has a lot of problems. It supports a very limited expression language, it works on uint8 type only, it does not handle negative numbers or overflows like it should. But since this was a purely educational exercise, the fact that it works is the most important one.

If you are interested in further reading, I recommend the following resources:

Thank you for reading!

P.S. If you liked this post, you can subscribe for new posts updates through email. No spam, letters 1-2 times per month.