Skip to content

Commit b794c46

Browse files
committed
Adds DownloadModelCommand
1 parent 271c244 commit b794c46

File tree

5 files changed

+171
-15
lines changed

5 files changed

+171
-15
lines changed

tool/builder/bin/main.dart

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,24 @@
22
// Use of this source code is governed by a BSD-style license that can be
33
// found in the LICENSE file.
44

5+
import 'dart:io' as io;
56
import 'package:args/command_runner.dart';
7+
import 'package:builder/download_model.dart';
68
import 'package:builder/sync_headers.dart';
9+
import 'package:logging/logging.dart';
710

811
final runner = CommandRunner(
912
'build',
1013
'Performs build operations for google/flutter-mediapipe that '
1114
'depend on contents in this repository',
12-
)..addCommand(SyncHeadersCommand());
15+
)
16+
..addCommand(SyncHeadersCommand())
17+
..addCommand(DownloadModelCommand());
1318

14-
void main(List<String> arguments) => runner.run(arguments);
19+
void main(List<String> arguments) {
20+
Logger.root.onRecord.listen((LogRecord record) {
21+
io.stdout
22+
.writeln('${record.level.name}: ${record.time}: ${record.message}');
23+
});
24+
runner.run(arguments);
25+
}

tool/builder/lib/download_model.dart

Lines changed: 144 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,144 @@
1+
// Copyright 2014 The Flutter Authors. All rights reserved.
2+
// Use of this source code is governed by a BSD-style license that can be
3+
// found in the LICENSE file.
4+
5+
import 'dart:io' as io;
6+
import 'package:args/command_runner.dart';
7+
import 'package:builder/repo_finder.dart';
8+
import 'package:http/http.dart' as http;
9+
import 'package:logging/logging.dart';
10+
import 'package:path/path.dart' as path;
11+
12+
final _log = Logger('DownloadModelCommand');
13+
14+
enum Models {
15+
textclassification,
16+
languagedetection,
17+
}
18+
19+
class DownloadModelCommand extends Command with RepoFinderMixin {
20+
@override
21+
String description = 'Downloads a given MediaPipe model and places it in '
22+
'the designated location.';
23+
@override
24+
String name = 'model';
25+
26+
DownloadModelCommand() {
27+
argParser
28+
..addOption(
29+
'model',
30+
abbr: 'm',
31+
allowed: [
32+
// Values will be added to this as the repository gets more
33+
// integration tests that require new models.
34+
Models.textclassification.name,
35+
Models.languagedetection.name,
36+
],
37+
help: 'The desired model to download. Use this option if you want the '
38+
'standard model for a given task. Using this option also removes any '
39+
'need to use the `destination` option, as a value here implies a '
40+
'destination. However, you still can specify a destination to '
41+
'override the default location where the model is placed.\n'
42+
'\n'
43+
'Note: Either this or `custommodel` must be used. If both are '
44+
'supplied, `model` is used.',
45+
)
46+
..addOption(
47+
'custommodel',
48+
abbr: 'c',
49+
help: 'The desired model to download. Use this option if you want to '
50+
'specify a specific and nonstandard model. Using this option means '
51+
'you *must* use the `destination` option.\n'
52+
'\n'
53+
'Note: Either this or `model` must be used. If both are supplied, '
54+
'`model` is used.',
55+
)
56+
..addOption(
57+
'destination',
58+
abbr: 'd',
59+
help:
60+
'The location to place the downloaded model. This value is required '
61+
'if you use the `custommodel` option, but optional if you use the '
62+
'`model` option.',
63+
);
64+
}
65+
66+
static final Map<String, String> _standardModelSources = {
67+
Models.textclassification.name:
68+
'https://storage.googleapis.com/mediapipe-models/text_classifier/bert_classifier/float32/1/bert_classifier.tflite',
69+
Models.languagedetection.name:
70+
'https://storage.googleapis.com/mediapipe-models/language_detector/language_detector/float32/1/language_detector.tflite',
71+
};
72+
73+
static final Map<String, String> _standardModelDestinations = {
74+
Models.textclassification.name:
75+
'packages/mediapipe-task-text/example/assets/',
76+
Models.languagedetection.name:
77+
'packages/mediapipe-task-text/example/assets/',
78+
};
79+
80+
@override
81+
Future<void> run() async {
82+
final io.Directory flutterMediaPipeDirectory = findFlutterMediaPipeRoot();
83+
84+
late final String modelSource;
85+
late final String modelDestination;
86+
87+
if (argResults!['model'] != null) {
88+
modelSource = _standardModelSources[argResults!['model']]!;
89+
modelDestination = (_isArgProvided(argResults!['destination']))
90+
? argResults!['destination']
91+
: _standardModelDestinations[argResults!['model']]!;
92+
} else {
93+
if (argResults!['custommodel'] == null) {
94+
throw Exception(
95+
'You must use either the `model` or `custommodel` option.',
96+
);
97+
}
98+
if (argResults!['destination'] == null) {
99+
throw Exception(
100+
'If you do not use the `model` option, then you must supply a '
101+
'`destination`, as a "standard" destination cannot be used.',
102+
);
103+
}
104+
modelSource = argResults!['custommodel'];
105+
modelDestination = argResults!['destination'];
106+
}
107+
108+
io.File destinationFile = io.File(
109+
path.joinAll([
110+
flutterMediaPipeDirectory.absolute.path,
111+
modelDestination,
112+
modelSource.split('/').last,
113+
]),
114+
);
115+
ensureFolders(destinationFile);
116+
await downloadModel(modelSource, destinationFile);
117+
}
118+
119+
Future<void> downloadModel(
120+
String modelSource,
121+
io.File destinationFile,
122+
) async {
123+
_log.info('Downloading $modelSource');
124+
125+
// TODO(craiglabenz): Convert to StreamedResponse
126+
final response = await http.get(Uri.parse(modelSource));
127+
128+
if (response.statusCode != 200) {
129+
throw Exception('${response.statusCode} ${response.reasonPhrase} :: '
130+
'$modelSource');
131+
}
132+
133+
if (!(await destinationFile.exists())) {
134+
_log.fine('Creating file at ${destinationFile.absolute.path}');
135+
await destinationFile.create();
136+
}
137+
138+
_log.fine('Downloaded ${response.contentLength} bytes');
139+
_log.info('Saving to ${destinationFile.absolute.path}');
140+
await destinationFile.writeAsBytes(response.bodyBytes);
141+
}
142+
}
143+
144+
bool _isArgProvided(String? val) => val != null && val != '';

tool/builder/lib/repo_finder.dart

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,4 +104,17 @@ mixin RepoFinderMixin on Command {
104104
),
105105
).existsSync();
106106
}
107+
108+
/// Builds any missing folders between the file and the root of the repository
109+
void ensureFolders(io.File file) {
110+
io.Directory parent = file.parent;
111+
List<io.Directory> dirsToCreate = [];
112+
while (!parent.existsSync()) {
113+
dirsToCreate.add(parent);
114+
parent = parent.parent;
115+
}
116+
for (io.Directory dir in dirsToCreate.reversed) {
117+
dir.createSync();
118+
}
119+
}
107120
}

tool/builder/lib/sync_headers.dart

Lines changed: 0 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -133,19 +133,6 @@ class SyncHeadersCommand extends Command with RepoFinderMixin {
133133
}
134134
}
135135
}
136-
137-
/// Builds any missing folders between the file and the root of the repository
138-
void ensureFolders(io.File file) {
139-
io.Directory parent = file.parent;
140-
List<io.Directory> dirsToCreate = [];
141-
while (!parent.existsSync()) {
142-
dirsToCreate.add(parent);
143-
parent = parent.parent;
144-
}
145-
for (io.Directory dir in dirsToCreate.reversed) {
146-
dir.createSync();
147-
}
148-
}
149136
}
150137

151138
class Options {

tool/builder/pubspec.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ environment:
99
# Add regular dependencies here.
1010
dependencies:
1111
args: ^2.4.2
12+
http: ^1.1.0
1213
io: ^1.0.4
1314
logging: ^1.2.0
1415
path: ^1.8.0

0 commit comments

Comments
 (0)