diff --git a/packages/schema/src/plugins/enhancer/delegate/index.ts b/packages/schema/src/plugins/enhancer/delegate/index.ts index 5e4cffdfa..d3f85576d 100644 --- a/packages/schema/src/plugins/enhancer/delegate/index.ts +++ b/packages/schema/src/plugins/enhancer/delegate/index.ts @@ -5,8 +5,8 @@ import { PrismaSchemaGenerator } from '../../prisma/schema-generator'; import path from 'path'; export async function generate(model: Model, options: PluginOptions, project: Project, outDir: string) { - const prismaGenerator = new PrismaSchemaGenerator(); - await prismaGenerator.generate(model, { + const prismaGenerator = new PrismaSchemaGenerator(model); + await prismaGenerator.generate({ provider: '@internal', schemaPath: options.schemaPath, output: path.join(outDir, 'delegate.prisma'), diff --git a/packages/schema/src/plugins/enhancer/enhance/index.ts b/packages/schema/src/plugins/enhancer/enhance/index.ts index 1d42b5912..06caf6950 100644 --- a/packages/schema/src/plugins/enhancer/enhance/index.ts +++ b/packages/schema/src/plugins/enhancer/enhance/index.ts @@ -6,19 +6,24 @@ import { isDelegateModel, type PluginOptions, } from '@zenstackhq/sdk'; -import { DataModelField, isDataModel, isReferenceExpr, type DataModel, type Model } from '@zenstackhq/sdk/ast'; +import { DataModel, DataModelField, isDataModel, isReferenceExpr, type Model } from '@zenstackhq/sdk/ast'; import path from 'path'; import { - ForEachDescendantTraversalControl, - MethodSignature, + FunctionDeclarationStructure, + InterfaceDeclaration, + ModuleDeclaration, Node, Project, - PropertySignature, + SourceFile, SyntaxKind, TypeAliasDeclaration, + VariableStatement, } from 'ts-morph'; import { PrismaSchemaGenerator } from '../../prisma/schema-generator'; +// information of delegate models and their sub models +type DelegateInfo = [DataModel, DataModel[]][]; + export async function generate(model: Model, options: PluginOptions, project: Project, outDir: string) { const outFile = path.join(outDir, 'enhance.ts'); let logicalPrismaClientDir: string | undefined; @@ -34,7 +39,11 @@ import modelMeta from './model-meta'; import policy from './policy'; ${options.withZodSchemas ? "import * as zodSchemas from './zod';" : 'const zodSchemas = undefined;'} import { Prisma } from '${getPrismaClientImportSpec(model, outDir)}'; -${logicalPrismaClientDir ? `import { PrismaClient as EnhancedPrismaClient } from '${logicalPrismaClientDir}';` : ''} +${ + logicalPrismaClientDir + ? `import type { PrismaClient as EnhancedPrismaClient } from '${logicalPrismaClientDir}/index-fixed';` + : '' +} export function enhance(prisma: DbClient, context?: EnhancementContext, options?: EnhancementOptions) { return createEnhancement(prisma, { @@ -58,29 +67,33 @@ function hasDelegateModel(model: Model) { } async function generateLogicalPrisma(model: Model, options: PluginOptions, outDir: string) { - const prismaGenerator = new PrismaSchemaGenerator(); + const prismaGenerator = new PrismaSchemaGenerator(model); const prismaClientOutDir = './.delegate'; - await prismaGenerator.generate(model, { - provider: '@internal', + await prismaGenerator.generate({ + provider: '@internal', // doesn't matter schemaPath: options.schemaPath, output: path.join(outDir, 'delegate.prisma'), overrideClientGenerationPath: prismaClientOutDir, mode: 'logical', }); + // make a bunch of typing fixes to the generated prisma client await processClientTypes(model, path.join(outDir, prismaClientOutDir)); + return prismaClientOutDir; } async function processClientTypes(model: Model, prismaClientDir: string) { + // make necessary updates to the generated `index.d.ts` file and save it as `index-fixed.d.ts` const project = new Project(); const sf = project.addSourceFileAtPath(path.join(prismaClientDir, 'index.d.ts')); - const delegateModels: [DataModel, DataModel[]][] = []; + // build a map of delegate models and their sub models + const delegateInfo: DelegateInfo = []; model.declarations .filter((d): d is DataModel => isDelegateModel(d)) .forEach((dm) => { - delegateModels.push([ + delegateInfo.push([ dm, model.declarations.filter( (d): d is DataModel => isDataModel(d) && d.superTypes.some((s) => s.ref === dm) @@ -88,154 +101,231 @@ async function processClientTypes(model: Model, prismaClientDir: string) { ]); }); - const toRemove: (PropertySignature | MethodSignature)[] = []; - const toReplaceText: [TypeAliasDeclaration, string][] = []; - - sf.forEachDescendant((desc, traversal) => { - removeAuxRelationFields(desc, toRemove, traversal); - fixDelegateUnionType(desc, delegateModels, toReplaceText, traversal); - removeCreateFromDelegateInputTypes(desc, delegateModels, toRemove, traversal); - removeDelegateToplevelCreates(desc, delegateModels, toRemove, traversal); - removeDiscriminatorFromConcreteInputTypes(desc, delegateModels, toRemove); + const sfNew = project.createSourceFile(path.join(prismaClientDir, 'index-fixed.d.ts'), undefined, { + overwrite: true, }); - - toRemove.forEach((n) => n.remove()); - toReplaceText.forEach(([node, text]) => node.replaceWithText(text)); + transform(sf, sfNew, delegateInfo); + sfNew.formatText(); await project.save(); } -function removeAuxRelationFields( - desc: Node, - toRemove: (PropertySignature | MethodSignature)[], - traversal: ForEachDescendantTraversalControl -) { - if (desc.isKind(SyntaxKind.PropertySignature) || desc.isKind(SyntaxKind.MethodSignature)) { - // remove aux fields - const name = desc.getName(); +function transform(sf: SourceFile, sfNew: SourceFile, delegateModels: DelegateInfo) { + // copy toplevel imports + sfNew.addImportDeclarations(sf.getImportDeclarations().map((n) => n.getStructure())); - if (name.startsWith(DELEGATE_AUX_RELATION_PREFIX)) { - toRemove.push(desc); - traversal.skip(); - } - } + // copy toplevel import equals + sfNew.addStatements(sf.getChildrenOfKind(SyntaxKind.ImportEqualsDeclaration).map((n) => n.getFullText())); + + // copy toplevel exports + sfNew.addExportAssignments(sf.getExportAssignments().map((n) => n.getStructure())); + + // copy toplevel type aliases + sfNew.addTypeAliases(sf.getTypeAliases().map((n) => n.getStructure())); + + // copy toplevel classes + sfNew.addClasses(sf.getClasses().map((n) => n.getStructure())); + + // copy toplevel variables + sfNew.addVariableStatements(sf.getVariableStatements().map((n) => n.getStructure())); + + // copy toplevel namespaces except for `Prisma` + sfNew.addModules( + sf + .getModules() + .filter((n) => n.getName() !== 'Prisma') + .map((n) => n.getStructure()) + ); + + // transform the `Prisma` namespace + const prismaModule = sf.getModuleOrThrow('Prisma'); + const newPrismaModule = sfNew.addModule({ name: 'Prisma', isExported: true }); + transformPrismaModule(prismaModule, newPrismaModule, delegateModels); } -function fixDelegateUnionType( - desc: Node, - delegateModels: [DataModel, DataModel[]][], - toReplaceText: [TypeAliasDeclaration, string][], - traversal: ForEachDescendantTraversalControl +function transformPrismaModule( + prismaModule: ModuleDeclaration, + newPrismaModule: ModuleDeclaration, + delegateInfo: DelegateInfo ) { - if (!desc.isKind(SyntaxKind.TypeAliasDeclaration)) { - return; - } + // module block is the direct container of declarations inside a namespace + const moduleBlock = prismaModule.getFirstChildByKindOrThrow(SyntaxKind.ModuleBlock); - const name = desc.getName(); - delegateModels.forEach(([delegate, concreteModels]) => { - if (name === `$${delegate.name}Payload`) { - const discriminator = getDiscriminatorField(delegate); - if (discriminator) { - toReplaceText.push([ - desc, - `export type ${name} = - ${concreteModels - .map((m) => `($${m.name}Payload & { scalars: { ${discriminator.name}: '${m.name}' } })`) - .join(' | ')};`, - ]); - traversal.skip(); - } - } - }); + // most of the toplevel constructs should be copied over + // here we use ts-morph batch operations for optimal performance + + // copy imports + newPrismaModule.addStatements( + moduleBlock.getChildrenOfKind(SyntaxKind.ImportEqualsDeclaration).map((n) => n.getFullText()) + ); + + // copy classes + newPrismaModule.addClasses(moduleBlock.getClasses().map((n) => n.getStructure())); + + // copy functions + newPrismaModule.addFunctions( + moduleBlock.getFunctions().map((n) => n.getStructure() as FunctionDeclarationStructure) + ); + + // copy nested namespaces + newPrismaModule.addModules(moduleBlock.getModules().map((n) => n.getStructure())); + + // transform variables + const newVariables = moduleBlock.getVariableStatements().map((variable) => transformVariableStatement(variable)); + newPrismaModule.addVariableStatements(newVariables); + + // transform interfaces + const newInterfaces = moduleBlock.getInterfaces().map((iface) => transformInterface(iface, delegateInfo)); + newPrismaModule.addInterfaces(newInterfaces); + + // transform type aliases + const newTypeAliases = moduleBlock.getTypeAliases().map((typeAlias) => transformTypeAlias(typeAlias, delegateInfo)); + newPrismaModule.addTypeAliases(newTypeAliases); } -function removeCreateFromDelegateInputTypes( - desc: Node, - delegateModels: [DataModel, DataModel[]][], - toRemove: (PropertySignature | MethodSignature)[], - traversal: ForEachDescendantTraversalControl -) { - if (!desc.isKind(SyntaxKind.TypeAliasDeclaration)) { - return; - } +function transformVariableStatement(variable: VariableStatement) { + const structure = variable.getStructure(); - const name = desc.getName(); - delegateModels.forEach(([delegate]) => { - // remove create related sub-payload from delegate's input types since they cannot be created directly - const regex = new RegExp(`\\${delegate.name}(Unchecked)?(Create|Update).*Input`); - if (regex.test(name)) { - desc.forEachDescendant((d, innerTraversal) => { - if ( - d.isKind(SyntaxKind.PropertySignature) && - ['create', 'upsert', 'connectOrCreate'].includes(d.getName()) - ) { - toRemove.push(d); - innerTraversal.skip(); - } + // remove `delegate_aux_*` fields from the variable's typing + const auxFields = findAuxDecls(variable); + if (auxFields.length > 0) { + structure.declarations.forEach((variable) => { + let source = variable.type?.toString(); + auxFields.forEach((f) => { + source = source?.replace(f.getText(), ''); }); - traversal.skip(); - } - }); + variable.type = source; + }); + } + + return structure; } -function removeDiscriminatorFromConcreteInputTypes( - desc: Node, - delegateModels: [DataModel, DataModel[]][], - toRemove: (PropertySignature | MethodSignature)[] -) { - if (!desc.isKind(SyntaxKind.TypeAliasDeclaration)) { - return; +function transformInterface(iface: InterfaceDeclaration, delegateInfo: DelegateInfo) { + const structure = iface.getStructure(); + + // filter out aux fields + structure.properties = structure.properties?.filter((p) => !p.name.startsWith(DELEGATE_AUX_RELATION_PREFIX)); + + // filter out aux methods + structure.methods = structure.methods?.filter((m) => !m.name.startsWith(DELEGATE_AUX_RELATION_PREFIX)); + + if (delegateInfo.some(([delegate]) => `${delegate.name}Delegate` === iface.getName())) { + // delegate models cannot be created directly, remove create/createMany/upsert + structure.methods = structure.methods?.filter((m) => !['create', 'createMany', 'upsert'].includes(m.name)); } - const name = desc.getName(); - delegateModels.forEach(([delegate, concretes]) => { - const discriminator = getDiscriminatorField(delegate); - if (!discriminator) { - return; + return structure; +} + +function transformTypeAlias(typeAlias: TypeAliasDeclaration, delegateInfo: DelegateInfo) { + const structure = typeAlias.getStructure(); + let source = structure.type as string; + + // remove aux fields + source = removeAuxFieldsFromTypeAlias(typeAlias, source); + + // remove discriminator field from concrete input types + source = removeDiscriminatorFromConcreteInput(typeAlias, delegateInfo, source); + + // remove create/connectOrCreate/upsert fields from delegate's input types + source = removeCreateFromDelegateInput(typeAlias, delegateInfo, source); + + // fix delegate payload union type + source = fixDelegatePayloadType(typeAlias, delegateInfo, source); + + structure.type = source; + return structure; +} + +function fixDelegatePayloadType(typeAlias: TypeAliasDeclaration, delegateInfo: DelegateInfo, source: string) { + // change the type of `$Payload` type of delegate model to a union of concrete types + const typeName = typeAlias.getName(); + const payloadRecord = delegateInfo.find(([delegate]) => `$${delegate.name}Payload` === typeName); + if (payloadRecord) { + const discriminatorDecl = getDiscriminatorField(payloadRecord[0]); + if (discriminatorDecl) { + source = `${payloadRecord[1] + .map( + (concrete) => + `($${concrete.name}Payload & { scalars: { ${discriminatorDecl.name}: '${concrete.name}' } })` + ) + .join(' | ')}`; } + } + return source; +} - concretes.forEach((concrete) => { - // remove discriminator field from the create/update input of concrete models - const regex = new RegExp(`\\${concrete.name}(Unchecked)?(Create|Update).*Input`); - if (regex.test(name)) { - desc.forEachDescendant((d, innerTraversal) => { - if (d.isKind(SyntaxKind.PropertySignature)) { - if (d.getName() === discriminator.name) { - toRemove.push(d); - } - innerTraversal.skip(); - } - }); - } +function removeCreateFromDelegateInput(typeAlias: TypeAliasDeclaration, delegateModels: DelegateInfo, source: string) { + // remove create/connectOrCreate/upsert fields from delegate's input types because + // delegate models cannot be created directly + const typeName = typeAlias.getName(); + const delegateModelNames = delegateModels.map(([delegate]) => delegate.name); + const delegateCreateUpdateInputRegex = new RegExp( + `\\${delegateModelNames.join('|')}(Unchecked)?(Create|Update).*Input` + ); + if (delegateCreateUpdateInputRegex.test(typeName)) { + const toRemove = typeAlias + .getDescendantsOfKind(SyntaxKind.PropertySignature) + .filter((p) => ['create', 'connectOrCreate', 'upsert'].includes(p.getName())); + toRemove.forEach((r) => { + source = source.replace(r.getText(), ''); }); - }); + } + return source; } -function removeDelegateToplevelCreates( - desc: Node, - delegateModels: [DataModel, DataModel[]][], - toRemove: (PropertySignature | MethodSignature)[], - traversal: ForEachDescendantTraversalControl +function removeDiscriminatorFromConcreteInput( + typeAlias: TypeAliasDeclaration, + delegateInfo: DelegateInfo, + source: string ) { - if (desc.isKind(SyntaxKind.InterfaceDeclaration)) { - // remove create and upsert methods from delegate interfaces since they cannot be created directly - const name = desc.getName(); - if (delegateModels.map(([dm]) => `${dm.name}Delegate`).includes(name)) { - const createMethod = desc.getMethod('create'); - if (createMethod) { - toRemove.push(createMethod); - } - const createManyMethod = desc.getMethod('createMany'); - if (createManyMethod) { - toRemove.push(createManyMethod); - } - const upsertMethod = desc.getMethod('upsert'); - if (upsertMethod) { - toRemove.push(upsertMethod); - } - traversal.skip(); + // remove discriminator field from the create/update input of concrete models because + // discriminator cannot be set directly + const typeName = typeAlias.getName(); + const concreteModelNames = delegateInfo.map(([, concretes]) => concretes.map((c) => c.name)).flatMap((c) => c); + const concreteCreateUpdateInputRegex = new RegExp( + `(${concreteModelNames.join('|')})(Unchecked)?(Create|Update).*Input` + ); + + const match = typeName.match(concreteCreateUpdateInputRegex); + if (match) { + const modelName = match[1]; + const record = delegateInfo.find(([, concretes]) => concretes.some((c) => c.name === modelName)); + if (record) { + // remove all discriminator fields recursively + const delegateOfConcrete = record[0]; + const discriminators = getDiscriminatorFieldsRecursively(delegateOfConcrete); + discriminators.forEach((discriminatorDecl) => { + const discriminatorNode = findNamedProperty(typeAlias, discriminatorDecl.name); + if (discriminatorNode) { + source = source.replace(discriminatorNode.getText(), ''); + } + }); } } + return source; +} + +function removeAuxFieldsFromTypeAlias(typeAlias: TypeAliasDeclaration, source: string) { + // remove `delegate_aux_*` fields from the type alias + const auxDecls = findAuxDecls(typeAlias); + if (auxDecls.length > 0) { + auxDecls.forEach((d) => { + source = source.replace(d.getText(), ''); + }); + } + return source; +} + +function findNamedProperty(typeAlias: TypeAliasDeclaration, name: string) { + return typeAlias.getFirstDescendant((d) => d.isKind(SyntaxKind.PropertySignature) && d.getName() === name); +} + +function findAuxDecls(node: Node) { + return node + .getDescendantsOfKind(SyntaxKind.PropertySignature) + .filter((n) => n.getName().startsWith(DELEGATE_AUX_RELATION_PREFIX)); } function getDiscriminatorField(delegate: DataModel) { @@ -246,3 +336,19 @@ function getDiscriminatorField(delegate: DataModel) { const arg = delegateAttr.args[0]?.value; return isReferenceExpr(arg) ? (arg.target.ref as DataModelField) : undefined; } + +function getDiscriminatorFieldsRecursively(delegate: DataModel, result: DataModelField[] = []) { + if (isDelegateModel(delegate)) { + const discriminator = getDiscriminatorField(delegate); + if (discriminator) { + result.push(discriminator); + } + + for (const superType of delegate.superTypes) { + if (superType.ref) { + result.push(...getDiscriminatorFieldsRecursively(superType.ref, result)); + } + } + } + return result; +} diff --git a/packages/schema/src/plugins/prisma/index.ts b/packages/schema/src/plugins/prisma/index.ts index b27624cd7..5aa64c145 100644 --- a/packages/schema/src/plugins/prisma/index.ts +++ b/packages/schema/src/plugins/prisma/index.ts @@ -5,7 +5,7 @@ export const name = 'Prisma'; export const description = 'Generating Prisma schema'; const run: PluginFunction = async (model, options, _dmmf, _globalOptions) => { - return new PrismaSchemaGenerator().generate(model, options); + return new PrismaSchemaGenerator(model).generate(options); }; export default run; diff --git a/packages/schema/src/plugins/prisma/schema-generator.ts b/packages/schema/src/plugins/prisma/schema-generator.ts index 9a8ccb0eb..4ac78c6e3 100644 --- a/packages/schema/src/plugins/prisma/schema-generator.ts +++ b/packages/schema/src/plugins/prisma/schema-generator.ts @@ -17,6 +17,7 @@ import { InvocationExpr, isArrayExpr, isDataModel, + isDataSource, isInvocationExpr, isLiteralExpr, isNullExpr, @@ -79,6 +80,7 @@ import { const MODEL_PASSTHROUGH_ATTR = '@@prisma.passthrough'; const FIELD_PASSTHROUGH_ATTR = '@prisma.passthrough'; +const PROVIDERS_SUPPORTING_NAMED_CONSTRAINTS = ['postgresql', 'mysql', 'cockroachdb']; /** * Generates Prisma schema file @@ -95,7 +97,9 @@ export class PrismaSchemaGenerator { private mode: 'logical' | 'physical' = 'physical'; - async generate(model: Model, options: PluginOptions) { + constructor(private readonly zmodel: Model) {} + + async generate(options: PluginOptions) { const warnings: string[] = []; if (options.mode) { this.mode = options.mode as 'logical' | 'physical'; @@ -110,7 +114,7 @@ export class PrismaSchemaGenerator { const prisma = new PrismaModel(); - for (const decl of model.declarations) { + for (const decl of this.zmodel.declarations) { switch (decl.$type) { case DataSource: this.generateDataSource(prisma, decl as DataSource); @@ -151,7 +155,7 @@ export class PrismaSchemaGenerator { const generateClient = options.generateClient !== false; if (generateClient) { - let generateCmd = `prisma generate --schema "${outFile}"`; + let generateCmd = `prisma generate --schema "${outFile}"${this.mode === 'logical' ? ' --no-engine' : ''}`; if (typeof options.generateArgs === 'string') { generateCmd += ` ${options.generateArgs}`; } @@ -452,17 +456,23 @@ export class PrismaSchemaGenerator { new AttributeArgValue('FieldReference', new PrismaFieldReference(idField.name)) ) ); - relationField.attributes.push( - new PrismaFieldAttribute('@relation', [ - new PrismaAttributeArg('fields', args), - new PrismaAttributeArg('references', args), + + const addedRel = new PrismaFieldAttribute('@relation', [ + new PrismaAttributeArg('fields', args), + new PrismaAttributeArg('references', args), + ]); + + if (this.supportNamedConstraints) { + addedRel.args.push( // generate a `map` argument for foreign key constraint disambiguation new PrismaAttributeArg( 'map', new PrismaAttributeArgValue('String', `${relationField.name}_fk`) - ), - ]) - ); + ) + ); + } + + relationField.attributes.push(addedRel); } else { relationField.attributes.push(this.makeFieldAttribute(relAttr as DataModelFieldAttribute)); } @@ -471,6 +481,21 @@ export class PrismaSchemaGenerator { }); } + private get supportNamedConstraints() { + const ds = this.zmodel.declarations.find(isDataSource); + if (!ds) { + return false; + } + + const provider = ds.fields.find((f) => f.name === 'provider'); + if (!provider) { + return false; + } + + const value = getStringLiteral(provider.value); + return value && PROVIDERS_SUPPORTING_NAMED_CONSTRAINTS.includes(value); + } + private isPrismaAttribute(attr: DataModelAttribute | DataModelFieldAttribute) { if (!attr.decl.ref) { return false; diff --git a/packages/schema/tests/generator/prisma-generator.test.ts b/packages/schema/tests/generator/prisma-generator.test.ts index 67ba27f99..35d68fb28 100644 --- a/packages/schema/tests/generator/prisma-generator.test.ts +++ b/packages/schema/tests/generator/prisma-generator.test.ts @@ -49,7 +49,7 @@ describe('Prisma generator test', () => { } `); - await new PrismaSchemaGenerator().generate(model, { + await new PrismaSchemaGenerator(model).generate({ name: 'Prisma', provider: '@core/prisma', schemaPath: 'schema.zmodel', @@ -90,7 +90,7 @@ describe('Prisma generator test', () => { `); const { name } = tmp.fileSync({ postfix: '.prisma' }); - await new PrismaSchemaGenerator().generate(model, { + await new PrismaSchemaGenerator(model).generate({ name: 'Prisma', provider: '@core/prisma', schemaPath: 'schema.zmodel', @@ -128,7 +128,7 @@ describe('Prisma generator test', () => { `); const { name } = tmp.fileSync({ postfix: '.prisma' }); - await new PrismaSchemaGenerator().generate(model, { + await new PrismaSchemaGenerator(model).generate({ name: 'Prisma', provider: '@core/prisma', schemaPath: 'schema.zmodel', @@ -162,7 +162,7 @@ describe('Prisma generator test', () => { `); const { name } = tmp.fileSync({ postfix: '.prisma' }); - await new PrismaSchemaGenerator().generate(model, { + await new PrismaSchemaGenerator(model).generate({ name: 'Prisma', provider: '@core/prisma', schemaPath: 'schema.zmodel', @@ -194,7 +194,7 @@ describe('Prisma generator test', () => { `); const { name } = tmp.fileSync({ postfix: '.prisma' }); - await new PrismaSchemaGenerator().generate(model, { + await new PrismaSchemaGenerator(model).generate({ name: 'Prisma', provider: '@core/prisma', schemaPath: 'schema.zmodel', @@ -230,7 +230,7 @@ describe('Prisma generator test', () => { `); const { name } = tmp.fileSync({ postfix: '.prisma' }); - await new PrismaSchemaGenerator().generate(model, { + await new PrismaSchemaGenerator(model).generate({ name: 'Prisma', provider: '@core/prisma', schemaPath: 'schema.zmodel', @@ -270,7 +270,7 @@ describe('Prisma generator test', () => { `); const { name } = tmp.fileSync({ postfix: '.prisma' }); - await new PrismaSchemaGenerator().generate(model, { + await new PrismaSchemaGenerator(model).generate({ name: 'Prisma', provider: '@core/prisma', schemaPath: 'schema.zmodel', @@ -321,7 +321,7 @@ describe('Prisma generator test', () => { `); const { name } = tmp.fileSync({ postfix: '.prisma' }); - await new PrismaSchemaGenerator().generate(model, { + await new PrismaSchemaGenerator(model).generate({ name: 'Prisma', provider: '@core/prisma', schemaPath: 'schema.zmodel', @@ -357,7 +357,7 @@ describe('Prisma generator test', () => { } `); const { name } = tmp.fileSync({ postfix: '.prisma' }); - await new PrismaSchemaGenerator().generate(model, { + await new PrismaSchemaGenerator(model).generate({ name: 'Prisma', provider: '@core/prisma', schemaPath: 'schema.zmodel', @@ -380,7 +380,7 @@ describe('Prisma generator test', () => { const model = await loadDocument(path.join(__dirname, './zmodel/schema.zmodel')); const { name } = tmp.fileSync({ postfix: '.prisma' }); - await new PrismaSchemaGenerator().generate(model, { + await new PrismaSchemaGenerator(model).generate({ name: 'Prisma', provider: '@core/prisma', schemaPath: 'schema.zmodel', @@ -430,7 +430,7 @@ describe('Prisma generator test', () => { `); const { name } = tmp.fileSync({ postfix: '.prisma' }); - await new PrismaSchemaGenerator().generate(model, { + await new PrismaSchemaGenerator(model).generate({ name: 'Prisma', provider: '@core/prisma', schemaPath: 'schema.zmodel', @@ -461,7 +461,7 @@ describe('Prisma generator test', () => { `); const { name } = tmp.fileSync({ postfix: '.prisma' }); - await new PrismaSchemaGenerator().generate(model, { + await new PrismaSchemaGenerator(model).generate({ name: 'Prisma', provider: '@core/prisma', schemaPath: 'schema.zmodel', @@ -496,7 +496,7 @@ describe('Prisma generator test', () => { `); const { name } = tmp.fileSync({ postfix: '.prisma' }); - await new PrismaSchemaGenerator().generate(model, { + await new PrismaSchemaGenerator(model).generate({ name: 'Prisma', provider: '@core/prisma', schemaPath: 'schema.zmodel', diff --git a/tests/integration/tests/enhancements/with-delegate/regressions.test.ts b/tests/integration/tests/enhancements/with-delegate/issue-1058.test.ts similarity index 94% rename from tests/integration/tests/enhancements/with-delegate/regressions.test.ts rename to tests/integration/tests/enhancements/with-delegate/issue-1058.test.ts index 77166e275..cd566c71f 100644 --- a/tests/integration/tests/enhancements/with-delegate/regressions.test.ts +++ b/tests/integration/tests/enhancements/with-delegate/issue-1058.test.ts @@ -1,7 +1,7 @@ import { loadSchema } from '@zenstackhq/testtools'; -describe('Regression tests', () => { - it('FK Constraint Ambiguity', async () => { +describe('Regression for issue 1058', () => { + it('test', async () => { const schema = ` model User { id String @id @default(cuid()) diff --git a/tests/integration/tests/enhancements/with-delegate/issue-1064.test.ts b/tests/integration/tests/enhancements/with-delegate/issue-1064.test.ts new file mode 100644 index 000000000..a8505f507 --- /dev/null +++ b/tests/integration/tests/enhancements/with-delegate/issue-1064.test.ts @@ -0,0 +1,291 @@ +import { loadSchema } from '@zenstackhq/testtools'; + +describe('Regression for issue 1064', () => { + it('test', async () => { + const schema = ` + model Account { + id String @id @default(cuid()) + userId String + type String + provider String + providerAccountId String + refresh_token String? // @db.Text + access_token String? // @db.Text + expires_at Int? + token_type String? + scope String? + id_token String? // @db.Text + session_state String? + user User @relation(fields: [userId], references: [id], onDelete: Cascade) + @@allow('all', auth().id == userId) + @@unique([provider, providerAccountId]) + } + + model Session { + id String @id @default(cuid()) + sessionToken String @unique + userId String + expires DateTime + user User @relation(fields: [userId], references: [id], onDelete: Cascade) + @@allow('all', auth().id == userId) + } + + model VerificationToken { + identifier String + token String @unique + expires DateTime + + @@allow('all', true) + @@unique([identifier, token]) + } + + model User { + id String @id @default(cuid()) + name String? + email String? @unique + emailVerified DateTime? + image String + accounts Account[] + sessions Session[] + + username String @unique @length(min: 4, max: 20) + about String? @length(max: 500) + location String? @length(max: 100) + + role String @default("USER") @deny(operation: "update", auth().role != "ADMIN") + + inserted_at DateTime @default(now()) + updated_at DateTime @updatedAt() @default(now()) + + editComments EditComment[] + + posts Post[] + rankings UserRanking[] + ratings UserRating[] + favorites UserFavorite[] + + people Person[] + studios Studio[] + edits Edit[] + attachments Attachment[] + galleries Gallery[] + + uploads UserUpload[] + + maxUploadsPerDay Int @default(10) + maxEditsPerDay Int @default(10) + + // everyone can signup, and user profile is also publicly readable + @@allow('create,read', true) + // only the user can update or delete their own profile + @@allow('update,delete', auth() == this) + } + + abstract model UserEntityRelation { + entityId String? + entity Entity? @relation(fields: [entityId], references: [id], onUpdate: NoAction) + userId String + user User @relation(fields: [userId], references: [id], onDelete: Cascade, onUpdate: NoAction) + + + // everyone can read + @@allow('read', true) + @@allow('create,update,delete', auth().id == this.userId) + + @@unique([userId,entityId]) + } + + model UserUpload { + timestamp DateTime @default(now()) + + key String @id + url String @unique + size Int + + userId String + user User @relation(fields: [userId], references: [id], onDelete: Cascade, onUpdate: NoAction) + + @@allow('create', auth().id == userId) + @@allow('all', auth().role == "ADMIN") + } + + model Post { + id Int @id @default(autoincrement()) + title String @length(max: 100) + body String @length(max: 1000) + createdAt DateTime @default(now()) + + authorId String + author User @relation(fields: [authorId], references: [id], onDelete: Cascade, onUpdate: NoAction) + + @@allow('read', true) + @@allow('create,update,delete', auth().id == authorId && auth().role == "ADMIN") + } + + model Edit extends UserEntityRelation { + id String @id @default(cuid()) + status String @default("PENDING") @allow('update', auth().role in ["ADMIN", "MODERATOR"]) + type String @allow('update', false) + timestamp DateTime @default(now()) + note String? @length(max: 300) + // for creates - createPayload & updates - data before diff is applied + data String? + // for updates + diff String? + + comments EditComment[] + } + + model EditComment { + id Int @id @default(autoincrement()) + timestamp DateTime @default(now()) + content String @length(max: 300) + editId String + edit Edit @relation(fields: [editId], references: [id], onUpdate: Cascade) + authorId String + author User @relation(fields: [authorId], references: [id], onUpdate: Cascade) + + // everyone can read + @@allow('read', true) + @@allow('create,update,delete', auth().id == this.authorId || auth().role in ["ADMIN", "MODERATOR"]) + } + + model MetadataIdentifier { + id Int @default(autoincrement()) @id + + identifier String + + metadataSource String + MetadataSource MetadataSource @relation(fields: [metadataSource], references: [slug], onUpdate: Cascade) + + entities Entity[] + + @@unique([identifier, metadataSource]) + + @@allow('read', true) + @@allow('create,update,delete', auth().role in ["ADMIN", "MODERATOR"]) + } + + model MetadataSource { + slug String @id + name String @unique + identifierRegex String + desc String? + url String + icon String + identifiers MetadataIdentifier[] + + @@allow('all', auth().role == "ADMIN") + } + + model Attachment extends UserEntityRelation { + id String @id @default(cuid()) + createdAt DateTime @default(now()) + key String @unique + url String @unique + galleries Gallery[] + @@allow('delete', auth().role in ["ADMIN", "MODERATOR"]) + } + + model Entity { + id String @id @default(cuid()) + name String + desc String? + + attachments Attachment[] + createdAt DateTime @default(now()) + updatedAt DateTime @updatedAt @default(now()) + + type String + + status String @default("PENDING") // PENDING ON INITIAL CREATION + verified Boolean @default(false) + + edits Edit[] + userRankings UserRanking[] + userFavorites UserFavorite[] + userRatings UserRating[] + metaIdentifiers MetadataIdentifier[] + + @@delegate(type) + + @@allow('read', true) + @@allow('create', auth() != null) + @@allow('update', auth().role in ["ADMIN", "MODERATOR"]) + @@allow('delete', auth().role == "ADMIN") + } + + model Person extends Entity { + studios Studio[] + owners User[] + clips Clip[] + events Event[] + galleries Gallery[] + } + + model Studio extends Entity { + people Person[] + owners User[] + clips Clip[] + events Event[] + galleries Gallery[] + } + + model Clip extends Entity { + url String? + people Person[] + studios Studio[] + galleries Gallery[] + } + + model UserRanking extends UserEntityRelation { + id String @id @default(cuid()) + rank Int @gte(1) @lte(100) + note String? @length(max: 300) + } + + model UserFavorite extends UserEntityRelation { + id String @id @default(cuid()) + favoritedAt DateTime @default(now()) + } + + model UserRating extends UserEntityRelation { + id String @id @default(cuid()) + rating Int @gte(1) @lte(5) + note String? @length(max: 500) + ratedAt DateTime @default(now()) + } + + model Event { + id Int @id @default(autoincrement()) + name String @length(max: 100) + desc String? @length(max: 500) + location String? @length(max: 100) + date DateTime? + people Person[] + studios Studio[] + + @@allow('read', true) + @@allow('create,update,delete', auth().role == "ADMIN") + } + + model Gallery { + id String @id @default(cuid()) + studioId String? + personId String? + timestamp DateTime @default(now()) + authorId String + author User @relation(fields: [authorId], references: [id], onDelete: Cascade, onUpdate: NoAction) + people Person[] + studios Studio[] + clips Clip[] + attachments Attachment[] + + @@allow('read', true) + @@allow('create,update,delete', auth().id == this.authorId && auth().role == "ADMIN") + } + `; + + await loadSchema(schema); + }); +}); diff --git a/tests/integration/tests/enhancements/with-delegate/polymorphism.test.ts b/tests/integration/tests/enhancements/with-delegate/polymorphism.test.ts index 0d0b24ca2..31976fbce 100644 --- a/tests/integration/tests/enhancements/with-delegate/polymorphism.test.ts +++ b/tests/integration/tests/enhancements/with-delegate/polymorphism.test.ts @@ -1,5 +1,7 @@ -import { loadSchema } from '@zenstackhq/testtools'; import { PrismaErrorCode } from '@zenstackhq/runtime'; +import { loadSchema, run } from '@zenstackhq/testtools'; +import fs from 'fs'; +import path from 'path'; describe('Polymorphism Test', () => { const schema = ` @@ -1012,4 +1014,77 @@ model Gallery { 'groupBy with fields from base type is not supported yet' ); }); + + it('typescript compilation', async () => { + const { projectDir } = await loadSchema(schema, { enhancements: ['delegate'] }); + const src = ` + import { PrismaClient } from '@prisma/client'; + import { enhance } from '.zenstack/enhance'; + + const prisma = new PrismaClient(); + + async function main() { + await prisma.user.deleteMany(); + const db = enhance(prisma); + + const user1 = await db.user.create({ data: { } }); + + await db.ratedVideo.create({ + data: { + owner: { connect: { id: user1.id } }, + duration: 100, + url: 'abc', + rating: 10, + }, + }); + + await db.image.create({ + data: { + owner: { connect: { id: user1.id } }, + format: 'webp', + }, + }); + + const video = await db.video.findFirst({ include: { owner: true } }); + console.log(video?.duration); + console.log(video?.viewCount); + + const asset = await db.asset.findFirstOrThrow(); + console.log(asset.assetType); + console.log(asset.viewCount); + + if (asset.assetType === 'Video') { + console.log('Video: duration', asset.duration); + } else { + console.log('Image: format', asset.format); + } + } + + main() + .then(async () => { + await prisma.$disconnect(); + }) + .catch(async (e) => { + console.error(e); + await prisma.$disconnect(); + process.exit(1); + }); + `; + + fs.writeFileSync(path.join(projectDir, 'script.ts'), src); + fs.writeFileSync( + path.join(projectDir, 'tsconfig.json'), + JSON.stringify({ + compilerOptions: { + outDir: 'dist', + strict: true, + lib: ['esnext'], + esModuleInterop: true, + }, + }) + ); + + run('npm i -D @types/node', undefined, projectDir); + run('npx tsc --noEmit --skipLibCheck script.ts', undefined, projectDir); + }); });