Skip to content

Commit c1ed6b9

Browse files
committed
model memory troubleshooting
1 parent ebf1df1 commit c1ed6b9

File tree

6 files changed

+73
-74
lines changed

6 files changed

+73
-74
lines changed

packages/mediapipe-core/lib/src/task_options.dart

Lines changed: 49 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -21,40 +21,83 @@ import 'third_party/mediapipe/generated/mediapipe_common_bindings.dart'
2121
/// classifier's desired behavior.
2222
class BaseOptions extends Equatable {
2323
/// Generative constructor that creates a [BaseOptions] instance.
24-
const BaseOptions({this.modelAssetBuffer, this.modelAssetPath})
25-
: assert(
24+
const BaseOptions._({
25+
this.modelAssetBuffer,
26+
this.modelAssetPath,
27+
this.modelAssetBufferCount,
28+
required _BaseOptionsType type,
29+
}) : assert(
2630
!(modelAssetBuffer == null && modelAssetPath == null),
2731
'You must supply either `modelAssetBuffer` or `modelAssetPath`',
2832
),
2933
assert(
3034
!(modelAssetBuffer != null && modelAssetPath != null),
3135
'You must only supply one of `modelAssetBuffer` and `modelAssetPath`',
32-
);
36+
),
37+
assert(
38+
(modelAssetBuffer == null) == (modelAssetBufferCount == null),
39+
'modelAssetBuffer and modelAssetBufferCount must only be submitted '
40+
'together',
41+
),
42+
_type = type;
43+
44+
/// Constructor for [BaseOptions] classes using a file system path.
45+
///
46+
/// In practice, this is unsupported, as assets in Flutter are bundled into
47+
/// the build output and not available on disk. However, it can potentially
48+
/// be helpful for testing / development purposes.
49+
factory BaseOptions.path(String path) => BaseOptions._(
50+
modelAssetPath: path,
51+
type: _BaseOptionsType.path,
52+
);
53+
54+
/// Constructor for [BaseOptions] classes using an in-memory pointer to the
55+
/// MediaPipe SDK.
56+
///
57+
/// In practice, this is the only option supported for production builds.
58+
factory BaseOptions.memory(Uint8List buffer) {
59+
return BaseOptions._(
60+
modelAssetBuffer: buffer,
61+
modelAssetBufferCount: buffer.lengthInBytes,
62+
type: _BaseOptionsType.memory,
63+
);
64+
}
3365

3466
/// The model asset file contents as bytes;
3567
final Uint8List? modelAssetBuffer;
3668

69+
/// The size of the model assets buffer (or `0` if not set).
70+
final int? modelAssetBufferCount;
71+
3772
/// Path to the model asset file.
3873
final String? modelAssetPath;
3974

75+
final _BaseOptionsType _type;
76+
4077
/// Converts this pure-Dart representation into C-memory suitable for the
4178
/// MediaPipe SDK to instantiate various classifiers.
4279
Pointer<bindings.BaseOptions> toStruct() {
4380
final struct = calloc<bindings.BaseOptions>();
4481

45-
if (modelAssetPath != null) {
82+
if (_type == _BaseOptionsType.path) {
4683
struct.ref.model_asset_path = prepareString(modelAssetPath!);
4784
}
48-
if (modelAssetBuffer != null) {
85+
if (_type == _BaseOptionsType.memory) {
4986
struct.ref.model_asset_buffer = prepareUint8List(modelAssetBuffer!);
5087
}
5188
return struct;
5289
}
5390

5491
@override
55-
List<Object?> get props => [modelAssetBuffer, modelAssetPath];
92+
List<Object?> get props => [
93+
modelAssetBuffer,
94+
modelAssetPath,
95+
modelAssetBufferCount,
96+
];
5697
}
5798

99+
enum _BaseOptionsType { path, memory }
100+
58101
/// Dart representation of MediaPipe's "ClassifierOptions" concept.
59102
///
60103
/// Classifier options shared across MediaPipe classification tasks.

packages/mediapipe-core/lib/src/third_party/mediapipe/generated/mediapipe_common_bindings.dart

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,9 @@ final class BaseOptions extends ffi.Struct {
2323
external ffi.Pointer<ffi.Char> model_asset_buffer;
2424

2525
external ffi.Pointer<ffi.Char> model_asset_path;
26+
27+
@ffi.Int()
28+
external int model_asset_buffer_count;
2629
}
2730

2831
final class __mbstate_t extends ffi.Union {

packages/mediapipe-core/test/task_options_test.dart

Lines changed: 3 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -9,32 +9,16 @@ import 'package:test/test.dart';
99
import 'package:mediapipe_core/mediapipe_core.dart';
1010

1111
void main() {
12-
group('BaseOptions constructor should', () {
13-
test('enforce exactly one of modelPath and modelBuffer', () {
14-
expect(
15-
() => BaseOptions(
16-
modelAssetPath: 'abc',
17-
modelAssetBuffer: Uint8List.fromList([1, 2, 3]),
18-
),
19-
throwsA(TypeMatcher<AssertionError>()),
20-
);
21-
22-
expect(BaseOptions.new, throwsA(TypeMatcher<AssertionError>()));
23-
});
24-
});
25-
2612
group('BaseOptions.toStruct/fromStruct should', () {
2713
test('allocate memory in C for a modelAssetPath', () {
28-
final options = BaseOptions(modelAssetPath: 'abc');
14+
final options = BaseOptions.path('abc');
2915
final struct = options.toStruct();
3016
expect(toDartString(struct.ref.model_asset_path), 'abc');
3117
expectNullPtr(struct.ref.model_asset_buffer);
3218
});
3319

3420
test('allocate memory in C for a modelAssetBuffer', () {
35-
final options = BaseOptions(
36-
modelAssetBuffer: Uint8List.fromList([1, 2, 3]),
37-
);
21+
final options = BaseOptions.memory(Uint8List.fromList([1, 2, 3]));
3822
final struct = options.toStruct();
3923
expect(
4024
toUint8List(struct.ref.model_asset_buffer),
@@ -44,9 +28,7 @@ void main() {
4428
});
4529

4630
test('allocate memory in C for a modelAssetBuffer containing 0', () {
47-
final options = BaseOptions(
48-
modelAssetBuffer: Uint8List.fromList([1, 2, 0, 3]),
49-
);
31+
final options = BaseOptions.memory(Uint8List.fromList([1, 2, 0, 3]));
5032
final struct = options.toStruct();
5133
expect(
5234
toUint8List(struct.ref.model_asset_buffer),

packages/mediapipe-core/third_party/mediapipe/tasks/c/core/base_options.h

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -20,14 +20,17 @@ limitations under the License.
2020
extern "C" {
2121
#endif
2222

23-
// Base options for MediaPipe C Tasks.
24-
struct BaseOptions {
25-
// The model asset file contents as a string.
26-
char* model_asset_buffer;
27-
28-
// The path to the model asset to open and mmap in memory.
29-
char* model_asset_path;
30-
};
23+
// Base options for MediaPipe C Tasks.
24+
struct BaseOptions {
25+
// The model asset file contents as a string.
26+
const char *model_asset_buffer;
27+
28+
// The path to the model asset to open and mmap in memory.
29+
const char *model_asset_path;
30+
31+
// The size of the model assets buffer (or `0` if not set).
32+
const size_t model_asset_buffer_count;
33+
};
3134

3235
#ifdef __cplusplus
3336
} // extern C

packages/mediapipe-task-text/example/lib/main.dart

Lines changed: 5 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,8 @@
11
import 'dart:io' as io;
22
import 'dart:typed_data';
33
import 'package:logging/logging.dart';
4-
import 'package:path/path.dart' as path;
54
import 'package:flutter/material.dart';
65
import 'package:mediapipe_text/mediapipe_text.dart';
7-
import 'package:path_provider/path_provider.dart';
86

97
final _log = Logger('TextClassificationExample');
108

@@ -33,54 +31,24 @@ class _MainAppState extends State<MainApp> {
3331
late final TextClassifier _classifier;
3432
final TextEditingController _controller = TextEditingController();
3533
String? results;
34+
late final ByteData classifierBytes;
3635

3736
@override
3837
void initState() {
3938
super.initState();
4039
_controller.text = 'Hello, world!';
4140
_initClassifier();
41+
Future.delayed(const Duration(milliseconds: 500)).then((_) => _classify());
4242
}
4343

4444
Future<void> _initClassifier() async {
45-
// getApplicationDocumentsDirectory().then((dir) {
46-
// print(dir.absolute.path);
47-
// });
48-
// final dir = await getApplicationSupportDirectory();
49-
// print('app support: ${dir.absolute.path}');
50-
51-
// DefaultAssetBundle.of(context).
52-
53-
final ByteData classifierBytes = await DefaultAssetBundle.of(context)
45+
classifierBytes = await DefaultAssetBundle.of(context)
5446
.load('assets/bert_classifier.tflite');
5547

56-
// final dir = io.Directory(path.current);
57-
// final modelPath = path.joinAll(
58-
// [dir.absolute.path, 'assets/bert_classifier.tflite'],
59-
// );
60-
// _log.finest('modelPath: $modelPath');
61-
// if (io.File(modelPath).existsSync()) {
62-
// _log.fine('Successfully found model.');
63-
// } else {
64-
// _log.severe('Invalid model path \n\t$modelPath.\n\nModel not found.');
65-
// io.exit(1);
66-
// }
67-
68-
// final sdkPath = path.joinAll(
69-
// [dir.absolute.path, 'assets/libtext_classifier.dylib'],
70-
// );
71-
// _log.finest('sdkPath: $sdkPath');
72-
// if (io.File(sdkPath).existsSync()) {
73-
// _log.fine('Successfully found SDK.');
74-
// } else {
75-
// _log.severe('Invalid SDK path $sdkPath. SDK not found.');
76-
// io.exit(1);
77-
// }
78-
7948
_classifier = TextClassifier(
80-
// options: TextClassifierOptions.fromAssetPath(modelPath),
8149
options: TextClassifierOptions.fromAssetBuffer(
82-
Uint8List.view(classifierBytes.buffer)),
83-
// sdkPath: sdkPath,
50+
classifierBytes.buffer.asUint8List(),
51+
),
8452
);
8553
}
8654

packages/mediapipe-task-text/lib/src/tasks/text_classification/containers/text_classifier_options.dart

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ class TextClassifierOptions {
2727
}) {
2828
assert(!kIsWeb, 'fromAssetPath cannot be used on the web');
2929
return TextClassifierOptions(
30-
baseOptions: BaseOptions(modelAssetPath: assetPath),
30+
baseOptions: BaseOptions.path(assetPath),
3131
classifierOptions: classifierOptions,
3232
);
3333
}
@@ -40,7 +40,7 @@ class TextClassifierOptions {
4040
ClassifierOptions? classifierOptions,
4141
}) =>
4242
TextClassifierOptions(
43-
baseOptions: BaseOptions(modelAssetBuffer: assetBuffer),
43+
baseOptions: BaseOptions.memory(assetBuffer),
4444
classifierOptions: classifierOptions,
4545
);
4646

0 commit comments

Comments
 (0)