Skip to content

Commit 5cb0f14

Browse files
authored
use specialized code when compiling opaque closure expressions (#43320)
invoke specialization when an OC is created at run time
1 parent 98e60ff commit 5cb0f14

File tree

3 files changed

+153
-107
lines changed

3 files changed

+153
-107
lines changed

src/codegen.cpp

Lines changed: 95 additions & 91 deletions
Original file line numberDiff line numberDiff line change
@@ -4554,6 +4554,68 @@ static void emit_stmtpos(jl_codectx_t &ctx, jl_value_t *expr, int ssaval_result)
45544554
}
45554555
}
45564556

4557+
static std::pair<Function*, Function*> get_oc_function(jl_codectx_t &ctx, jl_method_t *closure_method, jl_tupletype_t *env_t, jl_tupletype_t *argt_typ, jl_value_t *rettype, bool vaOverride)
4558+
{
4559+
jl_svec_t *sig_args = NULL;
4560+
jl_value_t *sigtype = NULL;
4561+
jl_code_info_t *ir = NULL;
4562+
JL_GC_PUSH3(&sig_args, &sigtype, &ir);
4563+
4564+
size_t nsig = 1 + jl_svec_len(argt_typ->parameters);
4565+
sig_args = jl_alloc_svec_uninit(nsig);
4566+
jl_svecset(sig_args, 0, env_t);
4567+
for (size_t i = 0; i < jl_svec_len(argt_typ->parameters); ++i) {
4568+
jl_svecset(sig_args, 1+i, jl_svecref(argt_typ->parameters, i));
4569+
}
4570+
sigtype = (jl_value_t*)jl_apply_tuple_type_v(jl_svec_data(sig_args), nsig);
4571+
4572+
jl_method_instance_t *mi = jl_specializations_get_linfo(closure_method, sigtype, jl_emptysvec);
4573+
jl_code_instance_t *ci = (jl_code_instance_t*)jl_rettype_inferred(mi, ctx.world, ctx.world);
4574+
4575+
if (ci == NULL || (jl_value_t*)ci == jl_nothing || ci->inferred == NULL || ci->inferred == jl_nothing) {
4576+
JL_GC_POP();
4577+
return std::make_pair((Function*)NULL, (Function*)NULL);
4578+
}
4579+
4580+
ir = jl_uncompress_ir(closure_method, ci, (jl_array_t*)ci->inferred);
4581+
4582+
// TODO: Emit this inline and outline it late using LLVM's coroutine support.
4583+
std::unique_ptr<Module> closure_m;
4584+
jl_llvm_functions_t closure_decls;
4585+
std::tie(closure_m, closure_decls) = emit_function(mi, ir, rettype, ctx.emission_context,
4586+
ctx.builder.getContext(), vaOverride);
4587+
4588+
assert(closure_decls.functionObject != "jl_fptr_sparam");
4589+
bool isspecsig = closure_decls.functionObject != "jl_fptr_args";
4590+
4591+
Function *F = NULL;
4592+
std::string fname = isspecsig ?
4593+
closure_decls.functionObject :
4594+
closure_decls.specFunctionObject;
4595+
if (GlobalValue *V = jl_Module->getNamedValue(fname)) {
4596+
F = cast<Function>(V);
4597+
} else {
4598+
F = Function::Create(get_func_sig(jl_LLVMContext),
4599+
Function::ExternalLinkage,
4600+
fname, jl_Module);
4601+
F->setAttributes(get_func_attrs(jl_LLVMContext));
4602+
}
4603+
Function *specF = NULL;
4604+
if (!isspecsig) {
4605+
specF = F;
4606+
} else {
4607+
specF = closure_m->getFunction(closure_decls.specFunctionObject);
4608+
if (specF) {
4609+
jl_returninfo_t returninfo = get_specsig_function(ctx, jl_Module,
4610+
closure_decls.specFunctionObject, sigtype, rettype, true);
4611+
specF = returninfo.decl;
4612+
}
4613+
}
4614+
ctx.oc_modules.push_back(std::move(closure_m));
4615+
JL_GC_POP();
4616+
return std::make_pair(F, specF);
4617+
}
4618+
45574619
// `expr` is not clobbered in JL_TRY
45584620
JL_GCC_IGNORE_START("-Wclobbered")
45594621
static jl_cgval_t emit_expr(jl_codectx_t &ctx, jl_value_t *expr, ssize_t ssaval)
@@ -4832,112 +4894,54 @@ static jl_cgval_t emit_expr(jl_codectx_t &ctx, jl_value_t *expr, ssize_t ssaval)
48324894
}
48334895

48344896
if (can_optimize) {
4835-
// TODO: Emit this inline and outline it late using LLVM's coroutine
4836-
// support.
4837-
jl_method_t *closure_method = (jl_method_t *)source.constant;
4838-
jl_code_info_t *closure_src = jl_uncompress_ir(closure_method, NULL,
4839-
(jl_array_t*)closure_method->source);
4840-
4841-
std::unique_ptr<Module> closure_m;
4842-
jl_llvm_functions_t closure_decls;
4843-
4844-
jl_method_instance_t *li = NULL;
48454897
jl_value_t *closure_t = NULL;
48464898
jl_tupletype_t *env_t = NULL;
4847-
jl_svec_t *sig_args = NULL;
4848-
JL_GC_PUSH5(&li, &closure_src, &closure_t, &env_t, &sig_args);
4849-
4850-
li = jl_new_method_instance_uninit();
4851-
li->def.method = closure_method;
4852-
jl_tupletype_t *argt_typ = (jl_tupletype_t *)argt.constant;
4853-
4854-
closure_t = jl_apply_type2((jl_value_t*)jl_opaque_closure_type, (jl_value_t*)argt_typ, ub.constant);
4855-
4856-
size_t nsig = 1 + jl_svec_len(argt_typ->parameters);
4857-
sig_args = jl_alloc_svec_uninit(nsig);
4858-
jl_svecset(sig_args, 0, closure_t);
4859-
for (size_t i = 0; i < jl_svec_len(argt_typ->parameters); ++i) {
4860-
jl_svecset(sig_args, 1+i, jl_svecref(argt_typ->parameters, i));
4861-
}
4862-
li->specTypes = (jl_value_t*)jl_apply_tuple_type_v(jl_svec_data(sig_args), nsig);
4863-
jl_gc_wb(li, li->specTypes);
4864-
4865-
std::tie(closure_m, closure_decls) = emit_function(li, closure_src,
4866-
ub.constant, ctx.emission_context, ctx.builder.getContext(), jl_unbox_bool(isva.constant));
4899+
JL_GC_PUSH2(&closure_t, &env_t);
48674900

48684901
jl_value_t **env_component_ts = (jl_value_t**)alloca(sizeof(jl_value_t*) * (nargs-5));
48694902
for (size_t i = 0; i < nargs - 5; ++i) {
48704903
env_component_ts[i] = argv[5+i].typ;
48714904
}
48724905

48734906
env_t = jl_apply_tuple_type_v(env_component_ts, nargs-5);
4874-
jl_cgval_t env(ctx.builder.getContext());
4875-
// TODO: Inline the env at the end of the opaque closure and generate a descriptor for GC
4907+
// we need to know the full env type to look up the right specialization
48764908
if (jl_is_concrete_type((jl_value_t*)env_t)) {
4877-
env = emit_new_struct(ctx, (jl_value_t*)env_t, nargs-5, &argv.data()[5]);
4878-
}
4879-
else {
4880-
Value *env_val = emit_jlcall(ctx, jltuple_func, Constant::getNullValue(ctx.types().T_prjlvalue),
4881-
&argv[5], nargs-5, JLCALL_F_CC);
4882-
env = mark_julia_type(ctx, env_val, true, env_t);
4883-
}
4884-
4885-
assert(closure_decls.functionObject != "jl_fptr_sparam");
4886-
bool isspecsig = closure_decls.functionObject != "jl_fptr_args";
4887-
4888-
Function *F = NULL;
4889-
std::string fname = isspecsig ?
4890-
closure_decls.functionObject :
4891-
closure_decls.specFunctionObject;
4892-
if (GlobalValue *V = jl_Module->getNamedValue(fname)) {
4893-
F = cast<Function>(V);
4894-
}
4895-
else {
4896-
F = Function::Create(get_func_sig(ctx.builder.getContext()),
4897-
Function::ExternalLinkage,
4898-
fname, jl_Module);
4899-
F->setAttributes(get_func_attrs(ctx.builder.getContext()));
4900-
}
4901-
jl_cgval_t jlcall_ptr = mark_julia_type(ctx,
4902-
F, false, jl_voidpointer_type);
4903-
4904-
jl_cgval_t fptr(ctx.builder.getContext());
4905-
if (!isspecsig) {
4906-
fptr = jlcall_ptr;
4907-
} else {
4908-
Function *specptr = closure_m->getFunction(closure_decls.specFunctionObject);
4909-
if (specptr) {
4910-
jl_returninfo_t returninfo = get_specsig_function(ctx, jl_Module,
4911-
closure_decls.specFunctionObject, li->specTypes, ub.constant, true);
4912-
fptr = mark_julia_type(ctx, returninfo.decl, false, jl_voidpointer_type);
4913-
} else {
4914-
fptr = mark_julia_type(ctx,
4915-
(llvm::Value*)Constant::getNullValue(getSizeTy(ctx.builder.getContext())),
4916-
false, jl_voidpointer_type);
4917-
}
4918-
}
4909+
jl_tupletype_t *argt_typ = (jl_tupletype_t*)argt.constant;
4910+
Function *F, *specF;
4911+
std::tie(F, specF) = get_oc_function(ctx, (jl_method_t*)source.constant, env_t, argt_typ, ub.constant, jl_unbox_bool(isva.constant));
4912+
if (F) {
4913+
jl_cgval_t jlcall_ptr = mark_julia_type(ctx, F, false, jl_voidpointer_type);
4914+
jl_cgval_t world_age = mark_julia_type(ctx,
4915+
tbaa_decorate(ctx.tbaa().tbaa_gcframe,
4916+
ctx.builder.CreateAlignedLoad(ctx.world_age_field, Align(sizeof(size_t)))),
4917+
false,
4918+
jl_long_type);
4919+
jl_cgval_t fptr(ctx.builder.getContext());
4920+
if (specF)
4921+
fptr = mark_julia_type(ctx, specF, false, jl_voidpointer_type);
4922+
else
4923+
fptr = mark_julia_type(ctx, (llvm::Value*)Constant::getNullValue(getSizeTy(ctx.builder.getContext())), false, jl_voidpointer_type);
49194924

4920-
jl_cgval_t world_age = mark_julia_type(ctx,
4921-
tbaa_decorate(ctx.tbaa().tbaa_gcframe,
4922-
ctx.builder.CreateAlignedLoad(getSizeTy(ctx.builder.getContext()), ctx.world_age_field, Align(sizeof(size_t)))),
4923-
false,
4924-
jl_long_type);
4925-
4926-
jl_cgval_t closure_fields[6] = {
4927-
env,
4928-
isva,
4929-
world_age,
4930-
source,
4931-
jlcall_ptr,
4932-
fptr
4933-
};
4925+
// TODO: Inline the env at the end of the opaque closure and generate a descriptor for GC
4926+
jl_cgval_t env = emit_new_struct(ctx, (jl_value_t*)env_t, nargs-5, &argv.data()[5]);
49344927

4935-
jl_cgval_t ret = emit_new_struct(ctx, closure_t, 6, closure_fields);
4928+
jl_cgval_t closure_fields[6] = {
4929+
env,
4930+
isva,
4931+
world_age,
4932+
source,
4933+
jlcall_ptr,
4934+
fptr
4935+
};
49364936

4937-
ctx.oc_modules.push_back(std::move(closure_m));
4937+
closure_t = jl_apply_type2((jl_value_t*)jl_opaque_closure_type, (jl_value_t*)argt_typ, ub.constant);
4938+
jl_cgval_t ret = emit_new_struct(ctx, closure_t, 6, closure_fields);
49384939

4940+
JL_GC_POP();
4941+
return ret;
4942+
}
4943+
}
49394944
JL_GC_POP();
4940-
return ret;
49414945
}
49424946

49434947
return mark_julia_type(ctx,

src/interpreter.c

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -696,6 +696,9 @@ jl_value_t *jl_interpret_opaque_closure(jl_opaque_closure_t *oc, jl_value_t **ar
696696
jl_code_info_t *code = jl_uncompress_ir(source, NULL, (jl_array_t*)source->source);
697697
interpreter_state *s;
698698
unsigned nroots = jl_source_nslots(code) + jl_source_nssavalues(code) + 2;
699+
jl_task_t *ct = jl_current_task;
700+
size_t last_age = ct->world_age;
701+
ct->world_age = oc->world;
699702
jl_value_t **locals = NULL;
700703
JL_GC_PUSHFRAME(s, locals, nroots);
701704
locals[0] = (jl_value_t*)oc;
@@ -710,7 +713,6 @@ jl_value_t *jl_interpret_opaque_closure(jl_opaque_closure_t *oc, jl_value_t **ar
710713
s->preevaluation = 0;
711714
s->continue_at = 0;
712715
s->mi = NULL;
713-
714716
size_t defargs = source->nargs;
715717
int isva = !!oc->isva;
716718
assert(isva ? nargs + 2 >= defargs : nargs + 1 == defargs);
@@ -722,6 +724,9 @@ jl_value_t *jl_interpret_opaque_closure(jl_opaque_closure_t *oc, jl_value_t **ar
722724
}
723725
JL_GC_ENABLEFRAME(s);
724726
jl_value_t *r = eval_body(code->code, s, 0, 0);
727+
locals[0] = r; // GC root
728+
jl_typeassert(r, jl_tparam1(jl_typeof(oc)));
729+
ct->world_age = last_age;
725730
JL_GC_POP();
726731
return r;
727732
}

src/opaque_closure.c

Lines changed: 52 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -3,18 +3,23 @@
33
#include "julia.h"
44
#include "julia_internal.h"
55

6-
JL_DLLEXPORT jl_value_t *jl_invoke_opaque_closure(jl_opaque_closure_t *oc, jl_value_t **args, size_t nargs)
6+
jl_value_t *jl_fptr_const_opaque_closure(jl_opaque_closure_t *oc, jl_value_t **args, size_t nargs)
77
{
8-
jl_value_t *ret = NULL;
9-
JL_GC_PUSH1(&ret);
10-
jl_task_t *ct = jl_current_task;
11-
size_t last_age = ct->world_age;
12-
ct->world_age = oc->world;
13-
ret = jl_interpret_opaque_closure(oc, args, nargs);
14-
jl_typeassert(ret, jl_tparam1(jl_typeof(oc)));
15-
ct->world_age = last_age;
8+
return oc->captures;
9+
}
10+
11+
// TODO: remove
12+
jl_value_t *jl_fptr_va_opaque_closure(jl_opaque_closure_t *oc, jl_value_t **args, size_t nargs)
13+
{
14+
size_t defargs = oc->source->nargs;
15+
jl_value_t **newargs;
16+
JL_GC_PUSHARGS(newargs, defargs - 1);
17+
for (size_t i = 0; i < defargs - 2; i++)
18+
newargs[i] = args[i];
19+
newargs[defargs - 2] = jl_f_tuple(NULL, &args[defargs - 2], nargs + 2 - defargs);
20+
jl_value_t *ans = ((jl_fptr_args_t)oc->specptr)((jl_value_t*)oc, newargs, defargs - 1);
1621
JL_GC_POP();
17-
return ret;
22+
return ans;
1823
}
1924

2025
jl_opaque_closure_t *jl_new_opaque_closure(jl_tupletype_t *argt, jl_value_t *isva,
@@ -31,17 +36,49 @@ jl_opaque_closure_t *jl_new_opaque_closure(jl_tupletype_t *argt, jl_value_t *isv
3136
jl_value_t *oc_type JL_ALWAYS_LEAFTYPE;
3237
oc_type = jl_apply_type2((jl_value_t*)jl_opaque_closure_type, (jl_value_t*)argt, rt_ub);
3338
JL_GC_PROMISE_ROOTED(oc_type);
34-
jl_value_t *captures = NULL;
35-
JL_GC_PUSH1(&captures);
39+
jl_value_t *captures = NULL, *sigtype = NULL;
40+
jl_svec_t *sig_args = NULL;
41+
JL_GC_PUSH3(&captures, &sigtype, &sig_args);
3642
captures = jl_f_tuple(NULL, env, nenv);
43+
44+
size_t nsig = 1 + jl_svec_len(argt->parameters);
45+
sig_args = jl_alloc_svec_uninit(nsig);
46+
jl_svecset(sig_args, 0, jl_typeof(captures));
47+
for (size_t i = 0; i < nsig-1; ++i) {
48+
jl_svecset(sig_args, 1+i, jl_tparam(argt, i));
49+
}
50+
sigtype = (jl_value_t*)jl_apply_tuple_type_v(jl_svec_data(sig_args), nsig);
51+
jl_method_instance_t *mi = jl_specializations_get_linfo((jl_method_t*)source, sigtype, jl_emptysvec);
52+
size_t world = jl_atomic_load_acquire(&jl_world_counter);
53+
jl_code_instance_t *ci = jl_compile_method_internal(mi, world);
54+
3755
jl_opaque_closure_t *oc = (jl_opaque_closure_t*)jl_gc_alloc(ct->ptls, sizeof(jl_opaque_closure_t), oc_type);
3856
JL_GC_POP();
3957
oc->source = (jl_method_t*)source;
4058
oc->isva = jl_unbox_bool(isva);
41-
oc->invoke = (jl_fptr_args_t)jl_invoke_opaque_closure;
42-
oc->specptr = NULL;
4359
oc->captures = captures;
44-
oc->world = jl_atomic_load_acquire(&jl_world_counter);
60+
oc->specptr = NULL;
61+
int compiled = 0;
62+
if (jl_atomic_load_relaxed(&ci->invoke) == jl_fptr_interpret_call) {
63+
oc->invoke = (jl_fptr_args_t)jl_interpret_opaque_closure;
64+
}
65+
else if (jl_atomic_load_relaxed(&ci->invoke) == jl_fptr_args) {
66+
oc->invoke = jl_atomic_load_relaxed(&ci->specptr.fptr1);
67+
compiled = 1;
68+
}
69+
else if (jl_atomic_load_relaxed(&ci->invoke) == jl_fptr_const_return) {
70+
oc->invoke = (jl_fptr_args_t)jl_fptr_const_opaque_closure;
71+
oc->captures = ci->rettype_const;
72+
}
73+
else {
74+
oc->invoke = (jl_fptr_args_t)jl_atomic_load_relaxed(&ci->invoke);
75+
compiled = 1;
76+
}
77+
if (oc->isva && compiled) {
78+
oc->specptr = (jl_fptr_args_t)oc->invoke;
79+
oc->invoke = (jl_fptr_args_t)jl_fptr_va_opaque_closure;
80+
}
81+
oc->world = world;
4582
return oc;
4683
}
4784

0 commit comments

Comments
 (0)