From 4a2681ed8be79a4e66ee1a0b6cee09e07a756ea4 Mon Sep 17 00:00:00 2001 From: ymc9 <104139426+ymc9@users.noreply.github.com> Date: Mon, 15 Apr 2024 10:53:31 +0800 Subject: [PATCH 1/2] fix(zmodel): member access from `auth()` is not properly resolved when the auth model is imported Fixes #1257 --- .../projects/t3-trpc-v10/prisma/schema.prisma | 30 +++++------ packages/schema/src/cli/cli-util.ts | 35 +++++++----- .../validator/expression-validator.ts | 17 +++++- .../validator/schema-validator.ts | 4 +- .../src/language-server/zmodel-linker.ts | 14 ++--- .../src/language-server/zmodel-scope.ts | 23 ++++---- packages/schema/src/utils/ast-utils.ts | 6 ++- .../validation/attribute-validation.test.ts | 2 +- .../tests/regression/issue-1257.test.ts | 53 +++++++++++++++++++ 9 files changed, 129 insertions(+), 55 deletions(-) create mode 100644 tests/integration/tests/regression/issue-1257.test.ts diff --git a/packages/plugins/trpc/tests/projects/t3-trpc-v10/prisma/schema.prisma b/packages/plugins/trpc/tests/projects/t3-trpc-v10/prisma/schema.prisma index 2a0b2142a..a28fea9fb 100644 --- a/packages/plugins/trpc/tests/projects/t3-trpc-v10/prisma/schema.prisma +++ b/packages/plugins/trpc/tests/projects/t3-trpc-v10/prisma/schema.prisma @@ -4,28 +4,28 @@ ////////////////////////////////////////////////////////////////////////////////////////////// datasource db { - provider = "sqlite" - url = "file:./dev.db" + provider = "sqlite" + url = "file:./dev.db" } generator client { - provider = "prisma-client-js" + provider = "prisma-client-js" } model User { - id Int @id() @default(autoincrement()) - email String @unique() - posts Post[] + id Int @id() @default(autoincrement()) + email String @unique() + posts Post[] } model Post { - id Int @id() @default(autoincrement()) - name String - createdAt DateTime @default(now()) - updatedAt DateTime @updatedAt() - published Boolean @default(false) - author User @relation(fields: [authorId], references: [id]) - authorId Int + id Int @id() @default(autoincrement()) + name String + createdAt DateTime @default(now()) + updatedAt DateTime @updatedAt() + published Boolean @default(false) + author User @relation(fields: [authorId], references: [id]) + authorId Int - @@index([name]) -} \ No newline at end of file + @@index([name]) +} diff --git a/packages/schema/src/cli/cli-util.ts b/packages/schema/src/cli/cli-util.ts index 89d194fb9..e9db60fb6 100644 --- a/packages/schema/src/cli/cli-util.ts +++ b/packages/schema/src/cli/cli-util.ts @@ -64,23 +64,30 @@ export async function loadDocument(fileName: string): Promise { } ); - const validationErrors = langiumDocuments.all - .flatMap((d) => d.diagnostics ?? []) - .filter((e) => e.severity === 1) + const diagnostics = langiumDocuments.all + .flatMap((doc) => (doc.diagnostics ?? []).map((diag) => ({ doc, diag }))) + .filter(({ diag }) => diag.severity === 1 || diag.severity === 2) .toArray(); - if (validationErrors.length > 0) { - console.error(colors.red('Validation errors:')); - for (const validationError of validationErrors) { - console.error( - colors.red( - `line ${validationError.range.start.line + 1}: ${ - validationError.message - } [${document.textDocument.getText(validationError.range)}]` - ) - ); + let hasErrors = false; + + if (diagnostics.length > 0) { + for (const { doc, diag } of diagnostics) { + const message = `${path.relative(process.cwd(), doc.uri.fsPath)}:${diag.range.start.line + 1}:${ + diag.range.start.character + 1 + } - ${diag.message}`; + + if (diag.severity === 1) { + console.error(colors.red(message)); + hasErrors = true; + } else { + console.warn(colors.yellow(message)); + } + } + + if (hasErrors) { + throw new CliError('Schema contains validation errors'); } - throw new CliError('schema validation errors'); } const model = document.parseResult.value as Model; diff --git a/packages/schema/src/language-server/validator/expression-validator.ts b/packages/schema/src/language-server/validator/expression-validator.ts index 59ad7edbb..cb42e4cb1 100644 --- a/packages/schema/src/language-server/validator/expression-validator.ts +++ b/packages/schema/src/language-server/validator/expression-validator.ts @@ -10,6 +10,7 @@ import { isLiteralExpr, isMemberAccessExpr, isNullExpr, + isReferenceExpr, isThisExpr, } from '@zenstackhq/language/ast'; import { isAuthInvocation, isDataModelFieldReference, isEnumFieldReference } from '@zenstackhq/sdk'; @@ -33,9 +34,21 @@ export default class ExpressionValidator implements AstValidator { { node: expr } ); } else { - accept('error', 'expression cannot be resolved', { - node: expr, + const hasReferenceResolutionError = streamAst(expr).some((node) => { + if (isMemberAccessExpr(node)) { + return !!node.member.error; + } + if (isReferenceExpr(node)) { + return !!node.target.error; + } + return false; }); + if (!hasReferenceResolutionError) { + // report silent errors not involving linker errors + accept('error', 'Expression cannot be resolved', { + node: expr, + }); + } } } diff --git a/packages/schema/src/language-server/validator/schema-validator.ts b/packages/schema/src/language-server/validator/schema-validator.ts index 9e0512547..d071324c1 100644 --- a/packages/schema/src/language-server/validator/schema-validator.ts +++ b/packages/schema/src/language-server/validator/schema-validator.ts @@ -1,7 +1,7 @@ import { Model, isDataModel, isDataSource } from '@zenstackhq/language/ast'; import { hasAttribute } from '@zenstackhq/sdk'; import { LangiumDocuments, ValidationAcceptor } from 'langium'; -import { getAllDeclarationsFromImports, resolveImport, resolveTransitiveImports } from '../../utils/ast-utils'; +import { getAllDeclarationsIncludingImports, resolveImport, resolveTransitiveImports } from '../../utils/ast-utils'; import { PLUGIN_MODULE_NAME, STD_LIB_MODULE_NAME } from '../constants'; import { AstValidator } from '../types'; import { validateDuplicatedDeclarations } from './utils'; @@ -43,7 +43,7 @@ export default class SchemaValidator implements AstValidator { } private validateDataSources(model: Model, accept: ValidationAcceptor) { - const dataSources = getAllDeclarationsFromImports(this.documents, model).filter((d) => isDataSource(d)); + const dataSources = getAllDeclarationsIncludingImports(this.documents, model).filter((d) => isDataSource(d)); if (dataSources.length > 1) { accept('error', 'Multiple datasource declarations are not allowed', { node: dataSources[1] }); } diff --git a/packages/schema/src/language-server/zmodel-linker.ts b/packages/schema/src/language-server/zmodel-linker.ts index 56e2431d5..5a15f9336 100644 --- a/packages/schema/src/language-server/zmodel-linker.ts +++ b/packages/schema/src/language-server/zmodel-linker.ts @@ -36,9 +36,9 @@ import { isStringLiteral, } from '@zenstackhq/language/ast'; import { + getAuthModel, getContainingModel, getModelFieldsWithBases, - hasAttribute, isAuthInvocation, isFutureExpr, } from '@zenstackhq/sdk'; @@ -58,7 +58,7 @@ import { } from 'langium'; import { match } from 'ts-pattern'; import { CancellationToken } from 'vscode-jsonrpc'; -import { getAllDeclarationsFromImports, getContainingDataModel } from '../utils/ast-utils'; +import { getAllDataModelsIncludingImports, getContainingDataModel } from '../utils/ast-utils'; import { mapBuiltinTypeToExpressionType } from './validator/utils'; interface DefaultReference extends Reference { @@ -287,14 +287,8 @@ export class ZModelLinker extends DefaultLinker { const model = getContainingModel(node); if (model) { - let authModel = getAllDeclarationsFromImports(this.langiumDocuments(), model).find((d) => { - return isDataModel(d) && hasAttribute(d, '@@auth'); - }); - if (!authModel) { - authModel = getAllDeclarationsFromImports(this.langiumDocuments(), model).find((d) => { - return isDataModel(d) && d.name === 'User'; - }); - } + const allDataModels = getAllDataModelsIncludingImports(this.langiumDocuments(), model); + const authModel = getAuthModel(allDataModels); if (authModel) { node.$resolvedType = { decl: authModel, nullable: true }; } diff --git a/packages/schema/src/language-server/zmodel-scope.ts b/packages/schema/src/language-server/zmodel-scope.ts index 7dff9c8df..a10b71949 100644 --- a/packages/schema/src/language-server/zmodel-scope.ts +++ b/packages/schema/src/language-server/zmodel-scope.ts @@ -10,13 +10,7 @@ import { isReferenceExpr, isThisExpr, } from '@zenstackhq/language/ast'; -import { - getAuthModel, - getDataModels, - getModelFieldsWithBases, - getRecursiveBases, - isAuthInvocation, -} from '@zenstackhq/sdk'; +import { getAuthModel, getModelFieldsWithBases, getRecursiveBases, isAuthInvocation } from '@zenstackhq/sdk'; import { AstNode, AstNodeDescription, @@ -37,7 +31,12 @@ import { } from 'langium'; import { match } from 'ts-pattern'; import { CancellationToken } from 'vscode-jsonrpc'; -import { isCollectionPredicate, isFutureInvocation, resolveImportUri } from '../utils/ast-utils'; +import { + getAllDataModelsIncludingImports, + isCollectionPredicate, + isFutureInvocation, + resolveImportUri, +} from '../utils/ast-utils'; import { PLUGIN_MODULE_NAME, STD_LIB_MODULE_NAME } from './constants'; /** @@ -88,7 +87,7 @@ export class ZModelScopeComputation extends DefaultScopeComputation { } export class ZModelScopeProvider extends DefaultScopeProvider { - constructor(services: LangiumServices) { + constructor(private readonly services: LangiumServices) { super(services); } @@ -222,7 +221,11 @@ export class ZModelScopeProvider extends DefaultScopeProvider { private createScopeForAuthModel(node: AstNode, globalScope: Scope) { const model = getContainerOfType(node, isModel); if (model) { - const authModel = getAuthModel(getDataModels(model, true)); + const allDataModels = getAllDataModelsIncludingImports( + this.services.shared.workspace.LangiumDocuments, + model + ); + const authModel = getAuthModel(allDataModels); if (authModel) { return this.createScopeForModel(authModel, globalScope); } diff --git a/packages/schema/src/utils/ast-utils.ts b/packages/schema/src/utils/ast-utils.ts index d33f27e71..3a255228e 100644 --- a/packages/schema/src/utils/ast-utils.ts +++ b/packages/schema/src/utils/ast-utils.ts @@ -216,11 +216,15 @@ export function resolveImport(documents: LangiumDocuments, imp: ModelImport): Mo return undefined; } -export function getAllDeclarationsFromImports(documents: LangiumDocuments, model: Model) { +export function getAllDeclarationsIncludingImports(documents: LangiumDocuments, model: Model) { const imports = resolveTransitiveImports(documents, model); return model.declarations.concat(...imports.map((imp) => imp.declarations)); } +export function getAllDataModelsIncludingImports(documents: LangiumDocuments, model: Model) { + return getAllDeclarationsIncludingImports(documents, model).filter(isDataModel); +} + export function isCollectionPredicate(node: AstNode): node is BinaryExpr { return isBinaryExpr(node) && ['?', '!', '^'].includes(node.operator); } diff --git a/packages/schema/tests/schema/validation/attribute-validation.test.ts b/packages/schema/tests/schema/validation/attribute-validation.test.ts index e3c1c597e..aca9e2674 100644 --- a/packages/schema/tests/schema/validation/attribute-validation.test.ts +++ b/packages/schema/tests/schema/validation/attribute-validation.test.ts @@ -1081,7 +1081,7 @@ describe('Attribute tests', () => { @@allow('all', auth().email != null) } `) - ).toContain(`expression cannot be resolved`); + ).toContain(`Could not resolve reference to DataModelField named 'email'.`); }); it('collection predicate expression check', async () => { diff --git a/tests/integration/tests/regression/issue-1257.test.ts b/tests/integration/tests/regression/issue-1257.test.ts new file mode 100644 index 000000000..a692d0464 --- /dev/null +++ b/tests/integration/tests/regression/issue-1257.test.ts @@ -0,0 +1,53 @@ +import { FILE_SPLITTER, loadSchema } from '@zenstackhq/testtools'; + +describe('issue 1210', () => { + it('regression', async () => { + await loadSchema( + `schema.zmodel + import "./user" + import "./image" + + generator client { + provider = "prisma-client-js" + } + + datasource db { + provider = "postgresql" + url = env("DATABASE_URL") + } + + ${FILE_SPLITTER}base.zmodel + abstract model Base { + id Int @id @default(autoincrement()) + } + + ${FILE_SPLITTER}user.zmodel + import "./base" + import "./image" + + enum Role { + Admin + } + + model User extends Base { + email String @unique + role Role + @@auth + } + + ${FILE_SPLITTER}image.zmodel + import "./user" + import "./base" + + model Image extends Base { + width Int @default(0) + height Int @default(0) + + @@allow('read', true) + @@allow('all', auth().role == Admin) + } + `, + { addPrelude: false, pushDb: false } + ); + }); +}); From 8962e89ccd65498951bd91e9e64810bf43706d3b Mon Sep 17 00:00:00 2001 From: ymc9 <104139426+ymc9@users.noreply.github.com> Date: Mon, 15 Apr 2024 11:16:03 +0800 Subject: [PATCH 2/2] improve scope computation of member access expression --- packages/schema/src/language-server/zmodel-scope.ts | 4 ++-- tests/integration/tests/regression/issue-756.test.ts | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/packages/schema/src/language-server/zmodel-scope.ts b/packages/schema/src/language-server/zmodel-scope.ts index a10b71949..e48a17621 100644 --- a/packages/schema/src/language-server/zmodel-scope.ts +++ b/packages/schema/src/language-server/zmodel-scope.ts @@ -144,9 +144,9 @@ export class ZModelScopeProvider extends DefaultScopeProvider { return EMPTY_SCOPE; }) .when(isMemberAccessExpr, (operand) => { - // operand is a member access, it must be resolved to a + // operand is a member access, it must be resolved to a non-array data model type const ref = operand.member.ref; - if (isDataModelField(ref)) { + if (isDataModelField(ref) && !ref.type.array) { const targetModel = ref.type.reference?.ref; return this.createScopeForModel(targetModel, globalScope); } diff --git a/tests/integration/tests/regression/issue-756.test.ts b/tests/integration/tests/regression/issue-756.test.ts index b10e60af2..9f6750ea9 100644 --- a/tests/integration/tests/regression/issue-756.test.ts +++ b/tests/integration/tests/regression/issue-756.test.ts @@ -28,6 +28,6 @@ describe('Regression: issue 756', () => { } ` ) - ).toContain('expression cannot be resolved'); + ).toContain(`Could not resolve reference to DataModelField named 'authorId'.`); }); });