diff --git a/packages/language/src/generated/ast.ts b/packages/language/src/generated/ast.ts index b4d7e6d82..1705f019b 100644 --- a/packages/language/src/generated/ast.ts +++ b/packages/language/src/generated/ast.ts @@ -98,6 +98,14 @@ export function isTypeDeclaration(item: unknown): item is TypeDeclaration { return reflection.isInstance(item, TypeDeclaration); } +export type TypeDefFieldTypes = Enum | TypeDef; + +export const TypeDefFieldTypes = 'TypeDefFieldTypes'; + +export function isTypeDefFieldTypes(item: unknown): item is TypeDefFieldTypes { + return reflection.isInstance(item, TypeDefFieldTypes); +} + export interface Argument extends AstNode { readonly $container: InvocationExpr; readonly $type: 'Argument'; @@ -654,7 +662,7 @@ export interface TypeDefFieldType extends AstNode { readonly $type: 'TypeDefFieldType'; array: boolean optional: boolean - reference?: Reference + reference?: Reference type?: BuiltinType } @@ -738,6 +746,7 @@ export type ZModelAstType = { TypeDef: TypeDef TypeDefField: TypeDefField TypeDefFieldType: TypeDefFieldType + TypeDefFieldTypes: TypeDefFieldTypes UnaryExpr: UnaryExpr UnsupportedFieldType: UnsupportedFieldType } @@ -745,7 +754,7 @@ export type ZModelAstType = { export class ZModelAstReflection extends AbstractAstReflection { getAllTypes(): string[] { - return ['AbstractDeclaration', 'Argument', 'ArrayExpr', 'Attribute', 'AttributeArg', 'AttributeParam', 'AttributeParamType', 'BinaryExpr', 'BooleanLiteral', 'ConfigArrayExpr', 'ConfigExpr', 'ConfigField', 'ConfigInvocationArg', 'ConfigInvocationExpr', 'DataModel', 'DataModelAttribute', 'DataModelField', 'DataModelFieldAttribute', 'DataModelFieldType', 'DataSource', 'Enum', 'EnumField', 'Expression', 'FieldInitializer', 'FunctionDecl', 'FunctionParam', 'FunctionParamType', 'GeneratorDecl', 'InternalAttribute', 'InvocationExpr', 'LiteralExpr', 'MemberAccessExpr', 'Model', 'ModelImport', 'NullExpr', 'NumberLiteral', 'ObjectExpr', 'Plugin', 'PluginField', 'ReferenceArg', 'ReferenceExpr', 'ReferenceTarget', 'StringLiteral', 'ThisExpr', 'TypeDeclaration', 'TypeDef', 'TypeDefField', 'TypeDefFieldType', 'UnaryExpr', 'UnsupportedFieldType']; + return ['AbstractDeclaration', 'Argument', 'ArrayExpr', 'Attribute', 'AttributeArg', 'AttributeParam', 'AttributeParamType', 'BinaryExpr', 'BooleanLiteral', 'ConfigArrayExpr', 'ConfigExpr', 'ConfigField', 'ConfigInvocationArg', 'ConfigInvocationExpr', 'DataModel', 'DataModelAttribute', 'DataModelField', 'DataModelFieldAttribute', 'DataModelFieldType', 'DataSource', 'Enum', 'EnumField', 'Expression', 'FieldInitializer', 'FunctionDecl', 'FunctionParam', 'FunctionParamType', 'GeneratorDecl', 'InternalAttribute', 'InvocationExpr', 'LiteralExpr', 'MemberAccessExpr', 'Model', 'ModelImport', 'NullExpr', 'NumberLiteral', 'ObjectExpr', 'Plugin', 'PluginField', 'ReferenceArg', 'ReferenceExpr', 'ReferenceTarget', 'StringLiteral', 'ThisExpr', 'TypeDeclaration', 'TypeDef', 'TypeDefField', 'TypeDefFieldType', 'TypeDefFieldTypes', 'UnaryExpr', 'UnsupportedFieldType']; } protected override computeIsSubtype(subtype: string, supertype: string): boolean { @@ -775,9 +784,7 @@ export class ZModelAstReflection extends AbstractAstReflection { case ConfigArrayExpr: { return this.isSubtype(ConfigExpr, supertype); } - case DataModel: - case Enum: - case TypeDef: { + case DataModel: { return this.isSubtype(AbstractDeclaration, supertype) || this.isSubtype(TypeDeclaration, supertype); } case DataModelField: @@ -785,6 +792,10 @@ export class ZModelAstReflection extends AbstractAstReflection { case FunctionParam: { return this.isSubtype(ReferenceTarget, supertype); } + case Enum: + case TypeDef: { + return this.isSubtype(AbstractDeclaration, supertype) || this.isSubtype(TypeDeclaration, supertype) || this.isSubtype(TypeDefFieldTypes, supertype); + } case InvocationExpr: case LiteralExpr: { return this.isSubtype(ConfigExpr, supertype) || this.isSubtype(Expression, supertype); @@ -821,7 +832,7 @@ export class ZModelAstReflection extends AbstractAstReflection { return ReferenceTarget; } case 'TypeDefFieldType:reference': { - return TypeDef; + return TypeDefFieldTypes; } default: { throw new Error(`${referenceId} is not a valid reference id.`); diff --git a/packages/language/src/generated/grammar.ts b/packages/language/src/generated/grammar.ts index 01c492bd5..43beb12a2 100644 --- a/packages/language/src/generated/grammar.ts +++ b/packages/language/src/generated/grammar.ts @@ -2165,7 +2165,7 @@ export const ZModelGrammar = (): Grammar => loadedZModelGrammar ?? (loadedZModel "terminal": { "$type": "CrossReference", "type": { - "$ref": "#/types@1" + "$ref": "#/types@2" }, "terminal": { "$type": "RuleCall", @@ -2267,7 +2267,7 @@ export const ZModelGrammar = (): Grammar => loadedZModelGrammar ?? (loadedZModel }, "arguments": [] }, - "cardinality": "+" + "cardinality": "*" }, { "$type": "Keyword", @@ -2375,7 +2375,7 @@ export const ZModelGrammar = (): Grammar => loadedZModelGrammar ?? (loadedZModel "terminal": { "$type": "CrossReference", "type": { - "$ref": "#/rules@40" + "$ref": "#/types@1" }, "terminal": { "$type": "RuleCall", @@ -2827,7 +2827,7 @@ export const ZModelGrammar = (): Grammar => loadedZModelGrammar ?? (loadedZModel "terminal": { "$type": "CrossReference", "type": { - "$ref": "#/types@1" + "$ref": "#/types@2" }, "terminal": { "$type": "RuleCall", @@ -3255,7 +3255,7 @@ export const ZModelGrammar = (): Grammar => loadedZModelGrammar ?? (loadedZModel "terminal": { "$type": "CrossReference", "type": { - "$ref": "#/types@1" + "$ref": "#/types@2" }, "terminal": { "$type": "RuleCall", @@ -3838,6 +3838,27 @@ export const ZModelGrammar = (): Grammar => loadedZModelGrammar ?? (loadedZModel ] } }, + { + "$type": "Type", + "name": "TypeDefFieldTypes", + "type": { + "$type": "UnionType", + "types": [ + { + "$type": "SimpleType", + "typeRef": { + "$ref": "#/rules@40" + } + }, + { + "$type": "SimpleType", + "typeRef": { + "$ref": "#/rules@44" + } + } + ] + } + }, { "$type": "Type", "name": "TypeDeclaration", diff --git a/packages/language/src/zmodel.langium b/packages/language/src/zmodel.langium index d66ea5f32..ef5a1f883 100644 --- a/packages/language/src/zmodel.langium +++ b/packages/language/src/zmodel.langium @@ -183,15 +183,17 @@ TypeDef: (comments+=TRIPLE_SLASH_COMMENT)* 'type' name=RegularID '{' ( fields+=TypeDefField - )+ + )* '}'; +type TypeDefFieldTypes = TypeDef | Enum; + TypeDefField: - (comments+=TRIPLE_SLASH_COMMENT)* + (comments+=TRIPLE_SLASH_COMMENT)* name=RegularIDWithTypeNames type=TypeDefFieldType (attributes+=DataModelFieldAttribute)*; TypeDefFieldType: - (type=BuiltinType | reference=[TypeDef:RegularID]) (array?='[' ']')? (optional?='?')?; + (type=BuiltinType | reference=[TypeDefFieldTypes:RegularID]) (array?='[' ']')? (optional?='?')?; UnsupportedFieldType: 'Unsupported' '(' (value=LiteralExpr) ')'; diff --git a/packages/schema/src/plugins/enhancer/enhance/model-typedef-generator.ts b/packages/schema/src/plugins/enhancer/enhance/model-typedef-generator.ts index 40bc3e5b4..21ede7a64 100644 --- a/packages/schema/src/plugins/enhancer/enhance/model-typedef-generator.ts +++ b/packages/schema/src/plugins/enhancer/enhance/model-typedef-generator.ts @@ -1,5 +1,5 @@ -import { PluginError } from '@zenstackhq/sdk'; -import { BuiltinType, TypeDef, TypeDefFieldType } from '@zenstackhq/sdk/ast'; +import { getDataModels, PluginError } from '@zenstackhq/sdk'; +import { BuiltinType, Enum, isEnum, TypeDef, TypeDefFieldType } from '@zenstackhq/sdk/ast'; import { SourceFile } from 'ts-morph'; import { match } from 'ts-pattern'; import { name } from '..'; @@ -36,7 +36,11 @@ function zmodelTypeToTsType(type: TypeDefFieldType) { if (type.type) { result = builtinTypeToTsType(type.type); } else if (type.reference?.ref) { - result = type.reference.ref.name; + if (isEnum(type.reference.ref)) { + result = makeEnumTypeReference(type.reference.ref); + } else { + result = type.reference.ref.name; + } } else { throw new PluginError(name, `Unsupported field type: ${type}`); } @@ -61,3 +65,17 @@ function builtinTypeToTsType(type: BuiltinType) { .with('Json', () => 'unknown') .exhaustive(); } + +function makeEnumTypeReference(enumDecl: Enum) { + const zmodel = enumDecl.$container; + const models = getDataModels(zmodel); + + if (models.some((model) => model.fields.some((field) => field.type.reference?.ref === enumDecl))) { + // if the enum is referenced by any data model, Prisma already generates its type, + // we just need to reference it + return enumDecl.name; + } else { + // otherwise, we need to inline the enum + return enumDecl.fields.map((field) => `'${field.name}'`).join(' | '); + } +} diff --git a/packages/schema/src/plugins/zod/transformer.ts b/packages/schema/src/plugins/zod/transformer.ts index 6b83e5723..081dcaf4a 100644 --- a/packages/schema/src/plugins/zod/transformer.ts +++ b/packages/schema/src/plugins/zod/transformer.ts @@ -1,6 +1,6 @@ /* eslint-disable @typescript-eslint/ban-ts-comment */ import { indentString, isDiscriminatorField, type PluginOptions } from '@zenstackhq/sdk'; -import { DataModel, isDataModel, isTypeDef, type Model } from '@zenstackhq/sdk/ast'; +import { DataModel, Enum, isDataModel, isEnum, isTypeDef, type Model } from '@zenstackhq/sdk/ast'; import { checkModelHasModelRelation, findModelByName, isAggregateInputType } from '@zenstackhq/sdk/dmmf-helpers'; import { supportCreateMany, type DMMF as PrismaDMMF } from '@zenstackhq/sdk/prisma'; import path from 'path'; @@ -53,6 +53,9 @@ export default class Transformer { } async generateEnumSchemas() { + const generated: string[] = []; + + // generate for enums in DMMF for (const enumType of this.enumTypes) { const name = upperCaseFirst(enumType.name); const filePath = path.join(Transformer.outputPath, `enums/${name}.schema.ts`); @@ -61,14 +64,26 @@ export default class Transformer { `z.enum(${JSON.stringify(enumType.values)})` )}`; this.sourceFiles.push(this.project.createSourceFile(filePath, content, { overwrite: true })); + generated.push(enumType.name); + } + + // enums not referenced by data models are not in DMMF, deal with them separately + const extraEnums = this.zmodel.declarations.filter((d): d is Enum => isEnum(d) && !generated.includes(d.name)); + for (const enumDecl of extraEnums) { + const name = upperCaseFirst(enumDecl.name); + const filePath = path.join(Transformer.outputPath, `enums/${name}.schema.ts`); + const content = `/* eslint-disable */\n${this.generateImportZodStatement()}\n${this.generateExportSchemaStatement( + `${name}`, + `z.enum(${JSON.stringify(enumDecl.fields.map((f) => f.name))})` + )}`; + this.sourceFiles.push(this.project.createSourceFile(filePath, content, { overwrite: true })); + generated.push(enumDecl.name); } this.sourceFiles.push( this.project.createSourceFile( path.join(Transformer.outputPath, `enums/index.ts`), - this.enumTypes - .map((enumType) => `export * from './${upperCaseFirst(enumType.name)}.schema';`) - .join('\n'), + generated.map((name) => `export * from './${upperCaseFirst(name)}.schema';`).join('\n'), { overwrite: true } ) ); diff --git a/tests/integration/tests/enhancements/json/crud.test.ts b/tests/integration/tests/enhancements/json/crud.test.ts index af3705a95..9cd7ff8a4 100644 --- a/tests/integration/tests/enhancements/json/crud.test.ts +++ b/tests/integration/tests/enhancements/json/crud.test.ts @@ -191,6 +191,83 @@ describe('Json field CRUD', () => { ).toResolveTruthy(); }); + it('respects enums used by data models', async () => { + const params = await loadSchema( + ` + enum Role { + USER + ADMIN + } + + type Profile { + role Role + } + + model User { + id Int @id @default(autoincrement()) + profile Profile @json + @@allow('all', true) + } + + model Foo { + id Int @id @default(autoincrement()) + role Role + } + `, + { + provider: 'postgresql', + dbUrl, + } + ); + + prisma = params.prisma; + const db = params.enhance(); + + await expect(db.user.create({ data: { profile: { role: 'MANAGER' } } })).toBeRejectedByPolicy(); + await expect(db.user.create({ data: { profile: { role: 'ADMIN' } } })).resolves.toMatchObject({ + profile: { role: 'ADMIN' }, + }); + await expect(db.user.findFirst()).resolves.toMatchObject({ + profile: { role: 'ADMIN' }, + }); + }); + + it('respects enums unused by data models', async () => { + const params = await loadSchema( + ` + enum Role { + USER + ADMIN + } + + type Profile { + role Role + } + + model User { + id Int @id @default(autoincrement()) + profile Profile @json + @@allow('all', true) + } + `, + { + provider: 'postgresql', + dbUrl, + } + ); + + prisma = params.prisma; + const db = params.enhance(); + + await expect(db.user.create({ data: { profile: { role: 'MANAGER' } } })).toBeRejectedByPolicy(); + await expect(db.user.create({ data: { profile: { role: 'ADMIN' } } })).resolves.toMatchObject({ + profile: { role: 'ADMIN' }, + }); + await expect(db.user.findFirst()).resolves.toMatchObject({ + profile: { role: 'ADMIN' }, + }); + }); + it('respects @default', async () => { const params = await loadSchema( ` diff --git a/tests/integration/tests/enhancements/json/typing.test.ts b/tests/integration/tests/enhancements/json/typing.test.ts index 9681bf015..a73e04f03 100644 --- a/tests/integration/tests/enhancements/json/typing.test.ts +++ b/tests/integration/tests/enhancements/json/typing.test.ts @@ -179,6 +179,108 @@ async function main() { ); }); + it('works with enums used in models', async () => { + await loadSchema( + ` + enum Role { + USER + ADMIN + } + + type Profile { + role Role + } + + model User { + id Int @id @default(autoincrement()) + profile Profile @json + @@allow('all', true) + } + + model Foo { + id Int @id @default(autoincrement()) + role Role + } + `, + { + provider: 'postgresql', + pushDb: false, + compile: true, + extraSourceFiles: [ + { + name: 'main.ts', + content: ` +import type { Profile } from '.zenstack/models'; +import { enhance } from '.zenstack/enhance'; +import { PrismaClient } from '@prisma/client'; +import { Role } from '@prisma/client'; +const prisma = new PrismaClient(); +const db = enhance(prisma); + +async function main() { + const profile: Profile = { + role: Role.ADMIN, + } + + await db.user.create({ data: { profile: { role: Role.ADMIN } } }); + const user = await db.user.findFirstOrThrow(); + console.log(user.profile.role === Role.ADMIN); +} +`, + }, + ], + } + ); + }); + + it('works with enums unused in models', async () => { + await loadSchema( + ` + enum Role { + USER + ADMIN + } + + type Profile { + role Role + } + + model User { + id Int @id @default(autoincrement()) + profile Profile @json + @@allow('all', true) + } + `, + { + provider: 'postgresql', + pushDb: false, + compile: true, + extraSourceFiles: [ + { + name: 'main.ts', + content: ` +import type { Profile } from '.zenstack/models'; +import { enhance } from '.zenstack/enhance'; +import { PrismaClient } from '@prisma/client'; +const prisma = new PrismaClient(); +const db = enhance(prisma); + +async function main() { + const profile: Profile = { + role: 'ADMIN', + } + + await db.user.create({ data: { profile: { role: 'ADMIN' } } }); + const user = await db.user.findFirstOrThrow(); + console.log(user.profile.role === 'ADMIN'); +} +`, + }, + ], + } + ); + }); + it('type coverage', async () => { await loadSchema( `