Skip to content

Commit 457c76c

Browse files
pavithranraofacebook-github-bot
authored andcommitted
Extending _get_bytecode_version to support flatbuffers format (#75021)
Summary: Pull Request resolved: #75021 Extending `_get_bytecode_version` to support flatbuffers. ghstack-source-id: 152771695 (Note: this ignores all push blocking failures!) Test Plan: ``` ~/fbsource/xplat] cd ~/fbsource/xplat/ && buck test //xplat/caffe2:test_lite_interpreter Building: finished in 0.8 sec (100%) 327/327 jobs, 0/327 updated Total time: 0.9 sec Testing: finished in 06:59.5 min (85 PASS/0 FAIL) BUILD SUCCEEDED RESULTS FOR //xplat/caffe2:test_lite_interpreter PASS 412.3s 85 Passed 0 Skipped 0 Failed //xplat/caffe2:test_lite_interpreter TESTS PASSED ``` Reviewed By: iseeyuan Differential Revision: D34900498 fbshipit-source-id: 65743076d43a933c5381ec128d0268f22c0a8441
1 parent dcbd524 commit 457c76c

File tree

8 files changed

+106
-5
lines changed

8 files changed

+106
-5
lines changed

test/cpp/jit/test_flatbuffer.cpp

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -249,6 +249,23 @@ TEST(FlatbufferTest, Inline) {
249249
AT_ASSERT(output.toTensor().item<float>() == 7.0);
250250
}
251251

252+
#if defined ENABLE_FLATBUFFER
253+
TEST(FlatbufferTest, GetByteCodeVersion) {
254+
Module m("m");
255+
m.define(R"(
256+
def forward(self, input: Tensor):
257+
return input + 1
258+
)");
259+
std::stringstream ss;
260+
m._save_for_mobile(ss, {}, false, /*use_flatbuffer=*/true);
261+
auto version = _get_model_bytecode_version(ss);
262+
AT_ASSERT(version == caffe2::serialize::kProducedBytecodeVersion);
263+
ss.seekg(0, ss.beg);
264+
auto version_again = _get_model_bytecode_version(ss);
265+
AT_ASSERT(version == version_again);
266+
}
267+
#endif
268+
252269
TEST(FlatbufferTest, Tuple) {
253270
Module m("m");
254271
m.define(R"JIT(

test/cpp/jit/test_lite_interpreter.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -656,12 +656,14 @@ void backportAllVersionCheck(
656656

657657
// Check backport model version
658658
auto backport_version = _get_model_bytecode_version(oss);
659+
backport_version = _get_model_bytecode_version(oss);
659660
AT_ASSERT(backport_version == current_to_version);
660661

661662
// Load and run the backport model, then compare the result with expect
662663
// result
663664
runAndCheckBytecodeModel(
664665
oss, input_data, expect_result_list, current_to_version);
666+
oss.seekg(0, oss.beg);
665667
runAndCheckTorchScriptModel(
666668
oss, input_data, expect_result_list, current_to_version);
667669

torch/csrc/jit/mobile/compatibility/model_compatibility.cpp

Lines changed: 47 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,10 @@
33
#include <caffe2/serialize/inline_container.h>
44
#include <torch/csrc/jit/api/compilation_unit.h> // removed after using simple type_resolver/obj_loader
55
#include <torch/csrc/jit/mobile/compatibility/model_compatibility.h>
6+
#include <torch/csrc/jit/mobile/file_format.h>
7+
#if defined(ENABLE_FLATBUFFER)
8+
#include <torch/csrc/jit/mobile/flatbuffer_loader.h>
9+
#endif
610
#include <torch/csrc/jit/mobile/import.h> // removed after using simple type_resolver/obj_loader
711
#include <torch/csrc/jit/mobile/type_parser.h>
812
#include <torch/csrc/jit/serialization/import_export_constants.h>
@@ -69,13 +73,52 @@ uint64_t _get_model_bytecode_version(
6973
const std::vector<IValue>& bytecode_ivalues);
7074

7175
uint64_t _get_model_bytecode_version(std::istream& in) {
72-
std::unique_ptr<IStreamAdapter> rai = std::make_unique<IStreamAdapter>(&in);
73-
return _get_model_bytecode_version(std::move(rai));
76+
auto orig_pos = in.tellg();
77+
auto format = getFileFormat(in);
78+
switch (format) {
79+
case FileFormat::FlatbufferFileFormat: {
80+
#if !defined(ENABLE_FLATBUFFER)
81+
TORCH_CHECK(
82+
false,
83+
"Flatbuffer input file but the build hasn't enabled flatbuffer");
84+
#else
85+
return get_bytecode_version(in);
86+
#endif
87+
}
88+
case FileFormat::ZipFileFormat: {
89+
std::unique_ptr<IStreamAdapter> rai =
90+
std::make_unique<IStreamAdapter>(&in);
91+
auto version = _get_model_bytecode_version(std::move(rai));
92+
in.seekg(orig_pos, in.beg);
93+
return version;
94+
}
95+
96+
default:
97+
TORCH_CHECK(false, "Unrecognized data format");
98+
}
7499
}
75100

76101
uint64_t _get_model_bytecode_version(const std::string& filename) {
77-
std::unique_ptr<FileAdapter> rai = std::make_unique<FileAdapter>(filename);
78-
return _get_model_bytecode_version(std::move(rai));
102+
auto format = getFileFormat(filename);
103+
switch (format) {
104+
case FileFormat::FlatbufferFileFormat: {
105+
#if !defined(ENABLE_FLATBUFFER)
106+
TORCH_CHECK(
107+
false,
108+
"Flatbuffer input file but the build hasn't enabled flatbuffer");
109+
#else
110+
return get_bytecode_version(filename);
111+
#endif
112+
}
113+
case FileFormat::ZipFileFormat: {
114+
std::unique_ptr<FileAdapter> rai =
115+
std::make_unique<FileAdapter>(filename);
116+
return _get_model_bytecode_version(std::move(rai));
117+
}
118+
119+
default:
120+
TORCH_CHECK(false, "Unrecognized data format");
121+
}
79122
}
80123

81124
uint64_t _get_model_bytecode_version(

torch/csrc/jit/mobile/flatbuffer_loader.cpp

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -663,5 +663,27 @@ mobile::Module load_mobile_module_from_file(
663663
return parse_and_initialize_mobile_module(std::move(data), size, device);
664664
}
665665

666+
uint64_t get_bytecode_version(std::istream& in) {
667+
std::shared_ptr<char> data;
668+
size_t size = 0;
669+
std::tie(data, size) = get_stream_content(in);
670+
TORCH_CHECK(
671+
mobile::serialization::ModuleBufferHasIdentifier(data.get()),
672+
"Format error");
673+
auto* flatbuffer_module = mobile::serialization::GetMutableModule(data.get());
674+
return flatbuffer_module->bytecode_version();
675+
}
676+
677+
uint64_t get_bytecode_version(const std::string& filename) {
678+
std::shared_ptr<char> data;
679+
size_t size = 0;
680+
std::tie(data, size) = get_file_content(filename.c_str());
681+
TORCH_CHECK(
682+
mobile::serialization::ModuleBufferHasIdentifier(data.get()),
683+
"Format error");
684+
auto* flatbuffer_module = mobile::serialization::GetMutableModule(data.get());
685+
return flatbuffer_module->bytecode_version();
686+
}
687+
666688
} // namespace jit
667689
} // namespace torch

torch/csrc/jit/mobile/flatbuffer_loader.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,9 @@ TORCH_API std::tuple<std::shared_ptr<char>, size_t> get_file_content(
6262
TORCH_API std::tuple<std::shared_ptr<char>, size_t> get_stream_content(
6363
std::istream& in);
6464

65+
TORCH_API uint64_t get_bytecode_version(std::istream& in);
66+
TORCH_API uint64_t get_bytecode_version(const std::string& filename);
67+
6568
class TORCH_API FlatbufferLoader {
6669
public:
6770
FlatbufferLoader();

torch/csrc/jit/mobile/module.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -135,12 +135,21 @@ class TORCH_API Module {
135135
mem_to_delete_ = delete_mem;
136136
}
137137

138+
void set_bytecode_version(int64_t version) {
139+
bytecode_version_ = version;
140+
}
141+
142+
int64_t bytecode_version() const {
143+
return bytecode_version_;
144+
}
145+
138146
private:
139147
c10::intrusive_ptr<c10::ivalue::Object> object_;
140148
std::unordered_map<std::string, std::string> metadata_;
141149
std::shared_ptr<CompilationUnit> cu_;
142150
MobileDebugTable debug_table_;
143151
bool has_debug_handles_ = false;
152+
int64_t bytecode_version_;
144153

145154
// Extra handle for the module to delete when itself is deleted
146155
std::shared_ptr<char> mem_to_delete_;

torch/csrc/jit/serialization/export_bytecode.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -377,6 +377,8 @@ mobile::Module jitModuleToMobile(
377377
backend_debug_info_map.begin(), backend_debug_info_map.end());
378378
m.setDebugTable(MobileDebugTable(
379379
debug_handle_cs_ptr_map.begin(), debug_handle_cs_ptr_map.end()));
380+
381+
m.set_bytecode_version(options.model_version);
380382
return m;
381383
}
382384

torch/csrc/jit/serialization/flatbuffer_serializer.cpp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -386,9 +386,12 @@ flatbuffers::DetachedBuffer FlatbufferSerializer::serializeModule(
386386
jit_constants_indexes.emplace_back(storeIValueAndGetIndex(fbb, ival));
387387
}
388388

389+
const uint32_t bytecode_version =
390+
static_cast<uint32_t>(module.bytecode_version());
391+
389392
auto mod = CreateModule(
390393
fbb,
391-
0, /* version */
394+
/*bytecode_version=*/bytecode_version,
392395
extra_files_offset, /* extra_files */
393396
functions_offset,
394397
ivalue_index,

0 commit comments

Comments
 (0)