diff --git a/firebase-ml-modeldownloader/api.txt b/firebase-ml-modeldownloader/api.txt index 659d5f06674..2feaff8bf10 100644 --- a/firebase-ml-modeldownloader/api.txt +++ b/firebase-ml-modeldownloader/api.txt @@ -3,7 +3,7 @@ package com.google.firebase.ml.modeldownloader { public class CustomModel { method public long getDownloadId(); - method @Nullable public java.io.File getFile(); + method @Nullable public java.io.File getFile() throws java.lang.Exception; method @NonNull public String getModelHash(); method @NonNull public String getName(); method public long getSize(); @@ -33,9 +33,19 @@ package com.google.firebase.ml.modeldownloader { method @NonNull public com.google.android.gms.tasks.Task deleteDownloadedModel(@NonNull String); method @NonNull public static com.google.firebase.ml.modeldownloader.FirebaseModelDownloader getInstance(); method @NonNull public static com.google.firebase.ml.modeldownloader.FirebaseModelDownloader getInstance(@NonNull com.google.firebase.FirebaseApp); - method @NonNull public com.google.android.gms.tasks.Task getModel(@NonNull String, @NonNull com.google.firebase.ml.modeldownloader.DownloadType, @Nullable com.google.firebase.ml.modeldownloader.CustomModelDownloadConditions); + method @NonNull public com.google.android.gms.tasks.Task getModel(@NonNull String, @NonNull com.google.firebase.ml.modeldownloader.DownloadType, @Nullable com.google.firebase.ml.modeldownloader.CustomModelDownloadConditions) throws java.lang.Exception; method @NonNull public com.google.android.gms.tasks.Task> listDownloadedModels(); } } +package com.google.firebase.ml.modeldownloader.internal { + + public class ModelFileManager { + ctor public ModelFileManager(@NonNull com.google.firebase.FirebaseApp); + method @NonNull public static com.google.firebase.ml.modeldownloader.internal.ModelFileManager getInstance(); + method @Nullable @WorkerThread public java.io.File moveModelToDestinationFolder(@NonNull com.google.firebase.ml.modeldownloader.CustomModel, @NonNull android.os.ParcelFileDescriptor) throws java.lang.Exception; + } + +} + diff --git a/firebase-ml-modeldownloader/firebase-ml-modeldownloader.gradle b/firebase-ml-modeldownloader/firebase-ml-modeldownloader.gradle index 685a7e39f21..5d004717e31 100644 --- a/firebase-ml-modeldownloader/firebase-ml-modeldownloader.gradle +++ b/firebase-ml-modeldownloader/firebase-ml-modeldownloader.gradle @@ -25,7 +25,7 @@ android { compileSdkVersion project.targetSdkVersion defaultConfig { - minSdkVersion project.minSdkVersion + minSdkVersion 16 targetSdkVersion project.targetSdkVersion multiDexEnabled true versionName version diff --git a/firebase-ml-modeldownloader/src/main/java/com/google/firebase/ml/modeldownloader/CustomModel.java b/firebase-ml-modeldownloader/src/main/java/com/google/firebase/ml/modeldownloader/CustomModel.java index 5b30f01ef76..232851d54c5 100644 --- a/firebase-ml-modeldownloader/src/main/java/com/google/firebase/ml/modeldownloader/CustomModel.java +++ b/firebase-ml-modeldownloader/src/main/java/com/google/firebase/ml/modeldownloader/CustomModel.java @@ -16,7 +16,9 @@ import androidx.annotation.NonNull; import androidx.annotation.Nullable; +import androidx.annotation.VisibleForTesting; import com.google.android.gms.common.internal.Objects; +import com.google.firebase.ml.modeldownloader.internal.ModelFileDownloadService; import java.io.File; /** @@ -127,11 +129,30 @@ public String getName() { * progress, returns null, if file update is in progress returns last fully uploaded model. */ @Nullable - public File getFile() { + public File getFile() throws Exception { + return getFile(ModelFileDownloadService.getInstance()); + } + + /** + * The local model file. If null is returned, use the download Id to check the download status. + * + * @return the local file associated with the model. If the original file download is still in + * progress, returns null. If file update is in progress, returns the last fully uploaded + * model. + */ + @Nullable + @VisibleForTesting + File getFile(ModelFileDownloadService fileDownloadService) throws Exception { + // check for completed download + File newDownloadFile = fileDownloadService.loadNewlyDownloadedModelFile(this); + if (newDownloadFile != null) { + return newDownloadFile; + } + // return local file, if present if (localFilePath == null || localFilePath.isEmpty()) { return null; } - throw new UnsupportedOperationException("Not implemented, file retrieval coming soon."); + return new File(localFilePath); } /** @@ -144,7 +165,11 @@ public long getSize() { return fileSize; } - /** @return the model hash */ + /** + * Retrieves the model Hash. + * + * @return the model hash + */ @NonNull public String getModelHash() { return modelHash; @@ -161,6 +186,31 @@ public long getDownloadId() { return downloadId; } + @NonNull + @Override + public String toString() { + Objects.ToStringHelper stringHelper = + Objects.toStringHelper(this) + .add("name", name) + .add("modelHash", modelHash) + .add("fileSize", fileSize); + + if (localFilePath != null && !localFilePath.isEmpty()) { + stringHelper.add("localFilePath", localFilePath); + } + if (downloadId != 0L) { + stringHelper.add("downloadId", downloadId); + } + if (downloadUrl != null && !downloadUrl.isEmpty()) { + stringHelper.add("downloadUrl", downloadUrl); + } + if (downloadUrlExpiry != 0L && !localFilePath.isEmpty()) { + stringHelper.add("downloadUrlExpiry", downloadUrlExpiry); + } + + return stringHelper.toString(); + } + @Override public boolean equals(Object o) { if (o == this) { @@ -200,6 +250,8 @@ public long getDownloadUrlExpiry() { } /** + * Returns the model download url, usually only present when download is about to occur. + * * @return the model download url *

Internal use only * @hide diff --git a/firebase-ml-modeldownloader/src/main/java/com/google/firebase/ml/modeldownloader/FirebaseModelDownloader.java b/firebase-ml-modeldownloader/src/main/java/com/google/firebase/ml/modeldownloader/FirebaseModelDownloader.java index 21499578aac..9152094d133 100644 --- a/firebase-ml-modeldownloader/src/main/java/com/google/firebase/ml/modeldownloader/FirebaseModelDownloader.java +++ b/firebase-ml-modeldownloader/src/main/java/com/google/firebase/ml/modeldownloader/FirebaseModelDownloader.java @@ -13,14 +13,20 @@ // limitations under the License. package com.google.firebase.ml.modeldownloader; +import android.os.Build.VERSION_CODES; import androidx.annotation.NonNull; import androidx.annotation.Nullable; +import androidx.annotation.RequiresApi; import androidx.annotation.VisibleForTesting; import com.google.android.gms.common.internal.Preconditions; import com.google.android.gms.tasks.Task; import com.google.android.gms.tasks.TaskCompletionSource; +import com.google.android.gms.tasks.Tasks; import com.google.firebase.FirebaseApp; import com.google.firebase.FirebaseOptions; +import com.google.firebase.installations.FirebaseInstallationsApi; +import com.google.firebase.ml.modeldownloader.internal.CustomModelDownloadService; +import com.google.firebase.ml.modeldownloader.internal.ModelFileDownloadService; import com.google.firebase.ml.modeldownloader.internal.SharedPreferencesUtil; import java.util.Set; import java.util.concurrent.Executor; @@ -30,11 +36,18 @@ public class FirebaseModelDownloader { private final FirebaseOptions firebaseOptions; private final SharedPreferencesUtil sharedPreferencesUtil; + private final ModelFileDownloadService fileDownloadService; + private final CustomModelDownloadService modelDownloadService; private final Executor executor; - FirebaseModelDownloader(FirebaseApp firebaseApp) { + @RequiresApi(api = VERSION_CODES.KITKAT) + FirebaseModelDownloader( + FirebaseApp firebaseApp, FirebaseInstallationsApi firebaseInstallationsApi) { this.firebaseOptions = firebaseApp.getOptions(); + this.fileDownloadService = new ModelFileDownloadService(firebaseApp); this.sharedPreferencesUtil = new SharedPreferencesUtil(firebaseApp); + this.modelDownloadService = + new CustomModelDownloadService(firebaseOptions, firebaseInstallationsApi); this.executor = Executors.newCachedThreadPool(); } @@ -42,9 +55,13 @@ public class FirebaseModelDownloader { FirebaseModelDownloader( FirebaseOptions firebaseOptions, SharedPreferencesUtil sharedPreferencesUtil, + ModelFileDownloadService fileDownloadService, + CustomModelDownloadService modelDownloadService, Executor executor) { this.firebaseOptions = firebaseOptions; this.sharedPreferencesUtil = sharedPreferencesUtil; + this.fileDownloadService = fileDownloadService; + this.modelDownloadService = modelDownloadService; this.executor = executor; } @@ -95,21 +112,77 @@ public static FirebaseModelDownloader getInstance(@NonNull FirebaseApp app) { public Task getModel( @NonNull String modelName, @NonNull DownloadType downloadType, - @Nullable CustomModelDownloadConditions conditions) { + @Nullable CustomModelDownloadConditions conditions) + throws Exception { + CustomModel localModel = sharedPreferencesUtil.getCustomModelDetails(modelName); + switch (downloadType) { + case LOCAL_MODEL: + if (localModel != null) { + return Tasks.forResult(localModel); + } + Task modelDetails = + modelDownloadService.getCustomModelDetails( + firebaseOptions.getProjectId(), modelName, null); + + // no local model - start download. + return modelDetails.continueWithTask( + executor, + modelDetailTask -> { + if (modelDetailTask.isSuccessful()) { + // start download + return fileDownloadService + .download(modelDetailTask.getResult(), conditions) + .continueWithTask( + executor, + downloadTask -> { + if (downloadTask.isSuccessful()) { + // read the updated model + CustomModel downloadedModel = + sharedPreferencesUtil.getCustomModelDetails(modelName); + // TODO(annz) trigger file move here as well... right now it's temp + // call loadNewlyDownloadedModelFile + return Tasks.forResult(downloadedModel); + } + return Tasks.forException(new Exception("File download failed.")); + }); + } + return Tasks.forException(modelDetailTask.getException()); + }); + case LATEST_MODEL: + // check for latest model and download newest + break; + case LOCAL_MODEL_UPDATE_IN_BACKGROUND: + // start download in back ground return current model if not null. + break; + } throw new UnsupportedOperationException("Not yet implemented."); } - /** @return The set of all models that are downloaded to this device. */ + /** + * Triggers the move to permanent storage of successful model downloads and lists all models + * downloaded to device. + * + * @return The set of all models that are downloaded to this device, triggers completion of file + * moves for completed model downloads. + */ @NonNull public Task> listDownloadedModels() { + // trigger completion of file moves for download files. + try { + fileDownloadService.maybeCheckDownloadingComplete(); + } catch (Exception ex) { + System.out.println("Error checking for in progress downloads: " + ex.getMessage()); + } + TaskCompletionSource> taskCompletionSource = new TaskCompletionSource<>(); executor.execute( () -> taskCompletionSource.setResult(sharedPreferencesUtil.listDownloadedModels())); return taskCompletionSource.getTask(); } - /* + /** * Delete old local models, when no longer in use. + * * @param modelName - name of the model */ @NonNull diff --git a/firebase-ml-modeldownloader/src/main/java/com/google/firebase/ml/modeldownloader/FirebaseModelDownloaderRegistrar.java b/firebase-ml-modeldownloader/src/main/java/com/google/firebase/ml/modeldownloader/FirebaseModelDownloaderRegistrar.java index 200abf8dce4..bed08dcac87 100644 --- a/firebase-ml-modeldownloader/src/main/java/com/google/firebase/ml/modeldownloader/FirebaseModelDownloaderRegistrar.java +++ b/firebase-ml-modeldownloader/src/main/java/com/google/firebase/ml/modeldownloader/FirebaseModelDownloaderRegistrar.java @@ -24,6 +24,8 @@ import com.google.firebase.components.Dependency; import com.google.firebase.installations.FirebaseInstallationsApi; import com.google.firebase.ml.modeldownloader.internal.CustomModelDownloadService; +import com.google.firebase.ml.modeldownloader.internal.ModelFileDownloadService; +import com.google.firebase.ml.modeldownloader.internal.ModelFileManager; import com.google.firebase.ml.modeldownloader.internal.SharedPreferencesUtil; import com.google.firebase.platforminfo.LibraryVersionComponent; import java.util.Arrays; @@ -44,12 +46,24 @@ public List> getComponents() { return Arrays.asList( Component.builder(FirebaseModelDownloader.class) .add(Dependency.required(FirebaseApp.class)) - .factory(c -> new FirebaseModelDownloader(c.get(FirebaseApp.class))) + .add(Dependency.required(FirebaseInstallationsApi.class)) + .factory( + c -> + new FirebaseModelDownloader( + c.get(FirebaseApp.class), c.get(FirebaseInstallationsApi.class))) .build(), Component.builder(SharedPreferencesUtil.class) .add(Dependency.required(FirebaseApp.class)) .factory(c -> new SharedPreferencesUtil(c.get(FirebaseApp.class))) .build(), + Component.builder(ModelFileManager.class) + .add(Dependency.required(FirebaseApp.class)) + .factory(c -> new ModelFileManager(c.get(FirebaseApp.class))) + .build(), + Component.builder(ModelFileDownloadService.class) + .add(Dependency.required(FirebaseApp.class)) + .factory(c -> new ModelFileDownloadService(c.get(FirebaseApp.class))) + .build(), Component.builder(CustomModelDownloadService.class) .add(Dependency.required(FirebaseOptions.class)) .add(Dependency.required(FirebaseInstallationsApi.class)) diff --git a/firebase-ml-modeldownloader/src/main/java/com/google/firebase/ml/modeldownloader/internal/CustomModelDownloadService.java b/firebase-ml-modeldownloader/src/main/java/com/google/firebase/ml/modeldownloader/internal/CustomModelDownloadService.java index 2096a1d0b7d..68e5ab983b9 100644 --- a/firebase-ml-modeldownloader/src/main/java/com/google/firebase/ml/modeldownloader/internal/CustomModelDownloadService.java +++ b/firebase-ml-modeldownloader/src/main/java/com/google/firebase/ml/modeldownloader/internal/CustomModelDownloadService.java @@ -14,12 +14,10 @@ package com.google.firebase.ml.modeldownloader.internal; -import android.os.Build.VERSION_CODES; import android.util.JsonReader; import android.util.Log; import androidx.annotation.NonNull; import androidx.annotation.Nullable; -import androidx.annotation.RequiresApi; import com.google.android.gms.common.util.VisibleForTesting; import com.google.android.gms.tasks.Task; import com.google.android.gms.tasks.Tasks; @@ -34,7 +32,6 @@ import java.net.HttpURLConnection; import java.net.URL; import java.nio.charset.Charset; -import java.nio.charset.StandardCharsets; import java.text.ParseException; import java.text.SimpleDateFormat; import java.util.Date; @@ -51,12 +48,10 @@ * * @hide */ -@RequiresApi(api = VERSION_CODES.KITKAT) -public final class CustomModelDownloadService { - +public class CustomModelDownloadService { private static final String TAG = "CustomModelDownloadSer"; private static final int CONNECTION_TIME_OUT_MS = 2000; // 2 seconds. - private static final Charset UTF_8 = StandardCharsets.UTF_8; + private static final Charset UTF_8 = Charset.forName("UTF-8"); private static final String ISO_DATE_PATTERN = "yyyy-MM-dd'T'HH:mm:ss.SSS'Z'"; private static final String ACCEPT_ENCODING_HEADER_KEY = "Accept-Encoding"; private static final String CONTENT_ENCODING_HEADER_KEY = "Content-Encoding"; @@ -76,8 +71,8 @@ public final class CustomModelDownloadService { static final String DOWNLOAD_MODEL_REGEX = "%s/v1beta2/projects/%s/models/%s:download"; private final ExecutorService executorService; - private FirebaseInstallationsApi firebaseInstallations; - private String apiKey; + private final FirebaseInstallationsApi firebaseInstallations; + private final String apiKey; private String downloadHost = FIREBASE_DOWNLOAD_HOST; public CustomModelDownloadService( diff --git a/firebase-ml-modeldownloader/src/main/java/com/google/firebase/ml/modeldownloader/internal/ModelFileDownloadService.java b/firebase-ml-modeldownloader/src/main/java/com/google/firebase/ml/modeldownloader/internal/ModelFileDownloadService.java new file mode 100644 index 00000000000..9c8a9c3e4bf --- /dev/null +++ b/firebase-ml-modeldownloader/src/main/java/com/google/firebase/ml/modeldownloader/internal/ModelFileDownloadService.java @@ -0,0 +1,371 @@ +// Copyright 2020 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package com.google.firebase.ml.modeldownloader.internal; + +import android.app.DownloadManager; +import android.app.DownloadManager.Query; +import android.app.DownloadManager.Request; +import android.content.BroadcastReceiver; +import android.content.Context; +import android.content.Intent; +import android.content.IntentFilter; +import android.database.Cursor; +import android.net.Uri; +import android.os.Build.VERSION; +import android.os.Build.VERSION_CODES; +import android.os.ParcelFileDescriptor; +import android.util.LongSparseArray; +import androidx.annotation.GuardedBy; +import androidx.annotation.NonNull; +import androidx.annotation.Nullable; +import androidx.annotation.VisibleForTesting; +import androidx.annotation.WorkerThread; +import com.google.android.gms.tasks.Task; +import com.google.android.gms.tasks.TaskCompletionSource; +import com.google.android.gms.tasks.Tasks; +import com.google.firebase.FirebaseApp; +import com.google.firebase.ml.modeldownloader.CustomModel; +import com.google.firebase.ml.modeldownloader.CustomModelDownloadConditions; +import java.io.File; +import java.io.FileNotFoundException; +import java.util.regex.Matcher; +import java.util.regex.Pattern; + +/** + * Calls the Android Download service to copy the model file to device (temp location) and then + * moves file to it's permanent location, updating the model details in shared preferences + * throughout. + * + * @hide + */ +public class ModelFileDownloadService { + + private final DownloadManager downloadManager; + private final Context context; + private final ModelFileManager fileManager; + private final SharedPreferencesUtil sharedPreferencesUtil; + + @GuardedBy("this") + // Mapping from download id to broadcast receiver. Because models can update, we cannot just keep + // one instance of DownloadBroadcastReceiver per RemoteModelDownloadManager object. + private final LongSparseArray receiverMaps = new LongSparseArray<>(); + + @GuardedBy("this") + // Mapping from download id to TaskCompletionSource. Because models can update, we cannot just + // keep one instance of TaskCompletionSource per RemoteModelDownloadManager object. + private final LongSparseArray> taskCompletionSourceMaps = + new LongSparseArray<>(); + + private CustomModelDownloadConditions downloadConditions = + new CustomModelDownloadConditions.Builder().build(); + + public ModelFileDownloadService(@NonNull FirebaseApp firebaseApp) { + this.context = firebaseApp.getApplicationContext(); + downloadManager = (DownloadManager) context.getSystemService(Context.DOWNLOAD_SERVICE); + this.fileManager = ModelFileManager.getInstance(); + this.sharedPreferencesUtil = new SharedPreferencesUtil(firebaseApp); + } + + @VisibleForTesting + ModelFileDownloadService( + @NonNull FirebaseApp firebaseApp, + DownloadManager downloadManager, + ModelFileManager fileManager, + SharedPreferencesUtil sharedPreferencesUtil) { + this.context = firebaseApp.getApplicationContext(); + this.downloadManager = downloadManager; + this.fileManager = fileManager; + this.sharedPreferencesUtil = sharedPreferencesUtil; + } + + /** + * Get ModelFileDownloadService instance using the firebase app returned by {@link + * FirebaseApp#getInstance()} + * + * @return ModelFileDownloadService + */ + @NonNull + public static ModelFileDownloadService getInstance() { + return FirebaseApp.getInstance().get(ModelFileDownloadService.class); + } + + public Task download( + CustomModel customModel, CustomModelDownloadConditions downloadConditions) { + this.downloadConditions = downloadConditions; + // todo add url tests here + return ensureModelDownloaded(customModel); + } + + @VisibleForTesting + Task ensureModelDownloaded(CustomModel customModel) { + // todo check model not already in progress of being downloaded + + // todo remove any failed download attempts + + // schedule new download of model file + Long newDownloadId = scheduleModelDownload(customModel); + if (newDownloadId == null) { + return Tasks.forException(new Exception("Failed to schedule the download task")); + } + + return registerReceiverForDownloadId(newDownloadId); + } + + private synchronized DownloadBroadcastReceiver getReceiverInstance(long downloadId) { + DownloadBroadcastReceiver receiver = receiverMaps.get(downloadId); + if (receiver == null) { + receiver = + new DownloadBroadcastReceiver(downloadId, getTaskCompletionSourceInstance(downloadId)); + receiverMaps.put(downloadId, receiver); + } + return receiver; + } + + private Task registerReceiverForDownloadId(long downloadId) { + BroadcastReceiver broadcastReceiver = getReceiverInstance(downloadId); + // It is okay to always register here. Since the broadcast receiver is the same via the lookup + // for the same download id, the same broadcast receiver will be notified only once. + context.registerReceiver( + broadcastReceiver, new IntentFilter(DownloadManager.ACTION_DOWNLOAD_COMPLETE)); + + return getTaskCompletionSourceInstance(downloadId).getTask(); + } + + @VisibleForTesting + synchronized TaskCompletionSource getTaskCompletionSourceInstance(long downloadId) { + TaskCompletionSource taskCompletionSource = taskCompletionSourceMaps.get(downloadId); + if (taskCompletionSource == null) { + taskCompletionSource = new TaskCompletionSource<>(); + taskCompletionSourceMaps.put(downloadId, taskCompletionSource); + } + + return taskCompletionSource; + } + + @VisibleForTesting + synchronized Long scheduleModelDownload(@NonNull CustomModel customModel) { + if (downloadManager == null) { + return null; + } + + if (customModel.getDownloadUrl() == null || customModel.getDownloadUrl().isEmpty()) { + return null; + } + // todo handle expired url here and figure out what to do about delayed downloads too.. + + // Schedule a new downloading + Request downloadRequest = new Request(Uri.parse(customModel.getDownloadUrl())); + // check Url is not expired - get new one if necessary... + + // By setting the destination uri to null, the downloaded file will be stored in + // DownloadManager's purgeable cache. As a result, WRITE_EXTERNAL_STORAGE permission is not + // needed. + downloadRequest.setDestinationUri(null); + if (VERSION.SDK_INT >= VERSION_CODES.N) { + downloadRequest.setRequiresCharging(downloadConditions.isChargingRequired()); + downloadRequest.setRequiresDeviceIdle(downloadConditions.isDeviceIdleRequired()); + } + + if (downloadConditions.isWifiRequired()) { + downloadRequest.setAllowedNetworkTypes(Request.NETWORK_WIFI); + } + + long id = downloadManager.enqueue(downloadRequest); + // update the custom model to store the download id - do not lose current local file - in case + // this is a background update. + sharedPreferencesUtil.setDownloadingCustomModelDetails( + new CustomModel( + customModel.getName(), + customModel.getModelHash(), + customModel.getSize(), + id, + customModel.getLocalFilePath())); + return id; + } + + @Nullable + @VisibleForTesting + synchronized Integer getDownloadingModelStatusCode(Long downloadingId) { + if (downloadManager == null || downloadingId == null) { + return null; + } + + Integer statusCode = null; + + try (Cursor cursor = downloadManager.query(new Query().setFilterById(downloadingId))) { + + if (cursor != null && cursor.moveToFirst()) { + statusCode = cursor.getInt(cursor.getColumnIndex(DownloadManager.COLUMN_STATUS)); + } + + if (statusCode == null) { + return null; + } + + if (statusCode != DownloadManager.STATUS_RUNNING + && statusCode != DownloadManager.STATUS_PAUSED + && statusCode != DownloadManager.STATUS_PENDING + && statusCode != DownloadManager.STATUS_SUCCESSFUL + && statusCode != DownloadManager.STATUS_FAILED) { + // Unknown status + statusCode = null; + } + return statusCode; + } + } + + @Nullable + private synchronized ParcelFileDescriptor getDownloadedFile(Long downloadingId) { + if (downloadManager == null || downloadingId == null) { + return null; + } + + ParcelFileDescriptor fileDescriptor = null; + try { + fileDescriptor = downloadManager.openDownloadedFile(downloadingId); + } catch (FileNotFoundException e) { + System.out.println("Downloaded file is not found"); + } + return fileDescriptor; + } + + public void maybeCheckDownloadingComplete() throws Exception { + for (String key : sharedPreferencesUtil.getSharedPreferenceKeySet()) { + // if a local file path is present - get model details. + Matcher matcher = + Pattern.compile(SharedPreferencesUtil.DOWNLOADING_MODEL_ID_MATCHER).matcher(key); + if (matcher.find()) { + String modelName = matcher.group(matcher.groupCount()); + CustomModel downloadingModel = sharedPreferencesUtil.getCustomModelDetails(modelName); + Integer statusCode = getDownloadingModelStatusCode(downloadingModel.getDownloadId()); + if (statusCode == DownloadManager.STATUS_SUCCESSFUL + || statusCode == DownloadManager.STATUS_FAILED) { + loadNewlyDownloadedModelFile(downloadingModel); + } + } + } + } + + @Nullable + @WorkerThread + public File loadNewlyDownloadedModelFile(CustomModel model) throws Exception { + Long downloadingId = model.getDownloadId(); + String downloadingModelHash = model.getModelHash(); + + if (downloadingId == null || downloadingModelHash == null) { + // no downloading model file or incomplete info. + return null; + } + + Integer statusCode = getDownloadingModelStatusCode(downloadingId); + if (statusCode == null) { + return null; + } + + if (statusCode == DownloadManager.STATUS_SUCCESSFUL) { + // Get downloaded file. + ParcelFileDescriptor fileDescriptor = getDownloadedFile(downloadingId); + if (fileDescriptor == null) { + // reset original model - removing download id. + sharedPreferencesUtil.setFailedUploadedCustomModelDetails(model.getName()); + // todo call the download register? + return null; + } + + // Try to move it to destination folder. + File newModelFile = fileManager.moveModelToDestinationFolder(model, fileDescriptor); + + if (newModelFile == null) { + // reset original model - removing download id. + // todo call the download register? + sharedPreferencesUtil.setFailedUploadedCustomModelDetails(model.getName()); + return null; + } + + // Successfully moved, update share preferences + sharedPreferencesUtil.setUploadedCustomModelDetails( + new CustomModel( + model.getName(), model.getModelHash(), model.getSize(), 0, newModelFile.getPath())); + + // Cleans up the old files if it is the initial creation. + return newModelFile; + } else if (statusCode == DownloadManager.STATUS_FAILED) { + // reset original model - removing download id. + sharedPreferencesUtil.setFailedUploadedCustomModelDetails(model.getName()); + // todo - determine if the temp files need to be clean up? Does one exist? + } + // Other cases, return as null and wait for download finish. + return null; + } + + // This class runs totally on worker thread because we registered the receiver with a worker + // thread handler. + @WorkerThread + private class DownloadBroadcastReceiver extends BroadcastReceiver { + + // Download Id is captured inside this class in memory. So there is no concern of inconsistency + // with the persisted download id in shared preferences. + private final long downloadId; + private final TaskCompletionSource taskCompletionSource; + + private DownloadBroadcastReceiver( + long downloadId, TaskCompletionSource taskCompletionSource) { + this.downloadId = downloadId; + this.taskCompletionSource = taskCompletionSource; + } + + @Override + public void onReceive(Context context, Intent intent) { + long id = intent.getLongExtra(DownloadManager.EXTRA_DOWNLOAD_ID, -1); + if (id != downloadId) { + return; + } + + Integer statusCode = getDownloadingModelStatusCode(downloadId); + synchronized (ModelFileDownloadService.this) { + try { + context.getApplicationContext().unregisterReceiver(this); + } catch (IllegalArgumentException e) { + // If we try to unregister a receiver that was never registered or has been unregistered, + // IllegalArgumentException will be thrown by the Android Framework. + // Our current code does not have this problem. However, in order to be safer in the + // future, we just ignore the exception here, because it is not a big deal. The code can + // move on. + } + + receiverMaps.remove(downloadId); + taskCompletionSourceMaps.remove(downloadId); + } + + if (statusCode != null) { + if (statusCode == DownloadManager.STATUS_FAILED) { + // todo add failure reason and logging + System.out.println("Download Failed for id: " + id); + taskCompletionSource.setException(new Exception("Failed")); + return; + } + + if (statusCode == DownloadManager.STATUS_SUCCESSFUL) { + System.out.println("Download Succeeded for id: " + id); + taskCompletionSource.setResult(null); + return; + } + } + + // Status code is null or not one of success or fail. + taskCompletionSource.setException(new Exception("Model downloading failed")); + } + } +} diff --git a/firebase-ml-modeldownloader/src/main/java/com/google/firebase/ml/modeldownloader/internal/ModelFileManager.java b/firebase-ml-modeldownloader/src/main/java/com/google/firebase/ml/modeldownloader/internal/ModelFileManager.java new file mode 100644 index 00000000000..c573d2cb37a --- /dev/null +++ b/firebase-ml-modeldownloader/src/main/java/com/google/firebase/ml/modeldownloader/internal/ModelFileManager.java @@ -0,0 +1,166 @@ +// Copyright 2020 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +package com.google.firebase.ml.modeldownloader.internal; + +import android.content.Context; +import android.os.Build.VERSION; +import android.os.Build.VERSION_CODES; +import android.os.ParcelFileDescriptor; +import android.os.ParcelFileDescriptor.AutoCloseInputStream; +import androidx.annotation.NonNull; +import androidx.annotation.Nullable; +import androidx.annotation.VisibleForTesting; +import androidx.annotation.WorkerThread; +import com.google.firebase.FirebaseApp; +import com.google.firebase.ml.modeldownloader.CustomModel; +import java.io.File; +import java.io.FileInputStream; +import java.io.FileOutputStream; +import java.io.IOException; + +/** Model File Manager is used to move the downloaded file to the appropriate locations. */ +public class ModelFileManager { + + @VisibleForTesting + static final String CUSTOM_MODEL_ROOT_PATH = "com.google.firebase.ml.custom.models"; + + private static final int INVALID_INDEX = -1; + private final Context context; + private final FirebaseApp firebaseApp; + + public ModelFileManager(@NonNull FirebaseApp firebaseApp) { + this.context = firebaseApp.getApplicationContext(); + this.firebaseApp = firebaseApp; + } + + /** + * Get ModelFileDownloadService instance using the firebase app returned by {@link + * FirebaseApp#getInstance()} + * + * @return ModelFileDownloadService + */ + @NonNull + public static ModelFileManager getInstance() { + return FirebaseApp.getInstance().get(ModelFileManager.class); + } + + /** + * Get the directory where the model is supposed to reside. This method does not ensure that the + * directory specified does exist. If you need to ensure its existence, you should call + * getDirImpl. + */ + @Nullable + private File getModelDirUnsafe(@NonNull String modelName) { + String modelTypeSpecificRoot = CUSTOM_MODEL_ROOT_PATH; + File root; + if (VERSION.SDK_INT >= VERSION_CODES.LOLLIPOP) { + root = new File(context.getNoBackupFilesDir(), modelTypeSpecificRoot); + } else { + root = context.getApplicationContext().getDir(modelTypeSpecificRoot, Context.MODE_PRIVATE); + } + File firebaseAppDir = new File(root, firebaseApp.getPersistenceKey()); + return new File(firebaseAppDir, modelName); + } + + /** + * Gets the directory in the following schema: + * app_root/model_type_specific_root/[temp]/firebase_app_persistence_key/model_name. + */ + @VisibleForTesting + @WorkerThread + File getDirImpl(@NonNull String modelName) throws Exception { + File modelDir = getModelDirUnsafe(modelName); + if (!modelDir.exists()) { + if (!modelDir.mkdirs()) { + throw new Exception("Failed to create model folder: " + modelDir); + } + } else if (!modelDir.isDirectory()) { + throw new Exception( + "Can not create model folder, since an existing file has the same name: " + modelDir); + } + return modelDir; + } + + /** + * Since the model files under the model folder are named with numbers, and the later one is the + * newer, the latest model is the file name with largest number. + */ + @WorkerThread + private int getLatestCachedModelVersion(@NonNull File modelDir) { + File[] modelFiles = modelDir.listFiles(); + if (modelFiles == null || modelFiles.length == 0) { + return INVALID_INDEX; + } + + int index = INVALID_INDEX; + for (File modelFile : modelFiles) { + try { + index = Math.max(index, Integer.parseInt(modelFile.getName())); + } catch (NumberFormatException e) { + System.out.println("Contains non-integer file name " + modelFile.getName()); + } + } + return index; + } + + @VisibleForTesting + @Nullable + File getModelFileDestination(@NonNull CustomModel model) throws Exception { + File destFolder = getDirImpl(model.getName()); + int index = getLatestCachedModelVersion(destFolder); + return new File(destFolder, String.valueOf(index + 1)); + } + + /** + * Moves a downloaded file from external storage to private folder. + * + *

The private file path pattern is /%private_folder%/%firebaseapp_persistentkey%/%model_name%/ + * + *

The model file under the model folder are named with numbers starting from 0. The larger one + * is the newer model downloaded from cloud. + * + *

The caller is supposed to cleanup the previous downloaded files after this call, even when + * this call throws exception. + * + * @return null if the movement failed. Otherwise, return the destination file. + */ + @Nullable + @WorkerThread + public synchronized File moveModelToDestinationFolder( + @NonNull CustomModel customModel, @NonNull ParcelFileDescriptor modelFileDescriptor) + throws Exception { + File modelFileDestination = getModelFileDestination(customModel); + + // Moves to the final destination file in app private folder to avoid the downloaded file from + // being changed by + // other apps. + try (FileInputStream fis = new AutoCloseInputStream(modelFileDescriptor); + FileOutputStream fos = new FileOutputStream(modelFileDestination)) { + byte[] buffer = new byte[4096]; + int read; + while ((read = fis.read(buffer)) != -1) { + fos.write(buffer, 0, read); + } + // Let's be extra sure it is all written before we return. + fos.getFD().sync(); + } catch (IOException e) { + // Failed to copy to destination - clean up. + System.out.println("Failed to copy downloaded model file to destination folder: " + e); + modelFileDestination.delete(); + return null; + } + + return modelFileDestination; + } +} diff --git a/firebase-ml-modeldownloader/src/main/java/com/google/firebase/ml/modeldownloader/internal/SharedPreferencesUtil.java b/firebase-ml-modeldownloader/src/main/java/com/google/firebase/ml/modeldownloader/internal/SharedPreferencesUtil.java index ca4046d1ead..573ec13700e 100644 --- a/firebase-ml-modeldownloader/src/main/java/com/google/firebase/ml/modeldownloader/internal/SharedPreferencesUtil.java +++ b/firebase-ml-modeldownloader/src/main/java/com/google/firebase/ml/modeldownloader/internal/SharedPreferencesUtil.java @@ -31,6 +31,8 @@ /** @hide */ public class SharedPreferencesUtil { + public static final String DOWNLOADING_MODEL_ID_MATCHER = "downloading_model_id_(.*?)_([^/]+)/?"; + @VisibleForTesting static final String PREFERENCES_PACKAGE_NAME = "com.google.firebase.ml.modelDownloader"; @@ -38,7 +40,6 @@ public class SharedPreferencesUtil { private static final String LOCAL_MODEL_HASH_PATTERN = "current_model_hash_%s_%s"; private static final String LOCAL_MODEL_FILE_PATH_PATTERN = "current_model_path_%s_%s"; private static final String LOCAL_MODEL_FILE_PATH_MATCHER = "current_model_path_(.*?)_([^/]+)/?"; - private static final String LOCAL_MODEL_FILE_SIZE_PATTERN = "current_model_size_%s_%s"; // details about model during download. private static final String DOWNLOADING_MODEL_HASH_PATTERN = "downloading_model_hash_%s_%s"; @@ -176,6 +177,19 @@ public synchronized void setUploadedCustomModelDetails(@NonNull CustomModel cust .commit(); } + /** + * The information about a failed custom model download. Updates the local model information and + * clears the download details associated with this model. Does not update the local file model. + * + * @param customModelName custom model details to be stored. + * @hide + */ + public synchronized void setFailedUploadedCustomModelDetails(@NonNull String customModelName) + throws IllegalArgumentException { + Editor editor = getSharedPreferences().edit(); + clearDownloadingModelDetails(editor, customModelName); + } + /** * Clears all stored data related to a local custom model, including download details. * @@ -196,6 +210,22 @@ public synchronized void clearModelDetails(@NonNull String modelName, boolean cl .commit(); } + /** + * Set of all keys associated with this firebase app. + * + * @return + */ + public Set getSharedPreferenceKeySet() { + return getSharedPreferences().getAll().keySet(); + } + + /** + * Lists the current set of downloaded model, does not include downloads in progress. Call + * ModelFileManager.maybeGetUpdatedModels() before calling this to trigger successful download + * completions. + * + * @return list of Custom Models. + */ public synchronized Set listDownloadedModels() { Set customModels = new HashSet<>(); Set keySet = getSharedPreferences().getAll().keySet(); @@ -209,28 +239,11 @@ public synchronized Set listDownloadedModels() { if (extractModel != null) { customModels.add(extractModel); } - } else { - matcher = Pattern.compile(DOWNLOADING_MODEL_ID_PATTERN).matcher(key); - if (matcher.find()) { - String modelName = matcher.group(matcher.groupCount()); - CustomModel extractModel = maybeGetUpdatedModel(modelName); - if (extractModel != null) { - customModels.add(extractModel); - } - } } } return customModels; } - synchronized CustomModel maybeGetUpdatedModel(String modelName) { - CustomModel downloadModel = getCustomModelDetails(modelName); - // TODO(annz) check here if download currently in progress have completed. - // if yes, then complete file relocation and return the updated model, otherwise return null - - return null; - } - /** * Clears all stored data related to a custom model download. * diff --git a/firebase-ml-modeldownloader/src/test/java/com/google/firebase/ml/modeldownloader/CustomModelDownloadConditionsTest.java b/firebase-ml-modeldownloader/src/test/java/com/google/firebase/ml/modeldownloader/CustomModelDownloadConditionsTest.java index d7cdba9ed21..1f0257b12f7 100644 --- a/firebase-ml-modeldownloader/src/test/java/com/google/firebase/ml/modeldownloader/CustomModelDownloadConditionsTest.java +++ b/firebase-ml-modeldownloader/src/test/java/com/google/firebase/ml/modeldownloader/CustomModelDownloadConditionsTest.java @@ -49,7 +49,7 @@ public void testTwoDefaultConditionsSame() { } @Test - public void testTwoConfigedConditionsSame() { + public void testTwoConfiguredConditionsSame() { CustomModelDownloadConditions conditions1 = new CustomModelDownloadConditions.Builder().requireDeviceIdle().requireCharging().build(); CustomModelDownloadConditions conditions2 = @@ -58,7 +58,7 @@ public void testTwoConfigedConditionsSame() { } @Test - public void testTwoConfigedConditionsDifferent() { + public void testTwoConfiguredConditionsDifferent() { CustomModelDownloadConditions conditions1 = new CustomModelDownloadConditions.Builder().requireCharging().build(); CustomModelDownloadConditions conditions2 = diff --git a/firebase-ml-modeldownloader/src/test/java/com/google/firebase/ml/modeldownloader/CustomModelTest.java b/firebase-ml-modeldownloader/src/test/java/com/google/firebase/ml/modeldownloader/CustomModelTest.java index b58732fc48c..1072543aaf8 100644 --- a/firebase-ml-modeldownloader/src/test/java/com/google/firebase/ml/modeldownloader/CustomModelTest.java +++ b/firebase-ml-modeldownloader/src/test/java/com/google/firebase/ml/modeldownloader/CustomModelTest.java @@ -17,9 +17,22 @@ import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertNotEquals; import static org.junit.Assert.assertNull; - +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +import androidx.test.core.app.ApplicationProvider; +import com.google.firebase.FirebaseApp; +import com.google.firebase.FirebaseOptions; +import com.google.firebase.FirebaseOptions.Builder; +import com.google.firebase.ml.modeldownloader.internal.ModelFileDownloadService; +import java.io.File; +import org.junit.Before; import org.junit.Test; import org.junit.runner.RunWith; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; import org.robolectric.RobolectricTestRunner; @RunWith(RobolectricTestRunner.class) @@ -27,13 +40,31 @@ public class CustomModelTest { public static final String MODEL_NAME = "ModelName"; public static final String MODEL_HASH = "dsf324"; + public static final File TEST_MODEL_FILE = new File("fakeFile.tflite"); + public static final File TEST_MODEL_FILE_UPDATED = new File("fakeUpdateFile.tflite"); public static final String MODEL_URL = "https://project.firebase.com/modelName/23424.jpg"; + public static final String TEST_PROJECT_ID = "777777777777"; + public static final FirebaseOptions FIREBASE_OPTIONS = + new Builder() + .setApplicationId("1:123456789:android:abcdef") + .setProjectId(TEST_PROJECT_ID) + .build(); private static final long URL_EXPIRATION = 604800L; - CustomModel CUSTOM_MODEL = new CustomModel(MODEL_NAME, MODEL_HASH, 100, 0); - - CustomModel CUSTOM_MODEL_URL = + final CustomModel CUSTOM_MODEL = new CustomModel(MODEL_NAME, MODEL_HASH, 100, 0); + final CustomModel CUSTOM_MODEL_URL = new CustomModel(MODEL_NAME, MODEL_HASH, 100, MODEL_URL, URL_EXPIRATION); + final CustomModel CUSTOM_MODEL_FILE = + new CustomModel(MODEL_NAME, MODEL_HASH, 100, 0, TEST_MODEL_FILE.getPath()); + @Mock ModelFileDownloadService fileDownloadService; + + @Before + public void setUp() { + MockitoAnnotations.initMocks(this); + FirebaseApp.clearInstancesForTest(); + // default app + FirebaseApp.initializeApp(ApplicationProvider.getApplicationContext(), FIREBASE_OPTIONS); + } @Test public void customModel_getName() { @@ -56,8 +87,32 @@ public void customModel_getDownloadId() { } @Test - public void customModel_getFile_downloadIncomplete() { - assertNull(CUSTOM_MODEL.getFile()); + public void customModel_getFile_noLocalNoDownloadIncomplete() throws Exception { + when(fileDownloadService.loadNewlyDownloadedModelFile(any(CustomModel.class))).thenReturn(null); + assertNull(CUSTOM_MODEL.getFile(fileDownloadService)); + verify(fileDownloadService, times(1)).loadNewlyDownloadedModelFile(any()); + } + + @Test + public void customModel_getFile_localModelNoDownload() throws Exception { + when(fileDownloadService.loadNewlyDownloadedModelFile(any(CustomModel.class))).thenReturn(null); + assertEquals(CUSTOM_MODEL_FILE.getFile(fileDownloadService), TEST_MODEL_FILE); + verify(fileDownloadService, times(1)).loadNewlyDownloadedModelFile(any()); + } + + @Test + public void customModel_getFile_localModelDownloadComplete() throws Exception { + when(fileDownloadService.loadNewlyDownloadedModelFile(any(CustomModel.class))) + .thenReturn(TEST_MODEL_FILE_UPDATED); + assertEquals(CUSTOM_MODEL_FILE.getFile(fileDownloadService), TEST_MODEL_FILE_UPDATED); + verify(fileDownloadService, times(1)).loadNewlyDownloadedModelFile(any()); + } + + @Test + public void customModel_getFile_noLocalDownloadComplete() throws Exception { + when(fileDownloadService.loadNewlyDownloadedModelFile(any())).thenReturn(TEST_MODEL_FILE); + assertEquals(CUSTOM_MODEL.getFile(fileDownloadService), TEST_MODEL_FILE); + verify(fileDownloadService, times(1)).loadNewlyDownloadedModelFile(any()); } @Test diff --git a/firebase-ml-modeldownloader/src/test/java/com/google/firebase/ml/modeldownloader/FirebaseModelDownloaderTest.java b/firebase-ml-modeldownloader/src/test/java/com/google/firebase/ml/modeldownloader/FirebaseModelDownloaderTest.java index 6e1455cb1fd..aba709eb512 100644 --- a/firebase-ml-modeldownloader/src/test/java/com/google/firebase/ml/modeldownloader/FirebaseModelDownloaderTest.java +++ b/firebase-ml-modeldownloader/src/test/java/com/google/firebase/ml/modeldownloader/FirebaseModelDownloaderTest.java @@ -17,17 +17,24 @@ import static com.google.common.truth.Truth.assertThat; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertThrows; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.doNothing; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; import androidx.test.core.app.ApplicationProvider; import com.google.android.gms.tasks.Task; +import com.google.android.gms.tasks.Tasks; import com.google.firebase.FirebaseApp; import com.google.firebase.FirebaseOptions; import com.google.firebase.FirebaseOptions.Builder; +import com.google.firebase.ml.modeldownloader.internal.CustomModelDownloadService; +import com.google.firebase.ml.modeldownloader.internal.ModelFileDownloadService; import com.google.firebase.ml.modeldownloader.internal.SharedPreferencesUtil; import java.util.Collections; import java.util.Set; -import java.util.concurrent.ExecutionException; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; import org.junit.Before; @@ -51,12 +58,16 @@ public class FirebaseModelDownloaderTest { new CustomModelDownloadConditions.Builder().build(); public static final String MODEL_HASH = "dsf324"; + public static final CustomModelDownloadConditions DOWNLOAD_CONDITIONS = + new CustomModelDownloadConditions.Builder().requireWifi().build(); + // TODO replace with uploaded model. - CustomModel CUSTOM_MODEL = new CustomModel(MODEL_NAME, MODEL_HASH, 100, 0); + final CustomModel CUSTOM_MODEL = new CustomModel(MODEL_NAME, MODEL_HASH, 100, 0); FirebaseModelDownloader firebaseModelDownloader; @Mock SharedPreferencesUtil mockPrefs; - + @Mock ModelFileDownloadService mockFileDownloadService; + @Mock CustomModelDownloadService mockModelDownloadService; ExecutorService executor; @Before @@ -67,7 +78,13 @@ public void setUp() { FirebaseApp.initializeApp(ApplicationProvider.getApplicationContext(), FIREBASE_OPTIONS); executor = Executors.newSingleThreadExecutor(); - firebaseModelDownloader = new FirebaseModelDownloader(FIREBASE_OPTIONS, mockPrefs, executor); + firebaseModelDownloader = + new FirebaseModelDownloader( + FIREBASE_OPTIONS, + mockPrefs, + mockFileDownloadService, + mockModelDownloadService, + executor); } @Test @@ -76,13 +93,73 @@ public void getModel_unimplemented() { UnsupportedOperationException.class, () -> FirebaseModelDownloader.getInstance() - .getModel(MODEL_NAME, DownloadType.LOCAL_MODEL, DEFAULT_DOWNLOAD_CONDITIONS)); + .getModel( + MODEL_NAME, + DownloadType.LOCAL_MODEL_UPDATE_IN_BACKGROUND, + DEFAULT_DOWNLOAD_CONDITIONS)); } @Test - public void listDownloadedModels_returnsEmptyModelList() - throws ExecutionException, InterruptedException { + public void getModel_localExists() throws Exception { + when(mockPrefs.getCustomModelDetails(eq(MODEL_NAME))).thenReturn(CUSTOM_MODEL); + TestOnCompleteListener onCompleteListener = new TestOnCompleteListener<>(); + Task task = + firebaseModelDownloader.getModel(MODEL_NAME, DownloadType.LOCAL_MODEL, DOWNLOAD_CONDITIONS); + task.addOnCompleteListener(executor, onCompleteListener); + CustomModel customModel = onCompleteListener.await(); + + verify(mockPrefs, times(1)).getCustomModelDetails(eq(MODEL_NAME)); + assertThat(task.isComplete()).isTrue(); + assertEquals(customModel, CUSTOM_MODEL); + } + + @Test + public void getModel_noLocalModel() throws Exception { + when(mockPrefs.getCustomModelDetails(eq(MODEL_NAME))).thenReturn(null).thenReturn(CUSTOM_MODEL); + when(mockModelDownloadService.getCustomModelDetails( + eq(TEST_PROJECT_ID), eq(MODEL_NAME), eq(null))) + .thenReturn(Tasks.forResult(CUSTOM_MODEL)); + when(mockFileDownloadService.download(any(), eq(DOWNLOAD_CONDITIONS))) + .thenReturn(Tasks.forResult(null)); + TestOnCompleteListener onCompleteListener = new TestOnCompleteListener<>(); + Task task = + firebaseModelDownloader.getModel(MODEL_NAME, DownloadType.LOCAL_MODEL, DOWNLOAD_CONDITIONS); + task.addOnCompleteListener(executor, onCompleteListener); + CustomModel customModel = onCompleteListener.await(); + + verify(mockPrefs, times(2)).getCustomModelDetails(eq(MODEL_NAME)); + assertThat(task.isComplete()).isTrue(); + assertEquals(customModel, CUSTOM_MODEL); + } + + @Test + public void getModel_noLocalModel_error() throws Exception { + when(mockPrefs.getCustomModelDetails(eq(MODEL_NAME))).thenReturn(null).thenReturn(CUSTOM_MODEL); + when(mockModelDownloadService.getCustomModelDetails( + eq(TEST_PROJECT_ID), eq(MODEL_NAME), eq(null))) + .thenReturn(Tasks.forResult(CUSTOM_MODEL)); + when(mockFileDownloadService.download(any(), eq(DOWNLOAD_CONDITIONS))) + .thenReturn(Tasks.forException(new Exception("bad download"))); + TestOnCompleteListener onCompleteListener = new TestOnCompleteListener<>(); + Task task = + firebaseModelDownloader.getModel(MODEL_NAME, DownloadType.LOCAL_MODEL, DOWNLOAD_CONDITIONS); + task.addOnCompleteListener(executor, onCompleteListener); + try { + onCompleteListener.await(); + } catch (Exception ex) { + assertThat(ex.getMessage().contains("download failed")).isTrue(); + } + + verify(mockPrefs, times(1)).getCustomModelDetails(eq(MODEL_NAME)); + assertThat(task.isComplete()).isTrue(); + assertThat(task.isSuccessful()).isFalse(); + } + + @Test + public void listDownloadedModels_returnsEmptyModelList() throws Exception { when(mockPrefs.listDownloadedModels()).thenReturn(Collections.emptySet()); + doNothing().when(mockFileDownloadService).maybeCheckDownloadingComplete(); + TestOnCompleteListener> onCompleteListener = new TestOnCompleteListener<>(); Task> task = firebaseModelDownloader.listDownloadedModels(); task.addOnCompleteListener(executor, onCompleteListener); @@ -93,9 +170,9 @@ public void listDownloadedModels_returnsEmptyModelList() } @Test - public void listDownloadedModels_returnsModelList() - throws ExecutionException, InterruptedException { + public void listDownloadedModels_returnsModelList() throws Exception { when(mockPrefs.listDownloadedModels()).thenReturn(Collections.singleton(CUSTOM_MODEL)); + doNothing().when(mockFileDownloadService).maybeCheckDownloadingComplete(); TestOnCompleteListener> onCompleteListener = new TestOnCompleteListener<>(); Task> task = firebaseModelDownloader.listDownloadedModels(); diff --git a/firebase-ml-modeldownloader/src/test/java/com/google/firebase/ml/modeldownloader/TestOnCompleteListener.java b/firebase-ml-modeldownloader/src/test/java/com/google/firebase/ml/modeldownloader/TestOnCompleteListener.java index 332c786684c..a4080d0023a 100644 --- a/firebase-ml-modeldownloader/src/test/java/com/google/firebase/ml/modeldownloader/TestOnCompleteListener.java +++ b/firebase-ml-modeldownloader/src/test/java/com/google/firebase/ml/modeldownloader/TestOnCompleteListener.java @@ -46,7 +46,7 @@ public void onComplete(@NonNull Task task) { } /** Blocks until the {@link #onComplete} is called. */ - public TResult await() throws InterruptedException, ExecutionException { + public TResult await() throws Exception { if (!latch.await(TIMEOUT_MS, TimeUnit.MILLISECONDS)) { throw new InterruptedException("timed out waiting for result"); } @@ -56,10 +56,13 @@ public TResult await() throws InterruptedException, ExecutionException { if (exception instanceof InterruptedException) { throw (InterruptedException) exception; } - // todo(annz) add firebase ml exception handling here. if (exception instanceof IOException) { throw new ExecutionException(exception); } + // TODO(annz) replace with firebase ml exception handling. + if (exception instanceof Exception) { + throw exception; + } throw new IllegalStateException("got an unexpected exception type", exception); } } diff --git a/firebase-ml-modeldownloader/src/test/java/com/google/firebase/ml/modeldownloader/internal/CustomModelDownloadServiceTest.java b/firebase-ml-modeldownloader/src/test/java/com/google/firebase/ml/modeldownloader/internal/CustomModelDownloadServiceTest.java index c69b3ce7178..35735a1a915 100644 --- a/firebase-ml-modeldownloader/src/test/java/com/google/firebase/ml/modeldownloader/internal/CustomModelDownloadServiceTest.java +++ b/firebase-ml-modeldownloader/src/test/java/com/google/firebase/ml/modeldownloader/internal/CustomModelDownloadServiceTest.java @@ -130,7 +130,7 @@ public void parseTokenExpirationTimestamp_failed() { } @Test - public void testDownloadService_noHashSuccess() throws Exception { + public void downloadService_noHashSuccess() throws Exception { String downloadPath = String.format(CustomModelDownloadService.DOWNLOAD_MODEL_REGEX, "", PROJECT_ID, MODEL_NAME); stubFor( @@ -165,7 +165,7 @@ public void testDownloadService_noHashSuccess() throws Exception { } @Test - public void testDownloadService_withHashSuccess_noMatch() throws Exception { + public void downloadService_withHashSuccess_noMatch() throws Exception { String downloadPath = String.format(CustomModelDownloadService.DOWNLOAD_MODEL_REGEX, "", PROJECT_ID, MODEL_NAME); stubFor( @@ -200,7 +200,7 @@ public void testDownloadService_withHashSuccess_noMatch() throws Exception { } @Test - public void testDownloadService_withHashSuccess_match() throws Exception { + public void downloadService_withHashSuccess_match() throws Exception { String downloadPath = String.format(CustomModelDownloadService.DOWNLOAD_MODEL_REGEX, "", PROJECT_ID, MODEL_NAME); stubFor( @@ -233,7 +233,7 @@ public void testDownloadService_withHashSuccess_match() throws Exception { } @Test - public void testDownloadService_modelNotFound() throws Exception { + public void downloadService_modelNotFound() throws Exception { String downloadPath = String.format(CustomModelDownloadService.DOWNLOAD_MODEL_REGEX, "", PROJECT_ID, MODEL_NAME); stubFor( diff --git a/firebase-ml-modeldownloader/src/test/java/com/google/firebase/ml/modeldownloader/internal/ModelFileDownloadServiceTest.java b/firebase-ml-modeldownloader/src/test/java/com/google/firebase/ml/modeldownloader/internal/ModelFileDownloadServiceTest.java new file mode 100644 index 00000000000..5c74f490722 --- /dev/null +++ b/firebase-ml-modeldownloader/src/test/java/com/google/firebase/ml/modeldownloader/internal/ModelFileDownloadServiceTest.java @@ -0,0 +1,452 @@ +// Copyright 2020 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package com.google.firebase.ml.modeldownloader.internal; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertThrows; +import static org.junit.Assert.assertTrue; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyLong; +import static org.mockito.Mockito.atLeastOnce; +import static org.mockito.Mockito.doThrow; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +import android.app.DownloadManager; +import android.app.DownloadManager.Request; +import android.content.Intent; +import android.database.MatrixCursor; +import android.net.Uri; +import android.os.ParcelFileDescriptor; +import androidx.test.core.app.ApplicationProvider; +import com.google.android.gms.tasks.Task; +import com.google.firebase.FirebaseApp; +import com.google.firebase.FirebaseOptions; +import com.google.firebase.FirebaseOptions.Builder; +import com.google.firebase.ml.modeldownloader.CustomModel; +import com.google.firebase.ml.modeldownloader.CustomModelDownloadConditions; +import com.google.firebase.ml.modeldownloader.TestOnCompleteListener; +import java.io.File; +import java.io.FileNotFoundException; +import java.io.IOException; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import org.junit.After; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; +import org.robolectric.RobolectricTestRunner; + +@RunWith(RobolectricTestRunner.class) +public class ModelFileDownloadServiceTest { + + private static final String TEST_PROJECT_ID = "777777777777"; + private static final FirebaseOptions FIREBASE_OPTIONS = + new Builder() + .setApplicationId("1:123456789:android:abcdef") + .setProjectId(TEST_PROJECT_ID) + .build(); + + private static final String MODEL_NAME = "MODEL_NAME_1"; + private static final String MODEL_HASH = "dsf324"; + public static final String MODEL_URL = "https://project.firebase.com/modelName/23424.jpg"; + private static final long URL_EXPIRATION = 604800L; + + private static final Long DOWNLOAD_ID = 987923L; + + private static final CustomModel CUSTOM_MODEL_NO_URL = + new CustomModel(MODEL_NAME, MODEL_HASH, 100, 0); + private static final CustomModel CUSTOM_MODEL_URL = + new CustomModel(MODEL_NAME, MODEL_HASH, 100, MODEL_URL, URL_EXPIRATION); + private static final CustomModel CUSTOM_MODEL_DOWNLOADING = + new CustomModel(MODEL_NAME, MODEL_HASH, 100, DOWNLOAD_ID); + CustomModel customModelDownloadComplete; + + private static final CustomModelDownloadConditions DOWNLOAD_CONDITIONS_CHARGING_IDLE = + new CustomModelDownloadConditions.Builder().requireCharging().requireDeviceIdle().build(); + + File testTempModelFile; + File testAppModelFile; + + private ModelFileDownloadService modelFileDownloadService; + private SharedPreferencesUtil sharedPreferencesUtil; + @Mock DownloadManager mockDownloadManager; + @Mock ModelFileManager mockFileManager; + + ExecutorService executor; + private MatrixCursor matrixCursor; + FirebaseApp app; + + @Before + public void setUp() { + MockitoAnnotations.initMocks(this); + FirebaseApp.clearInstancesForTest(); + app = FirebaseApp.initializeApp(ApplicationProvider.getApplicationContext(), FIREBASE_OPTIONS); + + executor = Executors.newSingleThreadExecutor(); + sharedPreferencesUtil = new SharedPreferencesUtil(app); + sharedPreferencesUtil.clearModelDetails(MODEL_NAME, false); + + modelFileDownloadService = + new ModelFileDownloadService( + app, mockDownloadManager, mockFileManager, sharedPreferencesUtil); + + matrixCursor = new MatrixCursor(new String[] {DownloadManager.COLUMN_STATUS}); + try { + testTempModelFile = File.createTempFile("fakeTempFile", ".tflite"); + + testAppModelFile = File.createTempFile("fakeAppFile", ".tflite"); + customModelDownloadComplete = + new CustomModel(MODEL_NAME, MODEL_HASH, 100, 0, testAppModelFile.getPath()); + } catch (IOException ex) { + System.out.println("Error creating test files"); + } + } + + @After + public void teardown() { + if (testAppModelFile.isFile()) { + testAppModelFile.delete(); + } + if (testTempModelFile.isFile()) { + testTempModelFile.delete(); + } + } + + @Test + public void downloaded_success_chargingAndIdle() throws Exception { + Request downloadRequest = new Request(Uri.parse(CUSTOM_MODEL_URL.getDownloadUrl())); + downloadRequest.setRequiresCharging(true); + downloadRequest.setRequiresDeviceIdle(true); + + when(mockDownloadManager.enqueue(any())).thenReturn(DOWNLOAD_ID); + matrixCursor.addRow(new Integer[] {DownloadManager.STATUS_SUCCESSFUL}); + when(mockDownloadManager.query(any())).thenReturn(matrixCursor); + + TestOnCompleteListener onCompleteListener = new TestOnCompleteListener<>(); + Task task = + modelFileDownloadService.download(CUSTOM_MODEL_URL, DOWNLOAD_CONDITIONS_CHARGING_IDLE); + + // Complete the download + Intent downloadCompleteIntent = new Intent(DownloadManager.ACTION_DOWNLOAD_COMPLETE); + downloadCompleteIntent.putExtra(DownloadManager.EXTRA_DOWNLOAD_ID, DOWNLOAD_ID); + app.getApplicationContext().sendBroadcast(downloadCompleteIntent); + + task.addOnCompleteListener(executor, onCompleteListener); + onCompleteListener.await(); + + assertTrue(task.isComplete()); + assertTrue(task.isSuccessful()); + assertNull(task.getResult()); + assertEquals( + sharedPreferencesUtil.getDownloadingCustomModelDetails(MODEL_NAME), + CUSTOM_MODEL_DOWNLOADING); + + verify(mockDownloadManager, times(1)).enqueue(any()); + verify(mockDownloadManager, atLeastOnce()).query(any()); + } + + @Test + public void downloaded_success_wifi() throws Exception { + Request downloadRequest = new Request(Uri.parse(CUSTOM_MODEL_URL.getDownloadUrl())); + downloadRequest.setRequiresCharging(true); + downloadRequest.setAllowedNetworkTypes(Request.NETWORK_WIFI); + + when(mockDownloadManager.enqueue(any())).thenReturn(DOWNLOAD_ID); + matrixCursor.addRow(new Integer[] {DownloadManager.STATUS_SUCCESSFUL}); + when(mockDownloadManager.query(any())).thenReturn(matrixCursor); + + TestOnCompleteListener onCompleteListener = new TestOnCompleteListener<>(); + Task task = + modelFileDownloadService.download( + CUSTOM_MODEL_URL, new CustomModelDownloadConditions.Builder().requireWifi().build()); + + // Complete the download + Intent downloadCompleteIntent = new Intent(DownloadManager.ACTION_DOWNLOAD_COMPLETE); + downloadCompleteIntent.putExtra(DownloadManager.EXTRA_DOWNLOAD_ID, DOWNLOAD_ID); + app.getApplicationContext().sendBroadcast(downloadCompleteIntent); + + task.addOnCompleteListener(executor, onCompleteListener); + onCompleteListener.await(); + + assertTrue(task.isComplete()); + assertTrue(task.isSuccessful()); + assertNull(task.getResult()); + + assertEquals( + sharedPreferencesUtil.getDownloadingCustomModelDetails(MODEL_NAME), + CUSTOM_MODEL_DOWNLOADING); + verify(mockDownloadManager, times(1)).enqueue(any()); + verify(mockDownloadManager, atLeastOnce()).query(any()); + } + + @Test + public void ensureModelDownloaded_noUrl() { + when(mockDownloadManager.enqueue(any())).thenReturn(DOWNLOAD_ID); + matrixCursor.addRow(new Integer[] {DownloadManager.STATUS_SUCCESSFUL}); + when(mockDownloadManager.query(any())).thenReturn(matrixCursor); + + TestOnCompleteListener onCompleteListener = new TestOnCompleteListener<>(); + Task task = modelFileDownloadService.ensureModelDownloaded(CUSTOM_MODEL_NO_URL); + + assertTrue(task.isComplete()); + + // Complete the download + Intent downloadCompleteIntent = new Intent(DownloadManager.ACTION_DOWNLOAD_COMPLETE); + downloadCompleteIntent.putExtra(DownloadManager.EXTRA_DOWNLOAD_ID, DOWNLOAD_ID); + app.getApplicationContext().sendBroadcast(downloadCompleteIntent); + + task.addOnCompleteListener(executor, onCompleteListener); + assertThrows(Exception.class, () -> onCompleteListener.await()); + + assertTrue(task.isComplete()); + assertFalse(task.isSuccessful()); + + assertNull(sharedPreferencesUtil.getDownloadingCustomModelDetails(MODEL_NAME)); + verify(mockDownloadManager, never()).enqueue(any()); + verify(mockDownloadManager, never()).query(any()); + } + + @Test + public void ensureModelDownloaded_success() throws Exception { + when(mockDownloadManager.enqueue(any())).thenReturn(DOWNLOAD_ID); + matrixCursor.addRow(new Integer[] {DownloadManager.STATUS_SUCCESSFUL}); + when(mockDownloadManager.query(any())).thenReturn(matrixCursor); + + TestOnCompleteListener onCompleteListener = new TestOnCompleteListener<>(); + Task task = modelFileDownloadService.ensureModelDownloaded(CUSTOM_MODEL_URL); + + // Complete the download + Intent downloadCompleteIntent = new Intent(DownloadManager.ACTION_DOWNLOAD_COMPLETE); + downloadCompleteIntent.putExtra(DownloadManager.EXTRA_DOWNLOAD_ID, DOWNLOAD_ID); + app.getApplicationContext().sendBroadcast(downloadCompleteIntent); + + task.addOnCompleteListener(executor, onCompleteListener); + onCompleteListener.await(); + + assertTrue(task.isComplete()); + assertTrue(task.isSuccessful()); + assertNull(task.getResult()); + assertEquals( + sharedPreferencesUtil.getDownloadingCustomModelDetails(MODEL_NAME), + CUSTOM_MODEL_DOWNLOADING); + + verify(mockDownloadManager, times(1)).enqueue(any()); + verify(mockDownloadManager, atLeastOnce()).query(any()); + } + + @Test + public void ensureModelDownloaded_downloadFailed() { + when(mockDownloadManager.enqueue(any())).thenReturn(DOWNLOAD_ID); + matrixCursor.addRow(new Integer[] {DownloadManager.STATUS_FAILED}); + when(mockDownloadManager.query(any())).thenReturn(matrixCursor); + + TestOnCompleteListener onCompleteListener = new TestOnCompleteListener<>(); + Task task = modelFileDownloadService.ensureModelDownloaded(CUSTOM_MODEL_URL); + + try { + // Complete the download + Intent downloadCompleteIntent = new Intent(DownloadManager.ACTION_DOWNLOAD_COMPLETE); + downloadCompleteIntent.putExtra(DownloadManager.EXTRA_DOWNLOAD_ID, DOWNLOAD_ID); + app.getApplicationContext().sendBroadcast(downloadCompleteIntent); + + task.addOnCompleteListener(executor, onCompleteListener); + onCompleteListener.await(); + } catch (Exception ex) { + assertTrue(ex.getMessage().contains("Failed")); + } + + assertTrue(task.isComplete()); + assertFalse(task.isSuccessful()); + assertTrue(task.getException().getMessage().contains("Failed")); + assertEquals( + sharedPreferencesUtil.getDownloadingCustomModelDetails(MODEL_NAME), + CUSTOM_MODEL_DOWNLOADING); + + verify(mockDownloadManager, times(1)).enqueue(any()); + verify(mockDownloadManager, atLeastOnce()).query(any()); + } + + @Test + public void scheduleModelDownload_success() { + when(mockDownloadManager.enqueue(any())).thenReturn(DOWNLOAD_ID); + Long id = modelFileDownloadService.scheduleModelDownload(CUSTOM_MODEL_URL); + assertEquals(DOWNLOAD_ID, id); + assertEquals( + sharedPreferencesUtil.getDownloadingCustomModelDetails(MODEL_NAME), + CUSTOM_MODEL_DOWNLOADING); + verify(mockDownloadManager, times(1)).enqueue(any()); + } + + @Test + public void scheduleModelDownload_noUri() { + assertNull(modelFileDownloadService.scheduleModelDownload(CUSTOM_MODEL_NO_URL)); + verify(mockDownloadManager, never()).enqueue(any()); + } + + @Test + public void scheduleModelDownload_failed() { + when(mockDownloadManager.enqueue(any())).thenThrow(new IllegalArgumentException("bad enqueue")); + assertThrows( + IllegalArgumentException.class, + () -> modelFileDownloadService.scheduleModelDownload(CUSTOM_MODEL_URL)); + assertNull(sharedPreferencesUtil.getDownloadingCustomModelDetails(MODEL_NAME)); + verify(mockDownloadManager, times(1)).enqueue(any()); + } + + @Test + public void getDownloadStatus_NullCursor() { + // Not found + assertNull(modelFileDownloadService.getDownloadingModelStatusCode(0L)); + when(mockDownloadManager.query(any())).thenReturn(null); + assertNull(modelFileDownloadService.getDownloadingModelStatusCode(DOWNLOAD_ID)); + } + + @Test + public void getDownloadStatus_Success() { + // Not found + assertNull(modelFileDownloadService.getDownloadingModelStatusCode(0L)); + matrixCursor.addRow(new Integer[] {DownloadManager.STATUS_SUCCESSFUL}); + when(mockDownloadManager.query(any())).thenReturn(matrixCursor); + assertTrue( + modelFileDownloadService.getDownloadingModelStatusCode(DOWNLOAD_ID) + == DownloadManager.STATUS_SUCCESSFUL); + } + + @Test + public void maybeCheckDownloadingComplete_downloadComplete() throws Exception { + sharedPreferencesUtil.setDownloadingCustomModelDetails(CUSTOM_MODEL_DOWNLOADING); + assertNull(modelFileDownloadService.getDownloadingModelStatusCode(0L)); + matrixCursor.addRow(new Integer[] {DownloadManager.STATUS_SUCCESSFUL}); + when(mockDownloadManager.query(any())).thenReturn(matrixCursor); + when(mockDownloadManager.openDownloadedFile(anyLong())) + .thenReturn( + ParcelFileDescriptor.open(testTempModelFile, ParcelFileDescriptor.MODE_READ_ONLY)); + + when(mockFileManager.moveModelToDestinationFolder(any(), any())).thenReturn(testAppModelFile); + + modelFileDownloadService.maybeCheckDownloadingComplete(); + + assertEquals( + sharedPreferencesUtil.getCustomModelDetails(MODEL_NAME), customModelDownloadComplete); + verify(mockDownloadManager, times(3)).query(any()); + } + + @Test + public void maybeCheckDownloadingComplete_downloadInprogress() throws Exception { + sharedPreferencesUtil.setDownloadingCustomModelDetails(CUSTOM_MODEL_DOWNLOADING); + assertNull(modelFileDownloadService.getDownloadingModelStatusCode(0L)); + matrixCursor.addRow(new Integer[] {DownloadManager.STATUS_RUNNING}); + when(mockDownloadManager.query(any())).thenReturn(matrixCursor); + + modelFileDownloadService.maybeCheckDownloadingComplete(); + assertEquals( + sharedPreferencesUtil.getDownloadingCustomModelDetails(MODEL_NAME), + CUSTOM_MODEL_DOWNLOADING); + verify(mockDownloadManager, times(2)).query(any()); + } + + @Test + public void maybeCheckDownloadingComplete_multipleDownloads() throws Exception { + sharedPreferencesUtil.setDownloadingCustomModelDetails(CUSTOM_MODEL_DOWNLOADING); + String secondModelName = "secondModelName"; + CustomModel downloading2 = new CustomModel(secondModelName, MODEL_HASH, 100, DOWNLOAD_ID + 1); + sharedPreferencesUtil.setDownloadingCustomModelDetails(downloading2); + + assertNull(modelFileDownloadService.getDownloadingModelStatusCode(0L)); + matrixCursor.addRow(new Integer[] {DownloadManager.STATUS_SUCCESSFUL}); + when(mockDownloadManager.query(any())).thenReturn(matrixCursor); + when(mockDownloadManager.openDownloadedFile(anyLong())) + .thenReturn( + ParcelFileDescriptor.open(testTempModelFile, ParcelFileDescriptor.MODE_READ_ONLY)); + + when(mockFileManager.moveModelToDestinationFolder(any(), any())).thenReturn(testAppModelFile); + + modelFileDownloadService.maybeCheckDownloadingComplete(); + + assertEquals( + sharedPreferencesUtil.getCustomModelDetails(MODEL_NAME), customModelDownloadComplete); + assertEquals( + sharedPreferencesUtil.getCustomModelDetails(secondModelName), + new CustomModel(secondModelName, MODEL_HASH, 100, 0, testAppModelFile.getPath())); + verify(mockDownloadManager, times(5)).query(any()); + } + + @Test + public void maybeCheckDownloadingComplete_noDownloadsInProgress() throws Exception { + modelFileDownloadService.maybeCheckDownloadingComplete(); + verify(mockDownloadManager, never()).query(any()); + } + + @Test + public void loadNewlyDownloadedModelFile_successFilePresent() throws Exception { + // Not found + assertNull(modelFileDownloadService.getDownloadingModelStatusCode(0L)); + matrixCursor.addRow(new Integer[] {DownloadManager.STATUS_SUCCESSFUL}); + when(mockDownloadManager.query(any())).thenReturn(matrixCursor); + when(mockDownloadManager.openDownloadedFile(anyLong())) + .thenReturn( + ParcelFileDescriptor.open(testTempModelFile, ParcelFileDescriptor.MODE_READ_ONLY)); + + when(mockFileManager.moveModelToDestinationFolder(any(), any())).thenReturn(testAppModelFile); + + assertEquals( + modelFileDownloadService.loadNewlyDownloadedModelFile(CUSTOM_MODEL_DOWNLOADING), + testAppModelFile); + + CustomModel retrievedModel = sharedPreferencesUtil.getCustomModelDetails(MODEL_NAME); + assertEquals(retrievedModel, customModelDownloadComplete); + } + + @Test + public void loadNewlyDownloadedModelFile_successNoFile() throws Exception { + // Not found + assertNull(modelFileDownloadService.getDownloadingModelStatusCode(0L)); + matrixCursor.addRow(new Integer[] {DownloadManager.STATUS_SUCCESSFUL}); + when(mockDownloadManager.query(any())).thenReturn(matrixCursor); + doThrow(new FileNotFoundException("File not found.")) + .when(mockDownloadManager) + .openDownloadedFile(anyLong()); + + assertNull(modelFileDownloadService.loadNewlyDownloadedModelFile(CUSTOM_MODEL_DOWNLOADING)); + assertNull(sharedPreferencesUtil.getDownloadingCustomModelDetails(MODEL_NAME)); + } + + @Test + public void loadNewlyDownloadedModelFile_Running() throws Exception { + // Not found + assertNull(modelFileDownloadService.getDownloadingModelStatusCode(0L)); + matrixCursor.addRow(new Integer[] {DownloadManager.STATUS_RUNNING}); + when(mockDownloadManager.query(any())).thenReturn(matrixCursor); + assertNull(modelFileDownloadService.loadNewlyDownloadedModelFile(CUSTOM_MODEL_DOWNLOADING)); + assertNull(sharedPreferencesUtil.getCustomModelDetails(MODEL_NAME)); + } + + @Test + public void loadNewlyDownloadedModelFile_Failed() throws Exception { + // Not found + assertNull(modelFileDownloadService.getDownloadingModelStatusCode(0L)); + matrixCursor.addRow(new Integer[] {DownloadManager.STATUS_FAILED}); + when(mockDownloadManager.query(any())).thenReturn(matrixCursor); + assertNull(modelFileDownloadService.loadNewlyDownloadedModelFile(CUSTOM_MODEL_DOWNLOADING)); + assertNull(sharedPreferencesUtil.getCustomModelDetails(MODEL_NAME)); + } +} diff --git a/firebase-ml-modeldownloader/src/test/java/com/google/firebase/ml/modeldownloader/internal/ModelFileManagerTest.java b/firebase-ml-modeldownloader/src/test/java/com/google/firebase/ml/modeldownloader/internal/ModelFileManagerTest.java new file mode 100644 index 00000000000..c48111a4d59 --- /dev/null +++ b/firebase-ml-modeldownloader/src/test/java/com/google/firebase/ml/modeldownloader/internal/ModelFileManagerTest.java @@ -0,0 +1,117 @@ +// Copyright 2020 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package com.google.firebase.ml.modeldownloader.internal; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; + +import android.os.ParcelFileDescriptor; +import androidx.test.core.app.ApplicationProvider; +import com.google.firebase.FirebaseApp; +import com.google.firebase.FirebaseOptions; +import com.google.firebase.FirebaseOptions.Builder; +import com.google.firebase.ml.modeldownloader.CustomModel; +import java.io.File; +import java.io.IOException; +import org.junit.After; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.mockito.MockitoAnnotations; +import org.robolectric.RobolectricTestRunner; + +@RunWith(RobolectricTestRunner.class) +public class ModelFileManagerTest { + + public static final String TEST_PROJECT_ID = "777777777777"; + public static final FirebaseOptions FIREBASE_OPTIONS = + new Builder() + .setApplicationId("1:123456789:android:abcdef") + .setProjectId(TEST_PROJECT_ID) + .build(); + + public static final String MODEL_NAME = "MODEL_NAME_1"; + public static final String MODEL_HASH = "dsf324"; + + final CustomModel CUSTOM_MODEL_NO_FILE = new CustomModel(MODEL_NAME, MODEL_HASH, 100, 0); + + private File testModelFile; + + ModelFileManager fileManager; + String expectedDestinationFolder; + + @Before + public void setUp() throws IOException { + MockitoAnnotations.initMocks(this); + FirebaseApp.clearInstancesForTest(); + FirebaseApp app = + FirebaseApp.initializeApp(ApplicationProvider.getApplicationContext(), FIREBASE_OPTIONS); + + fileManager = new ModelFileManager(app); + + setUpTestingFiles(app); + } + + private void setUpTestingFiles(FirebaseApp app) throws IOException { + final File testDir = new File(app.getApplicationContext().getNoBackupFilesDir(), "tmpModels"); + testDir.mkdirs(); + // make sure the directory is empty. Doesn't recurse into subdirs, but that's OK since + // we're only using this directory for this test and we won't create any subdirs. + for (File f : testDir.listFiles()) { + if (f.isFile()) { + f.delete(); + } + } + + testModelFile = File.createTempFile("modelFile", "tflite"); + expectedDestinationFolder = + new File( + app.getApplicationContext().getNoBackupFilesDir(), + ModelFileManager.CUSTOM_MODEL_ROOT_PATH) + .getAbsolutePath() + + "/" + + app.getPersistenceKey() + + "/" + + MODEL_NAME; + } + + @After + public void teardown() { + testModelFile.deleteOnExit(); + } + + @Test + public void getDirImpl() throws Exception { + File modelDirectory = fileManager.getDirImpl(MODEL_NAME); + assertTrue(modelDirectory.getAbsolutePath().endsWith(MODEL_NAME)); + } + + @Test + public void getModelFileDestination_noExistingFiles() throws Exception { + File firstFile = fileManager.getModelFileDestination(CUSTOM_MODEL_NO_FILE); + assertTrue(firstFile.getAbsolutePath().endsWith(String.format("%s/0", MODEL_NAME))); + } + + @Test + public void moveModelToDestinationFolder() throws Exception { + ParcelFileDescriptor fd = + ParcelFileDescriptor.open(testModelFile, ParcelFileDescriptor.MODE_READ_ONLY); + + assertEquals( + fileManager.moveModelToDestinationFolder(CUSTOM_MODEL_NO_FILE, fd), + new File(expectedDestinationFolder + "/0")); + new File(expectedDestinationFolder + "/0").delete(); + } +} diff --git a/firebase-ml-modeldownloader/src/test/java/com/google/firebase/ml/modeldownloader/internal/SharedPreferencesUtilTest.java b/firebase-ml-modeldownloader/src/test/java/com/google/firebase/ml/modeldownloader/internal/SharedPreferencesUtilTest.java index a9aa7883703..8d476a1cefa 100644 --- a/firebase-ml-modeldownloader/src/test/java/com/google/firebase/ml/modeldownloader/internal/SharedPreferencesUtilTest.java +++ b/firebase-ml-modeldownloader/src/test/java/com/google/firebase/ml/modeldownloader/internal/SharedPreferencesUtilTest.java @@ -27,6 +27,7 @@ import org.junit.Before; import org.junit.Test; import org.junit.runner.RunWith; +import org.mockito.MockitoAnnotations; import org.robolectric.RobolectricTestRunner; /** Tests for {@link SharedPreferencesUtil}. */ @@ -34,22 +35,19 @@ public class SharedPreferencesUtilTest { private static final String TEST_PROJECT_ID = "777777777777"; - private static final String MODEL_NAME = "ModelName"; private static final String MODEL_HASH = "dsf324"; private static final CustomModel CUSTOM_MODEL_DOWNLOAD_COMPLETE = new CustomModel(MODEL_NAME, MODEL_HASH, 100, 0, "file/path/store/ModelName/1"); - private static final CustomModel CUSTOM_MODEL_UPDATE_IN_BACKGROUND = new CustomModel(MODEL_NAME, MODEL_HASH, 100, 986, "file/path/store/ModelName/1"); - private static final CustomModel CUSTOM_MODEL_DOWNLOADING = new CustomModel(MODEL_NAME, MODEL_HASH, 100, 986); - private SharedPreferencesUtil sharedPreferencesUtil; @Before public void setUp() { + MockitoAnnotations.initMocks(this); FirebaseApp.clearInstancesForTest(); FirebaseApp app = FirebaseApp.initializeApp( @@ -122,7 +120,7 @@ public void clearDownloadingModelDetails_keepsLocalModel() throws IllegalArgumen } @Test - public void listDownloadedModels_localModelFound() throws IllegalArgumentException { + public void listDownloadedModels_localModelFound() { sharedPreferencesUtil.setUploadedCustomModelDetails(CUSTOM_MODEL_DOWNLOAD_COMPLETE); Set retrievedModel = sharedPreferencesUtil.listDownloadedModels(); assertEquals(retrievedModel.size(), 1); @@ -130,18 +128,18 @@ public void listDownloadedModels_localModelFound() throws IllegalArgumentExcepti } @Test - public void listDownloadedModels_downloadingModelNotFound() throws IllegalArgumentException { + public void listDownloadedModels_downloadingModelNotFound() { sharedPreferencesUtil.setDownloadingCustomModelDetails(CUSTOM_MODEL_DOWNLOADING); assertEquals(sharedPreferencesUtil.listDownloadedModels().size(), 0); } @Test - public void listDownloadedModels_noModels() throws IllegalArgumentException { + public void listDownloadedModels_noModels() { assertEquals(sharedPreferencesUtil.listDownloadedModels().size(), 0); } @Test - public void listDownloadedModels_multipleModels() throws IllegalArgumentException { + public void listDownloadedModels_multipleModels() { sharedPreferencesUtil.setUploadedCustomModelDetails(CUSTOM_MODEL_DOWNLOAD_COMPLETE); CustomModel model2 =