Skip to content

Commit c323944

Browse files
muchulee8pytorchmergebot
authored andcommitted
[AOTInductor] Include constants in AOTInductor .so file. (pytorch#107718)
Summary: Include the constants into AOTInductor .so file. We do not modify existing API signatures but create necessary format with weight lifted out instead. Test Plan: test/inductor/test_aot_inductor.py Reviewers: Subscribers: Tasks: Tags: Fixes #ISSUE_NUMBER Pull Request resolved: pytorch#107718 Approved by: https://github.com/angelayi, https://github.com/eellison
1 parent fa49be2 commit c323944

File tree

15 files changed

+398
-71
lines changed

15 files changed

+398
-71
lines changed

benchmarks/dynamo/common.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1145,9 +1145,9 @@ def load(cls, model, example_inputs, eager_forward):
11451145

11461146
# Use a utility function for easier benchmarking
11471147
source = """
1148-
#include <torch/csrc/inductor/aot_inductor_model.h>
1148+
#include <torch/csrc/inductor/aot_inductor_model_container.h>
11491149
1150-
torch::aot_inductor::AOTInductorModel model;
1150+
torch::aot_inductor::AOTInductorModelContainer model(1);
11511151
11521152
void run(
11531153
const std::vector<at::Tensor>& input_tensors,

test/cpp/aot_inductor/test.cpp

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,17 +23,28 @@ TEST(AotInductorTest, BasicTest) {
2323
Net net;
2424
net.to(torch::kCUDA);
2525

26+
// We should fix the weight over here.
27+
// This should match exactly with the one in test.py
28+
torch::Tensor weights =
29+
at::arange(640, at::dtype(at::kFloat).device(at::kCUDA));
30+
weights = at::reshape(weights, {10, 64});
31+
torch::Tensor bias = at::zeros({10}, at::dtype(at::kFloat).device(at::kCUDA));
32+
33+
for (const auto& pair : net.named_parameters()) {
34+
if (pair.key().find("weight") != std::string::npos) {
35+
pair.value().copy_(weights);
36+
} else if (pair.key().find("bias") != std::string::npos) {
37+
pair.value().copy_(bias);
38+
}
39+
}
40+
2641
torch::Tensor x =
2742
at::randn({32, 64}, at::dtype(at::kFloat).device(at::kCUDA));
2843
torch::Tensor y =
2944
at::randn({32, 64}, at::dtype(at::kFloat).device(at::kCUDA));
3045
torch::Tensor results_ref = net.forward(x, y);
3146

32-
// TODO: we need to provide an API to concatenate args and weights
3347
std::vector<torch::Tensor> inputs;
34-
for (const auto& pair : net.named_parameters()) {
35-
inputs.push_back(pair.value());
36-
}
3748
inputs.push_back(x);
3849
inputs.push_back(y);
3950

test/cpp/aot_inductor/test.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,12 @@ class Net(torch.nn.Module):
88
def __init__(self):
99
super().__init__()
1010
self.fc = torch.nn.Linear(64, 10)
11+
weights = torch.arange(640)
12+
weights = torch.reshape(weights, (10, 64))
13+
14+
with torch.no_grad():
15+
self.fc.weight.copy_(weights)
16+
self.fc.bias.copy_(torch.zeros(10))
1117

1218
def forward(self, x, y):
1319
return self.fc(torch.sin(x) + torch.cos(y))

test/inductor/test_aot_inductor.py

Lines changed: 27 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -43,9 +43,9 @@ def load(cls, model, example_inputs, example_outputs, options=None):
4343

4444
# Use a utility function for easier testing
4545
source = """
46-
#include <torch/csrc/inductor/aot_inductor_model.h>
46+
#include <torch/csrc/inductor/aot_inductor_model_container.h>
4747
48-
torch::aot_inductor::AOTInductorModel model;
48+
torch::aot_inductor::AOTInductorModelContainer model(1);
4949
5050
void run(
5151
const std::vector<at::Tensor>& input_tensors,
@@ -69,12 +69,10 @@ def run(cls, model, example_inputs, example_outputs, options=None):
6969
optimized, exported, output_tensors, output_spec = AOTInductorModelRunner.load(
7070
model, example_inputs, example_outputs, options
7171
)
72-
param_buffer_values = list(exported.state_dict.values())
7372
flat_example_inputs = fx_pytree.tree_flatten_spec(
7473
example_inputs, exported.call_spec.in_spec
7574
)
76-
all_args = (*param_buffer_values, *flat_example_inputs)
77-
optimized(all_args, output_tensors)
75+
optimized(flat_example_inputs, output_tensors)
7876
return pytree.tree_unflatten(output_tensors, output_spec)
7977

8078

@@ -98,6 +96,30 @@ def forward(self, x, y):
9896
actual = AOTInductorModelRunner.run(model, example_inputs, expected)
9997
self.assertTrue(same(actual, expected))
10098

99+
@requires_cpp_extension()
100+
def test_with_offset(self):
101+
class Repro(torch.nn.Module):
102+
def __init__(self):
103+
super().__init__()
104+
self.orig_tensor = torch.randn(2, 15, 10, device="cuda")[0]
105+
self.tensor = self.orig_tensor[5:, :]
106+
107+
def forward(self, x, y):
108+
return (
109+
x
110+
+ torch.nn.functional.linear(y, self.orig_tensor[:10, :])
111+
+ self.tensor
112+
)
113+
114+
model = Repro()
115+
example_inputs = (
116+
torch.randn(10, 10, device="cuda"),
117+
torch.randn(10, 10, device="cuda"),
118+
)
119+
expected = model(*example_inputs)
120+
actual = AOTInductorModelRunner.run(model, example_inputs, expected)
121+
self.assertTrue(same(actual, expected))
122+
101123
@requires_cpp_extension()
102124
def test_missing_output(self):
103125
class Repro(torch.nn.Module):

test/inductor/test_inductor_freezing.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -317,8 +317,8 @@ def foo(mod, x):
317317
# we unfuse the conv bias, but it should only have one constant in the kernel
318318
if self.device == "cuda":
319319
FileCheck().check_not(".run(").check("conv").check(".run(").check_same(
320-
"constant"
321-
).check_not("constant").check_next("return").run(code[0])
320+
"frozen_param"
321+
).check_not("frozen_param").check_next("return").run(code[0])
322322

323323
self.assertEqual(
324324
out_optimized_for_infernece, out_eager, atol=1e-2, rtol=1e-2

torch/_export/__init__.py

Lines changed: 40 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import copy
12
import dataclasses
23
import io
34
import re
@@ -553,7 +554,6 @@ def aot_compile(
553554
Returns:
554555
Path to the generated shared library, and the exported program
555556
"""
556-
from torch._inductor.compile_fx import compile_fx_aot
557557
from torch._inductor.decomposition import select_decomp_table
558558

559559
global DECOMP_TABLE
@@ -562,11 +562,46 @@ def aot_compile(
562562
# Reset the global value
563563
DECOMP_TABLE = core_aten_decompositions()
564564

565-
param_buffer_values = list(ep.state_dict.values())
566565
flat_example_inputs = fx_pytree.tree_flatten_spec(
567566
combine_args_kwargs(args, kwargs), ep.call_spec.in_spec # type: ignore[arg-type]
568567
)
569-
all_args = (*param_buffer_values, *flat_example_inputs)
570568

571-
so_path = torch._inductor.aot_compile(ep.graph_module, list(all_args), options)
572-
return so_path, ep
569+
unlifted_module = ep.module()
570+
unlifted_module.graph.set_codegen(torch.fx.CodeGen()) # type: ignore[attr-defined]
571+
unlifted_module.recompile()
572+
options = (
573+
{"from_export": True}
574+
if options is None
575+
else {**options, "from_export": True}
576+
)
577+
so_path = torch._inductor.aot_compile(unlifted_module, flat_example_inputs, options) # type: ignore[arg-type]
578+
579+
user_inputs = []
580+
user_outputs = []
581+
for node in unlifted_module.graph.nodes:
582+
if node.op == "placeholder":
583+
user_inputs.append(node.name)
584+
elif node.op == "output":
585+
user_outputs = [arg.name for arg in node.args[0]]
586+
587+
unlifted_ep = ExportedProgram(
588+
unlifted_module,
589+
unlifted_module.graph,
590+
ExportGraphSignature(
591+
[],
592+
[],
593+
user_inputs,
594+
user_outputs,
595+
{},
596+
{},
597+
{},
598+
None,
599+
),
600+
call_spec=copy.deepcopy(ep.call_spec),
601+
state_dict={},
602+
range_constraints=copy.deepcopy(ep.range_constraints),
603+
equality_constraints=copy.deepcopy(ep.equality_constraints),
604+
module_call_graph=ep.module_call_graph,
605+
)
606+
607+
return so_path, unlifted_ep

torch/_inductor/codecache.py

Lines changed: 89 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -302,13 +302,13 @@ def get_lock_dir():
302302
return lock_dir
303303

304304

305-
def code_hash(code, extra: str = ""):
306-
hashing_str = code
305+
def code_hash(code: Union[str, bytes], extra: str = ""):
306+
hashing_str = code if isinstance(code, bytes) else code.encode("utf-8")
307307
if extra != "":
308-
hashing_str = hashing_str + "||" + extra
308+
hashing_str = hashing_str + b"||" + extra.encode("utf-8")
309309
return (
310310
"c"
311-
+ base64.b32encode(hashlib.sha256(hashing_str.encode("utf-8")).digest())[:51]
311+
+ base64.b32encode(hashlib.sha256(hashing_str).digest())[:51]
312312
.decode("utf-8")
313313
.lower()
314314
)
@@ -656,6 +656,10 @@ def pick_vec_isa():
656656
return invalid_vec_isa
657657

658658

659+
def get_compile_only(compile_only=True):
660+
return "-c" if compile_only else ""
661+
662+
659663
def get_shared(shared=True):
660664
return "-shared -fPIC" if shared else ""
661665

@@ -884,6 +888,7 @@ def cpp_compile_command(
884888
vec_isa: VecISA = invalid_vec_isa,
885889
cuda=False,
886890
aot_mode=False,
891+
compile_only=False,
887892
):
888893
ipaths, lpaths, libs, macros = get_include_and_linking_paths(
889894
include_pytorch, vec_isa, cuda, aot_mode
@@ -913,11 +918,20 @@ def cpp_compile_command(
913918
{use_custom_generated_macros()}
914919
{use_fb_internal_macros()}
915920
{use_standard_sys_dir_headers()}
921+
{get_compile_only(compile_only)}
916922
-o {out_name}
917923
""",
918924
).strip()
919925

920926

927+
def run_command_and_check(cmd: str):
928+
cmd = shlex.split(cmd)
929+
try:
930+
subprocess.check_call(cmd)
931+
except subprocess.CalledProcessError as e:
932+
raise exc.CppCompileError(cmd, e.output) from e
933+
934+
921935
class CudaKernelParamCache:
922936
cache = dict()
923937
clear = staticmethod(cache.clear)
@@ -951,12 +965,29 @@ def compile(cls, graph, source_code, cuda):
951965
"i", "o", vec_isa=picked_vec_isa, cuda=cuda, aot_mode=graph.aot_mode
952966
)
953967
)
968+
if config.is_fbcode():
969+
ld_command = build_paths.ld()
970+
objcopy_command = build_paths.objcopy()
971+
else:
972+
ld_command = "ld"
973+
objcopy_command = "objcopy"
954974
key, input_path = write(
955975
source_code,
956976
"cpp",
957977
extra=cpp_command,
958978
specified_dir=config.aot_inductor_output_path,
959979
)
980+
981+
aot_constants = b""
982+
for tensor in graph.constants.values():
983+
aot_constants += bytes(tensor.untyped_storage().cpu())
984+
985+
consts_key, consts_path = write(
986+
aot_constants,
987+
"bin",
988+
specified_dir=config.aot_inductor_output_path,
989+
)
990+
960991
if key not in cls.cache:
961992
from filelock import FileLock
962993

@@ -966,20 +997,61 @@ def compile(cls, graph, source_code, cuda):
966997
output_so = os.path.splitext(input_path)[0] + ".so"
967998

968999
if not os.path.exists(output_so):
969-
cmd = shlex.split(
970-
cpp_compile_command(
971-
input=input_path,
972-
output=output_so,
973-
vec_isa=picked_vec_isa,
974-
cuda=cuda,
975-
aot_mode=graph.aot_mode,
976-
)
1000+
output_o = os.path.splitext(input_path)[0] + ".o"
1001+
cmd = cpp_compile_command(
1002+
input=input_path,
1003+
output=output_o,
1004+
vec_isa=picked_vec_isa,
1005+
cuda=cuda,
1006+
aot_mode=graph.aot_mode,
1007+
compile_only=True,
1008+
)
1009+
log.debug("aot compilation command: %s", cmd)
1010+
run_command_and_check(cmd)
1011+
1012+
consts_o = os.path.splitext(consts_path)[0] + ".o"
1013+
cmd = f"{ld_command} -r -b binary -o {consts_o} {consts_path}"
1014+
run_command_and_check(cmd)
1015+
log.debug("aot constant binary command: %s", cmd)
1016+
1017+
cmd = (
1018+
f"{objcopy_command} --rename-section"
1019+
" .data=.lrodata,alloc,load,readonly,data,contents"
1020+
f" {consts_o} {consts_o}"
1021+
)
1022+
log.debug("aot constant obj command: %s", cmd)
1023+
run_command_and_check(cmd)
1024+
1025+
cmd = f"rm {consts_path}"
1026+
log.debug("aot constant bin removal command: %s", cmd)
1027+
run_command_and_check(cmd)
1028+
1029+
body = re.sub(r"[\W_]+", "_", consts_path)
1030+
symbol_list = []
1031+
symbol_list.append(
1032+
f"{objcopy_command} --redefine-sym _binary_{body}_start=_binary_constants_bin_start {consts_o}"
1033+
)
1034+
symbol_list.append(
1035+
f"{objcopy_command} --redefine-sym _binary_{body}_start=_binary_constants_bin_size {consts_o}"
1036+
)
1037+
symbol_list.append(
1038+
f"{objcopy_command} --redefine-sym _binary_{body}_end=_binary_constants_bin_end {consts_o}"
1039+
)
1040+
log.debug(
1041+
"aot constant binary redefine symbol: %s", " ".join(symbol_list)
1042+
)
1043+
for cmd in symbol_list:
1044+
run_command_and_check(cmd)
1045+
1046+
cmd = cpp_compile_command(
1047+
input=f"{output_o} {consts_o}",
1048+
output=output_so,
1049+
vec_isa=picked_vec_isa,
1050+
cuda=cuda,
1051+
aot_mode=graph.aot_mode,
9771052
)
978-
log.debug("aot compilation command: %s", " ".join(cmd))
979-
try:
980-
subprocess.check_call(cmd)
981-
except subprocess.CalledProcessError as e:
982-
raise exc.CppCompileError(cmd, e.output) from e
1053+
log.debug("aot linkage command: %s", cmd)
1054+
run_command_and_check(cmd)
9831055
else:
9841056
log.debug(
9851057
"aot_inductor dynamic library already exist: %s", output_so

0 commit comments

Comments
 (0)