Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .prettierrc
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
{
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this change is not related right?

"trailingComma": "all"
}
34 changes: 34 additions & 0 deletions apps/example/getWebMetroConfig.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
const path = require("path");
const root = path.resolve(__dirname, "../..");
const r3fPath = path.resolve(root, "node_modules/@react-three/fiber");
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this import is not used here right?

const rnwPath = path.resolve(root, "node_modules/react-native-web");
const assetRegistryPath = path.resolve(
root,
"node_modules/react-native-web/dist/modules/AssetRegistry/index",
);

module.exports = function(metroConfig){
metroConfig.resolver.platforms = ["ios", "android", "web"];
const origResolveRequest = metroConfig.resolver.resolveRequest;
metroConfig.resolver.resolveRequest = (contextRaw, moduleName, platform) => {
const context = {
...contextRaw,
preferNativePlatform: false,
};

if (moduleName === "react-native") {
return {
filePath: path.resolve(rnwPath, "dist/index.js"),
type: "sourceFile",
};
}

// Let default config handle other modules
return origResolveRequest(context, moduleName, platform);
};

metroConfig.transformer.assetRegistryPath = assetRegistryPath;

return metroConfig
}

36 changes: 36 additions & 0 deletions apps/example/index.html
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="utf-8" />
<meta httpEquiv="X-UA-Compatible" content="IE=edge" />
<meta name="viewport" content="width=device-width, initial-scale=1, shrink-to-fit=no" />
<link rel="icon" href="/src/assets/react.png" type="image/png">
<script src="https://unpkg.com/canvaskit-wasm/bin/full/canvaskit.js"></script>

<title>Example</title>
<!-- The `react-native-web` recommended style reset: https://necolas.github.io/react-native-web/docs/setup/#root-element -->
<style id="react-native-web-reset">
/* These styles make the body full-height */
html,
body {
height: 100%;
}
/* These styles disable body scrolling if you are using <ScrollView> */
body {
overflow: hidden;
}
/* These styles make the root element full-height */
#root {
display: flex;
height: 100%;
flex: 1;
}
</style>
<body>
<noscript>
You need to enable JavaScript to run this app.
</noscript>
<div id="root"></div>
<script src="index.bundle?platform=web&dev=true&hot=false&lazy=true&transform.engine=hermes&transform.routerRoot=app&unstable_transformProfile=hermes-stable" defer></script>
</body>
</html>
21 changes: 21 additions & 0 deletions apps/example/index.web.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
import { AppRegistry } from "react-native";
import App from "./src/App";
import { name as appName } from "./app.json";

AppRegistry.registerComponent(appName, () => App);

const rootTag = document.getElementById("root");
if (process.env.NODE_ENV !== "production") {
if (!rootTag) {
throw new Error(
'Required HTML element with id "root" was not found in the document HTML.',
);
}
}

CanvasKitInit({
locateFile: (file) => `https://unpkg.com/canvaskit-wasm/bin/full/${file}`,
}).then((CanvasKit) => {
window.CanvasKit = global.CanvasKit = CanvasKit;
AppRegistry.runApplication(appName, { rootTag });
});
16 changes: 12 additions & 4 deletions apps/example/metro.config.js
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
const { makeMetroConfig } = require("@rnx-kit/metro-config");
const path = require('path');
const getWebMetroConfig = require('./getWebMetroConfig');

const root = path.resolve(__dirname, '../..');
const threePackagePath = path.resolve(root, 'node_modules/three');
Expand All @@ -17,18 +18,26 @@ const extraConfig = {
type: 'sourceFile',
};
}
if (moduleName === 'three' || moduleName === 'three/webgpu') {
if (moduleName === 'three' || moduleName === 'three/webgpu') {
return {
filePath: path.resolve(threePackagePath, 'build/three.webgpu.js'),
type: 'sourceFile',
};
}
if (moduleName === 'three/tsl') {
if (moduleName === 'three/tsl') {
return {
filePath: path.resolve(threePackagePath, 'build/three.tsl.js'),
type: 'sourceFile',
};
}

if (moduleName === "@react-three/fiber") {
//Just use the vanilla web build of react three fiber, not the stale "native" code path which has not been kept up to date.
return {
filePath: path.resolve(r3fPath, "dist/react-three-fiber.esm.js"),
type: "sourceFile",
};
}
// Let Metro handle other modules
return context.resolveRequest(context, moduleName, platform);
},
Expand All @@ -47,5 +56,4 @@ const extraConfig = {
const metroConfig = makeMetroConfig(extraConfig);
metroConfig.resolver.assetExts.push('glb', 'gltf', 'jpg', 'bin', 'hdr');


module.exports = metroConfig;
module.exports = !!process.env.IS_WEB_BUILD ? getWebMetroConfig(metroConfig) : metroConfig;
8 changes: 6 additions & 2 deletions apps/example/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
"tsc": "tsc --noEmit",
"android": "react-native run-android",
"ios": "react-native run-ios",
"web": "IS_WEB_BUILD=true react-native start",
"start": "react-native start",
"pod:install:ios": "pod install --project-directory=ios",
"pod:install:macos": "pod install --project-directory=macos",
Expand All @@ -19,18 +20,21 @@
"@callstack/react-native-visionos": "^0.74.0",
"@react-navigation/native": "^6.1.17",
"@react-navigation/stack": "^6.4.0",
"@react-three/fiber": "^8.17.6",
"@react-three/fiber": "^9.1.2",
"@shopify/react-native-skia": "2.0.0",
"@tensorflow/tfjs": "^4.22.0",
"@tensorflow/tfjs-backend-webgpu": "^4.22.0",
"@tensorflow/tfjs-vis": "^1.5.1",
"@types/react-dom": "^19.1.5",
"fast-text-encoding": "^1.0.6",
"react": "19.0.0",
"react-dom": "19.0.0",
"react-native": "0.79.2",
"react-native-gesture-handler": "^2.17.1",
"react-native-macos": "^0.78.3",
"react-native-reanimated": "^3.12.1",
"react-native-safe-area-context": "^5.4.0",
"react-native-safe-area-context": "^5.4.1",
"react-native-web": "^0.20.0",
"react-native-wgpu": "*",
"teapot": "^1.0.0",
"three": "0.172.0",
Expand Down
5 changes: 4 additions & 1 deletion apps/example/src/App.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,10 @@ function App() {
return (
<GestureHandlerRootView style={{ flex: 1 }}>
<NavigationContainer>
<Stack.Navigator initialRouteName="Home">
<Stack.Navigator
initialRouteName="Home"
screenOptions={{ cardStyle: { flex: 1 } }}
>
<Stack.Screen name="Home" component={Home} />
<Stack.Screen name="HelloTriangle" component={HelloTriangle} />
<Stack.Screen
Expand Down
6 changes: 4 additions & 2 deletions apps/example/src/Home.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -110,9 +110,11 @@ export const examples = [
] as const;

const styles = StyleSheet.create({
container: {},
container: {
flex: 1,
},
content: {
paddingBottom: 32,
marginBottom: 32,
},
thumbnail: {
backgroundColor: "white",
Expand Down
105 changes: 65 additions & 40 deletions apps/example/src/MNISTInference/MNISTInference.tsx
Original file line number Diff line number Diff line change
@@ -1,15 +1,20 @@
import React, { useCallback, useEffect, useMemo, useRef } from "react";
import { Button, Dimensions, Platform, StyleSheet, View } from "react-native";
import type { SkImage, SkSurface } from "@shopify/react-native-skia";
import { Button, Dimensions, StyleSheet, View } from "react-native";
import type {
SkImage,
SkPaint,
SkPath,
SkSurface,
} from "@shopify/react-native-skia";
import {
useFont,
Canvas,
Fill,
Skia,
PaintStyle,
Path,
ColorType,
AlphaType,
matchFont,
Text,
Image,
FilterMode,
Expand All @@ -21,40 +26,54 @@ import { useDevice } from "react-native-wgpu";
import type { Network } from "./Lib";
import { createDemo, centerData, SIZE } from "./Lib";

const { width } = Dimensions.get("window");
export function MNISTInference() {
const skiaConstants = useRef<null | {
paint: SkPaint;
grid: SkPath;
width: number;
f: number;
}>(null);

const fontFamily = Platform.select({ ios: "Helvetica", default: "serif" });
const fontStyle = {
fontFamily,
fontSize: 200,
};
const font = matchFont(fontStyle);
const font = useFont(require("../assets/helvetica.ttf"));

const paint = Skia.Paint();
paint.setColor(Skia.Color("black"));
paint.setStyle(PaintStyle.Stroke);
paint.setStrokeWidth(1);
// Lazy initialize skia derived constants
if (!skiaConstants.current) {
const { width } = Dimensions.get("window");

const grid = Skia.Path.Make();
const cellSize = width / SIZE;
const paint = Skia.Paint();
paint.setColor(Skia.Color("black"));
paint.setStyle(PaintStyle.Stroke);
paint.setStrokeWidth(1);

grid.moveTo(0, 0);
const grid = Skia.Path.Make();
const cellSize = width / SIZE;

// Draw vertical lines
for (let i = 0; i <= SIZE; i++) {
grid.moveTo(i * cellSize, 0);
grid.lineTo(i * cellSize, width);
}
grid.moveTo(0, 0);

// Draw horizontal lines
for (let i = 0; i <= SIZE; i++) {
grid.moveTo(0, i * cellSize);
grid.lineTo(width, i * cellSize);
}
// Draw vertical lines
for (let i = 0; i <= SIZE; i++) {
grid.moveTo(i * cellSize, 0);
grid.lineTo(i * cellSize, width);
}

const f = 1 / cellSize;
// Draw horizontal lines
for (let i = 0; i <= SIZE; i++) {
grid.moveTo(0, i * cellSize);
grid.lineTo(width, i * cellSize);
}

const f = 1 / cellSize;

skiaConstants.current = {
f,
paint,
grid,
width,
};
}

const { f, paint, grid, width } = skiaConstants.current;

export function MNISTInference() {
const { device } = useDevice();
const network = useRef<Network>();
const text = useSharedValue("");
Expand Down Expand Up @@ -84,21 +103,26 @@ export function MNISTInference() {
if (surface.value) {
const canvas = surface.value.getCanvas();
canvas.drawPath(path.value, paint);
surface.value.flush();
image.value = surface.value!.makeImageSnapshot();
const pixels = image.value.readPixels(0, 0, {
width: SIZE,
height: SIZE,
alphaType: AlphaType.Opaque,
colorType: ColorType.Alpha_8,
colorType: ColorType.RGBA_8888,
alphaType: AlphaType.Unpremul,
});

const gray = new Uint8Array(SIZE * SIZE);
for (let i = 0; i < SIZE * SIZE; i++) {
gray[i] = pixels![i * 4];
}

runOnJS(runInference)(
centerData(pixels as Uint8Array).map(
(x) => (x / 255) * 3.24 - 0.42,
),
centerData(gray).map((x) => (x / 255) * 3.24 - 0.42),
);
}
});
}, [path, runInference, surface, image]);
}, [path, runInference, surface, image, f, paint]);

useEffect(() => {
(async () => {
Expand All @@ -111,6 +135,11 @@ export function MNISTInference() {
})();
})();
}, [device, network, surface]);

if (!font) {
return null;
}

return (
<View style={style.container}>
<Button
Expand All @@ -123,7 +152,7 @@ export function MNISTInference() {
title="Reset"
/>
<GestureDetector gesture={gesture}>
<Canvas style={style.canvas}>
<Canvas style={{ width, height: width * 2 }}>
<Fill color="rgb(239, 239, 248)" />
<Path
path={grid}
Expand Down Expand Up @@ -151,8 +180,4 @@ const style = StyleSheet.create({
container: {
flex: 1,
},
canvas: {
width,
height: width * 2,
},
});
Binary file added apps/example/src/assets/helvetica.ttf
Binary file not shown.
Loading