From 21a0a2d3ec2a366a4d4c4f1c3938edb7fb8067fa Mon Sep 17 00:00:00 2001 From: ymc9 <104139426+ymc9@users.noreply.github.com> Date: Mon, 29 Apr 2024 09:46:01 +0800 Subject: [PATCH 01/13] WIP: permission checker --- packages/runtime/package.json | 6 + .../enhancements/policy/constraint-solver.ts | 277 +++++++++++++++++ .../src/enhancements/policy/handler.ts | 38 ++- .../src/enhancements/policy/logic-solver.d.ts | 45 +++ .../src/enhancements/policy/policy-utils.ts | 40 ++- packages/runtime/src/enhancements/types.ts | 22 +- packages/runtime/src/types.ts | 10 +- .../enhancer/policy/constraint-transformer.ts | 290 ++++++++++++++++++ .../enhancer/policy/policy-guard-generator.ts | 83 ++++- pnpm-lock.yaml | 13 +- .../enhancements/with-policy/checker.test.ts | 20 ++ 11 files changed, 834 insertions(+), 10 deletions(-) create mode 100644 packages/runtime/src/enhancements/policy/constraint-solver.ts create mode 100644 packages/runtime/src/enhancements/policy/logic-solver.d.ts create mode 100644 packages/schema/src/plugins/enhancer/policy/constraint-transformer.ts create mode 100644 tests/integration/tests/enhancements/with-policy/checker.test.ts diff --git a/packages/runtime/package.json b/packages/runtime/package.json index 8f418ab7e..753c61f8e 100644 --- a/packages/runtime/package.json +++ b/packages/runtime/package.json @@ -29,6 +29,10 @@ "types": "./enhancements/index.d.ts", "default": "./enhancements/index.js" }, + "./constraint-solver": { + "types": "./constraint-solver.d.ts", + "default": "./constraint-solver.js" + }, "./zod": { "types": "./zod/index.d.ts", "default": "./zod/index.js" @@ -79,12 +83,14 @@ "decimal.js": "^10.4.2", "deepcopy": "^2.1.0", "deepmerge": "^4.3.1", + "logic-solver": "^2.0.1", "lower-case-first": "^2.0.2", "pluralize": "^8.0.0", "safe-json-stringify": "^1.2.0", "semver": "^7.5.2", "superjson": "^1.11.0", "tiny-invariant": "^1.3.1", + "ts-pattern": "^4.3.0", "tslib": "^2.4.1", "upper-case-first": "^2.0.2", "uuid": "^9.0.0", diff --git a/packages/runtime/src/enhancements/policy/constraint-solver.ts b/packages/runtime/src/enhancements/policy/constraint-solver.ts new file mode 100644 index 000000000..87c5ee0ec --- /dev/null +++ b/packages/runtime/src/enhancements/policy/constraint-solver.ts @@ -0,0 +1,277 @@ +import Logic, { Formula } from 'logic-solver'; +import { match } from 'ts-pattern'; +import type { CheckerConstraint, ComparisonTerm, ConstraintVariable } from '../types'; + +const MAGIC_NULL = 0x7fffffff; + +export class ConstraintSolver { + private stringTable: string[] = []; + private variables: Map = new Map(); + + solve(constraint: CheckerConstraint): boolean { + this.stringTable = []; + this.variables = new Map(); + + const formula = this.buildFormula(constraint); + const solver = new Logic.Solver(); + solver.require(formula); + const solution = solver.solve(); + if (solution) { + console.log('Solution:'); + this.variables.forEach((v, k) => console.log(`\t${k}=${solution?.evaluate(v)}`)); + } else { + console.log('No solution'); + } + return !!solution; + } + + private buildFormula(constraint: CheckerConstraint): Logic.Formula { + if ('value' in constraint) { + if (constraint.value === null) { + return Logic.constantBits(MAGIC_NULL); + } + + if (typeof constraint.value === 'boolean') { + return constraint.value === true ? Logic.TRUE : Logic.FALSE; + } + + if (typeof constraint.value === 'number') { + return Logic.constantBits(constraint.value); + } + + if (typeof constraint.value === 'string') { + const index = this.stringTable.indexOf(constraint.value); + if (index === -1) { + this.stringTable.push(constraint.value); + return Logic.constantBits(this.stringTable.length - 1); + } else { + return Logic.constantBits(index); + } + } + } + + if ('name' in constraint) { + // variable + return match(constraint.type) + .with('boolean', () => this.booleanVariable(constraint)) + .with('number', () => this.intVariable(constraint.name)) + .with('string', () => this.intVariable(constraint.name)) + .exhaustive(); + } + + if ('eq' in constraint) { + return this.transformEquality(constraint.eq.left, constraint.eq.right); + } + + if ('gt' in constraint) { + return this.transformComparison(constraint.gt.left, constraint.gt.right, (l, r) => Logic.greaterThan(l, r)); + } + + if ('gte' in constraint) { + return this.transformComparison(constraint.gte.left, constraint.gte.right, (l, r) => + Logic.greaterThanOrEqual(l, r) + ); + } + + if ('lt' in constraint) { + return this.transformComparison(constraint.lt.left, constraint.lt.right, (l, r) => Logic.lessThan(l, r)); + } + + if ('lte' in constraint) { + return this.transformComparison(constraint.lte.left, constraint.lte.right, (l, r) => + Logic.greaterThan(l, r) + ); + } + + if ('and' in constraint) { + return Logic.and(...constraint.and.map((c) => this.buildFormula(c))); + } + + if ('or' in constraint) { + return Logic.or(...constraint.or.map((c) => this.buildFormula(c))); + } + + if ('not' in constraint) { + return Logic.not(this.buildFormula(constraint.not)); + } + + throw new Error(`Unsupported constraint format: ${JSON.stringify(constraint)}`); + } + + private booleanVariable(constraint: ConstraintVariable): string { + this.variables.set(constraint.name, constraint.name); + return constraint.name; + } + + private intVariable(name: string): string { + const r = Logic.variableBits(name, 32); + this.variables.set(name, r); + return r; + } + + private transformEquality(left: ComparisonTerm, right: ComparisonTerm) { + if (left.type !== right.type) { + throw new Error(`Type mismatch in equality constraint: ${JSON.stringify(left)}, ${JSON.stringify(right)}`); + } + const leftConstraint = this.buildFormula(left); + const rightConstraint = this.buildFormula(right); + if (left.type === 'boolean' && right.type === 'boolean') { + return Logic.equiv(leftConstraint, rightConstraint); + } else { + return Logic.equalBits(leftConstraint, rightConstraint); + } + } + + private transformComparison( + left: ComparisonTerm, + right: ComparisonTerm, + func: (left: Logic.Formula, right: Logic.Formula) => Logic.Formula + ) { + const leftConstraint = this.buildFormula(left); + const rightConstraint = this.buildFormula(right); + return func(leftConstraint, rightConstraint); + } +} + +// export function solve(constraint: CheckerConstraint) { +// const stringTable: string[] = []; +// const formula = buildFormula(constraint, stringTable); +// const solver = new Logic.Solver(); +// solver.require(formula); +// const solution = solver.solve(); +// console.log('Solution:', solution?.getMap()); +// return !!solution; +// } + +// function buildFormula(constraint: CheckerConstraint, stringTable: string[]): Logic.Formula { +// if ('value' in constraint) { +// if (constraint.value === null) { +// return Logic.constantBits(MAGIC_NULL); +// } + +// if (typeof constraint.value === 'boolean') { +// return constraint.value === true ? Logic.TRUE : Logic.FALSE; +// } + +// if (typeof constraint.value === 'number') { +// return Logic.constantBits(constraint.value); +// } + +// if (typeof constraint.value === 'string') { +// const index = stringTable.indexOf(constraint.value); +// if (index === -1) { +// stringTable.push(constraint.value); +// return Logic.constantBits(stringTable.length - 1); +// } else { +// return Logic.constantBits(index); +// } +// } +// } + +// if ('name' in constraint) { +// // variable +// return match(constraint.type) +// .with('boolean', () => constraint.name) +// .with('number', () => Logic.variableBits(constraint.name, 32)) +// .with('string', () => Logic.variableBits(constraint.name, 32)) +// .exhaustive(); +// } + +// if ('eq' in constraint) { +// return transformEquality(constraint.eq.left, constraint.eq.right, stringTable); +// } + +// if ('gt' in constraint) { +// return transformComparison(constraint.gt.left, constraint.gt.right, stringTable, (l, r) => +// Logic.greaterThan(l, r) +// ); +// } + +// if ('gte' in constraint) { +// return transformComparison(constraint.gte.left, constraint.gte.right, stringTable, (l, r) => +// Logic.greaterThanOrEqual(l, r) +// ); +// } + +// if ('lt' in constraint) { +// return transformComparison(constraint.lt.left, constraint.lt.right, stringTable, (l, r) => +// Logic.lessThan(l, r) +// ); +// } + +// if ('lte' in constraint) { +// return transformComparison(constraint.lte.left, constraint.lte.right, stringTable, (l, r) => +// Logic.greaterThan(l, r) +// ); +// } + +// if ('and' in constraint) { +// return Logic.and(...constraint.and.map((c) => buildFormula(c, stringTable))); +// } + +// if ('or' in constraint) { +// return Logic.or(...constraint.or.map((c) => buildFormula(c, stringTable))); +// } + +// if ('not' in constraint) { +// return Logic.not(buildFormula(constraint.not, stringTable)); +// } + +// throw new Error(`Unsupported constraint format: ${JSON.stringify(constraint)}`); +// } + +// function transformEquality(left: ComparisonTerm, right: ComparisonTerm, stringTable: string[]) { +// if (left.type !== right.type) { +// throw new Error(`Type mismatch in equality constraint: ${JSON.stringify(left)}, ${JSON.stringify(right)}`); +// } +// const leftConstraint = buildFormula(left, stringTable); +// const rightConstraint = buildFormula(right, stringTable); +// if (left.type === 'boolean' && right.type === 'boolean') { +// return Logic.equiv(leftConstraint, rightConstraint); +// } else { +// return Logic.equalBits(leftConstraint, rightConstraint); +// } +// } + +// function transformComparison( +// left: ComparisonTerm, +// right: ComparisonTerm, +// stringTable: string[], +// func: (left: Logic.Formula, right: Logic.Formula) => Logic.Formula +// ): string { +// const leftConstraint = buildFormula(left, stringTable); +// const rightConstraint = buildFormula(right, stringTable); +// return func(leftConstraint, rightConstraint); +// } + +// // export type Constraint = Logic.Formula; + +// // export function TRUE(): Constraint { +// // return Logic.TRUE; +// // } + +// // export function FALSE(): Constraint { +// // return Logic.FALSE; +// // } + +// // export function variable(name: string): Constraint { +// // return name; +// // } + +// // export function and(...args: Constraint[]): Constraint { +// // return Logic.and(...args); +// // } + +// // export function or(...args: Constraint[]): Constraint { +// // return Logic.or(...args); +// // } + +// // export function not(arg: Constraint): Constraint { +// // return Logic.not(arg); +// // } + +// // export function checkSat(constraint: Constraint): boolean { +// // const solver = new Logic.Solver(); +// // solver.require(constraint); +// // return !!solver.solve(); +// // } diff --git a/packages/runtime/src/enhancements/policy/handler.ts b/packages/runtime/src/enhancements/policy/handler.ts index d6d893d4e..ee4c5fb42 100644 --- a/packages/runtime/src/enhancements/policy/handler.ts +++ b/packages/runtime/src/enhancements/policy/handler.ts @@ -16,13 +16,15 @@ import { type FieldInfo, type ModelMeta, } from '../../cross'; -import { PolicyOperationKind, type CrudContract, type DbClientContract } from '../../types'; +import { PolicyCrudKind, PolicyOperationKind, type CrudContract, type DbClientContract } from '../../types'; import type { EnhancementContext, InternalEnhancementOptions } from '../create-enhancement'; import { Logger } from '../logger'; import { createDeferredPromise, createFluentPromise } from '../promise'; import { PrismaProxyHandler } from '../proxy'; import { QueryUtils } from '../query-utils'; +import type { CheckerConstraint } from '../types'; import { clone, formatObject, isUnsafeMutate, prismaClientValidationError } from '../utils'; +import { ConstraintSolver } from './constraint-solver'; import { PolicyUtil } from './policy-utils'; // a record for post-write policy check @@ -1436,6 +1438,40 @@ export class PolicyProxyHandler implements Pr //#endregion + //#region Check + + async check( + operation: PolicyCrudKind, + fieldValues?: Record + ): Promise { + let constraint = this.policyUtils.getCheckerConstraint(this.model, operation); + if (typeof constraint === 'boolean') { + return constraint; + } + + if (fieldValues) { + const extraConstraints: CheckerConstraint[] = []; + for (const [field, value] of Object.entries(fieldValues)) { + if (value !== undefined) { + const valueType = typeof value; + if (valueType !== 'number' && valueType !== 'string' && valueType !== 'boolean') { + throw new Error(`invalid value type for field "${field}" is not supported`); + } + extraConstraints.push({ + eq: { left: { name: field, type: valueType }, right: { value, type: valueType } }, + }); + } + } + if (extraConstraints.length > 0) { + constraint = { and: [constraint, ...extraConstraints] }; + } + } + + return new ConstraintSolver().solve(constraint); + } + + //#endregion + //#region Utils private get shouldLogQuery() { diff --git a/packages/runtime/src/enhancements/policy/logic-solver.d.ts b/packages/runtime/src/enhancements/policy/logic-solver.d.ts new file mode 100644 index 000000000..b177c520e --- /dev/null +++ b/packages/runtime/src/enhancements/policy/logic-solver.d.ts @@ -0,0 +1,45 @@ +declare module 'logic-solver' { + type Term = string; + + type Formula = Term; + + const TRUE: Term; + + const FALSE: Term; + + export function equiv(operand1: Formula, operand2: Formula): Formula; + + export function equalBits(bits1: Formula, bits2: Formula): Formula; + + export function greaterThan(bits1: Formula, bits2: Formula): Formula; + + export function greaterThanOrEqual(bits1: Formula, bits2: Formula): Formula; + + export function lessThan(bits1: Formula, bits2: Formula): Formula; + + export function lessThanOrEqual(bits1: Formula, bits2: Formula): Formula; + + export function and(...args: Formula[]): Formula; + + export function or(...args: Formula[]): Formula; + + export function not(arg: Formula): Formula; + + export function variableBits(baseName: string, N: number): Formula; + + export function constantBits(wholeNumber: number): Formula; + + interface Solution { + getMap(): object; + + evaluate(formula: Formula): unknown; + } + + class Solver { + require(...args: Formula[]): void; + + forbid(...args: Formula[]): void; + + solve(): Solution; + } +} diff --git a/packages/runtime/src/enhancements/policy/policy-utils.ts b/packages/runtime/src/enhancements/policy/policy-utils.ts index bcb946877..9c3e8aecb 100644 --- a/packages/runtime/src/enhancements/policy/policy-utils.ts +++ b/packages/runtime/src/enhancements/policy/policy-utils.ts @@ -17,12 +17,12 @@ import { PrismaErrorCode, } from '../../constants'; import { enumerate, getFields, getModelFields, resolveField, zip, type FieldInfo, type ModelMeta } from '../../cross'; -import { AuthUser, CrudContract, DbClientContract, PolicyOperationKind } from '../../types'; +import { AuthUser, CrudContract, DbClientContract, PolicyCrudKind, PolicyOperationKind } from '../../types'; import { getVersion } from '../../version'; import type { EnhancementContext, InternalEnhancementOptions } from '../create-enhancement'; import { Logger } from '../logger'; import { QueryUtils } from '../query-utils'; -import type { InputCheckFunc, PolicyDef, ReadFieldCheckFunc, ZodSchemas } from '../types'; +import type { CheckerFunc, InputCheckFunc, PolicyDef, ReadFieldCheckFunc, ZodSchemas } from '../types'; import { formatObject, prismaClientKnownRequestError } from '../utils'; /** @@ -228,7 +228,7 @@ export class PolicyUtil extends QueryUtils { //#endregion - //# Auth guard + //#region Auth guard private readonly FULLY_OPEN_AUTH_GUARD = { create: true, @@ -267,7 +267,7 @@ export class PolicyUtil extends QueryUtils { } if (!provider) { - throw this.unknownError(`zenstack: unable to load authorization guard for ${model}`); + throw this.unknownError(`unable to load authorization guard for ${model}`); } const r = provider({ user: this.user, preValue }, db); return this.reduce(r); @@ -561,6 +561,38 @@ export class PolicyUtil extends QueryUtils { return true; } + //#endregion + + //#region Checker + + getCheckerConstraint(model: string, operation: PolicyCrudKind): ReturnType | boolean { + const checker = this.getModelChecker(model); + if (!checker) { + throw this.unknownError(`unable to load policy guard for ${model}`); + } + + const provider = checker[operation]; + if (typeof provider === 'boolean') { + return provider; + } + + if (!provider) { + throw this.unknownError(`unable to load authorization guard for ${model}`); + } + return provider({ user: this.user }); + } + + private getModelChecker(model: string): PolicyDef['checker']['string'] { + if (this.options.kinds && !this.options.kinds.includes('policy')) { + // policy enhancement not enabled, return a constant checker + return { create: true, read: true, update: true, delete: true }; + } else { + return this.options.policy.checker?.[lowerCaseFirst(model)]; + } + } + + //#endregion + /** * Gets unique constraints for the given model. */ diff --git a/packages/runtime/src/enhancements/types.ts b/packages/runtime/src/enhancements/types.ts index 9fecc375e..8e2178b15 100644 --- a/packages/runtime/src/enhancements/types.ts +++ b/packages/runtime/src/enhancements/types.ts @@ -9,7 +9,7 @@ import { HAS_FIELD_LEVEL_POLICY_FLAG, PRE_UPDATE_VALUE_SELECTOR, } from '../constants'; -import type { CrudContract, PolicyOperationKind, QueryContext } from '../types'; +import type { CheckerContext, CrudContract, PolicyCrudKind, PolicyOperationKind, QueryContext } from '../types'; /** * Common options for PrismaClient enhancements @@ -33,6 +33,24 @@ export interface CommonEnhancementOptions { */ export type PolicyFunc = (context: QueryContext, db: CrudContract) => object; +export type CheckerFunc = (context: CheckerContext) => CheckerConstraint; + +export type ConstraintVariable = { name: string; type: 'boolean' | 'number' | 'string' }; +export type ConstraintValue = { value: number | boolean | string | null; type: 'boolean' | 'number' | 'string' }; +export type ComparisonTerm = ConstraintVariable | ConstraintValue; + +export type CheckerConstraint = + | ConstraintValue + | ConstraintVariable + | { eq: { left: ComparisonTerm; right: ComparisonTerm } } + | { gt: { left: ComparisonTerm; right: ComparisonTerm } } + | { gte: { left: ComparisonTerm; right: ComparisonTerm } } + | { lt: { left: ComparisonTerm; right: ComparisonTerm } } + | { lte: { left: ComparisonTerm; right: ComparisonTerm } } + | { and: CheckerConstraint[] } + | { or: CheckerConstraint[] } + | { not: CheckerConstraint }; + /** * Function for getting policy guard with a given context */ @@ -71,6 +89,8 @@ export type PolicyDef = { } >; + checker: Record>; + // tracks which models have data validation rules validation: Record; diff --git a/packages/runtime/src/types.ts b/packages/runtime/src/types.ts index 4bcab85a1..b516b2cc3 100644 --- a/packages/runtime/src/types.ts +++ b/packages/runtime/src/types.ts @@ -30,10 +30,12 @@ export interface DbOperations { */ export type PolicyKind = 'allow' | 'deny'; +export type PolicyCrudKind = 'read' | 'create' | 'update' | 'delete'; + /** * Kinds of operations controlled by access policies */ -export type PolicyOperationKind = 'create' | 'update' | 'postUpdate' | 'read' | 'delete'; +export type PolicyOperationKind = PolicyCrudKind | 'postUpdate'; /** * Current login user info @@ -56,6 +58,12 @@ export type QueryContext = { preValue?: any; }; +export type CheckerContext = { + user?: AuthUser; + + fieldValues?: Record; +}; + /** * Prisma contract for CRUD operations. */ diff --git a/packages/schema/src/plugins/enhancer/policy/constraint-transformer.ts b/packages/schema/src/plugins/enhancer/policy/constraint-transformer.ts new file mode 100644 index 000000000..5b6fbbba8 --- /dev/null +++ b/packages/schema/src/plugins/enhancer/policy/constraint-transformer.ts @@ -0,0 +1,290 @@ +import { ZModelCodeGenerator, getLiteral, isAuthInvocation } from '@zenstackhq/sdk'; +import { + BinaryExpr, + Expression, + ExpressionType, + LiteralExpr, + MemberAccessExpr, + ReferenceExpr, + UnaryExpr, + isBinaryExpr, + isDataModelField, + isLiteralExpr, + isMemberAccessExpr, + isReferenceExpr, + isThisExpr, + isUnaryExpr, +} from '@zenstackhq/sdk/ast'; +import { P, match } from 'ts-pattern'; + +export type ConstraintTransformerOptions = { + authAccessor: string; +}; + +export class ConstraintTransformer { + private varCounter = 0; + + constructor(private readonly options: ConstraintTransformerOptions) {} + + transformRules(allows: Expression[], denies: Expression[]): string { + this.varCounter = 0; + + if (allows.length === 0 && denies.length === 0) { + return `{ value: true, type: 'boolean' }`; + } + + if (allows.length === 0) { + return `{ value: false, type: 'boolean' }`; + } + + let result: string; + + const allowConstraints = allows.map((allow) => this.transformExpression(allow)); + if (allowConstraints.length > 1) { + result = this.and(...allowConstraints); + } else { + result = allowConstraints[0]; + } + + if (denies.length > 0) { + const denyConstraints = denies.map((deny) => this.transformExpression(deny)); + result = this.and(result, this.not(this.or(...denyConstraints))); + } + + return result; + } + + private and(...constraints: string[]) { + if (constraints.length === 0) { + throw new Error('No expressions to combine'); + } + return constraints.length === 1 ? `{ and: [ ${constraints.join(', ')} ] }` : constraints[0]; + } + + private or(...constraints: string[]) { + if (constraints.length === 0) { + throw new Error('No expressions to combine'); + } + return constraints.length === 1 ? `{ or: [ ${constraints.join(', ')} ] }` : constraints[0]; + } + + private not(constraint: string) { + return `{ not: ${constraint} }`; + } + + private transformExpression(expression: Expression) { + return ( + match(expression) + .when(isBinaryExpr, (expr) => this.transformBinary(expr)) + .when(isUnaryExpr, (expr) => this.transformUnary(expr)) + // top-level boolean literal + .when(isLiteralExpr, (expr) => this.transformLiteral(expr)) + // top-level boolean reference expr + .when(isReferenceExpr, (expr) => this.transformReference(expr)) + // top-level boolean member access expr + .when(isMemberAccessExpr, (expr) => this.transformMemberAccess(expr)) + .otherwise(() => this.nextVar()) + ); + } + + private transformLiteral(expr: LiteralExpr) { + const value = getLiteral(expr); + return value ? `{ value: true, type: 'boolean' }` : `{ value: false, type: 'boolean' }`; + } + + private transformReference(expr: ReferenceExpr) { + return `{ name: '${expr.target.$refText}', type: 'boolean' }`; + } + + private transformMemberAccess(expr: MemberAccessExpr) { + if (isThisExpr(expr.operand)) { + return `{ name: '${expr.member.$refText}', type: 'boolean' }`; + } + return this.nextVar(); + } + + private transformBinary(expr: BinaryExpr): string { + return match(expr.operator) + .with('&&', () => this.and(this.transformExpression(expr.left), this.transformExpression(expr.right))) + .with('||', () => this.or(this.transformExpression(expr.left), this.transformExpression(expr.right))) + .with(P.union('==', '!=', '<', '<=', '>', '>='), () => this.transformComparison(expr)) + .otherwise(() => this.nextVar()); + } + + private transformUnary(expr: UnaryExpr): string { + return match(expr.operator) + .with('!', () => this.not(this.transformExpression(expr.operand))) + .otherwise(() => this.nextVar()); + } + + private transformComparison(expr: BinaryExpr) { + const leftOperand = this.getComparisonOperand(expr.left); + const rightOperand = this.getComparisonOperand(expr.right); + + if (leftOperand === undefined || rightOperand === undefined) { + return this.nextVar(); + } + + const op = match(expr.operator) + .with('==', () => 'eq') + .with('!=', () => 'eq') + .with('<', () => 'lt') + .with('<=', () => 'lte') + .with('>', () => 'gt') + .with('>=', () => 'gte') + .otherwise(() => { + throw new Error(`Unsupported operator: ${expr.operator}`); + }); + + let result = `{ ${op}: { left: ${leftOperand}, right: ${rightOperand} } }`; + if (expr.operator === '!=') { + result = `{ not: ${result} }`; + } + + return result; + } + + private getComparisonOperand(expr: Expression) { + if (isLiteralExpr(expr)) { + const mappedType = this.mapType(expr.$resolvedType?.decl as ExpressionType); + if (mappedType) { + return `{ value: ${expr.value}, type: '${mappedType}' }`; + } else { + return undefined; + } + } + + const fieldAccess = this.getFieldAccess(expr); + if (fieldAccess) { + const fieldType = expr.$resolvedType?.decl; + if (!fieldType) { + return undefined; + } + + const mappedType = this.mapType(fieldType as ExpressionType); + if (mappedType) { + return `{ name: '${fieldAccess.name}', type: '${mappedType}' }`; + } else { + return undefined; + } + } + + const authAccess = this.getAuthAccess(expr); + if (authAccess) { + const fieldType = expr.$resolvedType?.decl; + if (!fieldType) { + return undefined; + } + + const mappedType = this.mapType(fieldType as ExpressionType); + if (mappedType) { + return `{ value: ${this.options.authAccessor}?.${authAccess.name}, type: '${mappedType}' }`; + } else { + return undefined; + } + } + + return undefined; + } + + private mapType(fieldType: ExpressionType) { + return match(fieldType) + .with('Boolean', () => 'boolean') + .with('Int', () => 'number') + .with('String', () => 'string') + .otherwise(() => undefined); + } + + // private transformEquality(left: Expression, right: Expression) { + // if (this.isFieldAccess(left) || this.isFieldAccess(right)) { + // const variable = this.isFieldAccess(left) ? left : right; + // const value = this.isFieldAccess(left) ? right : left; + + // const value = this.getExprValue(right); + // if (value !== undefined) { + // if (value === true) { + // return `{ var: '${this.getFieldName(left)}' }`; + // } else if (value === false) { + // return `{ not: { var: '${this.getFieldName(left)}' } }`; + // } else { + // return `{ eq: { ${this.getFieldName(left)}: ${this.encodeValue(value)} } }`; + // } + // } + // } + + // if (this.isFieldAccess(right)) { + // const value = this.getExprValue(left); + // if (value !== undefined) { + // return `'${this.getFieldName(right)} == ${value}'`; + // } + // } + + // if (this.isAuthAccess(left)) { + // const value = this.getExprValue(right); + // if (value !== undefined) { + // return `${this.options.authAccessor}?.${left.member.$refText} === ${value}`; + // } + // } + + // if (this.isAuthAccess(right)) { + // const value = this.getExprValue(left); + // if (value !== undefined) { + // return `${this.options.authAccessor}?.${right.member.$refText} === ${value}`; + // } + // } + + // return this.nextVar(); + // } + + // private encodeValue(value: string | number | boolean) { + // if (typeof value === 'number' || typeof value === 'boolean') { + // return value; + // } else { + // const index = this.stringTable.indexOf(value); + // if (index === -1) { + // this.stringTable.push(value); + // return this.stringTable.length - 1; + // } else { + // return index; + // } + // } + // } + + private getFieldAccess(expr: Expression) { + if (isReferenceExpr(expr)) { + return isDataModelField(expr.target.ref) ? { name: expr.target.$refText } : undefined; + } + if (isMemberAccessExpr(expr)) { + return isThisExpr(expr.operand) ? { name: expr.member.$refText } : undefined; + } + return undefined; + } + + // private getFieldName(expr: ReferenceExpr | MemberAccessExpr): string { + // return isReferenceExpr(expr) ? expr.target.$refText : expr.member.$refText; + // } + + // private getExprValue(expr: Expression): string | boolean | number | undefined { + // if (isLiteralExpr(expr)) { + // return expr.value; + // } + + // if (this.isAuthAccess(expr)) { + // return `${this.options.authAccessor}?.${expr.member.$refText}`; + // } + + // return undefined; + // } + + private getAuthAccess(expr: Expression) { + return isMemberAccessExpr(expr) && isAuthInvocation(expr.operand) ? { name: expr.member.$refText } : undefined; + } + + private nextVar() { + return `'__var${this.varCounter++}'`; + } + + private expressionVariable(expr: Expression) { + return new ZModelCodeGenerator().generate(expr); + } +} diff --git a/packages/schema/src/plugins/enhancer/policy/policy-guard-generator.ts b/packages/schema/src/plugins/enhancer/policy/policy-guard-generator.ts index 753ef8f19..4dc6c27b2 100644 --- a/packages/schema/src/plugins/enhancer/policy/policy-guard-generator.ts +++ b/packages/schema/src/plugins/enhancer/policy/policy-guard-generator.ts @@ -54,6 +54,7 @@ import { FunctionDeclaration, Project, SourceFile, VariableDeclarationKind, Writ import { name } from '..'; import { isCollectionPredicate } from '../../../utils/ast-utils'; import { ALL_OPERATION_KINDS } from '../../plugin-utils'; +import { ConstraintTransformer } from './constraint-transformer'; import { ExpressionWriter, FALSE, TRUE } from './expression-writer'; /** @@ -70,6 +71,8 @@ export class PolicyGenerator { { name: 'type CrudContract' }, { name: 'allFieldsEqual' }, { name: 'type PolicyDef' }, + { name: 'type CheckerContext' }, + { name: 'type CheckerConstraint' }, ], moduleSpecifier: `${RUNTIME_PACKAGE}`, }); @@ -90,6 +93,11 @@ export class PolicyGenerator { policyMap[model.name] = await this.generateQueryGuardForModel(model, sf); } + const checkerMap: Record> = {}; + for (const model of models) { + checkerMap[model.name] = await this.generateCheckerForModel(model, sf); + } + const authSelector = this.generateAuthSelector(models); sf.addVariableStatement({ @@ -118,6 +126,19 @@ export class PolicyGenerator { }); writer.writeLine(','); + writer.write('checker:'); + writer.inlineBlock(() => { + for (const [model, map] of Object.entries(checkerMap)) { + writer.write(`${lowerCaseFirst(model)}:`); + writer.inlineBlock(() => { + Object.entries(map).forEach(([op, func]) => { + writer.write(`${op}: ${func},`); + }); + }); + } + }); + writer.writeLine(','); + writer.write('validation:'); writer.inlineBlock(() => { for (const model of models) { @@ -301,7 +322,6 @@ export class PolicyGenerator { } const guardFunc = this.generateQueryGuardFunction(sourceFile, model, kind, allows, denies); - // eslint-disable-next-line @typescript-eslint/no-non-null-assertion result[kind] = guardFunc.getName()!; if (kind === 'postUpdate') { @@ -313,7 +333,6 @@ export class PolicyGenerator { if (kind === 'create' && this.canCheckCreateBasedOnInput(model, allows, denies)) { const inputCheckFunc = this.generateInputCheckFunction(sourceFile, model, kind, allows, denies); - // eslint-disable-next-line @typescript-eslint/no-non-null-assertion result[kind + '_input'] = inputCheckFunc.getName()!; } } @@ -847,4 +866,64 @@ export class PolicyGenerator { statements.push(`const user: any = context.user ?? null;`); } } + + private async generateCheckerForModel(model: DataModel, sourceFile: SourceFile) { + const result: Record = {}; + + // eslint-disable-next-line @typescript-eslint/no-explicit-any + const policies = analyzePolicies(model); + + for (const kind of ['create', 'read', 'update', 'delete'] as const) { + if (policies[kind] === true || policies[kind] === false) { + result[kind] = policies[kind] as boolean; + continue; + } + + const denies = this.getPolicyExpressions(model, 'deny', kind); + const allows = this.getPolicyExpressions(model, 'allow', kind); + + if (kind === 'update' && allows.length === 0) { + // no allow rule for 'update', policy is constant based on if there's + // post-update counterpart + if (this.getPolicyExpressions(model, 'allow', 'postUpdate').length === 0) { + result[kind] = false; + continue; + } else { + result[kind] = true; + continue; + } + } + + const guardFunc = this.generateCheckerFunction(sourceFile, model, kind, allows, denies); + result[kind] = guardFunc.getName()!; + } + + return result; + } + + private generateCheckerFunction( + sourceFile: SourceFile, + model: DataModel, + kind: string, + allows: Expression[], + denies: Expression[] + ) { + const transformed = new ConstraintTransformer({ + authAccessor: 'context.user', + }).transformRules(allows, denies); + + const func = sourceFile.addFunction({ + name: `${model.name}Checker_${kind}`, + returnType: 'CheckerConstraint', + parameters: [ + { + name: 'context', + type: 'CheckerContext', + }, + ], + statements: [`return ${transformed};`], + }); + + return func; + } } diff --git a/pnpm-lock.yaml b/pnpm-lock.yaml index d08baac14..112683164 100644 --- a/pnpm-lock.yaml +++ b/pnpm-lock.yaml @@ -409,6 +409,9 @@ importers: deepmerge: specifier: ^4.3.1 version: 4.3.1 + logic-solver: + specifier: ^2.0.1 + version: 2.0.1 lower-case-first: specifier: ^2.0.2 version: 2.0.2 @@ -427,6 +430,9 @@ importers: tiny-invariant: specifier: ^1.3.1 version: 1.3.1 + ts-pattern: + specifier: ^4.3.0 + version: 4.3.0 tslib: specifier: ^2.4.1 version: 2.4.1 @@ -10554,6 +10560,12 @@ packages: wrap-ansi: 8.1.0 dev: false + /logic-solver@2.0.1: + resolution: {integrity: sha512-F1oCywXUzvAF4Z98mMyXySUCpUU3hNyc+JfYV3g2x/4BupC/xv94iPJuHh9us2XX5UrvY5lnKUXNvjcJNQBJ/g==} + dependencies: + underscore: 1.13.6 + dev: false + /loose-envify@1.4.0: resolution: {integrity: sha512-lyuxPGr/Wfhrlem2CL/UcnUc1zcqKAImBDzukY7Y5F/yQiNdko6+fRLevlw1HgMySw7f611UIY408EtxRSoK3Q==} hasBin: true @@ -14441,7 +14453,6 @@ packages: /underscore@1.13.6: resolution: {integrity: sha512-+A5Sja4HP1M08MaXya7p5LvjuM7K6q/2EaC0+iovj/wOcMsTzMvDFbasi/oSapiwOlt252IqsKqPjCl7huKS0A==} - dev: true /undici-types@5.26.5: resolution: {integrity: sha512-JlCMO+ehdEIKqlFxk6IfVoAUVmgz7cU7zD/h9XZ0qzeosSHmUJVOzSQvvYSYWXkFXC+IfLKSIffhv0sVZup6pA==} diff --git a/tests/integration/tests/enhancements/with-policy/checker.test.ts b/tests/integration/tests/enhancements/with-policy/checker.test.ts new file mode 100644 index 000000000..736402f1e --- /dev/null +++ b/tests/integration/tests/enhancements/with-policy/checker.test.ts @@ -0,0 +1,20 @@ +import { loadSchema } from '@zenstackhq/testtools'; + +describe('Permission checker', () => { + it('simple', async () => { + const { enhance } = await loadSchema( + ` + model Model { + id String @id @default(uuid()) + value Int + @@allow('read', value == 1) + } + ` + ); + + const db = enhance(); + await expect(db.model.check('read')).toResolveTruthy(); + await expect(db.model.check('read', { value: 0 })).toResolveFalsy(); + await expect(db.model.check('read', { value: 1 })).toResolveTruthy(); + }); +}); From 5ab8a44a1a08260dad7fbc4e37f069cc126b26b5 Mon Sep 17 00:00:00 2001 From: ymc9 <104139426+ymc9@users.noreply.github.com> Date: Sat, 4 May 2024 11:03:22 +0800 Subject: [PATCH 02/13] WIP: progress --- .../enhancements/policy/constraint-solver.ts | 306 +++++------------ .../src/enhancements/policy/handler.ts | 47 ++- .../src/enhancements/policy/logic-solver.d.ts | 8 +- packages/runtime/src/enhancements/types.ts | 39 ++- .../enhancer/policy/constraint-transformer.ts | 164 +++------ .../enhancer/policy/policy-guard-generator.ts | 11 +- .../enhancements/with-policy/checker.test.ts | 313 +++++++++++++++++- 7 files changed, 529 insertions(+), 359 deletions(-) diff --git a/packages/runtime/src/enhancements/policy/constraint-solver.ts b/packages/runtime/src/enhancements/policy/constraint-solver.ts index 87c5ee0ec..3a282f580 100644 --- a/packages/runtime/src/enhancements/policy/constraint-solver.ts +++ b/packages/runtime/src/enhancements/policy/constraint-solver.ts @@ -1,8 +1,13 @@ import Logic, { Formula } from 'logic-solver'; import { match } from 'ts-pattern'; -import type { CheckerConstraint, ComparisonTerm, ConstraintVariable } from '../types'; - -const MAGIC_NULL = 0x7fffffff; +import type { + CheckerConstraint, + ComparisonConstraint, + ComparisonTerm, + LogicalConstraint, + ValueConstraint, + VariableConstraint, +} from '../types'; export class ConstraintSolver { private stringTable: string[] = []; @@ -26,84 +31,98 @@ export class ConstraintSolver { } private buildFormula(constraint: CheckerConstraint): Logic.Formula { - if ('value' in constraint) { - if (constraint.value === null) { - return Logic.constantBits(MAGIC_NULL); - } - - if (typeof constraint.value === 'boolean') { - return constraint.value === true ? Logic.TRUE : Logic.FALSE; - } - - if (typeof constraint.value === 'number') { - return Logic.constantBits(constraint.value); - } + return match(constraint) + .when( + (c): c is ValueConstraint => c.kind === 'value', + (c) => this.buildValueFormula(c) + ) + .when( + (c): c is VariableConstraint => c.kind === 'variable', + (c) => this.buildVariableFormula(c) + ) + .when( + (c): c is ComparisonConstraint => ['eq', 'gt', 'gte', 'lt', 'lte'].includes(c.kind), + (c) => this.buildComparisonFormula(c) + ) + .when( + (c): c is LogicalConstraint => ['and', 'or', 'not'].includes(c.kind), + (c) => this.buildLogicalFormula(c) + ) + .otherwise(() => { + throw new Error(`Unsupported constraint format: ${JSON.stringify(constraint)}`); + }); + } - if (typeof constraint.value === 'string') { - const index = this.stringTable.indexOf(constraint.value); - if (index === -1) { - this.stringTable.push(constraint.value); - return Logic.constantBits(this.stringTable.length - 1); - } else { - return Logic.constantBits(index); + private buildLogicalFormula(constraint: LogicalConstraint) { + return match(constraint.kind) + .with('and', () => Logic.and(...constraint.children.map((c) => this.buildFormula(c)))) + .with('or', () => Logic.or(...constraint.children.map((c) => this.buildFormula(c)))) + .with('not', () => { + if (constraint.children.length !== 1) { + throw new Error('"not" constraint must have exactly one child'); } - } - } - - if ('name' in constraint) { - // variable - return match(constraint.type) - .with('boolean', () => this.booleanVariable(constraint)) - .with('number', () => this.intVariable(constraint.name)) - .with('string', () => this.intVariable(constraint.name)) - .exhaustive(); - } - - if ('eq' in constraint) { - return this.transformEquality(constraint.eq.left, constraint.eq.right); - } - - if ('gt' in constraint) { - return this.transformComparison(constraint.gt.left, constraint.gt.right, (l, r) => Logic.greaterThan(l, r)); - } - - if ('gte' in constraint) { - return this.transformComparison(constraint.gte.left, constraint.gte.right, (l, r) => - Logic.greaterThanOrEqual(l, r) - ); - } - - if ('lt' in constraint) { - return this.transformComparison(constraint.lt.left, constraint.lt.right, (l, r) => Logic.lessThan(l, r)); - } - - if ('lte' in constraint) { - return this.transformComparison(constraint.lte.left, constraint.lte.right, (l, r) => - Logic.greaterThan(l, r) - ); - } - - if ('and' in constraint) { - return Logic.and(...constraint.and.map((c) => this.buildFormula(c))); - } + return Logic.not(this.buildFormula(constraint.children[0])); + }) + .exhaustive(); + } - if ('or' in constraint) { - return Logic.or(...constraint.or.map((c) => this.buildFormula(c))); - } + private buildComparisonFormula(constraint: ComparisonConstraint) { + return match(constraint.kind) + .with('eq', () => this.transformEquality(constraint.left, constraint.right)) + .with('gt', () => + this.transformComparison(constraint.left, constraint.right, (l, r) => Logic.greaterThan(l, r)) + ) + .with('gte', () => + this.transformComparison(constraint.left, constraint.right, (l, r) => Logic.greaterThanOrEqual(l, r)) + ) + .with('lt', () => + this.transformComparison(constraint.left, constraint.right, (l, r) => Logic.lessThan(l, r)) + ) + .with('lte', () => + this.transformComparison(constraint.left, constraint.right, (l, r) => Logic.lessThanOrEqual(l, r)) + ) + .exhaustive(); + } - if ('not' in constraint) { - return Logic.not(this.buildFormula(constraint.not)); - } + buildVariableFormula(constraint: VariableConstraint) { + return match(constraint.type) + .with('boolean', () => this.booleanVariable(constraint.name)) + .with('number', () => this.intVariable(constraint.name)) + .with('string', () => this.intVariable(constraint.name)) + .exhaustive(); + } - throw new Error(`Unsupported constraint format: ${JSON.stringify(constraint)}`); + private buildValueFormula(constraint: ValueConstraint) { + return match(constraint.value) + .when( + (v): v is boolean => typeof v === 'boolean', + (v) => (v === true ? Logic.TRUE : Logic.FALSE) + ) + .when( + (v): v is number => typeof v === 'number', + (v) => Logic.constantBits(v) + ) + .when( + (v): v is string => typeof v === 'string', + (v) => { + const index = this.stringTable.indexOf(v); + if (index === -1) { + this.stringTable.push(v); + return Logic.constantBits(this.stringTable.length - 1); + } else { + return Logic.constantBits(index); + } + } + ) + .exhaustive(); } - private booleanVariable(constraint: ConstraintVariable): string { - this.variables.set(constraint.name, constraint.name); - return constraint.name; + private booleanVariable(name: string) { + this.variables.set(name, name); + return name; } - private intVariable(name: string): string { + private intVariable(name: string) { const r = Logic.variableBits(name, 32); this.variables.set(name, r); return r; @@ -132,146 +151,3 @@ export class ConstraintSolver { return func(leftConstraint, rightConstraint); } } - -// export function solve(constraint: CheckerConstraint) { -// const stringTable: string[] = []; -// const formula = buildFormula(constraint, stringTable); -// const solver = new Logic.Solver(); -// solver.require(formula); -// const solution = solver.solve(); -// console.log('Solution:', solution?.getMap()); -// return !!solution; -// } - -// function buildFormula(constraint: CheckerConstraint, stringTable: string[]): Logic.Formula { -// if ('value' in constraint) { -// if (constraint.value === null) { -// return Logic.constantBits(MAGIC_NULL); -// } - -// if (typeof constraint.value === 'boolean') { -// return constraint.value === true ? Logic.TRUE : Logic.FALSE; -// } - -// if (typeof constraint.value === 'number') { -// return Logic.constantBits(constraint.value); -// } - -// if (typeof constraint.value === 'string') { -// const index = stringTable.indexOf(constraint.value); -// if (index === -1) { -// stringTable.push(constraint.value); -// return Logic.constantBits(stringTable.length - 1); -// } else { -// return Logic.constantBits(index); -// } -// } -// } - -// if ('name' in constraint) { -// // variable -// return match(constraint.type) -// .with('boolean', () => constraint.name) -// .with('number', () => Logic.variableBits(constraint.name, 32)) -// .with('string', () => Logic.variableBits(constraint.name, 32)) -// .exhaustive(); -// } - -// if ('eq' in constraint) { -// return transformEquality(constraint.eq.left, constraint.eq.right, stringTable); -// } - -// if ('gt' in constraint) { -// return transformComparison(constraint.gt.left, constraint.gt.right, stringTable, (l, r) => -// Logic.greaterThan(l, r) -// ); -// } - -// if ('gte' in constraint) { -// return transformComparison(constraint.gte.left, constraint.gte.right, stringTable, (l, r) => -// Logic.greaterThanOrEqual(l, r) -// ); -// } - -// if ('lt' in constraint) { -// return transformComparison(constraint.lt.left, constraint.lt.right, stringTable, (l, r) => -// Logic.lessThan(l, r) -// ); -// } - -// if ('lte' in constraint) { -// return transformComparison(constraint.lte.left, constraint.lte.right, stringTable, (l, r) => -// Logic.greaterThan(l, r) -// ); -// } - -// if ('and' in constraint) { -// return Logic.and(...constraint.and.map((c) => buildFormula(c, stringTable))); -// } - -// if ('or' in constraint) { -// return Logic.or(...constraint.or.map((c) => buildFormula(c, stringTable))); -// } - -// if ('not' in constraint) { -// return Logic.not(buildFormula(constraint.not, stringTable)); -// } - -// throw new Error(`Unsupported constraint format: ${JSON.stringify(constraint)}`); -// } - -// function transformEquality(left: ComparisonTerm, right: ComparisonTerm, stringTable: string[]) { -// if (left.type !== right.type) { -// throw new Error(`Type mismatch in equality constraint: ${JSON.stringify(left)}, ${JSON.stringify(right)}`); -// } -// const leftConstraint = buildFormula(left, stringTable); -// const rightConstraint = buildFormula(right, stringTable); -// if (left.type === 'boolean' && right.type === 'boolean') { -// return Logic.equiv(leftConstraint, rightConstraint); -// } else { -// return Logic.equalBits(leftConstraint, rightConstraint); -// } -// } - -// function transformComparison( -// left: ComparisonTerm, -// right: ComparisonTerm, -// stringTable: string[], -// func: (left: Logic.Formula, right: Logic.Formula) => Logic.Formula -// ): string { -// const leftConstraint = buildFormula(left, stringTable); -// const rightConstraint = buildFormula(right, stringTable); -// return func(leftConstraint, rightConstraint); -// } - -// // export type Constraint = Logic.Formula; - -// // export function TRUE(): Constraint { -// // return Logic.TRUE; -// // } - -// // export function FALSE(): Constraint { -// // return Logic.FALSE; -// // } - -// // export function variable(name: string): Constraint { -// // return name; -// // } - -// // export function and(...args: Constraint[]): Constraint { -// // return Logic.and(...args); -// // } - -// // export function or(...args: Constraint[]): Constraint { -// // return Logic.or(...args); -// // } - -// // export function not(arg: Constraint): Constraint { -// // return Logic.not(arg); -// // } - -// // export function checkSat(constraint: Constraint): boolean { -// // const solver = new Logic.Solver(); -// // solver.require(constraint); -// // return !!solver.solve(); -// // } diff --git a/packages/runtime/src/enhancements/policy/handler.ts b/packages/runtime/src/enhancements/policy/handler.ts index ee4c5fb42..194d8c4c5 100644 --- a/packages/runtime/src/enhancements/policy/handler.ts +++ b/packages/runtime/src/enhancements/policy/handler.ts @@ -2,6 +2,7 @@ import { lowerCaseFirst } from 'lower-case-first'; import invariant from 'tiny-invariant'; +import { P, match } from 'ts-pattern'; import { upperCaseFirst } from 'upper-case-first'; import { fromZodError } from 'zod-validation-error'; import { CrudFailureReason } from '../../constants'; @@ -1452,18 +1453,48 @@ export class PolicyProxyHandler implements Pr if (fieldValues) { const extraConstraints: CheckerConstraint[] = []; for (const [field, value] of Object.entries(fieldValues)) { - if (value !== undefined) { - const valueType = typeof value; - if (valueType !== 'number' && valueType !== 'string' && valueType !== 'boolean') { - throw new Error(`invalid value type for field "${field}" is not supported`); - } - extraConstraints.push({ - eq: { left: { name: field, type: valueType }, right: { value, type: valueType } }, + if (value === undefined) { + continue; + } + + if (value === null) { + throw new Error(`Using "null" as filter value is not supported yet`); + } + + const fieldInfo = requireField(this.modelMeta, this.model, field); + + if (fieldInfo.isDataModel || fieldInfo.isArray) { + throw new Error( + `Providing filter for field "${field}" is not supported. Only scalar fields are allowed.` + ); + } + + const fieldType = match(fieldInfo.type) + .with(P.union('Int', 'BigInt', 'Float', 'Decimal'), () => 'number') + .with('String', () => 'string') + .with('Boolean', () => 'boolean') + .otherwise(() => { + throw new Error( + `Providing filter for field "${field}" is not supported. Only number, string, and boolean fields are allowed.` + ); }); + + const valueType = typeof value; + if (valueType !== 'number' && valueType !== 'string' && valueType !== 'boolean') { + throw new Error( + `Invalid value for field "${field}". Only number, string, boolean, or null is allowed.` + ); } + + extraConstraints.push({ + kind: 'eq', + left: { kind: 'variable', name: field, type: fieldType }, + right: { kind: 'value', value, type: fieldType }, + }); } + if (extraConstraints.length > 0) { - constraint = { and: [constraint, ...extraConstraints] }; + constraint = { kind: 'and', children: [constraint, ...extraConstraints] }; } } diff --git a/packages/runtime/src/enhancements/policy/logic-solver.d.ts b/packages/runtime/src/enhancements/policy/logic-solver.d.ts index b177c520e..31eea6758 100644 --- a/packages/runtime/src/enhancements/policy/logic-solver.d.ts +++ b/packages/runtime/src/enhancements/policy/logic-solver.d.ts @@ -1,11 +1,9 @@ declare module 'logic-solver' { - type Term = string; + interface Formula {} - type Formula = Term; + const TRUE: Formula; - const TRUE: Term; - - const FALSE: Term; + const FALSE: Formula; export function equiv(operand1: Formula, operand2: Formula): Formula; diff --git a/packages/runtime/src/enhancements/types.ts b/packages/runtime/src/enhancements/types.ts index 8e2178b15..17b1ac48e 100644 --- a/packages/runtime/src/enhancements/types.ts +++ b/packages/runtime/src/enhancements/types.ts @@ -35,21 +35,30 @@ export type PolicyFunc = (context: QueryContext, db: CrudContract) => object; export type CheckerFunc = (context: CheckerContext) => CheckerConstraint; -export type ConstraintVariable = { name: string; type: 'boolean' | 'number' | 'string' }; -export type ConstraintValue = { value: number | boolean | string | null; type: 'boolean' | 'number' | 'string' }; -export type ComparisonTerm = ConstraintVariable | ConstraintValue; - -export type CheckerConstraint = - | ConstraintValue - | ConstraintVariable - | { eq: { left: ComparisonTerm; right: ComparisonTerm } } - | { gt: { left: ComparisonTerm; right: ComparisonTerm } } - | { gte: { left: ComparisonTerm; right: ComparisonTerm } } - | { lt: { left: ComparisonTerm; right: ComparisonTerm } } - | { lte: { left: ComparisonTerm; right: ComparisonTerm } } - | { and: CheckerConstraint[] } - | { or: CheckerConstraint[] } - | { not: CheckerConstraint }; +export type ConstraintValueTypes = 'boolean' | 'number' | 'string'; + +export type VariableConstraint = { kind: 'variable'; name: string; type: ConstraintValueTypes }; + +export type ValueConstraint = { + kind: 'value'; + value: number | boolean | string; + type: ConstraintValueTypes; +}; + +export type ComparisonTerm = VariableConstraint | ValueConstraint; + +export type ComparisonConstraint = { + kind: 'eq' | 'gt' | 'gte' | 'lt' | 'lte'; + left: ComparisonTerm; + right: ComparisonTerm; +}; + +export type LogicalConstraint = { + kind: 'and' | 'or' | 'not'; + children: CheckerConstraint[]; +}; + +export type CheckerConstraint = ValueConstraint | VariableConstraint | ComparisonConstraint | LogicalConstraint; /** * Function for getting policy guard with a given context diff --git a/packages/schema/src/plugins/enhancer/policy/constraint-transformer.ts b/packages/schema/src/plugins/enhancer/policy/constraint-transformer.ts index 5b6fbbba8..2da7b5188 100644 --- a/packages/schema/src/plugins/enhancer/policy/constraint-transformer.ts +++ b/packages/schema/src/plugins/enhancer/policy/constraint-transformer.ts @@ -1,11 +1,14 @@ -import { ZModelCodeGenerator, getLiteral, isAuthInvocation } from '@zenstackhq/sdk'; +import { ZModelCodeGenerator, isAuthInvocation } from '@zenstackhq/sdk'; import { BinaryExpr, + BooleanLiteral, Expression, ExpressionType, LiteralExpr, MemberAccessExpr, + NumberLiteral, ReferenceExpr, + StringLiteral, UnaryExpr, isBinaryExpr, isDataModelField, @@ -29,12 +32,8 @@ export class ConstraintTransformer { transformRules(allows: Expression[], denies: Expression[]): string { this.varCounter = 0; - if (allows.length === 0 && denies.length === 0) { - return `{ value: true, type: 'boolean' }`; - } - if (allows.length === 0) { - return `{ value: false, type: 'boolean' }`; + return this.value('false', 'boolean'); } let result: string; @@ -51,6 +50,8 @@ export class ConstraintTransformer { result = this.and(result, this.not(this.or(...denyConstraints))); } + console.log(`Constraint transformation result:\n${JSON.stringify(result, null, 2)}`); + return result; } @@ -58,18 +59,18 @@ export class ConstraintTransformer { if (constraints.length === 0) { throw new Error('No expressions to combine'); } - return constraints.length === 1 ? `{ and: [ ${constraints.join(', ')} ] }` : constraints[0]; + return constraints.length === 1 ? constraints[0] : `{ kind: 'and', children: [ ${constraints.join(', ')} ] }`; } private or(...constraints: string[]) { if (constraints.length === 0) { throw new Error('No expressions to combine'); } - return constraints.length === 1 ? `{ or: [ ${constraints.join(', ')} ] }` : constraints[0]; + return constraints.length === 1 ? constraints[0] : `{ kind: 'or', children: [ ${constraints.join(', ')} ] }`; } private not(constraint: string) { - return `{ not: ${constraint} }`; + return `{ kind: 'not', children: [${constraint}] }`; } private transformExpression(expression: Expression) { @@ -88,17 +89,20 @@ export class ConstraintTransformer { } private transformLiteral(expr: LiteralExpr) { - const value = getLiteral(expr); - return value ? `{ value: true, type: 'boolean' }` : `{ value: false, type: 'boolean' }`; + return match(expr.$type) + .with(NumberLiteral, () => this.value(expr.value.toString(), 'number')) + .with(StringLiteral, () => this.value(`'${expr.value}'`, 'string')) + .with(BooleanLiteral, () => this.value(expr.value.toString(), 'boolean')) + .exhaustive(); } private transformReference(expr: ReferenceExpr) { - return `{ name: '${expr.target.$refText}', type: 'boolean' }`; + return this.variable(expr.target.$refText, 'boolean'); } private transformMemberAccess(expr: MemberAccessExpr) { if (isThisExpr(expr.operand)) { - return `{ name: '${expr.member.$refText}', type: 'boolean' }`; + return this.variable(expr.member.$refText, 'boolean'); } return this.nextVar(); } @@ -114,7 +118,7 @@ export class ConstraintTransformer { private transformUnary(expr: UnaryExpr): string { return match(expr.operator) .with('!', () => this.not(this.transformExpression(expr.operand))) - .otherwise(() => this.nextVar()); + .exhaustive(); } private transformComparison(expr: BinaryExpr) { @@ -136,9 +140,9 @@ export class ConstraintTransformer { throw new Error(`Unsupported operator: ${expr.operator}`); }); - let result = `{ ${op}: { left: ${leftOperand}, right: ${rightOperand} } }`; + let result = `{ kind: '${op}', left: ${leftOperand}, right: ${rightOperand} }`; if (expr.operator === '!=') { - result = `{ not: ${result} }`; + result = `{ kind: 'not', children: [${result}] }`; } return result; @@ -146,24 +150,14 @@ export class ConstraintTransformer { private getComparisonOperand(expr: Expression) { if (isLiteralExpr(expr)) { - const mappedType = this.mapType(expr.$resolvedType?.decl as ExpressionType); - if (mappedType) { - return `{ value: ${expr.value}, type: '${mappedType}' }`; - } else { - return undefined; - } + return this.transformLiteral(expr); } const fieldAccess = this.getFieldAccess(expr); if (fieldAccess) { - const fieldType = expr.$resolvedType?.decl; - if (!fieldType) { - return undefined; - } - - const mappedType = this.mapType(fieldType as ExpressionType); + const mappedType = this.mapType(expr); if (mappedType) { - return `{ name: '${fieldAccess.name}', type: '${mappedType}' }`; + return this.variable(fieldAccess.name, mappedType); } else { return undefined; } @@ -171,14 +165,13 @@ export class ConstraintTransformer { const authAccess = this.getAuthAccess(expr); if (authAccess) { - const fieldType = expr.$resolvedType?.decl; - if (!fieldType) { - return undefined; - } - - const mappedType = this.mapType(fieldType as ExpressionType); + const fieldAccess = `${this.options.authAccessor}?.${authAccess}`; + const mappedType = this.mapType(expr); if (mappedType) { - return `{ value: ${this.options.authAccessor}?.${authAccess.name}, type: '${mappedType}' }`; + return `${fieldAccess} === undefined ? ${this.expressionVariable(expr, mappedType)} : ${this.value( + fieldAccess, + mappedType + )}`; } else { return undefined; } @@ -187,69 +180,14 @@ export class ConstraintTransformer { return undefined; } - private mapType(fieldType: ExpressionType) { - return match(fieldType) + private mapType(expression: Expression) { + return match(expression.$resolvedType?.decl as ExpressionType) .with('Boolean', () => 'boolean') .with('Int', () => 'number') .with('String', () => 'string') .otherwise(() => undefined); } - // private transformEquality(left: Expression, right: Expression) { - // if (this.isFieldAccess(left) || this.isFieldAccess(right)) { - // const variable = this.isFieldAccess(left) ? left : right; - // const value = this.isFieldAccess(left) ? right : left; - - // const value = this.getExprValue(right); - // if (value !== undefined) { - // if (value === true) { - // return `{ var: '${this.getFieldName(left)}' }`; - // } else if (value === false) { - // return `{ not: { var: '${this.getFieldName(left)}' } }`; - // } else { - // return `{ eq: { ${this.getFieldName(left)}: ${this.encodeValue(value)} } }`; - // } - // } - // } - - // if (this.isFieldAccess(right)) { - // const value = this.getExprValue(left); - // if (value !== undefined) { - // return `'${this.getFieldName(right)} == ${value}'`; - // } - // } - - // if (this.isAuthAccess(left)) { - // const value = this.getExprValue(right); - // if (value !== undefined) { - // return `${this.options.authAccessor}?.${left.member.$refText} === ${value}`; - // } - // } - - // if (this.isAuthAccess(right)) { - // const value = this.getExprValue(left); - // if (value !== undefined) { - // return `${this.options.authAccessor}?.${right.member.$refText} === ${value}`; - // } - // } - - // return this.nextVar(); - // } - - // private encodeValue(value: string | number | boolean) { - // if (typeof value === 'number' || typeof value === 'boolean') { - // return value; - // } else { - // const index = this.stringTable.indexOf(value); - // if (index === -1) { - // this.stringTable.push(value); - // return this.stringTable.length - 1; - // } else { - // return index; - // } - // } - // } - private getFieldAccess(expr: Expression) { if (isReferenceExpr(expr)) { return isDataModelField(expr.target.ref) ? { name: expr.target.$refText } : undefined; @@ -260,31 +198,33 @@ export class ConstraintTransformer { return undefined; } - // private getFieldName(expr: ReferenceExpr | MemberAccessExpr): string { - // return isReferenceExpr(expr) ? expr.target.$refText : expr.member.$refText; - // } - - // private getExprValue(expr: Expression): string | boolean | number | undefined { - // if (isLiteralExpr(expr)) { - // return expr.value; - // } + private getAuthAccess(expr: Expression): string | undefined { + if (!isMemberAccessExpr(expr)) { + return undefined; + } - // if (this.isAuthAccess(expr)) { - // return `${this.options.authAccessor}?.${expr.member.$refText}`; - // } + if (isAuthInvocation(expr.operand)) { + return expr.member.$refText; + } else { + const operand = this.getAuthAccess(expr.operand); + return operand ? `${operand}?.${expr.member.$refText}` : undefined; + } + } - // return undefined; - // } + private nextVar(type = 'boolean') { + return this.variable(`__var${this.varCounter++}`, type); + } - private getAuthAccess(expr: Expression) { - return isMemberAccessExpr(expr) && isAuthInvocation(expr.operand) ? { name: expr.member.$refText } : undefined; + private expressionVariable(expr: Expression, type: string) { + const name = new ZModelCodeGenerator().generate(expr); + return this.variable(name, type); } - private nextVar() { - return `'__var${this.varCounter++}'`; + private variable(name: string, type: string) { + return `{ kind: 'variable', name: '${name}', type: '${type}' }`; } - private expressionVariable(expr: Expression) { - return new ZModelCodeGenerator().generate(expr); + private value(value: string, type: string) { + return `{ kind: 'value', value: ${value}, type: '${type}' }`; } } diff --git a/packages/schema/src/plugins/enhancer/policy/policy-guard-generator.ts b/packages/schema/src/plugins/enhancer/policy/policy-guard-generator.ts index 4dc6c27b2..8e47c6614 100644 --- a/packages/schema/src/plugins/enhancer/policy/policy-guard-generator.ts +++ b/packages/schema/src/plugins/enhancer/policy/policy-guard-generator.ts @@ -135,6 +135,7 @@ export class PolicyGenerator { writer.write(`${op}: ${func},`); }); }); + writer.writeLine(','); } }); writer.writeLine(','); @@ -908,10 +909,16 @@ export class PolicyGenerator { allows: Expression[], denies: Expression[] ) { + const statements: string[] = []; + + this.generateNormalizedAuthRef(model, allows, denies, statements); + const transformed = new ConstraintTransformer({ - authAccessor: 'context.user', + authAccessor: 'user', }).transformRules(allows, denies); + statements.push(`return ${transformed};`); + const func = sourceFile.addFunction({ name: `${model.name}Checker_${kind}`, returnType: 'CheckerConstraint', @@ -921,7 +928,7 @@ export class PolicyGenerator { type: 'CheckerContext', }, ], - statements: [`return ${transformed};`], + statements, }); return func; diff --git a/tests/integration/tests/enhancements/with-policy/checker.test.ts b/tests/integration/tests/enhancements/with-policy/checker.test.ts index 736402f1e..83791559c 100644 --- a/tests/integration/tests/enhancements/with-policy/checker.test.ts +++ b/tests/integration/tests/enhancements/with-policy/checker.test.ts @@ -1,13 +1,63 @@ import { loadSchema } from '@zenstackhq/testtools'; describe('Permission checker', () => { - it('simple', async () => { + it('empty rules', async () => { const { enhance } = await loadSchema( ` model Model { - id String @id @default(uuid()) + id Int @id @default(autoincrement()) + value Int + } + ` + ); + const db = enhance(); + await expect(db.model.check('read')).toResolveFalsy(); + await expect(db.model.check('read', { value: 1 })).toResolveFalsy(); + }); + + it('unconditional allow', async () => { + const { enhance } = await loadSchema( + ` + model Model { + id Int @id @default(autoincrement()) + value Int + @@allow('all', true) + } + ` + ); + const db = enhance(); + await expect(db.model.check('read')).toResolveTruthy(); + await expect(db.model.check('read', { value: 0 })).toResolveTruthy(); + }); + + it('deny rule', async () => { + const { enhance } = await loadSchema( + ` + model Model { + id Int @id @default(autoincrement()) + value Int + @@allow('all', value > 0) + @@deny('all', value == 1) + } + ` + ); + const db = enhance(); + await expect(db.model.check('read')).toResolveTruthy(); + await expect(db.model.check('read', { value: 0 })).toResolveFalsy(); + await expect(db.model.check('read', { value: 1 })).toResolveFalsy(); + await expect(db.model.check('read', { value: 2 })).toResolveTruthy(); + }); + + it('int field condition', async () => { + const { enhance } = await loadSchema( + ` + model Model { + id Int @id @default(autoincrement()) value Int @@allow('read', value == 1) + @@allow('create', value != 1) + @@allow('update', value > 1) + @@allow('delete', value <= 1) } ` ); @@ -16,5 +66,264 @@ describe('Permission checker', () => { await expect(db.model.check('read')).toResolveTruthy(); await expect(db.model.check('read', { value: 0 })).toResolveFalsy(); await expect(db.model.check('read', { value: 1 })).toResolveTruthy(); + + await expect(db.model.check('create')).toResolveTruthy(); + await expect(db.model.check('create', { value: 0 })).toResolveTruthy(); + await expect(db.model.check('create', { value: 1 })).toResolveFalsy(); + + await expect(db.model.check('update')).toResolveTruthy(); + await expect(db.model.check('update', { value: 1 })).toResolveFalsy(); + await expect(db.model.check('update', { value: 2 })).toResolveTruthy(); + + await expect(db.model.check('delete')).toResolveTruthy(); + await expect(db.model.check('delete', { value: 0 })).toResolveTruthy(); + await expect(db.model.check('delete', { value: 1 })).toResolveTruthy(); + await expect(db.model.check('delete', { value: 2 })).toResolveFalsy(); + }); + + it('boolean field toplevel condition', async () => { + const { enhance } = await loadSchema( + ` + model Model { + id Int @id @default(autoincrement()) + value Boolean + @@allow('read', value) + } + ` + ); + + const db = enhance(); + await expect(db.model.check('read')).toResolveTruthy(); + await expect(db.model.check('read', { value: false })).toResolveFalsy(); + await expect(db.model.check('read', { value: true })).toResolveTruthy(); + }); + + it('boolean field condition', async () => { + const { enhance } = await loadSchema( + ` + model Model { + id Int @id @default(autoincrement()) + value Boolean + @@allow('read', value == true) + @@allow('create', value == false) + @@allow('update', value != true) + @@allow('delete', value != false) + } + ` + ); + + const db = enhance(); + await expect(db.model.check('read')).toResolveTruthy(); + await expect(db.model.check('read', { value: false })).toResolveFalsy(); + await expect(db.model.check('read', { value: true })).toResolveTruthy(); + + await expect(db.model.check('create')).toResolveTruthy(); + await expect(db.model.check('create', { value: true })).toResolveFalsy(); + await expect(db.model.check('create', { value: false })).toResolveTruthy(); + + await expect(db.model.check('update')).toResolveTruthy(); + await expect(db.model.check('update', { value: true })).toResolveFalsy(); + await expect(db.model.check('update', { value: false })).toResolveTruthy(); + + await expect(db.model.check('delete')).toResolveTruthy(); + await expect(db.model.check('delete', { value: false })).toResolveFalsy(); + await expect(db.model.check('delete', { value: true })).toResolveTruthy(); + }); + + it('string field condition', async () => { + const { enhance } = await loadSchema( + ` + model Model { + id Int @id @default(autoincrement()) + value String + @@allow('read', value == 'admin') + } + ` + ); + + const db = enhance(); + await expect(db.model.check('read')).toResolveTruthy(); + await expect(db.model.check('read', { value: 'user' })).toResolveFalsy(); + await expect(db.model.check('read', { value: 'admin' })).toResolveTruthy(); + }); + + it('function noop', async () => { + const { enhance } = await loadSchema( + ` + model Model { + id Int @id @default(autoincrement()) + value String + @@allow('read', startsWith(value, 'admin')) + @@allow('update', !startsWith(value, 'admin')) + } + ` + ); + + const db = enhance(); + await expect(db.model.check('read')).toResolveTruthy(); + await expect(db.model.check('read', { value: 'user' })).toResolveTruthy(); + await expect(db.model.check('read', { value: 'admin' })).toResolveTruthy(); + await expect(db.model.check('update')).toResolveTruthy(); + await expect(db.model.check('update', { value: 'user' })).toResolveTruthy(); + await expect(db.model.check('update', { value: 'admin' })).toResolveTruthy(); + }); + + it('relation noop', async () => { + const { enhance } = await loadSchema( + ` + model Model { + id Int @id @default(autoincrement()) + value String + foo Foo? + + @@allow('read', foo.x > 0) + } + + model Foo { + id Int @id @default(autoincrement()) + x Int + modelId Int @unique + model Model @relation(fields: [modelId], references: [id]) + } + ` + ); + + const db = enhance(); + await expect(db.model.check('read')).toResolveTruthy(); + await expect(db.model.check('read', { foo: { x: 0 } })).rejects.toThrow('Providing filter for field "foo"'); + }); + + it('collection predicate noop', async () => { + const { enhance } = await loadSchema( + ` + model Model { + id Int @id @default(autoincrement()) + value String + foo Foo[] + + @@allow('read', foo?[x > 0]) + } + + model Foo { + id Int @id @default(autoincrement()) + x Int + modelId Int + model Model @relation(fields: [modelId], references: [id]) + } + ` + ); + + const db = enhance(); + await expect(db.model.check('read')).toResolveTruthy(); + await expect(db.model.check('read', { foo: [{ x: 0 }] })).rejects.toThrow('Providing filter for field "foo"'); + }); + + it('field complex condition', async () => { + const { enhance } = await loadSchema( + ` + model Model { + id Int @id @default(autoincrement()) + x Int + y Int + @@allow('read', x > 0 && x > y) + } + ` + ); + + const db = enhance(); + await expect(db.model.check('read')).toResolveTruthy(); + await expect(db.model.check('read', { x: 0 })).toResolveFalsy(); + await expect(db.model.check('read', { x: 1 })).toResolveTruthy(); + await expect(db.model.check('read', { x: 1, y: 0 })).toResolveTruthy(); + await expect(db.model.check('read', { x: 1, y: 1 })).toResolveFalsy(); + }); + + it('field condition unsolvable', async () => { + const { enhance } = await loadSchema( + ` + model Model { + id Int @id @default(autoincrement()) + x Int + y Int + @@allow('read', x > 0 && x < y && y <= 1) + } + ` + ); + + const db = enhance(); + await expect(db.model.check('read')).toResolveFalsy(); + await expect(db.model.check('read', { x: 0 })).toResolveFalsy(); + await expect(db.model.check('read', { x: 1 })).toResolveFalsy(); + await expect(db.model.check('read', { x: 1, y: 2 })).toResolveFalsy(); + await expect(db.model.check('read', { x: 1, y: 1 })).toResolveFalsy(); + }); + + it('simple auth condition', async () => { + const { enhance } = await loadSchema( + ` + model User { + id Int @id @default(autoincrement()) + level Int + } + + model Model { + id Int @id @default(autoincrement()) + value Int + @@allow('read', auth().level > 0) + } + ` + ); + + await expect(enhance().model.check('read')).toResolveTruthy(); + await expect(enhance({ id: 1, level: 0 }).model.check('read')).toResolveFalsy(); + await expect(enhance({ id: 1, level: 1 }).model.check('read')).toResolveTruthy(); + }); + + it('auth with relation', async () => { + const { enhance } = await loadSchema( + ` + model User { + id Int @id @default(autoincrement()) + profile Profile? + } + + model Profile { + id Int @id @default(autoincrement()) + level Int + user User @relation(fields: [userId], references: [id]) + userId Int @unique + } + + model Model { + id Int @id @default(autoincrement()) + value Int + @@allow('read', auth().profile.level > 0) + } + ` + ); + + await expect(enhance().model.check('read')).toResolveTruthy(); + await expect(enhance({ id: 1 }).model.check('read')).toResolveTruthy(); + await expect(enhance({ id: 1, profile: { level: 0 } }).model.check('read')).toResolveFalsy(); + await expect(enhance({ id: 1, profile: { level: 1 } }).model.check('read')).toResolveTruthy(); + }); + + it('nullable field', async () => { + const { enhance } = await loadSchema( + ` + model Model { + id Int @id @default(autoincrement()) + value Int? + @@allow('read', value != null) + @@allow('create', value == null) + } + ` + ); + + const db = enhance(); + await expect(db.model.check('read')).toResolveTruthy(); + await expect(db.model.check('read', { value: 1 })).toResolveTruthy(); + await expect(db.model.check('create')).toResolveTruthy(); + await expect(db.model.check('create', { value: 1 })).toResolveTruthy(); }); }); From 19fea2c4025625eccb452fe481898588a883f1ca Mon Sep 17 00:00:00 2001 From: ymc9 <104139426+ymc9@users.noreply.github.com> Date: Sun, 5 May 2024 13:59:34 +0800 Subject: [PATCH 03/13] WIP: progress --- .../enhance/checker-type-generator.ts | 44 +++++++++++++++++++ .../src/plugins/enhancer/enhance/index.ts | 13 ++++-- .../enhancements/with-policy/checker.test.ts | 29 ++++++++++++ 3 files changed, 82 insertions(+), 4 deletions(-) create mode 100644 packages/schema/src/plugins/enhancer/enhance/checker-type-generator.ts diff --git a/packages/schema/src/plugins/enhancer/enhance/checker-type-generator.ts b/packages/schema/src/plugins/enhancer/enhance/checker-type-generator.ts new file mode 100644 index 000000000..f7986b322 --- /dev/null +++ b/packages/schema/src/plugins/enhancer/enhance/checker-type-generator.ts @@ -0,0 +1,44 @@ +import { getDataModels } from '@zenstackhq/sdk'; +import type { DataModel, DataModelField, Model } from '@zenstackhq/sdk/ast'; +import { lowerCaseFirst } from 'lower-case-first'; +import { P, match } from 'ts-pattern'; + +/** + * Generates a `ModelCheckers` interface that contains a `check` method for each model in the schema. + */ +export function generateCheckerType(model: Model) { + return ` +type CheckerOperation = 'create' | 'read' | 'update' | 'delete'; + +export interface ModelCheckers { + ${getDataModels(model) + .map((dataModel) => `\t${lowerCaseFirst(dataModel.name)}: ${generateDataModelChecker(dataModel)}`) + .join(',\n')} +} +`; +} + +function generateDataModelChecker(dataModel: DataModel) { + return `{ + check(op: CheckerOperation, args?: ${generateDataModelArgs(dataModel)}): Promise + }`; +} + +function generateDataModelArgs(dataModel: DataModel) { + return `{ ${dataModel.fields + .filter((field) => isFieldFilterable(field)) + .map((field) => `${field.name}?: ${mapFieldType(field)}`) + .join('; ')} }`; +} + +function isFieldFilterable(field: DataModelField) { + return !!mapFieldType(field); +} + +function mapFieldType(field: DataModelField) { + return match(field.type.type) + .with('Boolean', () => 'boolean') + .with(P.union('BigInt', 'Int', 'Float', 'Decimal'), () => 'number') + .with('String', () => 'string') + .otherwise(() => undefined); +} diff --git a/packages/schema/src/plugins/enhancer/enhance/index.ts b/packages/schema/src/plugins/enhancer/enhance/index.ts index 0425b76d0..81fe66c3c 100644 --- a/packages/schema/src/plugins/enhancer/enhance/index.ts +++ b/packages/schema/src/plugins/enhancer/enhance/index.ts @@ -41,6 +41,7 @@ import { trackPrismaSchemaError } from '../../prisma'; import { PrismaSchemaGenerator } from '../../prisma/schema-generator'; import { isDefaultWithAuth } from '../enhancer-utils'; import { generateAuthType } from './auth-type-generator'; +import { generateCheckerType } from './checker-type-generator'; // information of delegate models and their sub models type DelegateInfo = [DataModel, DataModel[]][]; @@ -89,6 +90,8 @@ export class EnhancerGenerator { const authTypes = authModel ? generateAuthType(this.model, authModel) : ''; const authTypeParam = authModel ? `auth.${authModel.name}` : 'AuthUser'; + const checkerTypes = generateCheckerType(this.model); + const enhanceTs = this.project.createSourceFile( path.join(this.outDir, 'enhance.ts'), `import { type EnhancementContext, type EnhancementOptions, type ZodSchemas, type AuthUser } from '@zenstackhq/runtime'; @@ -105,6 +108,8 @@ ${ ${authTypes} +${checkerTypes} + ${ logicalPrismaClientDir ? this.createLogicalPrismaEnhanceFunction(authTypeParam) @@ -127,14 +132,14 @@ import type * as _P from '${prismaImport}'; private createSimplePrismaEnhanceFunction(authTypeParam: string) { return ` -export function enhance(prisma: DbClient, context?: EnhancementContext<${authTypeParam}>, options?: EnhancementOptions) { +export function enhance(prisma: DbClient, context?: EnhancementContext<${authTypeParam}>, options?: EnhancementOptions): DbClient & ModelCheckers { return createEnhancement(prisma, { modelMeta, policy, zodSchemas: zodSchemas as unknown as (ZodSchemas | undefined), prismaModule: Prisma, ...options - }, context); + }, context) as DbClient & ModelCheckers; } `; } @@ -157,12 +162,12 @@ import type { Prisma, PrismaClient } from '${logicalPrismaClientDir}/index-fixed // overload for plain PrismaClient export function enhance & InternalArgs>( prisma: _PrismaClient, - context?: EnhancementContext<${authTypeParam}>, options?: EnhancementOptions): PrismaClient; + context?: EnhancementContext<${authTypeParam}>, options?: EnhancementOptions): PrismaClient & ModelCheckers; // overload for extended PrismaClient export function enhance & InternalArgs>( prisma: DynamicClientExtensionThis, - context?: EnhancementContext<${authTypeParam}>, options?: EnhancementOptions): DynamicClientExtensionThis; + context?: EnhancementContext<${authTypeParam}>, options?: EnhancementOptions): DynamicClientExtensionThis & ModelCheckers; export function enhance(prisma: any, context?: EnhancementContext<${authTypeParam}>, options?: EnhancementOptions): any { return createEnhancement(prisma, { diff --git a/tests/integration/tests/enhancements/with-policy/checker.test.ts b/tests/integration/tests/enhancements/with-policy/checker.test.ts index 83791559c..52fd55afc 100644 --- a/tests/integration/tests/enhancements/with-policy/checker.test.ts +++ b/tests/integration/tests/enhancements/with-policy/checker.test.ts @@ -326,4 +326,33 @@ describe('Permission checker', () => { await expect(db.model.check('create')).toResolveTruthy(); await expect(db.model.check('create', { value: 1 })).toResolveTruthy(); }); + + it('compilation', async () => { + await loadSchema( + ` + model Model { + id Int @id @default(autoincrement()) + value Int + @@allow('read', value == 1) + } + `, + { + compile: true, + extraSourceFiles: [ + { + name: 'main.ts', + content: ` + import { PrismaClient } from '@prisma/client'; + import { enhance } from '.zenstack/enhance'; + + const prisma = new PrismaClient(); + const db = enhance(prisma); + db.model.check('read'); + db.model.check('read', { value: 1 }); + `, + }, + ], + } + ); + }); }); From 2ef2da0976a7c651f56c667182893756b6da75a7 Mon Sep 17 00:00:00 2001 From: ymc9 <104139426+ymc9@users.noreply.github.com> Date: Sun, 5 May 2024 20:32:59 +0800 Subject: [PATCH 04/13] more tests --- .../enhancements/policy/constraint-solver.ts | 71 +++++++++++------- .../src/enhancements/policy/handler.ts | 21 +++++- .../enhancer/policy/constraint-transformer.ts | 10 ++- .../enhancements/with-policy/checker.test.ts | 73 +++++++++++++++++++ 4 files changed, 147 insertions(+), 28 deletions(-) diff --git a/packages/runtime/src/enhancements/policy/constraint-solver.ts b/packages/runtime/src/enhancements/policy/constraint-solver.ts index 3a282f580..cf70ddab0 100644 --- a/packages/runtime/src/enhancements/policy/constraint-solver.ts +++ b/packages/runtime/src/enhancements/policy/constraint-solver.ts @@ -1,4 +1,4 @@ -import Logic, { Formula } from 'logic-solver'; +import Logic from 'logic-solver'; import { match } from 'ts-pattern'; import type { CheckerConstraint, @@ -9,25 +9,41 @@ import type { VariableConstraint, } from '../types'; +/** + * A boolean constraint solver based on `logic-solver`. + */ export class ConstraintSolver { + // a table for internalizing string literals private stringTable: string[] = []; - private variables: Map = new Map(); - solve(constraint: CheckerConstraint): boolean { + // a map for storing variable names and their corresponding formulas + private variables: Map = new Map(); + + /** + * Check the satisfiability of the given constraint. + */ + checkSat(constraint: CheckerConstraint): boolean { + // reset state this.stringTable = []; - this.variables = new Map(); + this.variables = new Map(); + // convert the constraint to a "logic-solver" formula const formula = this.buildFormula(constraint); + + // solve the formula const solver = new Logic.Solver(); solver.require(formula); - const solution = solver.solve(); - if (solution) { - console.log('Solution:'); - this.variables.forEach((v, k) => console.log(`\t${k}=${solution?.evaluate(v)}`)); - } else { - console.log('No solution'); - } - return !!solution; + + // DEBUG: + // const solution = solver.solve(); + // if (solution) { + // console.log('Solution:'); + // this.variables.forEach((v, k) => console.log(`\t${k}=${solution?.evaluate(v)}`)); + // } else { + // console.log('No solution'); + // } + + return !!solver.solve(); } private buildFormula(constraint: CheckerConstraint): Logic.Formula { @@ -84,12 +100,15 @@ export class ConstraintSolver { .exhaustive(); } - buildVariableFormula(constraint: VariableConstraint) { - return match(constraint.type) - .with('boolean', () => this.booleanVariable(constraint.name)) - .with('number', () => this.intVariable(constraint.name)) - .with('string', () => this.intVariable(constraint.name)) - .exhaustive(); + private buildVariableFormula(constraint: VariableConstraint) { + return ( + match(constraint.type) + .with('boolean', () => this.booleanVariable(constraint.name)) + .with('number', () => this.intVariable(constraint.name)) + // strings are internalized and represented by their indices + .with('string', () => this.intVariable(constraint.name)) + .exhaustive() + ); } private buildValueFormula(constraint: ValueConstraint) { @@ -105,6 +124,7 @@ export class ConstraintSolver { .when( (v): v is string => typeof v === 'string', (v) => { + // internalize the string and use its index as formula representation const index = this.stringTable.indexOf(v); if (index === -1) { this.stringTable.push(v); @@ -132,12 +152,15 @@ export class ConstraintSolver { if (left.type !== right.type) { throw new Error(`Type mismatch in equality constraint: ${JSON.stringify(left)}, ${JSON.stringify(right)}`); } - const leftConstraint = this.buildFormula(left); - const rightConstraint = this.buildFormula(right); + + const leftFormula = this.buildFormula(left); + const rightFormula = this.buildFormula(right); if (left.type === 'boolean' && right.type === 'boolean') { - return Logic.equiv(leftConstraint, rightConstraint); + // logical equivalence + return Logic.equiv(leftFormula, rightFormula); } else { - return Logic.equalBits(leftConstraint, rightConstraint); + // integer equality + return Logic.equalBits(leftFormula, rightFormula); } } @@ -146,8 +169,6 @@ export class ConstraintSolver { right: ComparisonTerm, func: (left: Logic.Formula, right: Logic.Formula) => Logic.Formula ) { - const leftConstraint = this.buildFormula(left); - const rightConstraint = this.buildFormula(right); - return func(leftConstraint, rightConstraint); + return func(this.buildFormula(left), this.buildFormula(right)); } } diff --git a/packages/runtime/src/enhancements/policy/handler.ts b/packages/runtime/src/enhancements/policy/handler.ts index 194d8c4c5..e8084ccf8 100644 --- a/packages/runtime/src/enhancements/policy/handler.ts +++ b/packages/runtime/src/enhancements/policy/handler.ts @@ -1451,6 +1451,8 @@ export class PolicyProxyHandler implements Pr } if (fieldValues) { + // combine runtime filters with generated constraints + const extraConstraints: CheckerConstraint[] = []; for (const [field, value] of Object.entries(fieldValues)) { if (value === undefined) { @@ -1463,12 +1465,14 @@ export class PolicyProxyHandler implements Pr const fieldInfo = requireField(this.modelMeta, this.model, field); + // relation and array fields are not supported if (fieldInfo.isDataModel || fieldInfo.isArray) { throw new Error( `Providing filter for field "${field}" is not supported. Only scalar fields are allowed.` ); } + // map field type to constraint type const fieldType = match(fieldInfo.type) .with(P.union('Int', 'BigInt', 'Float', 'Decimal'), () => 'number') .with('String', () => 'string') @@ -1479,13 +1483,24 @@ export class PolicyProxyHandler implements Pr ); }); + // check value type const valueType = typeof value; if (valueType !== 'number' && valueType !== 'string' && valueType !== 'boolean') { throw new Error( - `Invalid value for field "${field}". Only number, string, boolean, or null is allowed.` + `Invalid value type for field "${field}". Only number, string or boolean is allowed.` ); } + if (fieldType !== valueType) { + throw new Error(`Invalid value type for field "${field}". Expected "${fieldType}".`); + } + + // check number validity + if (typeof value === 'number' && (!Number.isInteger(value) || value < 0)) { + throw new Error(`Invalid value for field "${field}". Only non-negative integers are allowed.`); + } + + // build a constraint extraConstraints.push({ kind: 'eq', left: { kind: 'variable', name: field, type: fieldType }, @@ -1494,11 +1509,13 @@ export class PolicyProxyHandler implements Pr } if (extraConstraints.length > 0) { + // combine the constraints constraint = { kind: 'and', children: [constraint, ...extraConstraints] }; } } - return new ConstraintSolver().solve(constraint); + // check satisfiability + return new ConstraintSolver().checkSat(constraint); } //#endregion diff --git a/packages/schema/src/plugins/enhancer/policy/constraint-transformer.ts b/packages/schema/src/plugins/enhancer/policy/constraint-transformer.ts index 2da7b5188..e122a4bbc 100644 --- a/packages/schema/src/plugins/enhancer/policy/constraint-transformer.ts +++ b/packages/schema/src/plugins/enhancer/policy/constraint-transformer.ts @@ -90,7 +90,15 @@ export class ConstraintTransformer { private transformLiteral(expr: LiteralExpr) { return match(expr.$type) - .with(NumberLiteral, () => this.value(expr.value.toString(), 'number')) + .with(NumberLiteral, () => { + const parsed = parseFloat(expr.value as string); + if (isNaN(parsed) || parsed < 0 || !Number.isInteger(parsed)) { + // only non-negative integers are supported, for other cases, + // transform into a free variable + return this.nextVar('number'); + } + return this.value(expr.value.toString(), 'number'); + }) .with(StringLiteral, () => this.value(`'${expr.value}'`, 'string')) .with(BooleanLiteral, () => this.value(expr.value.toString(), 'boolean')) .exhaustive(); diff --git a/tests/integration/tests/enhancements/with-policy/checker.test.ts b/tests/integration/tests/enhancements/with-policy/checker.test.ts index 52fd55afc..e49b8dc57 100644 --- a/tests/integration/tests/enhancements/with-policy/checker.test.ts +++ b/tests/integration/tests/enhancements/with-policy/checker.test.ts @@ -355,4 +355,77 @@ describe('Permission checker', () => { } ); }); + + it('invalid filter', async () => { + const { enhance } = await loadSchema( + ` + model Model { + id Int @id @default(autoincrement()) + value Int + foo Foo? + d DateTime + + @@allow('read', value == 1) + } + + model Foo { + id Int @id @default(autoincrement()) + x Int + model Model @relation(fields: [modelId], references: [id]) + modelId Int @unique + } + ` + ); + + const db = enhance(); + await expect(db.model.check('read', { foo: { x: 1 } })).rejects.toThrow( + `Providing filter for field "foo" is not supported. Only scalar fields are allowed.` + ); + await expect(db.model.check('read', { d: new Date() })).rejects.toThrow( + `Providing filter for field "d" is not supported. Only number, string, and boolean fields are allowed.` + ); + await expect(db.model.check('read', { value: null })).rejects.toThrow( + `Using "null" as filter value is not supported yet` + ); + await expect(db.model.check('read', { value: {} })).rejects.toThrow( + 'Invalid value type for field "value". Only number, string or boolean is allowed.' + ); + await expect(db.model.check('read', { value: 'abc' })).rejects.toThrow( + 'Invalid value type for field "value". Expected "number"' + ); + await expect(db.model.check('read', { value: -1 })).rejects.toThrow( + 'Invalid value for field "value". Only non-negative integers are allowed.' + ); + }); + + it('float field ignored', async () => { + const { enhance } = await loadSchema( + ` + model Model { + id Int @id @default(autoincrement()) + value Float + @@allow('read', value == 1.1) + } + ` + ); + const db = enhance(); + await expect(db.model.check('read')).toResolveTruthy(); + await expect(db.model.check('read', { value: 1 })).toResolveTruthy(); + }); + + it('float value ignored', async () => { + const { enhance } = await loadSchema( + ` + model Model { + id Int @id @default(autoincrement()) + value Int + @@allow('read', value > 1.1) + } + ` + ); + const db = enhance(); + // await expect(db.model.check('read')).toResolveTruthy(); + await expect(db.model.check('read', { value: 1 })).toResolveTruthy(); + await expect(db.model.check('read', { value: 2 })).toResolveTruthy(); + }); }); From b55740aa21c37004af9c4bb012f7067a52218cc3 Mon Sep 17 00:00:00 2001 From: ymc9 <104139426+ymc9@users.noreply.github.com> Date: Sun, 5 May 2024 21:08:39 +0800 Subject: [PATCH 05/13] more fixes and tests --- .../enhancements/policy/constraint-solver.ts | 2 +- .../src/enhancements/policy/handler.ts | 5 ++ .../src/enhancements/policy/logic-solver.d.ts | 66 +++++++++++++++++++ .../src/enhancements/policy/policy-utils.ts | 9 ++- packages/runtime/src/enhancements/types.ts | 24 +++++++ packages/runtime/src/types.ts | 11 +++- .../enhance/checker-type-generator.ts | 11 ++++ .../enhancer/policy/constraint-transformer.ts | 41 ++++++++++-- .../enhancer/policy/policy-guard-generator.ts | 2 + .../enhancements/with-policy/checker.test.ts | 13 ++++ 10 files changed, 174 insertions(+), 10 deletions(-) diff --git a/packages/runtime/src/enhancements/policy/constraint-solver.ts b/packages/runtime/src/enhancements/policy/constraint-solver.ts index cf70ddab0..5b8b484de 100644 --- a/packages/runtime/src/enhancements/policy/constraint-solver.ts +++ b/packages/runtime/src/enhancements/policy/constraint-solver.ts @@ -10,7 +10,7 @@ import type { } from '../types'; /** - * A boolean constraint solver based on `logic-solver`. + * A boolean constraint solver based on `logic-solver`. Only boolean and integer types are supported. */ export class ConstraintSolver { // a table for internalizing string literals diff --git a/packages/runtime/src/enhancements/policy/handler.ts b/packages/runtime/src/enhancements/policy/handler.ts index e8084ccf8..43f19d1bf 100644 --- a/packages/runtime/src/enhancements/policy/handler.ts +++ b/packages/runtime/src/enhancements/policy/handler.ts @@ -1441,6 +1441,11 @@ export class PolicyProxyHandler implements Pr //#region Check + /** + * Checks if the given operation is possibly allowed by the policy, without querying the database. + * @param operation The CRUD operation. + * @param fieldValues Extra field value filters to be combined with the policy constraints. + */ async check( operation: PolicyCrudKind, fieldValues?: Record diff --git a/packages/runtime/src/enhancements/policy/logic-solver.d.ts b/packages/runtime/src/enhancements/policy/logic-solver.d.ts index 31eea6758..d10e688f6 100644 --- a/packages/runtime/src/enhancements/policy/logic-solver.d.ts +++ b/packages/runtime/src/enhancements/policy/logic-solver.d.ts @@ -1,43 +1,109 @@ +/** + * Type definitions for the `logic-solver` npm package. + */ declare module 'logic-solver' { + /** + * A boolean formula. + */ interface Formula {} + /** + * The `TRUE` formula. + */ const TRUE: Formula; + /** + * The `FALSE` formula. + */ const FALSE: Formula; + /** + * Boolean equivalence. + */ export function equiv(operand1: Formula, operand2: Formula): Formula; + /** + * Bits equality. + */ export function equalBits(bits1: Formula, bits2: Formula): Formula; + /** + * Bits greater-than. + */ export function greaterThan(bits1: Formula, bits2: Formula): Formula; + /** + * Bits greater-than-or-equal. + */ export function greaterThanOrEqual(bits1: Formula, bits2: Formula): Formula; + /** + * Bits less-than. + */ export function lessThan(bits1: Formula, bits2: Formula): Formula; + /** + * Bits less-than-or-equal. + */ export function lessThanOrEqual(bits1: Formula, bits2: Formula): Formula; + /** + * Logical AND. + */ export function and(...args: Formula[]): Formula; + /** + * Logical OR. + */ export function or(...args: Formula[]): Formula; + /** + * Logical NOT. + */ export function not(arg: Formula): Formula; + /** + * Creates a bits variable with the given name and bit length. + */ export function variableBits(baseName: string, N: number): Formula; + /** + * Creates a constant bits formula from the given whole number. + */ export function constantBits(wholeNumber: number): Formula; + /** + * A solution to a constraint. + */ interface Solution { + /** + * Returns a map of variable assignments. + */ getMap(): object; + /** + * Evaluates the given formula against the solution. + */ evaluate(formula: Formula): unknown; } + /** + * A constraint solver. + */ class Solver { + /** + * Adds constraints to the solver. + */ require(...args: Formula[]): void; + /** + * Adds negated constraints from the solver. + */ forbid(...args: Formula[]): void; + /** + * Solves the constraints. + */ solve(): Solution; } } diff --git a/packages/runtime/src/enhancements/policy/policy-utils.ts b/packages/runtime/src/enhancements/policy/policy-utils.ts index 9c3e8aecb..1b6c2c02c 100644 --- a/packages/runtime/src/enhancements/policy/policy-utils.ts +++ b/packages/runtime/src/enhancements/policy/policy-utils.ts @@ -565,6 +565,9 @@ export class PolicyUtil extends QueryUtils { //#region Checker + /** + * Gets checker constraints for the given model and operation. + */ getCheckerConstraint(model: string, operation: PolicyCrudKind): ReturnType | boolean { const checker = this.getModelChecker(model); if (!checker) { @@ -576,9 +579,11 @@ export class PolicyUtil extends QueryUtils { return provider; } - if (!provider) { - throw this.unknownError(`unable to load authorization guard for ${model}`); + if (typeof provider !== 'function') { + throw this.unknownError(`unable to ${operation} checker for ${model}`); } + + // call checker function return provider({ user: this.user }); } diff --git a/packages/runtime/src/enhancements/types.ts b/packages/runtime/src/enhancements/types.ts index 17b1ac48e..c2e90fa94 100644 --- a/packages/runtime/src/enhancements/types.ts +++ b/packages/runtime/src/enhancements/types.ts @@ -33,31 +33,55 @@ export interface CommonEnhancementOptions { */ export type PolicyFunc = (context: QueryContext, db: CrudContract) => object; +/** + * Function for checking if an operation is possibly allowed. + */ export type CheckerFunc = (context: CheckerContext) => CheckerConstraint; +/** + * Supported checker constraint checking value types. + */ export type ConstraintValueTypes = 'boolean' | 'number' | 'string'; +/** + * Free variable constraint + */ export type VariableConstraint = { kind: 'variable'; name: string; type: ConstraintValueTypes }; +/** + * Constant value constraint + */ export type ValueConstraint = { kind: 'value'; value: number | boolean | string; type: ConstraintValueTypes; }; +/** + * Terms for comparison constraints + */ export type ComparisonTerm = VariableConstraint | ValueConstraint; +/** + * Comparison constraint + */ export type ComparisonConstraint = { kind: 'eq' | 'gt' | 'gte' | 'lt' | 'lte'; left: ComparisonTerm; right: ComparisonTerm; }; +/** + * Logical constraint + */ export type LogicalConstraint = { kind: 'and' | 'or' | 'not'; children: CheckerConstraint[]; }; +/** + * Operation allowability checking constraint + */ export type CheckerConstraint = ValueConstraint | VariableConstraint | ComparisonConstraint | LogicalConstraint; /** diff --git a/packages/runtime/src/types.ts b/packages/runtime/src/types.ts index b516b2cc3..0e1c86e93 100644 --- a/packages/runtime/src/types.ts +++ b/packages/runtime/src/types.ts @@ -58,10 +58,19 @@ export type QueryContext = { preValue?: any; }; +/** + * Context for checking operation allowability. + */ export type CheckerContext = { + /** + * Current user + */ user?: AuthUser; - fieldValues?: Record; + /** + * Extra field value filters. + */ + fieldValues?: Record; }; /** diff --git a/packages/schema/src/plugins/enhancer/enhance/checker-type-generator.ts b/packages/schema/src/plugins/enhancer/enhance/checker-type-generator.ts index f7986b322..bda5dd9da 100644 --- a/packages/schema/src/plugins/enhancer/enhance/checker-type-generator.ts +++ b/packages/schema/src/plugins/enhancer/enhance/checker-type-generator.ts @@ -5,6 +5,17 @@ import { P, match } from 'ts-pattern'; /** * Generates a `ModelCheckers` interface that contains a `check` method for each model in the schema. + * + * E.g.: + * + * ```ts + * type CheckerOperation = 'create' | 'read' | 'update' | 'delete'; + * + * export interface ModelCheckers { + * user: { check(op: CheckerOperation, args?: { email?: string; age?: number; }): Promise }, + * ... + * } + * ``` */ export function generateCheckerType(model: Model) { return ` diff --git a/packages/schema/src/plugins/enhancer/policy/constraint-transformer.ts b/packages/schema/src/plugins/enhancer/policy/constraint-transformer.ts index e122a4bbc..a3b7f956a 100644 --- a/packages/schema/src/plugins/enhancer/policy/constraint-transformer.ts +++ b/packages/schema/src/plugins/enhancer/policy/constraint-transformer.ts @@ -20,24 +20,37 @@ import { } from '@zenstackhq/sdk/ast'; import { P, match } from 'ts-pattern'; +/** + * Options for {@link ConstraintTransformer}. + */ export type ConstraintTransformerOptions = { authAccessor: string; }; +/** + * Transform a set of allow and deny rules into a single constraint expression. + */ export class ConstraintTransformer { + // a counter for generating unique variable names private varCounter = 0; constructor(private readonly options: ConstraintTransformerOptions) {} + /** + * Transforms a set of allow and deny rules into a single constraint expression. + */ transformRules(allows: Expression[], denies: Expression[]): string { + // reset state this.varCounter = 0; if (allows.length === 0) { + // unconditionally deny return this.value('false', 'boolean'); } let result: string; + // transform allow rules const allowConstraints = allows.map((allow) => this.transformExpression(allow)); if (allowConstraints.length > 1) { result = this.and(...allowConstraints); @@ -45,12 +58,14 @@ export class ConstraintTransformer { result = allowConstraints[0]; } + // transform deny rules and compose if (denies.length > 0) { const denyConstraints = denies.map((deny) => this.transformExpression(deny)); result = this.and(result, this.not(this.or(...denyConstraints))); } - console.log(`Constraint transformation result:\n${JSON.stringify(result, null, 2)}`); + // DEBUG: + // console.log(`Constraint transformation result:\n${JSON.stringify(result, null, 2)}`); return result; } @@ -105,22 +120,30 @@ export class ConstraintTransformer { } private transformReference(expr: ReferenceExpr) { + // top-level reference is transformed into a named variable return this.variable(expr.target.$refText, 'boolean'); } private transformMemberAccess(expr: MemberAccessExpr) { if (isThisExpr(expr.operand)) { + // "this.x" is transformed into a named variable return this.variable(expr.member.$refText, 'boolean'); } + + // other member access expressions are not supported and thus + // transformed into a free variable return this.nextVar(); } private transformBinary(expr: BinaryExpr): string { - return match(expr.operator) - .with('&&', () => this.and(this.transformExpression(expr.left), this.transformExpression(expr.right))) - .with('||', () => this.or(this.transformExpression(expr.left), this.transformExpression(expr.right))) - .with(P.union('==', '!=', '<', '<=', '>', '>='), () => this.transformComparison(expr)) - .otherwise(() => this.nextVar()); + return ( + match(expr.operator) + .with('&&', () => this.and(this.transformExpression(expr.left), this.transformExpression(expr.right))) + .with('||', () => this.or(this.transformExpression(expr.left), this.transformExpression(expr.right))) + .with(P.union('==', '!=', '<', '<=', '>', '>='), () => this.transformComparison(expr)) + // unsupported operators (e.g., collection predicate) are transformed into a free variable + .otherwise(() => this.nextVar()) + ); } private transformUnary(expr: UnaryExpr): string { @@ -134,6 +157,7 @@ export class ConstraintTransformer { const rightOperand = this.getComparisonOperand(expr.right); if (leftOperand === undefined || rightOperand === undefined) { + // if either operand is not supported, transform into a free variable return this.nextVar(); } @@ -150,6 +174,7 @@ export class ConstraintTransformer { let result = `{ kind: '${op}', left: ${leftOperand}, right: ${rightOperand} }`; if (expr.operator === '!=') { + // transform "!=" into "not eq" result = `{ kind: 'not', children: [${result}] }`; } @@ -163,6 +188,7 @@ export class ConstraintTransformer { const fieldAccess = this.getFieldAccess(expr); if (fieldAccess) { + // model field access is transformed into a named variable const mappedType = this.mapType(expr); if (mappedType) { return this.variable(fieldAccess.name, mappedType); @@ -173,6 +199,9 @@ export class ConstraintTransformer { const authAccess = this.getAuthAccess(expr); if (authAccess) { + // `auth().` access is transformed into a runtime boolean value if it + // doesn't evaluate to undefined (due to ?. chaining), otherwise into + // a named variable const fieldAccess = `${this.options.authAccessor}?.${authAccess}`; const mappedType = this.mapType(expr); if (mappedType) { diff --git a/packages/schema/src/plugins/enhancer/policy/policy-guard-generator.ts b/packages/schema/src/plugins/enhancer/policy/policy-guard-generator.ts index 8e47c6614..f8952d6fa 100644 --- a/packages/schema/src/plugins/enhancer/policy/policy-guard-generator.ts +++ b/packages/schema/src/plugins/enhancer/policy/policy-guard-generator.ts @@ -88,11 +88,13 @@ export class PolicyGenerator { const models = getDataModels(model); + // policy guard functions const policyMap: Record> = {}; for (const model of models) { policyMap[model.name] = await this.generateQueryGuardForModel(model, sf); } + // CRUD checker functions const checkerMap: Record> = {}; for (const model of models) { checkerMap[model.name] = await this.generateCheckerForModel(model, sf); diff --git a/tests/integration/tests/enhancements/with-policy/checker.test.ts b/tests/integration/tests/enhancements/with-policy/checker.test.ts index e49b8dc57..5e1580af2 100644 --- a/tests/integration/tests/enhancements/with-policy/checker.test.ts +++ b/tests/integration/tests/enhancements/with-policy/checker.test.ts @@ -226,6 +226,8 @@ describe('Permission checker', () => { x Int y Int @@allow('read', x > 0 && x > y) + @@allow('create', x > 1 || x > y) + @@allow('update', !(x >= y)) } ` ); @@ -236,6 +238,17 @@ describe('Permission checker', () => { await expect(db.model.check('read', { x: 1 })).toResolveTruthy(); await expect(db.model.check('read', { x: 1, y: 0 })).toResolveTruthy(); await expect(db.model.check('read', { x: 1, y: 1 })).toResolveFalsy(); + + await expect(db.model.check('create')).toResolveTruthy(); + await expect(db.model.check('create', { x: 0 })).toResolveFalsy(); // numbers are non-negative + await expect(db.model.check('create', { x: 1 })).toResolveTruthy(); + await expect(db.model.check('create', { x: 1, y: 0 })).toResolveTruthy(); + await expect(db.model.check('create', { x: 1, y: 1 })).toResolveFalsy(); + + await expect(db.model.check('update')).toResolveTruthy(); + await expect(db.model.check('update', { x: 0 })).toResolveTruthy(); + await expect(db.model.check('update', { y: 0 })).toResolveFalsy(); // numbers are non-negative + await expect(db.model.check('update', { x: 1, y: 1 })).toResolveFalsy(); }); it('field condition unsolvable', async () => { From ebae535710bd8a9bdf225af567b87e4a1dad5c01 Mon Sep 17 00:00:00 2001 From: ymc9 <104139426+ymc9@users.noreply.github.com> Date: Sun, 5 May 2024 21:21:17 +0800 Subject: [PATCH 06/13] fix: error wording --- packages/runtime/src/enhancements/policy/policy-utils.ts | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/packages/runtime/src/enhancements/policy/policy-utils.ts b/packages/runtime/src/enhancements/policy/policy-utils.ts index 1b6c2c02c..f910df264 100644 --- a/packages/runtime/src/enhancements/policy/policy-utils.ts +++ b/packages/runtime/src/enhancements/policy/policy-utils.ts @@ -580,7 +580,7 @@ export class PolicyUtil extends QueryUtils { } if (typeof provider !== 'function') { - throw this.unknownError(`unable to ${operation} checker for ${model}`); + throw this.unknownError(`unable to load ${operation} checker for ${model}`); } // call checker function From ae96914b9cce12b7bb377a6481a84b6e9c2bcaef Mon Sep 17 00:00:00 2001 From: ymc9 <104139426+ymc9@users.noreply.github.com> Date: Sun, 5 May 2024 21:29:56 +0800 Subject: [PATCH 07/13] change return type to a deferred promise --- packages/runtime/src/enhancements/policy/handler.ts | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/packages/runtime/src/enhancements/policy/handler.ts b/packages/runtime/src/enhancements/policy/handler.ts index 43f19d1bf..bf6df3e7b 100644 --- a/packages/runtime/src/enhancements/policy/handler.ts +++ b/packages/runtime/src/enhancements/policy/handler.ts @@ -1450,6 +1450,10 @@ export class PolicyProxyHandler implements Pr operation: PolicyCrudKind, fieldValues?: Record ): Promise { + return createDeferredPromise(() => this.doCheck(operation, fieldValues)); + } + + private async doCheck(operation: PolicyCrudKind, fieldValues?: Record) { let constraint = this.policyUtils.getCheckerConstraint(this.model, operation); if (typeof constraint === 'boolean') { return constraint; From 7c0277d9f5cb17483fedffdeaa0da09723d83233 Mon Sep 17 00:00:00 2001 From: ymc9 <104139426+ymc9@users.noreply.github.com> Date: Mon, 6 May 2024 22:56:34 +0800 Subject: [PATCH 08/13] fix: handle nullable `auth()` access --- .../enhancements/policy/constraint-solver.ts | 13 +++ .../enhancer/policy/constraint-transformer.ts | 81 ++++++++++++++----- .../enhancements/with-policy/checker.test.ts | 50 +++++++++++- 3 files changed, 123 insertions(+), 21 deletions(-) diff --git a/packages/runtime/src/enhancements/policy/constraint-solver.ts b/packages/runtime/src/enhancements/policy/constraint-solver.ts index 5b8b484de..9ecc1dd02 100644 --- a/packages/runtime/src/enhancements/policy/constraint-solver.ts +++ b/packages/runtime/src/enhancements/policy/constraint-solver.ts @@ -83,6 +83,19 @@ export class ConstraintSolver { } private buildComparisonFormula(constraint: ComparisonConstraint) { + if (constraint.left.kind === 'value' && constraint.right.kind === 'value') { + // constant comparison + const left: ValueConstraint = constraint.left; + const right: ValueConstraint = constraint.right; + return match(constraint.kind) + .with('eq', () => (left.value === right.value ? Logic.TRUE : Logic.FALSE)) + .with('gt', () => (left.value > right.value ? Logic.TRUE : Logic.FALSE)) + .with('gte', () => (left.value >= right.value ? Logic.TRUE : Logic.FALSE)) + .with('lt', () => (left.value < right.value ? Logic.TRUE : Logic.FALSE)) + .with('lte', () => (left.value <= right.value ? Logic.TRUE : Logic.FALSE)) + .exhaustive(); + } + return match(constraint.kind) .with('eq', () => this.transformEquality(constraint.left, constraint.right)) .with('gt', () => diff --git a/packages/schema/src/plugins/enhancer/policy/constraint-transformer.ts b/packages/schema/src/plugins/enhancer/policy/constraint-transformer.ts index a3b7f956a..9ab49512b 100644 --- a/packages/schema/src/plugins/enhancer/policy/constraint-transformer.ts +++ b/packages/schema/src/plugins/enhancer/policy/constraint-transformer.ts @@ -14,6 +14,7 @@ import { isDataModelField, isLiteralExpr, isMemberAccessExpr, + isNullExpr, isReferenceExpr, isThisExpr, isUnaryExpr, @@ -125,13 +126,19 @@ export class ConstraintTransformer { } private transformMemberAccess(expr: MemberAccessExpr) { + // "this.x" is transformed into a named variable if (isThisExpr(expr.operand)) { - // "this.x" is transformed into a named variable return this.variable(expr.member.$refText, 'boolean'); } - // other member access expressions are not supported and thus - // transformed into a free variable + // top-level auth access + const authAccess = this.getAuthAccess(expr); + if (authAccess) { + return this.value(`${authAccess} ?? false`, 'boolean'); + } + + // other top-level member access expressions are not supported + // and thus transformed into a free variable return this.nextVar(); } @@ -153,14 +160,19 @@ export class ConstraintTransformer { } private transformComparison(expr: BinaryExpr) { - const leftOperand = this.getComparisonOperand(expr.left); - const rightOperand = this.getComparisonOperand(expr.right); + if (this.isAuthEqualNull(expr)) { + // `auth() == null` => `user === null` + return this.value(`${this.options.authAccessor} === null`, 'boolean'); + } - if (leftOperand === undefined || rightOperand === undefined) { - // if either operand is not supported, transform into a free variable - return this.nextVar(); + if (this.isAuthNotEqualNull(expr)) { + // `auth() != null` => `user !== null` + return this.value(`${this.options.authAccessor} !== null`, 'boolean'); } + const leftOperand = this.getComparisonOperand(expr.left); + const rightOperand = this.getComparisonOperand(expr.right); + const op = match(expr.operator) .with('==', () => 'eq') .with('!=', () => 'eq') @@ -175,12 +187,52 @@ export class ConstraintTransformer { let result = `{ kind: '${op}', left: ${leftOperand}, right: ${rightOperand} }`; if (expr.operator === '!=') { // transform "!=" into "not eq" - result = `{ kind: 'not', children: [${result}] }`; + result = this.not(result); + } + + // `auth()` access can be undefined, when that happens, we assume a false condition + // for the comparison, unless we're directly comparing `auth() != null` + + const leftAuthAccess = this.getAuthAccess(expr.left); + const rightAuthAccess = this.getAuthAccess(expr.right); + + if (leftAuthAccess) { + // `auth().f op x` => `auth().f !== undefined && auth().f op x` + return this.and(this.value(`${this.normalizeToNull(leftAuthAccess)} !== null`, 'boolean'), result); + } else if (rightAuthAccess) { + // `x op auth().f` => `auth().f !== undefined && x op auth().f` + return this.and(this.value(`${this.normalizeToNull(rightAuthAccess)} !== null`, 'boolean'), result); + } + + if (leftOperand === undefined || rightOperand === undefined) { + // if either operand is not supported, transform into a free variable + return this.nextVar(); } return result; } + // normalize `auth()` access undefined value to null + private normalizeToNull(expr: string) { + return `(${expr} ?? null)`; + } + + private isAuthEqualNull(expr: BinaryExpr) { + return ( + expr.operator === '==' && + ((isAuthInvocation(expr.left) && isNullExpr(expr.right)) || + (isAuthInvocation(expr.right) && isNullExpr(expr.left))) + ); + } + + private isAuthNotEqualNull(expr: BinaryExpr) { + return ( + expr.operator === '!=' && + ((isAuthInvocation(expr.left) && isNullExpr(expr.right)) || + (isAuthInvocation(expr.right) && isNullExpr(expr.left))) + ); + } + private getComparisonOperand(expr: Expression) { if (isLiteralExpr(expr)) { return this.transformLiteral(expr); @@ -199,16 +251,9 @@ export class ConstraintTransformer { const authAccess = this.getAuthAccess(expr); if (authAccess) { - // `auth().` access is transformed into a runtime boolean value if it - // doesn't evaluate to undefined (due to ?. chaining), otherwise into - // a named variable - const fieldAccess = `${this.options.authAccessor}?.${authAccess}`; const mappedType = this.mapType(expr); if (mappedType) { - return `${fieldAccess} === undefined ? ${this.expressionVariable(expr, mappedType)} : ${this.value( - fieldAccess, - mappedType - )}`; + return `${this.value(authAccess, mappedType)}`; } else { return undefined; } @@ -241,7 +286,7 @@ export class ConstraintTransformer { } if (isAuthInvocation(expr.operand)) { - return expr.member.$refText; + return `${this.options.authAccessor}?.${expr.member.$refText}`; } else { const operand = this.getAuthAccess(expr.operand); return operand ? `${operand}?.${expr.member.$refText}` : undefined; diff --git a/tests/integration/tests/enhancements/with-policy/checker.test.ts b/tests/integration/tests/enhancements/with-policy/checker.test.ts index 5e1580af2..951febdce 100644 --- a/tests/integration/tests/enhancements/with-policy/checker.test.ts +++ b/tests/integration/tests/enhancements/with-policy/checker.test.ts @@ -277,19 +277,63 @@ describe('Permission checker', () => { model User { id Int @id @default(autoincrement()) level Int + admin Boolean } model Model { id Int @id @default(autoincrement()) value Int @@allow('read', auth().level > 0) + @@allow('create', auth().admin) + @@allow('update', !auth().admin) } ` ); - await expect(enhance().model.check('read')).toResolveTruthy(); + await expect(enhance().model.check('read')).toResolveFalsy(); + await expect(enhance({ id: 1 }).model.check('read')).toResolveFalsy(); await expect(enhance({ id: 1, level: 0 }).model.check('read')).toResolveFalsy(); await expect(enhance({ id: 1, level: 1 }).model.check('read')).toResolveTruthy(); + + await expect(enhance().model.check('create')).toResolveFalsy(); + await expect(enhance({ id: 1 }).model.check('create')).toResolveFalsy(); + await expect(enhance({ id: 1, admin: false }).model.check('create')).toResolveFalsy(); + await expect(enhance({ id: 1, admin: true }).model.check('create')).toResolveTruthy(); + + await expect(enhance().model.check('update')).toResolveTruthy(); + await expect(enhance({ id: 1 }).model.check('update')).toResolveTruthy(); + await expect(enhance({ id: 1, admin: true }).model.check('update')).toResolveFalsy(); + await expect(enhance({ id: 1, admin: false }).model.check('update')).toResolveTruthy(); + }); + + it('auth null check', async () => { + const { enhance } = await loadSchema( + ` + model User { + id Int @id @default(autoincrement()) + level Int + } + + model Model { + id Int @id @default(autoincrement()) + value Int + @@allow('read', auth() != null) + @@allow('create', auth() == null) + @@allow('update', auth().level > 0) + } + ` + ); + + await expect(enhance().model.check('read')).toResolveFalsy(); + await expect(enhance({ id: 1 }).model.check('read')).toResolveTruthy(); + + await expect(enhance().model.check('create')).toResolveTruthy(); + await expect(enhance({ id: 1 }).model.check('create')).toResolveFalsy(); + + await expect(enhance().model.check('update')).toResolveFalsy(); + await expect(enhance({ id: 1 }).model.check('update')).toResolveFalsy(); + await expect(enhance({ id: 1, level: 0 }).model.check('update')).toResolveFalsy(); + await expect(enhance({ id: 1, level: 1 }).model.check('update')).toResolveTruthy(); }); it('auth with relation', async () => { @@ -315,8 +359,8 @@ describe('Permission checker', () => { ` ); - await expect(enhance().model.check('read')).toResolveTruthy(); - await expect(enhance({ id: 1 }).model.check('read')).toResolveTruthy(); + await expect(enhance().model.check('read')).toResolveFalsy(); + await expect(enhance({ id: 1 }).model.check('read')).toResolveFalsy(); await expect(enhance({ id: 1, profile: { level: 0 } }).model.check('read')).toResolveFalsy(); await expect(enhance({ id: 1, profile: { level: 1 } }).model.check('read')).toResolveTruthy(); }); From 3c6e5ee96f4dbcabbf4ef3ec1bd10af78bf16652 Mon Sep 17 00:00:00 2001 From: ymc9 <104139426+ymc9@users.noreply.github.com> Date: Tue, 7 May 2024 15:59:11 +0800 Subject: [PATCH 09/13] more fixes --- .../enhancements/policy/constraint-solver.ts | 50 +++++-- packages/runtime/src/enhancements/types.ts | 2 +- .../enhancer/policy/constraint-transformer.ts | 124 +++++++++++++----- .../enhancer/policy/policy-guard-generator.ts | 2 +- packages/sdk/src/utils.ts | 26 +++- .../enhancements/with-policy/checker.test.ts | 44 ++++++- tests/regression/tests/issue-961.test.ts | 6 +- 7 files changed, 202 insertions(+), 52 deletions(-) diff --git a/packages/runtime/src/enhancements/policy/constraint-solver.ts b/packages/runtime/src/enhancements/policy/constraint-solver.ts index 9ecc1dd02..c87a528e7 100644 --- a/packages/runtime/src/enhancements/policy/constraint-solver.ts +++ b/packages/runtime/src/enhancements/policy/constraint-solver.ts @@ -57,7 +57,7 @@ export class ConstraintSolver { (c) => this.buildVariableFormula(c) ) .when( - (c): c is ComparisonConstraint => ['eq', 'gt', 'gte', 'lt', 'lte'].includes(c.kind), + (c): c is ComparisonConstraint => ['eq', 'ne', 'gt', 'gte', 'lt', 'lte'].includes(c.kind), (c) => this.buildComparisonFormula(c) ) .when( @@ -71,17 +71,43 @@ export class ConstraintSolver { private buildLogicalFormula(constraint: LogicalConstraint) { return match(constraint.kind) - .with('and', () => Logic.and(...constraint.children.map((c) => this.buildFormula(c)))) - .with('or', () => Logic.or(...constraint.children.map((c) => this.buildFormula(c)))) - .with('not', () => { - if (constraint.children.length !== 1) { - throw new Error('"not" constraint must have exactly one child'); - } - return Logic.not(this.buildFormula(constraint.children[0])); - }) + .with('and', () => this.buildAndFormula(constraint)) + .with('or', () => this.buildOrFormula(constraint)) + .with('not', () => this.buildNotFormula(constraint)) .exhaustive(); } + private buildAndFormula(constraint: LogicalConstraint): Logic.Formula { + if (constraint.children.some((c) => this.isFalse(c))) { + // short-circuit + return Logic.FALSE; + } + return Logic.and(...constraint.children.map((c) => this.buildFormula(c))); + } + + private buildOrFormula(constraint: LogicalConstraint): Logic.Formula { + if (constraint.children.some((c) => this.isTrue(c))) { + // short-circuit + return Logic.TRUE; + } + return Logic.or(...constraint.children.map((c) => this.buildFormula(c))); + } + + private buildNotFormula(constraint: LogicalConstraint) { + if (constraint.children.length !== 1) { + throw new Error('"not" constraint must have exactly one child'); + } + return Logic.not(this.buildFormula(constraint.children[0])); + } + + private isTrue(constraint: CheckerConstraint): unknown { + return constraint.kind === 'value' && constraint.value === true; + } + + private isFalse(constraint: CheckerConstraint): unknown { + return constraint.kind === 'value' && constraint.value === false; + } + private buildComparisonFormula(constraint: ComparisonConstraint) { if (constraint.left.kind === 'value' && constraint.right.kind === 'value') { // constant comparison @@ -89,6 +115,7 @@ export class ConstraintSolver { const right: ValueConstraint = constraint.right; return match(constraint.kind) .with('eq', () => (left.value === right.value ? Logic.TRUE : Logic.FALSE)) + .with('ne', () => (left.value !== right.value ? Logic.TRUE : Logic.FALSE)) .with('gt', () => (left.value > right.value ? Logic.TRUE : Logic.FALSE)) .with('gte', () => (left.value >= right.value ? Logic.TRUE : Logic.FALSE)) .with('lt', () => (left.value < right.value ? Logic.TRUE : Logic.FALSE)) @@ -98,6 +125,7 @@ export class ConstraintSolver { return match(constraint.kind) .with('eq', () => this.transformEquality(constraint.left, constraint.right)) + .with('ne', () => this.transformInequality(constraint.left, constraint.right)) .with('gt', () => this.transformComparison(constraint.left, constraint.right, (l, r) => Logic.greaterThan(l, r)) ) @@ -177,6 +205,10 @@ export class ConstraintSolver { } } + private transformInequality(left: ComparisonTerm, right: ComparisonTerm) { + return Logic.not(this.transformEquality(left, right)); + } + private transformComparison( left: ComparisonTerm, right: ComparisonTerm, diff --git a/packages/runtime/src/enhancements/types.ts b/packages/runtime/src/enhancements/types.ts index c2e90fa94..bb508a1d1 100644 --- a/packages/runtime/src/enhancements/types.ts +++ b/packages/runtime/src/enhancements/types.ts @@ -66,7 +66,7 @@ export type ComparisonTerm = VariableConstraint | ValueConstraint; * Comparison constraint */ export type ComparisonConstraint = { - kind: 'eq' | 'gt' | 'gte' | 'lt' | 'lte'; + kind: 'eq' | 'ne' | 'gt' | 'gte' | 'lt' | 'lte'; left: ComparisonTerm; right: ComparisonTerm; }; diff --git a/packages/schema/src/plugins/enhancer/policy/constraint-transformer.ts b/packages/schema/src/plugins/enhancer/policy/constraint-transformer.ts index 9ab49512b..08d72f945 100644 --- a/packages/schema/src/plugins/enhancer/policy/constraint-transformer.ts +++ b/packages/schema/src/plugins/enhancer/policy/constraint-transformer.ts @@ -1,7 +1,8 @@ -import { ZModelCodeGenerator, isAuthInvocation } from '@zenstackhq/sdk'; +import { ZModelCodeGenerator, getRelationKeyPairs, isAuthInvocation, isDataModelFieldReference } from '@zenstackhq/sdk'; import { BinaryExpr, BooleanLiteral, + DataModelField, Expression, ExpressionType, LiteralExpr, @@ -160,46 +161,28 @@ export class ConstraintTransformer { } private transformComparison(expr: BinaryExpr) { - if (this.isAuthEqualNull(expr)) { - // `auth() == null` => `user === null` - return this.value(`${this.options.authAccessor} === null`, 'boolean'); - } - - if (this.isAuthNotEqualNull(expr)) { - // `auth() != null` => `user !== null` - return this.value(`${this.options.authAccessor} !== null`, 'boolean'); + if (isAuthInvocation(expr.left) || isAuthInvocation(expr.right)) { + // handle the case if any operand is `auth()` invocation + const authComparison = this.transformAuthComparison(expr); + return authComparison ?? this.nextVar(); } const leftOperand = this.getComparisonOperand(expr.left); const rightOperand = this.getComparisonOperand(expr.right); - const op = match(expr.operator) - .with('==', () => 'eq') - .with('!=', () => 'eq') - .with('<', () => 'lt') - .with('<=', () => 'lte') - .with('>', () => 'gt') - .with('>=', () => 'gte') - .otherwise(() => { - throw new Error(`Unsupported operator: ${expr.operator}`); - }); - - let result = `{ kind: '${op}', left: ${leftOperand}, right: ${rightOperand} }`; - if (expr.operator === '!=') { - // transform "!=" into "not eq" - result = this.not(result); - } + const op = this.mapOperatorToConstraintKind(expr.operator); + const result = `{ kind: '${op}', left: ${leftOperand}, right: ${rightOperand} }`; - // `auth()` access can be undefined, when that happens, we assume a false condition - // for the comparison, unless we're directly comparing `auth() != null` + // `auth()` member access can be undefined, when that happens, we assume a false condition + // for the comparison const leftAuthAccess = this.getAuthAccess(expr.left); const rightAuthAccess = this.getAuthAccess(expr.right); - if (leftAuthAccess) { + if (leftAuthAccess && rightOperand) { // `auth().f op x` => `auth().f !== undefined && auth().f op x` return this.and(this.value(`${this.normalizeToNull(leftAuthAccess)} !== null`, 'boolean'), result); - } else if (rightAuthAccess) { + } else if (rightAuthAccess && leftOperand) { // `x op auth().f` => `auth().f !== undefined && x op auth().f` return this.and(this.value(`${this.normalizeToNull(rightAuthAccess)} !== null`, 'boolean'), result); } @@ -212,6 +195,64 @@ export class ConstraintTransformer { return result; } + private transformAuthComparison(expr: BinaryExpr) { + if (this.isAuthEqualNull(expr)) { + // `auth() == null` => `user === null` + return this.value(`${this.options.authAccessor} === null`, 'boolean'); + } + + if (this.isAuthNotEqualNull(expr)) { + // `auth() != null` => `user !== null` + return this.value(`${this.options.authAccessor} !== null`, 'boolean'); + } + + // auth() equality check against a relation, translate to id-fk comparison + const operand = isAuthInvocation(expr.left) ? expr.right : expr.left; + if (!isDataModelFieldReference(operand)) { + return undefined; + } + + // get id-fk field pairs from the relation field + const relationField = operand.target.ref as DataModelField; + const idFkPairs = getRelationKeyPairs(relationField); + + // build id-fk field comparison constraints + const fieldConstraints: string[] = []; + + idFkPairs.forEach(({ id, foreignKey }) => { + const idFieldType = this.mapType(id.type.type as ExpressionType); + if (!idFieldType) { + return; + } + const fkFieldType = this.mapType(foreignKey.type.type as ExpressionType); + if (!fkFieldType) { + return; + } + + const op = this.mapOperatorToConstraintKind(expr.operator); + const authIdAccess = `${this.options.authAccessor}?.${id.name}`; + + fieldConstraints.push( + this.and( + // `auth()?.id != null` guard + this.value(`${this.normalizeToNull(authIdAccess)} !== null`, 'boolean'), + // `auth()?.id [op] fkField` + `{ kind: '${op}', left: ${this.value(authIdAccess, idFieldType)}, right: ${this.variable( + foreignKey.name, + fkFieldType + )} }` + ) + ); + }); + + // combine field constraints + if (fieldConstraints.length > 0) { + return this.and(...fieldConstraints); + } + + return undefined; + } + // normalize `auth()` access undefined value to null private normalizeToNull(expr: string) { return `(${expr} ?? null)`; @@ -241,7 +282,7 @@ export class ConstraintTransformer { const fieldAccess = this.getFieldAccess(expr); if (fieldAccess) { // model field access is transformed into a named variable - const mappedType = this.mapType(expr); + const mappedType = this.mapExpressionType(expr); if (mappedType) { return this.variable(fieldAccess.name, mappedType); } else { @@ -251,7 +292,7 @@ export class ConstraintTransformer { const authAccess = this.getAuthAccess(expr); if (authAccess) { - const mappedType = this.mapType(expr); + const mappedType = this.mapExpressionType(expr); if (mappedType) { return `${this.value(authAccess, mappedType)}`; } else { @@ -262,14 +303,31 @@ export class ConstraintTransformer { return undefined; } - private mapType(expression: Expression) { - return match(expression.$resolvedType?.decl as ExpressionType) + private mapExpressionType(expression: Expression) { + return this.mapType(expression.$resolvedType?.decl as ExpressionType); + } + + private mapType(type: ExpressionType) { + return match(type) .with('Boolean', () => 'boolean') .with('Int', () => 'number') .with('String', () => 'string') .otherwise(() => undefined); } + private mapOperatorToConstraintKind(operator: BinaryExpr['operator']) { + return match(operator) + .with('==', () => 'eq') + .with('!=', () => 'ne') + .with('<', () => 'lt') + .with('<=', () => 'lte') + .with('>', () => 'gt') + .with('>=', () => 'gte') + .otherwise(() => { + throw new Error(`Unsupported operator: ${operator}`); + }); + } + private getFieldAccess(expr: Expression) { if (isReferenceExpr(expr)) { return isDataModelField(expr.target.ref) ? { name: expr.target.$refText } : undefined; diff --git a/packages/schema/src/plugins/enhancer/policy/policy-guard-generator.ts b/packages/schema/src/plugins/enhancer/policy/policy-guard-generator.ts index f8952d6fa..a7bd0e846 100644 --- a/packages/schema/src/plugins/enhancer/policy/policy-guard-generator.ts +++ b/packages/schema/src/plugins/enhancer/policy/policy-guard-generator.ts @@ -922,7 +922,7 @@ export class PolicyGenerator { statements.push(`return ${transformed};`); const func = sourceFile.addFunction({ - name: `${model.name}Checker_${kind}`, + name: `${model.name}$checker$${kind}`, returnType: 'CheckerConstraint', parameters: [ { diff --git a/packages/sdk/src/utils.ts b/packages/sdk/src/utils.ts index 6617983aa..1bed512c8 100644 --- a/packages/sdk/src/utils.ts +++ b/packages/sdk/src/utils.ts @@ -298,9 +298,9 @@ export function isForeignKeyField(field: DataModelField) { } /** - * Gets the foreign key fields of the given relation field. + * Gets the foreign key-id field pairs from the given relation field. */ -export function getForeignKeyFields(relationField: DataModelField) { +export function getRelationKeyPairs(relationField: DataModelField) { if (!isRelationshipField(relationField)) { return []; } @@ -309,11 +309,31 @@ export function getForeignKeyFields(relationField: DataModelField) { if (relAttr) { // find "fields" arg const fieldsArg = getAttributeArg(relAttr, 'fields'); + let fkFields: DataModelField[]; if (fieldsArg && isArrayExpr(fieldsArg)) { - return fieldsArg.items + fkFields = fieldsArg.items .filter((item): item is ReferenceExpr => isReferenceExpr(item)) .map((item) => item.target.ref as DataModelField); + } else { + return []; } + + // find "references" arg + const referencesArg = getAttributeArg(relAttr, 'references'); + let idFields: DataModelField[]; + if (referencesArg && isArrayExpr(referencesArg)) { + idFields = referencesArg.items + .filter((item): item is ReferenceExpr => isReferenceExpr(item)) + .map((item) => item.target.ref as DataModelField); + } else { + return []; + } + + if (idFields.length !== fkFields.length) { + throw new Error(`Relation's references arg and fields are must have equal length`); + } + + return idFields.map((idField, i) => ({ id: idField, foreignKey: fkFields[i] })); } return []; diff --git a/tests/integration/tests/enhancements/with-policy/checker.test.ts b/tests/integration/tests/enhancements/with-policy/checker.test.ts index 951febdce..9d3af35fe 100644 --- a/tests/integration/tests/enhancements/with-policy/checker.test.ts +++ b/tests/integration/tests/enhancements/with-policy/checker.test.ts @@ -306,6 +306,48 @@ describe('Permission checker', () => { await expect(enhance({ id: 1, admin: false }).model.check('update')).toResolveTruthy(); }); + it('auth compared with relation field', async () => { + const { enhance } = await loadSchema( + ` + model User { + id Int @id @default(autoincrement()) + models Model[] + } + + model Model { + id Int @id @default(autoincrement()) + owner User @relation(fields: [ownerId], references: [id]) + ownerId Int + @@allow('read', auth().id == ownerId) + @@allow('create', auth().id != ownerId) + @@allow('update', auth() == owner) + @@allow('delete', auth() != owner) + } + `, + { preserveTsFiles: true } + ); + + await expect(enhance().model.check('read')).toResolveFalsy(); + await expect(enhance({ id: 1 }).model.check('read')).toResolveTruthy(); + await expect(enhance({ id: 1 }).model.check('read', { ownerId: 1 })).toResolveTruthy(); + await expect(enhance({ id: 1 }).model.check('read', { ownerId: 2 })).toResolveFalsy(); + + await expect(enhance().model.check('create')).toResolveFalsy(); + await expect(enhance({ id: 1 }).model.check('create')).toResolveTruthy(); + await expect(enhance({ id: 1 }).model.check('create', { ownerId: 1 })).toResolveFalsy(); + await expect(enhance({ id: 1 }).model.check('create', { ownerId: 2 })).toResolveTruthy(); + + await expect(enhance().model.check('update')).toResolveFalsy(); + await expect(enhance({ id: 1 }).model.check('update')).toResolveTruthy(); + await expect(enhance({ id: 1 }).model.check('update', { ownerId: 1 })).toResolveTruthy(); + await expect(enhance({ id: 1 }).model.check('update', { ownerId: 2 })).toResolveFalsy(); + + await expect(enhance().model.check('delete')).toResolveFalsy(); + await expect(enhance({ id: 1 }).model.check('delete')).toResolveTruthy(); + await expect(enhance({ id: 1 }).model.check('delete', { ownerId: 1 })).toResolveFalsy(); + await expect(enhance({ id: 1 }).model.check('delete', { ownerId: 2 })).toResolveTruthy(); + }); + it('auth null check', async () => { const { enhance } = await loadSchema( ` @@ -336,7 +378,7 @@ describe('Permission checker', () => { await expect(enhance({ id: 1, level: 1 }).model.check('update')).toResolveTruthy(); }); - it('auth with relation', async () => { + it('auth with relation access', async () => { const { enhance } = await loadSchema( ` model User { diff --git a/tests/regression/tests/issue-961.test.ts b/tests/regression/tests/issue-961.test.ts index 7bc42071b..f6dc3a135 100644 --- a/tests/regression/tests/issue-961.test.ts +++ b/tests/regression/tests/issue-961.test.ts @@ -123,10 +123,8 @@ describe('Regression: issue 961', () => { await expect(db.userColumn.findMany()).resolves.toHaveLength(1); }); - // disabled because of Prisma V4 bug: https://github.com/prisma/prisma/issues/18371 - // eslint-disable-next-line jest/no-disabled-tests - it.skip('updateMany', async () => { - const { prisma, enhance } = await loadSchema(schema, { logPrismaQuery: true }); + it('updateMany', async () => { + const { prisma, enhance } = await loadSchema(schema); const user = await prisma.user.create({ data: { From 8af5cddaa44dc63e114828504f40a08dafff77a7 Mon Sep 17 00:00:00 2001 From: ymc9 <104139426+ymc9@users.noreply.github.com> Date: Tue, 7 May 2024 16:55:14 +0800 Subject: [PATCH 10/13] guard checker generation with an option flag --- .../src/enhancements/policy/policy-utils.ts | 11 ++- packages/runtime/src/enhancements/types.ts | 2 +- .../src/plugins/enhancer/enhance/index.ts | 19 +++-- .../enhancer/policy/policy-guard-generator.ts | 34 ++++---- .../enhancements/with-policy/checker.test.ts | 80 ++++++++++++++----- 5 files changed, 101 insertions(+), 45 deletions(-) diff --git a/packages/runtime/src/enhancements/policy/policy-utils.ts b/packages/runtime/src/enhancements/policy/policy-utils.ts index f910df264..e4207e384 100644 --- a/packages/runtime/src/enhancements/policy/policy-utils.ts +++ b/packages/runtime/src/enhancements/policy/policy-utils.ts @@ -587,12 +587,17 @@ export class PolicyUtil extends QueryUtils { return provider({ user: this.user }); } - private getModelChecker(model: string): PolicyDef['checker']['string'] { + private getModelChecker(model: string) { if (this.options.kinds && !this.options.kinds.includes('policy')) { - // policy enhancement not enabled, return a constant checker + // policy enhancement not enabled, return a constant true checker return { create: true, read: true, update: true, delete: true }; } else { - return this.options.policy.checker?.[lowerCaseFirst(model)]; + let result = this.options.policy.checker?.[lowerCaseFirst(model)]; + if (!result) { + // checker generation not enabled, return constant false checker + result = { create: false, read: false, update: false, delete: false }; + } + return result; } } diff --git a/packages/runtime/src/enhancements/types.ts b/packages/runtime/src/enhancements/types.ts index bb508a1d1..89d5ce9f6 100644 --- a/packages/runtime/src/enhancements/types.ts +++ b/packages/runtime/src/enhancements/types.ts @@ -122,7 +122,7 @@ export type PolicyDef = { } >; - checker: Record>; + checker?: Record>; // tracks which models have data validation rules validation: Record; diff --git a/packages/schema/src/plugins/enhancer/enhance/index.ts b/packages/schema/src/plugins/enhancer/enhance/index.ts index 81fe66c3c..62b1d03e6 100644 --- a/packages/schema/src/plugins/enhancer/enhance/index.ts +++ b/packages/schema/src/plugins/enhancer/enhance/index.ts @@ -90,7 +90,7 @@ export class EnhancerGenerator { const authTypes = authModel ? generateAuthType(this.model, authModel) : ''; const authTypeParam = authModel ? `auth.${authModel.name}` : 'AuthUser'; - const checkerTypes = generateCheckerType(this.model); + const checkerTypes = this.generatePermissionChecker ? generateCheckerType(this.model) : ''; const enhanceTs = this.project.createSourceFile( path.join(this.outDir, 'enhance.ts'), @@ -131,15 +131,16 @@ import type * as _P from '${prismaImport}'; } private createSimplePrismaEnhanceFunction(authTypeParam: string) { + const returnType = `DbClient${this.generatePermissionChecker ? ' & ModelCheckers' : ''}`; return ` -export function enhance(prisma: DbClient, context?: EnhancementContext<${authTypeParam}>, options?: EnhancementOptions): DbClient & ModelCheckers { +export function enhance(prisma: DbClient, context?: EnhancementContext<${authTypeParam}>, options?: EnhancementOptions): ${returnType} { return createEnhancement(prisma, { modelMeta, policy, zodSchemas: zodSchemas as unknown as (ZodSchemas | undefined), prismaModule: Prisma, ...options - }, context) as DbClient & ModelCheckers; + }, context) as ${returnType}; } `; } @@ -162,12 +163,16 @@ import type { Prisma, PrismaClient } from '${logicalPrismaClientDir}/index-fixed // overload for plain PrismaClient export function enhance & InternalArgs>( prisma: _PrismaClient, - context?: EnhancementContext<${authTypeParam}>, options?: EnhancementOptions): PrismaClient & ModelCheckers; + context?: EnhancementContext<${authTypeParam}>, options?: EnhancementOptions): PrismaClient${ + this.generatePermissionChecker ? ' & ModelCheckers' : '' + }; // overload for extended PrismaClient export function enhance & InternalArgs>( prisma: DynamicClientExtensionThis, - context?: EnhancementContext<${authTypeParam}>, options?: EnhancementOptions): DynamicClientExtensionThis & ModelCheckers; + context?: EnhancementContext<${authTypeParam}>, options?: EnhancementOptions): DynamicClientExtensionThis${ + this.generatePermissionChecker ? ' & ModelCheckers' : '' + }; export function enhance(prisma: any, context?: EnhancementContext<${authTypeParam}>, options?: EnhancementOptions): any { return createEnhancement(prisma, { @@ -627,4 +632,8 @@ export function enhance(prisma: any, context?: EnhancementContext<${authTypePara await sf.save(); } } + + private get generatePermissionChecker() { + return this.options.generatePermissionChecker === true; + } } diff --git a/packages/schema/src/plugins/enhancer/policy/policy-guard-generator.ts b/packages/schema/src/plugins/enhancer/policy/policy-guard-generator.ts index a7bd0e846..a36a52126 100644 --- a/packages/schema/src/plugins/enhancer/policy/policy-guard-generator.ts +++ b/packages/schema/src/plugins/enhancer/policy/policy-guard-generator.ts @@ -94,10 +94,14 @@ export class PolicyGenerator { policyMap[model.name] = await this.generateQueryGuardForModel(model, sf); } + const generatePermissionChecker = options.generatePermissionChecker === true; + // CRUD checker functions const checkerMap: Record> = {}; - for (const model of models) { - checkerMap[model.name] = await this.generateCheckerForModel(model, sf); + if (generatePermissionChecker) { + for (const model of models) { + checkerMap[model.name] = await this.generateCheckerForModel(model, sf); + } } const authSelector = this.generateAuthSelector(models); @@ -128,19 +132,21 @@ export class PolicyGenerator { }); writer.writeLine(','); - writer.write('checker:'); - writer.inlineBlock(() => { - for (const [model, map] of Object.entries(checkerMap)) { - writer.write(`${lowerCaseFirst(model)}:`); - writer.inlineBlock(() => { - Object.entries(map).forEach(([op, func]) => { - writer.write(`${op}: ${func},`); + if (generatePermissionChecker) { + writer.write('checker:'); + writer.inlineBlock(() => { + for (const [model, map] of Object.entries(checkerMap)) { + writer.write(`${lowerCaseFirst(model)}:`); + writer.inlineBlock(() => { + Object.entries(map).forEach(([op, func]) => { + writer.write(`${op}: ${func},`); + }); }); - }); - writer.writeLine(','); - } - }); - writer.writeLine(','); + writer.writeLine(','); + } + }); + writer.writeLine(','); + } writer.write('validation:'); writer.inlineBlock(() => { diff --git a/tests/integration/tests/enhancements/with-policy/checker.test.ts b/tests/integration/tests/enhancements/with-policy/checker.test.ts index 9d3af35fe..ed44c5665 100644 --- a/tests/integration/tests/enhancements/with-policy/checker.test.ts +++ b/tests/integration/tests/enhancements/with-policy/checker.test.ts @@ -1,8 +1,44 @@ -import { loadSchema } from '@zenstackhq/testtools'; +import { SchemaLoadOptions, loadSchema } from '@zenstackhq/testtools'; describe('Permission checker', () => { - it('empty rules', async () => { + const PRELUDE = ` + datasource db { + provider = 'sqlite' + url = 'file:./dev.db' + } + + generator js { + provider = 'prisma-client-js' + } + + plugin enhancer { + provider = '@core/enhancer' + generatePermissionChecker = true + } + `; + + const load = (schema: string, options?: SchemaLoadOptions) => + loadSchema(`${PRELUDE}\n${schema}`, { + ...options, + addPrelude: false, + }); + + it('checker generation not enabled', async () => { const { enhance } = await loadSchema( + ` + model Model { + id Int @id @default(autoincrement()) + value Int + @@allow('all', true) + } + ` + ); + const db = enhance(); + await expect(db.model.check('read')).toResolveFalsy(); + }); + + it('empty rules', async () => { + const { enhance } = await load( ` model Model { id Int @id @default(autoincrement()) @@ -16,7 +52,7 @@ describe('Permission checker', () => { }); it('unconditional allow', async () => { - const { enhance } = await loadSchema( + const { enhance } = await load( ` model Model { id Int @id @default(autoincrement()) @@ -31,7 +67,7 @@ describe('Permission checker', () => { }); it('deny rule', async () => { - const { enhance } = await loadSchema( + const { enhance } = await load( ` model Model { id Int @id @default(autoincrement()) @@ -49,7 +85,7 @@ describe('Permission checker', () => { }); it('int field condition', async () => { - const { enhance } = await loadSchema( + const { enhance } = await load( ` model Model { id Int @id @default(autoincrement()) @@ -82,7 +118,7 @@ describe('Permission checker', () => { }); it('boolean field toplevel condition', async () => { - const { enhance } = await loadSchema( + const { enhance } = await load( ` model Model { id Int @id @default(autoincrement()) @@ -99,7 +135,7 @@ describe('Permission checker', () => { }); it('boolean field condition', async () => { - const { enhance } = await loadSchema( + const { enhance } = await load( ` model Model { id Int @id @default(autoincrement()) @@ -131,7 +167,7 @@ describe('Permission checker', () => { }); it('string field condition', async () => { - const { enhance } = await loadSchema( + const { enhance } = await load( ` model Model { id Int @id @default(autoincrement()) @@ -148,7 +184,7 @@ describe('Permission checker', () => { }); it('function noop', async () => { - const { enhance } = await loadSchema( + const { enhance } = await load( ` model Model { id Int @id @default(autoincrement()) @@ -169,7 +205,7 @@ describe('Permission checker', () => { }); it('relation noop', async () => { - const { enhance } = await loadSchema( + const { enhance } = await load( ` model Model { id Int @id @default(autoincrement()) @@ -194,7 +230,7 @@ describe('Permission checker', () => { }); it('collection predicate noop', async () => { - const { enhance } = await loadSchema( + const { enhance } = await load( ` model Model { id Int @id @default(autoincrement()) @@ -219,7 +255,7 @@ describe('Permission checker', () => { }); it('field complex condition', async () => { - const { enhance } = await loadSchema( + const { enhance } = await load( ` model Model { id Int @id @default(autoincrement()) @@ -252,7 +288,7 @@ describe('Permission checker', () => { }); it('field condition unsolvable', async () => { - const { enhance } = await loadSchema( + const { enhance } = await load( ` model Model { id Int @id @default(autoincrement()) @@ -272,7 +308,7 @@ describe('Permission checker', () => { }); it('simple auth condition', async () => { - const { enhance } = await loadSchema( + const { enhance } = await load( ` model User { id Int @id @default(autoincrement()) @@ -307,7 +343,7 @@ describe('Permission checker', () => { }); it('auth compared with relation field', async () => { - const { enhance } = await loadSchema( + const { enhance } = await load( ` model User { id Int @id @default(autoincrement()) @@ -349,7 +385,7 @@ describe('Permission checker', () => { }); it('auth null check', async () => { - const { enhance } = await loadSchema( + const { enhance } = await load( ` model User { id Int @id @default(autoincrement()) @@ -379,7 +415,7 @@ describe('Permission checker', () => { }); it('auth with relation access', async () => { - const { enhance } = await loadSchema( + const { enhance } = await load( ` model User { id Int @id @default(autoincrement()) @@ -408,7 +444,7 @@ describe('Permission checker', () => { }); it('nullable field', async () => { - const { enhance } = await loadSchema( + const { enhance } = await load( ` model Model { id Int @id @default(autoincrement()) @@ -427,7 +463,7 @@ describe('Permission checker', () => { }); it('compilation', async () => { - await loadSchema( + await load( ` model Model { id Int @id @default(autoincrement()) @@ -456,7 +492,7 @@ describe('Permission checker', () => { }); it('invalid filter', async () => { - const { enhance } = await loadSchema( + const { enhance } = await load( ` model Model { id Int @id @default(autoincrement()) @@ -498,7 +534,7 @@ describe('Permission checker', () => { }); it('float field ignored', async () => { - const { enhance } = await loadSchema( + const { enhance } = await load( ` model Model { id Int @id @default(autoincrement()) @@ -513,7 +549,7 @@ describe('Permission checker', () => { }); it('float value ignored', async () => { - const { enhance } = await loadSchema( + const { enhance } = await load( ` model Model { id Int @id @default(autoincrement()) From 3665391bbf4ad5c38dd87c8e29633a211d5f7e05 Mon Sep 17 00:00:00 2001 From: ymc9 <104139426+ymc9@users.noreply.github.com> Date: Tue, 7 May 2024 20:52:27 +0800 Subject: [PATCH 11/13] multiple fixes and features - enum support - multiple allow conditions - rpc server adapter support --- packages/testtools/src/schema.ts | 2 + .../enhancements/with-policy/checker.test.ts | 360 +++++++++++------- 2 files changed, 225 insertions(+), 137 deletions(-) diff --git a/packages/testtools/src/schema.ts b/packages/testtools/src/schema.ts index 7249e6c4a..c2109e579 100644 --- a/packages/testtools/src/schema.ts +++ b/packages/testtools/src/schema.ts @@ -102,6 +102,7 @@ generator js { plugin enhancer { provider = '@core/enhancer' ${options.preserveTsFiles ? 'preserveTsFiles = true' : ''} + ${options.generatePermissionChecker ? 'generatePermissionChecker = true' : ''} } plugin zod { @@ -131,6 +132,7 @@ export type SchemaLoadOptions = { extraSourceFiles?: { name: string; content: string }[]; projectDir?: string; preserveTsFiles?: boolean; + generatePermissionChecker?: boolean; }; const defaultOptions: SchemaLoadOptions = { diff --git a/tests/integration/tests/enhancements/with-policy/checker.test.ts b/tests/integration/tests/enhancements/with-policy/checker.test.ts index ed44c5665..4a2c0193e 100644 --- a/tests/integration/tests/enhancements/with-policy/checker.test.ts +++ b/tests/integration/tests/enhancements/with-policy/checker.test.ts @@ -1,4 +1,4 @@ -import { SchemaLoadOptions, loadSchema } from '@zenstackhq/testtools'; +import { SchemaLoadOptions, createPostgresDb, dropPostgresDb, loadSchema } from '@zenstackhq/testtools'; describe('Permission checker', () => { const PRELUDE = ` @@ -18,9 +18,9 @@ describe('Permission checker', () => { `; const load = (schema: string, options?: SchemaLoadOptions) => - loadSchema(`${PRELUDE}\n${schema}`, { + loadSchema(schema, { ...options, - addPrelude: false, + generatePermissionChecker: true, }); it('checker generation not enabled', async () => { @@ -34,7 +34,7 @@ describe('Permission checker', () => { ` ); const db = enhance(); - await expect(db.model.check('read')).toResolveFalsy(); + await expect(db.model.check({ operation: 'read' })).rejects.toThrow('Generated permission checkers not found'); }); it('empty rules', async () => { @@ -47,8 +47,8 @@ describe('Permission checker', () => { ` ); const db = enhance(); - await expect(db.model.check('read')).toResolveFalsy(); - await expect(db.model.check('read', { value: 1 })).toResolveFalsy(); + await expect(db.model.check({ operation: 'read' })).toResolveFalsy(); + await expect(db.model.check({ operation: 'read', filter: { value: 1 } })).toResolveFalsy(); }); it('unconditional allow', async () => { @@ -62,8 +62,26 @@ describe('Permission checker', () => { ` ); const db = enhance(); - await expect(db.model.check('read')).toResolveTruthy(); - await expect(db.model.check('read', { value: 0 })).toResolveTruthy(); + await expect(db.model.check({ operation: 'read' })).toResolveTruthy(); + await expect(db.model.check({ operation: 'read', filter: { value: 0 } })).toResolveTruthy(); + }); + + it('multiple allow rules', async () => { + const { enhance } = await load( + ` + model Model { + id Int @id @default(autoincrement()) + value Int + @@allow('all', value == 1) + @@allow('all', value == 2) + } + ` + ); + const db = enhance(); + await expect(db.model.check({ operation: 'read' })).toResolveTruthy(); + await expect(db.model.check({ operation: 'read', filter: { value: 0 } })).toResolveFalsy(); + await expect(db.model.check({ operation: 'read', filter: { value: 1 } })).toResolveTruthy(); + await expect(db.model.check({ operation: 'read', filter: { value: 2 } })).toResolveTruthy(); }); it('deny rule', async () => { @@ -78,10 +96,10 @@ describe('Permission checker', () => { ` ); const db = enhance(); - await expect(db.model.check('read')).toResolveTruthy(); - await expect(db.model.check('read', { value: 0 })).toResolveFalsy(); - await expect(db.model.check('read', { value: 1 })).toResolveFalsy(); - await expect(db.model.check('read', { value: 2 })).toResolveTruthy(); + await expect(db.model.check({ operation: 'read' })).toResolveTruthy(); + await expect(db.model.check({ operation: 'read', filter: { value: 0 } })).toResolveFalsy(); + await expect(db.model.check({ operation: 'read', filter: { value: 1 } })).toResolveFalsy(); + await expect(db.model.check({ operation: 'read', filter: { value: 2 } })).toResolveTruthy(); }); it('int field condition', async () => { @@ -99,22 +117,22 @@ describe('Permission checker', () => { ); const db = enhance(); - await expect(db.model.check('read')).toResolveTruthy(); - await expect(db.model.check('read', { value: 0 })).toResolveFalsy(); - await expect(db.model.check('read', { value: 1 })).toResolveTruthy(); - - await expect(db.model.check('create')).toResolveTruthy(); - await expect(db.model.check('create', { value: 0 })).toResolveTruthy(); - await expect(db.model.check('create', { value: 1 })).toResolveFalsy(); - - await expect(db.model.check('update')).toResolveTruthy(); - await expect(db.model.check('update', { value: 1 })).toResolveFalsy(); - await expect(db.model.check('update', { value: 2 })).toResolveTruthy(); - - await expect(db.model.check('delete')).toResolveTruthy(); - await expect(db.model.check('delete', { value: 0 })).toResolveTruthy(); - await expect(db.model.check('delete', { value: 1 })).toResolveTruthy(); - await expect(db.model.check('delete', { value: 2 })).toResolveFalsy(); + await expect(db.model.check({ operation: 'read' })).toResolveTruthy(); + await expect(db.model.check({ operation: 'read', filter: { value: 0 } })).toResolveFalsy(); + await expect(db.model.check({ operation: 'read', filter: { value: 1 } })).toResolveTruthy(); + + await expect(db.model.check({ operation: 'create' })).toResolveTruthy(); + await expect(db.model.check({ operation: 'create', filter: { value: 0 } })).toResolveTruthy(); + await expect(db.model.check({ operation: 'create', filter: { value: 1 } })).toResolveFalsy(); + + await expect(db.model.check({ operation: 'update' })).toResolveTruthy(); + await expect(db.model.check({ operation: 'update', filter: { value: 1 } })).toResolveFalsy(); + await expect(db.model.check({ operation: 'update', filter: { value: 2 } })).toResolveTruthy(); + + await expect(db.model.check({ operation: 'delete' })).toResolveTruthy(); + await expect(db.model.check({ operation: 'delete', filter: { value: 0 } })).toResolveTruthy(); + await expect(db.model.check({ operation: 'delete', filter: { value: 1 } })).toResolveTruthy(); + await expect(db.model.check({ operation: 'delete', filter: { value: 2 } })).toResolveFalsy(); }); it('boolean field toplevel condition', async () => { @@ -129,9 +147,9 @@ describe('Permission checker', () => { ); const db = enhance(); - await expect(db.model.check('read')).toResolveTruthy(); - await expect(db.model.check('read', { value: false })).toResolveFalsy(); - await expect(db.model.check('read', { value: true })).toResolveTruthy(); + await expect(db.model.check({ operation: 'read' })).toResolveTruthy(); + await expect(db.model.check({ operation: 'read', filter: { value: false } })).toResolveFalsy(); + await expect(db.model.check({ operation: 'read', filter: { value: true } })).toResolveTruthy(); }); it('boolean field condition', async () => { @@ -149,21 +167,21 @@ describe('Permission checker', () => { ); const db = enhance(); - await expect(db.model.check('read')).toResolveTruthy(); - await expect(db.model.check('read', { value: false })).toResolveFalsy(); - await expect(db.model.check('read', { value: true })).toResolveTruthy(); + await expect(db.model.check({ operation: 'read' })).toResolveTruthy(); + await expect(db.model.check({ operation: 'read', filter: { value: false } })).toResolveFalsy(); + await expect(db.model.check({ operation: 'read', filter: { value: true } })).toResolveTruthy(); - await expect(db.model.check('create')).toResolveTruthy(); - await expect(db.model.check('create', { value: true })).toResolveFalsy(); - await expect(db.model.check('create', { value: false })).toResolveTruthy(); + await expect(db.model.check({ operation: 'create' })).toResolveTruthy(); + await expect(db.model.check({ operation: 'create', filter: { value: true } })).toResolveFalsy(); + await expect(db.model.check({ operation: 'create', filter: { value: false } })).toResolveTruthy(); - await expect(db.model.check('update')).toResolveTruthy(); - await expect(db.model.check('update', { value: true })).toResolveFalsy(); - await expect(db.model.check('update', { value: false })).toResolveTruthy(); + await expect(db.model.check({ operation: 'update' })).toResolveTruthy(); + await expect(db.model.check({ operation: 'update', filter: { value: true } })).toResolveFalsy(); + await expect(db.model.check({ operation: 'update', filter: { value: false } })).toResolveTruthy(); - await expect(db.model.check('delete')).toResolveTruthy(); - await expect(db.model.check('delete', { value: false })).toResolveFalsy(); - await expect(db.model.check('delete', { value: true })).toResolveTruthy(); + await expect(db.model.check({ operation: 'delete' })).toResolveTruthy(); + await expect(db.model.check({ operation: 'delete', filter: { value: false } })).toResolveFalsy(); + await expect(db.model.check({ operation: 'delete', filter: { value: true } })).toResolveTruthy(); }); it('string field condition', async () => { @@ -178,9 +196,57 @@ describe('Permission checker', () => { ); const db = enhance(); - await expect(db.model.check('read')).toResolveTruthy(); - await expect(db.model.check('read', { value: 'user' })).toResolveFalsy(); - await expect(db.model.check('read', { value: 'admin' })).toResolveTruthy(); + await expect(db.model.check({ operation: 'read' })).toResolveTruthy(); + await expect(db.model.check({ operation: 'read', filter: { value: 'user' } })).toResolveFalsy(); + await expect(db.model.check({ operation: 'read', filter: { value: 'admin' } })).toResolveTruthy(); + }); + + it('enum', async () => { + const dbUrl = await createPostgresDb('permission-checker-enum'); + let prisma: any; + try { + const r = await loadSchema( + ` + datasource db { + provider = 'postgresql' + url = '${dbUrl}' + } + + generator js { + provider = 'prisma-client-js' + } + + plugin enhancer { + provider = '@core/enhancer' + generatePermissionChecker = true + } + + enum Role { + USER + ADMIN + } + model User { + id Int @id @default(autoincrement()) + role Role + } + model Model { + id Int @id @default(autoincrement()) + @@allow('read', auth().role == ADMIN) + } + `, + { addPrelude: false, generatePermissionChecker: true } + ); + + prisma = r.prisma; + const enhance = r.enhance; + + await expect(enhance().model.check({ operation: 'read' })).toResolveFalsy(); + await expect(enhance({ id: 1, role: 'USER' }).model.check({ operation: 'read' })).toResolveFalsy(); + await expect(enhance({ id: 1, role: 'ADMIN' }).model.check({ operation: 'read' })).toResolveTruthy(); + } finally { + await prisma.$disconnect(); + await dropPostgresDb('permission-checker-enum'); + } }); it('function noop', async () => { @@ -196,12 +262,12 @@ describe('Permission checker', () => { ); const db = enhance(); - await expect(db.model.check('read')).toResolveTruthy(); - await expect(db.model.check('read', { value: 'user' })).toResolveTruthy(); - await expect(db.model.check('read', { value: 'admin' })).toResolveTruthy(); - await expect(db.model.check('update')).toResolveTruthy(); - await expect(db.model.check('update', { value: 'user' })).toResolveTruthy(); - await expect(db.model.check('update', { value: 'admin' })).toResolveTruthy(); + await expect(db.model.check({ operation: 'read' })).toResolveTruthy(); + await expect(db.model.check({ operation: 'read', filter: { value: 'user' } })).toResolveTruthy(); + await expect(db.model.check({ operation: 'read', filter: { value: 'admin' } })).toResolveTruthy(); + await expect(db.model.check({ operation: 'update' })).toResolveTruthy(); + await expect(db.model.check({ operation: 'update', filter: { value: 'user' } })).toResolveTruthy(); + await expect(db.model.check({ operation: 'update', filter: { value: 'admin' } })).toResolveTruthy(); }); it('relation noop', async () => { @@ -225,8 +291,10 @@ describe('Permission checker', () => { ); const db = enhance(); - await expect(db.model.check('read')).toResolveTruthy(); - await expect(db.model.check('read', { foo: { x: 0 } })).rejects.toThrow('Providing filter for field "foo"'); + await expect(db.model.check({ operation: 'read' })).toResolveTruthy(); + await expect(db.model.check({ operation: 'read', filter: { foo: { x: 0 } } })).rejects.toThrow( + 'Providing filter for field "foo"' + ); }); it('collection predicate noop', async () => { @@ -250,8 +318,10 @@ describe('Permission checker', () => { ); const db = enhance(); - await expect(db.model.check('read')).toResolveTruthy(); - await expect(db.model.check('read', { foo: [{ x: 0 }] })).rejects.toThrow('Providing filter for field "foo"'); + await expect(db.model.check({ operation: 'read' })).toResolveTruthy(); + await expect(db.model.check({ operation: 'read', filter: { foo: [{ x: 0 }] } })).rejects.toThrow( + 'Providing filter for field "foo"' + ); }); it('field complex condition', async () => { @@ -269,22 +339,22 @@ describe('Permission checker', () => { ); const db = enhance(); - await expect(db.model.check('read')).toResolveTruthy(); - await expect(db.model.check('read', { x: 0 })).toResolveFalsy(); - await expect(db.model.check('read', { x: 1 })).toResolveTruthy(); - await expect(db.model.check('read', { x: 1, y: 0 })).toResolveTruthy(); - await expect(db.model.check('read', { x: 1, y: 1 })).toResolveFalsy(); - - await expect(db.model.check('create')).toResolveTruthy(); - await expect(db.model.check('create', { x: 0 })).toResolveFalsy(); // numbers are non-negative - await expect(db.model.check('create', { x: 1 })).toResolveTruthy(); - await expect(db.model.check('create', { x: 1, y: 0 })).toResolveTruthy(); - await expect(db.model.check('create', { x: 1, y: 1 })).toResolveFalsy(); - - await expect(db.model.check('update')).toResolveTruthy(); - await expect(db.model.check('update', { x: 0 })).toResolveTruthy(); - await expect(db.model.check('update', { y: 0 })).toResolveFalsy(); // numbers are non-negative - await expect(db.model.check('update', { x: 1, y: 1 })).toResolveFalsy(); + await expect(db.model.check({ operation: 'read' })).toResolveTruthy(); + await expect(db.model.check({ operation: 'read', filter: { x: 0 } })).toResolveFalsy(); + await expect(db.model.check({ operation: 'read', filter: { x: 1 } })).toResolveTruthy(); + await expect(db.model.check({ operation: 'read', filter: { x: 1, y: 0 } })).toResolveTruthy(); + await expect(db.model.check({ operation: 'read', filter: { x: 1, y: 1 } })).toResolveFalsy(); + + await expect(db.model.check({ operation: 'create' })).toResolveTruthy(); + await expect(db.model.check({ operation: 'create', filter: { x: 0 } })).toResolveFalsy(); // numbers are non-negative + await expect(db.model.check({ operation: 'create', filter: { x: 1 } })).toResolveTruthy(); + await expect(db.model.check({ operation: 'create', filter: { x: 1, y: 0 } })).toResolveTruthy(); + await expect(db.model.check({ operation: 'create', filter: { x: 1, y: 1 } })).toResolveFalsy(); + + await expect(db.model.check({ operation: 'update' })).toResolveTruthy(); + await expect(db.model.check({ operation: 'update', filter: { x: 0 } })).toResolveTruthy(); + await expect(db.model.check({ operation: 'update', filter: { y: 0 } })).toResolveFalsy(); // numbers are non-negative + await expect(db.model.check({ operation: 'update', filter: { x: 1, y: 1 } })).toResolveFalsy(); }); it('field condition unsolvable', async () => { @@ -300,11 +370,11 @@ describe('Permission checker', () => { ); const db = enhance(); - await expect(db.model.check('read')).toResolveFalsy(); - await expect(db.model.check('read', { x: 0 })).toResolveFalsy(); - await expect(db.model.check('read', { x: 1 })).toResolveFalsy(); - await expect(db.model.check('read', { x: 1, y: 2 })).toResolveFalsy(); - await expect(db.model.check('read', { x: 1, y: 1 })).toResolveFalsy(); + await expect(db.model.check({ operation: 'read' })).toResolveFalsy(); + await expect(db.model.check({ operation: 'read', filter: { x: 0 } })).toResolveFalsy(); + await expect(db.model.check({ operation: 'read', filter: { x: 1 } })).toResolveFalsy(); + await expect(db.model.check({ operation: 'read', filter: { x: 1, y: 2 } })).toResolveFalsy(); + await expect(db.model.check({ operation: 'read', filter: { x: 1, y: 1 } })).toResolveFalsy(); }); it('simple auth condition', async () => { @@ -326,20 +396,20 @@ describe('Permission checker', () => { ` ); - await expect(enhance().model.check('read')).toResolveFalsy(); - await expect(enhance({ id: 1 }).model.check('read')).toResolveFalsy(); - await expect(enhance({ id: 1, level: 0 }).model.check('read')).toResolveFalsy(); - await expect(enhance({ id: 1, level: 1 }).model.check('read')).toResolveTruthy(); + await expect(enhance().model.check({ operation: 'read' })).toResolveFalsy(); + await expect(enhance({ id: 1 }).model.check({ operation: 'read' })).toResolveFalsy(); + await expect(enhance({ id: 1, level: 0 }).model.check({ operation: 'read' })).toResolveFalsy(); + await expect(enhance({ id: 1, level: 1 }).model.check({ operation: 'read' })).toResolveTruthy(); - await expect(enhance().model.check('create')).toResolveFalsy(); - await expect(enhance({ id: 1 }).model.check('create')).toResolveFalsy(); - await expect(enhance({ id: 1, admin: false }).model.check('create')).toResolveFalsy(); - await expect(enhance({ id: 1, admin: true }).model.check('create')).toResolveTruthy(); + await expect(enhance().model.check({ operation: 'create' })).toResolveFalsy(); + await expect(enhance({ id: 1 }).model.check({ operation: 'create' })).toResolveFalsy(); + await expect(enhance({ id: 1, admin: false }).model.check({ operation: 'create' })).toResolveFalsy(); + await expect(enhance({ id: 1, admin: true }).model.check({ operation: 'create' })).toResolveTruthy(); - await expect(enhance().model.check('update')).toResolveTruthy(); - await expect(enhance({ id: 1 }).model.check('update')).toResolveTruthy(); - await expect(enhance({ id: 1, admin: true }).model.check('update')).toResolveFalsy(); - await expect(enhance({ id: 1, admin: false }).model.check('update')).toResolveTruthy(); + await expect(enhance().model.check({ operation: 'update' })).toResolveTruthy(); + await expect(enhance({ id: 1 }).model.check({ operation: 'update' })).toResolveTruthy(); + await expect(enhance({ id: 1, admin: true }).model.check({ operation: 'update' })).toResolveFalsy(); + await expect(enhance({ id: 1, admin: false }).model.check({ operation: 'update' })).toResolveTruthy(); }); it('auth compared with relation field', async () => { @@ -363,25 +433,25 @@ describe('Permission checker', () => { { preserveTsFiles: true } ); - await expect(enhance().model.check('read')).toResolveFalsy(); - await expect(enhance({ id: 1 }).model.check('read')).toResolveTruthy(); - await expect(enhance({ id: 1 }).model.check('read', { ownerId: 1 })).toResolveTruthy(); - await expect(enhance({ id: 1 }).model.check('read', { ownerId: 2 })).toResolveFalsy(); - - await expect(enhance().model.check('create')).toResolveFalsy(); - await expect(enhance({ id: 1 }).model.check('create')).toResolveTruthy(); - await expect(enhance({ id: 1 }).model.check('create', { ownerId: 1 })).toResolveFalsy(); - await expect(enhance({ id: 1 }).model.check('create', { ownerId: 2 })).toResolveTruthy(); - - await expect(enhance().model.check('update')).toResolveFalsy(); - await expect(enhance({ id: 1 }).model.check('update')).toResolveTruthy(); - await expect(enhance({ id: 1 }).model.check('update', { ownerId: 1 })).toResolveTruthy(); - await expect(enhance({ id: 1 }).model.check('update', { ownerId: 2 })).toResolveFalsy(); - - await expect(enhance().model.check('delete')).toResolveFalsy(); - await expect(enhance({ id: 1 }).model.check('delete')).toResolveTruthy(); - await expect(enhance({ id: 1 }).model.check('delete', { ownerId: 1 })).toResolveFalsy(); - await expect(enhance({ id: 1 }).model.check('delete', { ownerId: 2 })).toResolveTruthy(); + await expect(enhance().model.check({ operation: 'read' })).toResolveFalsy(); + await expect(enhance({ id: 1 }).model.check({ operation: 'read' })).toResolveTruthy(); + await expect(enhance({ id: 1 }).model.check({ operation: 'read', filter: { ownerId: 1 } })).toResolveTruthy(); + await expect(enhance({ id: 1 }).model.check({ operation: 'read', filter: { ownerId: 2 } })).toResolveFalsy(); + + await expect(enhance().model.check({ operation: 'create' })).toResolveFalsy(); + await expect(enhance({ id: 1 }).model.check({ operation: 'create' })).toResolveTruthy(); + await expect(enhance({ id: 1 }).model.check({ operation: 'create', filter: { ownerId: 1 } })).toResolveFalsy(); + await expect(enhance({ id: 1 }).model.check({ operation: 'create', filter: { ownerId: 2 } })).toResolveTruthy(); + + await expect(enhance().model.check({ operation: 'update' })).toResolveFalsy(); + await expect(enhance({ id: 1 }).model.check({ operation: 'update' })).toResolveTruthy(); + await expect(enhance({ id: 1 }).model.check({ operation: 'update', filter: { ownerId: 1 } })).toResolveTruthy(); + await expect(enhance({ id: 1 }).model.check({ operation: 'update', filter: { ownerId: 2 } })).toResolveFalsy(); + + await expect(enhance().model.check({ operation: 'delete' })).toResolveFalsy(); + await expect(enhance({ id: 1 }).model.check({ operation: 'delete' })).toResolveTruthy(); + await expect(enhance({ id: 1 }).model.check({ operation: 'delete', filter: { ownerId: 1 } })).toResolveFalsy(); + await expect(enhance({ id: 1 }).model.check({ operation: 'delete', filter: { ownerId: 2 } })).toResolveTruthy(); }); it('auth null check', async () => { @@ -402,16 +472,16 @@ describe('Permission checker', () => { ` ); - await expect(enhance().model.check('read')).toResolveFalsy(); - await expect(enhance({ id: 1 }).model.check('read')).toResolveTruthy(); + await expect(enhance().model.check({ operation: 'read' })).toResolveFalsy(); + await expect(enhance({ id: 1 }).model.check({ operation: 'read' })).toResolveTruthy(); - await expect(enhance().model.check('create')).toResolveTruthy(); - await expect(enhance({ id: 1 }).model.check('create')).toResolveFalsy(); + await expect(enhance().model.check({ operation: 'create' })).toResolveTruthy(); + await expect(enhance({ id: 1 }).model.check({ operation: 'create' })).toResolveFalsy(); - await expect(enhance().model.check('update')).toResolveFalsy(); - await expect(enhance({ id: 1 }).model.check('update')).toResolveFalsy(); - await expect(enhance({ id: 1, level: 0 }).model.check('update')).toResolveFalsy(); - await expect(enhance({ id: 1, level: 1 }).model.check('update')).toResolveTruthy(); + await expect(enhance().model.check({ operation: 'update' })).toResolveFalsy(); + await expect(enhance({ id: 1 }).model.check({ operation: 'update' })).toResolveFalsy(); + await expect(enhance({ id: 1, level: 0 }).model.check({ operation: 'update' })).toResolveFalsy(); + await expect(enhance({ id: 1, level: 1 }).model.check({ operation: 'update' })).toResolveTruthy(); }); it('auth with relation access', async () => { @@ -437,10 +507,10 @@ describe('Permission checker', () => { ` ); - await expect(enhance().model.check('read')).toResolveFalsy(); - await expect(enhance({ id: 1 }).model.check('read')).toResolveFalsy(); - await expect(enhance({ id: 1, profile: { level: 0 } }).model.check('read')).toResolveFalsy(); - await expect(enhance({ id: 1, profile: { level: 1 } }).model.check('read')).toResolveTruthy(); + await expect(enhance().model.check({ operation: 'read' })).toResolveFalsy(); + await expect(enhance({ id: 1 }).model.check({ operation: 'read' })).toResolveFalsy(); + await expect(enhance({ id: 1, profile: { level: 0 } }).model.check({ operation: 'read' })).toResolveFalsy(); + await expect(enhance({ id: 1, profile: { level: 1 } }).model.check({ operation: 'read' })).toResolveTruthy(); }); it('nullable field', async () => { @@ -456,10 +526,10 @@ describe('Permission checker', () => { ); const db = enhance(); - await expect(db.model.check('read')).toResolveTruthy(); - await expect(db.model.check('read', { value: 1 })).toResolveTruthy(); - await expect(db.model.check('create')).toResolveTruthy(); - await expect(db.model.check('create', { value: 1 })).toResolveTruthy(); + await expect(db.model.check({ operation: 'read' })).toResolveTruthy(); + await expect(db.model.check({ operation: 'read', filter: { value: 1 } })).toResolveTruthy(); + await expect(db.model.check({ operation: 'create' })).toResolveTruthy(); + await expect(db.model.check({ operation: 'create', filter: { value: 1 } })).toResolveTruthy(); }); it('compilation', async () => { @@ -482,8 +552,8 @@ describe('Permission checker', () => { const prisma = new PrismaClient(); const db = enhance(prisma); - db.model.check('read'); - db.model.check('read', { value: 1 }); + db.model.check({ operation: 'read' }); + db.model.check({ operation: 'read', filter: { value: 1 }}); `, }, ], @@ -513,22 +583,22 @@ describe('Permission checker', () => { ); const db = enhance(); - await expect(db.model.check('read', { foo: { x: 1 } })).rejects.toThrow( + await expect(db.model.check({ operation: 'read', filter: { foo: { x: 1 } } })).rejects.toThrow( `Providing filter for field "foo" is not supported. Only scalar fields are allowed.` ); - await expect(db.model.check('read', { d: new Date() })).rejects.toThrow( + await expect(db.model.check({ operation: 'read', filter: { d: new Date() } })).rejects.toThrow( `Providing filter for field "d" is not supported. Only number, string, and boolean fields are allowed.` ); - await expect(db.model.check('read', { value: null })).rejects.toThrow( + await expect(db.model.check({ operation: 'read', filter: { value: null } })).rejects.toThrow( `Using "null" as filter value is not supported yet` ); - await expect(db.model.check('read', { value: {} })).rejects.toThrow( + await expect(db.model.check({ operation: 'read', filter: { value: {} } })).rejects.toThrow( 'Invalid value type for field "value". Only number, string or boolean is allowed.' ); - await expect(db.model.check('read', { value: 'abc' })).rejects.toThrow( + await expect(db.model.check({ operation: 'read', filter: { value: 'abc' } })).rejects.toThrow( 'Invalid value type for field "value". Expected "number"' ); - await expect(db.model.check('read', { value: -1 })).rejects.toThrow( + await expect(db.model.check({ operation: 'read', filter: { value: -1 } })).rejects.toThrow( 'Invalid value for field "value". Only non-negative integers are allowed.' ); }); @@ -544,8 +614,8 @@ describe('Permission checker', () => { ` ); const db = enhance(); - await expect(db.model.check('read')).toResolveTruthy(); - await expect(db.model.check('read', { value: 1 })).toResolveTruthy(); + await expect(db.model.check({ operation: 'read' })).toResolveTruthy(); + await expect(db.model.check({ operation: 'read', filter: { value: 1 } })).toResolveTruthy(); }); it('float value ignored', async () => { @@ -559,8 +629,24 @@ describe('Permission checker', () => { ` ); const db = enhance(); - // await expect(db.model.check('read')).toResolveTruthy(); - await expect(db.model.check('read', { value: 1 })).toResolveTruthy(); - await expect(db.model.check('read', { value: 2 })).toResolveTruthy(); + await expect(db.model.check({ operation: 'read' })).toResolveTruthy(); + await expect(db.model.check({ operation: 'read', filter: { value: 1 } })).toResolveTruthy(); + await expect(db.model.check({ operation: 'read', filter: { value: 2 } })).toResolveTruthy(); + }); + + it('negative value ignored', async () => { + const { enhance } = await load( + ` + model Model { + id Int @id @default(autoincrement()) + value Int + @@allow('read', value >-1) + } + ` + ); + const db = enhance(); + await expect(db.model.check({ operation: 'read' })).toResolveTruthy(); + await expect(db.model.check({ operation: 'read', filter: { value: 1 } })).toResolveTruthy(); + await expect(db.model.check({ operation: 'read', filter: { value: 2 } })).toResolveTruthy(); }); }); From a68e06860692af8b3bc36517abc24d67066fb1fb Mon Sep 17 00:00:00 2001 From: ymc9 <104139426+ymc9@users.noreply.github.com> Date: Tue, 7 May 2024 20:52:46 +0800 Subject: [PATCH 12/13] multiple fixes --- .../src/enhancements/policy/handler.ts | 21 +++++++----- .../src/enhancements/policy/policy-utils.ts | 6 ++-- packages/runtime/src/types.ts | 1 + .../enhance/checker-type-generator.ts | 4 +-- .../enhancer/policy/constraint-transformer.ts | 27 ++++++++++----- packages/server/src/api/rpc/index.ts | 1 + packages/server/tests/api/rpc.test.ts | 33 ++++++++++++++++++- 7 files changed, 70 insertions(+), 23 deletions(-) diff --git a/packages/runtime/src/enhancements/policy/handler.ts b/packages/runtime/src/enhancements/policy/handler.ts index bf6df3e7b..f84af001a 100644 --- a/packages/runtime/src/enhancements/policy/handler.ts +++ b/packages/runtime/src/enhancements/policy/handler.ts @@ -38,6 +38,12 @@ type PostWriteCheckRecord = { type FindOperations = 'findUnique' | 'findUniqueOrThrow' | 'findFirst' | 'findFirstOrThrow' | 'findMany'; +// input arg type for `check` API +type PermissionCheckArgs = { + operation: PolicyCrudKind; + filter?: Record; +}; + /** * Prisma proxy handler for injecting access policy check. */ @@ -1446,24 +1452,21 @@ export class PolicyProxyHandler implements Pr * @param operation The CRUD operation. * @param fieldValues Extra field value filters to be combined with the policy constraints. */ - async check( - operation: PolicyCrudKind, - fieldValues?: Record - ): Promise { - return createDeferredPromise(() => this.doCheck(operation, fieldValues)); + async check(args: PermissionCheckArgs): Promise { + return createDeferredPromise(() => this.doCheck(args)); } - private async doCheck(operation: PolicyCrudKind, fieldValues?: Record) { - let constraint = this.policyUtils.getCheckerConstraint(this.model, operation); + private async doCheck(args: PermissionCheckArgs) { + let constraint = this.policyUtils.getCheckerConstraint(this.model, args.operation); if (typeof constraint === 'boolean') { return constraint; } - if (fieldValues) { + if (args.filter) { // combine runtime filters with generated constraints const extraConstraints: CheckerConstraint[] = []; - for (const [field, value] of Object.entries(fieldValues)) { + for (const [field, value] of Object.entries(args.filter)) { if (value === undefined) { continue; } diff --git a/packages/runtime/src/enhancements/policy/policy-utils.ts b/packages/runtime/src/enhancements/policy/policy-utils.ts index e4207e384..2cddaae5e 100644 --- a/packages/runtime/src/enhancements/policy/policy-utils.ts +++ b/packages/runtime/src/enhancements/policy/policy-utils.ts @@ -592,10 +592,12 @@ export class PolicyUtil extends QueryUtils { // policy enhancement not enabled, return a constant true checker return { create: true, read: true, update: true, delete: true }; } else { - let result = this.options.policy.checker?.[lowerCaseFirst(model)]; + const result = this.options.policy.checker?.[lowerCaseFirst(model)]; if (!result) { // checker generation not enabled, return constant false checker - result = { create: false, read: false, update: false, delete: false }; + throw new Error( + `Generated permission checkers not found. Please make sure the "generatePermissionChecker" option is set to true in the "@core/enhancer" plugin.` + ); } return result; } diff --git a/packages/runtime/src/types.ts b/packages/runtime/src/types.ts index 0e1c86e93..4c32480ba 100644 --- a/packages/runtime/src/types.ts +++ b/packages/runtime/src/types.ts @@ -22,6 +22,7 @@ export interface DbOperations { groupBy(args: unknown): Promise; count(args?: unknown): Promise; subscribe(args?: unknown): Promise; + check(args: unknown): Promise; fields: Record; } diff --git a/packages/schema/src/plugins/enhancer/enhance/checker-type-generator.ts b/packages/schema/src/plugins/enhancer/enhance/checker-type-generator.ts index bda5dd9da..fe6c415cc 100644 --- a/packages/schema/src/plugins/enhancer/enhance/checker-type-generator.ts +++ b/packages/schema/src/plugins/enhancer/enhance/checker-type-generator.ts @@ -19,7 +19,7 @@ import { P, match } from 'ts-pattern'; */ export function generateCheckerType(model: Model) { return ` -type CheckerOperation = 'create' | 'read' | 'update' | 'delete'; +import type { PolicyCrudKind } from '@zenstackhq/runtime'; export interface ModelCheckers { ${getDataModels(model) @@ -31,7 +31,7 @@ export interface ModelCheckers { function generateDataModelChecker(dataModel: DataModel) { return `{ - check(op: CheckerOperation, args?: ${generateDataModelArgs(dataModel)}): Promise + check(args: { operation: PolicyCrudKind, filter?: ${generateDataModelArgs(dataModel)} }): Promise }`; } diff --git a/packages/schema/src/plugins/enhancer/policy/constraint-transformer.ts b/packages/schema/src/plugins/enhancer/policy/constraint-transformer.ts index 08d72f945..a0b1c1dd2 100644 --- a/packages/schema/src/plugins/enhancer/policy/constraint-transformer.ts +++ b/packages/schema/src/plugins/enhancer/policy/constraint-transformer.ts @@ -1,4 +1,9 @@ -import { ZModelCodeGenerator, getRelationKeyPairs, isAuthInvocation, isDataModelFieldReference } from '@zenstackhq/sdk'; +import { + getRelationKeyPairs, + isAuthInvocation, + isDataModelFieldReference, + isEnumFieldReference, +} from '@zenstackhq/sdk'; import { BinaryExpr, BooleanLiteral, @@ -13,6 +18,7 @@ import { UnaryExpr, isBinaryExpr, isDataModelField, + isEnum, isLiteralExpr, isMemberAccessExpr, isNullExpr, @@ -55,7 +61,7 @@ export class ConstraintTransformer { // transform allow rules const allowConstraints = allows.map((allow) => this.transformExpression(allow)); if (allowConstraints.length > 1) { - result = this.and(...allowConstraints); + result = this.or(...allowConstraints); } else { result = allowConstraints[0]; } @@ -63,7 +69,7 @@ export class ConstraintTransformer { // transform deny rules and compose if (denies.length > 0) { const denyConstraints = denies.map((deny) => this.transformExpression(deny)); - result = this.and(result, this.not(this.or(...denyConstraints))); + result = this.and(result, ...denyConstraints.map((c) => this.not(c))); } // DEBUG: @@ -279,6 +285,10 @@ export class ConstraintTransformer { return this.transformLiteral(expr); } + if (isEnumFieldReference(expr)) { + return this.value(`'${expr.target.$refText}'`, 'string'); + } + const fieldAccess = this.getFieldAccess(expr); if (fieldAccess) { // model field access is transformed into a named variable @@ -304,7 +314,11 @@ export class ConstraintTransformer { } private mapExpressionType(expression: Expression) { - return this.mapType(expression.$resolvedType?.decl as ExpressionType); + if (isEnum(expression.$resolvedType?.decl)) { + return 'string'; + } else { + return this.mapType(expression.$resolvedType?.decl as ExpressionType); + } } private mapType(type: ExpressionType) { @@ -355,11 +369,6 @@ export class ConstraintTransformer { return this.variable(`__var${this.varCounter++}`, type); } - private expressionVariable(expr: Expression, type: string) { - const name = new ZModelCodeGenerator().generate(expr); - return this.variable(name, type); - } - private variable(name: string, type: string) { return `{ kind: 'variable', name: '${name}', type: '${type}' }`; } diff --git a/packages/server/src/api/rpc/index.ts b/packages/server/src/api/rpc/index.ts index a7fb44d72..a8882b8be 100644 --- a/packages/server/src/api/rpc/index.ts +++ b/packages/server/src/api/rpc/index.ts @@ -81,6 +81,7 @@ class RequestHandler extends APIHandlerBase { case 'aggregate': case 'groupBy': case 'count': + case 'check': if (method !== 'GET') { return { status: 400, diff --git a/packages/server/tests/api/rpc.test.ts b/packages/server/tests/api/rpc.test.ts index 432abec2c..b24a7c108 100644 --- a/packages/server/tests/api/rpc.test.ts +++ b/packages/server/tests/api/rpc.test.ts @@ -15,7 +15,7 @@ describe('RPC API Handler Tests', () => { let zodSchemas: any; beforeAll(async () => { - const params = await loadSchema(schema, { fullZod: true }); + const params = await loadSchema(schema, { fullZod: true, generatePermissionChecker: true }); prisma = params.prisma; enhance = params.enhance; modelMeta = params.modelMeta; @@ -131,6 +131,37 @@ describe('RPC API Handler Tests', () => { expect(r.data.count).toBe(1); }); + it('check', async () => { + const handleRequest = makeHandler(); + + let r = await handleRequest({ + method: 'get', + path: '/post/check', + query: { q: JSON.stringify({ operation: 'read' }) }, + prisma: enhance(), + }); + expect(r.status).toBe(200); + expect(r.data).toEqual(true); + + r = await handleRequest({ + method: 'get', + path: '/post/check', + query: { q: JSON.stringify({ operation: 'read', filter: { published: false } }) }, + prisma: enhance(), + }); + expect(r.status).toBe(200); + expect(r.data).toEqual(false); + + r = await handleRequest({ + method: 'get', + path: '/post/check', + query: { q: JSON.stringify({ operation: 'read', filter: { authorId: '1', published: false } }) }, + prisma: enhance({ id: '1' }), + }); + expect(r.status).toBe(200); + expect(r.data).toEqual(true); + }); + it('policy violation', async () => { await prisma.user.create({ data: { From 64e0c98acc4cd018c5169f3b325a3d4b617290ef Mon Sep 17 00:00:00 2001 From: ymc9 <104139426+ymc9@users.noreply.github.com> Date: Tue, 7 May 2024 21:19:34 +0800 Subject: [PATCH 13/13] fix: add input check validation --- .../src/enhancements/policy/handler.ts | 34 +++++++++++++++---- 1 file changed, 28 insertions(+), 6 deletions(-) diff --git a/packages/runtime/src/enhancements/policy/handler.ts b/packages/runtime/src/enhancements/policy/handler.ts index f84af001a..cc0ea4f03 100644 --- a/packages/runtime/src/enhancements/policy/handler.ts +++ b/packages/runtime/src/enhancements/policy/handler.ts @@ -1457,6 +1457,10 @@ export class PolicyProxyHandler implements Pr } private async doCheck(args: PermissionCheckArgs) { + if (!['create', 'read', 'update', 'delete'].includes(args.operation)) { + throw prismaClientValidationError(this.prisma, this.prismaModule, `Invalid "operation" ${args.operation}`); + } + let constraint = this.policyUtils.getCheckerConstraint(this.model, args.operation); if (typeof constraint === 'boolean') { return constraint; @@ -1472,14 +1476,20 @@ export class PolicyProxyHandler implements Pr } if (value === null) { - throw new Error(`Using "null" as filter value is not supported yet`); + throw prismaClientValidationError( + this.prisma, + this.prismaModule, + `Using "null" as filter value is not supported yet` + ); } const fieldInfo = requireField(this.modelMeta, this.model, field); // relation and array fields are not supported if (fieldInfo.isDataModel || fieldInfo.isArray) { - throw new Error( + throw prismaClientValidationError( + this.prisma, + this.prismaModule, `Providing filter for field "${field}" is not supported. Only scalar fields are allowed.` ); } @@ -1490,7 +1500,9 @@ export class PolicyProxyHandler implements Pr .with('String', () => 'string') .with('Boolean', () => 'boolean') .otherwise(() => { - throw new Error( + throw prismaClientValidationError( + this.prisma, + this.prismaModule, `Providing filter for field "${field}" is not supported. Only number, string, and boolean fields are allowed.` ); }); @@ -1498,18 +1510,28 @@ export class PolicyProxyHandler implements Pr // check value type const valueType = typeof value; if (valueType !== 'number' && valueType !== 'string' && valueType !== 'boolean') { - throw new Error( + throw prismaClientValidationError( + this.prisma, + this.prismaModule, `Invalid value type for field "${field}". Only number, string or boolean is allowed.` ); } if (fieldType !== valueType) { - throw new Error(`Invalid value type for field "${field}". Expected "${fieldType}".`); + throw prismaClientValidationError( + this.prisma, + this.prismaModule, + `Invalid value type for field "${field}". Expected "${fieldType}".` + ); } // check number validity if (typeof value === 'number' && (!Number.isInteger(value) || value < 0)) { - throw new Error(`Invalid value for field "${field}". Only non-negative integers are allowed.`); + throw prismaClientValidationError( + this.prisma, + this.prismaModule, + `Invalid value for field "${field}". Only non-negative integers are allowed.` + ); } // build a constraint