Skip to content

Commit 06d74e6

Browse files
desertfirepytorchmergebot
authored andcommitted
Revert "[AOTInductor] Include constants in AOTInductor .so file. (#10… (pytorch#108349)
This reverts commit c323944 due to internal test failures. Pull Request resolved: pytorch#108349 Approved by: https://github.com/aakhundov, https://github.com/zhxchen17
1 parent 01dfa76 commit 06d74e6

File tree

15 files changed

+71
-397
lines changed

15 files changed

+71
-397
lines changed

benchmarks/dynamo/common.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1145,9 +1145,9 @@ def load(cls, model, example_inputs, eager_forward):
11451145

11461146
# Use a utility function for easier benchmarking
11471147
source = """
1148-
#include <torch/csrc/inductor/aot_inductor_model_container.h>
1148+
#include <torch/csrc/inductor/aot_inductor_model.h>
11491149
1150-
torch::aot_inductor::AOTInductorModelContainer model(1);
1150+
torch::aot_inductor::AOTInductorModel model;
11511151
11521152
void run(
11531153
const std::vector<at::Tensor>& input_tensors,

test/cpp/aot_inductor/test.cpp

Lines changed: 4 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -23,28 +23,17 @@ TEST(AotInductorTest, BasicTest) {
2323
Net net;
2424
net.to(torch::kCUDA);
2525

26-
// We should fix the weight over here.
27-
// This should match exactly with the one in test.py
28-
torch::Tensor weights =
29-
at::arange(640, at::dtype(at::kFloat).device(at::kCUDA));
30-
weights = at::reshape(weights, {10, 64});
31-
torch::Tensor bias = at::zeros({10}, at::dtype(at::kFloat).device(at::kCUDA));
32-
33-
for (const auto& pair : net.named_parameters()) {
34-
if (pair.key().find("weight") != std::string::npos) {
35-
pair.value().copy_(weights);
36-
} else if (pair.key().find("bias") != std::string::npos) {
37-
pair.value().copy_(bias);
38-
}
39-
}
40-
4126
torch::Tensor x =
4227
at::randn({32, 64}, at::dtype(at::kFloat).device(at::kCUDA));
4328
torch::Tensor y =
4429
at::randn({32, 64}, at::dtype(at::kFloat).device(at::kCUDA));
4530
torch::Tensor results_ref = net.forward(x, y);
4631

32+
// TODO: we need to provide an API to concatenate args and weights
4733
std::vector<torch::Tensor> inputs;
34+
for (const auto& pair : net.named_parameters()) {
35+
inputs.push_back(pair.value());
36+
}
4837
inputs.push_back(x);
4938
inputs.push_back(y);
5039

test/cpp/aot_inductor/test.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,6 @@ class Net(torch.nn.Module):
88
def __init__(self):
99
super().__init__()
1010
self.fc = torch.nn.Linear(64, 10)
11-
weights = torch.arange(640)
12-
weights = torch.reshape(weights, (10, 64))
13-
14-
with torch.no_grad():
15-
self.fc.weight.copy_(weights)
16-
self.fc.bias.copy_(torch.zeros(10))
1711

1812
def forward(self, x, y):
1913
return self.fc(torch.sin(x) + torch.cos(y))

test/inductor/test_aot_inductor.py

Lines changed: 5 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -37,9 +37,9 @@ def load(cls, model, example_inputs, example_outputs, options=None):
3737

3838
# Use a utility function for easier testing
3939
source = """
40-
#include <torch/csrc/inductor/aot_inductor_model_container.h>
40+
#include <torch/csrc/inductor/aot_inductor_model.h>
4141
42-
torch::aot_inductor::AOTInductorModelContainer model(1);
42+
torch::aot_inductor::AOTInductorModel model;
4343
4444
void run(
4545
const std::vector<at::Tensor>& input_tensors,
@@ -63,10 +63,12 @@ def run(cls, model, example_inputs, example_outputs, options=None):
6363
optimized, exported, output_tensors, output_spec = AOTInductorModelRunner.load(
6464
model, example_inputs, example_outputs, options
6565
)
66+
param_buffer_values = list(exported.state_dict.values())
6667
flat_example_inputs = fx_pytree.tree_flatten_spec(
6768
example_inputs, exported.call_spec.in_spec
6869
)
69-
optimized(flat_example_inputs, output_tensors)
70+
all_args = (*param_buffer_values, *flat_example_inputs)
71+
optimized(all_args, output_tensors)
7072
return pytree.tree_unflatten(output_tensors, output_spec)
7173

7274

@@ -89,29 +91,6 @@ def forward(self, x, y):
8991
actual = AOTInductorModelRunner.run(model, example_inputs, expected)
9092
self.assertTrue(same(actual, expected))
9193

92-
def test_with_offset(self):
93-
class Repro(torch.nn.Module):
94-
def __init__(self):
95-
super().__init__()
96-
self.orig_tensor = torch.randn(2, 15, 10, device="cuda")[0]
97-
self.tensor = self.orig_tensor[5:, :]
98-
99-
def forward(self, x, y):
100-
return (
101-
x
102-
+ torch.nn.functional.linear(y, self.orig_tensor[:10, :])
103-
+ self.tensor
104-
)
105-
106-
model = Repro()
107-
example_inputs = (
108-
torch.randn(10, 10, device="cuda"),
109-
torch.randn(10, 10, device="cuda"),
110-
)
111-
expected = model(*example_inputs)
112-
actual = AOTInductorModelRunner.run(model, example_inputs, expected)
113-
self.assertTrue(same(actual, expected))
114-
11594
def test_missing_output(self):
11695
class Repro(torch.nn.Module):
11796
def __init__(self):

test/inductor/test_inductor_freezing.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -317,8 +317,8 @@ def foo(mod, x):
317317
# we unfuse the conv bias, but it should only have one constant in the kernel
318318
if self.device == "cuda":
319319
FileCheck().check_not(".run(").check("conv").check(".run(").check_same(
320-
"frozen_param"
321-
).check_not("frozen_param").check_next("return").run(code[0])
320+
"constant"
321+
).check_not("constant").check_next("return").run(code[0])
322322

323323
self.assertEqual(
324324
out_optimized_for_infernece, out_eager, atol=1e-2, rtol=1e-2

torch/_export/__init__.py

Lines changed: 5 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
import copy
21
import dataclasses
32
import io
43
import re
@@ -627,6 +626,7 @@ def aot_compile(
627626
Returns:
628627
Path to the generated shared library, and the exported program
629628
"""
629+
from torch._inductor.compile_fx import compile_fx_aot
630630
from torch._inductor.decomposition import select_decomp_table
631631

632632
global DECOMP_TABLE
@@ -635,46 +635,11 @@ def aot_compile(
635635
# Reset the global value
636636
DECOMP_TABLE = core_aten_decompositions()
637637

638+
param_buffer_values = list(ep.state_dict.values())
638639
flat_example_inputs = fx_pytree.tree_flatten_spec(
639640
combine_args_kwargs(args, kwargs), ep.call_spec.in_spec # type: ignore[arg-type]
640641
)
642+
all_args = (*param_buffer_values, *flat_example_inputs)
641643

642-
unlifted_module = ep.module()
643-
unlifted_module.graph.set_codegen(torch.fx.CodeGen()) # type: ignore[attr-defined]
644-
unlifted_module.recompile()
645-
options = (
646-
{"from_export": True}
647-
if options is None
648-
else {**options, "from_export": True}
649-
)
650-
so_path = torch._inductor.aot_compile(unlifted_module, flat_example_inputs, options) # type: ignore[arg-type]
651-
652-
user_inputs = []
653-
user_outputs = []
654-
for node in unlifted_module.graph.nodes:
655-
if node.op == "placeholder":
656-
user_inputs.append(node.name)
657-
elif node.op == "output":
658-
user_outputs = [arg.name for arg in node.args[0]]
659-
660-
unlifted_ep = ExportedProgram(
661-
unlifted_module,
662-
unlifted_module.graph,
663-
ExportGraphSignature(
664-
[],
665-
[],
666-
user_inputs,
667-
user_outputs,
668-
{},
669-
{},
670-
{},
671-
None,
672-
),
673-
call_spec=copy.deepcopy(ep.call_spec),
674-
state_dict={},
675-
range_constraints=copy.deepcopy(ep.range_constraints),
676-
equality_constraints=copy.deepcopy(ep.equality_constraints),
677-
module_call_graph=ep.module_call_graph,
678-
)
679-
680-
return so_path, unlifted_ep
644+
so_path = torch._inductor.aot_compile(ep.graph_module, list(all_args), options)
645+
return so_path, ep

torch/_inductor/codecache.py

Lines changed: 17 additions & 89 deletions
Original file line numberDiff line numberDiff line change
@@ -302,13 +302,13 @@ def get_lock_dir():
302302
return lock_dir
303303

304304

305-
def code_hash(code: Union[str, bytes], extra: str = ""):
306-
hashing_str = code if isinstance(code, bytes) else code.encode("utf-8")
305+
def code_hash(code, extra: str = ""):
306+
hashing_str = code
307307
if extra != "":
308-
hashing_str = hashing_str + b"||" + extra.encode("utf-8")
308+
hashing_str = hashing_str + "||" + extra
309309
return (
310310
"c"
311-
+ base64.b32encode(hashlib.sha256(hashing_str).digest())[:51]
311+
+ base64.b32encode(hashlib.sha256(hashing_str.encode("utf-8")).digest())[:51]
312312
.decode("utf-8")
313313
.lower()
314314
)
@@ -656,10 +656,6 @@ def pick_vec_isa():
656656
return invalid_vec_isa
657657

658658

659-
def get_compile_only(compile_only=True):
660-
return "-c" if compile_only else ""
661-
662-
663659
def get_shared(shared=True):
664660
return "-shared -fPIC" if shared else ""
665661

@@ -888,7 +884,6 @@ def cpp_compile_command(
888884
vec_isa: VecISA = invalid_vec_isa,
889885
cuda=False,
890886
aot_mode=False,
891-
compile_only=False,
892887
):
893888
ipaths, lpaths, libs, macros = get_include_and_linking_paths(
894889
include_pytorch, vec_isa, cuda, aot_mode
@@ -918,20 +913,11 @@ def cpp_compile_command(
918913
{use_custom_generated_macros()}
919914
{use_fb_internal_macros()}
920915
{use_standard_sys_dir_headers()}
921-
{get_compile_only(compile_only)}
922916
-o {out_name}
923917
""",
924918
).strip()
925919

926920

927-
def run_command_and_check(cmd: str):
928-
cmd = shlex.split(cmd)
929-
try:
930-
subprocess.check_call(cmd)
931-
except subprocess.CalledProcessError as e:
932-
raise exc.CppCompileError(cmd, e.output) from e
933-
934-
935921
class CudaKernelParamCache:
936922
cache = dict()
937923
clear = staticmethod(cache.clear)
@@ -965,29 +951,12 @@ def compile(cls, graph, source_code, cuda):
965951
"i", "o", vec_isa=picked_vec_isa, cuda=cuda, aot_mode=graph.aot_mode
966952
)
967953
)
968-
if config.is_fbcode():
969-
ld_command = build_paths.ld()
970-
objcopy_command = build_paths.objcopy()
971-
else:
972-
ld_command = "ld"
973-
objcopy_command = "objcopy"
974954
key, input_path = write(
975955
source_code,
976956
"cpp",
977957
extra=cpp_command,
978958
specified_dir=config.aot_inductor_output_path,
979959
)
980-
981-
aot_constants = b""
982-
for tensor in graph.constants.values():
983-
aot_constants += bytes(tensor.untyped_storage().cpu())
984-
985-
consts_key, consts_path = write(
986-
aot_constants,
987-
"bin",
988-
specified_dir=config.aot_inductor_output_path,
989-
)
990-
991960
if key not in cls.cache:
992961
from filelock import FileLock
993962

@@ -997,61 +966,20 @@ def compile(cls, graph, source_code, cuda):
997966
output_so = os.path.splitext(input_path)[0] + ".so"
998967

999968
if not os.path.exists(output_so):
1000-
output_o = os.path.splitext(input_path)[0] + ".o"
1001-
cmd = cpp_compile_command(
1002-
input=input_path,
1003-
output=output_o,
1004-
vec_isa=picked_vec_isa,
1005-
cuda=cuda,
1006-
aot_mode=graph.aot_mode,
1007-
compile_only=True,
1008-
)
1009-
log.debug("aot compilation command: %s", cmd)
1010-
run_command_and_check(cmd)
1011-
1012-
consts_o = os.path.splitext(consts_path)[0] + ".o"
1013-
cmd = f"{ld_command} -r -b binary -o {consts_o} {consts_path}"
1014-
run_command_and_check(cmd)
1015-
log.debug("aot constant binary command: %s", cmd)
1016-
1017-
cmd = (
1018-
f"{objcopy_command} --rename-section"
1019-
" .data=.lrodata,alloc,load,readonly,data,contents"
1020-
f" {consts_o} {consts_o}"
1021-
)
1022-
log.debug("aot constant obj command: %s", cmd)
1023-
run_command_and_check(cmd)
1024-
1025-
cmd = f"rm {consts_path}"
1026-
log.debug("aot constant bin removal command: %s", cmd)
1027-
run_command_and_check(cmd)
1028-
1029-
body = re.sub(r"[\W_]+", "_", consts_path)
1030-
symbol_list = []
1031-
symbol_list.append(
1032-
f"{objcopy_command} --redefine-sym _binary_{body}_start=_binary_constants_bin_start {consts_o}"
1033-
)
1034-
symbol_list.append(
1035-
f"{objcopy_command} --redefine-sym _binary_{body}_start=_binary_constants_bin_size {consts_o}"
1036-
)
1037-
symbol_list.append(
1038-
f"{objcopy_command} --redefine-sym _binary_{body}_end=_binary_constants_bin_end {consts_o}"
1039-
)
1040-
log.debug(
1041-
"aot constant binary redefine symbol: %s", " ".join(symbol_list)
1042-
)
1043-
for cmd in symbol_list:
1044-
run_command_and_check(cmd)
1045-
1046-
cmd = cpp_compile_command(
1047-
input=f"{output_o} {consts_o}",
1048-
output=output_so,
1049-
vec_isa=picked_vec_isa,
1050-
cuda=cuda,
1051-
aot_mode=graph.aot_mode,
969+
cmd = shlex.split(
970+
cpp_compile_command(
971+
input=input_path,
972+
output=output_so,
973+
vec_isa=picked_vec_isa,
974+
cuda=cuda,
975+
aot_mode=graph.aot_mode,
976+
)
1052977
)
1053-
log.debug("aot linkage command: %s", cmd)
1054-
run_command_and_check(cmd)
978+
log.debug("aot compilation command: %s", " ".join(cmd))
979+
try:
980+
subprocess.check_call(cmd)
981+
except subprocess.CalledProcessError as e:
982+
raise exc.CppCompileError(cmd, e.output) from e
1055983
else:
1056984
log.debug(
1057985
"aot_inductor dynamic library already exist: %s", output_so

0 commit comments

Comments
 (0)