From fb71f2d2aa416dc3e8a1ab79ac6f0059e940003d Mon Sep 17 00:00:00 2001 From: ymc9 <104139426+ymc9@users.noreply.github.com> Date: Sun, 26 May 2024 21:20:25 +0800 Subject: [PATCH] refactor: refactor policy definition and generation --- packages/runtime/src/constants.ts | 35 - .../src/enhancements/policy/policy-utils.ts | 225 ++- packages/runtime/src/enhancements/types.ts | 163 ++- .../src/plugins/enhancer/policy/index.ts | 2 +- .../enhancer/policy/policy-guard-generator.ts | 1211 +++++++---------- .../src/plugins/enhancer/policy/utils.ts | 386 ++++++ packages/sdk/src/policy.ts | 2 + packages/testtools/src/schema.ts | 7 +- .../with-delegate/enhanced-client.test.ts | 14 +- .../with-delegate/policy-interaction.test.ts | 12 +- .../with-policy/fluent-api.test.ts | 3 +- .../with-policy/nested-to-one.test.ts | 3 +- .../with-policy/post-update.test.ts | 3 +- .../with-policy/prisma-omit.test.ts | 2 +- .../enhancements/with-policy/refactor.test.ts | 1 - .../with-policy/subscription.test.ts | 4 - .../integration/tests/plugins/policy.test.ts | 43 +- tests/regression/tests/issue-1014.test.ts | 3 +- tests/regression/tests/issue-1080.test.ts | 3 +- tests/regression/tests/issue-1241.test.ts | 3 +- tests/regression/tests/issue-1271.test.ts | 3 +- tests/regression/tests/issue-1435.test.ts | 2 +- tests/regression/tests/issue-1451.test.ts | 3 +- tests/regression/tests/issue-961.test.ts | 2 +- tests/regression/tests/issues.test.ts | 3 +- 25 files changed, 1161 insertions(+), 977 deletions(-) create mode 100644 packages/schema/src/plugins/enhancer/policy/utils.ts diff --git a/packages/runtime/src/constants.ts b/packages/runtime/src/constants.ts index a85392887..5fd8c2901 100644 --- a/packages/runtime/src/constants.ts +++ b/packages/runtime/src/constants.ts @@ -63,41 +63,6 @@ export const PRISMA_PROXY_ENHANCER = '$__zenstack_enhancer'; */ export const PRISMA_MINIMUM_VERSION = '5.0.0'; -/** - * Selector function name for fetching pre-update entity values. - */ -export const PRE_UPDATE_VALUE_SELECTOR = 'preValueSelect'; - -/** - * Prefix for field-level read checker function name - */ -export const FIELD_LEVEL_READ_CHECKER_PREFIX = 'readFieldCheck$'; - -/** - * Field-level access control evaluation selector function name - */ -export const FIELD_LEVEL_READ_CHECKER_SELECTOR = 'readFieldSelect'; - -/** - * Prefix for field-level override read guard function name - */ -export const FIELD_LEVEL_OVERRIDE_READ_GUARD_PREFIX = 'readFieldGuardOverride$'; - -/** - * Prefix for field-level update guard function name - */ -export const FIELD_LEVEL_UPDATE_GUARD_PREFIX = 'updateFieldGuard$'; - -/** - * Prefix for field-level override update guard function name - */ -export const FIELD_LEVEL_OVERRIDE_UPDATE_GUARD_PREFIX = 'updateFieldGuardOverride$'; - -/** - * Flag that indicates if the model has field-level access control - */ -export const HAS_FIELD_LEVEL_POLICY_FLAG = 'hasFieldLevelPolicy'; - /** * Prefix for auxiliary relation field generated for delegated models */ diff --git a/packages/runtime/src/enhancements/policy/policy-utils.ts b/packages/runtime/src/enhancements/policy/policy-utils.ts index 1af05b03e..02bf87ebf 100644 --- a/packages/runtime/src/enhancements/policy/policy-utils.ts +++ b/packages/runtime/src/enhancements/policy/policy-utils.ts @@ -5,24 +5,14 @@ import { lowerCaseFirst } from 'lower-case-first'; import { upperCaseFirst } from 'upper-case-first'; import { ZodError } from 'zod'; import { fromZodError } from 'zod-validation-error'; -import { - CrudFailureReason, - FIELD_LEVEL_OVERRIDE_READ_GUARD_PREFIX, - FIELD_LEVEL_OVERRIDE_UPDATE_GUARD_PREFIX, - FIELD_LEVEL_READ_CHECKER_PREFIX, - FIELD_LEVEL_READ_CHECKER_SELECTOR, - FIELD_LEVEL_UPDATE_GUARD_PREFIX, - HAS_FIELD_LEVEL_POLICY_FLAG, - PRE_UPDATE_VALUE_SELECTOR, - PrismaErrorCode, -} from '../../constants'; +import { CrudFailureReason, PrismaErrorCode } from '../../constants'; import { enumerate, getFields, getModelFields, resolveField, zip, type FieldInfo, type ModelMeta } from '../../cross'; import { AuthUser, CrudContract, DbClientContract, PolicyCrudKind, PolicyOperationKind } from '../../types'; import { getVersion } from '../../version'; import type { EnhancementContext, InternalEnhancementOptions } from '../create-enhancement'; import { Logger } from '../logger'; import { QueryUtils } from '../query-utils'; -import type { CheckerFunc, InputCheckFunc, PolicyDef, ReadFieldCheckFunc, ZodSchemas } from '../types'; +import type { CheckerFunc, ModelPolicyDef, PolicyDef, PolicyFunc, ZodSchemas } from '../types'; import { formatObject, prismaClientKnownRequestError } from '../utils'; /** @@ -230,23 +220,32 @@ export class PolicyUtil extends QueryUtils { //#region Auth guard - private readonly FULLY_OPEN_AUTH_GUARD = { - create: true, - read: true, - update: true, - delete: true, - postUpdate: true, - create_input: true, - update_input: true, + private readonly FULL_OPEN_MODEL_POLICY: ModelPolicyDef = { + modelLevel: { + read: { guard: true }, + create: { guard: true, inputChecker: true }, + update: { guard: true }, + delete: { guard: true }, + postUpdate: { guard: true }, + }, }; - private getModelAuthGuard(model: string): PolicyDef['guard']['string'] { + private getModelPolicyDef(model: string): ModelPolicyDef { if (this.options.kinds && !this.options.kinds.includes('policy')) { // policy enhancement not enabled, return an fully open guard - return this.FULLY_OPEN_AUTH_GUARD; - } else { - return this.policy.guard[lowerCaseFirst(model)]; + return this.FULL_OPEN_MODEL_POLICY; + } + + const def = this.policy.policy[lowerCaseFirst(model)]; + if (!def) { + throw this.unknownError(`unable to load policy guard for ${model}`); } + return def; + } + + private getModelGuardForOperation(model: string, operation: PolicyOperationKind): PolicyFunc | boolean { + const def = this.getModelPolicyDef(model); + return def.modelLevel[operation].guard ?? true; } /** @@ -256,20 +255,15 @@ export class PolicyUtil extends QueryUtils { * otherwise returns a guard object */ getAuthGuard(db: CrudContract, model: string, operation: PolicyOperationKind, preValue?: any) { - const guard = this.getModelAuthGuard(model); - if (!guard) { - throw this.unknownError(`unable to load policy guard for ${model}`); - } + const guard = this.getModelGuardForOperation(model, operation); - const provider = guard[operation]; - if (typeof provider === 'boolean') { - return this.reduce(provider); + // constant guard + if (typeof guard === 'boolean') { + return this.reduce(guard); } - if (!provider) { - throw this.unknownError(`unable to load authorization guard for ${model}`); - } - const r = provider({ user: this.user, preValue }, db); + // invoke guard function + const r = guard({ user: this.user, preValue }, db); return this.reduce(r); } @@ -277,19 +271,19 @@ export class PolicyUtil extends QueryUtils { * Get field-level read auth guard that overrides the model-level */ getFieldOverrideReadAuthGuard(db: CrudContract, model: string, field: string) { - const guard = this.requireGuard(model); + const def = this.getModelPolicyDef(model); + const guard = def.fieldLevel?.read?.overrideGuard?.[field]; - const provider = guard[`${FIELD_LEVEL_OVERRIDE_READ_GUARD_PREFIX}${field}`]; - if (provider === undefined) { + if (guard === undefined) { // field access is denied by default in override mode return this.makeFalse(); } - if (typeof provider === 'boolean') { - return this.reduce(provider); + if (typeof guard === 'boolean') { + return this.reduce(guard); } - const r = provider({ user: this.user }, db); + const r = guard({ user: this.user }, db); return this.reduce(r); } @@ -297,19 +291,19 @@ export class PolicyUtil extends QueryUtils { * Get field-level update auth guard */ getFieldUpdateAuthGuard(db: CrudContract, model: string, field: string) { - const guard = this.requireGuard(model); + const def = this.getModelPolicyDef(model); + const guard = def.fieldLevel?.update?.guard?.[field]; - const provider = guard[`${FIELD_LEVEL_UPDATE_GUARD_PREFIX}${field}`]; - if (provider === undefined) { + if (guard === undefined) { // field access is allowed by default return this.makeTrue(); } - if (typeof provider === 'boolean') { - return this.reduce(provider); + if (typeof guard === 'boolean') { + return this.reduce(guard); } - const r = provider({ user: this.user }, db); + const r = guard({ user: this.user }, db); return this.reduce(r); } @@ -317,19 +311,19 @@ export class PolicyUtil extends QueryUtils { * Get field-level update auth guard that overrides the model-level */ getFieldOverrideUpdateAuthGuard(db: CrudContract, model: string, field: string) { - const guard = this.requireGuard(model); + const def = this.getModelPolicyDef(model); + const guard = def.fieldLevel?.update?.overrideGuard?.[field]; - const provider = guard[`${FIELD_LEVEL_OVERRIDE_UPDATE_GUARD_PREFIX}${field}`]; - if (provider === undefined) { + if (guard === undefined) { // field access is denied by default in override mode return this.makeFalse(); } - if (typeof provider === 'boolean') { - return this.reduce(provider); + if (typeof guard === 'boolean') { + return this.reduce(guard); } - const r = provider({ user: this.user }, db); + const r = guard({ user: this.user }, db); return this.reduce(r); } @@ -337,27 +331,20 @@ export class PolicyUtil extends QueryUtils { * Checks if the given model has a policy guard for the given operation. */ hasAuthGuard(model: string, operation: PolicyOperationKind) { - const guard = this.getModelAuthGuard(model); - if (!guard) { - return false; - } - const provider = guard[operation]; - return typeof provider !== 'boolean' || provider !== true; + const guard = this.getModelGuardForOperation(model, operation); + return typeof guard !== 'boolean' || guard !== true; } /** * Checks if the given model has any field-level override policy guard for the given operation. */ hasOverrideAuthGuard(model: string, operation: PolicyOperationKind) { - const guard = this.requireGuard(model); - switch (operation) { - case 'read': - return Object.keys(guard).some((k) => k.startsWith(FIELD_LEVEL_OVERRIDE_READ_GUARD_PREFIX)); - case 'update': - return Object.keys(guard).some((k) => k.startsWith(FIELD_LEVEL_OVERRIDE_UPDATE_GUARD_PREFIX)); - default: - return false; + if (operation !== 'read' && operation !== 'update') { + return false; } + const def = this.getModelPolicyDef(model); + const guard = def.fieldLevel?.[operation]?.overrideGuard; + return guard && Object.keys(guard).length > 0; } /** @@ -366,22 +353,18 @@ export class PolicyUtil extends QueryUtils { * @returns boolean if static analysis is enough to determine the result, undefined if not */ checkInputGuard(model: string, args: any, operation: 'create'): boolean | undefined { - const guard = this.getModelAuthGuard(model); - if (!guard) { - return undefined; - } + const def = this.getModelPolicyDef(model); - const provider: InputCheckFunc | boolean | undefined = guard[`${operation}_input` as const]; - - if (typeof provider === 'boolean') { - return provider; + const guard = def.modelLevel[operation].inputChecker; + if (guard === undefined) { + return undefined; } - if (!provider) { - return undefined; + if (typeof guard === 'boolean') { + return guard; } - return provider(args, { user: this.user }); + return guard(args, { user: this.user }); } /** @@ -569,34 +552,29 @@ export class PolicyUtil extends QueryUtils { * Gets checker constraints for the given model and operation. */ getCheckerConstraint(model: string, operation: PolicyCrudKind): ReturnType | boolean { - const checker = this.getModelChecker(model); - const provider = checker[operation]; - if (typeof provider === 'boolean') { - return provider; + if (this.options.kinds && !this.options.kinds.includes('policy')) { + // policy enhancement not enabled, return a constant true checker result + return true; } - if (typeof provider !== 'function') { - throw this.unknownError(`invalid ${operation} checker function for ${model}`); + const def = this.getModelPolicyDef(model); + const checker = def.modelLevel[operation].permissionChecker; + if (checker === undefined) { + throw new Error( + `Generated permission checkers not found. Please make sure the "generatePermissionChecker" option is set to true in the "@core/enhancer" plugin.` + ); } - // call checker function - return provider({ user: this.user }); - } + if (typeof checker === 'boolean') { + return checker; + } - private getModelChecker(model: string) { - if (this.options.kinds && !this.options.kinds.includes('policy')) { - // policy enhancement not enabled, return a constant true checker - return { create: true, read: true, update: true, delete: true }; - } else { - const result = this.options.policy.checker?.[lowerCaseFirst(model)]; - if (!result) { - // checker generation not enabled, return constant false checker - throw new Error( - `Generated permission checkers not found. Please make sure the "generatePermissionChecker" option is set to true in the "@core/enhancer" plugin.` - ); - } - return result; + if (typeof checker !== 'function') { + throw this.unknownError(`invalid ${operation} checker function for ${model}`); } + + // call checker function + return checker({ user: this.user }); } //#endregion @@ -974,7 +952,7 @@ export class PolicyUtil extends QueryUtils { if (this.hasFieldLevelPolicy(model)) { // recursively inject selection for fields needed for field-level read checks - const readFieldSelect = this.getReadFieldSelect(model); + const readFieldSelect = this.getFieldReadCheckSelector(model); if (readFieldSelect) { this.doInjectReadCheckSelect(model, args, { select: readFieldSelect }); } @@ -1091,32 +1069,24 @@ export class PolicyUtil extends QueryUtils { /** * Gets field selection for fetching pre-update entity values for the given model. */ - getPreValueSelect(model: string): object | undefined { - const guard = this.getModelAuthGuard(model); - if (!guard) { - throw this.unknownError(`unable to load policy guard for ${model}`); - } - return guard[PRE_UPDATE_VALUE_SELECTOR]; + getPreValueSelect(model: string) { + const def = this.getModelPolicyDef(model); + return def.modelLevel.postUpdate.preUpdateSelector; } - private getReadFieldSelect(model: string): object | undefined { - const guard = this.getModelAuthGuard(model); - if (!guard) { - throw this.unknownError(`unable to load policy guard for ${model}`); - } - return guard[FIELD_LEVEL_READ_CHECKER_SELECTOR]; + private getFieldReadCheckSelector(model: string) { + const def = this.getModelPolicyDef(model); + return def.fieldLevel?.read?.selector; } private checkReadField(model: string, field: string, entity: any) { - const guard = this.getModelAuthGuard(model); - if (!guard) { - throw this.unknownError(`unable to load policy guard for ${model}`); - } - const func = guard[`${FIELD_LEVEL_READ_CHECKER_PREFIX}${field}`] as ReadFieldCheckFunc | undefined; - if (!func) { + const def = this.getModelPolicyDef(model); + const guard = def.fieldLevel?.read?.checker?.[field]; + + if (guard === undefined) { return true; } else { - return func(entity, { user: this.user }); + return guard(entity, { user: this.user }); } } @@ -1125,11 +1095,8 @@ export class PolicyUtil extends QueryUtils { } private hasFieldLevelPolicy(model: string) { - const guard = this.getModelAuthGuard(model); - if (!guard) { - throw this.unknownError(`unable to load policy guard for ${model}`); - } - return !!guard[HAS_FIELD_LEVEL_POLICY_FLAG]; + const def = this.getModelPolicyDef(model); + return !!def.fieldLevel?.read?.checker; } /** @@ -1305,14 +1272,6 @@ export class PolicyUtil extends QueryUtils { } } - private requireGuard(model: string) { - const guard = this.getModelAuthGuard(model); - if (!guard) { - throw this.unknownError(`unable to load policy guard for ${model}`); - } - return guard; - } - /** * Given an entity data, returns an object only containing id fields. */ diff --git a/packages/runtime/src/enhancements/types.ts b/packages/runtime/src/enhancements/types.ts index 89d5ce9f6..aa14555b8 100644 --- a/packages/runtime/src/enhancements/types.ts +++ b/packages/runtime/src/enhancements/types.ts @@ -1,15 +1,6 @@ /* eslint-disable @typescript-eslint/no-explicit-any */ import { z } from 'zod'; -import { - FIELD_LEVEL_OVERRIDE_READ_GUARD_PREFIX, - FIELD_LEVEL_OVERRIDE_UPDATE_GUARD_PREFIX, - FIELD_LEVEL_READ_CHECKER_PREFIX, - FIELD_LEVEL_READ_CHECKER_SELECTOR, - FIELD_LEVEL_UPDATE_GUARD_PREFIX, - HAS_FIELD_LEVEL_POLICY_FLAG, - PRE_UPDATE_VALUE_SELECTOR, -} from '../constants'; -import type { CheckerContext, CrudContract, PolicyCrudKind, PolicyOperationKind, QueryContext } from '../types'; +import type { CheckerContext, CrudContract, QueryContext } from '../types'; /** * Common options for PrismaClient enhancements @@ -98,31 +89,8 @@ export type ReadFieldCheckFunc = (input: any, context: QueryContext) => boolean; * Policy definition */ export type PolicyDef = { - // Prisma query guards - guard: Record< - string, - // policy operation guard functions - Partial> & - // 'create_input' checker function - Partial> & - // field-level read checker functions or update guard functions - Record<`${typeof FIELD_LEVEL_READ_CHECKER_PREFIX}${string}`, ReadFieldCheckFunc> & - Record< - | `${typeof FIELD_LEVEL_OVERRIDE_READ_GUARD_PREFIX}${string}` - | `${typeof FIELD_LEVEL_UPDATE_GUARD_PREFIX}${string}` - | `${typeof FIELD_LEVEL_OVERRIDE_UPDATE_GUARD_PREFIX}${string}`, - PolicyFunc - > & { - // pre-update value selector - [PRE_UPDATE_VALUE_SELECTOR]?: object; - // field-level read checker selector - [FIELD_LEVEL_READ_CHECKER_SELECTOR]?: object; - // flag that indicates if the model has field-level access control - [HAS_FIELD_LEVEL_POLICY_FLAG]?: boolean; - } - >; - - checker?: Record>; + // policy definitions for each model + policy: Record; // tracks which models have data validation rules validation: Record; @@ -131,10 +99,135 @@ export type PolicyDef = { authSelector?: object; }; +type ModelName = string; +type FieldName = string; + +/** + * Policy definition for a model + */ +export type ModelPolicyDef = { + /** + * Model-level CRUD policies + */ + modelLevel: ModelCrudDef; + + /** + * Field-level CRUD policies + */ + fieldLevel?: FieldCrudDef; +}; + +/** + * CRUD policy definitions for a model + */ +export type ModelCrudDef = { + read: ModelReadDef; + create: ModelCreateDef; + update: ModelUpdateDef; + delete: ModelDeleteDef; + postUpdate: ModelPostUpdateDef; +}; + +/** + * Common policy definition for a CRUD operation + */ +type ModelCrudCommon = { + /** + * Prisma query guard or a constant condition + */ + guard: PolicyFunc | boolean; + + /** + * Permission checker function or a constant condition + */ + permissionChecker?: CheckerFunc | boolean; +}; + +/** + * Policy definition for reading a model + */ +type ModelReadDef = ModelCrudCommon; + +/** + * Policy definition for creating a model + */ +type ModelCreateDef = ModelCrudCommon & { + /** + * Create input validation function. Only generated when a create + * can be approved or denied based on input values. + */ + inputChecker?: InputCheckFunc | boolean; +}; + +/** + * Policy definition for updating a model + */ +type ModelUpdateDef = ModelCrudCommon; + +/** + * Policy definition for deleting a model + */ +type ModelDeleteDef = ModelCrudCommon; + +/** + * Policy definition for post-update checking a model + */ +type ModelPostUpdateDef = { + guard: PolicyFunc | boolean; + preUpdateSelector?: object; +}; + +/** + * CRUD policy definitions for a field + */ +type FieldCrudDef = { + /** + * Field-level read policy + */ + read?: { + /** + * Selector for reading fields needed for evaluating the policy + */ + selector?: object; + + /** + * Field-level Prisma query guard + */ + checker?: Record; + + /** + * Field-level read override Prisma query guard + */ + overrideGuard?: Record; + }; + + /** + * Field-level update policy + */ + update?: { + /** + * Field-level update Prisma query guard + */ + guard?: Record; + + /** + * Field-level update override Prisma query guard + */ + overrideGuard?: Record; + }; +}; + /** * Zod schemas for validation */ export type ZodSchemas = { + /** + * Zod schema for each model + */ models: Record; + + /** + * Zod schema for Prisma input types for each model + */ input?: Record>; }; diff --git a/packages/schema/src/plugins/enhancer/policy/index.ts b/packages/schema/src/plugins/enhancer/policy/index.ts index 8eaf1d00b..918bfba8c 100644 --- a/packages/schema/src/plugins/enhancer/policy/index.ts +++ b/packages/schema/src/plugins/enhancer/policy/index.ts @@ -4,5 +4,5 @@ import type { Project } from 'ts-morph'; import { PolicyGenerator } from './policy-guard-generator'; export async function generate(model: Model, options: PluginOptions, project: Project, outDir: string) { - return new PolicyGenerator().generate(project, model, options, outDir); + return new PolicyGenerator(options).generate(project, model, outDir); } diff --git a/packages/schema/src/plugins/enhancer/policy/policy-guard-generator.ts b/packages/schema/src/plugins/enhancer/policy/policy-guard-generator.ts index a36a52126..619543e44 100644 --- a/packages/schema/src/plugins/enhancer/policy/policy-guard-generator.ts +++ b/packages/schema/src/plugins/enhancer/policy/policy-guard-generator.ts @@ -1,111 +1,57 @@ import { DataModel, DataModelField, - Enum, Expression, Model, isDataModel, isDataModelField, isEnum, - isExpression, - isInvocationExpr, isMemberAccessExpr, isReferenceExpr, isThisExpr, } from '@zenstackhq/language/ast'; -import { - FIELD_LEVEL_OVERRIDE_READ_GUARD_PREFIX, - FIELD_LEVEL_OVERRIDE_UPDATE_GUARD_PREFIX, - FIELD_LEVEL_READ_CHECKER_PREFIX, - FIELD_LEVEL_READ_CHECKER_SELECTOR, - FIELD_LEVEL_UPDATE_GUARD_PREFIX, - HAS_FIELD_LEVEL_POLICY_FLAG, - PRE_UPDATE_VALUE_SELECTOR, - type PolicyKind, - type PolicyOperationKind, -} from '@zenstackhq/runtime'; +import { PolicyCrudKind, type PolicyOperationKind } from '@zenstackhq/runtime'; import { ExpressionContext, - PluginError, PluginOptions, + PolicyAnalysisResult, RUNTIME_PACKAGE, TypeScriptExpressionTransformer, - TypeScriptExpressionTransformerError, analyzePolicies, - getAttributeArg, - getAuthModel, getDataModels, - getIdFields, - getLiteral, hasAttribute, hasValidationAttributes, isAuthInvocation, - isEnumFieldReference, isForeignKeyField, - isFromStdlib, - isFutureExpr, - resolved, } from '@zenstackhq/sdk'; import { getPrismaClientImportSpec } from '@zenstackhq/sdk/prisma'; -import { streamAllContents, streamAst, streamContents } from 'langium'; +import { streamAst } from 'langium'; import { lowerCaseFirst } from 'lower-case-first'; import path from 'path'; -import { FunctionDeclaration, Project, SourceFile, VariableDeclarationKind, WriterFunction } from 'ts-morph'; -import { name } from '..'; -import { isCollectionPredicate } from '../../../utils/ast-utils'; -import { ALL_OPERATION_KINDS } from '../../plugin-utils'; +import { CodeBlockWriter, Project, SourceFile, VariableDeclarationKind, WriterFunction } from 'ts-morph'; import { ConstraintTransformer } from './constraint-transformer'; -import { ExpressionWriter, FALSE, TRUE } from './expression-writer'; +import { + generateNormalizedAuthRef, + generateQueryGuardFunction, + generateSelectForRules, + getPolicyExpressions, + isEnumReferenced, +} from './utils'; /** * Generates source file that contains Prisma query guard objects used for injecting database queries */ export class PolicyGenerator { - async generate(project: Project, model: Model, options: PluginOptions, output: string) { + constructor(private options: PluginOptions) {} + + async generate(project: Project, model: Model, output: string) { const sf = project.createSourceFile(path.join(output, 'policy.ts'), undefined, { overwrite: true }); sf.addStatements('/* eslint-disable */'); - sf.addImportDeclaration({ - namedImports: [ - { name: 'type QueryContext' }, - { name: 'type CrudContract' }, - { name: 'allFieldsEqual' }, - { name: 'type PolicyDef' }, - { name: 'type CheckerContext' }, - { name: 'type CheckerConstraint' }, - ], - moduleSpecifier: `${RUNTIME_PACKAGE}`, - }); - - // import enums - const prismaImport = getPrismaClientImportSpec(output, options); - for (const e of model.declarations.filter((d) => isEnum(d) && this.isEnumReferenced(model, d))) { - sf.addImportDeclaration({ - namedImports: [{ name: e.name }], - moduleSpecifier: prismaImport, - }); - } + this.writeImports(model, output, sf); const models = getDataModels(model); - // policy guard functions - const policyMap: Record> = {}; - for (const model of models) { - policyMap[model.name] = await this.generateQueryGuardForModel(model, sf); - } - - const generatePermissionChecker = options.generatePermissionChecker === true; - - // CRUD checker functions - const checkerMap: Record> = {}; - if (generatePermissionChecker) { - for (const model of models) { - checkerMap[model.name] = await this.generateCheckerForModel(model, sf); - } - } - - const authSelector = this.generateAuthSelector(models); - sf.addVariableStatement({ declarationKind: VariableDeclarationKind.Const, declarations: [ @@ -114,55 +60,9 @@ export class PolicyGenerator { type: 'PolicyDef', initializer: (writer) => { writer.block(() => { - writer.write('guard:'); - writer.inlineBlock(() => { - for (const [model, map] of Object.entries(policyMap)) { - writer.write(`${lowerCaseFirst(model)}:`); - writer.inlineBlock(() => { - for (const [op, func] of Object.entries(map)) { - if (typeof func === 'object') { - writer.write(`${op}: ${JSON.stringify(func)},`); - } else { - writer.write(`${op}: ${func},`); - } - } - }); - writer.write(','); - } - }); - writer.writeLine(','); - - if (generatePermissionChecker) { - writer.write('checker:'); - writer.inlineBlock(() => { - for (const [model, map] of Object.entries(checkerMap)) { - writer.write(`${lowerCaseFirst(model)}:`); - writer.inlineBlock(() => { - Object.entries(map).forEach(([op, func]) => { - writer.write(`${op}: ${func},`); - }); - }); - writer.writeLine(','); - } - }); - writer.writeLine(','); - } - - writer.write('validation:'); - writer.inlineBlock(() => { - for (const model of models) { - writer.write(`${lowerCaseFirst(model.name)}:`); - writer.inlineBlock(() => { - writer.write(`hasValidation: ${hasValidationAttributes(model)}`); - }); - writer.writeLine(','); - } - }); - - if (authSelector) { - writer.writeLine(','); - writer.write(`authSelector: ${JSON.stringify(authSelector)}`); - } + this.writePolicy(writer, models, sf); + this.writeValidationMeta(writer, models); + this.writeAuthSelector(models, writer); }); }, }, @@ -172,242 +72,172 @@ export class PolicyGenerator { sf.addStatements('export default policy'); // save ts files if requested explicitly or the user provided - const preserveTsFiles = options.preserveTsFiles === true || !!options.output; + const preserveTsFiles = this.options.preserveTsFiles === true || !!this.options.output; if (preserveTsFiles) { await sf.save(); } } - // Generates a { select: ... } object to select `auth()` fields used in policy rules - private generateAuthSelector(models: DataModel[]) { - const authRules: Expression[] = []; + private writeImports(model: Model, output: string, sf: SourceFile) { + sf.addImportDeclaration({ + namedImports: [ + { name: 'type QueryContext' }, + { name: 'type CrudContract' }, + { name: 'allFieldsEqual' }, + { name: 'type PolicyDef' }, + { name: 'type CheckerContext' }, + { name: 'type CheckerConstraint' }, + ], + moduleSpecifier: `${RUNTIME_PACKAGE}`, + }); - models.forEach((model) => { - // model-level rules - const modelPolicyAttrs = model.attributes.filter((attr) => - ['@@allow', '@@deny'].includes(attr.decl.$refText) - ); + // import enums + const prismaImport = getPrismaClientImportSpec(output, this.options); + for (const e of model.declarations.filter((d) => isEnum(d) && isEnumReferenced(model, d))) { + sf.addImportDeclaration({ + namedImports: [{ name: e.name }], + moduleSpecifier: prismaImport, + }); + } + } - // field-level rules - const fieldPolicyAttrs = model.fields - .flatMap((f) => f.attributes) - .filter((attr) => ['@allow', '@deny'].includes(attr.decl.$refText)); + private writePolicy(writer: CodeBlockWriter, models: DataModel[], sourceFile: SourceFile) { + writer.write('policy:'); + writer.inlineBlock(() => { + for (const model of models) { + writer.write(`${lowerCaseFirst(model.name)}:`); - // all rule expression - const allExpressions = [...modelPolicyAttrs, ...fieldPolicyAttrs] - .filter((attr) => attr.args.length > 1) - .map((attr) => attr.args[1].value); + writer.block(() => { + // model-level guards + this.writeModelLevelDefs(model, writer, sourceFile); - // collect `auth()` member access - allExpressions.forEach((rule) => { - streamAst(rule).forEach((node) => { - if (isMemberAccessExpr(node) && isAuthInvocation(node.operand)) { - authRules.push(node); - } + // field-level guards + this.writeFieldLevelDefs(model, writer, sourceFile); }); - }); - }); - - if (authRules.length > 0) { - return this.generateSelectForRules(authRules, true); - } else { - return undefined; - } - } - private isEnumReferenced(model: Model, decl: Enum): unknown { - return streamAllContents(model).some((node) => { - if (isDataModelField(node) && node.type.reference?.ref === decl) { - // referenced as field type - return true; + writer.writeLine(','); } - if (isEnumFieldReference(node) && node.target.ref?.$container === decl) { - // enum field is referenced - return true; - } - return false; }); + writer.writeLine(','); } - private getPolicyExpressions( - target: DataModel | DataModelField, - kind: PolicyKind, - operation: PolicyOperationKind, - override = false - ) { - const attributes = target.attributes; - const attrName = isDataModel(target) ? `@@${kind}` : `@${kind}`; - const attrs = attributes.filter((attr) => { - if (attr.decl.ref?.name !== attrName) { - return false; - } + // #region Model-level definitions - if (override) { - const overrideArg = getAttributeArg(attr, 'override'); - return overrideArg && getLiteral(overrideArg) === true; - } else { - return true; - } + // writes model-level policy def for each operation kind for a model + // `[modelName]: { [operationKind]: [funcName] },` + private writeModelLevelDefs(model: DataModel, writer: CodeBlockWriter, sourceFile: SourceFile) { + const policies = analyzePolicies(model); + writer.write('modelLevel:'); + writer.inlineBlock(() => { + this.writeModelReadDef(model, policies, writer, sourceFile); + this.writeModelCreateDef(model, policies, writer, sourceFile); + this.writeModelUpdateDef(model, policies, writer, sourceFile); + this.writeModelPostUpdateDef(model, policies, writer, sourceFile); + this.writeModelDeleteDef(model, policies, writer, sourceFile); }); + writer.writeLine(','); + } - const checkOperation = operation === 'postUpdate' ? 'update' : operation; - - let result = attrs - .filter((attr) => { - const opsValue = getLiteral(attr.args[0].value); - if (!opsValue) { - return false; - } - const ops = opsValue.split(',').map((s) => s.trim()); - return ops.includes(checkOperation) || ops.includes('all'); - }) - .map((attr) => attr.args[1].value); - - if (operation === 'update') { - result = this.processUpdatePolicies(result, false); - } else if (operation === 'postUpdate') { - result = this.processUpdatePolicies(result, true); - } - - return result; + // writes `read: ...` for a given model + private writeModelReadDef( + model: DataModel, + policies: PolicyAnalysisResult, + writer: CodeBlockWriter, + sourceFile: SourceFile + ) { + writer.write(`read:`); + writer.inlineBlock(() => { + this.writeCommonModelDef(model, 'read', policies, writer, sourceFile); + }); + writer.writeLine(','); } - private processUpdatePolicies(expressions: Expression[], postUpdate: boolean) { - const hasFutureReference = expressions.some((expr) => this.hasFutureReference(expr)); - if (postUpdate) { - // when compiling post-update rules, if any rule contains `future()` reference, - // we include all as post-update rules - return hasFutureReference ? expressions : []; - } else { - // when compiling pre-update rules, if any rule contains `future()` reference, - // we completely skip pre-update check and defer them to post-update - return hasFutureReference ? [] : expressions; - } + // writes `create: ...` for a given model + private writeModelCreateDef( + model: DataModel, + policies: PolicyAnalysisResult, + writer: CodeBlockWriter, + sourceFile: SourceFile + ) { + writer.write(`create:`); + writer.inlineBlock(() => { + this.writeCommonModelDef(model, 'create', policies, writer, sourceFile); + + // create policy has an additional input checker for validating the payload + this.writeCreateInputChecker(model, writer, sourceFile); + }); + writer.writeLine(','); } - private hasFutureReference(expr: Expression) { - for (const node of streamAst(expr)) { - if (isInvocationExpr(node) && node.function.ref?.name === 'future' && isFromStdlib(node.function.ref)) { - return true; - } + // writes `inputChecker: [funcName]` for a given model + private writeCreateInputChecker(model: DataModel, writer: CodeBlockWriter, sourceFile: SourceFile) { + const allows = getPolicyExpressions(model, 'allow', 'create'); + const denies = getPolicyExpressions(model, 'deny', 'create'); + if (this.canCheckCreateBasedOnInput(model, allows, denies)) { + const inputCheckFunc = this.generateCreateInputCheckerFunction(model, allows, denies, sourceFile); + writer.write(`inputChecker: ${inputCheckFunc.getName()!},`); } - return false; } - private async generateQueryGuardForModel(model: DataModel, sourceFile: SourceFile) { - const result: Record = {}; - - // eslint-disable-next-line @typescript-eslint/no-explicit-any - const policies: any = analyzePolicies(model); - - for (const kind of ALL_OPERATION_KINDS) { - if (policies[kind] === true || policies[kind] === false) { - result[kind] = policies[kind]; - if (kind === 'create') { - result[kind + '_input'] = policies[kind]; - } - continue; - } - - const denies = this.getPolicyExpressions(model, 'deny', kind); - const allows = this.getPolicyExpressions(model, 'allow', kind); - - if (kind === 'update' && allows.length === 0) { - // no allow rule for 'update', policy is constant based on if there's - // post-update counterpart - if (this.getPolicyExpressions(model, 'allow', 'postUpdate').length === 0) { - result[kind] = false; - continue; - } else { - result[kind] = true; - continue; + private canCheckCreateBasedOnInput(model: DataModel, allows: Expression[], denies: Expression[]) { + return [...allows, ...denies].every((rule) => { + return streamAst(rule).every((expr) => { + if (isThisExpr(expr)) { + return false; } - } - - if (kind === 'postUpdate' && allows.length === 0 && denies.length === 0) { - // no rule 'postUpdate', always allow - result[kind] = true; - continue; - } + if (isReferenceExpr(expr)) { + if (isDataModel(expr.$resolvedType?.decl)) { + // if policy rules uses relation fields, + // we can't check based on create input + return false; + } - const guardFunc = this.generateQueryGuardFunction(sourceFile, model, kind, allows, denies); - result[kind] = guardFunc.getName()!; + if ( + isDataModelField(expr.target.ref) && + expr.target.ref.$container === model && + hasAttribute(expr.target.ref, '@default') + ) { + // reference to field of current model + // if it has default value, we can't check + // based on create input + return false; + } - if (kind === 'postUpdate') { - const preValueSelect = this.generateSelectForRules([...allows, ...denies]); - if (preValueSelect) { - result[PRE_UPDATE_VALUE_SELECTOR] = preValueSelect; + if (isDataModelField(expr.target.ref) && isForeignKeyField(expr.target.ref)) { + // reference to foreign key field + // we can't check based on create input + return false; + } } - } - - if (kind === 'create' && this.canCheckCreateBasedOnInput(model, allows, denies)) { - const inputCheckFunc = this.generateInputCheckFunction(sourceFile, model, kind, allows, denies); - result[kind + '_input'] = inputCheckFunc.getName()!; - } - } - - // generate field read checkers - this.generateReadFieldsCheckers(model, sourceFile, result); - - // generate field read override guards - this.generateReadFieldsOverrideGuards(model, sourceFile, result); - // generate field update guards - this.generateUpdateFieldsGuards(model, sourceFile, result); - - return result; + return true; + }); + }); } - private generateReadFieldsCheckers( + // generates a function for checking "create" input + private generateCreateInputCheckerFunction( model: DataModel, - sourceFile: SourceFile, - result: Record - ) { - const allFieldsAllows: Expression[] = []; - const allFieldsDenies: Expression[] = []; - - for (const field of model.fields) { - const allows = this.getPolicyExpressions(field, 'allow', 'read'); - const denies = this.getPolicyExpressions(field, 'deny', 'read'); - if (denies.length === 0 && allows.length === 0) { - continue; - } - - allFieldsAllows.push(...allows); - allFieldsDenies.push(...denies); - - const guardFunc = this.generateReadFieldCheckerFunction(sourceFile, field, allows, denies); - // eslint-disable-next-line @typescript-eslint/no-non-null-assertion - result[`${FIELD_LEVEL_READ_CHECKER_PREFIX}${field.name}`] = guardFunc.getName()!; - } - - if (allFieldsAllows.length > 0 || allFieldsDenies.length > 0) { - result[HAS_FIELD_LEVEL_POLICY_FLAG] = true; - const readFieldCheckSelect = this.generateSelectForRules([...allFieldsAllows, ...allFieldsDenies]); - if (readFieldCheckSelect) { - result[FIELD_LEVEL_READ_CHECKER_SELECTOR] = readFieldCheckSelect; - } - } - } - - private generateReadFieldCheckerFunction( - sourceFile: SourceFile, - field: DataModelField, allows: Expression[], - denies: Expression[] + denies: Expression[], + sourceFile: SourceFile ) { const statements: (string | WriterFunction)[] = []; - this.generateNormalizedAuthRef(field.$container as DataModel, allows, denies, statements); + generateNormalizedAuthRef(model, allows, denies, statements); - // compile rules down to typescript expressions statements.push((writer) => { + if (allows.length === 0) { + writer.write('return false;'); + return; + } + const transformer = new TypeScriptExpressionTransformer({ context: ExpressionContext.AccessPolicy, fieldReferenceContext: 'input', }); - const denyStmt = + let expr = denies.length > 0 ? '!(' + denies @@ -418,34 +248,18 @@ export class PolicyGenerator { ')' : undefined; - const allowStmt = - allows.length > 0 - ? '(' + - allows - .map((allow) => { - return transformer.transform(allow); - }) - .join(' || ') + - ')' - : undefined; - - let expr: string | undefined; - - if (denyStmt && allowStmt) { - expr = `${denyStmt} && ${allowStmt}`; - } else if (denyStmt) { - expr = denyStmt; - } else if (allowStmt) { - expr = allowStmt; - } else { - throw new Error('should not happen'); - } + const allowStmt = allows + .map((allow) => { + return transformer.transform(allow); + }) + .join(' || '); + expr = expr ? `${expr} && (${allowStmt})` : allowStmt; writer.write('return ' + expr); }); const func = sourceFile.addFunction({ - name: `${field.$container.name}$${field.name}_read`, + name: model.name + '_create_input', returnType: 'boolean', parameters: [ { @@ -463,323 +277,177 @@ export class PolicyGenerator { return func; } - private generateReadFieldsOverrideGuards( + // writes `update: ...` for a given model + private writeModelUpdateDef( model: DataModel, - sourceFile: SourceFile, - result: Record + policies: PolicyAnalysisResult, + writer: CodeBlockWriter, + sourceFile: SourceFile ) { - for (const field of model.fields) { - const overrideAllows = this.getPolicyExpressions(field, 'allow', 'read', true); - if (overrideAllows.length > 0) { - const denies = this.getPolicyExpressions(field, 'deny', 'read'); - const overrideGuardFunc = this.generateQueryGuardFunction( - sourceFile, - model, - 'read', - overrideAllows, - denies, - field, - true - ); - result[`${FIELD_LEVEL_OVERRIDE_READ_GUARD_PREFIX}${field.name}`] = overrideGuardFunc.getName()!; - } + writer.write(`update:`); + writer.inlineBlock(() => { + this.writeCommonModelDef(model, 'update', policies, writer, sourceFile); + }); + writer.writeLine(','); + } + + // writes `postUpdate: ...` for a given model + private writeModelPostUpdateDef( + model: DataModel, + policies: PolicyAnalysisResult, + writer: CodeBlockWriter, + sourceFile: SourceFile + ) { + writer.write(`postUpdate:`); + writer.inlineBlock(() => { + this.writeCommonModelDef(model, 'postUpdate', policies, writer, sourceFile); + + // post-update policy has an additional selector for reading the pre-update entity data + this.writePostUpdatePreValueSelector(model, writer); + }); + writer.writeLine(','); + } + + private writePostUpdatePreValueSelector(model: DataModel, writer: CodeBlockWriter) { + const allows = getPolicyExpressions(model, 'allow', 'postUpdate'); + const denies = getPolicyExpressions(model, 'deny', 'postUpdate'); + const preValueSelect = generateSelectForRules([...allows, ...denies]); + if (preValueSelect) { + writer.writeLine(`preUpdateSelector: ${JSON.stringify(preValueSelect)},`); } } - private generateUpdateFieldsGuards( + // writes `delete: ...` for a given model + private writeModelDeleteDef( model: DataModel, - sourceFile: SourceFile, - result: Record + policies: PolicyAnalysisResult, + writer: CodeBlockWriter, + sourceFile: SourceFile ) { - for (const field of model.fields) { - const allows = this.getPolicyExpressions(field, 'allow', 'update'); - const denies = this.getPolicyExpressions(field, 'deny', 'update'); + writer.write(`delete:`); + writer.inlineBlock(() => { + this.writeCommonModelDef(model, 'delete', policies, writer, sourceFile); + }); + } - if (denies.length === 0 && allows.length === 0) { - continue; - } + // writes `[kind]: ...` for a given model + private writeCommonModelDef( + model: DataModel, + kind: PolicyOperationKind, + policies: PolicyAnalysisResult, + writer: CodeBlockWriter, + sourceFile: SourceFile + ) { + const allows = getPolicyExpressions(model, 'allow', kind); + const denies = getPolicyExpressions(model, 'deny', kind); - const guardFunc = this.generateQueryGuardFunction(sourceFile, model, 'update', allows, denies, field); - // eslint-disable-next-line @typescript-eslint/no-non-null-assertion - result[`${FIELD_LEVEL_UPDATE_GUARD_PREFIX}${field.name}`] = guardFunc.getName()!; + // policy guard + this.writePolicyGuard(model, kind, policies, allows, denies, writer, sourceFile); - const overrideAllows = this.getPolicyExpressions(field, 'allow', 'update', true); - if (overrideAllows.length > 0) { - const overrideGuardFunc = this.generateQueryGuardFunction( - sourceFile, - model, - 'update', - overrideAllows, - denies, - field, - true - ); - result[`${FIELD_LEVEL_OVERRIDE_UPDATE_GUARD_PREFIX}${field.name}`] = overrideGuardFunc.getName()!; - } + // permission checker + if (kind !== 'postUpdate') { + this.writePermissionChecker(model, kind, policies, allows, denies, writer, sourceFile); } } - private canCheckCreateBasedOnInput(model: DataModel, allows: Expression[], denies: Expression[]) { - return [...allows, ...denies].every((rule) => { - return streamAst(rule).every((expr) => { - if (isThisExpr(expr)) { - return false; - } - if (isReferenceExpr(expr)) { - if (isDataModel(expr.$resolvedType?.decl)) { - // if policy rules uses relation fields, - // we can't check based on create input - return false; - } - - if ( - isDataModelField(expr.target.ref) && - expr.target.ref.$container === model && - hasAttribute(expr.target.ref, '@default') - ) { - // reference to field of current model - // if it has default value, we can't check - // based on create input - return false; - } - - if (isDataModelField(expr.target.ref) && isForeignKeyField(expr.target.ref)) { - // reference to foreign key field - // we can't check based on create input - return false; - } - } - - return true; - }); - }); - } - - // generates a "select" object that contains (recursively) fields referenced by the - // given policy rules - private generateSelectForRules(rules: Expression[], forAuthContext = false): object { - // eslint-disable-next-line @typescript-eslint/no-explicit-any - const result: any = {}; - const addPath = (path: string[]) => { - let curr = result; - path.forEach((seg, i) => { - if (i === path.length - 1) { - curr[seg] = true; - } else { - if (!curr[seg]) { - curr[seg] = { select: {} }; - } - curr = curr[seg].select; - } - }); - }; - - // visit a reference or member access expression to build a - // selection path - const visit = (node: Expression): string[] | undefined => { - if (isThisExpr(node)) { - return []; + // writes `guard: ...` for a given policy operation kind + private writePolicyGuard( + model: DataModel, + kind: PolicyOperationKind, + policies: ReturnType, + allows: Expression[], + denies: Expression[], + writer: CodeBlockWriter, + sourceFile: SourceFile + ) { + if (kind === 'update' && allows.length === 0) { + // no allow rule for 'update', policy is constant based on if there's + // post-update counterpart + if (getPolicyExpressions(model, 'allow', 'postUpdate').length === 0) { + writer.write(`guard: false,`); + } else { + writer.write(`guard: true,`); } + return; + } - if (isReferenceExpr(node)) { - const target = resolved(node.target); - if (isDataModelField(target)) { - // a field selection, it's a terminal - return [target.name]; - } - } + if (kind === 'postUpdate' && allows.length === 0 && denies.length === 0) { + // no 'postUpdate' rule, always allow + writer.write(`guard: true,`); + return; + } - if (isMemberAccessExpr(node)) { - if (forAuthContext && isAuthInvocation(node.operand)) { - return [node.member.$refText]; - } + if (kind in policies && typeof policies[kind as keyof typeof policies] === 'boolean') { + // constant policy + writer.write(`guard: ${policies[kind as keyof typeof policies]},`); + return; + } - if (isFutureExpr(node.operand)) { - // future().field is not subject to pre-update select - return undefined; - } + // generate a policy function that evaluates a partial prisma query + const guardFunc = generateQueryGuardFunction(sourceFile, model, kind, allows, denies); + writer.write(`guard: ${guardFunc.getName()!},`); + } - // build a selection path inside-out for chained member access - const inner = visit(node.operand); - if (inner) { - return [...inner, node.member.$refText]; - } - } + // writes `permissionChecker: ...` for a given policy operation kind + private writePermissionChecker( + model: DataModel, + kind: PolicyCrudKind, + policies: PolicyAnalysisResult, + allows: Expression[], + denies: Expression[], + writer: CodeBlockWriter, + sourceFile: SourceFile + ) { + if (this.options.generatePermissionChecker !== true) { + return; + } - return undefined; - }; - - // collect selection paths from the given expression - const collectReferencePaths = (expr: Expression): string[][] => { - if (isThisExpr(expr) && !isMemberAccessExpr(expr.$container)) { - // a standalone `this` expression, include all id fields - const model = expr.$resolvedType?.decl as DataModel; - const idFields = getIdFields(model); - return idFields.map((field) => [field.name]); - } + if (policies[kind] === true || policies[kind] === false) { + // constant policy + writer.write(`permissionChecker: ${policies[kind]},`); + return; + } - if (isMemberAccessExpr(expr) || isReferenceExpr(expr)) { - const path = visit(expr); - if (path) { - if (isDataModel(expr.$resolvedType?.decl)) { - // member selection ended at a data model field, include its id fields - const idFields = getIdFields(expr.$resolvedType?.decl as DataModel); - return idFields.map((field) => [...path, field.name]); - } else { - return [path]; - } - } else { - return []; - } - } else if (isCollectionPredicate(expr)) { - const path = visit(expr.left); - if (path) { - // recurse into RHS - const rhs = collectReferencePaths(expr.right); - // combine path of LHS and RHS - return rhs.map((r) => [...path, ...r]); - } else { - return []; - } - } else if (isInvocationExpr(expr)) { - // recurse into function arguments - return expr.args.flatMap((arg) => collectReferencePaths(arg.value)); + if (kind === 'update' && allows.length === 0) { + // no allow rule for 'update', policy is constant based on if there's + // post-update counterpart + if (getPolicyExpressions(model, 'allow', 'postUpdate').length === 0) { + writer.write(`permissionChecker: false,`); + return; } else { - // recurse - const children = streamContents(expr) - .filter((child): child is Expression => isExpression(child)) - .toArray(); - return children.flatMap((child) => collectReferencePaths(child)); + writer.write(`permissionChecker: true,`); + return; } - }; - - for (const rule of rules) { - const paths = collectReferencePaths(rule); - paths.forEach((p) => addPath(p)); } - return Object.keys(result).length === 0 ? undefined : result; + const guardFunc = this.generatePermissionCheckerFunction(model, kind, allows, denies, sourceFile); + writer.write(`permissionChecker: ${guardFunc.getName()!},`); } - private generateQueryGuardFunction( - sourceFile: SourceFile, + private generatePermissionCheckerFunction( model: DataModel, - kind: PolicyOperationKind, + kind: string, allows: Expression[], denies: Expression[], - forField?: DataModelField, - fieldOverride = false + sourceFile: SourceFile ) { - const statements: (string | WriterFunction)[] = []; + const statements: string[] = []; - this.generateNormalizedAuthRef(model, allows, denies, statements); - - const hasFieldAccess = [...denies, ...allows].some((rule) => - streamAst(rule).some( - (child) => - // this.??? - isThisExpr(child) || - // future().??? - isFutureExpr(child) || - // field reference - (isReferenceExpr(child) && isDataModelField(child.target.ref)) - ) - ); - - if (!hasFieldAccess) { - // none of the rules reference model fields, we can compile down to a plain boolean - // function in this case (so we can skip doing SQL queries when validating) - statements.push((writer) => { - const transformer = new TypeScriptExpressionTransformer({ - context: ExpressionContext.AccessPolicy, - isPostGuard: kind === 'postUpdate', - }); - try { - denies.forEach((rule) => { - writer.write(`if (${transformer.transform(rule, false)}) { return ${FALSE}; }`); - }); - allows.forEach((rule) => { - writer.write(`if (${transformer.transform(rule, false)}) { return ${TRUE}; }`); - }); - } catch (err) { - if (err instanceof TypeScriptExpressionTransformerError) { - throw new PluginError(name, err.message); - } else { - throw err; - } - } + generateNormalizedAuthRef(model, allows, denies, statements); - if (forField) { - if (allows.length === 0) { - // if there's no allow rule, for field-level rules, by default we allow - writer.write(`return ${TRUE};`); - } else { - // if there's any allow rule, we deny unless any allow rule evaluates to true - writer.write(`return ${FALSE};`); - } - } else { - // for model-level rules, the default is always deny - writer.write(`return ${FALSE};`); - } - }); - } else { - statements.push((writer) => { - writer.write('return '); - const exprWriter = new ExpressionWriter(writer, kind === 'postUpdate'); - const writeDenies = () => { - writer.conditionalWrite(denies.length > 1, '{ AND: ['); - denies.forEach((expr, i) => { - writer.inlineBlock(() => { - writer.write('NOT: '); - exprWriter.write(expr); - }); - writer.conditionalWrite(i !== denies.length - 1, ','); - }); - writer.conditionalWrite(denies.length > 1, ']}'); - }; - - const writeAllows = () => { - writer.conditionalWrite(allows.length > 1, '{ OR: ['); - allows.forEach((expr, i) => { - exprWriter.write(expr); - writer.conditionalWrite(i !== allows.length - 1, ','); - }); - writer.conditionalWrite(allows.length > 1, ']}'); - }; - - if (allows.length > 0 && denies.length > 0) { - // include both allow and deny rules - writer.write('{ AND: ['); - writeDenies(); - writer.write(','); - writeAllows(); - writer.write(']}'); - } else if (denies.length > 0) { - // only deny rules - writeDenies(); - } else if (allows.length > 0) { - // only allow rules - writeAllows(); - } else { - // disallow any operation - writer.write(`{ OR: [] }`); - } - writer.write(';'); - }); - } + const transformed = new ConstraintTransformer({ + authAccessor: 'user', + }).transformRules(allows, denies); + + statements.push(`return ${transformed};`); const func = sourceFile.addFunction({ - name: `${model.name}${forField ? '$' + forField.name : ''}${fieldOverride ? '$override' : ''}_${kind}`, - returnType: 'any', + name: `${model.name}$checker$${kind}`, + returnType: 'CheckerConstraint', parameters: [ { name: 'context', - type: 'QueryContext', - }, - { - // for generating field references used by field comparison in the same model - name: 'db', - type: 'CrudContract', + type: 'CheckerContext', }, ], statements, @@ -788,29 +456,166 @@ export class PolicyGenerator { return func; } - private generateInputCheckFunction( + // #endregion + + // #region Field-level definitions + + private writeFieldLevelDefs(model: DataModel, writer: CodeBlockWriter, sf: SourceFile) { + writer.write('fieldLevel:'); + writer.inlineBlock(() => { + this.writeFieldReadDef(model, writer, sf); + this.writeFieldUpdateDef(model, writer, sf); + }); + writer.writeLine(','); + } + + private writeFieldReadDef(model: DataModel, writer: CodeBlockWriter, sourceFile: SourceFile) { + const fieldCheckers: Record = {}; + const overrideGuards: Record = {}; + const allFieldsAllows: Expression[] = []; + const allFieldsDenies: Expression[] = []; + + // generate field read checkers + for (const field of model.fields) { + const allows = getPolicyExpressions(field, 'allow', 'read'); + const denies = getPolicyExpressions(field, 'deny', 'read'); + if (denies.length === 0 && allows.length === 0) { + continue; + } + + allFieldsAllows.push(...allows); + allFieldsDenies.push(...denies); + + const guardFunc = this.generateFieldReadCheckerFunction(sourceFile, field, allows, denies); + // eslint-disable-next-line @typescript-eslint/no-non-null-assertion + fieldCheckers[field.name] = guardFunc.getName()!; + + const overrideAllows = getPolicyExpressions(field, 'allow', 'read', true); + if (overrideAllows.length > 0) { + const denies = getPolicyExpressions(field, 'deny', 'read'); + const overrideGuardFunc = generateQueryGuardFunction( + sourceFile, + model, + 'read', + overrideAllows, + denies, + field, + true + ); + overrideGuards[field.name] = overrideGuardFunc.getName()!; + } + } + + if (Object.keys(fieldCheckers).length > 0 || Object.keys(overrideGuards).length > 0) { + writer.write('read:'); + writer.block(() => { + if (Object.keys(fieldCheckers).length > 0) { + writer.write('checker:'); + + // write checkers + writer.inlineBlock(() => { + Object.entries(fieldCheckers).forEach(([fieldName, funcName]) => { + writer.write(`${fieldName}: ${funcName},`); + }); + }); + writer.writeLine(','); + + // write field selector + const readFieldCheckSelect = generateSelectForRules([...allFieldsAllows, ...allFieldsDenies]); + if (readFieldCheckSelect) { + writer.write(`selector: ${JSON.stringify(readFieldCheckSelect)},`); + } + } + + if (Object.keys(overrideGuards).length > 0) { + // write override guards + writer.write('overrideGuard:'); + writer.inlineBlock(() => { + Object.entries(overrideGuards).forEach(([fieldName, funcName]) => { + writer.write(`${fieldName}: ${funcName},`); + }); + }); + writer.writeLine(','); + } + }); + writer.writeLine(','); + } + } + + private writeFieldUpdateDef(model: DataModel, writer: CodeBlockWriter, sourceFile: SourceFile) { + const guards: Record = {}; + const overrideGuards: Record = {}; + + for (const field of model.fields) { + const allows = getPolicyExpressions(field, 'allow', 'update'); + const denies = getPolicyExpressions(field, 'deny', 'update'); + + if (denies.length === 0 && allows.length === 0) { + continue; + } + + const guardFunc = generateQueryGuardFunction(sourceFile, model, 'update', allows, denies, field); + guards[field.name] = guardFunc.getName()!; + + const overrideAllows = getPolicyExpressions(field, 'allow', 'update', true); + if (overrideAllows.length > 0) { + const overrideGuardFunc = generateQueryGuardFunction( + sourceFile, + model, + 'update', + overrideAllows, + denies, + field, + true + ); + overrideGuards[field.name] = overrideGuardFunc.getName()!; + } + } + + if (Object.keys(guards).length > 0 || Object.keys(overrideGuards).length > 0) { + writer.write('update:'); + writer.block(() => { + if (Object.keys(guards).length > 0) { + writer.write('guard:'); + writer.inlineBlock(() => { + Object.entries(guards).forEach(([fieldName, funcName]) => { + writer.write(`${fieldName}: ${funcName},`); + }); + }); + writer.writeLine(','); + } + + if (Object.keys(overrideGuards).length > 0) { + writer.write('overrideGuard:'); + writer.inlineBlock(() => { + Object.entries(overrideGuards).forEach(([fieldName, funcName]) => { + writer.write(`${fieldName}: ${funcName},`); + }); + }); + writer.writeLine(','); + } + }); + } + } + + private generateFieldReadCheckerFunction( sourceFile: SourceFile, - model: DataModel, - kind: 'create' | 'update', + field: DataModelField, allows: Expression[], denies: Expression[] - ): FunctionDeclaration { + ) { const statements: (string | WriterFunction)[] = []; - this.generateNormalizedAuthRef(model, allows, denies, statements); + generateNormalizedAuthRef(field.$container as DataModel, allows, denies, statements); + // compile rules down to typescript expressions statements.push((writer) => { - if (allows.length === 0) { - writer.write('return false;'); - return; - } - const transformer = new TypeScriptExpressionTransformer({ context: ExpressionContext.AccessPolicy, fieldReferenceContext: 'input', }); - let expr = + const denyStmt = denies.length > 0 ? '!(' + denies @@ -821,18 +626,34 @@ export class PolicyGenerator { ')' : undefined; - const allowStmt = allows - .map((allow) => { - return transformer.transform(allow); - }) - .join(' || '); + const allowStmt = + allows.length > 0 + ? '(' + + allows + .map((allow) => { + return transformer.transform(allow); + }) + .join(' || ') + + ')' + : undefined; + + let expr: string | undefined; + + if (denyStmt && allowStmt) { + expr = `${denyStmt} && ${allowStmt}`; + } else if (denyStmt) { + expr = denyStmt; + } else if (allowStmt) { + expr = allowStmt; + } else { + throw new Error('should not happen'); + } - expr = expr ? `${expr} && (${allowStmt})` : allowStmt; writer.write('return ' + expr); }); const func = sourceFile.addFunction({ - name: model.name + '_' + kind + '_input', + name: `${field.$container.name}$${field.name}_read`, returnType: 'boolean', parameters: [ { @@ -850,95 +671,71 @@ export class PolicyGenerator { return func; } - private generateNormalizedAuthRef( - model: DataModel, - allows: Expression[], - denies: Expression[], - statements: (string | WriterFunction)[] - ) { - // check if any allow or deny rule contains 'auth()' invocation - const hasAuthRef = [...allows, ...denies].some((rule) => - streamAst(rule).some((child) => isAuthInvocation(child)) - ); - - if (hasAuthRef) { - const authModel = getAuthModel(getDataModels(model.$container, true)); - if (!authModel) { - throw new PluginError(name, 'Auth model not found'); - } - const userIdFields = getIdFields(authModel); - if (!userIdFields || userIdFields.length === 0) { - throw new PluginError(name, 'User model does not have an id field'); - } + // #endregion + + //#region Auth selector - // normalize user to null to avoid accidentally use undefined in filter - statements.push(`const user: any = context.user ?? null;`); + private writeAuthSelector(models: DataModel[], writer: CodeBlockWriter) { + const authSelector = this.generateAuthSelector(models); + if (authSelector) { + writer.write(`authSelector: ${JSON.stringify(authSelector)},`); } } - private async generateCheckerForModel(model: DataModel, sourceFile: SourceFile) { - const result: Record = {}; + // Generates a { select: ... } object to select `auth()` fields used in policy rules + private generateAuthSelector(models: DataModel[]) { + const authRules: Expression[] = []; - // eslint-disable-next-line @typescript-eslint/no-explicit-any - const policies = analyzePolicies(model); + models.forEach((model) => { + // model-level rules + const modelPolicyAttrs = model.attributes.filter((attr) => + ['@@allow', '@@deny'].includes(attr.decl.$refText) + ); - for (const kind of ['create', 'read', 'update', 'delete'] as const) { - if (policies[kind] === true || policies[kind] === false) { - result[kind] = policies[kind] as boolean; - continue; - } + // field-level rules + const fieldPolicyAttrs = model.fields + .flatMap((f) => f.attributes) + .filter((attr) => ['@allow', '@deny'].includes(attr.decl.$refText)); - const denies = this.getPolicyExpressions(model, 'deny', kind); - const allows = this.getPolicyExpressions(model, 'allow', kind); - - if (kind === 'update' && allows.length === 0) { - // no allow rule for 'update', policy is constant based on if there's - // post-update counterpart - if (this.getPolicyExpressions(model, 'allow', 'postUpdate').length === 0) { - result[kind] = false; - continue; - } else { - result[kind] = true; - continue; - } - } + // all rule expression + const allExpressions = [...modelPolicyAttrs, ...fieldPolicyAttrs] + .filter((attr) => attr.args.length > 1) + .map((attr) => attr.args[1].value); - const guardFunc = this.generateCheckerFunction(sourceFile, model, kind, allows, denies); - result[kind] = guardFunc.getName()!; - } + // collect `auth()` member access + allExpressions.forEach((rule) => { + streamAst(rule).forEach((node) => { + if (isMemberAccessExpr(node) && isAuthInvocation(node.operand)) { + authRules.push(node); + } + }); + }); + }); - return result; + if (authRules.length > 0) { + return generateSelectForRules(authRules, true); + } else { + return undefined; + } } - private generateCheckerFunction( - sourceFile: SourceFile, - model: DataModel, - kind: string, - allows: Expression[], - denies: Expression[] - ) { - const statements: string[] = []; - - this.generateNormalizedAuthRef(model, allows, denies, statements); - - const transformed = new ConstraintTransformer({ - authAccessor: 'user', - }).transformRules(allows, denies); + // #endregion - statements.push(`return ${transformed};`); + // #region Validation meta - const func = sourceFile.addFunction({ - name: `${model.name}$checker$${kind}`, - returnType: 'CheckerConstraint', - parameters: [ - { - name: 'context', - type: 'CheckerContext', - }, - ], - statements, + private writeValidationMeta(writer: CodeBlockWriter, models: DataModel[]) { + writer.write('validation:'); + writer.inlineBlock(() => { + for (const model of models) { + writer.write(`${lowerCaseFirst(model.name)}:`); + writer.inlineBlock(() => { + writer.write(`hasValidation: ${hasValidationAttributes(model)}`); + }); + writer.writeLine(','); + } }); - - return func; + writer.writeLine(','); } + + // #endregion } diff --git a/packages/schema/src/plugins/enhancer/policy/utils.ts b/packages/schema/src/plugins/enhancer/policy/utils.ts new file mode 100644 index 000000000..c8b75ffd8 --- /dev/null +++ b/packages/schema/src/plugins/enhancer/policy/utils.ts @@ -0,0 +1,386 @@ +/* eslint-disable @typescript-eslint/no-explicit-any */ +import type { PolicyKind, PolicyOperationKind } from '@zenstackhq/runtime'; +import { + ExpressionContext, + PluginError, + TypeScriptExpressionTransformer, + TypeScriptExpressionTransformerError, + getAttributeArg, + getAuthModel, + getDataModels, + getIdFields, + getLiteral, + isAuthInvocation, + isEnumFieldReference, + isFromStdlib, + isFutureExpr, + resolved, +} from '@zenstackhq/sdk'; +import { + Enum, + Model, + isDataModel, + isDataModelField, + isExpression, + isInvocationExpr, + isMemberAccessExpr, + isReferenceExpr, + isThisExpr, + type DataModel, + type DataModelField, + type Expression, +} from '@zenstackhq/sdk/ast'; +import { streamAllContents, streamAst, streamContents } from 'langium'; +import { SourceFile, WriterFunction } from 'ts-morph'; +import { name } from '..'; +import { isCollectionPredicate } from '../../../utils/ast-utils'; +import { ExpressionWriter, FALSE, TRUE } from './expression-writer'; + +/** + * Get policy expressions for the given model or field and operation kind + */ +export function getPolicyExpressions( + target: DataModel | DataModelField, + kind: PolicyKind, + operation: PolicyOperationKind, + override = false +) { + const attributes = target.attributes; + const attrName = isDataModel(target) ? `@@${kind}` : `@${kind}`; + const attrs = attributes.filter((attr) => { + if (attr.decl.ref?.name !== attrName) { + return false; + } + + if (override) { + const overrideArg = getAttributeArg(attr, 'override'); + return overrideArg && getLiteral(overrideArg) === true; + } else { + return true; + } + }); + + const checkOperation = operation === 'postUpdate' ? 'update' : operation; + + let result = attrs + .filter((attr) => { + const opsValue = getLiteral(attr.args[0].value); + if (!opsValue) { + return false; + } + const ops = opsValue.split(',').map((s) => s.trim()); + return ops.includes(checkOperation) || ops.includes('all'); + }) + .map((attr) => attr.args[1].value); + + if (operation === 'update') { + result = processUpdatePolicies(result, false); + } else if (operation === 'postUpdate') { + result = processUpdatePolicies(result, true); + } + + return result; +} + +function hasFutureReference(expr: Expression) { + for (const node of streamAst(expr)) { + if (isInvocationExpr(node) && node.function.ref?.name === 'future' && isFromStdlib(node.function.ref)) { + return true; + } + } + return false; +} + +function processUpdatePolicies(expressions: Expression[], postUpdate: boolean) { + const hasFutureRef = expressions.some(hasFutureReference); + if (postUpdate) { + // when compiling post-update rules, if any rule contains `future()` reference, + // we include all as post-update rules + return hasFutureRef ? expressions : []; + } else { + // when compiling pre-update rules, if any rule contains `future()` reference, + // we completely skip pre-update check and defer them to post-update + return hasFutureRef ? [] : expressions; + } +} + +/** + * Generates a "select" object that contains (recursively) fields referenced by the + * given policy rules + */ +export function generateSelectForRules(rules: Expression[], forAuthContext = false): object { + const result: any = {}; + const addPath = (path: string[]) => { + let curr = result; + path.forEach((seg, i) => { + if (i === path.length - 1) { + curr[seg] = true; + } else { + if (!curr[seg]) { + curr[seg] = { select: {} }; + } + curr = curr[seg].select; + } + }); + }; + + // visit a reference or member access expression to build a + // selection path + const visit = (node: Expression): string[] | undefined => { + if (isThisExpr(node)) { + return []; + } + + if (isReferenceExpr(node)) { + const target = resolved(node.target); + if (isDataModelField(target)) { + // a field selection, it's a terminal + return [target.name]; + } + } + + if (isMemberAccessExpr(node)) { + if (forAuthContext && isAuthInvocation(node.operand)) { + return [node.member.$refText]; + } + + if (isFutureExpr(node.operand)) { + // future().field is not subject to pre-update select + return undefined; + } + + // build a selection path inside-out for chained member access + const inner = visit(node.operand); + if (inner) { + return [...inner, node.member.$refText]; + } + } + + return undefined; + }; + + // collect selection paths from the given expression + const collectReferencePaths = (expr: Expression): string[][] => { + if (isThisExpr(expr) && !isMemberAccessExpr(expr.$container)) { + // a standalone `this` expression, include all id fields + const model = expr.$resolvedType?.decl as DataModel; + const idFields = getIdFields(model); + return idFields.map((field) => [field.name]); + } + + if (isMemberAccessExpr(expr) || isReferenceExpr(expr)) { + const path = visit(expr); + if (path) { + if (isDataModel(expr.$resolvedType?.decl)) { + // member selection ended at a data model field, include its id fields + const idFields = getIdFields(expr.$resolvedType?.decl as DataModel); + return idFields.map((field) => [...path, field.name]); + } else { + return [path]; + } + } else { + return []; + } + } else if (isCollectionPredicate(expr)) { + const path = visit(expr.left); + if (path) { + // recurse into RHS + const rhs = collectReferencePaths(expr.right); + // combine path of LHS and RHS + return rhs.map((r) => [...path, ...r]); + } else { + return []; + } + } else if (isInvocationExpr(expr)) { + // recurse into function arguments + return expr.args.flatMap((arg) => collectReferencePaths(arg.value)); + } else { + // recurse + const children = streamContents(expr) + .filter((child): child is Expression => isExpression(child)) + .toArray(); + return children.flatMap((child) => collectReferencePaths(child)); + } + }; + + for (const rule of rules) { + const paths = collectReferencePaths(rule); + paths.forEach((p) => addPath(p)); + } + + return Object.keys(result).length === 0 ? undefined : result; +} + +/** + * Generates a query guard function that returns a partial Prisma query for the given model or field + */ +export function generateQueryGuardFunction( + sourceFile: SourceFile, + model: DataModel, + kind: PolicyOperationKind, + allows: Expression[], + denies: Expression[], + forField?: DataModelField, + fieldOverride = false +) { + const statements: (string | WriterFunction)[] = []; + + generateNormalizedAuthRef(model, allows, denies, statements); + + const hasFieldAccess = [...denies, ...allows].some((rule) => + streamAst(rule).some( + (child) => + // this.??? + isThisExpr(child) || + // future().??? + isFutureExpr(child) || + // field reference + (isReferenceExpr(child) && isDataModelField(child.target.ref)) + ) + ); + + if (!hasFieldAccess) { + // none of the rules reference model fields, we can compile down to a plain boolean + // function in this case (so we can skip doing SQL queries when validating) + statements.push((writer) => { + const transformer = new TypeScriptExpressionTransformer({ + context: ExpressionContext.AccessPolicy, + isPostGuard: kind === 'postUpdate', + }); + try { + denies.forEach((rule) => { + writer.write(`if (${transformer.transform(rule, false)}) { return ${FALSE}; }`); + }); + allows.forEach((rule) => { + writer.write(`if (${transformer.transform(rule, false)}) { return ${TRUE}; }`); + }); + } catch (err) { + if (err instanceof TypeScriptExpressionTransformerError) { + throw new PluginError(name, err.message); + } else { + throw err; + } + } + + if (forField) { + if (allows.length === 0) { + // if there's no allow rule, for field-level rules, by default we allow + writer.write(`return ${TRUE};`); + } else { + // if there's any allow rule, we deny unless any allow rule evaluates to true + writer.write(`return ${FALSE};`); + } + } else { + // for model-level rules, the default is always deny + writer.write(`return ${FALSE};`); + } + }); + } else { + statements.push((writer) => { + writer.write('return '); + const exprWriter = new ExpressionWriter(writer, kind === 'postUpdate'); + const writeDenies = () => { + writer.conditionalWrite(denies.length > 1, '{ AND: ['); + denies.forEach((expr, i) => { + writer.inlineBlock(() => { + writer.write('NOT: '); + exprWriter.write(expr); + }); + writer.conditionalWrite(i !== denies.length - 1, ','); + }); + writer.conditionalWrite(denies.length > 1, ']}'); + }; + + const writeAllows = () => { + writer.conditionalWrite(allows.length > 1, '{ OR: ['); + allows.forEach((expr, i) => { + exprWriter.write(expr); + writer.conditionalWrite(i !== allows.length - 1, ','); + }); + writer.conditionalWrite(allows.length > 1, ']}'); + }; + + if (allows.length > 0 && denies.length > 0) { + // include both allow and deny rules + writer.write('{ AND: ['); + writeDenies(); + writer.write(','); + writeAllows(); + writer.write(']}'); + } else if (denies.length > 0) { + // only deny rules + writeDenies(); + } else if (allows.length > 0) { + // only allow rules + writeAllows(); + } else { + // disallow any operation + writer.write(`{ OR: [] }`); + } + writer.write(';'); + }); + } + + const func = sourceFile.addFunction({ + name: `${model.name}${forField ? '$' + forField.name : ''}${fieldOverride ? '$override' : ''}_${kind}`, + returnType: 'any', + parameters: [ + { + name: 'context', + type: 'QueryContext', + }, + { + // for generating field references used by field comparison in the same model + name: 'db', + type: 'CrudContract', + }, + ], + statements, + }); + + return func; +} + +/** + * Generates a normalized auth reference for the given policy rules + */ +export function generateNormalizedAuthRef( + model: DataModel, + allows: Expression[], + denies: Expression[], + statements: (string | WriterFunction)[] +) { + // check if any allow or deny rule contains 'auth()' invocation + const hasAuthRef = [...allows, ...denies].some((rule) => streamAst(rule).some((child) => isAuthInvocation(child))); + + if (hasAuthRef) { + const authModel = getAuthModel(getDataModels(model.$container, true)); + if (!authModel) { + throw new PluginError(name, 'Auth model not found'); + } + const userIdFields = getIdFields(authModel); + if (!userIdFields || userIdFields.length === 0) { + throw new PluginError(name, 'User model does not have an id field'); + } + + // normalize user to null to avoid accidentally use undefined in filter + statements.push(`const user: any = context.user ?? null;`); + } +} + +/** + * Check if the given enum is referenced in the model + */ +export function isEnumReferenced(model: Model, decl: Enum): unknown { + return streamAllContents(model).some((node) => { + if (isDataModelField(node) && node.type.reference?.ref === decl) { + // referenced as field type + return true; + } + if (isEnumFieldReference(node) && node.target.ref?.$container === decl) { + // enum field is referenced + return true; + } + return false; + }); +} diff --git a/packages/sdk/src/policy.ts b/packages/sdk/src/policy.ts index ccd3e851f..c9eea9865 100644 --- a/packages/sdk/src/policy.ts +++ b/packages/sdk/src/policy.ts @@ -2,6 +2,8 @@ import type { DataModel, DataModelAttribute } from './ast'; import { getLiteral } from './utils'; import { hasValidationAttributes } from './validation'; +export type PolicyAnalysisResult = ReturnType; + export function analyzePolicies(dataModel: DataModel) { const allows = dataModel.attributes.filter((attr) => attr.decl.ref?.name === '@@allow'); const denies = dataModel.attributes.filter((attr) => attr.decl.ref?.name === '@@deny'); diff --git a/packages/testtools/src/schema.ts b/packages/testtools/src/schema.ts index 4495ddf14..fb90fac4b 100644 --- a/packages/testtools/src/schema.ts +++ b/packages/testtools/src/schema.ts @@ -3,6 +3,7 @@ import type { Model } from '@zenstackhq/language/ast'; import { DEFAULT_RUNTIME_LOAD_PATH, + PolicyDef, type AuthUser, type CrudContract, type EnhancementKind, @@ -43,14 +44,12 @@ export type FullDbClientContract = CrudContract & { export function run(cmd: string, env?: Record, cwd?: string) { try { - const start = Date.now(); execSync(cmd, { stdio: 'pipe', encoding: 'utf-8', env: { ...process.env, DO_NOT_TRACK: '1', ...env }, cwd, }); - console.log('Execution took', Date.now() - start, 'ms', '-', cmd); } catch (err) { console.error('Command failed:', cmd, err); throw err; @@ -299,7 +298,7 @@ export async function loadSchema(schema: string, options?: SchemaLoadOptions) { projectDir, enhance: undefined as any, enhanceRaw: undefined as any, - policy: undefined as any, + policy: undefined as unknown as PolicyDef, modelMeta: undefined as any, zodSchemas: undefined as any, }; @@ -311,7 +310,7 @@ export async function loadSchema(schema: string, options?: SchemaLoadOptions) { : path.join(projectDir, opt.output) : path.join(projectDir, 'node_modules', DEFAULT_RUNTIME_LOAD_PATH); - const policy = require(path.join(outputPath, 'policy')).default; + const policy: PolicyDef = require(path.join(outputPath, 'policy')).default; const modelMeta = require(path.join(outputPath, 'model-meta')).default; let zodSchemas: any; diff --git a/tests/integration/tests/enhancements/with-delegate/enhanced-client.test.ts b/tests/integration/tests/enhancements/with-delegate/enhanced-client.test.ts index 6a31540d7..8acc832c6 100644 --- a/tests/integration/tests/enhancements/with-delegate/enhanced-client.test.ts +++ b/tests/integration/tests/enhancements/with-delegate/enhanced-client.test.ts @@ -6,7 +6,7 @@ describe('Polymorphism Test', () => { const schema = POLYMORPHIC_SCHEMA; async function setup() { - const { enhance } = await loadSchema(schema, { logPrismaQuery: true, enhancements: ['delegate'] }); + const { enhance } = await loadSchema(schema, { enhancements: ['delegate'] }); const db = enhance(); const user = await db.user.create({ data: { id: 1 } }); @@ -21,7 +21,7 @@ describe('Polymorphism Test', () => { } it('create hierarchy', async () => { - const { enhance } = await loadSchema(schema, { logPrismaQuery: true, enhancements: ['delegate'] }); + const { enhance } = await loadSchema(schema, { enhancements: ['delegate'] }); const db = enhance(); const user = await db.user.create({ data: { id: 1 } }); @@ -100,7 +100,7 @@ describe('Polymorphism Test', () => { name String } `, - { logPrismaQuery: true, enhancements: ['delegate'] } + { enhancements: ['delegate'] } ); const db = enhance(); @@ -109,7 +109,7 @@ describe('Polymorphism Test', () => { }); it('create with nesting', async () => { - const { enhance } = await loadSchema(schema, { logPrismaQuery: true, enhancements: ['delegate'] }); + const { enhance } = await loadSchema(schema, { enhancements: ['delegate'] }); const db = enhance(); // nested create a relation from base @@ -122,7 +122,7 @@ describe('Polymorphism Test', () => { }); it('create many polymorphic model', async () => { - const { enhance } = await loadSchema(schema, { logPrismaQuery: true, enhancements: ['delegate'] }); + const { enhance } = await loadSchema(schema, { enhancements: ['delegate'] }); const db = enhance(); await expect( @@ -140,7 +140,7 @@ describe('Polymorphism Test', () => { }); it('create many polymorphic relation', async () => { - const { enhance } = await loadSchema(schema, { logPrismaQuery: true, enhancements: ['delegate'] }); + const { enhance } = await loadSchema(schema, { enhancements: ['delegate'] }); const db = enhance(); const video1 = await db.ratedVideo.create({ @@ -898,7 +898,7 @@ describe('Polymorphism Test', () => { }); it('deleteMany', async () => { - const { enhance } = await loadSchema(schema, { logPrismaQuery: true, enhancements: ['delegate'] }); + const { enhance } = await loadSchema(schema, { enhancements: ['delegate'] }); const db = enhance(); const user = await db.user.create({ data: { id: 1 } }); diff --git a/tests/integration/tests/enhancements/with-delegate/policy-interaction.test.ts b/tests/integration/tests/enhancements/with-delegate/policy-interaction.test.ts index d0316595d..c8e5bd432 100644 --- a/tests/integration/tests/enhancements/with-delegate/policy-interaction.test.ts +++ b/tests/integration/tests/enhancements/with-delegate/policy-interaction.test.ts @@ -89,14 +89,10 @@ describe('Polymorphic Policy Test', () => { for (const schema of [booleanCondition, booleanExpression]) { const { enhanceRaw: enhance, prisma } = await loadSchema(schema); - const fullDb = enhance(prisma, undefined, { kinds: ['delegate'], logPrismaQuery: true }); + const fullDb = enhance(prisma, undefined, { kinds: ['delegate'] }); const user = await fullDb.user.create({ data: { id: 1 } }); - const userDb = enhance( - prisma, - { user: { id: user.id } }, - { kinds: ['delegate', 'policy'], logPrismaQuery: true } - ); + const userDb = enhance(prisma, { user: { id: user.id } }, { kinds: ['delegate', 'policy'] }); // violating Asset create await expect( @@ -189,9 +185,7 @@ describe('Polymorphic Policy Test', () => { } `; - const { enhance } = await loadSchema(schema, { - logPrismaQuery: true, - }); + const { enhance } = await loadSchema(schema); const db = enhance(); const user = await db.user.create({ data: { id: 1 } }); diff --git a/tests/integration/tests/enhancements/with-policy/fluent-api.test.ts b/tests/integration/tests/enhancements/with-policy/fluent-api.test.ts index 9dd247d65..c910ff4f1 100644 --- a/tests/integration/tests/enhancements/with-policy/fluent-api.test.ts +++ b/tests/integration/tests/enhancements/with-policy/fluent-api.test.ts @@ -41,8 +41,7 @@ model Post { secret String @default("secret") @allow('read', published == false, true) @@allow('read', published) -}`, - { logPrismaQuery: true } +}` ); await prisma.user.create({ diff --git a/tests/integration/tests/enhancements/with-policy/nested-to-one.test.ts b/tests/integration/tests/enhancements/with-policy/nested-to-one.test.ts index e215a917b..59c968fb5 100644 --- a/tests/integration/tests/enhancements/with-policy/nested-to-one.test.ts +++ b/tests/integration/tests/enhancements/with-policy/nested-to-one.test.ts @@ -290,8 +290,7 @@ describe('With Policy:nested to-one', () => { @@allow('create', value > 0) @@allow('update', value > 1) } - `, - { logPrismaQuery: true } + ` ); const db = enhance(); diff --git a/tests/integration/tests/enhancements/with-policy/post-update.test.ts b/tests/integration/tests/enhancements/with-policy/post-update.test.ts index b101356cd..d43804787 100644 --- a/tests/integration/tests/enhancements/with-policy/post-update.test.ts +++ b/tests/integration/tests/enhancements/with-policy/post-update.test.ts @@ -110,8 +110,7 @@ describe('With Policy: post update', () => { @@allow('create,read', true) @@allow('update', x > 0 && startsWith(future().value, 'hello')) } - `, - { logPrismaQuery: true } + ` ); const db = enhance(); diff --git a/tests/integration/tests/enhancements/with-policy/prisma-omit.test.ts b/tests/integration/tests/enhancements/with-policy/prisma-omit.test.ts index d46c31245..a9a1b49d2 100644 --- a/tests/integration/tests/enhancements/with-policy/prisma-omit.test.ts +++ b/tests/integration/tests/enhancements/with-policy/prisma-omit.test.ts @@ -21,7 +21,7 @@ describe('prisma omit', () => { @@allow('all', level > 1) } `, - { previewFeatures: ['omitApi'], logPrismaQuery: true } + { previewFeatures: ['omitApi'] } ); await prisma.user.create({ diff --git a/tests/integration/tests/enhancements/with-policy/refactor.test.ts b/tests/integration/tests/enhancements/with-policy/refactor.test.ts index 3c725697d..6ee5c2343 100644 --- a/tests/integration/tests/enhancements/with-policy/refactor.test.ts +++ b/tests/integration/tests/enhancements/with-policy/refactor.test.ts @@ -26,7 +26,6 @@ describe('With Policy: refactor tests', () => { { provider: 'postgresql', dbUrl, - logPrismaQuery: true, } ); getDb = enhance; diff --git a/tests/integration/tests/enhancements/with-policy/subscription.test.ts b/tests/integration/tests/enhancements/with-policy/subscription.test.ts index a4dccf807..aa93706d8 100644 --- a/tests/integration/tests/enhancements/with-policy/subscription.test.ts +++ b/tests/integration/tests/enhancements/with-policy/subscription.test.ts @@ -34,7 +34,6 @@ describe.skip('With Policy: subscription test', () => { provider: 'postgresql', dbUrl: DB_URL, pulseApiKey: PULSE_API_KEY, - logPrismaQuery: true, } ); @@ -88,7 +87,6 @@ describe.skip('With Policy: subscription test', () => { provider: 'postgresql', dbUrl: DB_URL, pulseApiKey: PULSE_API_KEY, - logPrismaQuery: true, } ); @@ -143,7 +141,6 @@ describe.skip('With Policy: subscription test', () => { provider: 'postgresql', dbUrl: DB_URL, pulseApiKey: PULSE_API_KEY, - logPrismaQuery: true, } ); @@ -198,7 +195,6 @@ describe.skip('With Policy: subscription test', () => { provider: 'postgresql', dbUrl: DB_URL, pulseApiKey: PULSE_API_KEY, - logPrismaQuery: true, } ); diff --git a/tests/integration/tests/plugins/policy.test.ts b/tests/integration/tests/plugins/policy.test.ts index 5158584f4..3d9e75f98 100644 --- a/tests/integration/tests/plugins/policy.test.ts +++ b/tests/integration/tests/plugins/policy.test.ts @@ -36,18 +36,20 @@ model M { const { policy } = await loadSchema(model); - expect(policy.guard.m.read({ user: undefined })).toEqual(FALSE); - expect(policy.guard.m.read({ user: { id: '1' } })).toEqual(TRUE); - - expect(policy.guard.m.create({ user: undefined })).toEqual(FALSE); - expect(policy.guard.m.create({ user: { id: '1' } })).toEqual(FALSE); - expect(policy.guard.m.create({ user: { id: '1', value: 0 } })).toEqual(FALSE); - expect(policy.guard.m.create({ user: { id: '1', value: 1 } })).toEqual(TRUE); - - expect(policy.guard.m.update({ user: undefined })).toEqual(FALSE); - expect(policy.guard.m.update({ user: { id: '1' } })).toEqual(FALSE); - expect(policy.guard.m.update({ user: { id: '1', value: 0 } })).toEqual(FALSE); - expect(policy.guard.m.update({ user: { id: '1', value: 1 } })).toEqual(TRUE); + const m = policy.policy.m.modelLevel; + + expect((m.read.guard as Function)({ user: undefined })).toEqual(FALSE); + expect((m.read.guard as Function)({ user: { id: '1' } })).toEqual(TRUE); + + expect((m.create.guard as Function)({ user: undefined })).toEqual(FALSE); + expect((m.create.guard as Function)({ user: { id: '1' } })).toEqual(FALSE); + expect((m.create.guard as Function)({ user: { id: '1', value: 0 } })).toEqual(FALSE); + expect((m.create.guard as Function)({ user: { id: '1', value: 1 } })).toEqual(TRUE); + + expect((m.update.guard as Function)({ user: undefined })).toEqual(FALSE); + expect((m.update.guard as Function)({ user: { id: '1' } })).toEqual(FALSE); + expect((m.update.guard as Function)({ user: { id: '1', value: 0 } })).toEqual(FALSE); + expect((m.update.guard as Function)({ user: { id: '1', value: 1 } })).toEqual(TRUE); }); it('no short-circuit', async () => { @@ -66,13 +68,14 @@ model M { const { policy } = await loadSchema(model); - expect(policy.guard.m.read({ user: undefined })).toEqual( + expect((policy.policy.m.modelLevel.read.guard as Function)({ user: undefined })).toEqual( expect.objectContaining({ AND: [{ OR: [] }, { value: { gt: 0 } }] }) ); - expect(policy.guard.m.read({ user: { id: '1' } })).toEqual( + expect((policy.policy.m.modelLevel.read.guard as Function)({ user: { id: '1' } })).toEqual( expect.objectContaining({ AND: [{ AND: [] }, { value: { gt: 0 } }] }) ); }); + it('auth() multiple level member access', async () => { const model = ` model User { @@ -97,12 +100,12 @@ model M { `; const { policy } = await loadSchema(model); - expect(policy.guard.task.read({ user: { cart: { tasks: [{ id: 1 }] } } })).toEqual( - expect.objectContaining({ AND: [{ OR: [] }, { value: { gt: 10 } }] }) - ); + expect( + (policy.policy.task.modelLevel.read.guard as Function)({ user: { cart: { tasks: [{ id: 1 }] } } }) + ).toEqual(expect.objectContaining({ AND: [{ OR: [] }, { value: { gt: 10 } }] })); - expect(policy.guard.task.read({ user: { cart: { tasks: [{ id: 123 }] } } })).toEqual( - expect.objectContaining({ AND: [{ AND: [] }, { value: { gt: 10 } }] }) - ); + expect( + (policy.policy.task.modelLevel.read.guard as Function)({ user: { cart: { tasks: [{ id: 123 }] } } }) + ).toEqual(expect.objectContaining({ AND: [{ AND: [] }, { value: { gt: 10 } }] })); }); }); diff --git a/tests/regression/tests/issue-1014.test.ts b/tests/regression/tests/issue-1014.test.ts index ad862db42..66caa1b11 100644 --- a/tests/regression/tests/issue-1014.test.ts +++ b/tests/regression/tests/issue-1014.test.ts @@ -37,8 +37,7 @@ describe('issue 1014', () => { title String @allow('read', true, true) content String } - `, - { logPrismaQuery: true } + ` ); const db = enhance(); diff --git a/tests/regression/tests/issue-1080.test.ts b/tests/regression/tests/issue-1080.test.ts index 17ce998c2..69408fdf0 100644 --- a/tests/regression/tests/issue-1080.test.ts +++ b/tests/regression/tests/issue-1080.test.ts @@ -19,8 +19,7 @@ describe('issue 1080', () => { @@allow('all', true) } - `, - { logPrismaQuery: true } + ` ); const db = enhance(); diff --git a/tests/regression/tests/issue-1241.test.ts b/tests/regression/tests/issue-1241.test.ts index e5d94c9b7..a555fcb8d 100644 --- a/tests/regression/tests/issue-1241.test.ts +++ b/tests/regression/tests/issue-1241.test.ts @@ -38,8 +38,7 @@ describe('issue 1241', () => { @@allow('all', true) } - `, - { logPrismaQuery: true } + ` ); const user = await prisma.user.create({ diff --git a/tests/regression/tests/issue-1271.test.ts b/tests/regression/tests/issue-1271.test.ts index d25cabb3b..9798664cb 100644 --- a/tests/regression/tests/issue-1271.test.ts +++ b/tests/regression/tests/issue-1271.test.ts @@ -39,8 +39,7 @@ describe('issue 1271', () => { @@allow("all", true) } - `, - { logPrismaQuery: true } + ` ); const db = enhance(); diff --git a/tests/regression/tests/issue-1435.test.ts b/tests/regression/tests/issue-1435.test.ts index 0093aff8b..d539b778f 100644 --- a/tests/regression/tests/issue-1435.test.ts +++ b/tests/regression/tests/issue-1435.test.ts @@ -83,7 +83,7 @@ describe('issue 1435', () => { reference String @id } `, - { provider: 'postgresql', dbUrl, logPrismaQuery: true } + { provider: 'postgresql', dbUrl } ); prisma = r.prisma; diff --git a/tests/regression/tests/issue-1451.test.ts b/tests/regression/tests/issue-1451.test.ts index f54a0ca4f..fb105561d 100644 --- a/tests/regression/tests/issue-1451.test.ts +++ b/tests/regression/tests/issue-1451.test.ts @@ -29,8 +29,7 @@ describe('issue 1452', () => { @@id([userId, spaceId]) @@allow('all', true) } - `, - { logPrismaQuery: true } + ` ); await prisma.user.create({ diff --git a/tests/regression/tests/issue-961.test.ts b/tests/regression/tests/issue-961.test.ts index f6dc3a135..1f622059e 100644 --- a/tests/regression/tests/issue-961.test.ts +++ b/tests/regression/tests/issue-961.test.ts @@ -35,7 +35,7 @@ describe('Regression: issue 961', () => { `; it('deleteMany', async () => { - const { prisma, enhance } = await loadSchema(schema, { logPrismaQuery: true }); + const { prisma, enhance } = await loadSchema(schema); const user = await prisma.user.create({ data: { diff --git a/tests/regression/tests/issues.test.ts b/tests/regression/tests/issues.test.ts index 318682aad..1418a309a 100644 --- a/tests/regression/tests/issues.test.ts +++ b/tests/regression/tests/issues.test.ts @@ -531,8 +531,7 @@ model tenant { model Equipment extends BaseEntityWithTenant { a String } -`, - { logPrismaQuery: true } +` ); await prisma.tenant.create({