diff --git a/packages/runtime/package.json b/packages/runtime/package.json index 10da4f5c9..6b049861d 100644 --- a/packages/runtime/package.json +++ b/packages/runtime/package.json @@ -29,6 +29,10 @@ "types": "./enhancements/index.d.ts", "default": "./enhancements/index.js" }, + "./constraint-solver": { + "types": "./constraint-solver.d.ts", + "default": "./constraint-solver.js" + }, "./zod": { "types": "./zod/index.d.ts", "default": "./zod/index.js" @@ -79,12 +83,14 @@ "decimal.js": "^10.4.2", "deepcopy": "^2.1.0", "deepmerge": "^4.3.1", + "logic-solver": "^2.0.1", "lower-case-first": "^2.0.2", "pluralize": "^8.0.0", "safe-json-stringify": "^1.2.0", "semver": "^7.5.2", "superjson": "^1.11.0", "tiny-invariant": "^1.3.1", + "ts-pattern": "^4.3.0", "tslib": "^2.4.1", "upper-case-first": "^2.0.2", "uuid": "^9.0.0", diff --git a/packages/runtime/src/enhancements/policy/constraint-solver.ts b/packages/runtime/src/enhancements/policy/constraint-solver.ts new file mode 100644 index 000000000..c87a528e7 --- /dev/null +++ b/packages/runtime/src/enhancements/policy/constraint-solver.ts @@ -0,0 +1,219 @@ +import Logic from 'logic-solver'; +import { match } from 'ts-pattern'; +import type { + CheckerConstraint, + ComparisonConstraint, + ComparisonTerm, + LogicalConstraint, + ValueConstraint, + VariableConstraint, +} from '../types'; + +/** + * A boolean constraint solver based on `logic-solver`. Only boolean and integer types are supported. + */ +export class ConstraintSolver { + // a table for internalizing string literals + private stringTable: string[] = []; + + // a map for storing variable names and their corresponding formulas + private variables: Map = new Map(); + + /** + * Check the satisfiability of the given constraint. + */ + checkSat(constraint: CheckerConstraint): boolean { + // reset state + this.stringTable = []; + this.variables = new Map(); + + // convert the constraint to a "logic-solver" formula + const formula = this.buildFormula(constraint); + + // solve the formula + const solver = new Logic.Solver(); + solver.require(formula); + + // DEBUG: + // const solution = solver.solve(); + // if (solution) { + // console.log('Solution:'); + // this.variables.forEach((v, k) => console.log(`\t${k}=${solution?.evaluate(v)}`)); + // } else { + // console.log('No solution'); + // } + + return !!solver.solve(); + } + + private buildFormula(constraint: CheckerConstraint): Logic.Formula { + return match(constraint) + .when( + (c): c is ValueConstraint => c.kind === 'value', + (c) => this.buildValueFormula(c) + ) + .when( + (c): c is VariableConstraint => c.kind === 'variable', + (c) => this.buildVariableFormula(c) + ) + .when( + (c): c is ComparisonConstraint => ['eq', 'ne', 'gt', 'gte', 'lt', 'lte'].includes(c.kind), + (c) => this.buildComparisonFormula(c) + ) + .when( + (c): c is LogicalConstraint => ['and', 'or', 'not'].includes(c.kind), + (c) => this.buildLogicalFormula(c) + ) + .otherwise(() => { + throw new Error(`Unsupported constraint format: ${JSON.stringify(constraint)}`); + }); + } + + private buildLogicalFormula(constraint: LogicalConstraint) { + return match(constraint.kind) + .with('and', () => this.buildAndFormula(constraint)) + .with('or', () => this.buildOrFormula(constraint)) + .with('not', () => this.buildNotFormula(constraint)) + .exhaustive(); + } + + private buildAndFormula(constraint: LogicalConstraint): Logic.Formula { + if (constraint.children.some((c) => this.isFalse(c))) { + // short-circuit + return Logic.FALSE; + } + return Logic.and(...constraint.children.map((c) => this.buildFormula(c))); + } + + private buildOrFormula(constraint: LogicalConstraint): Logic.Formula { + if (constraint.children.some((c) => this.isTrue(c))) { + // short-circuit + return Logic.TRUE; + } + return Logic.or(...constraint.children.map((c) => this.buildFormula(c))); + } + + private buildNotFormula(constraint: LogicalConstraint) { + if (constraint.children.length !== 1) { + throw new Error('"not" constraint must have exactly one child'); + } + return Logic.not(this.buildFormula(constraint.children[0])); + } + + private isTrue(constraint: CheckerConstraint): unknown { + return constraint.kind === 'value' && constraint.value === true; + } + + private isFalse(constraint: CheckerConstraint): unknown { + return constraint.kind === 'value' && constraint.value === false; + } + + private buildComparisonFormula(constraint: ComparisonConstraint) { + if (constraint.left.kind === 'value' && constraint.right.kind === 'value') { + // constant comparison + const left: ValueConstraint = constraint.left; + const right: ValueConstraint = constraint.right; + return match(constraint.kind) + .with('eq', () => (left.value === right.value ? Logic.TRUE : Logic.FALSE)) + .with('ne', () => (left.value !== right.value ? Logic.TRUE : Logic.FALSE)) + .with('gt', () => (left.value > right.value ? Logic.TRUE : Logic.FALSE)) + .with('gte', () => (left.value >= right.value ? Logic.TRUE : Logic.FALSE)) + .with('lt', () => (left.value < right.value ? Logic.TRUE : Logic.FALSE)) + .with('lte', () => (left.value <= right.value ? Logic.TRUE : Logic.FALSE)) + .exhaustive(); + } + + return match(constraint.kind) + .with('eq', () => this.transformEquality(constraint.left, constraint.right)) + .with('ne', () => this.transformInequality(constraint.left, constraint.right)) + .with('gt', () => + this.transformComparison(constraint.left, constraint.right, (l, r) => Logic.greaterThan(l, r)) + ) + .with('gte', () => + this.transformComparison(constraint.left, constraint.right, (l, r) => Logic.greaterThanOrEqual(l, r)) + ) + .with('lt', () => + this.transformComparison(constraint.left, constraint.right, (l, r) => Logic.lessThan(l, r)) + ) + .with('lte', () => + this.transformComparison(constraint.left, constraint.right, (l, r) => Logic.lessThanOrEqual(l, r)) + ) + .exhaustive(); + } + + private buildVariableFormula(constraint: VariableConstraint) { + return ( + match(constraint.type) + .with('boolean', () => this.booleanVariable(constraint.name)) + .with('number', () => this.intVariable(constraint.name)) + // strings are internalized and represented by their indices + .with('string', () => this.intVariable(constraint.name)) + .exhaustive() + ); + } + + private buildValueFormula(constraint: ValueConstraint) { + return match(constraint.value) + .when( + (v): v is boolean => typeof v === 'boolean', + (v) => (v === true ? Logic.TRUE : Logic.FALSE) + ) + .when( + (v): v is number => typeof v === 'number', + (v) => Logic.constantBits(v) + ) + .when( + (v): v is string => typeof v === 'string', + (v) => { + // internalize the string and use its index as formula representation + const index = this.stringTable.indexOf(v); + if (index === -1) { + this.stringTable.push(v); + return Logic.constantBits(this.stringTable.length - 1); + } else { + return Logic.constantBits(index); + } + } + ) + .exhaustive(); + } + + private booleanVariable(name: string) { + this.variables.set(name, name); + return name; + } + + private intVariable(name: string) { + const r = Logic.variableBits(name, 32); + this.variables.set(name, r); + return r; + } + + private transformEquality(left: ComparisonTerm, right: ComparisonTerm) { + if (left.type !== right.type) { + throw new Error(`Type mismatch in equality constraint: ${JSON.stringify(left)}, ${JSON.stringify(right)}`); + } + + const leftFormula = this.buildFormula(left); + const rightFormula = this.buildFormula(right); + if (left.type === 'boolean' && right.type === 'boolean') { + // logical equivalence + return Logic.equiv(leftFormula, rightFormula); + } else { + // integer equality + return Logic.equalBits(leftFormula, rightFormula); + } + } + + private transformInequality(left: ComparisonTerm, right: ComparisonTerm) { + return Logic.not(this.transformEquality(left, right)); + } + + private transformComparison( + left: ComparisonTerm, + right: ComparisonTerm, + func: (left: Logic.Formula, right: Logic.Formula) => Logic.Formula + ) { + return func(this.buildFormula(left), this.buildFormula(right)); + } +} diff --git a/packages/runtime/src/enhancements/policy/handler.ts b/packages/runtime/src/enhancements/policy/handler.ts index d6d893d4e..cc0ea4f03 100644 --- a/packages/runtime/src/enhancements/policy/handler.ts +++ b/packages/runtime/src/enhancements/policy/handler.ts @@ -2,6 +2,7 @@ import { lowerCaseFirst } from 'lower-case-first'; import invariant from 'tiny-invariant'; +import { P, match } from 'ts-pattern'; import { upperCaseFirst } from 'upper-case-first'; import { fromZodError } from 'zod-validation-error'; import { CrudFailureReason } from '../../constants'; @@ -16,13 +17,15 @@ import { type FieldInfo, type ModelMeta, } from '../../cross'; -import { PolicyOperationKind, type CrudContract, type DbClientContract } from '../../types'; +import { PolicyCrudKind, PolicyOperationKind, type CrudContract, type DbClientContract } from '../../types'; import type { EnhancementContext, InternalEnhancementOptions } from '../create-enhancement'; import { Logger } from '../logger'; import { createDeferredPromise, createFluentPromise } from '../promise'; import { PrismaProxyHandler } from '../proxy'; import { QueryUtils } from '../query-utils'; +import type { CheckerConstraint } from '../types'; import { clone, formatObject, isUnsafeMutate, prismaClientValidationError } from '../utils'; +import { ConstraintSolver } from './constraint-solver'; import { PolicyUtil } from './policy-utils'; // a record for post-write policy check @@ -35,6 +38,12 @@ type PostWriteCheckRecord = { type FindOperations = 'findUnique' | 'findUniqueOrThrow' | 'findFirst' | 'findFirstOrThrow' | 'findMany'; +// input arg type for `check` API +type PermissionCheckArgs = { + operation: PolicyCrudKind; + filter?: Record; +}; + /** * Prisma proxy handler for injecting access policy check. */ @@ -1436,6 +1445,115 @@ export class PolicyProxyHandler implements Pr //#endregion + //#region Check + + /** + * Checks if the given operation is possibly allowed by the policy, without querying the database. + * @param operation The CRUD operation. + * @param fieldValues Extra field value filters to be combined with the policy constraints. + */ + async check(args: PermissionCheckArgs): Promise { + return createDeferredPromise(() => this.doCheck(args)); + } + + private async doCheck(args: PermissionCheckArgs) { + if (!['create', 'read', 'update', 'delete'].includes(args.operation)) { + throw prismaClientValidationError(this.prisma, this.prismaModule, `Invalid "operation" ${args.operation}`); + } + + let constraint = this.policyUtils.getCheckerConstraint(this.model, args.operation); + if (typeof constraint === 'boolean') { + return constraint; + } + + if (args.filter) { + // combine runtime filters with generated constraints + + const extraConstraints: CheckerConstraint[] = []; + for (const [field, value] of Object.entries(args.filter)) { + if (value === undefined) { + continue; + } + + if (value === null) { + throw prismaClientValidationError( + this.prisma, + this.prismaModule, + `Using "null" as filter value is not supported yet` + ); + } + + const fieldInfo = requireField(this.modelMeta, this.model, field); + + // relation and array fields are not supported + if (fieldInfo.isDataModel || fieldInfo.isArray) { + throw prismaClientValidationError( + this.prisma, + this.prismaModule, + `Providing filter for field "${field}" is not supported. Only scalar fields are allowed.` + ); + } + + // map field type to constraint type + const fieldType = match(fieldInfo.type) + .with(P.union('Int', 'BigInt', 'Float', 'Decimal'), () => 'number') + .with('String', () => 'string') + .with('Boolean', () => 'boolean') + .otherwise(() => { + throw prismaClientValidationError( + this.prisma, + this.prismaModule, + `Providing filter for field "${field}" is not supported. Only number, string, and boolean fields are allowed.` + ); + }); + + // check value type + const valueType = typeof value; + if (valueType !== 'number' && valueType !== 'string' && valueType !== 'boolean') { + throw prismaClientValidationError( + this.prisma, + this.prismaModule, + `Invalid value type for field "${field}". Only number, string or boolean is allowed.` + ); + } + + if (fieldType !== valueType) { + throw prismaClientValidationError( + this.prisma, + this.prismaModule, + `Invalid value type for field "${field}". Expected "${fieldType}".` + ); + } + + // check number validity + if (typeof value === 'number' && (!Number.isInteger(value) || value < 0)) { + throw prismaClientValidationError( + this.prisma, + this.prismaModule, + `Invalid value for field "${field}". Only non-negative integers are allowed.` + ); + } + + // build a constraint + extraConstraints.push({ + kind: 'eq', + left: { kind: 'variable', name: field, type: fieldType }, + right: { kind: 'value', value, type: fieldType }, + }); + } + + if (extraConstraints.length > 0) { + // combine the constraints + constraint = { kind: 'and', children: [constraint, ...extraConstraints] }; + } + } + + // check satisfiability + return new ConstraintSolver().checkSat(constraint); + } + + //#endregion + //#region Utils private get shouldLogQuery() { diff --git a/packages/runtime/src/enhancements/policy/logic-solver.d.ts b/packages/runtime/src/enhancements/policy/logic-solver.d.ts new file mode 100644 index 000000000..d10e688f6 --- /dev/null +++ b/packages/runtime/src/enhancements/policy/logic-solver.d.ts @@ -0,0 +1,109 @@ +/** + * Type definitions for the `logic-solver` npm package. + */ +declare module 'logic-solver' { + /** + * A boolean formula. + */ + interface Formula {} + + /** + * The `TRUE` formula. + */ + const TRUE: Formula; + + /** + * The `FALSE` formula. + */ + const FALSE: Formula; + + /** + * Boolean equivalence. + */ + export function equiv(operand1: Formula, operand2: Formula): Formula; + + /** + * Bits equality. + */ + export function equalBits(bits1: Formula, bits2: Formula): Formula; + + /** + * Bits greater-than. + */ + export function greaterThan(bits1: Formula, bits2: Formula): Formula; + + /** + * Bits greater-than-or-equal. + */ + export function greaterThanOrEqual(bits1: Formula, bits2: Formula): Formula; + + /** + * Bits less-than. + */ + export function lessThan(bits1: Formula, bits2: Formula): Formula; + + /** + * Bits less-than-or-equal. + */ + export function lessThanOrEqual(bits1: Formula, bits2: Formula): Formula; + + /** + * Logical AND. + */ + export function and(...args: Formula[]): Formula; + + /** + * Logical OR. + */ + export function or(...args: Formula[]): Formula; + + /** + * Logical NOT. + */ + export function not(arg: Formula): Formula; + + /** + * Creates a bits variable with the given name and bit length. + */ + export function variableBits(baseName: string, N: number): Formula; + + /** + * Creates a constant bits formula from the given whole number. + */ + export function constantBits(wholeNumber: number): Formula; + + /** + * A solution to a constraint. + */ + interface Solution { + /** + * Returns a map of variable assignments. + */ + getMap(): object; + + /** + * Evaluates the given formula against the solution. + */ + evaluate(formula: Formula): unknown; + } + + /** + * A constraint solver. + */ + class Solver { + /** + * Adds constraints to the solver. + */ + require(...args: Formula[]): void; + + /** + * Adds negated constraints from the solver. + */ + forbid(...args: Formula[]): void; + + /** + * Solves the constraints. + */ + solve(): Solution; + } +} diff --git a/packages/runtime/src/enhancements/policy/policy-utils.ts b/packages/runtime/src/enhancements/policy/policy-utils.ts index bcb946877..2cddaae5e 100644 --- a/packages/runtime/src/enhancements/policy/policy-utils.ts +++ b/packages/runtime/src/enhancements/policy/policy-utils.ts @@ -17,12 +17,12 @@ import { PrismaErrorCode, } from '../../constants'; import { enumerate, getFields, getModelFields, resolveField, zip, type FieldInfo, type ModelMeta } from '../../cross'; -import { AuthUser, CrudContract, DbClientContract, PolicyOperationKind } from '../../types'; +import { AuthUser, CrudContract, DbClientContract, PolicyCrudKind, PolicyOperationKind } from '../../types'; import { getVersion } from '../../version'; import type { EnhancementContext, InternalEnhancementOptions } from '../create-enhancement'; import { Logger } from '../logger'; import { QueryUtils } from '../query-utils'; -import type { InputCheckFunc, PolicyDef, ReadFieldCheckFunc, ZodSchemas } from '../types'; +import type { CheckerFunc, InputCheckFunc, PolicyDef, ReadFieldCheckFunc, ZodSchemas } from '../types'; import { formatObject, prismaClientKnownRequestError } from '../utils'; /** @@ -228,7 +228,7 @@ export class PolicyUtil extends QueryUtils { //#endregion - //# Auth guard + //#region Auth guard private readonly FULLY_OPEN_AUTH_GUARD = { create: true, @@ -267,7 +267,7 @@ export class PolicyUtil extends QueryUtils { } if (!provider) { - throw this.unknownError(`zenstack: unable to load authorization guard for ${model}`); + throw this.unknownError(`unable to load authorization guard for ${model}`); } const r = provider({ user: this.user, preValue }, db); return this.reduce(r); @@ -561,6 +561,50 @@ export class PolicyUtil extends QueryUtils { return true; } + //#endregion + + //#region Checker + + /** + * Gets checker constraints for the given model and operation. + */ + getCheckerConstraint(model: string, operation: PolicyCrudKind): ReturnType | boolean { + const checker = this.getModelChecker(model); + if (!checker) { + throw this.unknownError(`unable to load policy guard for ${model}`); + } + + const provider = checker[operation]; + if (typeof provider === 'boolean') { + return provider; + } + + if (typeof provider !== 'function') { + throw this.unknownError(`unable to load ${operation} checker for ${model}`); + } + + // call checker function + return provider({ user: this.user }); + } + + private getModelChecker(model: string) { + if (this.options.kinds && !this.options.kinds.includes('policy')) { + // policy enhancement not enabled, return a constant true checker + return { create: true, read: true, update: true, delete: true }; + } else { + const result = this.options.policy.checker?.[lowerCaseFirst(model)]; + if (!result) { + // checker generation not enabled, return constant false checker + throw new Error( + `Generated permission checkers not found. Please make sure the "generatePermissionChecker" option is set to true in the "@core/enhancer" plugin.` + ); + } + return result; + } + } + + //#endregion + /** * Gets unique constraints for the given model. */ diff --git a/packages/runtime/src/enhancements/types.ts b/packages/runtime/src/enhancements/types.ts index 9fecc375e..89d5ce9f6 100644 --- a/packages/runtime/src/enhancements/types.ts +++ b/packages/runtime/src/enhancements/types.ts @@ -9,7 +9,7 @@ import { HAS_FIELD_LEVEL_POLICY_FLAG, PRE_UPDATE_VALUE_SELECTOR, } from '../constants'; -import type { CrudContract, PolicyOperationKind, QueryContext } from '../types'; +import type { CheckerContext, CrudContract, PolicyCrudKind, PolicyOperationKind, QueryContext } from '../types'; /** * Common options for PrismaClient enhancements @@ -33,6 +33,57 @@ export interface CommonEnhancementOptions { */ export type PolicyFunc = (context: QueryContext, db: CrudContract) => object; +/** + * Function for checking if an operation is possibly allowed. + */ +export type CheckerFunc = (context: CheckerContext) => CheckerConstraint; + +/** + * Supported checker constraint checking value types. + */ +export type ConstraintValueTypes = 'boolean' | 'number' | 'string'; + +/** + * Free variable constraint + */ +export type VariableConstraint = { kind: 'variable'; name: string; type: ConstraintValueTypes }; + +/** + * Constant value constraint + */ +export type ValueConstraint = { + kind: 'value'; + value: number | boolean | string; + type: ConstraintValueTypes; +}; + +/** + * Terms for comparison constraints + */ +export type ComparisonTerm = VariableConstraint | ValueConstraint; + +/** + * Comparison constraint + */ +export type ComparisonConstraint = { + kind: 'eq' | 'ne' | 'gt' | 'gte' | 'lt' | 'lte'; + left: ComparisonTerm; + right: ComparisonTerm; +}; + +/** + * Logical constraint + */ +export type LogicalConstraint = { + kind: 'and' | 'or' | 'not'; + children: CheckerConstraint[]; +}; + +/** + * Operation allowability checking constraint + */ +export type CheckerConstraint = ValueConstraint | VariableConstraint | ComparisonConstraint | LogicalConstraint; + /** * Function for getting policy guard with a given context */ @@ -71,6 +122,8 @@ export type PolicyDef = { } >; + checker?: Record>; + // tracks which models have data validation rules validation: Record; diff --git a/packages/runtime/src/types.ts b/packages/runtime/src/types.ts index 4bcab85a1..4c32480ba 100644 --- a/packages/runtime/src/types.ts +++ b/packages/runtime/src/types.ts @@ -22,6 +22,7 @@ export interface DbOperations { groupBy(args: unknown): Promise; count(args?: unknown): Promise; subscribe(args?: unknown): Promise; + check(args: unknown): Promise; fields: Record; } @@ -30,10 +31,12 @@ export interface DbOperations { */ export type PolicyKind = 'allow' | 'deny'; +export type PolicyCrudKind = 'read' | 'create' | 'update' | 'delete'; + /** * Kinds of operations controlled by access policies */ -export type PolicyOperationKind = 'create' | 'update' | 'postUpdate' | 'read' | 'delete'; +export type PolicyOperationKind = PolicyCrudKind | 'postUpdate'; /** * Current login user info @@ -56,6 +59,21 @@ export type QueryContext = { preValue?: any; }; +/** + * Context for checking operation allowability. + */ +export type CheckerContext = { + /** + * Current user + */ + user?: AuthUser; + + /** + * Extra field value filters. + */ + fieldValues?: Record; +}; + /** * Prisma contract for CRUD operations. */ diff --git a/packages/schema/src/plugins/enhancer/enhance/checker-type-generator.ts b/packages/schema/src/plugins/enhancer/enhance/checker-type-generator.ts new file mode 100644 index 000000000..fe6c415cc --- /dev/null +++ b/packages/schema/src/plugins/enhancer/enhance/checker-type-generator.ts @@ -0,0 +1,55 @@ +import { getDataModels } from '@zenstackhq/sdk'; +import type { DataModel, DataModelField, Model } from '@zenstackhq/sdk/ast'; +import { lowerCaseFirst } from 'lower-case-first'; +import { P, match } from 'ts-pattern'; + +/** + * Generates a `ModelCheckers` interface that contains a `check` method for each model in the schema. + * + * E.g.: + * + * ```ts + * type CheckerOperation = 'create' | 'read' | 'update' | 'delete'; + * + * export interface ModelCheckers { + * user: { check(op: CheckerOperation, args?: { email?: string; age?: number; }): Promise }, + * ... + * } + * ``` + */ +export function generateCheckerType(model: Model) { + return ` +import type { PolicyCrudKind } from '@zenstackhq/runtime'; + +export interface ModelCheckers { + ${getDataModels(model) + .map((dataModel) => `\t${lowerCaseFirst(dataModel.name)}: ${generateDataModelChecker(dataModel)}`) + .join(',\n')} +} +`; +} + +function generateDataModelChecker(dataModel: DataModel) { + return `{ + check(args: { operation: PolicyCrudKind, filter?: ${generateDataModelArgs(dataModel)} }): Promise + }`; +} + +function generateDataModelArgs(dataModel: DataModel) { + return `{ ${dataModel.fields + .filter((field) => isFieldFilterable(field)) + .map((field) => `${field.name}?: ${mapFieldType(field)}`) + .join('; ')} }`; +} + +function isFieldFilterable(field: DataModelField) { + return !!mapFieldType(field); +} + +function mapFieldType(field: DataModelField) { + return match(field.type.type) + .with('Boolean', () => 'boolean') + .with(P.union('BigInt', 'Int', 'Float', 'Decimal'), () => 'number') + .with('String', () => 'string') + .otherwise(() => undefined); +} diff --git a/packages/schema/src/plugins/enhancer/enhance/index.ts b/packages/schema/src/plugins/enhancer/enhance/index.ts index 0425b76d0..62b1d03e6 100644 --- a/packages/schema/src/plugins/enhancer/enhance/index.ts +++ b/packages/schema/src/plugins/enhancer/enhance/index.ts @@ -41,6 +41,7 @@ import { trackPrismaSchemaError } from '../../prisma'; import { PrismaSchemaGenerator } from '../../prisma/schema-generator'; import { isDefaultWithAuth } from '../enhancer-utils'; import { generateAuthType } from './auth-type-generator'; +import { generateCheckerType } from './checker-type-generator'; // information of delegate models and their sub models type DelegateInfo = [DataModel, DataModel[]][]; @@ -89,6 +90,8 @@ export class EnhancerGenerator { const authTypes = authModel ? generateAuthType(this.model, authModel) : ''; const authTypeParam = authModel ? `auth.${authModel.name}` : 'AuthUser'; + const checkerTypes = this.generatePermissionChecker ? generateCheckerType(this.model) : ''; + const enhanceTs = this.project.createSourceFile( path.join(this.outDir, 'enhance.ts'), `import { type EnhancementContext, type EnhancementOptions, type ZodSchemas, type AuthUser } from '@zenstackhq/runtime'; @@ -105,6 +108,8 @@ ${ ${authTypes} +${checkerTypes} + ${ logicalPrismaClientDir ? this.createLogicalPrismaEnhanceFunction(authTypeParam) @@ -126,15 +131,16 @@ import type * as _P from '${prismaImport}'; } private createSimplePrismaEnhanceFunction(authTypeParam: string) { + const returnType = `DbClient${this.generatePermissionChecker ? ' & ModelCheckers' : ''}`; return ` -export function enhance(prisma: DbClient, context?: EnhancementContext<${authTypeParam}>, options?: EnhancementOptions) { +export function enhance(prisma: DbClient, context?: EnhancementContext<${authTypeParam}>, options?: EnhancementOptions): ${returnType} { return createEnhancement(prisma, { modelMeta, policy, zodSchemas: zodSchemas as unknown as (ZodSchemas | undefined), prismaModule: Prisma, ...options - }, context); + }, context) as ${returnType}; } `; } @@ -157,12 +163,16 @@ import type { Prisma, PrismaClient } from '${logicalPrismaClientDir}/index-fixed // overload for plain PrismaClient export function enhance & InternalArgs>( prisma: _PrismaClient, - context?: EnhancementContext<${authTypeParam}>, options?: EnhancementOptions): PrismaClient; + context?: EnhancementContext<${authTypeParam}>, options?: EnhancementOptions): PrismaClient${ + this.generatePermissionChecker ? ' & ModelCheckers' : '' + }; // overload for extended PrismaClient export function enhance & InternalArgs>( prisma: DynamicClientExtensionThis, - context?: EnhancementContext<${authTypeParam}>, options?: EnhancementOptions): DynamicClientExtensionThis; + context?: EnhancementContext<${authTypeParam}>, options?: EnhancementOptions): DynamicClientExtensionThis${ + this.generatePermissionChecker ? ' & ModelCheckers' : '' + }; export function enhance(prisma: any, context?: EnhancementContext<${authTypeParam}>, options?: EnhancementOptions): any { return createEnhancement(prisma, { @@ -622,4 +632,8 @@ export function enhance(prisma: any, context?: EnhancementContext<${authTypePara await sf.save(); } } + + private get generatePermissionChecker() { + return this.options.generatePermissionChecker === true; + } } diff --git a/packages/schema/src/plugins/enhancer/policy/constraint-transformer.ts b/packages/schema/src/plugins/enhancer/policy/constraint-transformer.ts new file mode 100644 index 000000000..a0b1c1dd2 --- /dev/null +++ b/packages/schema/src/plugins/enhancer/policy/constraint-transformer.ts @@ -0,0 +1,379 @@ +import { + getRelationKeyPairs, + isAuthInvocation, + isDataModelFieldReference, + isEnumFieldReference, +} from '@zenstackhq/sdk'; +import { + BinaryExpr, + BooleanLiteral, + DataModelField, + Expression, + ExpressionType, + LiteralExpr, + MemberAccessExpr, + NumberLiteral, + ReferenceExpr, + StringLiteral, + UnaryExpr, + isBinaryExpr, + isDataModelField, + isEnum, + isLiteralExpr, + isMemberAccessExpr, + isNullExpr, + isReferenceExpr, + isThisExpr, + isUnaryExpr, +} from '@zenstackhq/sdk/ast'; +import { P, match } from 'ts-pattern'; + +/** + * Options for {@link ConstraintTransformer}. + */ +export type ConstraintTransformerOptions = { + authAccessor: string; +}; + +/** + * Transform a set of allow and deny rules into a single constraint expression. + */ +export class ConstraintTransformer { + // a counter for generating unique variable names + private varCounter = 0; + + constructor(private readonly options: ConstraintTransformerOptions) {} + + /** + * Transforms a set of allow and deny rules into a single constraint expression. + */ + transformRules(allows: Expression[], denies: Expression[]): string { + // reset state + this.varCounter = 0; + + if (allows.length === 0) { + // unconditionally deny + return this.value('false', 'boolean'); + } + + let result: string; + + // transform allow rules + const allowConstraints = allows.map((allow) => this.transformExpression(allow)); + if (allowConstraints.length > 1) { + result = this.or(...allowConstraints); + } else { + result = allowConstraints[0]; + } + + // transform deny rules and compose + if (denies.length > 0) { + const denyConstraints = denies.map((deny) => this.transformExpression(deny)); + result = this.and(result, ...denyConstraints.map((c) => this.not(c))); + } + + // DEBUG: + // console.log(`Constraint transformation result:\n${JSON.stringify(result, null, 2)}`); + + return result; + } + + private and(...constraints: string[]) { + if (constraints.length === 0) { + throw new Error('No expressions to combine'); + } + return constraints.length === 1 ? constraints[0] : `{ kind: 'and', children: [ ${constraints.join(', ')} ] }`; + } + + private or(...constraints: string[]) { + if (constraints.length === 0) { + throw new Error('No expressions to combine'); + } + return constraints.length === 1 ? constraints[0] : `{ kind: 'or', children: [ ${constraints.join(', ')} ] }`; + } + + private not(constraint: string) { + return `{ kind: 'not', children: [${constraint}] }`; + } + + private transformExpression(expression: Expression) { + return ( + match(expression) + .when(isBinaryExpr, (expr) => this.transformBinary(expr)) + .when(isUnaryExpr, (expr) => this.transformUnary(expr)) + // top-level boolean literal + .when(isLiteralExpr, (expr) => this.transformLiteral(expr)) + // top-level boolean reference expr + .when(isReferenceExpr, (expr) => this.transformReference(expr)) + // top-level boolean member access expr + .when(isMemberAccessExpr, (expr) => this.transformMemberAccess(expr)) + .otherwise(() => this.nextVar()) + ); + } + + private transformLiteral(expr: LiteralExpr) { + return match(expr.$type) + .with(NumberLiteral, () => { + const parsed = parseFloat(expr.value as string); + if (isNaN(parsed) || parsed < 0 || !Number.isInteger(parsed)) { + // only non-negative integers are supported, for other cases, + // transform into a free variable + return this.nextVar('number'); + } + return this.value(expr.value.toString(), 'number'); + }) + .with(StringLiteral, () => this.value(`'${expr.value}'`, 'string')) + .with(BooleanLiteral, () => this.value(expr.value.toString(), 'boolean')) + .exhaustive(); + } + + private transformReference(expr: ReferenceExpr) { + // top-level reference is transformed into a named variable + return this.variable(expr.target.$refText, 'boolean'); + } + + private transformMemberAccess(expr: MemberAccessExpr) { + // "this.x" is transformed into a named variable + if (isThisExpr(expr.operand)) { + return this.variable(expr.member.$refText, 'boolean'); + } + + // top-level auth access + const authAccess = this.getAuthAccess(expr); + if (authAccess) { + return this.value(`${authAccess} ?? false`, 'boolean'); + } + + // other top-level member access expressions are not supported + // and thus transformed into a free variable + return this.nextVar(); + } + + private transformBinary(expr: BinaryExpr): string { + return ( + match(expr.operator) + .with('&&', () => this.and(this.transformExpression(expr.left), this.transformExpression(expr.right))) + .with('||', () => this.or(this.transformExpression(expr.left), this.transformExpression(expr.right))) + .with(P.union('==', '!=', '<', '<=', '>', '>='), () => this.transformComparison(expr)) + // unsupported operators (e.g., collection predicate) are transformed into a free variable + .otherwise(() => this.nextVar()) + ); + } + + private transformUnary(expr: UnaryExpr): string { + return match(expr.operator) + .with('!', () => this.not(this.transformExpression(expr.operand))) + .exhaustive(); + } + + private transformComparison(expr: BinaryExpr) { + if (isAuthInvocation(expr.left) || isAuthInvocation(expr.right)) { + // handle the case if any operand is `auth()` invocation + const authComparison = this.transformAuthComparison(expr); + return authComparison ?? this.nextVar(); + } + + const leftOperand = this.getComparisonOperand(expr.left); + const rightOperand = this.getComparisonOperand(expr.right); + + const op = this.mapOperatorToConstraintKind(expr.operator); + const result = `{ kind: '${op}', left: ${leftOperand}, right: ${rightOperand} }`; + + // `auth()` member access can be undefined, when that happens, we assume a false condition + // for the comparison + + const leftAuthAccess = this.getAuthAccess(expr.left); + const rightAuthAccess = this.getAuthAccess(expr.right); + + if (leftAuthAccess && rightOperand) { + // `auth().f op x` => `auth().f !== undefined && auth().f op x` + return this.and(this.value(`${this.normalizeToNull(leftAuthAccess)} !== null`, 'boolean'), result); + } else if (rightAuthAccess && leftOperand) { + // `x op auth().f` => `auth().f !== undefined && x op auth().f` + return this.and(this.value(`${this.normalizeToNull(rightAuthAccess)} !== null`, 'boolean'), result); + } + + if (leftOperand === undefined || rightOperand === undefined) { + // if either operand is not supported, transform into a free variable + return this.nextVar(); + } + + return result; + } + + private transformAuthComparison(expr: BinaryExpr) { + if (this.isAuthEqualNull(expr)) { + // `auth() == null` => `user === null` + return this.value(`${this.options.authAccessor} === null`, 'boolean'); + } + + if (this.isAuthNotEqualNull(expr)) { + // `auth() != null` => `user !== null` + return this.value(`${this.options.authAccessor} !== null`, 'boolean'); + } + + // auth() equality check against a relation, translate to id-fk comparison + const operand = isAuthInvocation(expr.left) ? expr.right : expr.left; + if (!isDataModelFieldReference(operand)) { + return undefined; + } + + // get id-fk field pairs from the relation field + const relationField = operand.target.ref as DataModelField; + const idFkPairs = getRelationKeyPairs(relationField); + + // build id-fk field comparison constraints + const fieldConstraints: string[] = []; + + idFkPairs.forEach(({ id, foreignKey }) => { + const idFieldType = this.mapType(id.type.type as ExpressionType); + if (!idFieldType) { + return; + } + const fkFieldType = this.mapType(foreignKey.type.type as ExpressionType); + if (!fkFieldType) { + return; + } + + const op = this.mapOperatorToConstraintKind(expr.operator); + const authIdAccess = `${this.options.authAccessor}?.${id.name}`; + + fieldConstraints.push( + this.and( + // `auth()?.id != null` guard + this.value(`${this.normalizeToNull(authIdAccess)} !== null`, 'boolean'), + // `auth()?.id [op] fkField` + `{ kind: '${op}', left: ${this.value(authIdAccess, idFieldType)}, right: ${this.variable( + foreignKey.name, + fkFieldType + )} }` + ) + ); + }); + + // combine field constraints + if (fieldConstraints.length > 0) { + return this.and(...fieldConstraints); + } + + return undefined; + } + + // normalize `auth()` access undefined value to null + private normalizeToNull(expr: string) { + return `(${expr} ?? null)`; + } + + private isAuthEqualNull(expr: BinaryExpr) { + return ( + expr.operator === '==' && + ((isAuthInvocation(expr.left) && isNullExpr(expr.right)) || + (isAuthInvocation(expr.right) && isNullExpr(expr.left))) + ); + } + + private isAuthNotEqualNull(expr: BinaryExpr) { + return ( + expr.operator === '!=' && + ((isAuthInvocation(expr.left) && isNullExpr(expr.right)) || + (isAuthInvocation(expr.right) && isNullExpr(expr.left))) + ); + } + + private getComparisonOperand(expr: Expression) { + if (isLiteralExpr(expr)) { + return this.transformLiteral(expr); + } + + if (isEnumFieldReference(expr)) { + return this.value(`'${expr.target.$refText}'`, 'string'); + } + + const fieldAccess = this.getFieldAccess(expr); + if (fieldAccess) { + // model field access is transformed into a named variable + const mappedType = this.mapExpressionType(expr); + if (mappedType) { + return this.variable(fieldAccess.name, mappedType); + } else { + return undefined; + } + } + + const authAccess = this.getAuthAccess(expr); + if (authAccess) { + const mappedType = this.mapExpressionType(expr); + if (mappedType) { + return `${this.value(authAccess, mappedType)}`; + } else { + return undefined; + } + } + + return undefined; + } + + private mapExpressionType(expression: Expression) { + if (isEnum(expression.$resolvedType?.decl)) { + return 'string'; + } else { + return this.mapType(expression.$resolvedType?.decl as ExpressionType); + } + } + + private mapType(type: ExpressionType) { + return match(type) + .with('Boolean', () => 'boolean') + .with('Int', () => 'number') + .with('String', () => 'string') + .otherwise(() => undefined); + } + + private mapOperatorToConstraintKind(operator: BinaryExpr['operator']) { + return match(operator) + .with('==', () => 'eq') + .with('!=', () => 'ne') + .with('<', () => 'lt') + .with('<=', () => 'lte') + .with('>', () => 'gt') + .with('>=', () => 'gte') + .otherwise(() => { + throw new Error(`Unsupported operator: ${operator}`); + }); + } + + private getFieldAccess(expr: Expression) { + if (isReferenceExpr(expr)) { + return isDataModelField(expr.target.ref) ? { name: expr.target.$refText } : undefined; + } + if (isMemberAccessExpr(expr)) { + return isThisExpr(expr.operand) ? { name: expr.member.$refText } : undefined; + } + return undefined; + } + + private getAuthAccess(expr: Expression): string | undefined { + if (!isMemberAccessExpr(expr)) { + return undefined; + } + + if (isAuthInvocation(expr.operand)) { + return `${this.options.authAccessor}?.${expr.member.$refText}`; + } else { + const operand = this.getAuthAccess(expr.operand); + return operand ? `${operand}?.${expr.member.$refText}` : undefined; + } + } + + private nextVar(type = 'boolean') { + return this.variable(`__var${this.varCounter++}`, type); + } + + private variable(name: string, type: string) { + return `{ kind: 'variable', name: '${name}', type: '${type}' }`; + } + + private value(value: string, type: string) { + return `{ kind: 'value', value: ${value}, type: '${type}' }`; + } +} diff --git a/packages/schema/src/plugins/enhancer/policy/policy-guard-generator.ts b/packages/schema/src/plugins/enhancer/policy/policy-guard-generator.ts index 753ef8f19..a36a52126 100644 --- a/packages/schema/src/plugins/enhancer/policy/policy-guard-generator.ts +++ b/packages/schema/src/plugins/enhancer/policy/policy-guard-generator.ts @@ -54,6 +54,7 @@ import { FunctionDeclaration, Project, SourceFile, VariableDeclarationKind, Writ import { name } from '..'; import { isCollectionPredicate } from '../../../utils/ast-utils'; import { ALL_OPERATION_KINDS } from '../../plugin-utils'; +import { ConstraintTransformer } from './constraint-transformer'; import { ExpressionWriter, FALSE, TRUE } from './expression-writer'; /** @@ -70,6 +71,8 @@ export class PolicyGenerator { { name: 'type CrudContract' }, { name: 'allFieldsEqual' }, { name: 'type PolicyDef' }, + { name: 'type CheckerContext' }, + { name: 'type CheckerConstraint' }, ], moduleSpecifier: `${RUNTIME_PACKAGE}`, }); @@ -85,11 +88,22 @@ export class PolicyGenerator { const models = getDataModels(model); + // policy guard functions const policyMap: Record> = {}; for (const model of models) { policyMap[model.name] = await this.generateQueryGuardForModel(model, sf); } + const generatePermissionChecker = options.generatePermissionChecker === true; + + // CRUD checker functions + const checkerMap: Record> = {}; + if (generatePermissionChecker) { + for (const model of models) { + checkerMap[model.name] = await this.generateCheckerForModel(model, sf); + } + } + const authSelector = this.generateAuthSelector(models); sf.addVariableStatement({ @@ -118,6 +132,22 @@ export class PolicyGenerator { }); writer.writeLine(','); + if (generatePermissionChecker) { + writer.write('checker:'); + writer.inlineBlock(() => { + for (const [model, map] of Object.entries(checkerMap)) { + writer.write(`${lowerCaseFirst(model)}:`); + writer.inlineBlock(() => { + Object.entries(map).forEach(([op, func]) => { + writer.write(`${op}: ${func},`); + }); + }); + writer.writeLine(','); + } + }); + writer.writeLine(','); + } + writer.write('validation:'); writer.inlineBlock(() => { for (const model of models) { @@ -301,7 +331,6 @@ export class PolicyGenerator { } const guardFunc = this.generateQueryGuardFunction(sourceFile, model, kind, allows, denies); - // eslint-disable-next-line @typescript-eslint/no-non-null-assertion result[kind] = guardFunc.getName()!; if (kind === 'postUpdate') { @@ -313,7 +342,6 @@ export class PolicyGenerator { if (kind === 'create' && this.canCheckCreateBasedOnInput(model, allows, denies)) { const inputCheckFunc = this.generateInputCheckFunction(sourceFile, model, kind, allows, denies); - // eslint-disable-next-line @typescript-eslint/no-non-null-assertion result[kind + '_input'] = inputCheckFunc.getName()!; } } @@ -847,4 +875,70 @@ export class PolicyGenerator { statements.push(`const user: any = context.user ?? null;`); } } + + private async generateCheckerForModel(model: DataModel, sourceFile: SourceFile) { + const result: Record = {}; + + // eslint-disable-next-line @typescript-eslint/no-explicit-any + const policies = analyzePolicies(model); + + for (const kind of ['create', 'read', 'update', 'delete'] as const) { + if (policies[kind] === true || policies[kind] === false) { + result[kind] = policies[kind] as boolean; + continue; + } + + const denies = this.getPolicyExpressions(model, 'deny', kind); + const allows = this.getPolicyExpressions(model, 'allow', kind); + + if (kind === 'update' && allows.length === 0) { + // no allow rule for 'update', policy is constant based on if there's + // post-update counterpart + if (this.getPolicyExpressions(model, 'allow', 'postUpdate').length === 0) { + result[kind] = false; + continue; + } else { + result[kind] = true; + continue; + } + } + + const guardFunc = this.generateCheckerFunction(sourceFile, model, kind, allows, denies); + result[kind] = guardFunc.getName()!; + } + + return result; + } + + private generateCheckerFunction( + sourceFile: SourceFile, + model: DataModel, + kind: string, + allows: Expression[], + denies: Expression[] + ) { + const statements: string[] = []; + + this.generateNormalizedAuthRef(model, allows, denies, statements); + + const transformed = new ConstraintTransformer({ + authAccessor: 'user', + }).transformRules(allows, denies); + + statements.push(`return ${transformed};`); + + const func = sourceFile.addFunction({ + name: `${model.name}$checker$${kind}`, + returnType: 'CheckerConstraint', + parameters: [ + { + name: 'context', + type: 'CheckerContext', + }, + ], + statements, + }); + + return func; + } } diff --git a/packages/sdk/src/utils.ts b/packages/sdk/src/utils.ts index 6617983aa..1bed512c8 100644 --- a/packages/sdk/src/utils.ts +++ b/packages/sdk/src/utils.ts @@ -298,9 +298,9 @@ export function isForeignKeyField(field: DataModelField) { } /** - * Gets the foreign key fields of the given relation field. + * Gets the foreign key-id field pairs from the given relation field. */ -export function getForeignKeyFields(relationField: DataModelField) { +export function getRelationKeyPairs(relationField: DataModelField) { if (!isRelationshipField(relationField)) { return []; } @@ -309,11 +309,31 @@ export function getForeignKeyFields(relationField: DataModelField) { if (relAttr) { // find "fields" arg const fieldsArg = getAttributeArg(relAttr, 'fields'); + let fkFields: DataModelField[]; if (fieldsArg && isArrayExpr(fieldsArg)) { - return fieldsArg.items + fkFields = fieldsArg.items .filter((item): item is ReferenceExpr => isReferenceExpr(item)) .map((item) => item.target.ref as DataModelField); + } else { + return []; } + + // find "references" arg + const referencesArg = getAttributeArg(relAttr, 'references'); + let idFields: DataModelField[]; + if (referencesArg && isArrayExpr(referencesArg)) { + idFields = referencesArg.items + .filter((item): item is ReferenceExpr => isReferenceExpr(item)) + .map((item) => item.target.ref as DataModelField); + } else { + return []; + } + + if (idFields.length !== fkFields.length) { + throw new Error(`Relation's references arg and fields are must have equal length`); + } + + return idFields.map((idField, i) => ({ id: idField, foreignKey: fkFields[i] })); } return []; diff --git a/packages/server/src/api/rpc/index.ts b/packages/server/src/api/rpc/index.ts index a7fb44d72..a8882b8be 100644 --- a/packages/server/src/api/rpc/index.ts +++ b/packages/server/src/api/rpc/index.ts @@ -81,6 +81,7 @@ class RequestHandler extends APIHandlerBase { case 'aggregate': case 'groupBy': case 'count': + case 'check': if (method !== 'GET') { return { status: 400, diff --git a/packages/server/tests/api/rpc.test.ts b/packages/server/tests/api/rpc.test.ts index 432abec2c..b24a7c108 100644 --- a/packages/server/tests/api/rpc.test.ts +++ b/packages/server/tests/api/rpc.test.ts @@ -15,7 +15,7 @@ describe('RPC API Handler Tests', () => { let zodSchemas: any; beforeAll(async () => { - const params = await loadSchema(schema, { fullZod: true }); + const params = await loadSchema(schema, { fullZod: true, generatePermissionChecker: true }); prisma = params.prisma; enhance = params.enhance; modelMeta = params.modelMeta; @@ -131,6 +131,37 @@ describe('RPC API Handler Tests', () => { expect(r.data.count).toBe(1); }); + it('check', async () => { + const handleRequest = makeHandler(); + + let r = await handleRequest({ + method: 'get', + path: '/post/check', + query: { q: JSON.stringify({ operation: 'read' }) }, + prisma: enhance(), + }); + expect(r.status).toBe(200); + expect(r.data).toEqual(true); + + r = await handleRequest({ + method: 'get', + path: '/post/check', + query: { q: JSON.stringify({ operation: 'read', filter: { published: false } }) }, + prisma: enhance(), + }); + expect(r.status).toBe(200); + expect(r.data).toEqual(false); + + r = await handleRequest({ + method: 'get', + path: '/post/check', + query: { q: JSON.stringify({ operation: 'read', filter: { authorId: '1', published: false } }) }, + prisma: enhance({ id: '1' }), + }); + expect(r.status).toBe(200); + expect(r.data).toEqual(true); + }); + it('policy violation', async () => { await prisma.user.create({ data: { diff --git a/packages/testtools/src/schema.ts b/packages/testtools/src/schema.ts index 7249e6c4a..c2109e579 100644 --- a/packages/testtools/src/schema.ts +++ b/packages/testtools/src/schema.ts @@ -102,6 +102,7 @@ generator js { plugin enhancer { provider = '@core/enhancer' ${options.preserveTsFiles ? 'preserveTsFiles = true' : ''} + ${options.generatePermissionChecker ? 'generatePermissionChecker = true' : ''} } plugin zod { @@ -131,6 +132,7 @@ export type SchemaLoadOptions = { extraSourceFiles?: { name: string; content: string }[]; projectDir?: string; preserveTsFiles?: boolean; + generatePermissionChecker?: boolean; }; const defaultOptions: SchemaLoadOptions = { diff --git a/pnpm-lock.yaml b/pnpm-lock.yaml index 5ce4ac030..12e3f6606 100644 --- a/pnpm-lock.yaml +++ b/pnpm-lock.yaml @@ -409,6 +409,9 @@ importers: deepmerge: specifier: ^4.3.1 version: 4.3.1 + logic-solver: + specifier: ^2.0.1 + version: 2.0.1 lower-case-first: specifier: ^2.0.2 version: 2.0.2 @@ -427,6 +430,9 @@ importers: tiny-invariant: specifier: ^1.3.1 version: 1.3.1 + ts-pattern: + specifier: ^4.3.0 + version: 4.3.0 tslib: specifier: ^2.4.1 version: 2.4.1 @@ -10495,6 +10501,12 @@ packages: wrap-ansi: 8.1.0 dev: false + /logic-solver@2.0.1: + resolution: {integrity: sha512-F1oCywXUzvAF4Z98mMyXySUCpUU3hNyc+JfYV3g2x/4BupC/xv94iPJuHh9us2XX5UrvY5lnKUXNvjcJNQBJ/g==} + dependencies: + underscore: 1.13.6 + dev: false + /loose-envify@1.4.0: resolution: {integrity: sha512-lyuxPGr/Wfhrlem2CL/UcnUc1zcqKAImBDzukY7Y5F/yQiNdko6+fRLevlw1HgMySw7f611UIY408EtxRSoK3Q==} hasBin: true @@ -14372,7 +14384,6 @@ packages: /underscore@1.13.6: resolution: {integrity: sha512-+A5Sja4HP1M08MaXya7p5LvjuM7K6q/2EaC0+iovj/wOcMsTzMvDFbasi/oSapiwOlt252IqsKqPjCl7huKS0A==} - dev: true /undici-types@5.26.5: resolution: {integrity: sha512-JlCMO+ehdEIKqlFxk6IfVoAUVmgz7cU7zD/h9XZ0qzeosSHmUJVOzSQvvYSYWXkFXC+IfLKSIffhv0sVZup6pA==} diff --git a/tests/integration/tests/enhancements/with-policy/checker.test.ts b/tests/integration/tests/enhancements/with-policy/checker.test.ts new file mode 100644 index 000000000..4a2c0193e --- /dev/null +++ b/tests/integration/tests/enhancements/with-policy/checker.test.ts @@ -0,0 +1,652 @@ +import { SchemaLoadOptions, createPostgresDb, dropPostgresDb, loadSchema } from '@zenstackhq/testtools'; + +describe('Permission checker', () => { + const PRELUDE = ` + datasource db { + provider = 'sqlite' + url = 'file:./dev.db' + } + + generator js { + provider = 'prisma-client-js' + } + + plugin enhancer { + provider = '@core/enhancer' + generatePermissionChecker = true + } + `; + + const load = (schema: string, options?: SchemaLoadOptions) => + loadSchema(schema, { + ...options, + generatePermissionChecker: true, + }); + + it('checker generation not enabled', async () => { + const { enhance } = await loadSchema( + ` + model Model { + id Int @id @default(autoincrement()) + value Int + @@allow('all', true) + } + ` + ); + const db = enhance(); + await expect(db.model.check({ operation: 'read' })).rejects.toThrow('Generated permission checkers not found'); + }); + + it('empty rules', async () => { + const { enhance } = await load( + ` + model Model { + id Int @id @default(autoincrement()) + value Int + } + ` + ); + const db = enhance(); + await expect(db.model.check({ operation: 'read' })).toResolveFalsy(); + await expect(db.model.check({ operation: 'read', filter: { value: 1 } })).toResolveFalsy(); + }); + + it('unconditional allow', async () => { + const { enhance } = await load( + ` + model Model { + id Int @id @default(autoincrement()) + value Int + @@allow('all', true) + } + ` + ); + const db = enhance(); + await expect(db.model.check({ operation: 'read' })).toResolveTruthy(); + await expect(db.model.check({ operation: 'read', filter: { value: 0 } })).toResolveTruthy(); + }); + + it('multiple allow rules', async () => { + const { enhance } = await load( + ` + model Model { + id Int @id @default(autoincrement()) + value Int + @@allow('all', value == 1) + @@allow('all', value == 2) + } + ` + ); + const db = enhance(); + await expect(db.model.check({ operation: 'read' })).toResolveTruthy(); + await expect(db.model.check({ operation: 'read', filter: { value: 0 } })).toResolveFalsy(); + await expect(db.model.check({ operation: 'read', filter: { value: 1 } })).toResolveTruthy(); + await expect(db.model.check({ operation: 'read', filter: { value: 2 } })).toResolveTruthy(); + }); + + it('deny rule', async () => { + const { enhance } = await load( + ` + model Model { + id Int @id @default(autoincrement()) + value Int + @@allow('all', value > 0) + @@deny('all', value == 1) + } + ` + ); + const db = enhance(); + await expect(db.model.check({ operation: 'read' })).toResolveTruthy(); + await expect(db.model.check({ operation: 'read', filter: { value: 0 } })).toResolveFalsy(); + await expect(db.model.check({ operation: 'read', filter: { value: 1 } })).toResolveFalsy(); + await expect(db.model.check({ operation: 'read', filter: { value: 2 } })).toResolveTruthy(); + }); + + it('int field condition', async () => { + const { enhance } = await load( + ` + model Model { + id Int @id @default(autoincrement()) + value Int + @@allow('read', value == 1) + @@allow('create', value != 1) + @@allow('update', value > 1) + @@allow('delete', value <= 1) + } + ` + ); + + const db = enhance(); + await expect(db.model.check({ operation: 'read' })).toResolveTruthy(); + await expect(db.model.check({ operation: 'read', filter: { value: 0 } })).toResolveFalsy(); + await expect(db.model.check({ operation: 'read', filter: { value: 1 } })).toResolveTruthy(); + + await expect(db.model.check({ operation: 'create' })).toResolveTruthy(); + await expect(db.model.check({ operation: 'create', filter: { value: 0 } })).toResolveTruthy(); + await expect(db.model.check({ operation: 'create', filter: { value: 1 } })).toResolveFalsy(); + + await expect(db.model.check({ operation: 'update' })).toResolveTruthy(); + await expect(db.model.check({ operation: 'update', filter: { value: 1 } })).toResolveFalsy(); + await expect(db.model.check({ operation: 'update', filter: { value: 2 } })).toResolveTruthy(); + + await expect(db.model.check({ operation: 'delete' })).toResolveTruthy(); + await expect(db.model.check({ operation: 'delete', filter: { value: 0 } })).toResolveTruthy(); + await expect(db.model.check({ operation: 'delete', filter: { value: 1 } })).toResolveTruthy(); + await expect(db.model.check({ operation: 'delete', filter: { value: 2 } })).toResolveFalsy(); + }); + + it('boolean field toplevel condition', async () => { + const { enhance } = await load( + ` + model Model { + id Int @id @default(autoincrement()) + value Boolean + @@allow('read', value) + } + ` + ); + + const db = enhance(); + await expect(db.model.check({ operation: 'read' })).toResolveTruthy(); + await expect(db.model.check({ operation: 'read', filter: { value: false } })).toResolveFalsy(); + await expect(db.model.check({ operation: 'read', filter: { value: true } })).toResolveTruthy(); + }); + + it('boolean field condition', async () => { + const { enhance } = await load( + ` + model Model { + id Int @id @default(autoincrement()) + value Boolean + @@allow('read', value == true) + @@allow('create', value == false) + @@allow('update', value != true) + @@allow('delete', value != false) + } + ` + ); + + const db = enhance(); + await expect(db.model.check({ operation: 'read' })).toResolveTruthy(); + await expect(db.model.check({ operation: 'read', filter: { value: false } })).toResolveFalsy(); + await expect(db.model.check({ operation: 'read', filter: { value: true } })).toResolveTruthy(); + + await expect(db.model.check({ operation: 'create' })).toResolveTruthy(); + await expect(db.model.check({ operation: 'create', filter: { value: true } })).toResolveFalsy(); + await expect(db.model.check({ operation: 'create', filter: { value: false } })).toResolveTruthy(); + + await expect(db.model.check({ operation: 'update' })).toResolveTruthy(); + await expect(db.model.check({ operation: 'update', filter: { value: true } })).toResolveFalsy(); + await expect(db.model.check({ operation: 'update', filter: { value: false } })).toResolveTruthy(); + + await expect(db.model.check({ operation: 'delete' })).toResolveTruthy(); + await expect(db.model.check({ operation: 'delete', filter: { value: false } })).toResolveFalsy(); + await expect(db.model.check({ operation: 'delete', filter: { value: true } })).toResolveTruthy(); + }); + + it('string field condition', async () => { + const { enhance } = await load( + ` + model Model { + id Int @id @default(autoincrement()) + value String + @@allow('read', value == 'admin') + } + ` + ); + + const db = enhance(); + await expect(db.model.check({ operation: 'read' })).toResolveTruthy(); + await expect(db.model.check({ operation: 'read', filter: { value: 'user' } })).toResolveFalsy(); + await expect(db.model.check({ operation: 'read', filter: { value: 'admin' } })).toResolveTruthy(); + }); + + it('enum', async () => { + const dbUrl = await createPostgresDb('permission-checker-enum'); + let prisma: any; + try { + const r = await loadSchema( + ` + datasource db { + provider = 'postgresql' + url = '${dbUrl}' + } + + generator js { + provider = 'prisma-client-js' + } + + plugin enhancer { + provider = '@core/enhancer' + generatePermissionChecker = true + } + + enum Role { + USER + ADMIN + } + model User { + id Int @id @default(autoincrement()) + role Role + } + model Model { + id Int @id @default(autoincrement()) + @@allow('read', auth().role == ADMIN) + } + `, + { addPrelude: false, generatePermissionChecker: true } + ); + + prisma = r.prisma; + const enhance = r.enhance; + + await expect(enhance().model.check({ operation: 'read' })).toResolveFalsy(); + await expect(enhance({ id: 1, role: 'USER' }).model.check({ operation: 'read' })).toResolveFalsy(); + await expect(enhance({ id: 1, role: 'ADMIN' }).model.check({ operation: 'read' })).toResolveTruthy(); + } finally { + await prisma.$disconnect(); + await dropPostgresDb('permission-checker-enum'); + } + }); + + it('function noop', async () => { + const { enhance } = await load( + ` + model Model { + id Int @id @default(autoincrement()) + value String + @@allow('read', startsWith(value, 'admin')) + @@allow('update', !startsWith(value, 'admin')) + } + ` + ); + + const db = enhance(); + await expect(db.model.check({ operation: 'read' })).toResolveTruthy(); + await expect(db.model.check({ operation: 'read', filter: { value: 'user' } })).toResolveTruthy(); + await expect(db.model.check({ operation: 'read', filter: { value: 'admin' } })).toResolveTruthy(); + await expect(db.model.check({ operation: 'update' })).toResolveTruthy(); + await expect(db.model.check({ operation: 'update', filter: { value: 'user' } })).toResolveTruthy(); + await expect(db.model.check({ operation: 'update', filter: { value: 'admin' } })).toResolveTruthy(); + }); + + it('relation noop', async () => { + const { enhance } = await load( + ` + model Model { + id Int @id @default(autoincrement()) + value String + foo Foo? + + @@allow('read', foo.x > 0) + } + + model Foo { + id Int @id @default(autoincrement()) + x Int + modelId Int @unique + model Model @relation(fields: [modelId], references: [id]) + } + ` + ); + + const db = enhance(); + await expect(db.model.check({ operation: 'read' })).toResolveTruthy(); + await expect(db.model.check({ operation: 'read', filter: { foo: { x: 0 } } })).rejects.toThrow( + 'Providing filter for field "foo"' + ); + }); + + it('collection predicate noop', async () => { + const { enhance } = await load( + ` + model Model { + id Int @id @default(autoincrement()) + value String + foo Foo[] + + @@allow('read', foo?[x > 0]) + } + + model Foo { + id Int @id @default(autoincrement()) + x Int + modelId Int + model Model @relation(fields: [modelId], references: [id]) + } + ` + ); + + const db = enhance(); + await expect(db.model.check({ operation: 'read' })).toResolveTruthy(); + await expect(db.model.check({ operation: 'read', filter: { foo: [{ x: 0 }] } })).rejects.toThrow( + 'Providing filter for field "foo"' + ); + }); + + it('field complex condition', async () => { + const { enhance } = await load( + ` + model Model { + id Int @id @default(autoincrement()) + x Int + y Int + @@allow('read', x > 0 && x > y) + @@allow('create', x > 1 || x > y) + @@allow('update', !(x >= y)) + } + ` + ); + + const db = enhance(); + await expect(db.model.check({ operation: 'read' })).toResolveTruthy(); + await expect(db.model.check({ operation: 'read', filter: { x: 0 } })).toResolveFalsy(); + await expect(db.model.check({ operation: 'read', filter: { x: 1 } })).toResolveTruthy(); + await expect(db.model.check({ operation: 'read', filter: { x: 1, y: 0 } })).toResolveTruthy(); + await expect(db.model.check({ operation: 'read', filter: { x: 1, y: 1 } })).toResolveFalsy(); + + await expect(db.model.check({ operation: 'create' })).toResolveTruthy(); + await expect(db.model.check({ operation: 'create', filter: { x: 0 } })).toResolveFalsy(); // numbers are non-negative + await expect(db.model.check({ operation: 'create', filter: { x: 1 } })).toResolveTruthy(); + await expect(db.model.check({ operation: 'create', filter: { x: 1, y: 0 } })).toResolveTruthy(); + await expect(db.model.check({ operation: 'create', filter: { x: 1, y: 1 } })).toResolveFalsy(); + + await expect(db.model.check({ operation: 'update' })).toResolveTruthy(); + await expect(db.model.check({ operation: 'update', filter: { x: 0 } })).toResolveTruthy(); + await expect(db.model.check({ operation: 'update', filter: { y: 0 } })).toResolveFalsy(); // numbers are non-negative + await expect(db.model.check({ operation: 'update', filter: { x: 1, y: 1 } })).toResolveFalsy(); + }); + + it('field condition unsolvable', async () => { + const { enhance } = await load( + ` + model Model { + id Int @id @default(autoincrement()) + x Int + y Int + @@allow('read', x > 0 && x < y && y <= 1) + } + ` + ); + + const db = enhance(); + await expect(db.model.check({ operation: 'read' })).toResolveFalsy(); + await expect(db.model.check({ operation: 'read', filter: { x: 0 } })).toResolveFalsy(); + await expect(db.model.check({ operation: 'read', filter: { x: 1 } })).toResolveFalsy(); + await expect(db.model.check({ operation: 'read', filter: { x: 1, y: 2 } })).toResolveFalsy(); + await expect(db.model.check({ operation: 'read', filter: { x: 1, y: 1 } })).toResolveFalsy(); + }); + + it('simple auth condition', async () => { + const { enhance } = await load( + ` + model User { + id Int @id @default(autoincrement()) + level Int + admin Boolean + } + + model Model { + id Int @id @default(autoincrement()) + value Int + @@allow('read', auth().level > 0) + @@allow('create', auth().admin) + @@allow('update', !auth().admin) + } + ` + ); + + await expect(enhance().model.check({ operation: 'read' })).toResolveFalsy(); + await expect(enhance({ id: 1 }).model.check({ operation: 'read' })).toResolveFalsy(); + await expect(enhance({ id: 1, level: 0 }).model.check({ operation: 'read' })).toResolveFalsy(); + await expect(enhance({ id: 1, level: 1 }).model.check({ operation: 'read' })).toResolveTruthy(); + + await expect(enhance().model.check({ operation: 'create' })).toResolveFalsy(); + await expect(enhance({ id: 1 }).model.check({ operation: 'create' })).toResolveFalsy(); + await expect(enhance({ id: 1, admin: false }).model.check({ operation: 'create' })).toResolveFalsy(); + await expect(enhance({ id: 1, admin: true }).model.check({ operation: 'create' })).toResolveTruthy(); + + await expect(enhance().model.check({ operation: 'update' })).toResolveTruthy(); + await expect(enhance({ id: 1 }).model.check({ operation: 'update' })).toResolveTruthy(); + await expect(enhance({ id: 1, admin: true }).model.check({ operation: 'update' })).toResolveFalsy(); + await expect(enhance({ id: 1, admin: false }).model.check({ operation: 'update' })).toResolveTruthy(); + }); + + it('auth compared with relation field', async () => { + const { enhance } = await load( + ` + model User { + id Int @id @default(autoincrement()) + models Model[] + } + + model Model { + id Int @id @default(autoincrement()) + owner User @relation(fields: [ownerId], references: [id]) + ownerId Int + @@allow('read', auth().id == ownerId) + @@allow('create', auth().id != ownerId) + @@allow('update', auth() == owner) + @@allow('delete', auth() != owner) + } + `, + { preserveTsFiles: true } + ); + + await expect(enhance().model.check({ operation: 'read' })).toResolveFalsy(); + await expect(enhance({ id: 1 }).model.check({ operation: 'read' })).toResolveTruthy(); + await expect(enhance({ id: 1 }).model.check({ operation: 'read', filter: { ownerId: 1 } })).toResolveTruthy(); + await expect(enhance({ id: 1 }).model.check({ operation: 'read', filter: { ownerId: 2 } })).toResolveFalsy(); + + await expect(enhance().model.check({ operation: 'create' })).toResolveFalsy(); + await expect(enhance({ id: 1 }).model.check({ operation: 'create' })).toResolveTruthy(); + await expect(enhance({ id: 1 }).model.check({ operation: 'create', filter: { ownerId: 1 } })).toResolveFalsy(); + await expect(enhance({ id: 1 }).model.check({ operation: 'create', filter: { ownerId: 2 } })).toResolveTruthy(); + + await expect(enhance().model.check({ operation: 'update' })).toResolveFalsy(); + await expect(enhance({ id: 1 }).model.check({ operation: 'update' })).toResolveTruthy(); + await expect(enhance({ id: 1 }).model.check({ operation: 'update', filter: { ownerId: 1 } })).toResolveTruthy(); + await expect(enhance({ id: 1 }).model.check({ operation: 'update', filter: { ownerId: 2 } })).toResolveFalsy(); + + await expect(enhance().model.check({ operation: 'delete' })).toResolveFalsy(); + await expect(enhance({ id: 1 }).model.check({ operation: 'delete' })).toResolveTruthy(); + await expect(enhance({ id: 1 }).model.check({ operation: 'delete', filter: { ownerId: 1 } })).toResolveFalsy(); + await expect(enhance({ id: 1 }).model.check({ operation: 'delete', filter: { ownerId: 2 } })).toResolveTruthy(); + }); + + it('auth null check', async () => { + const { enhance } = await load( + ` + model User { + id Int @id @default(autoincrement()) + level Int + } + + model Model { + id Int @id @default(autoincrement()) + value Int + @@allow('read', auth() != null) + @@allow('create', auth() == null) + @@allow('update', auth().level > 0) + } + ` + ); + + await expect(enhance().model.check({ operation: 'read' })).toResolveFalsy(); + await expect(enhance({ id: 1 }).model.check({ operation: 'read' })).toResolveTruthy(); + + await expect(enhance().model.check({ operation: 'create' })).toResolveTruthy(); + await expect(enhance({ id: 1 }).model.check({ operation: 'create' })).toResolveFalsy(); + + await expect(enhance().model.check({ operation: 'update' })).toResolveFalsy(); + await expect(enhance({ id: 1 }).model.check({ operation: 'update' })).toResolveFalsy(); + await expect(enhance({ id: 1, level: 0 }).model.check({ operation: 'update' })).toResolveFalsy(); + await expect(enhance({ id: 1, level: 1 }).model.check({ operation: 'update' })).toResolveTruthy(); + }); + + it('auth with relation access', async () => { + const { enhance } = await load( + ` + model User { + id Int @id @default(autoincrement()) + profile Profile? + } + + model Profile { + id Int @id @default(autoincrement()) + level Int + user User @relation(fields: [userId], references: [id]) + userId Int @unique + } + + model Model { + id Int @id @default(autoincrement()) + value Int + @@allow('read', auth().profile.level > 0) + } + ` + ); + + await expect(enhance().model.check({ operation: 'read' })).toResolveFalsy(); + await expect(enhance({ id: 1 }).model.check({ operation: 'read' })).toResolveFalsy(); + await expect(enhance({ id: 1, profile: { level: 0 } }).model.check({ operation: 'read' })).toResolveFalsy(); + await expect(enhance({ id: 1, profile: { level: 1 } }).model.check({ operation: 'read' })).toResolveTruthy(); + }); + + it('nullable field', async () => { + const { enhance } = await load( + ` + model Model { + id Int @id @default(autoincrement()) + value Int? + @@allow('read', value != null) + @@allow('create', value == null) + } + ` + ); + + const db = enhance(); + await expect(db.model.check({ operation: 'read' })).toResolveTruthy(); + await expect(db.model.check({ operation: 'read', filter: { value: 1 } })).toResolveTruthy(); + await expect(db.model.check({ operation: 'create' })).toResolveTruthy(); + await expect(db.model.check({ operation: 'create', filter: { value: 1 } })).toResolveTruthy(); + }); + + it('compilation', async () => { + await load( + ` + model Model { + id Int @id @default(autoincrement()) + value Int + @@allow('read', value == 1) + } + `, + { + compile: true, + extraSourceFiles: [ + { + name: 'main.ts', + content: ` + import { PrismaClient } from '@prisma/client'; + import { enhance } from '.zenstack/enhance'; + + const prisma = new PrismaClient(); + const db = enhance(prisma); + db.model.check({ operation: 'read' }); + db.model.check({ operation: 'read', filter: { value: 1 }}); + `, + }, + ], + } + ); + }); + + it('invalid filter', async () => { + const { enhance } = await load( + ` + model Model { + id Int @id @default(autoincrement()) + value Int + foo Foo? + d DateTime + + @@allow('read', value == 1) + } + + model Foo { + id Int @id @default(autoincrement()) + x Int + model Model @relation(fields: [modelId], references: [id]) + modelId Int @unique + } + ` + ); + + const db = enhance(); + await expect(db.model.check({ operation: 'read', filter: { foo: { x: 1 } } })).rejects.toThrow( + `Providing filter for field "foo" is not supported. Only scalar fields are allowed.` + ); + await expect(db.model.check({ operation: 'read', filter: { d: new Date() } })).rejects.toThrow( + `Providing filter for field "d" is not supported. Only number, string, and boolean fields are allowed.` + ); + await expect(db.model.check({ operation: 'read', filter: { value: null } })).rejects.toThrow( + `Using "null" as filter value is not supported yet` + ); + await expect(db.model.check({ operation: 'read', filter: { value: {} } })).rejects.toThrow( + 'Invalid value type for field "value". Only number, string or boolean is allowed.' + ); + await expect(db.model.check({ operation: 'read', filter: { value: 'abc' } })).rejects.toThrow( + 'Invalid value type for field "value". Expected "number"' + ); + await expect(db.model.check({ operation: 'read', filter: { value: -1 } })).rejects.toThrow( + 'Invalid value for field "value". Only non-negative integers are allowed.' + ); + }); + + it('float field ignored', async () => { + const { enhance } = await load( + ` + model Model { + id Int @id @default(autoincrement()) + value Float + @@allow('read', value == 1.1) + } + ` + ); + const db = enhance(); + await expect(db.model.check({ operation: 'read' })).toResolveTruthy(); + await expect(db.model.check({ operation: 'read', filter: { value: 1 } })).toResolveTruthy(); + }); + + it('float value ignored', async () => { + const { enhance } = await load( + ` + model Model { + id Int @id @default(autoincrement()) + value Int + @@allow('read', value > 1.1) + } + ` + ); + const db = enhance(); + await expect(db.model.check({ operation: 'read' })).toResolveTruthy(); + await expect(db.model.check({ operation: 'read', filter: { value: 1 } })).toResolveTruthy(); + await expect(db.model.check({ operation: 'read', filter: { value: 2 } })).toResolveTruthy(); + }); + + it('negative value ignored', async () => { + const { enhance } = await load( + ` + model Model { + id Int @id @default(autoincrement()) + value Int + @@allow('read', value >-1) + } + ` + ); + const db = enhance(); + await expect(db.model.check({ operation: 'read' })).toResolveTruthy(); + await expect(db.model.check({ operation: 'read', filter: { value: 1 } })).toResolveTruthy(); + await expect(db.model.check({ operation: 'read', filter: { value: 2 } })).toResolveTruthy(); + }); +}); diff --git a/tests/regression/tests/issue-961.test.ts b/tests/regression/tests/issue-961.test.ts index 7bc42071b..f6dc3a135 100644 --- a/tests/regression/tests/issue-961.test.ts +++ b/tests/regression/tests/issue-961.test.ts @@ -123,10 +123,8 @@ describe('Regression: issue 961', () => { await expect(db.userColumn.findMany()).resolves.toHaveLength(1); }); - // disabled because of Prisma V4 bug: https://github.com/prisma/prisma/issues/18371 - // eslint-disable-next-line jest/no-disabled-tests - it.skip('updateMany', async () => { - const { prisma, enhance } = await loadSchema(schema, { logPrismaQuery: true }); + it('updateMany', async () => { + const { prisma, enhance } = await loadSchema(schema); const user = await prisma.user.create({ data: {