|
| 1 | +// Copyright 2014 The Flutter Authors. All rights reserved. |
| 2 | +// Use of this source code is governed by a BSD-style license that can be |
| 3 | +// found in the LICENSE file. |
| 4 | + |
| 5 | +import 'dart:io' as io; |
| 6 | +import 'package:args/command_runner.dart'; |
| 7 | +import 'package:builder/repo_finder.dart'; |
| 8 | +import 'package:http/http.dart' as http; |
| 9 | +import 'package:logging/logging.dart'; |
| 10 | +import 'package:path/path.dart' as path; |
| 11 | + |
| 12 | +final _log = Logger('DownloadModelCommand'); |
| 13 | + |
| 14 | +enum Models { |
| 15 | + textclassification, |
| 16 | + languagedetection, |
| 17 | +} |
| 18 | + |
| 19 | +class DownloadModelCommand extends Command with RepoFinderMixin { |
| 20 | + @override |
| 21 | + String description = 'Downloads a given MediaPipe model and places it in ' |
| 22 | + 'the designated location.'; |
| 23 | + @override |
| 24 | + String name = 'model'; |
| 25 | + |
| 26 | + DownloadModelCommand() { |
| 27 | + argParser |
| 28 | + ..addOption( |
| 29 | + 'model', |
| 30 | + abbr: 'm', |
| 31 | + allowed: [ |
| 32 | + // Values will be added to this as the repository gets more |
| 33 | + // integration tests that require new models. |
| 34 | + Models.textclassification.name, |
| 35 | + Models.languagedetection.name, |
| 36 | + ], |
| 37 | + help: 'The desired model to download. Use this option if you want the ' |
| 38 | + 'standard model for a given task. Using this option also removes any ' |
| 39 | + 'need to use the `destination` option, as a value here implies a ' |
| 40 | + 'destination. However, you still can specify a destination to ' |
| 41 | + 'override the default location where the model is placed.\n' |
| 42 | + '\n' |
| 43 | + 'Note: Either this or `custommodel` must be used. If both are ' |
| 44 | + 'supplied, `model` is used.', |
| 45 | + ) |
| 46 | + ..addOption( |
| 47 | + 'custommodel', |
| 48 | + abbr: 'c', |
| 49 | + help: 'The desired model to download. Use this option if you want to ' |
| 50 | + 'specify a specific and nonstandard model. Using this option means ' |
| 51 | + 'you *must* use the `destination` option.\n' |
| 52 | + '\n' |
| 53 | + 'Note: Either this or `model` must be used. If both are supplied, ' |
| 54 | + '`model` is used.', |
| 55 | + ) |
| 56 | + ..addOption( |
| 57 | + 'destination', |
| 58 | + abbr: 'd', |
| 59 | + help: |
| 60 | + 'The location to place the downloaded model. This value is required ' |
| 61 | + 'if you use the `custommodel` option, but optional if you use the ' |
| 62 | + '`model` option.', |
| 63 | + ); |
| 64 | + } |
| 65 | + |
| 66 | + static final Map<String, String> _standardModelSources = { |
| 67 | + Models.textclassification.name: |
| 68 | + 'https://storage.googleapis.com/mediapipe-models/text_classifier/bert_classifier/float32/1/bert_classifier.tflite', |
| 69 | + Models.languagedetection.name: |
| 70 | + 'https://storage.googleapis.com/mediapipe-models/language_detector/language_detector/float32/1/language_detector.tflite', |
| 71 | + }; |
| 72 | + |
| 73 | + static final Map<String, String> _standardModelDestinations = { |
| 74 | + Models.textclassification.name: |
| 75 | + 'packages/mediapipe-task-text/example/assets/', |
| 76 | + Models.languagedetection.name: |
| 77 | + 'packages/mediapipe-task-text/example/assets/', |
| 78 | + }; |
| 79 | + |
| 80 | + @override |
| 81 | + Future<void> run() async { |
| 82 | + final io.Directory flutterMediaPipeDirectory = findFlutterMediaPipeRoot(); |
| 83 | + |
| 84 | + late final String modelSource; |
| 85 | + late final String modelDestination; |
| 86 | + |
| 87 | + if (argResults!['model'] != null) { |
| 88 | + modelSource = _standardModelSources[argResults!['model']]!; |
| 89 | + modelDestination = (_isArgProvided(argResults!['destination'])) |
| 90 | + ? argResults!['destination'] |
| 91 | + : _standardModelDestinations[argResults!['model']]!; |
| 92 | + } else { |
| 93 | + if (argResults!['custommodel'] == null) { |
| 94 | + throw Exception( |
| 95 | + 'You must use either the `model` or `custommodel` option.', |
| 96 | + ); |
| 97 | + } |
| 98 | + if (argResults!['destination'] == null) { |
| 99 | + throw Exception( |
| 100 | + 'If you do not use the `model` option, then you must supply a ' |
| 101 | + '`destination`, as a "standard" destination cannot be used.', |
| 102 | + ); |
| 103 | + } |
| 104 | + modelSource = argResults!['custommodel']; |
| 105 | + modelDestination = argResults!['destination']; |
| 106 | + } |
| 107 | + |
| 108 | + io.File destinationFile = io.File( |
| 109 | + path.joinAll([ |
| 110 | + flutterMediaPipeDirectory.absolute.path, |
| 111 | + modelDestination, |
| 112 | + modelSource.split('/').last, |
| 113 | + ]), |
| 114 | + ); |
| 115 | + ensureFolders(destinationFile); |
| 116 | + await downloadModel(modelSource, destinationFile); |
| 117 | + } |
| 118 | + |
| 119 | + Future<void> downloadModel( |
| 120 | + String modelSource, |
| 121 | + io.File destinationFile, |
| 122 | + ) async { |
| 123 | + _log.info('Downloading $modelSource'); |
| 124 | + |
| 125 | + // TODO(craiglabenz): Convert to StreamedResponse |
| 126 | + final response = await http.get(Uri.parse(modelSource)); |
| 127 | + |
| 128 | + if (response.statusCode != 200) { |
| 129 | + throw Exception('${response.statusCode} ${response.reasonPhrase} :: ' |
| 130 | + '$modelSource'); |
| 131 | + } |
| 132 | + |
| 133 | + if (!(await destinationFile.exists())) { |
| 134 | + _log.fine('Creating file at ${destinationFile.absolute.path}'); |
| 135 | + await destinationFile.create(); |
| 136 | + } |
| 137 | + |
| 138 | + _log.fine('Downloaded ${response.contentLength} bytes'); |
| 139 | + _log.info('Saving to ${destinationFile.absolute.path}'); |
| 140 | + await destinationFile.writeAsBytes(response.bodyBytes); |
| 141 | + } |
| 142 | +} |
| 143 | + |
| 144 | +bool _isArgProvided(String? val) => val != null && val != ''; |
0 commit comments