Skip to content

Commit 5ffc50d

Browse files
uuuvngeohot
andauthored
Clang JIT (tinygrad#8481)
Co-authored-by: George Hotz <[email protected]>
1 parent 12fa434 commit 5ffc50d

File tree

8 files changed

+98
-21
lines changed

8 files changed

+98
-21
lines changed

.github/workflows/test.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,8 @@ jobs:
2727
python-version: 3.12
2828
- name: Install docs dependencies (no cache)
2929
run: pip install -e '.[docs]'
30+
- name: Install capstone for CLANG disassembly
31+
run: pip install capstone
3032
- name: Use as an external package
3133
run: |
3234
mkdir $HOME/test_external_dir

docs/abstractions2.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
print("******** first, the runtime ***********")
99

10-
from tinygrad.runtime.ops_clang import ClangProgram, ClangCompiler, MallocAllocator
10+
from tinygrad.runtime.ops_clang import ClangJITCompiler, MallocAllocator, CPUProgram
1111

1212
# allocate some buffers
1313
out = MallocAllocator.alloc(4)
@@ -19,10 +19,10 @@
1919
MallocAllocator._copyin(b, memoryview(bytearray([3,0,0,0])))
2020

2121
# compile a program to a binary
22-
lib = ClangCompiler().compile("void add(int *out, int *a, int *b) { out[0] = a[0] + b[0]; }")
22+
lib = ClangJITCompiler().compile("void add(int *out, int *a, int *b) { out[0] = a[0] + b[0]; }")
2323

24-
# create a runtime for the program (ctypes.CDLL)
25-
fxn = ClangProgram("add", lib)
24+
# create a runtime for the program
25+
fxn = CPUProgram("add", lib)
2626

2727
# run the program
2828
fxn(out, a, b)
@@ -65,7 +65,7 @@
6565
# compile a program (and print the source)
6666
fxn = CompiledRunner(kernel.to_program())
6767
print(fxn.p.src)
68-
# NOTE: fxn.clprg is the ClangProgram
68+
# NOTE: fxn.clprg is the CPUProgram
6969

7070
# run the program
7171
fxn.exec([out, a, b])

docs/developer/runtime.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,9 +36,9 @@ The `Allocator` class is responsible for managing memory on the device. There is
3636

3737
### Program
3838

39-
The `Program` class is created for each loaded program. It is responsible for compiling and executing the program on the device. As an example, here is a `ClangProgram` implementation which loads program and runs it.
39+
The `Program` class is created for each loaded program. It is responsible for executing the program on the device. As an example, here is a `CPUProgram` implementation which loads program and runs it.
4040

41-
::: tinygrad.runtime.ops_clang.ClangProgram
41+
::: tinygrad.runtime.ops_clang.CPUProgram
4242
options:
4343
members: true
4444

setup.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,8 @@
5858
"hypothesis",
5959
"nibabel",
6060
"bottle",
61-
"ggml-python"
61+
"ggml-python",
62+
"capstone"
6263
],
6364
'webgpu': ["wgpu"],
6465
'docs': [

tinygrad/helpers.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,8 @@ def lo32(x:Any) -> Any: return x & 0xFFFFFFFF # Any is sint
4747
def hi32(x:Any) -> Any: return x >> 32 # Any is sint
4848
def data64(data:Any) -> tuple[Any, Any]: return (data >> 32, data & 0xFFFFFFFF) # Any is sint
4949
def data64_le(data:Any) -> tuple[Any, Any]: return (data & 0xFFFFFFFF, data >> 32) # Any is sint
50+
def getbits(value: int, start: int, end: int): return (value >> start) & ((1 << end-start+1) - 1)
51+
def i2u(bits: int, value: int): return value if value >= 0 else (1<<bits)+value
5052
def merge_dicts(ds:Iterable[dict[T,U]]) -> dict[T,U]:
5153
kvs = set([(k,v) for d in ds for k,v in d.items()])
5254
assert len(kvs) == len(set(kv[0] for kv in kvs)), f"cannot merge, {kvs} contains different values for the same key"
@@ -267,7 +269,7 @@ def cpu_objdump(lib, objdump_tool='objdump'):
267269
def from_mv(mv:memoryview, to_type=ctypes.c_char):
268270
return ctypes.cast(ctypes.addressof(to_type.from_buffer(mv)), ctypes.POINTER(to_type * len(mv))).contents
269271
def to_mv(ptr:int, sz:int) -> memoryview: return memoryview(ctypes.cast(ptr, ctypes.POINTER(ctypes.c_uint8 * sz)).contents).cast("B")
270-
def mv_address(mv:memoryview): return ctypes.addressof(ctypes.c_char.from_buffer(mv))
272+
def mv_address(mv): return ctypes.addressof(ctypes.c_char.from_buffer(mv))
271273
def to_char_p_p(options: list[bytes], to_type=ctypes.c_char):
272274
return (ctypes.POINTER(to_type) * len(options))(*[ctypes.cast(ctypes.create_string_buffer(o), ctypes.POINTER(to_type)) for o in options])
273275
@functools.lru_cache(maxsize=None)

tinygrad/renderer/cstyle.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -190,7 +190,10 @@ def render_kernel(self, function_name, kernel, bufs, uops, prefix=None) -> str:
190190
'#define AMX_SET(imm5) __asm("nop\\nnop\\nnop\\n.word (0x201000+(%0<<5)+%1)" : : "i"(17), "i"(imm5) : "memory")',
191191
'#define AMX(op, gpr, btf) __asm(".word (0x201000+(%0 << 5)+0%1-((0%1>>4)*6))" : : "i"(op), "r"((unsigned long long)(gpr)+(btf)) : "memory")',
192192
]
193-
prefix += [f"""{(out := self.render_dtype(dtype_in.vec(N*N)))} __{name}({self.render_dtype(dtype_in.vec(N))} data1, {self.render_dtype(dtype_in.vec(M))} data2, {out} data0){{
193+
# 'static' in C roughly means that function symbol isn't exported. LLVM puts those symbols at the end of object file which allows Clang JIT
194+
# to just jump at the start of a shellcode whithout having to deal with symbols or trampolines at all. This is better than having to inline
195+
# wmma function every time it is called or wasting complexity on a symbol parsing and a memory page on trampoline.
196+
prefix += [f"""static {(out := self.render_dtype(dtype_in.vec(N*N)))} __{name}({self.render_dtype(dtype_in.vec(N))} data1, {self.render_dtype(dtype_in.vec(M))} data2, {out} data0){{
194197
AMX_SET(0);\n for(int ridx0 = 0; ridx0 < 16; ridx0++){{ AMX(4, (int *)(&data0), 0ull<<62 | (ridx0*4ull)<<56 | ridx0*64ull); }}
195198
AMX(0, (int *)(&data2), 0ull<<62); AMX(1, (int *)(&data1), 0ull<<62); AMX(12, 0, 0ull);
196199
for(int ridx0 = 0; ridx0 < 16; ridx0++){{ AMX(5, (int *)(&data0), 0ull<<62 | (ridx0*4ull)<<56 | ridx0*64ull); }}\n AMX_SET(1);\n return data0;\n}}"""] # noqa: E501

tinygrad/runtime/ops_clang.py

Lines changed: 61 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,14 @@
1-
import ctypes, subprocess, pathlib, tempfile
1+
import ctypes, ctypes.util, struct, platform, tempfile, pathlib, subprocess
2+
from mmap import mmap, PROT_READ, PROT_WRITE, PROT_EXEC, MAP_ANON, MAP_PRIVATE
3+
from tinygrad.helpers import OSX, mv_address, cpu_time_execution, cpu_objdump
24
from tinygrad.device import Compiled, Compiler, MallocAllocator
3-
from tinygrad.helpers import cpu_time_execution, cpu_objdump
5+
from tinygrad.runtime.support.elf import elf_loader, relocate
46
from tinygrad.renderer.cstyle import ClangRenderer
57

8+
# NOTE: MAP_JIT is added to mmap module in python 3.13
9+
MAP_JIT = 0x0800
10+
11+
# Used by ops_dsp.py
612
class ClangCompiler(Compiler):
713
def __init__(self, cachekey="compile_clang", args:list[str]|None=None, objdump_tool='objdump'):
814
self.args = ['-march=native'] if args is None else args
@@ -18,15 +24,60 @@ def compile(self, src:str) -> bytes:
1824

1925
def disassemble(self, lib:bytes): return cpu_objdump(lib, self.objdump_tool)
2026

21-
class ClangProgram:
27+
class ClangJITCompiler(Compiler):
28+
def __init__(self, cachekey="compile_clang_jit"): super().__init__(cachekey)
29+
30+
def compile(self, src:str) -> bytes:
31+
# -fno-math-errno is required for __builtin_sqrt to become an instruction instead of a function call
32+
# x18 is a reserved platform register. It is clobbered on context switch in macos and is used to store TEB pointer in windows on arm, don't use it
33+
args = ['-march=native', f'--target={platform.machine()}-none-unknown-elf', '-O2', '-fPIC', '-ffreestanding', '-fno-math-errno', '-nostdlib']
34+
arch_args = ['-ffixed-x18'] if platform.machine() == 'arm64' else []
35+
obj = subprocess.check_output(['clang', '-c', '-x', 'c', *args, *arch_args, '-', '-o', '-'], input=src.encode('utf-8'))
36+
image, _, relocs = elf_loader(obj)
37+
# This is needed because we have an object file, not a .so that has all internal references (like loads of constants from .rodata) resolved.
38+
for ploc,tgt,r_type,r_addend in relocs:
39+
image[ploc:ploc+4] = struct.pack("<I", relocate(struct.unpack("<I", image[ploc:ploc+4])[0], ploc, tgt+r_addend, r_type))
40+
return bytes(image)
41+
42+
def disassemble(self, lib):
43+
import capstone
44+
match platform.machine():
45+
case 'x86_64': cs = capstone.Cs(capstone.CS_ARCH_X86, capstone.CS_MODE_64)
46+
case 'aarch64' | 'arm64': cs = capstone.Cs(capstone.CS_ARCH_ARM64, capstone.CS_MODE_ARM)
47+
case machine: raise NotImplementedError(f"Capstone disassembly isn't supported for {machine}")
48+
for instr in cs.disasm(lib, 0):
49+
print(f"{instr.address:#08x}: {instr.mnemonic}\t{instr.op_str}")
50+
51+
# CPUProgram is a jit/shellcode program that can be just mmapped and jumped to
52+
class CPUProgram:
53+
helper_handle = ctypes.CDLL(ctypes.util.find_library('System' if OSX else 'gcc_s'))
54+
2255
def __init__(self, name:str, lib:bytes):
23-
self.name, self.lib = name, lib
24-
# write to disk so we can load it
25-
with tempfile.NamedTemporaryFile(delete=True) as cached_file_path:
26-
pathlib.Path(cached_file_path.name).write_bytes(lib)
27-
self.fxn = ctypes.CDLL(str(cached_file_path.name))[name]
56+
# On apple silicon with SPRR enabled (it always is in macos) RWX pages are unrepresentable: https://blog.svenpeter.dev/posts/m1_sprr_gxf/
57+
# MAP_JIT allows us to easily flip pages from RW- to R-X and vice versa. It is a noop on intel cpus. (man pthread_jit_write_protect_np)
58+
self.mem = mmap(-1, len(lib), MAP_ANON | MAP_PRIVATE | (MAP_JIT if OSX else 0), PROT_READ | PROT_WRITE | PROT_EXEC)
59+
60+
if OSX: CPUProgram.helper_handle.pthread_jit_write_protect_np(False)
61+
self.mem.write(lib)
62+
if OSX: CPUProgram.helper_handle.pthread_jit_write_protect_np(True)
63+
64+
# __clear_cache isn't a normal libc function, but a compiler support routine found in libgcc_s for gcc and compiler-rt for clang.
65+
# libgcc_s comes as shared library but compiler-rt is only a bunch of static library archives which we can't directly load, but fortunately
66+
# it somehow found its way into libSystem on macos (likely because it used __builtin_clear_cache) and libgcc_s is ~always present on linux
67+
# Using ["name"] instead of .name because otherwise name is getting mangled: https://docs.python.org/3.12/reference/expressions.html#index-5
68+
CPUProgram.helper_handle["__clear_cache"](ctypes.c_void_p(mv_address(self.mem)), ctypes.c_void_p(mv_address(self.mem) + len(lib)))
69+
70+
self.fxn = ctypes.CFUNCTYPE(None)(mv_address(self.mem))
2871

29-
def __call__(self, *bufs, vals=(), wait=False): return cpu_time_execution(lambda: self.fxn(*bufs, *vals), enable=wait)
72+
def __call__(self, *bufs, vals=(), wait=False):
73+
args = list(bufs) + list(vals)
74+
# NOTE: replace this by --target={host's triple}-elf in clang args once we only support macos sequoia and later.
75+
# Apple relaxes abi requirement for stack arguments to always be at least 8 byte aligned on arm64
76+
# https://developer.apple.com/documentation/xcode/writing-arm64-code-for-apple-platforms
77+
# This hack is required because clang/llvm bug doesn't allow us to just use {host's triple}+'-elf' (relocation failures)
78+
# The bug was fixed in https://github.com/llvm/llvm-project/commit/454cc36630296262cdb6360b60f90a64a97f7f1a but was only backported to xcode 16+
79+
if platform.machine() == "arm64" and OSX: args = args[:8] + [ctypes.c_int64(a) if isinstance(a, int) else a for a in args[8:]]
80+
return cpu_time_execution(lambda: self.fxn(*args), enable=wait)
3081

3182
class ClangDevice(Compiled):
32-
def __init__(self, device:str): super().__init__(device, MallocAllocator, ClangRenderer(), ClangCompiler(), ClangProgram)
83+
def __init__(self, device:str): super().__init__(device, MallocAllocator, ClangRenderer(), ClangJITCompiler(), CPUProgram)

tinygrad/runtime/support/elf.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
1-
from dataclasses import dataclass
21
import tinygrad.runtime.autogen.libc as libc
2+
from dataclasses import dataclass
3+
from tinygrad.helpers import getbits, i2u
34

45
@dataclass(frozen=True)
56
class ElfSection: name:str; header:libc.Elf64_Shdr; content:bytes # noqa: E702
@@ -34,3 +35,20 @@ def _to_carray(sh, ctype): return (ctype * (sh.header.sh_size // sh.header.sh_en
3435
relocs += [(target_image_off + roff, sections[sym.st_shndx].header.sh_addr + sym.st_value, rtype, raddend) for roff, sym, rtype, raddend in rels]
3536

3637
return memoryview(image), sections, relocs
38+
39+
def relocate(instr: int, ploc: int, tgt: int, r_type: int):
40+
match r_type:
41+
# https://refspecs.linuxfoundation.org/elf/x86_64-abi-0.95.pdf
42+
case libc.R_X86_64_PC32: return i2u(32, tgt-ploc)
43+
# https://github.com/ARM-software/abi-aa/blob/main/aaelf64/aaelf64.rst for definitions of relocations
44+
# https://www.scs.stanford.edu/~zyedidia/arm64/index.html for instruction encodings
45+
case libc.R_AARCH64_ADR_PREL_PG_HI21:
46+
rel_pg = (tgt & ~0xFFF) - (ploc & ~0xFFF)
47+
return instr | (getbits(rel_pg, 12, 13) << 29) | (getbits(rel_pg, 14, 32) << 5)
48+
case libc.R_AARCH64_ADD_ABS_LO12_NC: return instr | (getbits(tgt, 0, 11) << 10)
49+
case libc.R_AARCH64_CALL26: return instr | getbits(tgt, 2, 27)
50+
case libc.R_AARCH64_LDST16_ABS_LO12_NC: return instr | (getbits(tgt, 1, 11) << 10)
51+
case libc.R_AARCH64_LDST32_ABS_LO12_NC: return instr | (getbits(tgt, 2, 11) << 10)
52+
case libc.R_AARCH64_LDST64_ABS_LO12_NC: return instr | (getbits(tgt, 3, 11) << 10)
53+
case libc.R_AARCH64_LDST128_ABS_LO12_NC: return instr | (getbits(tgt, 4, 11) << 10)
54+
raise NotImplementedError(f"Encountered unknown relocation type {r_type}")

0 commit comments

Comments
 (0)