From b99f14bf59c9f4bdc8e7ea42d9ede6e55d4a812c Mon Sep 17 00:00:00 2001
From: ymc9 <104139426+ymc9@users.noreply.github.com>
Date: Tue, 12 Nov 2024 17:29:24 -0800
Subject: [PATCH 1/3] fix(runtime): intercepts `$extends` to reproxy its result
 to make sure enhancements persist

---
 .../runtime/src/enhancements/node/proxy.ts    | 19 ++++
 .../with-policy/client-extensions.test.ts     | 63 ++++++--------
 .../tests/issue-prisma-extension.test.ts      | 87 +++++++++++++++++++
 3 files changed, 131 insertions(+), 38 deletions(-)
 create mode 100644 tests/regression/tests/issue-prisma-extension.test.ts

diff --git a/packages/runtime/src/enhancements/node/proxy.ts b/packages/runtime/src/enhancements/node/proxy.ts
index ae4105301..cfbc0eb7c 100644
--- a/packages/runtime/src/enhancements/node/proxy.ts
+++ b/packages/runtime/src/enhancements/node/proxy.ts
@@ -254,6 +254,25 @@ export function makeProxy<T extends PrismaProxyHandler>(
                 }
             }
 
+            if (prop === '$extends') {
+                // Prisma's `$extends` API returns a new client instance, we need to recreate
+                // a proxy around it
+                const $extends = Reflect.get(target, prop, receiver);
+                if ($extends && typeof $extends === 'function') {
+                    return (...args: any[]) => {
+                        const result = $extends.bind(target)(...args);
+                        if (!result[PRISMA_PROXY_ENHANCER]) {
+                            return makeProxy(result, modelMeta, makeHandler, name + '$ext', errorTransformer);
+                        } else {
+                            // avoid double wrapping
+                            return result;
+                        }
+                    };
+                } else {
+                    return $extends;
+                }
+            }
+
             if (typeof prop !== 'string' || prop.startsWith('$') || !models.includes(prop.toLowerCase())) {
                 // skip non-model fields
                 return Reflect.get(target, prop, receiver);
diff --git a/tests/integration/tests/enhancements/with-policy/client-extensions.test.ts b/tests/integration/tests/enhancements/with-policy/client-extensions.test.ts
index 13f05aa51..1d907a4f2 100644
--- a/tests/integration/tests/enhancements/with-policy/client-extensions.test.ts
+++ b/tests/integration/tests/enhancements/with-policy/client-extensions.test.ts
@@ -44,13 +44,9 @@ describe('With Policy: client extensions', () => {
             });
         });
 
-        const xprisma = prisma.$extends(ext);
-        const db = enhanceRaw(xprisma);
-        await expect(db.model.getAll()).resolves.toHaveLength(2);
-
-        // FIXME: extending an enhanced client doesn't work for this case
-        // const db1 = enhance(prisma).$extends(ext);
-        // await expect(db1.model.getAll()).resolves.toHaveLength(2);
+        await expect(prisma.$extends(ext).model.getAll()).resolves.toHaveLength(3);
+        await expect(enhanceRaw(prisma.$extends(ext)).model.getAll()).resolves.toHaveLength(2);
+        await expect(enhanceRaw(prisma).$extends(ext).model.getAll()).resolves.toHaveLength(2);
     });
 
     it('one model new method', async () => {
@@ -84,9 +80,9 @@ describe('With Policy: client extensions', () => {
             });
         });
 
-        const xprisma = prisma.$extends(ext);
-        const db = enhanceRaw(xprisma);
-        await expect(db.model.getAll()).resolves.toHaveLength(2);
+        await expect(prisma.$extends(ext).model.getAll()).resolves.toHaveLength(3);
+        await expect(enhanceRaw(prisma.$extends(ext)).model.getAll()).resolves.toHaveLength(2);
+        await expect(enhanceRaw(prisma).$extends(ext).model.getAll()).resolves.toHaveLength(2);
     });
 
     it('add client method', async () => {
@@ -115,8 +111,11 @@ describe('With Policy: client extensions', () => {
             });
         });
 
-        const xprisma = prisma.$extends(ext);
-        xprisma.$log('abc');
+        enhanceRaw(prisma).$extends(ext).$log('abc');
+        expect(logged).toBeTruthy();
+
+        logged = false;
+        enhanceRaw(prisma.$extends(ext)).$log('abc');
         expect(logged).toBeTruthy();
     });
 
@@ -143,7 +142,6 @@ describe('With Policy: client extensions', () => {
                 query: {
                     model: {
                         async findMany({ args, query }: any) {
-                            // take incoming `where` and set `age`
                             args.where = { ...args.where, y: { lt: 300 } };
                             return query(args);
                         },
@@ -152,9 +150,8 @@ describe('With Policy: client extensions', () => {
             });
         });
 
-        const xprisma = prisma.$extends(ext);
-        const db = enhanceRaw(xprisma);
-        await expect(db.model.findMany()).resolves.toHaveLength(1);
+        await expect(enhanceRaw(prisma.$extends(ext)).model.findMany()).resolves.toHaveLength(1);
+        await expect(enhanceRaw(prisma).$extends(ext).model.findMany()).resolves.toHaveLength(1);
     });
 
     it('query override all models', async () => {
@@ -180,7 +177,6 @@ describe('With Policy: client extensions', () => {
                 query: {
                     $allModels: {
                         async findMany({ args, query }: any) {
-                            // take incoming `where` and set `age`
                             args.where = { ...args.where, y: { lt: 300 } };
                             console.log('findMany args:', args);
                             return query(args);
@@ -190,9 +186,8 @@ describe('With Policy: client extensions', () => {
             });
         });
 
-        const xprisma = prisma.$extends(ext);
-        const db = enhanceRaw(xprisma);
-        await expect(db.model.findMany()).resolves.toHaveLength(1);
+        await expect(enhanceRaw(prisma.$extends(ext)).model.findMany()).resolves.toHaveLength(1);
+        await expect(enhanceRaw(prisma).$extends(ext).model.findMany()).resolves.toHaveLength(1);
     });
 
     it('query override all operations', async () => {
@@ -218,7 +213,6 @@ describe('With Policy: client extensions', () => {
                 query: {
                     model: {
                         async $allOperations({ operation, args, query }: any) {
-                            // take incoming `where` and set `age`
                             args.where = { ...args.where, y: { lt: 300 } };
                             console.log(`${operation} args:`, args);
                             return query(args);
@@ -228,9 +222,8 @@ describe('With Policy: client extensions', () => {
             });
         });
 
-        const xprisma = prisma.$extends(ext);
-        const db = enhanceRaw(xprisma);
-        await expect(db.model.findMany()).resolves.toHaveLength(1);
+        await expect(enhanceRaw(prisma.$extends(ext)).model.findMany()).resolves.toHaveLength(1);
+        await expect(enhanceRaw(prisma).$extends(ext).model.findMany()).resolves.toHaveLength(1);
     });
 
     it('query override everything', async () => {
@@ -255,7 +248,6 @@ describe('With Policy: client extensions', () => {
                 name: 'prisma-extension-queryOverride',
                 query: {
                     async $allOperations({ operation, args, query }: any) {
-                        // take incoming `where` and set `age`
                         args.where = { ...args.where, y: { lt: 300 } };
                         console.log(`${operation} args:`, args);
                         return query(args);
@@ -264,9 +256,8 @@ describe('With Policy: client extensions', () => {
             });
         });
 
-        const xprisma = prisma.$extends(ext);
-        const db = enhanceRaw(xprisma);
-        await expect(db.model.findMany()).resolves.toHaveLength(1);
+        await expect(enhanceRaw(prisma.$extends(ext)).model.findMany()).resolves.toHaveLength(1);
+        await expect(enhanceRaw(prisma).$extends(ext).model.findMany()).resolves.toHaveLength(1);
     });
 
     it('result mutation', async () => {
@@ -301,11 +292,9 @@ describe('With Policy: client extensions', () => {
             });
         });
 
-        const xprisma = prisma.$extends(ext);
-        const db = enhanceRaw(xprisma);
-        const r = await db.model.findMany();
-        expect(r).toHaveLength(1);
-        expect(r).toEqual(expect.arrayContaining([expect.objectContaining({ value: 2 })]));
+        const expected = [expect.objectContaining({ value: 2 })];
+        await expect(enhanceRaw(prisma.$extends(ext)).model.findMany()).resolves.toEqual(expected);
+        await expect(enhanceRaw(prisma).$extends(ext).model.findMany()).resolves.toEqual(expected);
     });
 
     it('result custom fields', async () => {
@@ -339,10 +328,8 @@ describe('With Policy: client extensions', () => {
             });
         });
 
-        const xprisma = prisma.$extends(ext);
-        const db = enhanceRaw(xprisma);
-        const r = await db.model.findMany();
-        expect(r).toHaveLength(1);
-        expect(r).toEqual(expect.arrayContaining([expect.objectContaining({ doubleValue: 2 })]));
+        const expected = [expect.objectContaining({ doubleValue: 2 })];
+        await expect(enhanceRaw(prisma.$extends(ext)).model.findMany()).resolves.toEqual(expected);
+        await expect(enhanceRaw(prisma).$extends(ext).model.findMany()).resolves.toEqual(expected);
     });
 });
diff --git a/tests/regression/tests/issue-prisma-extension.test.ts b/tests/regression/tests/issue-prisma-extension.test.ts
new file mode 100644
index 000000000..5d70c963c
--- /dev/null
+++ b/tests/regression/tests/issue-prisma-extension.test.ts
@@ -0,0 +1,87 @@
+import { loadSchema } from '@zenstackhq/testtools';
+
+describe('issue prisma extension', () => {
+    it('extend enhanced client', async () => {
+        const { enhance, prisma } = await loadSchema(
+            `
+            model Post {
+                id Int @id
+                title String
+                published Boolean
+
+                @@allow('create', true)
+                @@allow('read', published)
+            }
+            `
+        );
+
+        await prisma.post.create({ data: { id: 1, title: 'post1', published: true } });
+        await prisma.post.create({ data: { id: 2, title: 'post2', published: false } });
+
+        const db = enhance();
+        await expect(db.post.findMany()).resolves.toHaveLength(1);
+
+        const extended = db.$extends({
+            model: {
+                post: {
+                    findManyListView: async (args: any) => {
+                        return { view: true, data: await db.post.findMany(args) };
+                    },
+                },
+            },
+        });
+
+        await expect(extended.post.findManyListView()).resolves.toMatchObject({
+            view: true,
+            data: [{ id: 1, title: 'post1', published: true }],
+        });
+        await expect(extended.post.findMany()).resolves.toHaveLength(1);
+    });
+
+    it('enhance extended client', async () => {
+        const { enhanceRaw, prisma } = await loadSchema(
+            `
+            model Post {
+                id Int @id
+                title String
+                published Boolean
+
+                @@allow('create', true)
+                @@allow('read', published)
+            }
+            `
+        );
+
+        await prisma.post.create({ data: { id: 1, title: 'post1', published: true } });
+        await prisma.post.create({ data: { id: 2, title: 'post2', published: false } });
+
+        const extended = prisma.$extends({
+            model: {
+                post: {
+                    findManyListView: async (args: any) => {
+                        return { view: true, data: await prisma.post.findMany(args) };
+                    },
+                },
+            },
+        });
+
+        await expect(extended.post.findMany()).resolves.toHaveLength(2);
+        await expect(extended.post.findManyListView()).resolves.toMatchObject({
+            view: true,
+            data: [
+                { id: 1, title: 'post1', published: true },
+                { id: 2, title: 'post2', published: false },
+            ],
+        });
+
+        const db = enhanceRaw(extended);
+        await expect(db.post.findMany()).resolves.toHaveLength(1);
+        await expect(db.post.findManyListView()).resolves.toMatchObject({
+            view: true,
+            data: [
+                { id: 1, title: 'post1', published: true },
+                { id: 2, title: 'post2', published: false },
+            ],
+        });
+    });
+});

From e7f6b54a1370474215df817f4f2f672ad5be0dd0 Mon Sep 17 00:00:00 2001
From: ymc9 <104139426+ymc9@users.noreply.github.com>
Date: Tue, 12 Nov 2024 17:34:54 -0800
Subject: [PATCH 2/3] update

---
 .../tests/issue-prisma-extension.test.ts      | 27 ++++++++++---------
 1 file changed, 15 insertions(+), 12 deletions(-)

diff --git a/tests/regression/tests/issue-prisma-extension.test.ts b/tests/regression/tests/issue-prisma-extension.test.ts
index 5d70c963c..fa041a18a 100644
--- a/tests/regression/tests/issue-prisma-extension.test.ts
+++ b/tests/regression/tests/issue-prisma-extension.test.ts
@@ -39,7 +39,7 @@ describe('issue prisma extension', () => {
     });
 
     it('enhance extended client', async () => {
-        const { enhanceRaw, prisma } = await loadSchema(
+        const { enhanceRaw, prisma, prismaModule } = await loadSchema(
             `
             model Post {
                 id Int @id
@@ -55,18 +55,20 @@ describe('issue prisma extension', () => {
         await prisma.post.create({ data: { id: 1, title: 'post1', published: true } });
         await prisma.post.create({ data: { id: 2, title: 'post2', published: false } });
 
-        const extended = prisma.$extends({
-            model: {
-                post: {
-                    findManyListView: async (args: any) => {
-                        return { view: true, data: await prisma.post.findMany(args) };
+        const ext = prismaModule.defineExtension((_prisma: any) => {
+            return _prisma.$extends({
+                model: {
+                    post: {
+                        findManyListView: async (args: any) => {
+                            return { view: true, data: await prisma.post.findMany(args) };
+                        },
                     },
                 },
-            },
+            });
         });
 
-        await expect(extended.post.findMany()).resolves.toHaveLength(2);
-        await expect(extended.post.findManyListView()).resolves.toMatchObject({
+        await expect(prisma.$extends(ext).post.findMany()).resolves.toHaveLength(2);
+        await expect(prisma.$extends(ext).post.findManyListView()).resolves.toMatchObject({
             view: true,
             data: [
                 { id: 1, title: 'post1', published: true },
@@ -74,9 +76,10 @@ describe('issue prisma extension', () => {
             ],
         });
 
-        const db = enhanceRaw(extended);
-        await expect(db.post.findMany()).resolves.toHaveLength(1);
-        await expect(db.post.findManyListView()).resolves.toMatchObject({
+        const enhanced = enhanceRaw(prisma.$extends(ext));
+        await expect(enhanced.post.findMany()).resolves.toHaveLength(1);
+        // findManyListView internally uses the un-enhanced client
+        await expect(enhanced.post.findManyListView()).resolves.toMatchObject({
             view: true,
             data: [
                 { id: 1, title: 'post1', published: true },

From 0ee16ccb5d6ac29e5c4bb4cadaf787036d487dd5 Mon Sep 17 00:00:00 2001
From: ymc9 <104139426+ymc9@users.noreply.github.com>
Date: Fri, 15 Nov 2024 22:13:28 -0800
Subject: [PATCH 3/3] update test

---
 tests/regression/tests/issue-1859.test.ts | 90 +++++++++++++++++++++++
 1 file changed, 90 insertions(+)
 create mode 100644 tests/regression/tests/issue-1859.test.ts

diff --git a/tests/regression/tests/issue-1859.test.ts b/tests/regression/tests/issue-1859.test.ts
new file mode 100644
index 000000000..2b9d4538b
--- /dev/null
+++ b/tests/regression/tests/issue-1859.test.ts
@@ -0,0 +1,90 @@
+import { loadSchema } from '@zenstackhq/testtools';
+
+describe('issue 1859', () => {
+    it('extend enhanced client', async () => {
+        const { enhance, prisma } = await loadSchema(
+            `
+            model Post {
+                id Int @id
+                title String
+                published Boolean
+
+                @@allow('create', true)
+                @@allow('read', published)
+            }
+            `
+        );
+
+        await prisma.post.create({ data: { id: 1, title: 'post1', published: true } });
+        await prisma.post.create({ data: { id: 2, title: 'post2', published: false } });
+
+        const db = enhance();
+        await expect(db.post.findMany()).resolves.toHaveLength(1);
+
+        const extended = db.$extends({
+            model: {
+                post: {
+                    findManyListView: async (args: any) => {
+                        return { view: true, data: await db.post.findMany(args) };
+                    },
+                },
+            },
+        });
+
+        await expect(extended.post.findManyListView()).resolves.toMatchObject({
+            view: true,
+            data: [{ id: 1, title: 'post1', published: true }],
+        });
+        await expect(extended.post.findMany()).resolves.toHaveLength(1);
+    });
+
+    it('enhance extended client', async () => {
+        const { enhanceRaw, prisma, prismaModule } = await loadSchema(
+            `
+            model Post {
+                id Int @id
+                title String
+                published Boolean
+
+                @@allow('create', true)
+                @@allow('read', published)
+            }
+            `
+        );
+
+        await prisma.post.create({ data: { id: 1, title: 'post1', published: true } });
+        await prisma.post.create({ data: { id: 2, title: 'post2', published: false } });
+
+        const ext = prismaModule.defineExtension((_prisma: any) => {
+            return _prisma.$extends({
+                model: {
+                    post: {
+                        findManyListView: async (args: any) => {
+                            return { view: true, data: await prisma.post.findMany(args) };
+                        },
+                    },
+                },
+            });
+        });
+
+        await expect(prisma.$extends(ext).post.findMany()).resolves.toHaveLength(2);
+        await expect(prisma.$extends(ext).post.findManyListView()).resolves.toMatchObject({
+            view: true,
+            data: [
+                { id: 1, title: 'post1', published: true },
+                { id: 2, title: 'post2', published: false },
+            ],
+        });
+
+        const enhanced = enhanceRaw(prisma.$extends(ext));
+        await expect(enhanced.post.findMany()).resolves.toHaveLength(1);
+        // findManyListView internally uses the un-enhanced client
+        await expect(enhanced.post.findManyListView()).resolves.toMatchObject({
+            view: true,
+            data: [
+                { id: 1, title: 'post1', published: true },
+                { id: 2, title: 'post2', published: false },
+            ],
+        });
+    });
+});