diff --git a/iris/index.html b/iris/index.html index 338b9f745..a2c8fa26c 100644 --- a/iris/index.html +++ b/iris/index.html @@ -84,8 +84,23 @@ color: white; } - .create-model { - display: inline-block; + .region { + margin-top: 6px; + margin-bottom: 6px; + padding-top: 3px; + padding-bottom: 3px; + border-style: dashed; + border-width: 1px; + border-color: #888; + } + + .region-title { + font-weight: bold; + } + + .load-save-section { + padding-top: 3px; + padding-bottom: 3px; } .logit-span { @@ -104,21 +119,33 @@

TensorFlow.js Layers: Iris Demo

-
-
- Train Epochs: - -
-
- Learning Rate: - +
+
Train Model
+
+
+ Train Epochs: + +
+
+ Learning Rate: + +
+
-
-
- - +
+
Save/Load Model
+
+ +
+ +
+ + + + Status unavailable. +
diff --git a/iris/index.js b/iris/index.js index ba9750562..557d8d89d 100644 --- a/iris/index.js +++ b/iris/index.js @@ -128,7 +128,6 @@ async function evaluateModelOnTestData(model, xTest, yTest) { predictOnManualInput(model); } -const LOCAL_MODEL_JSON_URL = 'http://localhost:1235/resources/model.json'; const HOSTED_MODEL_JSON_URL = 'https://storage.googleapis.com/tfjs-models/tfjs/iris_v1/model.json'; @@ -138,10 +137,15 @@ const HOSTED_MODEL_JSON_URL = async function iris() { const [xTrain, yTrain, xTest, yTest] = data.getIrisData(0.15); + const localLoadButton = document.getElementById('load-local'); + const localSaveButton = document.getElementById('save-local'); + const localRemoveButton = document.getElementById('remove-local'); + document.getElementById('train-from-scratch') .addEventListener('click', async () => { model = await trainModel(xTrain, yTrain, xTest, yTest); - evaluateModelOnTestData(model, xTest, yTest); + await evaluateModelOnTestData(model, xTest, yTest); + localSaveButton.disabled = false; }); if (await loader.urlExists(HOSTED_MODEL_JSON_URL)) { @@ -150,23 +154,27 @@ async function iris() { button.addEventListener('click', async () => { ui.clearEvaluateTable(); model = await loader.loadHostedPretrainedModel(HOSTED_MODEL_JSON_URL); - predictOnManualInput(model); + await predictOnManualInput(model); + localSaveButton.disabled = false; }); - // button.style.visibility = 'visible'; - button.style.display = 'inline-block'; } - if (await loader.urlExists(LOCAL_MODEL_JSON_URL)) { - ui.status('Model available: ' + LOCAL_MODEL_JSON_URL); - const button = document.getElementById('load-pretrained-local'); - button.addEventListener('click', async () => { - ui.clearEvaluateTable(); - model = await loader.loadHostedPretrainedModel(LOCAL_MODEL_JSON_URL); - predictOnManualInput(model); - }); - // button.style.visibility = 'visible'; - button.style.display = 'inline-block'; - } + localLoadButton.addEventListener('click', async () => { + model = await loader.loadModelLocally(); + await predictOnManualInput(model); + }); + + localSaveButton.addEventListener('click', async () => { + await loader.saveModelLocally(model); + await loader.updateLocalModelStatus(); + }); + + localRemoveButton.addEventListener('click', async () => { + await loader.removeModelLocally(); + await loader.updateLocalModelStatus(); + }); + + await loader.updateLocalModelStatus(); ui.status('Standing by.'); ui.wireUpEvaluateTableCallbacks(() => predictOnManualInput(model)); diff --git a/iris/loader.js b/iris/loader.js index 51296e4e9..d07a64d41 100644 --- a/iris/loader.js +++ b/iris/loader.js @@ -41,13 +41,49 @@ export async function loadHostedPretrainedModel(url) { try { const model = await tf.loadModel(url); ui.status('Done loading pretrained model.'); - // We can't load a model twice due to - // https://github.com/tensorflow/tfjs/issues/34 - // Therefore we remove the load buttons to avoid user confusion. - ui.disableLoadModelButtons(); return model; } catch (err) { console.error(err); ui.status('Loading pretrained model failed.'); } } + +// The URL-like path that identifies the client-side location where downloaded +// or locally trained models can be stored. +const LOCAL_MODEL_URL = 'indexeddb://tfjs-iris-demo-model/v1'; + +export async function saveModelLocally(model) { + const saveResult = await model.save(LOCAL_MODEL_URL); +} + +export async function loadModelLocally(model) { + return await tf.loadModel(LOCAL_MODEL_URL); +} + +export async function removeModelLocally(model) { + return await tf.io.removeModel(LOCAL_MODEL_URL); +} + +/** + * Check the presence and status of locally saved models (e.g., in IndexedDB). + * + * Update the UI control states accordingly. + */ +export async function updateLocalModelStatus() { + const localModelStatus = document.getElementById('local-model-status'); + const localLoadButton = document.getElementById('load-local'); + const localRemoveButton = document.getElementById('remove-local'); + + const modelsInfo = await tf.io.listModels(); + if (LOCAL_MODEL_URL in modelsInfo) { + localModelStatus.textContent = + 'Found locally-stored model saved at ' + + modelsInfo[LOCAL_MODEL_URL].dateSaved; + localLoadButton.disabled = false; + localRemoveButton.disabled = false; + } else { + localModelStatus.textContent = 'No locally-stored model is found.'; + localLoadButton.disabled = true; + localRemoveButton.disabled = true; + } +} diff --git a/iris/package.json b/iris/package.json index 6d2f3e477..376022aec 100644 --- a/iris/package.json +++ b/iris/package.json @@ -9,12 +9,12 @@ "node": ">=8.9.0" }, "dependencies": { - "@tensorflow/tfjs": "0.10.0", + "@tensorflow/tfjs": "0.11.2", "vega-embed": "^3.0.0" }, "scripts": { - "watch": "./serve.sh", - "build": "NODE_ENV=production parcel build index.html --no-minify --public-url ./" + "build": "NODE_ENV=production parcel build index.html --no-minify --public-url ./", + "watch": "./serve.sh" }, "devDependencies": { "babel-plugin-transform-runtime": "~6.23.0", diff --git a/iris/serve.sh b/iris/serve.sh index e5ede63ed..928d5cf95 100755 --- a/iris/serve.sh +++ b/iris/serve.sh @@ -35,8 +35,7 @@ node_modules/http-server/bin/http-server dist --cors -p "${RESOURCE_PORT}" > /de echo Starting the example html/js server... # This uses port 1234 by default. -node_modules/parcel-bundler/bin/cli.js serve -d dist --open --no-hmr --public-url / index.html +node_modules/parcel-bundler/bin/cli.js serve -d dist --open --no-hmr --public-url / index.html -p 1236 # When the Parcel server exits, kill the http-server too. kill $HTTP_SERVER_PID - diff --git a/iris/ui.js b/iris/ui.js index 044ef8200..63c09d0e1 100644 --- a/iris/ui.js +++ b/iris/ui.js @@ -50,7 +50,8 @@ export function plotLosses(lossValues, epoch, newTrainLoss, newValidationLoss) { 'x': {'field': 'epoch', 'type': 'ordinal'}, 'y': {'field': 'loss', 'type': 'quantitative'}, 'color': {'field': 'set', 'type': 'nominal'}, - } + }, + 'width': 500, }, {}); } @@ -78,7 +79,8 @@ export function plotAccuracies( 'x': {'field': 'epoch', 'type': 'ordinal'}, 'y': {'field': 'accuracy', 'type': 'quantitative'}, 'color': {'field': 'set', 'type': 'nominal'}, - } + }, + 'width': 500, }, {}); } @@ -230,8 +232,3 @@ export function status(statusText) { console.log(statusText); document.getElementById('demo-status').textContent = statusText; } - -export function disableLoadModelButtons() { - document.getElementById('load-pretrained-remote').style.display = 'none'; - document.getElementById('load-pretrained-local').style.display = 'none'; -} diff --git a/iris/yarn.lock b/iris/yarn.lock index a1c7f3fca..a4bda3bdb 100644 --- a/iris/yarn.lock +++ b/iris/yarn.lock @@ -2,27 +2,87 @@ # yarn lockfile v1 -"@tensorflow/tfjs-core@0.8.1": - version "0.8.1" - resolved "https://registry.yarnpkg.com/@tensorflow/tfjs-core/-/tfjs-core-0.8.1.tgz#d93e87302e29620906003b697d2d8596410ad712" +"@protobufjs/aspromise@^1.1.1", "@protobufjs/aspromise@^1.1.2": + version "1.1.2" + resolved "https://registry.yarnpkg.com/@protobufjs/aspromise/-/aspromise-1.1.2.tgz#9b8b0cc663d669a7d8f6f5d0893a14d348f30fbf" + +"@protobufjs/base64@^1.1.2": + version "1.1.2" + resolved "https://registry.yarnpkg.com/@protobufjs/base64/-/base64-1.1.2.tgz#4c85730e59b9a1f1f349047dbf24296034bb2735" + +"@protobufjs/codegen@^2.0.4": + version "2.0.4" + resolved "https://registry.yarnpkg.com/@protobufjs/codegen/-/codegen-2.0.4.tgz#7ef37f0d010fb028ad1ad59722e506d9262815cb" + +"@protobufjs/eventemitter@^1.1.0": + version "1.1.0" + resolved "https://registry.yarnpkg.com/@protobufjs/eventemitter/-/eventemitter-1.1.0.tgz#355cbc98bafad5978f9ed095f397621f1d066b70" + +"@protobufjs/fetch@^1.1.0": + version "1.1.0" + resolved "https://registry.yarnpkg.com/@protobufjs/fetch/-/fetch-1.1.0.tgz#ba99fb598614af65700c1619ff06d454b0d84c45" + dependencies: + "@protobufjs/aspromise" "^1.1.1" + "@protobufjs/inquire" "^1.1.0" + +"@protobufjs/float@^1.0.2": + version "1.0.2" + resolved "https://registry.yarnpkg.com/@protobufjs/float/-/float-1.0.2.tgz#5e9e1abdcb73fc0a7cb8b291df78c8cbd97b87d1" + +"@protobufjs/inquire@^1.1.0": + version "1.1.0" + resolved "https://registry.yarnpkg.com/@protobufjs/inquire/-/inquire-1.1.0.tgz#ff200e3e7cf2429e2dcafc1140828e8cc638f089" + +"@protobufjs/path@^1.1.2": + version "1.1.2" + resolved "https://registry.yarnpkg.com/@protobufjs/path/-/path-1.1.2.tgz#6cc2b20c5c9ad6ad0dccfd21ca7673d8d7fbf68d" + +"@protobufjs/pool@^1.1.0": + version "1.1.0" + resolved "https://registry.yarnpkg.com/@protobufjs/pool/-/pool-1.1.0.tgz#09fd15f2d6d3abfa9b65bc366506d6ad7846ff54" + +"@protobufjs/utf8@^1.1.0": + version "1.1.0" + resolved "https://registry.yarnpkg.com/@protobufjs/utf8/-/utf8-1.1.0.tgz#a777360b5b39a1a2e5106f8e858f2fd2d060c570" + +"@tensorflow/tfjs-converter@0.4.0": + version "0.4.0" + resolved "https://registry.yarnpkg.com/@tensorflow/tfjs-converter/-/tfjs-converter-0.4.0.tgz#bf038475417a37a8b58db1b3d3ba6dea8be2e65d" + dependencies: + "@types/long" "~3.0.32" + protobufjs "~6.8.0" + url "^0.11.0" + +"@tensorflow/tfjs-core@0.11.1": + version "0.11.1" + resolved "https://registry.yarnpkg.com/@tensorflow/tfjs-core/-/tfjs-core-0.11.1.tgz#d26808f912529668d0a41228da37566b6b2f4f08" dependencies: seedrandom "~2.4.3" -"@tensorflow/tfjs-layers@0.5.0": - version "0.5.0" - resolved "https://registry.yarnpkg.com/@tensorflow/tfjs-layers/-/tfjs-layers-0.5.0.tgz#0ec8d07b46863a162d3e0c60fae3d1087d8aa3ce" +"@tensorflow/tfjs-layers@0.6.2": + version "0.6.2" + resolved "https://registry.yarnpkg.com/@tensorflow/tfjs-layers/-/tfjs-layers-0.6.2.tgz#7d152763c99acf5f86d6a735dbb9d5ee6af04e22" -"@tensorflow/tfjs@0.10.0": - version "0.10.0" - resolved "https://registry.yarnpkg.com/@tensorflow/tfjs/-/tfjs-0.10.0.tgz#c64f8740b46c2ba734dc0564d37e6a2c33055511" +"@tensorflow/tfjs@0.11.2": + version "0.11.2" + resolved "https://registry.yarnpkg.com/@tensorflow/tfjs/-/tfjs-0.11.2.tgz#908cf8898d0a52a5b448a894ba60722d642da230" dependencies: - "@tensorflow/tfjs-core" "0.8.1" - "@tensorflow/tfjs-layers" "0.5.0" + "@tensorflow/tfjs-converter" "0.4.0" + "@tensorflow/tfjs-core" "0.11.1" + "@tensorflow/tfjs-layers" "0.6.2" "@types/json-stable-stringify@^1.0.32": version "1.0.32" resolved "https://registry.yarnpkg.com/@types/json-stable-stringify/-/json-stable-stringify-1.0.32.tgz#121f6917c4389db3923640b2e68de5fa64dda88e" +"@types/long@^3.0.32", "@types/long@~3.0.32": + version "3.0.32" + resolved "https://registry.yarnpkg.com/@types/long/-/long-3.0.32.tgz#f4e5af31e9e9b196d8e5fca8a5e2e20aa3d60b69" + +"@types/node@^8.9.4": + version "8.10.17" + resolved "https://registry.yarnpkg.com/@types/node/-/node-8.10.17.tgz#d48cf10f0dc6dcf59f827f5a3fc7a4a6004318d3" + abbrev@1: version "1.1.1" resolved "https://registry.yarnpkg.com/abbrev/-/abbrev-1.1.1.tgz#f8f2c887ad10bf67f634f005b6987fed3179aac8" @@ -2480,6 +2540,10 @@ lodash@^4.17.4: version "4.17.5" resolved "https://registry.yarnpkg.com/lodash/-/lodash-4.17.5.tgz#99a92d65c0272debe8c96b6057bc8fbfa3bed511" +long@^4.0.0: + version "4.0.0" + resolved "https://registry.yarnpkg.com/long/-/long-4.0.0.tgz#9a7b71cfb7d361a194ea555241c92f7468d5bf28" + loose-envify@^1.0.0: version "1.3.1" resolved "https://registry.yarnpkg.com/loose-envify/-/loose-envify-1.3.1.tgz#d1a8ad33fa9ce0e713d65fdd0ac8b748d478c848" @@ -3319,6 +3383,24 @@ proto-list@~1.2.1: version "1.2.4" resolved "https://registry.yarnpkg.com/proto-list/-/proto-list-1.2.4.tgz#212d5bfe1318306a420f6402b8e26ff39647a849" +protobufjs@~6.8.0: + version "6.8.6" + resolved "https://registry.yarnpkg.com/protobufjs/-/protobufjs-6.8.6.tgz#ce3cf4fff9625b62966c455fc4c15e4331a11ca2" + dependencies: + "@protobufjs/aspromise" "^1.1.2" + "@protobufjs/base64" "^1.1.2" + "@protobufjs/codegen" "^2.0.4" + "@protobufjs/eventemitter" "^1.1.0" + "@protobufjs/fetch" "^1.1.0" + "@protobufjs/float" "^1.0.2" + "@protobufjs/inquire" "^1.1.0" + "@protobufjs/path" "^1.1.2" + "@protobufjs/pool" "^1.1.0" + "@protobufjs/utf8" "^1.1.0" + "@types/long" "^3.0.32" + "@types/node" "^8.9.4" + long "^4.0.0" + prr@~1.0.1: version "1.0.1" resolved "https://registry.yarnpkg.com/prr/-/prr-1.0.1.tgz#d3fc114ba06995a45ec6893f484ceb1d78f5f476"