Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ import {
NonLocalBinding,
PolyType,
ScopeId,
SourceLocation,
Type,
ValidatedIdentifier,
ValueKind,
Expand Down Expand Up @@ -126,11 +127,6 @@ const HookSchema = z.object({

export type Hook = z.infer<typeof HookSchema>;

export const ModuleTypeResolver = z
.function()
.args(z.string())
.returns(z.nullable(TypeSchema));

Comment on lines -129 to -133
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

setting .returns() causes zod to return a function that internally validates the return type when you call it. that's cool but we want to control when the validation happens so i'm switching this to a plain function annotation

/*
* TODO(mofeiZ): User defined global types (with corresponding shapes).
* User defined global types should have inline ObjectShapes instead of directly
Expand All @@ -148,7 +144,7 @@ const EnvironmentConfigSchema = z.object({
* A function that, given the name of a module, can optionally return a description
* of that module's type signature.
*/
resolveModuleTypeSchema: z.nullable(ModuleTypeResolver).default(null),
moduleTypeProvider: z.nullable(z.function().args(z.string())).default(null),

/**
* A list of functions which the application compiles as macros, where
Expand Down Expand Up @@ -712,19 +708,27 @@ export class Environment {
return this.#outlinedFunctions;
}

#resolveModuleType(moduleName: string): Global | null {
if (this.config.resolveModuleTypeSchema == null) {
#resolveModuleType(moduleName: string, loc: SourceLocation): Global | null {
if (this.config.moduleTypeProvider == null) {
return null;
}
let moduleType = this.#moduleTypes.get(moduleName);
if (moduleType === undefined) {
const moduleConfig = this.config.resolveModuleTypeSchema(moduleName);
if (moduleConfig != null) {
const moduleTypes = TypeSchema.parse(moduleConfig);
const unparsedModuleConfig = this.config.moduleTypeProvider(moduleName);
if (unparsedModuleConfig != null) {
const parsedModuleConfig = TypeSchema.safeParse(unparsedModuleConfig);
if (!parsedModuleConfig.success) {
CompilerError.throwInvalidConfig({
reason: `Could not parse module type, the configured \`moduleTypeProvider\` function returned an invalid module description`,
description: parsedModuleConfig.error.toString(),
loc,
});
}
const moduleConfig = parsedModuleConfig.data;
moduleType = installTypeConfig(
this.#globals,
this.#shapes,
moduleTypes,
moduleConfig,
);
} else {
moduleType = null;
Expand All @@ -734,7 +738,10 @@ export class Environment {
return moduleType;
}

getGlobalDeclaration(binding: NonLocalBinding): Global | null {
getGlobalDeclaration(
binding: NonLocalBinding,
loc: SourceLocation,
): Global | null {
if (this.config.hookPattern != null) {
const match = new RegExp(this.config.hookPattern).exec(binding.name);
if (
Expand Down Expand Up @@ -772,7 +779,7 @@ export class Environment {
(isHookName(binding.imported) ? this.#getCustomHookType() : null)
);
} else {
const moduleType = this.#resolveModuleType(binding.module);
const moduleType = this.#resolveModuleType(binding.module, loc);
if (moduleType !== null) {
const importedType = this.getPropertyType(
moduleType,
Expand Down Expand Up @@ -805,10 +812,16 @@ export class Environment {
(isHookName(binding.name) ? this.#getCustomHookType() : null)
);
} else {
const moduleType = this.#resolveModuleType(binding.module);
const moduleType = this.#resolveModuleType(binding.module, loc);
if (moduleType !== null) {
// TODO: distinguish default/namespace cases
return moduleType;
if (binding.kind === 'ImportDefault') {
const defaultType = this.getPropertyType(moduleType, 'default');
if (defaultType !== null) {
return defaultType;
}
} else {
return moduleType;
}
}
return isHookName(binding.name) ? this.#getCustomHookType() : null;
}
Expand All @@ -819,9 +832,7 @@ export class Environment {
#isKnownReactModule(moduleName: string): boolean {
return (
moduleName.toLowerCase() === 'react' ||
moduleName.toLowerCase() === 'react-dom' ||
(this.config.enableSharedRuntime__testonly &&
moduleName === 'shared-runtime')
moduleName.toLowerCase() === 'react-dom'
);
}

Expand Down
23 changes: 23 additions & 0 deletions compiler/packages/babel-plugin-react-compiler/src/HIR/Globals.ts
Original file line number Diff line number Diff line change
Expand Up @@ -551,6 +551,9 @@ export function installTypeConfig(
case 'Ref': {
return {kind: 'Object', shapeId: BuiltInUseRefId};
}
case 'Any': {
return {kind: 'Poly'};
}
default: {
assertExhaustive(
typeConfig.name,
Expand All @@ -566,6 +569,20 @@ export function installTypeConfig(
calleeEffect: typeConfig.calleeEffect,
returnType: installTypeConfig(globals, shapes, typeConfig.returnType),
returnValueKind: typeConfig.returnValueKind,
noAlias: typeConfig.noAlias === true,
mutableOnlyIfOperandsAreMutable:
typeConfig.mutableOnlyIfOperandsAreMutable === true,
});
}
case 'hook': {
return addHook(shapes, {
hookKind: 'Custom',
positionalParams: typeConfig.positionalParams ?? [],
restParam: typeConfig.restParam ?? Effect.Freeze,
calleeEffect: Effect.Read,
returnType: installTypeConfig(globals, shapes, typeConfig.returnType),
returnValueKind: typeConfig.returnValueKind ?? ValueKind.Frozen,
noAlias: typeConfig.noAlias === true,
});
}
case 'object': {
Expand All @@ -578,6 +595,12 @@ export function installTypeConfig(
]),
);
}
default: {
assertExhaustive(
typeConfig,
`Unexpected type kind '${(typeConfig as any).kind}'`,
);
}
}
}

Expand Down
9 changes: 9 additions & 0 deletions compiler/packages/babel-plugin-react-compiler/src/HIR/HIR.ts
Original file line number Diff line number Diff line change
Expand Up @@ -1361,6 +1361,15 @@ export enum ValueKind {
Context = 'context',
}

export const ValueKindSchema = z.enum([
ValueKind.MaybeFrozen,
ValueKind.Frozen,
ValueKind.Primitive,
ValueKind.Global,
ValueKind.Mutable,
ValueKind.Context,
]);

// The effect with which a value is modified.
export enum Effect {
// Default value: not allowed after lifetime inference
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import {isValidIdentifier} from '@babel/types';
import {z} from 'zod';
import {Effect, ValueKind} from '..';
import {EffectSchema} from './HIR';
import {EffectSchema, ValueKindSchema} from './HIR';

export type ObjectPropertiesConfig = {[key: string]: TypeConfig};
export const ObjectPropertiesSchema: z.ZodType<ObjectPropertiesConfig> = z
Expand All @@ -18,9 +18,9 @@ export const ObjectPropertiesSchema: z.ZodType<ObjectPropertiesConfig> = z
)
.refine(record => {
return Object.keys(record).every(
key => key === '*' || isValidIdentifier(key),
key => key === '*' || key === 'default' || isValidIdentifier(key),
);
}, 'Expected all "object" property names to be valid identifiers or `*` to match any property');
}, 'Expected all "object" property names to be valid identifier, `*` to match any property, of `default` to define a module default export');

export type ObjectTypeConfig = {
kind: 'object';
Expand All @@ -38,18 +38,45 @@ export type FunctionTypeConfig = {
calleeEffect: Effect;
returnType: TypeConfig;
returnValueKind: ValueKind;
noAlias?: boolean | null | undefined;
mutableOnlyIfOperandsAreMutable?: boolean | null | undefined;
};
export const FunctionTypeSchema: z.ZodType<FunctionTypeConfig> = z.object({
kind: z.literal('function'),
positionalParams: z.array(EffectSchema),
restParam: EffectSchema.nullable(),
calleeEffect: EffectSchema,
returnType: z.lazy(() => TypeSchema),
returnValueKind: z.nativeEnum(ValueKind),
returnValueKind: ValueKindSchema,
noAlias: z.boolean().nullable().optional(),
mutableOnlyIfOperandsAreMutable: z.boolean().nullable().optional(),
});

export type BuiltInTypeConfig = 'Ref' | 'Array' | 'Primitive' | 'MixedReadonly';
export type HookTypeConfig = {
kind: 'hook';
positionalParams?: Array<Effect> | null | undefined;
restParam?: Effect | null | undefined;
returnType: TypeConfig;
returnValueKind?: ValueKind | null | undefined;
noAlias?: boolean | null | undefined;
};
export const HookTypeSchema: z.ZodType<HookTypeConfig> = z.object({
kind: z.literal('hook'),
positionalParams: z.array(EffectSchema).nullable().optional(),
restParam: EffectSchema.nullable().optional(),
returnType: z.lazy(() => TypeSchema),
returnValueKind: ValueKindSchema.nullable().optional(),
noAlias: z.boolean().nullable().optional(),
});

export type BuiltInTypeConfig =
| 'Any'
| 'Ref'
| 'Array'
| 'Primitive'
| 'MixedReadonly';
export const BuiltInTypeSchema: z.ZodType<BuiltInTypeConfig> = z.union([
z.literal('Any'),
z.literal('Ref'),
z.literal('Array'),
z.literal('Primitive'),
Expand All @@ -68,9 +95,11 @@ export const TypeReferenceSchema: z.ZodType<TypeReferenceConfig> = z.object({
export type TypeConfig =
| ObjectTypeConfig
| FunctionTypeConfig
| HookTypeConfig
| TypeReferenceConfig;
export const TypeSchema: z.ZodType<TypeConfig> = z.union([
ObjectTypeSchema,
FunctionTypeSchema,
HookTypeSchema,
TypeReferenceSchema,
]);
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ function collectTemporaries(
break;
}
case 'LoadGlobal': {
const global = env.getGlobalDeclaration(value.binding);
const global = env.getGlobalDeclaration(value.binding, value.loc);
const hookKind = global !== null ? getHookKindForType(env, global) : null;
const lvalId = instr.lvalue.identifier.id;
if (hookKind === 'useMemo' || hookKind === 'useCallback') {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,7 @@ function* generateInstructionTypes(
}

case 'LoadGlobal': {
const globalType = env.getGlobalDeclaration(value.binding);
const globalType = env.getGlobalDeclaration(value.binding, value.loc);
if (globalType) {
yield equation(left, globalType);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
## Input

```javascript
import {useFragment} from 'shared-runtime';

function Component(props) {
const post = useFragment(
graphql`
Expand Down Expand Up @@ -36,6 +38,8 @@ function Component(props) {

```javascript
import { c as _c } from "react/compiler-runtime";
import { useFragment } from "shared-runtime";

function Component(props) {
const $ = _c(4);
const post = useFragment(
Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import {useFragment} from 'shared-runtime';

function Component(props) {
const post = useFragment(
graphql`
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
## Input

```javascript
import {useFragment} from 'shared-runtime';

function Component(props) {
const item = useFragment(
graphql`
Expand All @@ -20,6 +22,8 @@ function Component(props) {

```javascript
import { c as _c } from "react/compiler-runtime";
import { useFragment } from "shared-runtime";

function Component(props) {
const $ = _c(2);
const item = useFragment(
Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import {useFragment} from 'shared-runtime';

function Component(props) {
const item = useFragment(
graphql`
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
## Input

```javascript
import {useFragment} from 'shared-runtime';

function Component(props) {
const x = makeObject();
const user = useFragment(
Expand All @@ -28,6 +30,8 @@ function Component(props) {

```javascript
import { c as _c } from "react/compiler-runtime";
import { useFragment } from "shared-runtime";

function Component(props) {
const $ = _c(3);
const x = makeObject();
Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import {useFragment} from 'shared-runtime';

function Component(props) {
const x = makeObject();
const user = useFragment(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
## Input

```javascript
import {useFragment} from 'shared-runtime';

function Component(props) {
const user = useFragment(
graphql`
Expand All @@ -26,6 +28,8 @@ function Component(props) {

```javascript
import { c as _c } from "react/compiler-runtime";
import { useFragment } from "shared-runtime";

function Component(props) {
const $ = _c(5);
const user = useFragment(
Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import {useFragment} from 'shared-runtime';

function Component(props) {
const user = useFragment(
graphql`
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
## Input

```javascript
import {useFragment} from 'shared-runtime';

function Component(props) {
const user = useFragment(
graphql`
Expand All @@ -19,6 +21,8 @@ function Component(props) {
## Code

```javascript
import { useFragment } from "shared-runtime";

function Component(props) {
const user = useFragment(
graphql`
Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import {useFragment} from 'shared-runtime';

function Component(props) {
const user = useFragment(
graphql`
Expand Down
Loading