From 04f6e00a0200be6dc55ff72cb35ddfd06375497b Mon Sep 17 00:00:00 2001 From: Scott Ashton Date: Tue, 3 Jun 2025 11:48:49 -0600 Subject: [PATCH 1/4] Add web support --- .prettierrc | 3 + apps/example/index.html | 36 ++ apps/example/index.js | 24 +- apps/example/metro.config.js | 71 +++- apps/example/package.json | 8 +- apps/example/src/App.tsx | 5 +- apps/example/src/Home.tsx | 6 +- .../src/MNISTInference/MNISTInference.tsx | 105 +++--- apps/example/src/assets/helvetica.ttf | Bin 0 -> 50583 bytes packages/webgpu/src/Canvas.tsx | 59 +--- .../webgpu/src/NativeWebGPUModuleWrapper.ts | 1 + .../src/NativeWebGPUModuleWrapper.web.ts | 12 + .../webgpu/src/WebGPUViewNativeComponent.ts | 2 +- packages/webgpu/src/WebGPUWrapper.ts | 1 + packages/webgpu/src/WebGPUWrapper.web.ts | 80 +++++ packages/webgpu/src/hooks.tsx | 3 +- packages/webgpu/src/index.tsx | 198 +---------- packages/webgpu/src/main.tsx | 215 ++++++++++++ packages/webgpu/src/main.web.tsx | 7 + packages/webgpu/src/types.ts | 20 ++ packages/webgpu/src/utils.ts | 22 ++ packages/webgpu/src/utils.web.ts | 50 +++ yarn.lock | 318 +++++++++++------- 23 files changed, 818 insertions(+), 428 deletions(-) create mode 100644 .prettierrc create mode 100644 apps/example/index.html create mode 100644 apps/example/src/assets/helvetica.ttf create mode 100644 packages/webgpu/src/NativeWebGPUModuleWrapper.ts create mode 100644 packages/webgpu/src/NativeWebGPUModuleWrapper.web.ts create mode 100644 packages/webgpu/src/WebGPUWrapper.ts create mode 100644 packages/webgpu/src/WebGPUWrapper.web.ts create mode 100644 packages/webgpu/src/main.tsx create mode 100644 packages/webgpu/src/main.web.tsx create mode 100644 packages/webgpu/src/types.ts create mode 100644 packages/webgpu/src/utils.ts create mode 100644 packages/webgpu/src/utils.web.ts diff --git a/.prettierrc b/.prettierrc new file mode 100644 index 000000000..bf357fbbc --- /dev/null +++ b/.prettierrc @@ -0,0 +1,3 @@ +{ + "trailingComma": "all" +} diff --git a/apps/example/index.html b/apps/example/index.html new file mode 100644 index 000000000..ef8728e71 --- /dev/null +++ b/apps/example/index.html @@ -0,0 +1,36 @@ + + + + + + + + + + Example + + + + +
+ + + diff --git a/apps/example/index.js b/apps/example/index.js index 69303b34d..9e609fbd1 100644 --- a/apps/example/index.js +++ b/apps/example/index.js @@ -2,8 +2,26 @@ * @format */ -import {AppRegistry} from 'react-native'; -import App from './src/App'; -import {name as appName} from './app.json'; +import { AppRegistry } from "react-native"; +import App from "./src/App"; +import { name as appName } from "./app.json"; AppRegistry.registerComponent(appName, () => App); + +if (Platform.OS === "web" && typeof window !== "undefined") { + const rootTag = document.getElementById("root"); + if (process.env.NODE_ENV !== "production") { + if (!rootTag) { + throw new Error( + 'Required HTML element with id "root" was not found in the document HTML.', + ); + } + } + + CanvasKitInit({ + locateFile: (file) => `https://unpkg.com/canvaskit-wasm/bin/full/${file}`, + }).then((CanvasKit) => { + window.CanvasKit = global.CanvasKit = CanvasKit; + AppRegistry.runApplication(appName, { rootTag }); + }); +} diff --git a/apps/example/metro.config.js b/apps/example/metro.config.js index 42ae30faf..d245cc25a 100644 --- a/apps/example/metro.config.js +++ b/apps/example/metro.config.js @@ -1,32 +1,69 @@ const { makeMetroConfig } = require("@rnx-kit/metro-config"); -const path = require('path'); +const path = require("path"); -const root = path.resolve(__dirname, '../..'); -const threePackagePath = path.resolve(root, 'node_modules/three'); +const root = path.resolve(__dirname, "../.."); +const threePackagePath = path.resolve(root, "node_modules/three"); +const r3fPath = path.resolve(root, "node_modules/@react-three/fiber"); + +const rnwPath = path.resolve(root, "node_modules/react-native-web"); + +const assetRegistryPath = path.resolve( + root, + "node_modules/react-native-web/dist/modules/AssetRegistry/index", +); + +const IS_WEB = !!process.env.IS_WEB_BUILD; const extraConfig = { watchFolders: [root], resolver: { extraNodeModules: { - 'three': threePackagePath, + three: threePackagePath, }, - resolveRequest: (context, moduleName, platform) => { - if (moduleName.startsWith('three/addons/')) { + platforms: ["ios", "android", "web"], + + resolveRequest: (contextRaw, moduleName, platform) => { + const context = IS_WEB + ? { + ...contextRaw, + preferNativePlatform: false, + } + : contextRaw; + + if (IS_WEB && moduleName === "react-native") { + return { + filePath: path.resolve(rnwPath, "dist/index.js"), + type: "sourceFile", + }; + } + + if (moduleName.startsWith("three/addons/")) { return { - filePath: path.resolve(threePackagePath, 'examples/jsm/' + moduleName.replace('three/addons/', '') + '.js'), - type: 'sourceFile', + filePath: path.resolve( + threePackagePath, + "examples/jsm/" + moduleName.replace("three/addons/", "") + ".js", + ), + type: "sourceFile", }; } - if (moduleName === 'three' || moduleName === 'three/webgpu') { + if (moduleName === "three" || moduleName === "three/webgpu") { return { - filePath: path.resolve(threePackagePath, 'build/three.webgpu.js'), - type: 'sourceFile', + filePath: path.resolve(threePackagePath, "build/three.webgpu.js"), + type: "sourceFile", }; } - if (moduleName === 'three/tsl') { + if (moduleName === "three/tsl") { return { - filePath: path.resolve(threePackagePath, 'build/three.tsl.js'), - type: 'sourceFile', + filePath: path.resolve(threePackagePath, "build/three.tsl.js"), + type: "sourceFile", + }; + } + + if (moduleName === "@react-three/fiber") { + //Do NOT use the stale react three fiber "native" version originally added for expo-gl + return { + filePath: path.resolve(r3fPath, "dist/react-three-fiber.esm.js"), + type: "sourceFile", }; } // Let Metro handle other modules @@ -45,7 +82,11 @@ const extraConfig = { }; const metroConfig = makeMetroConfig(extraConfig); -metroConfig.resolver.assetExts.push('glb', 'gltf', 'jpg', 'bin', 'hdr'); +metroConfig.resolver.assetExts.push("glb", "gltf", "jpg", "bin", "hdr"); + +if (IS_WEB) { + metroConfig.transformer.assetRegistryPath = assetRegistryPath; +} module.exports = metroConfig; diff --git a/apps/example/package.json b/apps/example/package.json index 7dce8612e..ac03c9a64 100644 --- a/apps/example/package.json +++ b/apps/example/package.json @@ -7,6 +7,7 @@ "tsc": "tsc --noEmit", "android": "react-native run-android", "ios": "react-native run-ios", + "web": "IS_WEB_BUILD=true react-native start", "start": "react-native start", "pod:install:ios": "pod install --project-directory=ios", "pod:install:macos": "pod install --project-directory=macos", @@ -19,18 +20,21 @@ "@callstack/react-native-visionos": "^0.74.0", "@react-navigation/native": "^6.1.17", "@react-navigation/stack": "^6.4.0", - "@react-three/fiber": "^8.17.6", + "@react-three/fiber": "^9.1.2", "@shopify/react-native-skia": "2.0.0", "@tensorflow/tfjs": "^4.22.0", "@tensorflow/tfjs-backend-webgpu": "^4.22.0", "@tensorflow/tfjs-vis": "^1.5.1", + "@types/react-dom": "^19.1.5", "fast-text-encoding": "^1.0.6", "react": "19.0.0", + "react-dom": "19.0.0", "react-native": "0.79.2", "react-native-gesture-handler": "^2.17.1", "react-native-macos": "^0.78.3", "react-native-reanimated": "^3.12.1", - "react-native-safe-area-context": "^5.4.0", + "react-native-safe-area-context": "^5.4.1", + "react-native-web": "^0.20.0", "react-native-wgpu": "*", "teapot": "^1.0.0", "three": "0.172.0", diff --git a/apps/example/src/App.tsx b/apps/example/src/App.tsx index 9e5a6facc..3b4ceb6b7 100644 --- a/apps/example/src/App.tsx +++ b/apps/example/src/App.tsx @@ -44,7 +44,10 @@ function App() { return ( - + (null); -const fontFamily = Platform.select({ ios: "Helvetica", default: "serif" }); -const fontStyle = { - fontFamily, - fontSize: 200, -}; -const font = matchFont(fontStyle); + const font = useFont(require("../assets/helvetica.ttf")); -const paint = Skia.Paint(); -paint.setColor(Skia.Color("black")); -paint.setStyle(PaintStyle.Stroke); -paint.setStrokeWidth(1); + // Lazy initialize skia derived constants + if (!skiaConstants.current) { + const { width } = Dimensions.get("window"); -const grid = Skia.Path.Make(); -const cellSize = width / SIZE; + const paint = Skia.Paint(); + paint.setColor(Skia.Color("black")); + paint.setStyle(PaintStyle.Stroke); + paint.setStrokeWidth(1); -grid.moveTo(0, 0); + const grid = Skia.Path.Make(); + const cellSize = width / SIZE; -// Draw vertical lines -for (let i = 0; i <= SIZE; i++) { - grid.moveTo(i * cellSize, 0); - grid.lineTo(i * cellSize, width); -} + grid.moveTo(0, 0); -// Draw horizontal lines -for (let i = 0; i <= SIZE; i++) { - grid.moveTo(0, i * cellSize); - grid.lineTo(width, i * cellSize); -} + // Draw vertical lines + for (let i = 0; i <= SIZE; i++) { + grid.moveTo(i * cellSize, 0); + grid.lineTo(i * cellSize, width); + } -const f = 1 / cellSize; + // Draw horizontal lines + for (let i = 0; i <= SIZE; i++) { + grid.moveTo(0, i * cellSize); + grid.lineTo(width, i * cellSize); + } + + const f = 1 / cellSize; + + skiaConstants.current = { + f, + paint, + grid, + width, + }; + } + + const { f, paint, grid, width } = skiaConstants.current; -export function MNISTInference() { const { device } = useDevice(); const network = useRef(); const text = useSharedValue(""); @@ -84,21 +103,26 @@ export function MNISTInference() { if (surface.value) { const canvas = surface.value.getCanvas(); canvas.drawPath(path.value, paint); + surface.value.flush(); image.value = surface.value!.makeImageSnapshot(); const pixels = image.value.readPixels(0, 0, { width: SIZE, height: SIZE, - alphaType: AlphaType.Opaque, - colorType: ColorType.Alpha_8, + colorType: ColorType.RGBA_8888, + alphaType: AlphaType.Unpremul, }); + + const gray = new Uint8Array(SIZE * SIZE); + for (let i = 0; i < SIZE * SIZE; i++) { + gray[i] = pixels![i * 4]; + } + runOnJS(runInference)( - centerData(pixels as Uint8Array).map( - (x) => (x / 255) * 3.24 - 0.42, - ), + centerData(gray).map((x) => (x / 255) * 3.24 - 0.42), ); } }); - }, [path, runInference, surface, image]); + }, [path, runInference, surface, image, f, paint]); useEffect(() => { (async () => { @@ -111,6 +135,11 @@ export function MNISTInference() { })(); })(); }, [device, network, surface]); + + if (!font) { + return null; + } + return (