diff --git a/README.md b/README.md index 5c1941d1..e8e96b86 100644 --- a/README.md +++ b/README.md @@ -64,6 +64,7 @@ Some examples require extra dependencies. See each sample's directory for specif * [custom_metric](custom_metric) - Custom metric to record the workflow type in the activity schedule to start latency. * [dsl](dsl) - DSL workflow that executes steps defined in a YAML file. * [encryption](encryption) - Apply end-to-end encryption for all input/output. +* [encryption_jwt](encryption_jwt) - Apply end-to-end encryption for all input/output using a KMS and per-namespace JWT-based auth. * [gevent_async](gevent_async) - Combine gevent and Temporal. * [langchain](langchain) - Orchestrate workflows for LangChain. * [message_passing/introduction](message_passing/introduction/) - Introduction to queries, signals, and updates. diff --git a/encryption_jwt/.gitignore b/encryption_jwt/.gitignore new file mode 100644 index 00000000..f8e13ab1 --- /dev/null +++ b/encryption_jwt/.gitignore @@ -0,0 +1 @@ +_certs diff --git a/encryption_jwt/README.md b/encryption_jwt/README.md new file mode 100644 index 00000000..83ec1d7c --- /dev/null +++ b/encryption_jwt/README.md @@ -0,0 +1,144 @@ +# Encryption with Temporal user role access + +This sample demonstrates: + +- CORS settings to allow connections to a codec server +- using a KMS key to encrypt/decrypt payloads +- extracting data from a JWT +- controlling decyption based on a user's Temporal Cloud role + +The Codec Server uses the [Operations API](https://docs.temporal.io/ops) to get user information. It would be helpful to be familiar with the API's requirements. This API is currently a beta relase and may change in the future. + +## Install + +For this sample, the optional `encryption_jwt` and `bedrock` dependency groups must be included. To include, run: + +```sh +poetry install --with encryption_jwt,bedrock +``` + +## Setup + +> [!WARNING] +> You must connect your Worker(s) to Temporal Cloud to see decryption working in the Web UI. + +### Key management + +This example uses the [AWS Key Management Service](https://aws.amazon.com/kms/) (KMS). You will need +to create a "Customer managed key" with its Alias set to your Temporal Namespace (replace `.`s with `_`s). +Alternately replace the key management portion with your own implementation. + +### Self-signed certificates + +The codec server will need to use HTTPS, self-signed certificates will work in the development +environment. Run the following command in a `_certs` directory that's a subdirectory of this one. +It will create certificate files that are good for 10 years. + +```sh +openssl req -x509 -newkey rsa:4096 -sha256 -days 3650 -nodes -keyout localhost.key -out localhost.pem -subj "/CN=localhost" +``` + +In the projects you can access the files using the following relative paths. + +- `./_certs/localhost.pem` +- `./_certs/localhost.key` + +## Run + +### Worker + +To run, first see the [repo README.md](../README.md) for prerequisites. + +Before starting the worker, open a terminal and add the following environment variables with +appropriate values: + +```sh +export TEMPORAL_ADDRESS= +export TEMPORAL_TLS_CERT= +export TEMPORAL_TLS_KEY= +export AWS_ACCESS_KEY_ID= +export AWS_SECRET_ACCESS_KEY= +export AWS_SESSION_TOKEN= +``` + +In the same terminal start the worker: + +```sh +poetry run python worker.py +``` + +> [!Note] +> You will need to run at least one Worker per-namespace. + +### Codec server + +The codec server allows you to see the encrypted payloads of workflows in the Web UI. The server +must be started with secure connections (HTTPS), you will need the paths to a pem (crt) and key +file. [Self-signed certificates](#self-signed-certificates) will work just fine. + +You will also need a [Temporal API Key](https://docs.temporal.io/cloud/api-keys#generate-an-api-key). It's value is set using the `TEMPORAL_API_KEY` env var. + +Open a new terminal and add the following environment variables with values: + +```sh +export TEMPORAL_TLS_CERT= +export TEMPORAL_TLS_KEY= +export TEMPORAL_API_KEY= # see https://docs.temporal.io/cloud/tcld/apikey#create +export TEMPORAL_OPS_ADDRESS=saas-api.tmprl.cloud:443 # uses "saas-api.tmprl.cloud:443" if not provided +export TEMPORAL_OPS_API_VERSION=2024-05-13-00 +export AWS_ACCESS_KEY_ID= +export AWS_SECRET_ACCESS_KEY= +export AWS_SESSION_TOKEN= +export SSL_PEM= +export SSL_KEY= +``` + +In the same terminal start the codec server: + +```sh +poetry run python codec_server.py +``` + +### Execute workflow + +In a third terminal, add the environment variables: + +```txt +export TEMPORAL_ADDRESS= +export TEMPORAL_TLS_CERT= +export TEMPORAL_TLS_KEY= +``` + +Then run the command to execute the workflow: + +```sh +poetry run python starter.py +``` + +The workflow should complete with the hello result. To view the workflow, use [temporal](https://docs.temporal.io/cli): + +```sh +temporal workflow show --workflow-id encryption-workflow-id +``` + +Note how the result looks (with wrapping removed): + +```txt +Output:[encoding binary/encrypted: payload encoding is not supported] +``` + +This is because the data is encrypted and not visible. + +## Temporal Web UI + +Open the Web UI and select a workflow, you'll only see encrypted results. To see decrypted results: + +- You must have the Temporal role of "admin" +- The codec server must be running +- Set the "Remote Codec Endpoint" in the web UI to the codec server domain: `https://localhost:8081` + - Both the "Pass the user access token" and "Include cross-origin credentials" must be enabled + +Once those requirements are met you can then see the unencrypted results. This is possible because +CORS settings in the codec server allow the browser to access the codec server directly over +localhost. Decrypted data never leaves your local machine. See [Codec +Server](https://docs.temporal.io/production-deployment/data-encryption) diff --git a/encryption_jwt/__init__.py b/encryption_jwt/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/encryption_jwt/codec.py b/encryption_jwt/codec.py new file mode 100644 index 00000000..64df452d --- /dev/null +++ b/encryption_jwt/codec.py @@ -0,0 +1,37 @@ +from typing import Iterable, List + +from temporalio.api.common.v1 import Payload +from temporalio.converter import PayloadCodec + +from encryption_jwt.encryptor import KMSEncryptor + + +class EncryptionCodec(PayloadCodec): + def __init__(self, namespace: str): + self._encryptor = KMSEncryptor(namespace) + + async def encode(self, payloads: Iterable[Payload]) -> List[Payload]: + # We blindly encode all payloads with the key and set the metadata with the key that was + # used (base64 encoded). + + async def encrypt_payload(p: Payload): + data, key = await self._encryptor.encrypt(p.SerializeToString()) + return Payload( + metadata={ + "encoding": b"binary/encrypted", + "data_key_encrypted": key, + }, + data=data, + ) + + # return list(map(encrypt_payload, payloads)) + return [await encrypt_payload(payload) for payload in payloads] + + async def decode(self, payloads: Iterable[Payload]) -> List[Payload]: + async def decrypt_payload(p: Payload): + data_key_encrypted_base64 = p.metadata.get("data_key_encrypted", b"") + data = await self._encryptor.decrypt(data_key_encrypted_base64, p.data) + return Payload.FromString(data) + + # return list(map(decrypt_payload, payloads)) + return [await decrypt_payload(payload) for payload in payloads] diff --git a/encryption_jwt/codec_server.py b/encryption_jwt/codec_server.py new file mode 100644 index 00000000..cfc54d02 --- /dev/null +++ b/encryption_jwt/codec_server.py @@ -0,0 +1,165 @@ +import logging +import os +import ssl + +import jwt +import requests +from aiohttp import hdrs, web +from google.protobuf import json_format +from jwt import PyJWK +from jwt.algorithms import RSAAlgorithm +from temporalio.api.cloud.cloudservice.v1 import GetUsersRequest +from temporalio.api.common.v1 import Payloads +from temporalio.client import CloudOperationsClient + +from encryption_jwt.codec import EncryptionCodec + +AUTHORIZED_ACCOUNT_ACCESS_ROLES = ["owner", "admin"] +AUTHORIZED_NAMESPACE_ACCESS_ROLES = ["read", "write", "admin"] + +TEMPORAL_CLIENT_CLOUD_API_VERSION = "2024-05-13-00" + +temporal_ops_address = ( + os.environ.get("TEMPORAL_OPS_ADDRESS") or "saas-api.tmprl.cloud:443" +) + + +def build_codec_server() -> web.Application: + # Cors handler + async def cors_options(req: web.Request) -> web.Response: + resp = web.Response() + + if req.headers.get(hdrs.ORIGIN) == "http://localhost:8080": + logger.info("Setting CORS headers for localhost") + resp.headers[hdrs.ACCESS_CONTROL_ALLOW_ORIGIN] = "http://localhost:8080" + + elif req.headers.get(hdrs.ORIGIN) == "https://cloud.temporal.io": + logger.info("Setting CORS headers for cloud.temporal.io") + resp.headers[hdrs.ACCESS_CONTROL_ALLOW_ORIGIN] = "https://cloud.temporal.io" + + allow_headers = "content-type,x-namespace" + if req.scheme.lower() == "https": + allow_headers += ",authorization" + resp.headers[hdrs.ACCESS_CONTROL_ALLOW_CREDENTIALS] = "true" + + # common + resp.headers[hdrs.ACCESS_CONTROL_ALLOW_METHODS] = "POST" + resp.headers[hdrs.ACCESS_CONTROL_ALLOW_HEADERS] = allow_headers + + return resp + + async def decryption_authorized(email: str, namespace: str) -> bool: + client = await CloudOperationsClient.connect( + api_key=os.environ.get("TEMPORAL_API_KEY"), + version=TEMPORAL_CLIENT_CLOUD_API_VERSION, + ) + + response = await client.cloud_service.get_users( + GetUsersRequest(namespace=namespace) + ) + + for user in response.users: + if user.spec.email.lower() == email.lower(): + if ( + user.spec.access.account_access.role + in AUTHORIZED_ACCOUNT_ACCESS_ROLES + ): + return True + else: + if namespace in user.spec.access.namespace_accesses: + if ( + user.spec.access.namespace_accesses[namespace].permission + in AUTHORIZED_NAMESPACE_ACCESS_ROLES + ): + return True + + return False + + def make_handler(fn: str): + async def handler(req: web.Request): + namespace = req.headers.get("x-namespace") or "default" + auth_header = req.headers.get("Authorization") or "" + _bearer, encoded = auth_header.split(" ") + + # Extract the kid from the Auth header + jwt_dict = jwt.get_unverified_header(encoded) + kid = jwt_dict["kid"] + algorithm = jwt_dict["alg"] + + # Fetch Temporal Cloud JWKS + jwks_url = "https://login.tmprl.cloud/.well-known/jwks.json" + jwks = requests.get(jwks_url).json() + + # Extract Temporal Cloud's public key + pyjwk = None + for key in jwks["keys"]: + if key["kid"] == kid: + # Convert JWKS key to PEM format + pyjwk = PyJWK.from_dict(key) + break + + if pyjwk is None: + raise ValueError("Public key not found in JWKS") + + # Decode the jwt, verifying against Temporal Cloud's public key + decoded = jwt.decode( + encoded, + pyjwk.key, + algorithms=[algorithm], + audience=[ + "https://saas-api.tmprl.cloud", + "https://prod-tmprl.us.auth0.com/userinfo", + ], + ) + + # Use the email to determine if the user is authorized to decrypt the payload + authorized = await decryption_authorized( + decoded["https://saas-api.tmprl.cloud/user/email"], namespace + ) + + if authorized: + # Read payloads as JSON + assert req.content_type == "application/json" + payloads = json_format.Parse(await req.read(), Payloads()) + encryptionCodec = EncryptionCodec(namespace) + payloads = Payloads( + payloads=await getattr(encryptionCodec, fn)(payloads.payloads) + ) + + # Apply CORS and return JSON + resp = await cors_options(req) + resp.content_type = "application/json" + resp.text = json_format.MessageToJson(payloads) + return resp + + return handler + + # Build app + app = web.Application() + # set up logger + logging.basicConfig(level=logging.DEBUG) + logger = logging.getLogger(__name__) + app.add_routes( + [ + web.post("/encode", make_handler("encode")), + web.post("/decode", make_handler("decode")), + web.options("/decode", cors_options), + ] + ) + + return app + + +if __name__ == "__main__": + # pylint: disable=C0103 + ssl_context = None + if os.environ.get("SSL_PEM") and os.environ.get("SSL_KEY"): + ssl_context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH) + ssl_context.check_hostname = False + ssl_context.load_cert_chain( + os.environ.get("SSL_PEM") or "", os.environ.get("SSL_KEY") or "" + ) + + web.run_app( + build_codec_server(), host="0.0.0.0", port=8081, ssl_context=ssl_context + ) diff --git a/encryption_jwt/encryptor.py b/encryption_jwt/encryptor.py new file mode 100644 index 00000000..a0e324d9 --- /dev/null +++ b/encryption_jwt/encryptor.py @@ -0,0 +1,89 @@ +import asyncio +import base64 +import logging +import os +import typing + +from botocore.exceptions import ClientError +from cryptography.hazmat.primitives.ciphers.aead import AESGCM +from temporalio import workflow + +with workflow.unsafe.imports_passed_through(): + import aioboto3 + + +class KMSEncryptor: + """Encrypts and decrypts using keys from AWS KMS.""" + + def __init__(self, namespace: str): + self._namespace = namespace + self._boto_session = None + + @property + def boto_session(self): + """Get a KMS client from boto3.""" + if not self._boto_session: + session = aioboto3.Session() + self._boto_session = session + + return self._boto_session + + async def encrypt(self, data: bytes) -> typing.Tuple[bytes, bytes]: + """Encrypt data using a key from KMS.""" + # The keys are rotated automatically by KMS, so fetch a new key to encrypt the data. + data_key_encrypted, data_key_plaintext = await self.__create_data_key( + self._namespace + ) + + if data_key_encrypted is None: + raise ValueError("No data key!") + + nonce = os.urandom(12) + encryptor = AESGCM(data_key_plaintext) + encrypted = asyncio.get_running_loop().run_in_executor( + None, encryptor.encrypt, nonce, data, None + ) + return nonce + await encrypted, base64.b64encode(data_key_encrypted) + + async def decrypt(self, data_key_encrypted_base64, data: bytes) -> bytes: + """Encrypt data using a key from KMS.""" + data_key_encrypted = base64.b64decode(data_key_encrypted_base64) + data_key_plaintext = await self.__decrypt_data_key(data_key_encrypted) + encryptor = AESGCM(data_key_plaintext) + decrypted = await asyncio.get_running_loop().run_in_executor( + None, encryptor.decrypt, data[:12], data[12:], None + ) + return decrypted + + async def __create_data_key(self, namespace: str): + """Get a set of keys from AWS KMS that can be used to encrypt data.""" + + # Create data key + alias_name = "alias/" + namespace.replace(".", "_") + async with self.boto_session.client("kms") as kms_client: + response = await kms_client.describe_key(KeyId=alias_name) + cmk_id = response["KeyMetadata"]["Arn"] + key_spec = "AES_256" + try: + response = await kms_client.generate_data_key( + KeyId=cmk_id, KeySpec=key_spec + ) + except ClientError as e: + logging.error(e) + return None, None + + # Return the encrypted and plaintext data key + return response["CiphertextBlob"], response["Plaintext"] + + async def __decrypt_data_key(self, data_key_encrypted): + """Use AWS KMS to exchange an encrypted key for its plaintext value.""" + + async with self.boto_session.client("kms") as kms_client: + # Decrypt the data key + try: + response = await kms_client.decrypt(CiphertextBlob=data_key_encrypted) + except ClientError as e: + logging.error(e) + return None + + return response["Plaintext"] diff --git a/encryption_jwt/starter.py b/encryption_jwt/starter.py new file mode 100644 index 00000000..1c5cdcae --- /dev/null +++ b/encryption_jwt/starter.py @@ -0,0 +1,64 @@ +import argparse +import asyncio +import dataclasses +import os + +import temporalio.converter +from temporalio.client import Client, TLSConfig + +from encryption_jwt.codec import EncryptionCodec +from encryption_jwt.worker import GreetingWorkflow + +temporal_address = "localhost:7233" +if os.environ.get("TEMPORAL_ADDRESS"): + temporal_address = os.environ["TEMPORAL_ADDRESS"] + +temporal_tls_cert = None +if os.environ.get("TEMPORAL_TLS_CERT"): + temporal_tls_cert_path = os.environ["TEMPORAL_TLS_CERT"] + with open(temporal_tls_cert_path, "rb") as f: + temporal_tls_cert = f.read() + +temporal_tls_key = None +if os.environ.get("TEMPORAL_TLS_KEY"): + temporal_tls_key_path = os.environ["TEMPORAL_TLS_KEY"] + with open(temporal_tls_key_path, "rb") as f: + temporal_tls_key = f.read() + + +async def main(namespace: str): + # Connect client + client = await Client.connect( + temporal_address, + # Use the default converter, but change the codec + data_converter=dataclasses.replace( + temporalio.converter.default(), payload_codec=EncryptionCodec(namespace) + ), + namespace=namespace, + tls=TLSConfig( + client_cert=temporal_tls_cert, + client_private_key=temporal_tls_key, + ) + if temporal_tls_cert and temporal_tls_key + else False, + ) + + # Run workflow + result = await client.execute_workflow( + GreetingWorkflow.run, + "Temporal", + id=f"encryption-workflow-id", + task_queue="encryption-task-queue", + ) + print(f"Workflow result: {result}") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="Run Temporal workflow with a specific namespace." + ) + parser.add_argument( + "namespace", type=str, help="The namespace to pass to the EncryptionCodec" + ) + args = parser.parse_args() + asyncio.run(main(args.namespace)) diff --git a/encryption_jwt/worker.py b/encryption_jwt/worker.py new file mode 100644 index 00000000..d33193be --- /dev/null +++ b/encryption_jwt/worker.py @@ -0,0 +1,82 @@ +import argparse +import asyncio +import dataclasses +import os + +import temporalio.converter +from temporalio import workflow +from temporalio.client import Client, TLSConfig +from temporalio.worker import Worker + +from encryption_jwt.codec import EncryptionCodec + +temporal_address = "localhost:7233" +if os.environ.get("TEMPORAL_ADDRESS"): + temporal_address = os.environ["TEMPORAL_ADDRESS"] + +temporal_tls_cert = None +if os.environ.get("TEMPORAL_TLS_CERT"): + temporal_tls_cert_path = os.environ["TEMPORAL_TLS_CERT"] + with open(temporal_tls_cert_path, "rb") as f: + temporal_tls_cert = f.read() + +temporal_tls_key = None +if os.environ.get("TEMPORAL_TLS_KEY"): + temporal_tls_key_path = os.environ["TEMPORAL_TLS_KEY"] + with open(temporal_tls_key_path, "rb") as f: + temporal_tls_key = f.read() + + +@workflow.defn(name="Workflow") +class GreetingWorkflow: + @workflow.run + async def run(self, name: str) -> str: + return f"Hello, {name}" + + +interrupt_event = asyncio.Event() + + +async def main(namespace: str): + # Connect client + client = await Client.connect( + temporal_address, + # Use the default converter, but change the codec + data_converter=dataclasses.replace( + temporalio.converter.default(), payload_codec=EncryptionCodec(namespace) + ), + namespace=namespace, + tls=TLSConfig( + client_cert=temporal_tls_cert, + client_private_key=temporal_tls_key, + ) + if temporal_tls_cert and temporal_tls_key + else False, + ) + + # Run a worker for the workflow + async with Worker( + client, + task_queue="encryption-task-queue", + workflows=[GreetingWorkflow], + ): + # Wait until interrupted + print("Worker started, ctrl+c to exit") + await interrupt_event.wait() + print("Shutting down") + + +if __name__ == "__main__": + loop = asyncio.new_event_loop() + parser = argparse.ArgumentParser( + description="Run Temporal workflow with a specific namespace." + ) + parser.add_argument( + "namespace", type=str, help="The namespace to pass to the EncryptionCodec" + ) + args = parser.parse_args() + try: + loop.run_until_complete(main(args.namespace)) + except KeyboardInterrupt: + interrupt_event.set() + loop.run_until_complete(loop.shutdown_asyncgens()) diff --git a/pyproject.toml b/pyproject.toml index ea908cd3..2bad7ab9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -34,6 +34,14 @@ encryption = [ "cryptography>=38.0.1,<39", "aiohttp>=3.8.1,<4", ] +encryption_jwt = [ + "cryptography>=38.0.1,<39", + "aiohttp>=3.8.1,<4", + "pyjwt>=2.9.0,<3", + "aioboto3>=13.1.1,<14", + "requests>=2.32.3,<3", + "types-requests>=2.31.0.6,<3" +] gevent = ["gevent==25.4.2 ; python_version >= '3.8'"] langchain = [ "langchain>=0.1.7,<0.2 ; python_version >= '3.8.1' and python_version < '4.0'", @@ -66,6 +74,7 @@ default-groups = [ "bedrock", "dsl", "encryption", + "encryption_jwt", "gevent", "langchain", "open-telemetry", @@ -89,6 +98,7 @@ packages = [ "custom_metric", "dsl", "encryption", + "encryption_jwt", "gevent_async", "hello", "langchain",