Skip to content

Commit 1deb532

Browse files
annzimmerrlazo
andauthored
Adding get Model for local downloads (#2212)
* Adding get Model for local downloads - only handles local model and unconditional downloads at this time. * Adding loadNewlyDownloadedFiles to listModels. * Adding loadNewlyDownloadedFiles to listModels. * Adding loadNewlyDownloadedFiles to listModels. * Update formatting. * Update firebase-ml-modeldownloader/src/main/java/com/google/firebase/ml/modeldownloader/CustomModel.java Co-authored-by: Rodrigo Lazo <[email protected]> * Update firebase-ml-modeldownloader/src/main/java/com/google/firebase/ml/modeldownloader/FirebaseModelDownloader.java Co-authored-by: Rodrigo Lazo <[email protected]> * Update firebase-ml-modeldownloader/src/main/java/com/google/firebase/ml/modeldownloader/internal/ModelFileManager.java Co-authored-by: Rodrigo Lazo <[email protected]> * Updating formatting after reviewer sugggested changes. * Updating firebaseInstallationApi component to avoid error message. * Updating firebaseInstallationApi component to avoid error message. Co-authored-by: Rodrigo Lazo <[email protected]>
1 parent 837d136 commit 1deb532

17 files changed

+1465
-69
lines changed

firebase-ml-modeldownloader/api.txt

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ package com.google.firebase.ml.modeldownloader {
33

44
public class CustomModel {
55
method public long getDownloadId();
6-
method @Nullable public java.io.File getFile();
6+
method @Nullable public java.io.File getFile() throws java.lang.Exception;
77
method @NonNull public String getModelHash();
88
method @NonNull public String getName();
99
method public long getSize();
@@ -33,9 +33,19 @@ package com.google.firebase.ml.modeldownloader {
3333
method @NonNull public com.google.android.gms.tasks.Task<java.lang.Void> deleteDownloadedModel(@NonNull String);
3434
method @NonNull public static com.google.firebase.ml.modeldownloader.FirebaseModelDownloader getInstance();
3535
method @NonNull public static com.google.firebase.ml.modeldownloader.FirebaseModelDownloader getInstance(@NonNull com.google.firebase.FirebaseApp);
36-
method @NonNull public com.google.android.gms.tasks.Task<com.google.firebase.ml.modeldownloader.CustomModel> getModel(@NonNull String, @NonNull com.google.firebase.ml.modeldownloader.DownloadType, @Nullable com.google.firebase.ml.modeldownloader.CustomModelDownloadConditions);
36+
method @NonNull public com.google.android.gms.tasks.Task<com.google.firebase.ml.modeldownloader.CustomModel> getModel(@NonNull String, @NonNull com.google.firebase.ml.modeldownloader.DownloadType, @Nullable com.google.firebase.ml.modeldownloader.CustomModelDownloadConditions) throws java.lang.Exception;
3737
method @NonNull public com.google.android.gms.tasks.Task<java.util.Set<com.google.firebase.ml.modeldownloader.CustomModel>> listDownloadedModels();
3838
}
3939

4040
}
4141

42+
package com.google.firebase.ml.modeldownloader.internal {
43+
44+
public class ModelFileManager {
45+
ctor public ModelFileManager(@NonNull com.google.firebase.FirebaseApp);
46+
method @NonNull public static com.google.firebase.ml.modeldownloader.internal.ModelFileManager getInstance();
47+
method @Nullable @WorkerThread public java.io.File moveModelToDestinationFolder(@NonNull com.google.firebase.ml.modeldownloader.CustomModel, @NonNull android.os.ParcelFileDescriptor) throws java.lang.Exception;
48+
}
49+
50+
}
51+

firebase-ml-modeldownloader/firebase-ml-modeldownloader.gradle

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ android {
2525
compileSdkVersion project.targetSdkVersion
2626

2727
defaultConfig {
28-
minSdkVersion project.minSdkVersion
28+
minSdkVersion 16
2929
targetSdkVersion project.targetSdkVersion
3030
multiDexEnabled true
3131
versionName version

firebase-ml-modeldownloader/src/main/java/com/google/firebase/ml/modeldownloader/CustomModel.java

Lines changed: 55 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,9 @@
1616

1717
import androidx.annotation.NonNull;
1818
import androidx.annotation.Nullable;
19+
import androidx.annotation.VisibleForTesting;
1920
import com.google.android.gms.common.internal.Objects;
21+
import com.google.firebase.ml.modeldownloader.internal.ModelFileDownloadService;
2022
import java.io.File;
2123

2224
/**
@@ -127,11 +129,30 @@ public String getName() {
127129
* progress, returns null, if file update is in progress returns last fully uploaded model.
128130
*/
129131
@Nullable
130-
public File getFile() {
132+
public File getFile() throws Exception {
133+
return getFile(ModelFileDownloadService.getInstance());
134+
}
135+
136+
/**
137+
* The local model file. If null is returned, use the download Id to check the download status.
138+
*
139+
* @return the local file associated with the model. If the original file download is still in
140+
* progress, returns null. If file update is in progress, returns the last fully uploaded
141+
* model.
142+
*/
143+
@Nullable
144+
@VisibleForTesting
145+
File getFile(ModelFileDownloadService fileDownloadService) throws Exception {
146+
// check for completed download
147+
File newDownloadFile = fileDownloadService.loadNewlyDownloadedModelFile(this);
148+
if (newDownloadFile != null) {
149+
return newDownloadFile;
150+
}
151+
// return local file, if present
131152
if (localFilePath == null || localFilePath.isEmpty()) {
132153
return null;
133154
}
134-
throw new UnsupportedOperationException("Not implemented, file retrieval coming soon.");
155+
return new File(localFilePath);
135156
}
136157

137158
/**
@@ -144,7 +165,11 @@ public long getSize() {
144165
return fileSize;
145166
}
146167

147-
/** @return the model hash */
168+
/**
169+
* Retrieves the model Hash.
170+
*
171+
* @return the model hash
172+
*/
148173
@NonNull
149174
public String getModelHash() {
150175
return modelHash;
@@ -161,6 +186,31 @@ public long getDownloadId() {
161186
return downloadId;
162187
}
163188

189+
@NonNull
190+
@Override
191+
public String toString() {
192+
Objects.ToStringHelper stringHelper =
193+
Objects.toStringHelper(this)
194+
.add("name", name)
195+
.add("modelHash", modelHash)
196+
.add("fileSize", fileSize);
197+
198+
if (localFilePath != null && !localFilePath.isEmpty()) {
199+
stringHelper.add("localFilePath", localFilePath);
200+
}
201+
if (downloadId != 0L) {
202+
stringHelper.add("downloadId", downloadId);
203+
}
204+
if (downloadUrl != null && !downloadUrl.isEmpty()) {
205+
stringHelper.add("downloadUrl", downloadUrl);
206+
}
207+
if (downloadUrlExpiry != 0L && !localFilePath.isEmpty()) {
208+
stringHelper.add("downloadUrlExpiry", downloadUrlExpiry);
209+
}
210+
211+
return stringHelper.toString();
212+
}
213+
164214
@Override
165215
public boolean equals(Object o) {
166216
if (o == this) {
@@ -200,6 +250,8 @@ public long getDownloadUrlExpiry() {
200250
}
201251

202252
/**
253+
* Returns the model download url, usually only present when download is about to occur.
254+
*
203255
* @return the model download url
204256
* <p>Internal use only
205257
* @hide

firebase-ml-modeldownloader/src/main/java/com/google/firebase/ml/modeldownloader/FirebaseModelDownloader.java

Lines changed: 77 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,14 +13,20 @@
1313
// limitations under the License.
1414
package com.google.firebase.ml.modeldownloader;
1515

16+
import android.os.Build.VERSION_CODES;
1617
import androidx.annotation.NonNull;
1718
import androidx.annotation.Nullable;
19+
import androidx.annotation.RequiresApi;
1820
import androidx.annotation.VisibleForTesting;
1921
import com.google.android.gms.common.internal.Preconditions;
2022
import com.google.android.gms.tasks.Task;
2123
import com.google.android.gms.tasks.TaskCompletionSource;
24+
import com.google.android.gms.tasks.Tasks;
2225
import com.google.firebase.FirebaseApp;
2326
import com.google.firebase.FirebaseOptions;
27+
import com.google.firebase.installations.FirebaseInstallationsApi;
28+
import com.google.firebase.ml.modeldownloader.internal.CustomModelDownloadService;
29+
import com.google.firebase.ml.modeldownloader.internal.ModelFileDownloadService;
2430
import com.google.firebase.ml.modeldownloader.internal.SharedPreferencesUtil;
2531
import java.util.Set;
2632
import java.util.concurrent.Executor;
@@ -30,21 +36,32 @@ public class FirebaseModelDownloader {
3036

3137
private final FirebaseOptions firebaseOptions;
3238
private final SharedPreferencesUtil sharedPreferencesUtil;
39+
private final ModelFileDownloadService fileDownloadService;
40+
private final CustomModelDownloadService modelDownloadService;
3341
private final Executor executor;
3442

35-
FirebaseModelDownloader(FirebaseApp firebaseApp) {
43+
@RequiresApi(api = VERSION_CODES.KITKAT)
44+
FirebaseModelDownloader(
45+
FirebaseApp firebaseApp, FirebaseInstallationsApi firebaseInstallationsApi) {
3646
this.firebaseOptions = firebaseApp.getOptions();
47+
this.fileDownloadService = new ModelFileDownloadService(firebaseApp);
3748
this.sharedPreferencesUtil = new SharedPreferencesUtil(firebaseApp);
49+
this.modelDownloadService =
50+
new CustomModelDownloadService(firebaseOptions, firebaseInstallationsApi);
3851
this.executor = Executors.newCachedThreadPool();
3952
}
4053

4154
@VisibleForTesting
4255
FirebaseModelDownloader(
4356
FirebaseOptions firebaseOptions,
4457
SharedPreferencesUtil sharedPreferencesUtil,
58+
ModelFileDownloadService fileDownloadService,
59+
CustomModelDownloadService modelDownloadService,
4560
Executor executor) {
4661
this.firebaseOptions = firebaseOptions;
4762
this.sharedPreferencesUtil = sharedPreferencesUtil;
63+
this.fileDownloadService = fileDownloadService;
64+
this.modelDownloadService = modelDownloadService;
4865
this.executor = executor;
4966
}
5067

@@ -95,21 +112,77 @@ public static FirebaseModelDownloader getInstance(@NonNull FirebaseApp app) {
95112
public Task<CustomModel> getModel(
96113
@NonNull String modelName,
97114
@NonNull DownloadType downloadType,
98-
@Nullable CustomModelDownloadConditions conditions) {
115+
@Nullable CustomModelDownloadConditions conditions)
116+
throws Exception {
117+
CustomModel localModel = sharedPreferencesUtil.getCustomModelDetails(modelName);
118+
switch (downloadType) {
119+
case LOCAL_MODEL:
120+
if (localModel != null) {
121+
return Tasks.forResult(localModel);
122+
}
123+
Task<CustomModel> modelDetails =
124+
modelDownloadService.getCustomModelDetails(
125+
firebaseOptions.getProjectId(), modelName, null);
126+
127+
// no local model - start download.
128+
return modelDetails.continueWithTask(
129+
executor,
130+
modelDetailTask -> {
131+
if (modelDetailTask.isSuccessful()) {
132+
// start download
133+
return fileDownloadService
134+
.download(modelDetailTask.getResult(), conditions)
135+
.continueWithTask(
136+
executor,
137+
downloadTask -> {
138+
if (downloadTask.isSuccessful()) {
139+
// read the updated model
140+
CustomModel downloadedModel =
141+
sharedPreferencesUtil.getCustomModelDetails(modelName);
142+
// TODO(annz) trigger file move here as well... right now it's temp
143+
// call loadNewlyDownloadedModelFile
144+
return Tasks.forResult(downloadedModel);
145+
}
146+
return Tasks.forException(new Exception("File download failed."));
147+
});
148+
}
149+
return Tasks.forException(modelDetailTask.getException());
150+
});
151+
case LATEST_MODEL:
152+
// check for latest model and download newest
153+
break;
154+
case LOCAL_MODEL_UPDATE_IN_BACKGROUND:
155+
// start download in back ground return current model if not null.
156+
break;
157+
}
99158
throw new UnsupportedOperationException("Not yet implemented.");
100159
}
101160

102-
/** @return The set of all models that are downloaded to this device. */
161+
/**
162+
* Triggers the move to permanent storage of successful model downloads and lists all models
163+
* downloaded to device.
164+
*
165+
* @return The set of all models that are downloaded to this device, triggers completion of file
166+
* moves for completed model downloads.
167+
*/
103168
@NonNull
104169
public Task<Set<CustomModel>> listDownloadedModels() {
170+
// trigger completion of file moves for download files.
171+
try {
172+
fileDownloadService.maybeCheckDownloadingComplete();
173+
} catch (Exception ex) {
174+
System.out.println("Error checking for in progress downloads: " + ex.getMessage());
175+
}
176+
105177
TaskCompletionSource<Set<CustomModel>> taskCompletionSource = new TaskCompletionSource<>();
106178
executor.execute(
107179
() -> taskCompletionSource.setResult(sharedPreferencesUtil.listDownloadedModels()));
108180
return taskCompletionSource.getTask();
109181
}
110182

111-
/*
183+
/**
112184
* Delete old local models, when no longer in use.
185+
*
113186
* @param modelName - name of the model
114187
*/
115188
@NonNull

firebase-ml-modeldownloader/src/main/java/com/google/firebase/ml/modeldownloader/FirebaseModelDownloaderRegistrar.java

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@
2424
import com.google.firebase.components.Dependency;
2525
import com.google.firebase.installations.FirebaseInstallationsApi;
2626
import com.google.firebase.ml.modeldownloader.internal.CustomModelDownloadService;
27+
import com.google.firebase.ml.modeldownloader.internal.ModelFileDownloadService;
28+
import com.google.firebase.ml.modeldownloader.internal.ModelFileManager;
2729
import com.google.firebase.ml.modeldownloader.internal.SharedPreferencesUtil;
2830
import com.google.firebase.platforminfo.LibraryVersionComponent;
2931
import java.util.Arrays;
@@ -44,12 +46,24 @@ public List<Component<?>> getComponents() {
4446
return Arrays.asList(
4547
Component.builder(FirebaseModelDownloader.class)
4648
.add(Dependency.required(FirebaseApp.class))
47-
.factory(c -> new FirebaseModelDownloader(c.get(FirebaseApp.class)))
49+
.add(Dependency.required(FirebaseInstallationsApi.class))
50+
.factory(
51+
c ->
52+
new FirebaseModelDownloader(
53+
c.get(FirebaseApp.class), c.get(FirebaseInstallationsApi.class)))
4854
.build(),
4955
Component.builder(SharedPreferencesUtil.class)
5056
.add(Dependency.required(FirebaseApp.class))
5157
.factory(c -> new SharedPreferencesUtil(c.get(FirebaseApp.class)))
5258
.build(),
59+
Component.builder(ModelFileManager.class)
60+
.add(Dependency.required(FirebaseApp.class))
61+
.factory(c -> new ModelFileManager(c.get(FirebaseApp.class)))
62+
.build(),
63+
Component.builder(ModelFileDownloadService.class)
64+
.add(Dependency.required(FirebaseApp.class))
65+
.factory(c -> new ModelFileDownloadService(c.get(FirebaseApp.class)))
66+
.build(),
5367
Component.builder(CustomModelDownloadService.class)
5468
.add(Dependency.required(FirebaseOptions.class))
5569
.add(Dependency.required(FirebaseInstallationsApi.class))

firebase-ml-modeldownloader/src/main/java/com/google/firebase/ml/modeldownloader/internal/CustomModelDownloadService.java

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,10 @@
1414

1515
package com.google.firebase.ml.modeldownloader.internal;
1616

17-
import android.os.Build.VERSION_CODES;
1817
import android.util.JsonReader;
1918
import android.util.Log;
2019
import androidx.annotation.NonNull;
2120
import androidx.annotation.Nullable;
22-
import androidx.annotation.RequiresApi;
2321
import com.google.android.gms.common.util.VisibleForTesting;
2422
import com.google.android.gms.tasks.Task;
2523
import com.google.android.gms.tasks.Tasks;
@@ -34,7 +32,6 @@
3432
import java.net.HttpURLConnection;
3533
import java.net.URL;
3634
import java.nio.charset.Charset;
37-
import java.nio.charset.StandardCharsets;
3835
import java.text.ParseException;
3936
import java.text.SimpleDateFormat;
4037
import java.util.Date;
@@ -51,12 +48,10 @@
5148
*
5249
* @hide
5350
*/
54-
@RequiresApi(api = VERSION_CODES.KITKAT)
55-
public final class CustomModelDownloadService {
56-
51+
public class CustomModelDownloadService {
5752
private static final String TAG = "CustomModelDownloadSer";
5853
private static final int CONNECTION_TIME_OUT_MS = 2000; // 2 seconds.
59-
private static final Charset UTF_8 = StandardCharsets.UTF_8;
54+
private static final Charset UTF_8 = Charset.forName("UTF-8");
6055
private static final String ISO_DATE_PATTERN = "yyyy-MM-dd'T'HH:mm:ss.SSS'Z'";
6156
private static final String ACCEPT_ENCODING_HEADER_KEY = "Accept-Encoding";
6257
private static final String CONTENT_ENCODING_HEADER_KEY = "Content-Encoding";
@@ -76,8 +71,8 @@ public final class CustomModelDownloadService {
7671
static final String DOWNLOAD_MODEL_REGEX = "%s/v1beta2/projects/%s/models/%s:download";
7772

7873
private final ExecutorService executorService;
79-
private FirebaseInstallationsApi firebaseInstallations;
80-
private String apiKey;
74+
private final FirebaseInstallationsApi firebaseInstallations;
75+
private final String apiKey;
8176
private String downloadHost = FIREBASE_DOWNLOAD_HOST;
8277

8378
public CustomModelDownloadService(

0 commit comments

Comments
 (0)