Skip to content

Commit a9f3fcc

Browse files
feat(core): Expose context authn methods (#1812)
### Proposed Changes * enable setting context info and retrieving token/jwk from context from additional services added on platform extension ### Checklist - [x] I have added or updated unit tests - [ ] I have added or updated integration tests (if appropriate) - [ ] I have added or updated documentation ### Testing Instructions
1 parent 7828aeb commit a9f3fcc

File tree

6 files changed

+150
-66
lines changed

6 files changed

+150
-66
lines changed

service/internal/auth/authn.go

Lines changed: 3 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -24,20 +24,10 @@ import (
2424

2525
sdkAudit "github.com/opentdf/platform/sdk/audit"
2626
"github.com/opentdf/platform/service/logger"
27-
)
2827

29-
const (
30-
authnContextKey = authContextKey("dpop-jwk")
28+
ctxAuth "github.com/opentdf/platform/service/pkg/auth"
3129
)
3230

33-
type authContextKey string
34-
35-
type authContext struct {
36-
key jwk.Key
37-
accessToken jwt.Token
38-
rawToken string
39-
}
40-
4131
var (
4232
// Set of allowed public endpoints that do not require authentication
4333
allowedPublicEndpoints = [...]string{
@@ -394,61 +384,18 @@ func (a Authentication) checkToken(ctx context.Context, authHeader []string, dpo
394384
if !tokenHasCNF && !a.enforceDPoP {
395385
// this condition is not quite tight because it's possible that the `cnf` claim may
396386
// come from token introspection
397-
ctx = ContextWithAuthNInfo(ctx, nil, accessToken, tokenRaw)
387+
ctx = ctxAuth.ContextWithAuthNInfo(ctx, nil, accessToken, tokenRaw)
398388
return accessToken, ctx, nil
399389
}
400390
key, err := a.validateDPoP(accessToken, tokenRaw, dpopInfo, dpopHeader)
401391
if err != nil {
402392
a.logger.Warn("failed to validate dpop", slog.String("token", tokenRaw), slog.Any("err", err))
403393
return nil, nil, err
404394
}
405-
ctx = ContextWithAuthNInfo(ctx, key, accessToken, tokenRaw)
395+
ctx = ctxAuth.ContextWithAuthNInfo(ctx, key, accessToken, tokenRaw)
406396
return accessToken, ctx, nil
407397
}
408398

409-
func ContextWithAuthNInfo(ctx context.Context, key jwk.Key, accessToken jwt.Token, raw string) context.Context {
410-
return context.WithValue(ctx, authnContextKey, &authContext{
411-
key,
412-
accessToken,
413-
raw,
414-
})
415-
}
416-
417-
func getContextDetails(ctx context.Context, l *logger.Logger) *authContext {
418-
key := ctx.Value(authnContextKey)
419-
if key == nil {
420-
return nil
421-
}
422-
if c, ok := key.(*authContext); ok {
423-
return c
424-
}
425-
426-
// We should probably return an error here?
427-
l.ErrorContext(ctx, "invalid authContext")
428-
return nil
429-
}
430-
431-
func GetJWKFromContext(ctx context.Context, l *logger.Logger) jwk.Key {
432-
if c := getContextDetails(ctx, l); c != nil {
433-
return c.key
434-
}
435-
return nil
436-
}
437-
438-
func GetAccessTokenFromContext(ctx context.Context, l *logger.Logger) jwt.Token {
439-
if c := getContextDetails(ctx, l); c != nil {
440-
return c.accessToken
441-
}
442-
return nil
443-
}
444-
445-
func GetRawAccessTokenFromContext(ctx context.Context, l *logger.Logger) string {
446-
if c := getContextDetails(ctx, l); c != nil {
447-
return c.rawToken
448-
}
449-
return ""
450-
}
451-
452399
func (a Authentication) validateDPoP(accessToken jwt.Token, acessTokenRaw string, dpopInfo receiverInfo, headers []string) (jwk.Key, error) {
453400
if len(headers) != 1 {
454401
return nil, fmt.Errorf("got %d dpop headers, should have 1", len(headers))

service/internal/auth/authn_test.go

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ import (
3131
sdkauth "github.com/opentdf/platform/sdk/auth"
3232
"github.com/opentdf/platform/service/internal/server/memhttp"
3333
"github.com/opentdf/platform/service/logger"
34+
ctxAuth "github.com/opentdf/platform/service/pkg/auth"
3435
"github.com/stretchr/testify/assert"
3536
"github.com/stretchr/testify/require"
3637
"github.com/stretchr/testify/suite"
@@ -69,7 +70,7 @@ func (f *FakeAccessServiceServer) LegacyPublicKey(_ context.Context, _ *connect.
6970

7071
func (f *FakeAccessServiceServer) Rewrap(ctx context.Context, req *connect.Request[kas.RewrapRequest]) (*connect.Response[kas.RewrapResponse], error) {
7172
f.accessToken = req.Header()["Authorization"]
72-
f.dpopKey = GetJWKFromContext(ctx, logger.CreateTestLogger())
73+
f.dpopKey = ctxAuth.GetJWKFromContext(ctx, logger.CreateTestLogger())
7374

7475
return &connect.Response[kas.RewrapResponse]{Msg: &kas.RewrapResponse{}}, nil
7576
}
@@ -512,7 +513,7 @@ func (s *AuthSuite) TestDPoPEndToEnd_HTTP() {
512513
timeout <- ""
513514
}()
514515
server := httptest.NewServer(s.auth.MuxHandler(http.HandlerFunc(func(_ http.ResponseWriter, req *http.Request) {
515-
jwkChan <- GetJWKFromContext(req.Context(), logger.CreateTestLogger())
516+
jwkChan <- ctxAuth.GetJWKFromContext(req.Context(), logger.CreateTestLogger())
516517
})))
517518
defer server.Close()
518519

@@ -638,7 +639,7 @@ func (s *AuthSuite) Test_Allowing_Auth_With_No_DPoP() {
638639

639640
_, ctx, err := auth.checkToken(context.Background(), []string{fmt.Sprintf("Bearer %s", string(signedTok))}, receiverInfo{}, nil)
640641
s.Require().NoError(err)
641-
s.Require().Nil(GetJWKFromContext(ctx, logger.CreateTestLogger()))
642+
s.Require().Nil(ctxAuth.GetJWKFromContext(ctx, logger.CreateTestLogger()))
642643
}
643644

644645
func (s *AuthSuite) Test_PublicPath_Matches() {

service/kas/access/rewrap.go

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -30,10 +30,10 @@ import (
3030

3131
kaspb "github.com/opentdf/platform/protocol/go/kas"
3232
"github.com/opentdf/platform/sdk"
33-
"github.com/opentdf/platform/service/internal/auth"
3433
"github.com/opentdf/platform/service/internal/security"
3534
"github.com/opentdf/platform/service/logger"
3635
"github.com/opentdf/platform/service/logger/audit"
36+
ctxAuth "github.com/opentdf/platform/service/pkg/auth"
3737
"google.golang.org/grpc/codes"
3838
"google.golang.org/grpc/status"
3939
)
@@ -128,7 +128,7 @@ func extractSRTBody(ctx context.Context, headers http.Header, in *kaspb.RewrapRe
128128
}
129129

130130
// get dpop public key from context
131-
dpopJWK := auth.GetJWKFromContext(ctx, &logger)
131+
dpopJWK := ctxAuth.GetJWKFromContext(ctx, &logger)
132132

133133
var err error
134134
var rbString string
@@ -247,7 +247,7 @@ func verifyAndParsePolicy(ctx context.Context, requestBody *RequestBody, k []byt
247247
func getEntityInfo(ctx context.Context, logger *logger.Logger) (*entityInfo, error) {
248248
info := new(entityInfo)
249249

250-
token := auth.GetAccessTokenFromContext(ctx, logger)
250+
token := ctxAuth.GetAccessTokenFromContext(ctx, logger)
251251
if token == nil {
252252
return nil, err401("missing access token")
253253
}
@@ -263,7 +263,7 @@ func getEntityInfo(ctx context.Context, logger *logger.Logger) (*entityInfo, err
263263
logger.WarnContext(ctx, "missing sub")
264264
}
265265

266-
info.Token = auth.GetRawAccessTokenFromContext(ctx, logger)
266+
info.Token = ctxAuth.GetRawAccessTokenFromContext(ctx, logger)
267267

268268
return info, nil
269269
}

service/kas/access/rewrap_test.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,8 @@ import (
1717
"github.com/lestrrat-go/jwx/v2/jws"
1818
"github.com/lestrrat-go/jwx/v2/jwt"
1919
"github.com/opentdf/platform/lib/ocrypto"
20-
"github.com/opentdf/platform/service/internal/auth"
2120
"github.com/opentdf/platform/service/logger"
21+
ctxAuth "github.com/opentdf/platform/service/pkg/auth"
2222
"github.com/stretchr/testify/assert"
2323
"github.com/stretchr/testify/require"
2424

@@ -328,7 +328,7 @@ func TestParseAndVerifyRequest(t *testing.T) {
328328
require.NoError(t, err, "couldn't get JWK from key")
329329
err = key.Set(jwk.AlgorithmKey, jwa.RS256) // Check the error return value
330330
require.NoError(t, err, "failed to set algorithm key")
331-
ctx = auth.ContextWithAuthNInfo(ctx, key, mockJWT(t), bearer)
331+
ctx = ctxAuth.ContextWithAuthNInfo(ctx, key, mockJWT(t), bearer)
332332
}
333333

334334
md := metadata.New(map[string]string{"token": bearer})
@@ -370,7 +370,7 @@ func Test_SignedRequestBody_When_Bad_Signature_Expect_Failure(t *testing.T) {
370370

371371
err = key.Set(jwk.AlgorithmKey, jwa.NoSignature)
372372
require.NoError(t, err, "failed to set algorithm key")
373-
ctx = auth.ContextWithAuthNInfo(ctx, key, mockJWT(t), string(jwtStandard(t)))
373+
ctx = ctxAuth.ContextWithAuthNInfo(ctx, key, mockJWT(t), string(jwtStandard(t)))
374374

375375
md := metadata.New(map[string]string{"token": string(jwtWrongKey(t))})
376376
ctx = metadata.NewIncomingContext(ctx, md)

service/pkg/auth/context_auth.go

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
package auth
2+
3+
import (
4+
"context"
5+
6+
"github.com/lestrrat-go/jwx/v2/jwk"
7+
"github.com/lestrrat-go/jwx/v2/jwt"
8+
"github.com/opentdf/platform/service/logger"
9+
)
10+
11+
var (
12+
authnContextKey = authContextKey{}
13+
)
14+
15+
type authContextKey struct{}
16+
17+
type authContext struct {
18+
key jwk.Key
19+
accessToken jwt.Token
20+
rawToken string
21+
}
22+
23+
func ContextWithAuthNInfo(ctx context.Context, key jwk.Key, accessToken jwt.Token, raw string) context.Context {
24+
return context.WithValue(ctx, authnContextKey, &authContext{
25+
key,
26+
accessToken,
27+
raw,
28+
})
29+
}
30+
31+
func getContextDetails(ctx context.Context, l *logger.Logger) *authContext {
32+
key := ctx.Value(authnContextKey)
33+
if key == nil {
34+
return nil
35+
}
36+
if c, ok := key.(*authContext); ok {
37+
return c
38+
}
39+
40+
// We should probably return an error here?
41+
l.ErrorContext(ctx, "invalid authContext")
42+
return nil
43+
}
44+
45+
func GetJWKFromContext(ctx context.Context, l *logger.Logger) jwk.Key {
46+
if c := getContextDetails(ctx, l); c != nil {
47+
return c.key
48+
}
49+
return nil
50+
}
51+
52+
func GetAccessTokenFromContext(ctx context.Context, l *logger.Logger) jwt.Token {
53+
if c := getContextDetails(ctx, l); c != nil {
54+
return c.accessToken
55+
}
56+
return nil
57+
}
58+
59+
func GetRawAccessTokenFromContext(ctx context.Context, l *logger.Logger) string {
60+
if c := getContextDetails(ctx, l); c != nil {
61+
return c.rawToken
62+
}
63+
return ""
64+
}
Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
package auth
2+
3+
import (
4+
"context"
5+
"testing"
6+
7+
"github.com/lestrrat-go/jwx/v2/jwk"
8+
"github.com/lestrrat-go/jwx/v2/jwt"
9+
"github.com/opentdf/platform/service/logger"
10+
"github.com/stretchr/testify/assert"
11+
)
12+
13+
func TestContextWithAuthNInfo(t *testing.T) {
14+
// Create mock JWK, JWT, and raw token
15+
mockJWK, _ := jwk.FromRaw([]byte("mockKey"))
16+
mockJWT, _ := jwt.NewBuilder().Build()
17+
rawToken := "mockRawToken"
18+
19+
// Initialize context
20+
ctx := context.Background()
21+
newCtx := ContextWithAuthNInfo(ctx, mockJWK, mockJWT, rawToken)
22+
23+
// Assert that the context contains the correct values
24+
value := newCtx.Value(authnContextKey)
25+
testAuthContext, ok := value.(*authContext)
26+
assert.True(t, ok)
27+
assert.NotNil(t, testAuthContext)
28+
assert.Equal(t, mockJWK, testAuthContext.key, "JWK should match")
29+
assert.Equal(t, mockJWT, testAuthContext.accessToken, "JWT should match")
30+
assert.Equal(t, rawToken, testAuthContext.rawToken, "Raw token should match")
31+
}
32+
33+
func TestGetJWKFromContext(t *testing.T) {
34+
// Create mock context with JWK
35+
mockJWK, _ := jwk.FromRaw([]byte("mockKey"))
36+
ctx := ContextWithAuthNInfo(context.Background(), mockJWK, nil, "")
37+
38+
// Retrieve the JWK and assert
39+
retrievedJWK := GetJWKFromContext(ctx, logger.CreateTestLogger())
40+
assert.NotNil(t, retrievedJWK, "JWK should not be nil")
41+
assert.Equal(t, mockJWK, retrievedJWK, "Retrieved JWK should match the mock JWK")
42+
}
43+
44+
func TestGetAccessTokenFromContext(t *testing.T) {
45+
// Create mock context with JWT
46+
mockJWT, _ := jwt.NewBuilder().Build()
47+
ctx := ContextWithAuthNInfo(context.Background(), nil, mockJWT, "")
48+
49+
// Retrieve the JWT and assert
50+
retrievedJWT := GetAccessTokenFromContext(ctx, logger.CreateTestLogger())
51+
assert.NotNil(t, retrievedJWT, "Access token should not be nil")
52+
assert.Equal(t, mockJWT, retrievedJWT, "Retrieved JWT should match the mock JWT")
53+
}
54+
55+
func TestGetRawAccessTokenFromContext(t *testing.T) {
56+
// Create mock context with raw token
57+
rawToken := "mockRawToken"
58+
ctx := ContextWithAuthNInfo(context.Background(), nil, nil, rawToken)
59+
60+
// Retrieve the raw token and assert
61+
retrievedRawToken := GetRawAccessTokenFromContext(ctx, logger.CreateTestLogger())
62+
assert.Equal(t, rawToken, retrievedRawToken, "Retrieved raw token should match the mock raw token")
63+
}
64+
65+
func TestGetContextDetailsInvalidType(t *testing.T) {
66+
// Create a context with an invalid type
67+
ctx := context.WithValue(context.Background(), authnContextKey, "invalidType")
68+
69+
// Assert that GetJWKFromContext handles the invalid type correctly
70+
retrievedJWK := GetJWKFromContext(ctx, logger.CreateTestLogger())
71+
assert.Nil(t, retrievedJWK, "JWK should be nil when context value is invalid")
72+
}

0 commit comments

Comments
 (0)