diff --git a/packages/runtime/src/enhancements/node/policy/handler.ts b/packages/runtime/src/enhancements/node/policy/handler.ts index 673665dd5..5c5fdd4ca 100644 --- a/packages/runtime/src/enhancements/node/policy/handler.ts +++ b/packages/runtime/src/enhancements/node/policy/handler.ts @@ -284,9 +284,16 @@ export class PolicyProxyHandler implements Pr if (context.field?.backLink) { const backLinkField = resolveField(this.modelMeta, model, context.field.backLink); if (backLinkField?.isRelationOwner) { - // the target side of relation owns the relation, - // check if it's updatable - await this.policyUtils.checkPolicyForUnique(model, args.where, 'update', db, args); + // "connect" is actually "update" to foreign keys, so we need to map the "connect" payload + // to "update" payload by translating pk to fks, and use that to check update policies + const fieldsToUpdate = Object.values(backLinkField.foreignKeyMapping ?? {}); + await this.policyUtils.checkPolicyForUnique( + model, + args.where, + 'update', + db, + fieldsToUpdate + ); } } @@ -319,9 +326,12 @@ export class PolicyProxyHandler implements Pr // check existence await this.policyUtils.checkExistence(db, model, args, true); - // the target side of relation owns the relation, - // check if it's updatable - await this.policyUtils.checkPolicyForUnique(model, args, 'update', db, args); + // the target side of relation owns the relation, check if it's updatable + + // "connect" is actually "update" to foreign keys, so we need to map the "connect" payload + // to "update" payload by translating pk to fks, and use that to check update policies + const fieldsToUpdate = Object.values(backLinkField.foreignKeyMapping ?? {}); + await this.policyUtils.checkPolicyForUnique(model, args, 'update', db, fieldsToUpdate); } } }, @@ -909,21 +919,11 @@ export class PolicyProxyHandler implements Pr } // update happens on the related model, require updatable, - // translate args to foreign keys so field-level policies can be checked - const checkArgs: any = {}; - if (args && typeof args === 'object' && backLinkField.foreignKeyMapping) { - for (const key of Object.keys(args)) { - const fk = backLinkField.foreignKeyMapping[key]; - if (fk) { - checkArgs[fk] = args[key]; - } - } - } - // `uniqueFilter` can be undefined if the entity to be disconnected doesn't exist if (uniqueFilter) { - // check for update - await this.policyUtils.checkPolicyForUnique(model, uniqueFilter, 'update', db, checkArgs); + // check for update, "connect" and "disconnect" are actually "update" to foreign keys + const fieldsToUpdate = Object.values(backLinkField.foreignKeyMapping ?? {}); + await this.policyUtils.checkPolicyForUnique(model, uniqueFilter, 'update', db, fieldsToUpdate); // register post-update check await _registerPostUpdateCheck(model, uniqueFilter, uniqueFilter); @@ -971,12 +971,18 @@ export class PolicyProxyHandler implements Pr this.policyUtils.tryReject(db, this.model, 'update'); // check pre-update guard - await this.policyUtils.checkPolicyForUnique(model, uniqueFilter, 'update', db, args); + await this.policyUtils.checkPolicyForUnique( + model, + uniqueFilter, + 'update', + db, + this.queryUtils.getFieldsWithDefinedValues(updatePayload) + ); // handle the case where id fields are updated const _args: any = args; - const updatePayload = _args.data && typeof _args.data === 'object' ? _args.data : _args; - const postUpdateIds = this.calculatePostUpdateIds(model, existing, updatePayload); + const checkPayload = _args.data && typeof _args.data === 'object' ? _args.data : _args; + const postUpdateIds = this.calculatePostUpdateIds(model, existing, checkPayload); // register post-update check await _registerPostUpdateCheck(model, existing, postUpdateIds); @@ -1068,7 +1074,13 @@ export class PolicyProxyHandler implements Pr // update case // check pre-update guard - await this.policyUtils.checkPolicyForUnique(model, existing, 'update', db, args); + await this.policyUtils.checkPolicyForUnique( + model, + existing, + 'update', + db, + this.queryUtils.getFieldsWithDefinedValues(args.update) + ); // handle the case where id fields are updated const postUpdateIds = this.calculatePostUpdateIds(model, existing, args.update); @@ -1156,7 +1168,7 @@ export class PolicyProxyHandler implements Pr await this.policyUtils.checkExistence(db, model, uniqueFilter, true); // check delete guard - await this.policyUtils.checkPolicyForUnique(model, uniqueFilter, 'delete', db, args); + await this.policyUtils.checkPolicyForUnique(model, uniqueFilter, 'delete', db, []); }, deleteMany: async (model, args, context) => { @@ -1526,7 +1538,7 @@ export class PolicyProxyHandler implements Pr await this.policyUtils.checkExistence(tx, this.model, args.where, true); // inject delete guard - await this.policyUtils.checkPolicyForUnique(this.model, args.where, 'delete', tx, args); + await this.policyUtils.checkPolicyForUnique(this.model, args.where, 'delete', tx, []); // proceed with the deletion if (this.shouldLogQuery) { @@ -1773,7 +1785,7 @@ export class PolicyProxyHandler implements Pr private async runPostWriteChecks(postWriteChecks: PostWriteCheckRecord[], db: CrudContract) { await Promise.all( postWriteChecks.map(async ({ model, operation, uniqueFilter, preValue }) => - this.policyUtils.checkPolicyForUnique(model, uniqueFilter, operation, db, undefined, preValue) + this.policyUtils.checkPolicyForUnique(model, uniqueFilter, operation, db, [], preValue) ) ); } diff --git a/packages/runtime/src/enhancements/node/policy/policy-utils.ts b/packages/runtime/src/enhancements/node/policy/policy-utils.ts index 94c1f7f20..ef4285d78 100644 --- a/packages/runtime/src/enhancements/node/policy/policy-utils.ts +++ b/packages/runtime/src/enhancements/node/policy/policy-utils.ts @@ -451,7 +451,11 @@ export class PolicyUtil extends QueryUtils { if (operation === 'update' && args) { // merge field-level policy guards - const fieldUpdateGuard = this.getFieldUpdateGuards(db, model, args); + const fieldUpdateGuard = this.getFieldUpdateGuards( + db, + model, + this.getFieldsWithDefinedValues(args.data ?? args) + ); if (fieldUpdateGuard.rejectedByField) { // rejected args.where = this.makeFalse(); @@ -834,7 +838,7 @@ export class PolicyUtil extends QueryUtils { uniqueFilter: any, operation: PolicyOperationKind, db: CrudContract, - args: any, + fieldsToUpdate: string[], preValue?: any ) { let guard = this.getAuthGuard(db, model, operation, preValue); @@ -849,9 +853,9 @@ export class PolicyUtil extends QueryUtils { let entityChecker: EntityChecker | undefined; - if (operation === 'update' && args) { + if (operation === 'update' && fieldsToUpdate.length > 0) { // merge field-level policy guards - const fieldUpdateGuard = this.getFieldUpdateGuards(db, model, args); + const fieldUpdateGuard = this.getFieldUpdateGuards(db, model, fieldsToUpdate); if (fieldUpdateGuard.rejectedByField) { // rejected throw this.deniedByPolicy( @@ -989,16 +993,12 @@ export class PolicyUtil extends QueryUtils { return this.and(...allFieldGuards); } - private getFieldUpdateGuards(db: CrudContract, model: string, args: any) { + private getFieldUpdateGuards(db: CrudContract, model: string, fieldsToUpdate: string[]) { const allFieldGuards = []; const allOverrideFieldGuards = []; let entityChecker: EntityChecker | undefined; - for (const [field, value] of Object.entries(args.data ?? args)) { - if (typeof value === 'undefined') { - continue; - } - + for (const field of fieldsToUpdate) { const fieldInfo = resolveField(this.modelMeta, model, field); if (fieldInfo?.isDataModel) { diff --git a/packages/runtime/src/enhancements/node/query-utils.ts b/packages/runtime/src/enhancements/node/query-utils.ts index c09fe1f95..75e729b0f 100644 --- a/packages/runtime/src/enhancements/node/query-utils.ts +++ b/packages/runtime/src/enhancements/node/query-utils.ts @@ -253,4 +253,16 @@ export class QueryUtils { return undefined; } + + /** + * Gets fields of object with defined values. + */ + getFieldsWithDefinedValues(data: object) { + if (!data) { + return []; + } + return Object.entries(data) + .filter(([, v]) => v !== undefined) + .map(([k]) => k); + } } diff --git a/tests/regression/tests/issue-2007.test.ts b/tests/regression/tests/issue-2007.test.ts new file mode 100644 index 000000000..4a4b9cbe6 --- /dev/null +++ b/tests/regression/tests/issue-2007.test.ts @@ -0,0 +1,93 @@ +import { loadSchema } from '@zenstackhq/testtools'; + +describe('issue 2007', () => { + it('regression1', async () => { + const { enhance } = await loadSchema( + ` + model Page { + id String @id @default(cuid()) + title String + + images Image[] + + @@allow('all', true) + } + + model Image { + id String @id @default(cuid()) @deny('update', true) + url String + pageId String? + page Page? @relation(fields: [pageId], references: [id]) + + @@allow('all', true) + } + ` + ); + + const db = enhance(); + + const image = await db.image.create({ + data: { + url: 'https://example.com/image.png', + }, + }); + + await expect( + db.image.update({ + where: { id: image.id }, + data: { + page: { + create: { + title: 'Page 1', + }, + }, + }, + }) + ).toResolveTruthy(); + }); + + it('regression2', async () => { + const { enhance } = await loadSchema( + ` + model Page { + id String @id @default(cuid()) + title String + + images Image[] + + @@allow('all', true) + } + + model Image { + id String @id @default(cuid()) + url String + pageId String? @deny('update', true) + page Page? @relation(fields: [pageId], references: [id]) + + @@allow('all', true) + } + ` + ); + + const db = enhance(); + + const image = await db.image.create({ + data: { + url: 'https://example.com/image.png', + }, + }); + + await expect( + db.image.update({ + where: { id: image.id }, + data: { + page: { + create: { + title: 'Page 1', + }, + }, + }, + }) + ).toBeRejectedByPolicy(); + }); +});