Skip to content

Implement multi-namespace search attribute translation #96

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion interceptor/access_control.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,8 @@ func createNamespaceAccessControl(access *auth.AccessControl) stringMatcher {
}

func isNamespaceAccessAllowed(obj any, access *auth.AccessControl) (bool, error) {
notAllowed, err := visitNamespace(obj, createNamespaceAccessControl(access))
v := NewNamespaceVisitor(createNamespaceAccessControl(access))
notAllowed, err := v.Visit(obj)
if err != nil {
return false, err
}
Expand Down
17 changes: 13 additions & 4 deletions interceptor/namespace_translator_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,11 +51,19 @@ type (
)

func TestTranslateNamespaceName(t *testing.T) {
testTranslateObj(t, visitNamespace, generateNamespaceObjCases(), require.Equal)
testTranslateObj(t, generateNamespaceObjCases(), require.Equal,
func(m map[string]string) Visitor {
return NewNamespaceVisitor(createStringMatcher(m))
},
)
}

func TestTranslateNamespaceReplicationMessages(t *testing.T) {
testTranslateObj(t, visitNamespace, generateNamespaceReplicationMessages(), require.EqualExportedValues)
testTranslateObj(t, generateNamespaceReplicationMessages(), require.EqualExportedValues,
func(m map[string]string) Visitor {
return NewNamespaceVisitor(createStringMatcher(m))
},
)
}

func generateNamespaceObjCases() []objCase {
Expand Down Expand Up @@ -497,9 +505,9 @@ func generateNamespaceReplicationMessages() []objCase {
// handle pointer cycles.
func testTranslateObj(
t *testing.T,
visitor visitor,
objCases []objCase,
equalityAssertion func(t require.TestingT, exp, actual any, extra ...any),
createVisitor func(map[string]string) Visitor,
) {
testcases := []struct {
testName string
Expand Down Expand Up @@ -540,7 +548,8 @@ func testTranslateObj(
expOutput := c.makeType(ts.outputName)
expChanged := ts.inputName != ts.outputName

changed, err := visitor(input, createStringMatcher(ts.mapping))
visitor := createVisitor(ts.mapping)
changed, err := visitor.Visit(input)
if len(c.expError) != 0 {
require.ErrorContains(t, err, c.expError)
} else {
Expand Down
107 changes: 81 additions & 26 deletions interceptor/reflection.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,10 @@ import (
"go.temporal.io/server/common/persistence/serialization"
)

const (
namespaceIDFieldName = "NamespaceId"
)

var (
serializer = serialization.NewSerializer()

Expand Down Expand Up @@ -40,19 +44,47 @@ var (
}
)

// stringMatcher returns 2 values:
// 1. new name. If there is no change, new name equals to input name
// 2. whether or not the input name matches the defined rule(s).
type stringMatcher func(name string) (string, bool)
type (
// Visitor will visits an object's fields recursively. It returns an
// implementation-specific bool and error, which typicall indicate if it
// matched anything and if it encountered an unrecoverable error.
Visitor interface {
Visit(any) (bool, error)
}

// visitNamespace uses reflection to recursively visit all fields in the
// given object. When it finds namespace string fields, it invokes the match
// function.
nsVisitor struct {
match stringMatcher
}

// saVisitor uses reflection to recursively visit search attribute fields in the given object.
// It translates search attribute fields according to per-namespace search attribute mappings.
//
// This is not concurrent safe. You must create a separate struct each time.
saVisitor struct {
getNamespaceSAMatcher getSAMatcher

// visitor visits each field in obj matching the matcher.
// It returns whether anything was matched and any error it encountered.
type visitor func(obj any, match stringMatcher) (bool, error)
// currentNamespaceId is internal-state to remember the namespace id set in some parent
// field as the visitor descends recursively into child fields.
currentNamespaceId string
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Switch to a Visitor struct so that we can track currentNamespaceId as we recursively descend into struct.

}

// stringMatcher returns 2 values:
// 1. new name. If there is no change, new name equals to input name
// 2. whether or not the input name matches the defined rule(s).
stringMatcher func(name string) (string, bool)

// getSAMatcher returns a string matcher for a given namespace's search attribute mapping
getSAMatcher func(nsId string) stringMatcher
)

func NewNamespaceVisitor(match stringMatcher) Visitor {
return &nsVisitor{match: match}
}

// visitNamespace uses reflection to recursively visit all fields
// in the given object. When it finds namespace string fields, it invokes
// the provided match function.
func visitNamespace(obj any, match stringMatcher) (bool, error) {
func (v *nsVisitor) Visit(obj any) (bool, error) {
var matched bool

// The visitor function can return Skip, Stop, or Continue to control recursion.
Expand All @@ -65,7 +97,7 @@ func visitNamespace(obj any, match stringMatcher) (bool, error) {

if info, ok := vwp.Interface().(*namespace.NamespaceInfo); ok && info != nil {
// Handle NamespaceInfo.Name in any message.
newName, ok := match(info.Name)
newName, ok := v.match(info.Name)
if !ok {
return visit.Continue, nil
}
Expand All @@ -74,7 +106,7 @@ func visitNamespace(obj any, match stringMatcher) (bool, error) {
}
matched = matched || ok
} else if dataBlobFieldNames[fieldType.Name] {
changed, err := visitDataBlobs(vwp, match, visitNamespace)
changed, err := visitDataBlobs(vwp, v)
matched = matched || changed
if err != nil {
return visit.Stop, err
Expand All @@ -84,7 +116,7 @@ func visitNamespace(obj any, match stringMatcher) (bool, error) {
if !ok {
return visit.Continue, nil
}
newName, ok := match(name)
newName, ok := v.match(name)
if !ok {
return visit.Continue, nil
}
Expand All @@ -101,10 +133,11 @@ func visitNamespace(obj any, match stringMatcher) (bool, error) {
return matched, err
}

// visitSearchAttributes uses reflection to recursively visit all fields
// in the given object. When it finds namespace string fields, it invokes
// the provided match function.
func visitSearchAttributes(obj any, match stringMatcher) (bool, error) {
func MakeSearchAttributeVisitor(getNsSearchAttr getSAMatcher) saVisitor {
return saVisitor{getNamespaceSAMatcher: getNsSearchAttr}
}

func (v *saVisitor) Visit(obj any) (bool, error) {
var matched bool

// The visitor function can return Skip, Stop, or Continue to control recursion.
Expand All @@ -115,13 +148,24 @@ func visitSearchAttributes(obj any, match stringMatcher) (bool, error) {
return action, nil
}

nsId := discoverNamespaceId(vwp)
if nsId != "" {
v.currentNamespaceId = nsId
Comment on lines +151 to +153
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Check for and remember namespace id on each type as we descend.

}

if dataBlobFieldNames[fieldType.Name] {
changed, err := visitDataBlobs(vwp, match, visitSearchAttributes)
changed, err := visitDataBlobs(vwp, v)
matched = matched || changed
if err != nil {
return visit.Stop, err
}
} else if searchAttributeFieldNames[fieldType.Name] {
// Get the per-namespace search attribute mapping
match := v.getNamespaceSAMatcher(v.currentNamespaceId)
if match == nil {
return visit.Continue, nil
}

// This could be *common.SearchAttributes, or it could be map[string]*common.Payload (indexed fields)
var changed bool
switch attrs := vwp.Interface().(type) {
Expand All @@ -148,6 +192,17 @@ func visitSearchAttributes(obj any, match stringMatcher) (bool, error) {
return matched, err
}

func discoverNamespaceId(vwp visit.ValueWithParent) string {
parent := vwp.Parent
if parent.Kind() == reflect.Struct {
typ, ok := parent.Type().FieldByName(namespaceIDFieldName)
if ok && typ.Type.Kind() == reflect.String {
return parent.FieldByName(namespaceIDFieldName).String()
}
}
return ""
}

func translateIndexedFields(fields map[string]*common.Payload, match stringMatcher) (map[string]*common.Payload, bool) {
if fields == nil {
return fields, false
Expand Down Expand Up @@ -178,10 +233,10 @@ func getParentFieldType(vwp visit.ValueWithParent) (result reflect.StructField,
return fieldType, action
}

func visitDataBlobs(vwp visit.ValueWithParent, match stringMatcher, visitor visitor) (bool, error) {
func visitDataBlobs(vwp visit.ValueWithParent, v Visitor) (bool, error) {
switch evt := vwp.Interface().(type) {
case []*common.DataBlob:
newEvts, matched, err := translateDataBlobs(match, visitor, evt...)
newEvts, matched, err := translateDataBlobs(v, evt...)
if err != nil {
return matched, err
}
Expand All @@ -192,7 +247,7 @@ func visitDataBlobs(vwp visit.ValueWithParent, match stringMatcher, visitor visi
}
return matched, nil
case *common.DataBlob:
newEvt, matched, err := translateOneDataBlob(match, visitor, evt)
newEvt, matched, err := translateOneDataBlob(v, evt)
if err != nil {
return matched, err
}
Expand All @@ -207,10 +262,10 @@ func visitDataBlobs(vwp visit.ValueWithParent, match stringMatcher, visitor visi
}
}

func translateDataBlobs(match stringMatcher, visitor visitor, blobs ...*common.DataBlob) ([]*common.DataBlob, bool, error) {
func translateDataBlobs(visitor Visitor, blobs ...*common.DataBlob) ([]*common.DataBlob, bool, error) {
var anyChanged bool
for i, blob := range blobs {
newBlob, changed, err := translateOneDataBlob(match, visitor, blob)
newBlob, changed, err := translateOneDataBlob(visitor, blob)
anyChanged = anyChanged || changed
if err != nil {
return blobs, anyChanged, err
Expand All @@ -220,7 +275,7 @@ func translateDataBlobs(match stringMatcher, visitor visitor, blobs ...*common.D
return blobs, anyChanged, nil
}

func translateOneDataBlob(match stringMatcher, visitor visitor, blob *common.DataBlob) (*common.DataBlob, bool, error) {
func translateOneDataBlob(visitor Visitor, blob *common.DataBlob) (*common.DataBlob, bool, error) {
if blob == nil || len(blob.Data) == 0 {
return blob, false, nil

Expand All @@ -230,7 +285,7 @@ func translateOneDataBlob(match stringMatcher, visitor visitor, blob *common.Dat
return blob, false, err
}

changed, err := visitor(evt, match)
changed, err := visitor.Visit(evt)
if err != nil || !changed {
return blob, changed, err
}
Expand Down
4 changes: 2 additions & 2 deletions interceptor/reflection_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,14 +26,14 @@ func BenchmarkVisitNamespace(b *testing.B) {
for _, c := range cases {
b.Run(c.objName, func(b *testing.B) {
for _, variant := range variants {
translator := createStringMatcher(variant.mapping)
visitor := NewNamespaceVisitor(createStringMatcher(variant.mapping))
b.Run(variant.testName, func(b *testing.B) {
for i := 0; i < b.N; i++ {
b.StopTimer()
input := c.makeType(variant.inputNSName)

b.StartTimer()
_, _ = visitNamespace(input, translator)
_, _ = visitor.Visit(input)
}
})
}
Expand Down
27 changes: 14 additions & 13 deletions interceptor/search_attribute_translator.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package interceptor
import (
"strings"

"go.temporal.io/server/api/adminservice/v1"
"go.temporal.io/server/common/api"
)

Expand Down Expand Up @@ -31,27 +32,27 @@ func (s *saTranslator) MatchMethod(m string) bool {
}

func (s *saTranslator) TranslateRequest(req any) (bool, error) {
return visitSearchAttributes(req, s.getNamespaceReqMatcher(""))
v := MakeSearchAttributeVisitor(s.getNamespaceReqMatcher)
return v.Visit(req)
}

func (s *saTranslator) TranslateResponse(resp any) (bool, error) {
return visitSearchAttributes(resp, s.getNamespaceRespMatcher(""))
func (s *saTranslator) TranslateResponse(req, resp any) (bool, error) {
// Detect namespace id in GetWorkflowExecutionRawHistoryV2Request.
// Use that namespace id to translate search attributes in the response type.
v := MakeSearchAttributeVisitor(s.getNamespaceRespMatcher)
switch val := req.(type) {
case *adminservice.GetWorkflowExecutionRawHistoryV2Request:
v.currentNamespaceId = val.NamespaceId
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Special handling for GetWorkflowExecutionRawHistoryV2Response. Carry over the namespace id found in the corresponding request.

}
return v.Visit(resp)
}

func (s *saTranslator) getNamespaceReqMatcher(namespaceId string) stringMatcher {
// Placeholder: Just return the first one (only support one namespace mapping)
for _, matcher := range s.reqMap {
return matcher
}
return createStringMatcher(nil)
return s.reqMap[namespaceId]
}

func (s *saTranslator) getNamespaceRespMatcher(namespaceId string) stringMatcher {
// Placeholder: Just return the first one (only support one namespace mappping)
for _, matcher := range s.respMap {
return matcher
}
return createStringMatcher(nil)
return s.respMap[namespaceId]
}

func createStringMatchers(nsMappings map[string]map[string]string) map[string]stringMatcher {
Expand Down
21 changes: 17 additions & 4 deletions interceptor/search_attribute_translator_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,23 @@ import (
)

func TestTranslateSearchAttribute(t *testing.T) {
testTranslateObj(t, visitSearchAttributes, generateSearchAttributeObjs(), require.EqualExportedValues)
namespaceId := "ns-1234"
testTranslateObj(t, generateSearchAttributeObjs(namespaceId), require.EqualExportedValues,
func(mapping map[string]string) Visitor {
v := MakeSearchAttributeVisitor(
func(nsId string) stringMatcher {
if nsId != namespaceId {
return nil
}
return createStringMatcher(mapping)
},
)
return &v
},
)
}

func generateSearchAttributeObjs() []objCase {
func generateSearchAttributeObjs(nsId string) []objCase {
return []objCase{
{
objName: "HistoryTaskAttributes",
Expand All @@ -32,7 +45,7 @@ func generateSearchAttributeObjs() []objCase {
{
Attributes: &replicationspb.ReplicationTask_HistoryTaskAttributes{
HistoryTaskAttributes: &replicationspb.HistoryTaskAttributes{
NamespaceId: "some-ns-id",
NamespaceId: nsId,
WorkflowId: "some-wf-id",
RunId: "some-run-id",
Events: makeHistoryEventsBlobWithSearchAttribute(name),
Expand Down Expand Up @@ -62,7 +75,7 @@ func generateSearchAttributeObjs() []objCase {
SyncWorkflowStateMutationAttributes: &replicationspb.SyncWorkflowStateMutationAttributes{
StateMutation: &persistence.WorkflowMutableStateMutation{
ExecutionInfo: &persistence.WorkflowExecutionInfo{
NamespaceId: "some-ns",
NamespaceId: nsId,
WorkflowId: "some-wf",
SearchAttributes: makeTestIndexedFieldMap(name),
Memo: map[string]*common.Payload{
Expand Down
4 changes: 2 additions & 2 deletions interceptor/translation_interceptor.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ func (i *TranslationInterceptor) Intercept(

for _, tr := range i.translators {
if tr.MatchMethod(info.FullMethod) {
changed, trErr := tr.TranslateResponse(resp)
changed, trErr := tr.TranslateResponse(req, resp)
logTranslateResult(i.logger, changed, trErr, methodName+"Response", resp)
}
}
Expand Down Expand Up @@ -98,7 +98,7 @@ func (w *streamTranslator) RecvMsg(m any) error {
func (w *streamTranslator) SendMsg(m any) error {
w.logger.Debug("Intercept SendMsg", tag.NewStringTag("type", fmt.Sprintf("%T", m)), tag.NewAnyTag("message", m))
for _, tr := range w.translators {
changed, trErr := tr.TranslateResponse(m)
changed, trErr := tr.TranslateResponse(nil, m)
logTranslateResult(w.logger, changed, trErr, "SendMsg", m)
}
return w.ServerStream.SendMsg(m)
Expand Down
Loading