Skip to content

Commit 5fe85ff

Browse files
authored
fix: post-update rule for id field is not effective if id is updated (#1237)
1 parent e3fb73a commit 5fe85ff

File tree

7 files changed

+117
-38
lines changed

7 files changed

+117
-38
lines changed

packages/plugins/trpc/tests/projects/t3-trpc-v10/src/server/api/routers/generated/routers/Post.router.ts

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,10 @@ export default function createRouter<Config extends BaseConfig>(
2323
.input($Schema.PostInputSchema.aggregate)
2424
.query(({ ctx, input }) => checkRead(db(ctx).post.aggregate(input as any))),
2525

26+
createMany: procedure
27+
.input($Schema.PostInputSchema.createMany)
28+
.mutation(async ({ ctx, input }) => checkMutate(db(ctx).post.createMany(input as any))),
29+
2630
create: procedure
2731
.input($Schema.PostInputSchema.create)
2832
.mutation(async ({ ctx, input }) => checkMutate(db(ctx).post.create(input as any))),
@@ -88,6 +92,29 @@ export interface ClientType<AppRouter extends AnyRouter, Context = AppRouter['_d
8892
opts?: UseTRPCInfiniteQueryOptions<string, T, Prisma.GetPostAggregateType<T>, Error>,
8993
) => UseTRPCInfiniteQueryResult<Prisma.GetPostAggregateType<T>, TRPCClientErrorLike<AppRouter>>;
9094
};
95+
createMany: {
96+
useMutation: <T extends Prisma.PostCreateManyArgs>(
97+
opts?: UseTRPCMutationOptions<
98+
Prisma.PostCreateManyArgs,
99+
TRPCClientErrorLike<AppRouter>,
100+
Prisma.BatchPayload,
101+
Context
102+
>,
103+
) => Omit<
104+
UseTRPCMutationResult<
105+
Prisma.BatchPayload,
106+
TRPCClientErrorLike<AppRouter>,
107+
Prisma.SelectSubset<T, Prisma.PostCreateManyArgs>,
108+
Context
109+
>,
110+
'mutateAsync'
111+
> & {
112+
mutateAsync: <T extends Prisma.PostCreateManyArgs>(
113+
variables: T,
114+
opts?: UseTRPCMutationOptions<T, TRPCClientErrorLike<AppRouter>, Prisma.BatchPayload, Context>,
115+
) => Promise<Prisma.BatchPayload>;
116+
};
117+
};
91118
create: {
92119
useMutation: <T extends Prisma.PostCreateArgs>(
93120
opts?: UseTRPCMutationOptions<

packages/plugins/trpc/tests/projects/t3-trpc-v10/src/server/api/routers/generated/routers/User.router.ts

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,10 @@ export default function createRouter<Config extends BaseConfig>(
2323
.input($Schema.UserInputSchema.aggregate)
2424
.query(({ ctx, input }) => checkRead(db(ctx).user.aggregate(input as any))),
2525

26+
createMany: procedure
27+
.input($Schema.UserInputSchema.createMany)
28+
.mutation(async ({ ctx, input }) => checkMutate(db(ctx).user.createMany(input as any))),
29+
2630
create: procedure
2731
.input($Schema.UserInputSchema.create)
2832
.mutation(async ({ ctx, input }) => checkMutate(db(ctx).user.create(input as any))),
@@ -88,6 +92,29 @@ export interface ClientType<AppRouter extends AnyRouter, Context = AppRouter['_d
8892
opts?: UseTRPCInfiniteQueryOptions<string, T, Prisma.GetUserAggregateType<T>, Error>,
8993
) => UseTRPCInfiniteQueryResult<Prisma.GetUserAggregateType<T>, TRPCClientErrorLike<AppRouter>>;
9094
};
95+
createMany: {
96+
useMutation: <T extends Prisma.UserCreateManyArgs>(
97+
opts?: UseTRPCMutationOptions<
98+
Prisma.UserCreateManyArgs,
99+
TRPCClientErrorLike<AppRouter>,
100+
Prisma.BatchPayload,
101+
Context
102+
>,
103+
) => Omit<
104+
UseTRPCMutationResult<
105+
Prisma.BatchPayload,
106+
TRPCClientErrorLike<AppRouter>,
107+
Prisma.SelectSubset<T, Prisma.UserCreateManyArgs>,
108+
Context
109+
>,
110+
'mutateAsync'
111+
> & {
112+
mutateAsync: <T extends Prisma.UserCreateManyArgs>(
113+
variables: T,
114+
opts?: UseTRPCMutationOptions<T, TRPCClientErrorLike<AppRouter>, Prisma.BatchPayload, Context>,
115+
) => Promise<Prisma.BatchPayload>;
116+
};
117+
};
91118
create: {
92119
useMutation: <T extends Prisma.UserCreateArgs>(
93120
opts?: UseTRPCMutationOptions<

packages/runtime/src/enhancements/policy/handler.ts

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -690,16 +690,25 @@ export class PolicyProxyHandler<DbClient extends DbClientContract> implements Pr
690690
const postWriteChecks: PostWriteCheckRecord[] = [];
691691

692692
// registers a post-update check task
693-
const _registerPostUpdateCheck = async (model: string, uniqueFilter: any) => {
693+
const _registerPostUpdateCheck = async (
694+
model: string,
695+
preUpdateLookupFilter: any,
696+
postUpdateLookupFilter: any
697+
) => {
694698
// both "post-update" rules and Zod schemas require a post-update check
695699
if (this.utils.hasAuthGuard(model, 'postUpdate') || this.utils.getZodSchema(model)) {
696700
// select pre-update field values
697701
let preValue: any;
698702
const preValueSelect = this.utils.getPreValueSelect(model);
699703
if (preValueSelect && Object.keys(preValueSelect).length > 0) {
700-
preValue = await db[model].findFirst({ where: uniqueFilter, select: preValueSelect });
704+
preValue = await db[model].findFirst({ where: preUpdateLookupFilter, select: preValueSelect });
701705
}
702-
postWriteChecks.push({ model, operation: 'postUpdate', uniqueFilter, preValue });
706+
postWriteChecks.push({
707+
model,
708+
operation: 'postUpdate',
709+
uniqueFilter: postUpdateLookupFilter,
710+
preValue,
711+
});
703712
}
704713
};
705714

@@ -826,7 +835,7 @@ export class PolicyProxyHandler<DbClient extends DbClientContract> implements Pr
826835
await this.utils.checkPolicyForUnique(model, args, 'update', db, checkArgs);
827836

828837
// register post-update check
829-
await _registerPostUpdateCheck(model, args);
838+
await _registerPostUpdateCheck(model, args, args);
830839
}
831840
}
832841
};
@@ -873,20 +882,20 @@ export class PolicyProxyHandler<DbClient extends DbClientContract> implements Pr
873882
await this.utils.checkPolicyForUnique(model, uniqueFilter, 'update', db, args);
874883

875884
// handles the case where id fields are updated
876-
const ids = this.utils.clone(existing);
885+
const postUpdateIds = this.utils.clone(existing);
877886
for (const key of Object.keys(existing)) {
878887
const updateValue = (args as any).data ? (args as any).data[key] : (args as any)[key];
879888
if (
880889
typeof updateValue === 'string' ||
881890
typeof updateValue === 'number' ||
882891
typeof updateValue === 'bigint'
883892
) {
884-
ids[key] = updateValue;
893+
postUpdateIds[key] = updateValue;
885894
}
886895
}
887896

888897
// register post-update check
889-
await _registerPostUpdateCheck(model, ids);
898+
await _registerPostUpdateCheck(model, existing, postUpdateIds);
890899
}
891900
},
892901

@@ -978,7 +987,7 @@ export class PolicyProxyHandler<DbClient extends DbClientContract> implements Pr
978987
await this.utils.checkPolicyForUnique(model, uniqueFilter, 'update', db, args);
979988

980989
// register post-update check
981-
await _registerPostUpdateCheck(model, uniqueFilter);
990+
await _registerPostUpdateCheck(model, uniqueFilter, uniqueFilter);
982991

983992
// convert upsert to update
984993
const convertedUpdate = {

packages/schema/src/plugins/access-policy/expression-writer.ts

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,8 @@ export class ExpressionWriter {
7070
this.plainExprBuilder = new TypeScriptExpressionTransformer({
7171
context: ExpressionContext.AccessPolicy,
7272
isPostGuard: this.isPostGuard,
73+
// in post-guard context, `this` references pre-update value
74+
thisExprContext: this.isPostGuard ? 'context.preValue' : undefined,
7375
});
7476
}
7577

packages/schema/src/plugins/access-policy/policy-guard-generator.ts

Lines changed: 8 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@ import {
66
Enum,
77
Expression,
88
Model,
9-
isBinaryExpr,
109
isDataModel,
1110
isDataModelField,
1211
isEnum,
@@ -15,7 +14,6 @@ import {
1514
isMemberAccessExpr,
1615
isReferenceExpr,
1716
isThisExpr,
18-
isUnaryExpr,
1917
} from '@zenstackhq/language/ast';
2018
import {
2119
FIELD_LEVEL_OVERRIDE_READ_GUARD_PREFIX,
@@ -281,30 +279,6 @@ export default class PolicyGenerator {
281279
}
282280
}
283281

284-
private visitPolicyExpression(expr: Expression, postUpdate: boolean): Expression | undefined {
285-
if (isBinaryExpr(expr) && (expr.operator === '&&' || expr.operator === '||')) {
286-
const left = this.visitPolicyExpression(expr.left, postUpdate);
287-
const right = this.visitPolicyExpression(expr.right, postUpdate);
288-
if (!left) return right;
289-
if (!right) return left;
290-
return { ...expr, left, right };
291-
}
292-
293-
if (isUnaryExpr(expr) && expr.operator === '!') {
294-
const operand = this.visitPolicyExpression(expr.operand, postUpdate);
295-
if (!operand) return undefined;
296-
return { ...expr, operand };
297-
}
298-
299-
if (postUpdate && !this.hasFutureReference(expr)) {
300-
return undefined;
301-
} else if (!postUpdate && this.hasFutureReference(expr)) {
302-
return undefined;
303-
}
304-
305-
return expr;
306-
}
307-
308282
private hasFutureReference(expr: Expression) {
309283
for (const node of streamAst(expr)) {
310284
if (isInvocationExpr(node) && node.function.ref?.name === 'future' && isFromStdlib(node.function.ref)) {
@@ -599,13 +573,19 @@ export default class PolicyGenerator {
599573
// visit a reference or member access expression to build a
600574
// selection path
601575
const visit = (node: Expression): string[] | undefined => {
576+
if (isThisExpr(node)) {
577+
return [];
578+
}
579+
602580
if (isReferenceExpr(node)) {
603581
const target = resolved(node.target);
604582
if (isDataModelField(target)) {
605583
// a field selection, it's a terminal
606584
return [target.name];
607585
}
608-
} else if (isMemberAccessExpr(node)) {
586+
}
587+
588+
if (isMemberAccessExpr(node)) {
609589
if (forAuthContext && isAuthInvocation(node.operand)) {
610590
return [node.member.$refText];
611591
}
@@ -621,6 +601,7 @@ export default class PolicyGenerator {
621601
return [...inner, node.member.$refText];
622602
}
623603
}
604+
624605
return undefined;
625606
};
626607

packages/schema/src/utils/typescript-expression-transformer.ts

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -112,9 +112,7 @@ export class TypeScriptExpressionTransformer {
112112
throw new TypeScriptExpressionTransformerError(`Unresolved MemberAccessExpr`);
113113
}
114114

115-
if (isThisExpr(expr.operand)) {
116-
return expr.member.ref.name;
117-
} else if (isFutureExpr(expr.operand)) {
115+
if (isFutureExpr(expr.operand)) {
118116
if (this.options?.isPostGuard !== true) {
119117
throw new TypeScriptExpressionTransformerError(`future() is only supported in postUpdate rules`);
120118
}
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
import { loadSchema } from '@zenstackhq/testtools';
2+
3+
describe('issue 1235', () => {
4+
it('regression1', async () => {
5+
const { enhance } = await loadSchema(
6+
`
7+
model Post {
8+
id Int @id @default(autoincrement())
9+
@@deny("update", future().id != id)
10+
@@allow("all", true)
11+
}
12+
`
13+
);
14+
15+
const db = enhance();
16+
const post = await db.post.create({ data: {} });
17+
await expect(db.post.update({ data: { id: post.id + 1 }, where: { id: post.id } })).toBeRejectedByPolicy();
18+
});
19+
20+
it('regression2', async () => {
21+
const { enhance } = await loadSchema(
22+
`
23+
model Post {
24+
id Int @id @default(autoincrement())
25+
@@deny("update", future().id != this.id)
26+
@@allow("all", true)
27+
}
28+
`
29+
);
30+
31+
const db = enhance();
32+
const post = await db.post.create({ data: {} });
33+
await expect(db.post.update({ data: { id: post.id + 1 }, where: { id: post.id } })).toBeRejectedByPolicy();
34+
});
35+
});

0 commit comments

Comments
 (0)