Skip to content

Commit 68a0eb3

Browse files
authored
fix(runtime): intercepts $extends to reproxy its result to make sure enhancements persist (#1847)
1 parent f377441 commit 68a0eb3

File tree

4 files changed

+224
-38
lines changed

4 files changed

+224
-38
lines changed

packages/runtime/src/enhancements/node/proxy.ts

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -254,6 +254,25 @@ export function makeProxy<T extends PrismaProxyHandler>(
254254
}
255255
}
256256

257+
if (prop === '$extends') {
258+
// Prisma's `$extends` API returns a new client instance, we need to recreate
259+
// a proxy around it
260+
const $extends = Reflect.get(target, prop, receiver);
261+
if ($extends && typeof $extends === 'function') {
262+
return (...args: any[]) => {
263+
const result = $extends.bind(target)(...args);
264+
if (!result[PRISMA_PROXY_ENHANCER]) {
265+
return makeProxy(result, modelMeta, makeHandler, name + '$ext', errorTransformer);
266+
} else {
267+
// avoid double wrapping
268+
return result;
269+
}
270+
};
271+
} else {
272+
return $extends;
273+
}
274+
}
275+
257276
if (typeof prop !== 'string' || prop.startsWith('$') || !models.includes(prop.toLowerCase())) {
258277
// skip non-model fields
259278
return Reflect.get(target, prop, receiver);

tests/integration/tests/enhancements/with-policy/client-extensions.test.ts

Lines changed: 25 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -44,13 +44,9 @@ describe('With Policy: client extensions', () => {
4444
});
4545
});
4646

47-
const xprisma = prisma.$extends(ext);
48-
const db = enhanceRaw(xprisma);
49-
await expect(db.model.getAll()).resolves.toHaveLength(2);
50-
51-
// FIXME: extending an enhanced client doesn't work for this case
52-
// const db1 = enhance(prisma).$extends(ext);
53-
// await expect(db1.model.getAll()).resolves.toHaveLength(2);
47+
await expect(prisma.$extends(ext).model.getAll()).resolves.toHaveLength(3);
48+
await expect(enhanceRaw(prisma.$extends(ext)).model.getAll()).resolves.toHaveLength(2);
49+
await expect(enhanceRaw(prisma).$extends(ext).model.getAll()).resolves.toHaveLength(2);
5450
});
5551

5652
it('one model new method', async () => {
@@ -84,9 +80,9 @@ describe('With Policy: client extensions', () => {
8480
});
8581
});
8682

87-
const xprisma = prisma.$extends(ext);
88-
const db = enhanceRaw(xprisma);
89-
await expect(db.model.getAll()).resolves.toHaveLength(2);
83+
await expect(prisma.$extends(ext).model.getAll()).resolves.toHaveLength(3);
84+
await expect(enhanceRaw(prisma.$extends(ext)).model.getAll()).resolves.toHaveLength(2);
85+
await expect(enhanceRaw(prisma).$extends(ext).model.getAll()).resolves.toHaveLength(2);
9086
});
9187

9288
it('add client method', async () => {
@@ -115,8 +111,11 @@ describe('With Policy: client extensions', () => {
115111
});
116112
});
117113

118-
const xprisma = prisma.$extends(ext);
119-
xprisma.$log('abc');
114+
enhanceRaw(prisma).$extends(ext).$log('abc');
115+
expect(logged).toBeTruthy();
116+
117+
logged = false;
118+
enhanceRaw(prisma.$extends(ext)).$log('abc');
120119
expect(logged).toBeTruthy();
121120
});
122121

@@ -143,7 +142,6 @@ describe('With Policy: client extensions', () => {
143142
query: {
144143
model: {
145144
async findMany({ args, query }: any) {
146-
// take incoming `where` and set `age`
147145
args.where = { ...args.where, y: { lt: 300 } };
148146
return query(args);
149147
},
@@ -152,9 +150,8 @@ describe('With Policy: client extensions', () => {
152150
});
153151
});
154152

155-
const xprisma = prisma.$extends(ext);
156-
const db = enhanceRaw(xprisma);
157-
await expect(db.model.findMany()).resolves.toHaveLength(1);
153+
await expect(enhanceRaw(prisma.$extends(ext)).model.findMany()).resolves.toHaveLength(1);
154+
await expect(enhanceRaw(prisma).$extends(ext).model.findMany()).resolves.toHaveLength(1);
158155
});
159156

160157
it('query override all models', async () => {
@@ -180,7 +177,6 @@ describe('With Policy: client extensions', () => {
180177
query: {
181178
$allModels: {
182179
async findMany({ args, query }: any) {
183-
// take incoming `where` and set `age`
184180
args.where = { ...args.where, y: { lt: 300 } };
185181
console.log('findMany args:', args);
186182
return query(args);
@@ -190,9 +186,8 @@ describe('With Policy: client extensions', () => {
190186
});
191187
});
192188

193-
const xprisma = prisma.$extends(ext);
194-
const db = enhanceRaw(xprisma);
195-
await expect(db.model.findMany()).resolves.toHaveLength(1);
189+
await expect(enhanceRaw(prisma.$extends(ext)).model.findMany()).resolves.toHaveLength(1);
190+
await expect(enhanceRaw(prisma).$extends(ext).model.findMany()).resolves.toHaveLength(1);
196191
});
197192

198193
it('query override all operations', async () => {
@@ -218,7 +213,6 @@ describe('With Policy: client extensions', () => {
218213
query: {
219214
model: {
220215
async $allOperations({ operation, args, query }: any) {
221-
// take incoming `where` and set `age`
222216
args.where = { ...args.where, y: { lt: 300 } };
223217
console.log(`${operation} args:`, args);
224218
return query(args);
@@ -228,9 +222,8 @@ describe('With Policy: client extensions', () => {
228222
});
229223
});
230224

231-
const xprisma = prisma.$extends(ext);
232-
const db = enhanceRaw(xprisma);
233-
await expect(db.model.findMany()).resolves.toHaveLength(1);
225+
await expect(enhanceRaw(prisma.$extends(ext)).model.findMany()).resolves.toHaveLength(1);
226+
await expect(enhanceRaw(prisma).$extends(ext).model.findMany()).resolves.toHaveLength(1);
234227
});
235228

236229
it('query override everything', async () => {
@@ -255,7 +248,6 @@ describe('With Policy: client extensions', () => {
255248
name: 'prisma-extension-queryOverride',
256249
query: {
257250
async $allOperations({ operation, args, query }: any) {
258-
// take incoming `where` and set `age`
259251
args.where = { ...args.where, y: { lt: 300 } };
260252
console.log(`${operation} args:`, args);
261253
return query(args);
@@ -264,9 +256,8 @@ describe('With Policy: client extensions', () => {
264256
});
265257
});
266258

267-
const xprisma = prisma.$extends(ext);
268-
const db = enhanceRaw(xprisma);
269-
await expect(db.model.findMany()).resolves.toHaveLength(1);
259+
await expect(enhanceRaw(prisma.$extends(ext)).model.findMany()).resolves.toHaveLength(1);
260+
await expect(enhanceRaw(prisma).$extends(ext).model.findMany()).resolves.toHaveLength(1);
270261
});
271262

272263
it('result mutation', async () => {
@@ -301,11 +292,9 @@ describe('With Policy: client extensions', () => {
301292
});
302293
});
303294

304-
const xprisma = prisma.$extends(ext);
305-
const db = enhanceRaw(xprisma);
306-
const r = await db.model.findMany();
307-
expect(r).toHaveLength(1);
308-
expect(r).toEqual(expect.arrayContaining([expect.objectContaining({ value: 2 })]));
295+
const expected = [expect.objectContaining({ value: 2 })];
296+
await expect(enhanceRaw(prisma.$extends(ext)).model.findMany()).resolves.toEqual(expected);
297+
await expect(enhanceRaw(prisma).$extends(ext).model.findMany()).resolves.toEqual(expected);
309298
});
310299

311300
it('result custom fields', async () => {
@@ -339,10 +328,8 @@ describe('With Policy: client extensions', () => {
339328
});
340329
});
341330

342-
const xprisma = prisma.$extends(ext);
343-
const db = enhanceRaw(xprisma);
344-
const r = await db.model.findMany();
345-
expect(r).toHaveLength(1);
346-
expect(r).toEqual(expect.arrayContaining([expect.objectContaining({ doubleValue: 2 })]));
331+
const expected = [expect.objectContaining({ doubleValue: 2 })];
332+
await expect(enhanceRaw(prisma.$extends(ext)).model.findMany()).resolves.toEqual(expected);
333+
await expect(enhanceRaw(prisma).$extends(ext).model.findMany()).resolves.toEqual(expected);
347334
});
348335
});
Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
import { loadSchema } from '@zenstackhq/testtools';
2+
3+
describe('issue 1859', () => {
4+
it('extend enhanced client', async () => {
5+
const { enhance, prisma } = await loadSchema(
6+
`
7+
model Post {
8+
id Int @id
9+
title String
10+
published Boolean
11+
12+
@@allow('create', true)
13+
@@allow('read', published)
14+
}
15+
`
16+
);
17+
18+
await prisma.post.create({ data: { id: 1, title: 'post1', published: true } });
19+
await prisma.post.create({ data: { id: 2, title: 'post2', published: false } });
20+
21+
const db = enhance();
22+
await expect(db.post.findMany()).resolves.toHaveLength(1);
23+
24+
const extended = db.$extends({
25+
model: {
26+
post: {
27+
findManyListView: async (args: any) => {
28+
return { view: true, data: await db.post.findMany(args) };
29+
},
30+
},
31+
},
32+
});
33+
34+
await expect(extended.post.findManyListView()).resolves.toMatchObject({
35+
view: true,
36+
data: [{ id: 1, title: 'post1', published: true }],
37+
});
38+
await expect(extended.post.findMany()).resolves.toHaveLength(1);
39+
});
40+
41+
it('enhance extended client', async () => {
42+
const { enhanceRaw, prisma, prismaModule } = await loadSchema(
43+
`
44+
model Post {
45+
id Int @id
46+
title String
47+
published Boolean
48+
49+
@@allow('create', true)
50+
@@allow('read', published)
51+
}
52+
`
53+
);
54+
55+
await prisma.post.create({ data: { id: 1, title: 'post1', published: true } });
56+
await prisma.post.create({ data: { id: 2, title: 'post2', published: false } });
57+
58+
const ext = prismaModule.defineExtension((_prisma: any) => {
59+
return _prisma.$extends({
60+
model: {
61+
post: {
62+
findManyListView: async (args: any) => {
63+
return { view: true, data: await prisma.post.findMany(args) };
64+
},
65+
},
66+
},
67+
});
68+
});
69+
70+
await expect(prisma.$extends(ext).post.findMany()).resolves.toHaveLength(2);
71+
await expect(prisma.$extends(ext).post.findManyListView()).resolves.toMatchObject({
72+
view: true,
73+
data: [
74+
{ id: 1, title: 'post1', published: true },
75+
{ id: 2, title: 'post2', published: false },
76+
],
77+
});
78+
79+
const enhanced = enhanceRaw(prisma.$extends(ext));
80+
await expect(enhanced.post.findMany()).resolves.toHaveLength(1);
81+
// findManyListView internally uses the un-enhanced client
82+
await expect(enhanced.post.findManyListView()).resolves.toMatchObject({
83+
view: true,
84+
data: [
85+
{ id: 1, title: 'post1', published: true },
86+
{ id: 2, title: 'post2', published: false },
87+
],
88+
});
89+
});
90+
});
Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
import { loadSchema } from '@zenstackhq/testtools';
2+
3+
describe('issue prisma extension', () => {
4+
it('extend enhanced client', async () => {
5+
const { enhance, prisma } = await loadSchema(
6+
`
7+
model Post {
8+
id Int @id
9+
title String
10+
published Boolean
11+
12+
@@allow('create', true)
13+
@@allow('read', published)
14+
}
15+
`
16+
);
17+
18+
await prisma.post.create({ data: { id: 1, title: 'post1', published: true } });
19+
await prisma.post.create({ data: { id: 2, title: 'post2', published: false } });
20+
21+
const db = enhance();
22+
await expect(db.post.findMany()).resolves.toHaveLength(1);
23+
24+
const extended = db.$extends({
25+
model: {
26+
post: {
27+
findManyListView: async (args: any) => {
28+
return { view: true, data: await db.post.findMany(args) };
29+
},
30+
},
31+
},
32+
});
33+
34+
await expect(extended.post.findManyListView()).resolves.toMatchObject({
35+
view: true,
36+
data: [{ id: 1, title: 'post1', published: true }],
37+
});
38+
await expect(extended.post.findMany()).resolves.toHaveLength(1);
39+
});
40+
41+
it('enhance extended client', async () => {
42+
const { enhanceRaw, prisma, prismaModule } = await loadSchema(
43+
`
44+
model Post {
45+
id Int @id
46+
title String
47+
published Boolean
48+
49+
@@allow('create', true)
50+
@@allow('read', published)
51+
}
52+
`
53+
);
54+
55+
await prisma.post.create({ data: { id: 1, title: 'post1', published: true } });
56+
await prisma.post.create({ data: { id: 2, title: 'post2', published: false } });
57+
58+
const ext = prismaModule.defineExtension((_prisma: any) => {
59+
return _prisma.$extends({
60+
model: {
61+
post: {
62+
findManyListView: async (args: any) => {
63+
return { view: true, data: await prisma.post.findMany(args) };
64+
},
65+
},
66+
},
67+
});
68+
});
69+
70+
await expect(prisma.$extends(ext).post.findMany()).resolves.toHaveLength(2);
71+
await expect(prisma.$extends(ext).post.findManyListView()).resolves.toMatchObject({
72+
view: true,
73+
data: [
74+
{ id: 1, title: 'post1', published: true },
75+
{ id: 2, title: 'post2', published: false },
76+
],
77+
});
78+
79+
const enhanced = enhanceRaw(prisma.$extends(ext));
80+
await expect(enhanced.post.findMany()).resolves.toHaveLength(1);
81+
// findManyListView internally uses the un-enhanced client
82+
await expect(enhanced.post.findManyListView()).resolves.toMatchObject({
83+
view: true,
84+
data: [
85+
{ id: 1, title: 'post1', published: true },
86+
{ id: 2, title: 'post2', published: false },
87+
],
88+
});
89+
});
90+
});

0 commit comments

Comments
 (0)