diff --git a/packages/runtime/src/enhancements/edge/encrypted.ts b/packages/runtime/src/enhancements/edge/encrypted.ts new file mode 120000 index 000000000..96d88b82d --- /dev/null +++ b/packages/runtime/src/enhancements/edge/encrypted.ts @@ -0,0 +1 @@ +../node/encrypted.ts \ No newline at end of file diff --git a/packages/runtime/src/enhancements/node/create-enhancement.ts b/packages/runtime/src/enhancements/node/create-enhancement.ts index adec1fdf2..871f8a1b4 100644 --- a/packages/runtime/src/enhancements/node/create-enhancement.ts +++ b/packages/runtime/src/enhancements/node/create-enhancement.ts @@ -14,13 +14,14 @@ import { withJsonProcessor } from './json-processor'; import { Logger } from './logger'; import { withOmit } from './omit'; import { withPassword } from './password'; +import { withEncrypted } from './encrypted'; import { policyProcessIncludeRelationPayload, withPolicy } from './policy'; import type { PolicyDef } from './types'; /** * All enhancement kinds */ -const ALL_ENHANCEMENTS: EnhancementKind[] = ['password', 'omit', 'policy', 'validation', 'delegate']; +const ALL_ENHANCEMENTS: EnhancementKind[] = ['password', 'omit', 'policy', 'validation', 'delegate', 'encrypted']; /** * Options for {@link createEnhancement} @@ -100,6 +101,7 @@ export function createEnhancement<DbClient extends object>( } const hasPassword = allFields.some((field) => field.attributes?.some((attr) => attr.name === '@password')); + const hasEncrypted = allFields.some((field) => field.attributes?.some((attr) => attr.name === '@encrypted')); const hasOmit = allFields.some((field) => field.attributes?.some((attr) => attr.name === '@omit')); const hasDefaultAuth = allFields.some((field) => field.defaultValueProvider); const hasTypeDefField = allFields.some((field) => field.isTypeDef); @@ -120,13 +122,22 @@ export function createEnhancement<DbClient extends object>( } } - // password enhancement must be applied prior to policy because it changes then length of the field + // password and encrypted enhancement must be applied prior to policy because it changes then length of the field // and can break validation rules like `@length` if (hasPassword && kinds.includes('password')) { // @password proxy result = withPassword(result, options); } + if (hasEncrypted && kinds.includes('encrypted')) { + if (!options.encryption) { + throw new Error('Encryption options are required for @encrypted enhancement'); + } + + // @encrypted proxy + result = withEncrypted(result, options); + } + // 'policy' and 'validation' enhancements are both enabled by `withPolicy` if (kinds.includes('policy') || kinds.includes('validation')) { result = withPolicy(result, options, context); diff --git a/packages/runtime/src/enhancements/node/encrypted.ts b/packages/runtime/src/enhancements/node/encrypted.ts new file mode 100644 index 000000000..c6d6fc873 --- /dev/null +++ b/packages/runtime/src/enhancements/node/encrypted.ts @@ -0,0 +1,175 @@ +/* eslint-disable @typescript-eslint/no-explicit-any */ +/* eslint-disable @typescript-eslint/no-unused-vars */ + +import { + FieldInfo, + NestedWriteVisitor, + enumerate, + getModelFields, + resolveField, + type PrismaWriteActionType, +} from '../../cross'; +import { DbClientContract, CustomEncryption, SimpleEncryption } from '../../types'; +import { InternalEnhancementOptions } from './create-enhancement'; +import { DefaultPrismaProxyHandler, PrismaProxyActions, makeProxy } from './proxy'; +import { QueryUtils } from './query-utils'; + +/** + * Gets an enhanced Prisma client that supports `@encrypted` attribute. + * + * @private + */ +export function withEncrypted<DbClient extends object = any>( + prisma: DbClient, + options: InternalEnhancementOptions +): DbClient { + return makeProxy( + prisma, + options.modelMeta, + (_prisma, model) => new EncryptedHandler(_prisma as DbClientContract, model, options), + 'encrypted' + ); +} + +class EncryptedHandler extends DefaultPrismaProxyHandler { + private queryUtils: QueryUtils; + private encoder = new TextEncoder(); + private decoder = new TextDecoder(); + + constructor(prisma: DbClientContract, model: string, options: InternalEnhancementOptions) { + super(prisma, model, options); + + this.queryUtils = new QueryUtils(prisma, options); + + if (!options.encryption) throw new Error('Encryption options must be provided'); + + if (this.isCustomEncryption(options.encryption!)) { + if (!options.encryption.encrypt || !options.encryption.decrypt) + throw new Error('Custom encryption must provide encrypt and decrypt functions'); + } else { + if (!options.encryption.encryptionKey) throw new Error('Encryption key must be provided'); + if (options.encryption.encryptionKey.length !== 32) throw new Error('Encryption key must be 32 bytes'); + } + } + + private async getKey(secret: Uint8Array): Promise<CryptoKey> { + return crypto.subtle.importKey('raw', secret, 'AES-GCM', false, ['encrypt', 'decrypt']); + } + + private isCustomEncryption(encryption: CustomEncryption | SimpleEncryption): encryption is CustomEncryption { + return 'encrypt' in encryption && 'decrypt' in encryption; + } + + private async encrypt(field: FieldInfo, data: string): Promise<string> { + if (this.isCustomEncryption(this.options.encryption!)) { + return this.options.encryption.encrypt(this.model, field, data); + } + + const key = await this.getKey(this.options.encryption!.encryptionKey); + const iv = crypto.getRandomValues(new Uint8Array(12)); + + const encrypted = await crypto.subtle.encrypt( + { + name: 'AES-GCM', + iv, + }, + key, + this.encoder.encode(data) + ); + + // Combine IV and encrypted data into a single array of bytes + const bytes = [...iv, ...new Uint8Array(encrypted)]; + + // Convert bytes to base64 string + return btoa(String.fromCharCode(...bytes)); + } + + private async decrypt(field: FieldInfo, data: string): Promise<string> { + if (this.isCustomEncryption(this.options.encryption!)) { + return this.options.encryption.decrypt(this.model, field, data); + } + + const key = await this.getKey(this.options.encryption!.encryptionKey); + + // Convert base64 back to bytes + const bytes = Uint8Array.from(atob(data), (c) => c.charCodeAt(0)); + + // First 12 bytes are IV, rest is encrypted data + const decrypted = await crypto.subtle.decrypt( + { + name: 'AES-GCM', + iv: bytes.slice(0, 12), + }, + key, + bytes.slice(12) + ); + + return this.decoder.decode(decrypted); + } + + // base override + protected async preprocessArgs(action: PrismaProxyActions, args: any) { + const actionsOfInterest: PrismaProxyActions[] = ['create', 'createMany', 'update', 'updateMany', 'upsert']; + if (args && args.data && actionsOfInterest.includes(action)) { + await this.preprocessWritePayload(this.model, action as PrismaWriteActionType, args); + } + return args; + } + + // base override + protected async processResultEntity<T>(method: PrismaProxyActions, data: T): Promise<T> { + if (!data || typeof data !== 'object') { + return data; + } + + for (const value of enumerate(data)) { + await this.doPostProcess(value, this.model); + } + + return data; + } + + private async doPostProcess(entityData: any, model: string) { + const realModel = this.queryUtils.getDelegateConcreteModel(model, entityData); + + for (const field of getModelFields(entityData)) { + const fieldInfo = await resolveField(this.options.modelMeta, realModel, field); + + if (!fieldInfo) { + continue; + } + + const shouldDecrypt = fieldInfo.attributes?.find((attr) => attr.name === '@encrypted'); + if (shouldDecrypt) { + // Don't decrypt null, undefined or empty string values + if (!entityData[field]) continue; + + try { + entityData[field] = await this.decrypt(fieldInfo, entityData[field]); + } catch (error) { + console.warn('Decryption failed, keeping original value:', error); + } + } + } + } + + private async preprocessWritePayload(model: string, action: PrismaWriteActionType, args: any) { + const visitor = new NestedWriteVisitor(this.options.modelMeta, { + field: async (field, _action, data, context) => { + // Don't encrypt null, undefined or empty string values + if (!data) return; + + const encAttr = field.attributes?.find((attr) => attr.name === '@encrypted'); + if (encAttr && field.type === 'String') { + try { + context.parent[field.name] = await this.encrypt(field, data); + } catch (error) { + throw new Error(`Encryption failed for field ${field.name}: ${error}`); + } + } + }, + }); + + await visitor.visit(model, action, args); + } +} diff --git a/packages/runtime/src/types.ts b/packages/runtime/src/types.ts index 7c4df97c1..e691fc32c 100644 --- a/packages/runtime/src/types.ts +++ b/packages/runtime/src/types.ts @@ -1,6 +1,7 @@ /* eslint-disable @typescript-eslint/no-explicit-any */ import type { z } from 'zod'; +import { FieldInfo } from './cross'; export type PrismaPromise<T> = Promise<T> & Record<string, (args?: any) => PrismaPromise<any>>; @@ -133,6 +134,11 @@ export type EnhancementOptions = { * The `isolationLevel` option passed to `prisma.$transaction()` call for transactions initiated by ZenStack. */ transactionIsolationLevel?: TransactionIsolationLevel; + + /** + * The encryption options for using the `encrypted` enhancement. + */ + encryption?: SimpleEncryption | CustomEncryption; }; /** @@ -145,7 +151,7 @@ export type EnhancementContext<User extends AuthUser = AuthUser> = { /** * Kinds of enhancements to `PrismaClient` */ -export type EnhancementKind = 'password' | 'omit' | 'policy' | 'validation' | 'delegate'; +export type EnhancementKind = 'password' | 'omit' | 'policy' | 'validation' | 'delegate' | 'encrypted'; /** * Function for transforming errors. @@ -166,3 +172,10 @@ export type ZodSchemas = { */ input?: Record<string, Record<string, z.ZodSchema>>; }; + +export type CustomEncryption = { + encrypt: (model: string, field: FieldInfo, plain: string) => Promise<string>; + decrypt: (model: string, field: FieldInfo, cipher: string) => Promise<string>; +}; + +export type SimpleEncryption = { encryptionKey: Uint8Array }; diff --git a/packages/schema/src/res/stdlib.zmodel b/packages/schema/src/res/stdlib.zmodel index 3316a90a9..3f1ba0efc 100644 --- a/packages/schema/src/res/stdlib.zmodel +++ b/packages/schema/src/res/stdlib.zmodel @@ -552,6 +552,14 @@ attribute @@auth() @@@supportTypeDef */ attribute @password(saltLength: Int?, salt: String?) @@@targetField([StringField]) + +/** + * Indicates that the field is encrypted when storing in the DB and should be decrypted when read + * + * ZenStack uses the Web Crypto API to encrypt and decrypt the field. + */ +attribute @encrypted() @@@targetField([StringField]) + /** * Indicates that the field should be omitted when read from the generated services. */ diff --git a/tests/integration/tests/enhancements/with-encrypted/with-encrypted.test.ts b/tests/integration/tests/enhancements/with-encrypted/with-encrypted.test.ts new file mode 100644 index 000000000..1e0544c0b --- /dev/null +++ b/tests/integration/tests/enhancements/with-encrypted/with-encrypted.test.ts @@ -0,0 +1,108 @@ +import { FieldInfo } from '@zenstackhq/runtime'; +import { loadSchema } from '@zenstackhq/testtools'; +import path from 'path'; + +describe('Encrypted test', () => { + let origDir: string; + + beforeAll(async () => { + origDir = path.resolve('.'); + }); + + afterEach(async () => { + process.chdir(origDir); + }); + + it('Simple encryption test', async () => { + const { enhance } = await loadSchema(` + model User { + id String @id @default(cuid()) + encrypted_value String @encrypted() + + @@allow('all', true) + }`); + + const sudoDb = enhance(undefined, { kinds: [] }); + const encryptionKey = new Uint8Array(Buffer.from('AAECAwQFBgcICQoLDA0ODxAREhMUFRYXGBkaGxwdHh8=', 'base64')); + + const db = enhance(undefined, { + kinds: ['encrypted'], + encryption: { encryptionKey }, + }); + + const create = await db.user.create({ + data: { + id: '1', + encrypted_value: 'abc123', + }, + }); + + const read = await db.user.findUnique({ + where: { + id: '1', + }, + }); + + const sudoRead = await sudoDb.user.findUnique({ + where: { + id: '1', + }, + }); + + expect(create.encrypted_value).toBe('abc123'); + expect(read.encrypted_value).toBe('abc123'); + expect(sudoRead.encrypted_value).not.toBe('abc123'); + }); + + it('Custom encryption test', async () => { + const { enhance } = await loadSchema(` + model User { + id String @id @default(cuid()) + encrypted_value String @encrypted() + + @@allow('all', true) + }`); + + const sudoDb = enhance(undefined, { kinds: [] }); + const db = enhance(undefined, { + kinds: ['encrypted'], + encryption: { + encrypt: async (model: string, field: FieldInfo, data: string) => { + // Add _enc to the end of the input + return data + '_enc'; + }, + decrypt: async (model: string, field: FieldInfo, cipher: string) => { + // Remove _enc from the end of the input explicitly + if (cipher.endsWith('_enc')) { + return cipher.slice(0, -4); // Remove last 4 characters (_enc) + } + + return cipher; + }, + }, + }); + + const create = await db.user.create({ + data: { + id: '1', + encrypted_value: 'abc123', + }, + }); + + const read = await db.user.findUnique({ + where: { + id: '1', + }, + }); + + const sudoRead = await sudoDb.user.findUnique({ + where: { + id: '1', + }, + }); + + expect(create.encrypted_value).toBe('abc123'); + expect(read.encrypted_value).toBe('abc123'); + expect(sudoRead.encrypted_value).toBe('abc123_enc'); + }); +});