Skip to content

feat: add tenant-id to lambda context and structured log message #187

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion awslambdaric/bootstrap.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,7 @@ def handle_event_request(
cognito_identity_json,
invoked_function_arn,
epoch_deadline_time_in_ms,
tenant_id,
log_sink,
):
error_result = None
Expand All @@ -168,6 +169,7 @@ def handle_event_request(
epoch_deadline_time_in_ms,
invoke_id,
invoked_function_arn,
tenant_id,
)
event = lambda_runtime_client.marshaller.unmarshal_request(
event_body, content_type
Expand Down Expand Up @@ -229,6 +231,7 @@ def create_lambda_context(
epoch_deadline_time_in_ms,
invoke_id,
invoked_function_arn,
tenant_id,
):
client_context = None
if client_context_json:
Expand All @@ -243,6 +246,7 @@ def create_lambda_context(
cognito_identity,
epoch_deadline_time_in_ms,
invoked_function_arn,
tenant_id,
)


Expand Down Expand Up @@ -337,6 +341,7 @@ def emit(self, record):
class LambdaLoggerFilter(logging.Filter):
def filter(self, record):
record.aws_request_id = _GLOBAL_AWS_REQUEST_ID or ""
record.tenant_id = _GLOBAL_TENANT_ID
return True


Expand Down Expand Up @@ -445,6 +450,7 @@ def create_log_sink():


_GLOBAL_AWS_REQUEST_ID = None
_GLOBAL_TENANT_ID = None


def _setup_logging(log_format, log_level, log_sink):
Expand Down Expand Up @@ -490,7 +496,7 @@ def run(app_root, handler, lambda_runtime_api_addr):

try:
_setup_logging(_AWS_LAMBDA_LOG_FORMAT, _AWS_LAMBDA_LOG_LEVEL, log_sink)
global _GLOBAL_AWS_REQUEST_ID
global _GLOBAL_AWS_REQUEST_ID, _GLOBAL_TENANT_ID

request_handler = _get_handler(handler)
except FaultException as e:
Expand All @@ -515,6 +521,7 @@ def run(app_root, handler, lambda_runtime_api_addr):
event_request = lambda_runtime_client.wait_next_invocation()

_GLOBAL_AWS_REQUEST_ID = event_request.invoke_id
_GLOBAL_TENANT_ID = event_request.tenant_id

update_xray_env_variable(event_request.x_amzn_trace_id)

Expand All @@ -528,5 +535,6 @@ def run(app_root, handler, lambda_runtime_api_addr):
event_request.cognito_identity,
event_request.invoked_function_arn,
event_request.deadline_time_in_ms,
event_request.tenant_id,
log_sink,
)
5 changes: 4 additions & 1 deletion awslambdaric/lambda_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ def __init__(
cognito_identity,
epoch_deadline_time_in_ms,
invoked_function_arn=None,
tenant_id=None,
):
self.aws_request_id = invoke_id
self.log_group_name = os.environ.get("AWS_LAMBDA_LOG_GROUP_NAME")
Expand All @@ -24,6 +25,7 @@ def __init__(
self.memory_limit_in_mb = os.environ.get("AWS_LAMBDA_FUNCTION_MEMORY_SIZE")
self.function_version = os.environ.get("AWS_LAMBDA_FUNCTION_VERSION")
self.invoked_function_arn = invoked_function_arn
self.tenant_id = tenant_id

self.client_context = make_obj_from_dict(ClientContext, client_context)
if self.client_context is not None:
Expand Down Expand Up @@ -65,7 +67,8 @@ def __repr__(self):
f"function_version={self.function_version},"
f"invoked_function_arn={self.invoked_function_arn},"
f"client_context={self.client_context},"
f"identity={self.identity}"
f"identity={self.identity},"
f"tenant_id={self.tenant_id}"
"])"
)

Expand Down
1 change: 1 addition & 0 deletions awslambdaric/lambda_runtime_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,7 @@ def wait_next_invocation(self):
deadline_time_in_ms=headers.get("Lambda-Runtime-Deadline-Ms"),
client_context=headers.get("Lambda-Runtime-Client-Context"),
cognito_identity=headers.get("Lambda-Runtime-Cognito-Identity"),
tenant_id=headers.get("Lambda-Runtime-Aws-Tenant-Id"),
content_type=headers.get("Content-Type"),
event_body=response_body,
)
Expand Down
4 changes: 4 additions & 0 deletions awslambdaric/lambda_runtime_log_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
"processName",
"process",
"aws_request_id",
"tenant_id",
"_frame_type",
}

Expand Down Expand Up @@ -124,6 +125,9 @@ def format(self, record: logging.LogRecord) -> str:
"requestId": getattr(record, "aws_request_id", None),
"location": self.__format_location(record),
}
if hasattr(record, "tenant_id") and record.tenant_id is not None:
result["tenantId"] = record.tenant_id

result.update(
(key, value)
for key, value in record.__dict__.items()
Expand Down
6 changes: 4 additions & 2 deletions awslambdaric/runtime_client.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,17 +52,19 @@ static PyObject *method_next(PyObject *self) {
auto client_context = response.client_context.c_str();
auto content_type = response.content_type.c_str();
auto cognito_id = response.cognito_identity.c_str();
auto tenant_id = response.tenant_id.c_str();

PyObject *payload_bytes = PyBytes_FromStringAndSize(payload.c_str(), payload.length());
PyObject *result = Py_BuildValue("(O,{s:s,s:s,s:s,s:l,s:s,s:s,s:s})",
PyObject *result = Py_BuildValue("(O,{s:s,s:s,s:s,s:l,s:s,s:s,s:s,s:s})",
payload_bytes, //Py_BuildValue() increments reference counter
"Lambda-Runtime-Aws-Request-Id", request_id,
"Lambda-Runtime-Trace-Id", NULL_IF_EMPTY(trace_id),
"Lambda-Runtime-Invoked-Function-Arn", function_arn,
"Lambda-Runtime-Deadline-Ms", deadline,
"Lambda-Runtime-Client-Context", NULL_IF_EMPTY(client_context),
"Content-Type", NULL_IF_EMPTY(content_type),
"Lambda-Runtime-Cognito-Identity", NULL_IF_EMPTY(cognito_id)
"Lambda-Runtime-Cognito-Identity", NULL_IF_EMPTY(cognito_id),
"Lambda-Runtime-Aws-Tenant-Id", NULL_IF_EMPTY(tenant_id)
);

Py_XDECREF(payload_bytes);
Expand Down
Binary file modified deps/aws-lambda-cpp-0.2.6.tar.gz
Binary file not shown.
39 changes: 39 additions & 0 deletions deps/patches/aws-lambda-cpp-add-tenant-id.patch
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
diff --git a/include/aws/lambda-runtime/runtime.h b/include/aws/lambda-runtime/runtime.h
index 7812ff6..96be869 100644
--- a/include/aws/lambda-runtime/runtime.h
+++ b/include/aws/lambda-runtime/runtime.h
@@ -61,6 +61,11 @@ struct invocation_request {
*/
std::string content_type;

+ /**
+ * The Tenant ID of the current invocation.
+ */
+ std::string tenant_id;
+
/**
* Function execution deadline counted in milliseconds since the Unix epoch.
*/
diff --git a/src/runtime.cpp b/src/runtime.cpp
index e53b2b8..9763282 100644
--- a/src/runtime.cpp
+++ b/src/runtime.cpp
@@ -40,6 +40,7 @@ static constexpr auto CLIENT_CONTEXT_HEADER = "lambda-runtime-client-context";
static constexpr auto COGNITO_IDENTITY_HEADER = "lambda-runtime-cognito-identity";
static constexpr auto DEADLINE_MS_HEADER = "lambda-runtime-deadline-ms";
static constexpr auto FUNCTION_ARN_HEADER = "lambda-runtime-invoked-function-arn";
+static constexpr auto TENANT_ID_HEADER = "lambda-runtime-aws-tenant-id";

enum Endpoints {
INIT,
@@ -289,6 +290,10 @@ runtime::next_outcome runtime::get_next()
req.function_arn = resp.get_header(FUNCTION_ARN_HEADER);
}

+ if (resp.has_header(TENANT_ID_HEADER)) {
+ req.tenant_id = resp.get_header(TENANT_ID_HEADER);
+ }
+
if (resp.has_header(DEADLINE_MS_HEADER)) {
auto const& deadline_string = resp.get_header(DEADLINE_MS_HEADER);
constexpr int base = 10;
3 changes: 2 additions & 1 deletion scripts/update_deps.sh
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,8 @@ wget -c https://github.com/awslabs/aws-lambda-cpp/archive/v$AWS_LAMBDA_CPP_RELEA
patch -p1 < ../patches/aws-lambda-cpp-posting-init-errors.patch && \
patch -p1 < ../patches/aws-lambda-cpp-make-the-runtime-client-user-agent-overrideable.patch && \
patch -p1 < ../patches/aws-lambda-cpp-make-lto-optional.patch && \
patch -p1 < ../patches/aws-lambda-cpp-add-content-type.patch
patch -p1 < ../patches/aws-lambda-cpp-add-content-type.patch && \
patch -p1 < ../patches/aws-lambda-cpp-add-tenant-id.patch
)

## Pack again and remove the folder
Expand Down
44 changes: 44 additions & 0 deletions tests/test_bootstrap.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ def test_handle_event_request_happy_case(self):
{},
"invoked_function_arn",
0,
"tenant_id",
bootstrap.StandardLogSink(),
)
self.lambda_runtime.post_invocation_result.assert_called_once_with(
Expand All @@ -111,6 +112,7 @@ def test_handle_event_request_invalid_client_context(self):
{},
"invoked_function_arn",
0,
"tenant_id",
bootstrap.StandardLogSink(),
)
args, _ = self.lambda_runtime.post_invocation_error.call_args
Expand Down Expand Up @@ -152,6 +154,7 @@ def test_handle_event_request_invalid_cognito_idenity(self):
"invalid_cognito_identity",
"invoked_function_arn",
0,
"tenant_id",
bootstrap.StandardLogSink(),
)
args, _ = self.lambda_runtime.post_invocation_error.call_args
Expand Down Expand Up @@ -194,6 +197,7 @@ def test_handle_event_request_invalid_event_body(self):
{},
"invoked_function_arn",
0,
"tenant_id",
bootstrap.StandardLogSink(),
)
args, _ = self.lambda_runtime.post_invocation_error.call_args
Expand Down Expand Up @@ -238,6 +242,7 @@ def invalid_json_response(json_input, lambda_context):
{},
"invoked_function_arn",
0,
"tenant_id",
bootstrap.StandardLogSink(),
)
args, _ = self.lambda_runtime.post_invocation_error.call_args
Expand Down Expand Up @@ -283,6 +288,7 @@ def __init__(self, message):
{},
"invoked_function_arn",
0,
"tenant_id",
bootstrap.StandardLogSink(),
)
args, _ = self.lambda_runtime.post_invocation_error.call_args
Expand Down Expand Up @@ -335,6 +341,7 @@ def __init__(self, message):
{},
"invoked_function_arn",
0,
"tenant_id",
bootstrap.StandardLogSink(),
)
args, _ = self.lambda_runtime.post_invocation_error.call_args
Expand Down Expand Up @@ -386,6 +393,7 @@ def unable_to_import_module(json_input, lambda_context):
{},
"invoked_function_arn",
0,
"tenant_id",
bootstrap.StandardLogSink(),
)
args, _ = self.lambda_runtime.post_invocation_error.call_args
Expand Down Expand Up @@ -425,6 +433,7 @@ def raise_exception_handler(json_input, lambda_context):
{},
"invoked_function_arn",
0,
"tenant_id",
bootstrap.StandardLogSink(),
)
args, _ = self.lambda_runtime.post_invocation_error.call_args
Expand Down Expand Up @@ -475,6 +484,7 @@ def raise_exception_handler(json_input, lambda_context):
{},
"invoked_function_arn",
0,
"tenant_id",
bootstrap.StandardLogSink(),
)

Expand Down Expand Up @@ -514,6 +524,7 @@ def raise_exception_handler(json_input, lambda_context):
{},
"invoked_function_arn",
0,
"tenant_id",
bootstrap.StandardLogSink(),
)
error_logs = (
Expand Down Expand Up @@ -546,6 +557,7 @@ def raise_exception_handler(json_input, lambda_context):
{},
"invoked_function_arn",
0,
"tenant_id",
bootstrap.StandardLogSink(),
)
error_logs = (
Expand Down Expand Up @@ -578,6 +590,7 @@ def raise_exception_handler(json_input, lambda_context):
{},
"invoked_function_arn",
0,
"tenant_id",
bootstrap.StandardLogSink(),
)
error_logs = (
Expand Down Expand Up @@ -619,6 +632,7 @@ def raise_exception_handler(json_input, lambda_context):
{},
"invoked_function_arn",
0,
"tenant_id",
bootstrap.StandardLogSink(),
)
error_logs = lambda_unhandled_exception_warning_message + "\n[ERROR]\r"
Expand Down Expand Up @@ -652,6 +666,7 @@ def raise_exception_handler(json_input, lambda_context):
{},
"invoked_function_arn",
0,
"tenant_id",
bootstrap.StandardLogSink(),
)

Expand Down Expand Up @@ -868,6 +883,7 @@ def test_application_json(self):
cognito_identity_json=None,
invoked_function_arn="invocation-arn",
epoch_deadline_time_in_ms=1415836801003,
tenant_id=None,
log_sink=bootstrap.StandardLogSink(),
)

Expand All @@ -887,6 +903,7 @@ def test_binary_request_binary_response(self):
cognito_identity_json=None,
invoked_function_arn="invocation-arn",
epoch_deadline_time_in_ms=1415836801003,
tenant_id=None,
log_sink=bootstrap.StandardLogSink(),
)

Expand All @@ -906,6 +923,7 @@ def test_json_request_binary_response(self):
cognito_identity_json=None,
invoked_function_arn="invocation-arn",
epoch_deadline_time_in_ms=1415836801003,
tenant_id=None,
log_sink=bootstrap.StandardLogSink(),
)

Expand All @@ -924,6 +942,7 @@ def test_binary_with_application_json(self):
cognito_identity_json=None,
invoked_function_arn="invocation-arn",
epoch_deadline_time_in_ms=1415836801003,
tenant_id=None,
log_sink=bootstrap.StandardLogSink(),
)

Expand Down Expand Up @@ -1357,6 +1376,31 @@ def test_json_formatter(self, mock_stderr):
)
self.assertEqual(mock_stderr.getvalue(), "")

@patch("awslambdaric.bootstrap._GLOBAL_TENANT_ID", "test-tenant-id")
@patch("sys.stderr", new_callable=StringIO)
def test_json_formatter_with_tenant_id(self, mock_stderr):
logger = logging.getLogger("a.b")
level = logging.INFO
message = "Test json formatting with tenant id"
expected = {
"level": "INFO",
"logger": "a.b",
"message": message,
"requestId": "",
"tenantId": "test-tenant-id",
}

with patch("sys.stdout", new_callable=StringIO) as mock_stdout:
logger.log(level, message)

data = json.loads(mock_stdout.getvalue())
data.pop("timestamp")
self.assertEqual(
data,
expected,
)
self.assertEqual(mock_stderr.getvalue(), "")

@patch("sys.stdout", new_callable=StringIO)
@patch("sys.stderr", new_callable=StringIO)
def test_exception(self, mock_stderr, mock_stdout):
Expand Down
Loading