Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 20 additions & 20 deletions packages/runtime/src/enhancements/policy/handler.ts
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
/* eslint-disable @typescript-eslint/no-explicit-any */

import { PrismaClientValidationError } from '@prisma/client/runtime';
import { CrudFailureReason } from '@zenstackhq/sdk';
import { AuthUser, DbClientContract, PolicyOperationKind } from '../../types';
import { BatchResult, PrismaProxyHandler } from '../proxy';
import { ModelMeta, PolicyDef } from '../types';
import { prismaClientValidationError } from '../utils';
import { Logger } from './logger';
import { PolicyUtil } from './policy-utils';

Expand Down Expand Up @@ -32,10 +32,10 @@ export class PolicyProxyHandler<DbClient extends DbClientContract> implements Pr

async findUnique(args: any) {
if (!args) {
throw new PrismaClientValidationError('query argument is required');
throw prismaClientValidationError(this.prisma, 'query argument is required');
}
if (!args.where) {
throw new PrismaClientValidationError('where field is required in query argument');
throw prismaClientValidationError(this.prisma, 'where field is required in query argument');
}

const entities = await this.utils.readWithCheck(this.model, args);
Expand Down Expand Up @@ -69,10 +69,10 @@ export class PolicyProxyHandler<DbClient extends DbClientContract> implements Pr

async create(args: any) {
if (!args) {
throw new PrismaClientValidationError('query argument is required');
throw prismaClientValidationError(this.prisma, 'query argument is required');
}
if (!args.data) {
throw new PrismaClientValidationError('data field is required in query argument');
throw prismaClientValidationError(this.prisma, 'data field is required in query argument');
}

await this.tryReject('create');
Expand All @@ -96,10 +96,10 @@ export class PolicyProxyHandler<DbClient extends DbClientContract> implements Pr

async createMany(args: any, skipDuplicates?: boolean) {
if (!args) {
throw new PrismaClientValidationError('query argument is required');
throw prismaClientValidationError(this.prisma, 'query argument is required');
}
if (!args.data) {
throw new PrismaClientValidationError('data field is required and must be an array');
throw prismaClientValidationError(this.prisma, 'data field is required and must be an array');
}

await this.tryReject('create');
Expand All @@ -117,13 +117,13 @@ export class PolicyProxyHandler<DbClient extends DbClientContract> implements Pr

async update(args: any) {
if (!args) {
throw new PrismaClientValidationError('query argument is required');
throw prismaClientValidationError(this.prisma, 'query argument is required');
}
if (!args.where) {
throw new PrismaClientValidationError('where field is required in query argument');
throw prismaClientValidationError(this.prisma, 'where field is required in query argument');
}
if (!args.data) {
throw new PrismaClientValidationError('data field is required in query argument');
throw prismaClientValidationError(this.prisma, 'data field is required in query argument');
}

await this.tryReject('update');
Expand All @@ -146,10 +146,10 @@ export class PolicyProxyHandler<DbClient extends DbClientContract> implements Pr

async updateMany(args: any) {
if (!args) {
throw new PrismaClientValidationError('query argument is required');
throw prismaClientValidationError(this.prisma, 'query argument is required');
}
if (!args.data) {
throw new PrismaClientValidationError('data field is required in query argument');
throw prismaClientValidationError(this.prisma, 'data field is required in query argument');
}

await this.tryReject('update');
Expand All @@ -167,16 +167,16 @@ export class PolicyProxyHandler<DbClient extends DbClientContract> implements Pr

async upsert(args: any) {
if (!args) {
throw new PrismaClientValidationError('query argument is required');
throw prismaClientValidationError(this.prisma, 'query argument is required');
}
if (!args.where) {
throw new PrismaClientValidationError('where field is required in query argument');
throw prismaClientValidationError(this.prisma, 'where field is required in query argument');
}
if (!args.create) {
throw new PrismaClientValidationError('create field is required in query argument');
throw prismaClientValidationError(this.prisma, 'create field is required in query argument');
}
if (!args.update) {
throw new PrismaClientValidationError('update field is required in query argument');
throw prismaClientValidationError(this.prisma, 'update field is required in query argument');
}

const origArgs = args;
Expand All @@ -201,10 +201,10 @@ export class PolicyProxyHandler<DbClient extends DbClientContract> implements Pr

async delete(args: any) {
if (!args) {
throw new PrismaClientValidationError('query argument is required');
throw prismaClientValidationError(this.prisma, 'query argument is required');
}
if (!args.where) {
throw new PrismaClientValidationError('where field is required in query argument');
throw prismaClientValidationError(this.prisma, 'where field is required in query argument');
}

await this.tryReject('delete');
Expand Down Expand Up @@ -250,7 +250,7 @@ export class PolicyProxyHandler<DbClient extends DbClientContract> implements Pr

async aggregate(args: any) {
if (!args) {
throw new PrismaClientValidationError('query argument is required');
throw prismaClientValidationError(this.prisma, 'query argument is required');
}

await this.tryReject('read');
Expand All @@ -262,7 +262,7 @@ export class PolicyProxyHandler<DbClient extends DbClientContract> implements Pr

async groupBy(args: any) {
if (!args) {
throw new PrismaClientValidationError('query argument is required');
throw prismaClientValidationError(this.prisma, 'query argument is required');
}

await this.tryReject('read');
Expand Down
12 changes: 6 additions & 6 deletions packages/runtime/src/enhancements/policy/policy-utils.ts
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
/* eslint-disable @typescript-eslint/no-explicit-any */

import { createId } from '@paralleldrive/cuid2';
import { PrismaClientKnownRequestError, PrismaClientUnknownRequestError } from '@prisma/client/runtime';
import { AUXILIARY_FIELDS, CrudFailureReason, GUARD_FIELD_NAME, TRANSACTION_FIELD_NAME } from '@zenstackhq/sdk';
import { lowerCaseFirst } from 'lower-case-first';
import deepcopy from 'deepcopy';
import { lowerCaseFirst } from 'lower-case-first';
import pluralize from 'pluralize';
import { fromZodError } from 'zod-validation-error';
import {
Expand All @@ -19,7 +18,7 @@ import { getVersion } from '../../version';
import { resolveField } from '../model-meta';
import { NestedWriteVisitor, VisitorContext } from '../nested-write-vistor';
import { ModelMeta, PolicyDef, PolicyFunc } from '../types';
import { enumerate, getModelFields } from '../utils';
import { enumerate, getModelFields, prismaClientKnownRequestError, prismaClientUnknownRequestError } from '../utils';
import { Logger } from './logger';

/**
Expand Down Expand Up @@ -707,21 +706,22 @@ export class PolicyUtil {
}

deniedByPolicy(model: string, operation: PolicyOperationKind, extra?: string, reason?: CrudFailureReason) {
return new PrismaClientKnownRequestError(
return prismaClientKnownRequestError(
this.db,
`denied by policy: ${model} entities failed '${operation}' check${extra ? ', ' + extra : ''}`,
{ clientVersion: getVersion(), code: 'P2004', meta: { reason } }
);
}

notFound(model: string) {
return new PrismaClientKnownRequestError(`entity not found for model ${model}`, {
return prismaClientKnownRequestError(this.db, `entity not found for model ${model}`, {
clientVersion: getVersion(),
code: 'P2025',
});
}

unknownError(message: string) {
return new PrismaClientUnknownRequestError(message, {
return prismaClientUnknownRequestError(this.db, message, {
clientVersion: getVersion(),
});
}
Expand Down
47 changes: 47 additions & 0 deletions packages/runtime/src/enhancements/utils.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
/* eslint-disable @typescript-eslint/no-var-requires */

import { AUXILIARY_FIELDS } from '@zenstackhq/sdk';
import path from 'path';
import * as util from 'util';
import { DbClientContract } from '../types';

/**
* Wraps a value into array if it's not already one
Expand Down Expand Up @@ -34,3 +38,46 @@ export function enumerate<T>(x: Enumerable<T>) {
export function formatObject(value: unknown) {
return util.formatWithOptions({ depth: 10 }, value);
}

let _PrismaClientValidationError: new (...args: unknown[]) => Error;
let _PrismaClientKnownRequestError: new (...args: unknown[]) => Error;
let _PrismaClientUnknownRequestError: new (...args: unknown[]) => Error;

/* eslint-disable @typescript-eslint/no-explicit-any */
function loadPrismaModule(prisma: any) {
// https://github.com/prisma/prisma/discussions/17832
if (prisma._engineConfig?.datamodelPath) {
const loadPath = path.dirname(prisma._engineConfig.datamodelPath);
try {
return require(loadPath).Prisma;
} catch {
return require('@prisma/client/runtime');
}
} else {
return require('@prisma/client/runtime');
}
}

export function prismaClientValidationError(prisma: DbClientContract, ...args: unknown[]) {
if (!_PrismaClientValidationError) {
const _prisma = loadPrismaModule(prisma);
_PrismaClientValidationError = _prisma.PrismaClientValidationError;
}
throw new _PrismaClientValidationError(...args);
}

export function prismaClientKnownRequestError(prisma: DbClientContract, ...args: unknown[]) {
if (!_PrismaClientKnownRequestError) {
const _prisma = loadPrismaModule(prisma);
_PrismaClientKnownRequestError = _prisma.PrismaClientKnownRequestError;
}
return new _PrismaClientKnownRequestError(...args);
}

export function prismaClientUnknownRequestError(prisma: DbClientContract, ...args: unknown[]) {
if (!_PrismaClientUnknownRequestError) {
const _prisma = loadPrismaModule(prisma);
_PrismaClientUnknownRequestError = _prisma.PrismaClientUnknownRequestError;
}
throw new _PrismaClientUnknownRequestError(...args);
}
12 changes: 12 additions & 0 deletions tests/integration/test-run/package-lock.json

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

10 changes: 8 additions & 2 deletions tests/integration/utils/jest-ext.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import { PrismaClientKnownRequestError } from '@prisma/client/runtime';
import { format } from 'util';
import { isPrismaClientKnownRequestError } from '@zenstackhq/runtime/error';

export const toBeRejectedByPolicy = async function (received: Promise<unknown>, expectedMessages?: string[]) {
if (!(received instanceof Promise)) {
Expand Down Expand Up @@ -133,7 +133,13 @@ export const toResolveNull = async function (received: Promise<unknown>) {
};

function expectPrismaCode(err: any, code: string) {
const errCode = (err as PrismaClientKnownRequestError).code;
if (!isPrismaClientKnownRequestError(err)) {
return {
message: () => `expected PrismaClientKnownRequestError', got ${err}`,
pass: false,
};
}
const errCode = err.code;
if (errCode !== code) {
return {
message: () => `expected PrismaClientKnownRequestError.code 'P2004', got ${errCode ?? err}`,
Expand Down