Skip to content

Adding get Model for local downloads #2212

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 11 commits into from
Dec 3, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 12 additions & 2 deletions firebase-ml-modeldownloader/api.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ package com.google.firebase.ml.modeldownloader {

public class CustomModel {
method public long getDownloadId();
method @Nullable public java.io.File getFile();
method @Nullable public java.io.File getFile() throws java.lang.Exception;
method @NonNull public String getModelHash();
method @NonNull public String getName();
method public long getSize();
Expand Down Expand Up @@ -33,9 +33,19 @@ package com.google.firebase.ml.modeldownloader {
method @NonNull public com.google.android.gms.tasks.Task<java.lang.Void> deleteDownloadedModel(@NonNull String);
method @NonNull public static com.google.firebase.ml.modeldownloader.FirebaseModelDownloader getInstance();
method @NonNull public static com.google.firebase.ml.modeldownloader.FirebaseModelDownloader getInstance(@NonNull com.google.firebase.FirebaseApp);
method @NonNull public com.google.android.gms.tasks.Task<com.google.firebase.ml.modeldownloader.CustomModel> getModel(@NonNull String, @NonNull com.google.firebase.ml.modeldownloader.DownloadType, @Nullable com.google.firebase.ml.modeldownloader.CustomModelDownloadConditions);
method @NonNull public com.google.android.gms.tasks.Task<com.google.firebase.ml.modeldownloader.CustomModel> getModel(@NonNull String, @NonNull com.google.firebase.ml.modeldownloader.DownloadType, @Nullable com.google.firebase.ml.modeldownloader.CustomModelDownloadConditions) throws java.lang.Exception;
method @NonNull public com.google.android.gms.tasks.Task<java.util.Set<com.google.firebase.ml.modeldownloader.CustomModel>> listDownloadedModels();
}

}

package com.google.firebase.ml.modeldownloader.internal {

public class ModelFileManager {
ctor public ModelFileManager(@NonNull com.google.firebase.FirebaseApp);
method @NonNull public static com.google.firebase.ml.modeldownloader.internal.ModelFileManager getInstance();
method @Nullable @WorkerThread public java.io.File moveModelToDestinationFolder(@NonNull com.google.firebase.ml.modeldownloader.CustomModel, @NonNull android.os.ParcelFileDescriptor) throws java.lang.Exception;
}

}

Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ android {
compileSdkVersion project.targetSdkVersion

defaultConfig {
minSdkVersion project.minSdkVersion
minSdkVersion 16
targetSdkVersion project.targetSdkVersion
multiDexEnabled true
versionName version
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,9 @@

import androidx.annotation.NonNull;
import androidx.annotation.Nullable;
import androidx.annotation.VisibleForTesting;
import com.google.android.gms.common.internal.Objects;
import com.google.firebase.ml.modeldownloader.internal.ModelFileDownloadService;
import java.io.File;

/**
Expand Down Expand Up @@ -127,11 +129,30 @@ public String getName() {
* progress, returns null, if file update is in progress returns last fully uploaded model.
*/
@Nullable
public File getFile() {
public File getFile() throws Exception {
return getFile(ModelFileDownloadService.getInstance());
}

/**
* The local model file. If null is returned, use the download Id to check the download status.
*
* @return the local file associated with the model. If the original file download is still in
* progress, returns null. If file update is in progress, returns the last fully uploaded
* model.
*/
@Nullable
@VisibleForTesting
File getFile(ModelFileDownloadService fileDownloadService) throws Exception {
// check for completed download
File newDownloadFile = fileDownloadService.loadNewlyDownloadedModelFile(this);
if (newDownloadFile != null) {
return newDownloadFile;
}
// return local file, if present
if (localFilePath == null || localFilePath.isEmpty()) {
return null;
}
throw new UnsupportedOperationException("Not implemented, file retrieval coming soon.");
return new File(localFilePath);
}

/**
Expand All @@ -144,7 +165,11 @@ public long getSize() {
return fileSize;
}

/** @return the model hash */
/**
* Retrieves the model Hash.
*
* @return the model hash
*/
@NonNull
public String getModelHash() {
return modelHash;
Expand All @@ -161,6 +186,31 @@ public long getDownloadId() {
return downloadId;
}

@NonNull
@Override
public String toString() {
Objects.ToStringHelper stringHelper =
Objects.toStringHelper(this)
.add("name", name)
.add("modelHash", modelHash)
.add("fileSize", fileSize);

if (localFilePath != null && !localFilePath.isEmpty()) {
stringHelper.add("localFilePath", localFilePath);
}
if (downloadId != 0L) {
stringHelper.add("downloadId", downloadId);
}
if (downloadUrl != null && !downloadUrl.isEmpty()) {
stringHelper.add("downloadUrl", downloadUrl);
}
if (downloadUrlExpiry != 0L && !localFilePath.isEmpty()) {
stringHelper.add("downloadUrlExpiry", downloadUrlExpiry);
}

return stringHelper.toString();
}

@Override
public boolean equals(Object o) {
if (o == this) {
Expand Down Expand Up @@ -200,6 +250,8 @@ public long getDownloadUrlExpiry() {
}

/**
* Returns the model download url, usually only present when download is about to occur.
*
* @return the model download url
* <p>Internal use only
* @hide
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,20 @@
// limitations under the License.
package com.google.firebase.ml.modeldownloader;

import android.os.Build.VERSION_CODES;
import androidx.annotation.NonNull;
import androidx.annotation.Nullable;
import androidx.annotation.RequiresApi;
import androidx.annotation.VisibleForTesting;
import com.google.android.gms.common.internal.Preconditions;
import com.google.android.gms.tasks.Task;
import com.google.android.gms.tasks.TaskCompletionSource;
import com.google.android.gms.tasks.Tasks;
import com.google.firebase.FirebaseApp;
import com.google.firebase.FirebaseOptions;
import com.google.firebase.installations.FirebaseInstallationsApi;
import com.google.firebase.ml.modeldownloader.internal.CustomModelDownloadService;
import com.google.firebase.ml.modeldownloader.internal.ModelFileDownloadService;
import com.google.firebase.ml.modeldownloader.internal.SharedPreferencesUtil;
import java.util.Set;
import java.util.concurrent.Executor;
Expand All @@ -30,21 +36,32 @@ public class FirebaseModelDownloader {

private final FirebaseOptions firebaseOptions;
private final SharedPreferencesUtil sharedPreferencesUtil;
private final ModelFileDownloadService fileDownloadService;
private final CustomModelDownloadService modelDownloadService;
private final Executor executor;

FirebaseModelDownloader(FirebaseApp firebaseApp) {
@RequiresApi(api = VERSION_CODES.KITKAT)
FirebaseModelDownloader(
FirebaseApp firebaseApp, FirebaseInstallationsApi firebaseInstallationsApi) {
this.firebaseOptions = firebaseApp.getOptions();
this.fileDownloadService = new ModelFileDownloadService(firebaseApp);
this.sharedPreferencesUtil = new SharedPreferencesUtil(firebaseApp);
this.modelDownloadService =
new CustomModelDownloadService(firebaseOptions, firebaseInstallationsApi);
this.executor = Executors.newCachedThreadPool();
}

@VisibleForTesting
FirebaseModelDownloader(
FirebaseOptions firebaseOptions,
SharedPreferencesUtil sharedPreferencesUtil,
ModelFileDownloadService fileDownloadService,
CustomModelDownloadService modelDownloadService,
Executor executor) {
this.firebaseOptions = firebaseOptions;
this.sharedPreferencesUtil = sharedPreferencesUtil;
this.fileDownloadService = fileDownloadService;
this.modelDownloadService = modelDownloadService;
this.executor = executor;
}

Expand Down Expand Up @@ -95,21 +112,77 @@ public static FirebaseModelDownloader getInstance(@NonNull FirebaseApp app) {
public Task<CustomModel> getModel(
@NonNull String modelName,
@NonNull DownloadType downloadType,
@Nullable CustomModelDownloadConditions conditions) {
@Nullable CustomModelDownloadConditions conditions)
throws Exception {
CustomModel localModel = sharedPreferencesUtil.getCustomModelDetails(modelName);
switch (downloadType) {
case LOCAL_MODEL:
if (localModel != null) {
return Tasks.forResult(localModel);
}
Task<CustomModel> modelDetails =
modelDownloadService.getCustomModelDetails(
firebaseOptions.getProjectId(), modelName, null);

// no local model - start download.
return modelDetails.continueWithTask(
executor,
modelDetailTask -> {
if (modelDetailTask.isSuccessful()) {
// start download
return fileDownloadService
.download(modelDetailTask.getResult(), conditions)
.continueWithTask(
executor,
downloadTask -> {
if (downloadTask.isSuccessful()) {
// read the updated model
CustomModel downloadedModel =
sharedPreferencesUtil.getCustomModelDetails(modelName);
// TODO(annz) trigger file move here as well... right now it's temp
// call loadNewlyDownloadedModelFile
return Tasks.forResult(downloadedModel);
}
return Tasks.forException(new Exception("File download failed."));
});
}
return Tasks.forException(modelDetailTask.getException());
});
case LATEST_MODEL:
// check for latest model and download newest
break;
case LOCAL_MODEL_UPDATE_IN_BACKGROUND:
// start download in back ground return current model if not null.
break;
}
throw new UnsupportedOperationException("Not yet implemented.");
}

/** @return The set of all models that are downloaded to this device. */
/**
* Triggers the move to permanent storage of successful model downloads and lists all models
* downloaded to device.
*
* @return The set of all models that are downloaded to this device, triggers completion of file
* moves for completed model downloads.
*/
@NonNull
public Task<Set<CustomModel>> listDownloadedModels() {
// trigger completion of file moves for download files.
try {
fileDownloadService.maybeCheckDownloadingComplete();
} catch (Exception ex) {
System.out.println("Error checking for in progress downloads: " + ex.getMessage());
}

TaskCompletionSource<Set<CustomModel>> taskCompletionSource = new TaskCompletionSource<>();
executor.execute(
() -> taskCompletionSource.setResult(sharedPreferencesUtil.listDownloadedModels()));
return taskCompletionSource.getTask();
}

/*
/**
* Delete old local models, when no longer in use.
*
* @param modelName - name of the model
*/
@NonNull
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@
import com.google.firebase.components.Dependency;
import com.google.firebase.installations.FirebaseInstallationsApi;
import com.google.firebase.ml.modeldownloader.internal.CustomModelDownloadService;
import com.google.firebase.ml.modeldownloader.internal.ModelFileDownloadService;
import com.google.firebase.ml.modeldownloader.internal.ModelFileManager;
import com.google.firebase.ml.modeldownloader.internal.SharedPreferencesUtil;
import com.google.firebase.platforminfo.LibraryVersionComponent;
import java.util.Arrays;
Expand All @@ -44,12 +46,24 @@ public List<Component<?>> getComponents() {
return Arrays.asList(
Component.builder(FirebaseModelDownloader.class)
.add(Dependency.required(FirebaseApp.class))
.factory(c -> new FirebaseModelDownloader(c.get(FirebaseApp.class)))
.add(Dependency.required(FirebaseInstallationsApi.class))
.factory(
c ->
new FirebaseModelDownloader(
c.get(FirebaseApp.class), c.get(FirebaseInstallationsApi.class)))
.build(),
Component.builder(SharedPreferencesUtil.class)
.add(Dependency.required(FirebaseApp.class))
.factory(c -> new SharedPreferencesUtil(c.get(FirebaseApp.class)))
.build(),
Component.builder(ModelFileManager.class)
.add(Dependency.required(FirebaseApp.class))
.factory(c -> new ModelFileManager(c.get(FirebaseApp.class)))
.build(),
Component.builder(ModelFileDownloadService.class)
.add(Dependency.required(FirebaseApp.class))
.factory(c -> new ModelFileDownloadService(c.get(FirebaseApp.class)))
.build(),
Component.builder(CustomModelDownloadService.class)
.add(Dependency.required(FirebaseOptions.class))
.add(Dependency.required(FirebaseInstallationsApi.class))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,10 @@

package com.google.firebase.ml.modeldownloader.internal;

import android.os.Build.VERSION_CODES;
import android.util.JsonReader;
import android.util.Log;
import androidx.annotation.NonNull;
import androidx.annotation.Nullable;
import androidx.annotation.RequiresApi;
import com.google.android.gms.common.util.VisibleForTesting;
import com.google.android.gms.tasks.Task;
import com.google.android.gms.tasks.Tasks;
Expand All @@ -34,7 +32,6 @@
import java.net.HttpURLConnection;
import java.net.URL;
import java.nio.charset.Charset;
import java.nio.charset.StandardCharsets;
import java.text.ParseException;
import java.text.SimpleDateFormat;
import java.util.Date;
Expand All @@ -51,12 +48,10 @@
*
* @hide
*/
@RequiresApi(api = VERSION_CODES.KITKAT)
public final class CustomModelDownloadService {

public class CustomModelDownloadService {
private static final String TAG = "CustomModelDownloadSer";
private static final int CONNECTION_TIME_OUT_MS = 2000; // 2 seconds.
private static final Charset UTF_8 = StandardCharsets.UTF_8;
private static final Charset UTF_8 = Charset.forName("UTF-8");
private static final String ISO_DATE_PATTERN = "yyyy-MM-dd'T'HH:mm:ss.SSS'Z'";
private static final String ACCEPT_ENCODING_HEADER_KEY = "Accept-Encoding";
private static final String CONTENT_ENCODING_HEADER_KEY = "Content-Encoding";
Expand All @@ -76,8 +71,8 @@ public final class CustomModelDownloadService {
static final String DOWNLOAD_MODEL_REGEX = "%s/v1beta2/projects/%s/models/%s:download";

private final ExecutorService executorService;
private FirebaseInstallationsApi firebaseInstallations;
private String apiKey;
private final FirebaseInstallationsApi firebaseInstallations;
private final String apiKey;
private String downloadHost = FIREBASE_DOWNLOAD_HOST;

public CustomModelDownloadService(
Expand Down
Loading