From b99f14bf59c9f4bdc8e7ea42d9ede6e55d4a812c Mon Sep 17 00:00:00 2001 From: ymc9 <104139426+ymc9@users.noreply.github.com> Date: Tue, 12 Nov 2024 17:29:24 -0800 Subject: [PATCH 1/3] fix(runtime): intercepts `$extends` to reproxy its result to make sure enhancements persist --- .../runtime/src/enhancements/node/proxy.ts | 19 ++++ .../with-policy/client-extensions.test.ts | 63 ++++++-------- .../tests/issue-prisma-extension.test.ts | 87 +++++++++++++++++++ 3 files changed, 131 insertions(+), 38 deletions(-) create mode 100644 tests/regression/tests/issue-prisma-extension.test.ts diff --git a/packages/runtime/src/enhancements/node/proxy.ts b/packages/runtime/src/enhancements/node/proxy.ts index ae4105301..cfbc0eb7c 100644 --- a/packages/runtime/src/enhancements/node/proxy.ts +++ b/packages/runtime/src/enhancements/node/proxy.ts @@ -254,6 +254,25 @@ export function makeProxy<T extends PrismaProxyHandler>( } } + if (prop === '$extends') { + // Prisma's `$extends` API returns a new client instance, we need to recreate + // a proxy around it + const $extends = Reflect.get(target, prop, receiver); + if ($extends && typeof $extends === 'function') { + return (...args: any[]) => { + const result = $extends.bind(target)(...args); + if (!result[PRISMA_PROXY_ENHANCER]) { + return makeProxy(result, modelMeta, makeHandler, name + '$ext', errorTransformer); + } else { + // avoid double wrapping + return result; + } + }; + } else { + return $extends; + } + } + if (typeof prop !== 'string' || prop.startsWith('$') || !models.includes(prop.toLowerCase())) { // skip non-model fields return Reflect.get(target, prop, receiver); diff --git a/tests/integration/tests/enhancements/with-policy/client-extensions.test.ts b/tests/integration/tests/enhancements/with-policy/client-extensions.test.ts index 13f05aa51..1d907a4f2 100644 --- a/tests/integration/tests/enhancements/with-policy/client-extensions.test.ts +++ b/tests/integration/tests/enhancements/with-policy/client-extensions.test.ts @@ -44,13 +44,9 @@ describe('With Policy: client extensions', () => { }); }); - const xprisma = prisma.$extends(ext); - const db = enhanceRaw(xprisma); - await expect(db.model.getAll()).resolves.toHaveLength(2); - - // FIXME: extending an enhanced client doesn't work for this case - // const db1 = enhance(prisma).$extends(ext); - // await expect(db1.model.getAll()).resolves.toHaveLength(2); + await expect(prisma.$extends(ext).model.getAll()).resolves.toHaveLength(3); + await expect(enhanceRaw(prisma.$extends(ext)).model.getAll()).resolves.toHaveLength(2); + await expect(enhanceRaw(prisma).$extends(ext).model.getAll()).resolves.toHaveLength(2); }); it('one model new method', async () => { @@ -84,9 +80,9 @@ describe('With Policy: client extensions', () => { }); }); - const xprisma = prisma.$extends(ext); - const db = enhanceRaw(xprisma); - await expect(db.model.getAll()).resolves.toHaveLength(2); + await expect(prisma.$extends(ext).model.getAll()).resolves.toHaveLength(3); + await expect(enhanceRaw(prisma.$extends(ext)).model.getAll()).resolves.toHaveLength(2); + await expect(enhanceRaw(prisma).$extends(ext).model.getAll()).resolves.toHaveLength(2); }); it('add client method', async () => { @@ -115,8 +111,11 @@ describe('With Policy: client extensions', () => { }); }); - const xprisma = prisma.$extends(ext); - xprisma.$log('abc'); + enhanceRaw(prisma).$extends(ext).$log('abc'); + expect(logged).toBeTruthy(); + + logged = false; + enhanceRaw(prisma.$extends(ext)).$log('abc'); expect(logged).toBeTruthy(); }); @@ -143,7 +142,6 @@ describe('With Policy: client extensions', () => { query: { model: { async findMany({ args, query }: any) { - // take incoming `where` and set `age` args.where = { ...args.where, y: { lt: 300 } }; return query(args); }, @@ -152,9 +150,8 @@ describe('With Policy: client extensions', () => { }); }); - const xprisma = prisma.$extends(ext); - const db = enhanceRaw(xprisma); - await expect(db.model.findMany()).resolves.toHaveLength(1); + await expect(enhanceRaw(prisma.$extends(ext)).model.findMany()).resolves.toHaveLength(1); + await expect(enhanceRaw(prisma).$extends(ext).model.findMany()).resolves.toHaveLength(1); }); it('query override all models', async () => { @@ -180,7 +177,6 @@ describe('With Policy: client extensions', () => { query: { $allModels: { async findMany({ args, query }: any) { - // take incoming `where` and set `age` args.where = { ...args.where, y: { lt: 300 } }; console.log('findMany args:', args); return query(args); @@ -190,9 +186,8 @@ describe('With Policy: client extensions', () => { }); }); - const xprisma = prisma.$extends(ext); - const db = enhanceRaw(xprisma); - await expect(db.model.findMany()).resolves.toHaveLength(1); + await expect(enhanceRaw(prisma.$extends(ext)).model.findMany()).resolves.toHaveLength(1); + await expect(enhanceRaw(prisma).$extends(ext).model.findMany()).resolves.toHaveLength(1); }); it('query override all operations', async () => { @@ -218,7 +213,6 @@ describe('With Policy: client extensions', () => { query: { model: { async $allOperations({ operation, args, query }: any) { - // take incoming `where` and set `age` args.where = { ...args.where, y: { lt: 300 } }; console.log(`${operation} args:`, args); return query(args); @@ -228,9 +222,8 @@ describe('With Policy: client extensions', () => { }); }); - const xprisma = prisma.$extends(ext); - const db = enhanceRaw(xprisma); - await expect(db.model.findMany()).resolves.toHaveLength(1); + await expect(enhanceRaw(prisma.$extends(ext)).model.findMany()).resolves.toHaveLength(1); + await expect(enhanceRaw(prisma).$extends(ext).model.findMany()).resolves.toHaveLength(1); }); it('query override everything', async () => { @@ -255,7 +248,6 @@ describe('With Policy: client extensions', () => { name: 'prisma-extension-queryOverride', query: { async $allOperations({ operation, args, query }: any) { - // take incoming `where` and set `age` args.where = { ...args.where, y: { lt: 300 } }; console.log(`${operation} args:`, args); return query(args); @@ -264,9 +256,8 @@ describe('With Policy: client extensions', () => { }); }); - const xprisma = prisma.$extends(ext); - const db = enhanceRaw(xprisma); - await expect(db.model.findMany()).resolves.toHaveLength(1); + await expect(enhanceRaw(prisma.$extends(ext)).model.findMany()).resolves.toHaveLength(1); + await expect(enhanceRaw(prisma).$extends(ext).model.findMany()).resolves.toHaveLength(1); }); it('result mutation', async () => { @@ -301,11 +292,9 @@ describe('With Policy: client extensions', () => { }); }); - const xprisma = prisma.$extends(ext); - const db = enhanceRaw(xprisma); - const r = await db.model.findMany(); - expect(r).toHaveLength(1); - expect(r).toEqual(expect.arrayContaining([expect.objectContaining({ value: 2 })])); + const expected = [expect.objectContaining({ value: 2 })]; + await expect(enhanceRaw(prisma.$extends(ext)).model.findMany()).resolves.toEqual(expected); + await expect(enhanceRaw(prisma).$extends(ext).model.findMany()).resolves.toEqual(expected); }); it('result custom fields', async () => { @@ -339,10 +328,8 @@ describe('With Policy: client extensions', () => { }); }); - const xprisma = prisma.$extends(ext); - const db = enhanceRaw(xprisma); - const r = await db.model.findMany(); - expect(r).toHaveLength(1); - expect(r).toEqual(expect.arrayContaining([expect.objectContaining({ doubleValue: 2 })])); + const expected = [expect.objectContaining({ doubleValue: 2 })]; + await expect(enhanceRaw(prisma.$extends(ext)).model.findMany()).resolves.toEqual(expected); + await expect(enhanceRaw(prisma).$extends(ext).model.findMany()).resolves.toEqual(expected); }); }); diff --git a/tests/regression/tests/issue-prisma-extension.test.ts b/tests/regression/tests/issue-prisma-extension.test.ts new file mode 100644 index 000000000..5d70c963c --- /dev/null +++ b/tests/regression/tests/issue-prisma-extension.test.ts @@ -0,0 +1,87 @@ +import { loadSchema } from '@zenstackhq/testtools'; + +describe('issue prisma extension', () => { + it('extend enhanced client', async () => { + const { enhance, prisma } = await loadSchema( + ` + model Post { + id Int @id + title String + published Boolean + + @@allow('create', true) + @@allow('read', published) + } + ` + ); + + await prisma.post.create({ data: { id: 1, title: 'post1', published: true } }); + await prisma.post.create({ data: { id: 2, title: 'post2', published: false } }); + + const db = enhance(); + await expect(db.post.findMany()).resolves.toHaveLength(1); + + const extended = db.$extends({ + model: { + post: { + findManyListView: async (args: any) => { + return { view: true, data: await db.post.findMany(args) }; + }, + }, + }, + }); + + await expect(extended.post.findManyListView()).resolves.toMatchObject({ + view: true, + data: [{ id: 1, title: 'post1', published: true }], + }); + await expect(extended.post.findMany()).resolves.toHaveLength(1); + }); + + it('enhance extended client', async () => { + const { enhanceRaw, prisma } = await loadSchema( + ` + model Post { + id Int @id + title String + published Boolean + + @@allow('create', true) + @@allow('read', published) + } + ` + ); + + await prisma.post.create({ data: { id: 1, title: 'post1', published: true } }); + await prisma.post.create({ data: { id: 2, title: 'post2', published: false } }); + + const extended = prisma.$extends({ + model: { + post: { + findManyListView: async (args: any) => { + return { view: true, data: await prisma.post.findMany(args) }; + }, + }, + }, + }); + + await expect(extended.post.findMany()).resolves.toHaveLength(2); + await expect(extended.post.findManyListView()).resolves.toMatchObject({ + view: true, + data: [ + { id: 1, title: 'post1', published: true }, + { id: 2, title: 'post2', published: false }, + ], + }); + + const db = enhanceRaw(extended); + await expect(db.post.findMany()).resolves.toHaveLength(1); + await expect(db.post.findManyListView()).resolves.toMatchObject({ + view: true, + data: [ + { id: 1, title: 'post1', published: true }, + { id: 2, title: 'post2', published: false }, + ], + }); + }); +}); From e7f6b54a1370474215df817f4f2f672ad5be0dd0 Mon Sep 17 00:00:00 2001 From: ymc9 <104139426+ymc9@users.noreply.github.com> Date: Tue, 12 Nov 2024 17:34:54 -0800 Subject: [PATCH 2/3] update --- .../tests/issue-prisma-extension.test.ts | 27 ++++++++++--------- 1 file changed, 15 insertions(+), 12 deletions(-) diff --git a/tests/regression/tests/issue-prisma-extension.test.ts b/tests/regression/tests/issue-prisma-extension.test.ts index 5d70c963c..fa041a18a 100644 --- a/tests/regression/tests/issue-prisma-extension.test.ts +++ b/tests/regression/tests/issue-prisma-extension.test.ts @@ -39,7 +39,7 @@ describe('issue prisma extension', () => { }); it('enhance extended client', async () => { - const { enhanceRaw, prisma } = await loadSchema( + const { enhanceRaw, prisma, prismaModule } = await loadSchema( ` model Post { id Int @id @@ -55,18 +55,20 @@ describe('issue prisma extension', () => { await prisma.post.create({ data: { id: 1, title: 'post1', published: true } }); await prisma.post.create({ data: { id: 2, title: 'post2', published: false } }); - const extended = prisma.$extends({ - model: { - post: { - findManyListView: async (args: any) => { - return { view: true, data: await prisma.post.findMany(args) }; + const ext = prismaModule.defineExtension((_prisma: any) => { + return _prisma.$extends({ + model: { + post: { + findManyListView: async (args: any) => { + return { view: true, data: await prisma.post.findMany(args) }; + }, }, }, - }, + }); }); - await expect(extended.post.findMany()).resolves.toHaveLength(2); - await expect(extended.post.findManyListView()).resolves.toMatchObject({ + await expect(prisma.$extends(ext).post.findMany()).resolves.toHaveLength(2); + await expect(prisma.$extends(ext).post.findManyListView()).resolves.toMatchObject({ view: true, data: [ { id: 1, title: 'post1', published: true }, @@ -74,9 +76,10 @@ describe('issue prisma extension', () => { ], }); - const db = enhanceRaw(extended); - await expect(db.post.findMany()).resolves.toHaveLength(1); - await expect(db.post.findManyListView()).resolves.toMatchObject({ + const enhanced = enhanceRaw(prisma.$extends(ext)); + await expect(enhanced.post.findMany()).resolves.toHaveLength(1); + // findManyListView internally uses the un-enhanced client + await expect(enhanced.post.findManyListView()).resolves.toMatchObject({ view: true, data: [ { id: 1, title: 'post1', published: true }, From 0ee16ccb5d6ac29e5c4bb4cadaf787036d487dd5 Mon Sep 17 00:00:00 2001 From: ymc9 <104139426+ymc9@users.noreply.github.com> Date: Fri, 15 Nov 2024 22:13:28 -0800 Subject: [PATCH 3/3] update test --- tests/regression/tests/issue-1859.test.ts | 90 +++++++++++++++++++++++ 1 file changed, 90 insertions(+) create mode 100644 tests/regression/tests/issue-1859.test.ts diff --git a/tests/regression/tests/issue-1859.test.ts b/tests/regression/tests/issue-1859.test.ts new file mode 100644 index 000000000..2b9d4538b --- /dev/null +++ b/tests/regression/tests/issue-1859.test.ts @@ -0,0 +1,90 @@ +import { loadSchema } from '@zenstackhq/testtools'; + +describe('issue 1859', () => { + it('extend enhanced client', async () => { + const { enhance, prisma } = await loadSchema( + ` + model Post { + id Int @id + title String + published Boolean + + @@allow('create', true) + @@allow('read', published) + } + ` + ); + + await prisma.post.create({ data: { id: 1, title: 'post1', published: true } }); + await prisma.post.create({ data: { id: 2, title: 'post2', published: false } }); + + const db = enhance(); + await expect(db.post.findMany()).resolves.toHaveLength(1); + + const extended = db.$extends({ + model: { + post: { + findManyListView: async (args: any) => { + return { view: true, data: await db.post.findMany(args) }; + }, + }, + }, + }); + + await expect(extended.post.findManyListView()).resolves.toMatchObject({ + view: true, + data: [{ id: 1, title: 'post1', published: true }], + }); + await expect(extended.post.findMany()).resolves.toHaveLength(1); + }); + + it('enhance extended client', async () => { + const { enhanceRaw, prisma, prismaModule } = await loadSchema( + ` + model Post { + id Int @id + title String + published Boolean + + @@allow('create', true) + @@allow('read', published) + } + ` + ); + + await prisma.post.create({ data: { id: 1, title: 'post1', published: true } }); + await prisma.post.create({ data: { id: 2, title: 'post2', published: false } }); + + const ext = prismaModule.defineExtension((_prisma: any) => { + return _prisma.$extends({ + model: { + post: { + findManyListView: async (args: any) => { + return { view: true, data: await prisma.post.findMany(args) }; + }, + }, + }, + }); + }); + + await expect(prisma.$extends(ext).post.findMany()).resolves.toHaveLength(2); + await expect(prisma.$extends(ext).post.findManyListView()).resolves.toMatchObject({ + view: true, + data: [ + { id: 1, title: 'post1', published: true }, + { id: 2, title: 'post2', published: false }, + ], + }); + + const enhanced = enhanceRaw(prisma.$extends(ext)); + await expect(enhanced.post.findMany()).resolves.toHaveLength(1); + // findManyListView internally uses the un-enhanced client + await expect(enhanced.post.findManyListView()).resolves.toMatchObject({ + view: true, + data: [ + { id: 1, title: 'post1', published: true }, + { id: 2, title: 'post2', published: false }, + ], + }); + }); +});