Skip to content

Commit 077843f

Browse files
committed
[oidc] encode and validate state params
Using JWT tokens for encoding/decoding/validation of state params carried throughout the OIDC/OAuth2 flow. Validating of integrity is crucial, as this piece of information contains the ID of the OIDC client to continue with when Gitpod receives the callback from a 3rd party. Tests should show that expiration time is checked and signature validation is effective.
1 parent 901dcd2 commit 077843f

File tree

9 files changed

+183
-32
lines changed

9 files changed

+183
-32
lines changed

components/public-api-server/go.mod

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ require (
4242
github.com/felixge/httpsnoop v1.0.1 // indirect
4343
github.com/go-jose/go-jose/v3 v3.0.0 // indirect
4444
github.com/go-sql-driver/mysql v1.6.0 // indirect
45+
github.com/golang-jwt/jwt/v4 v4.4.3 // indirect
4546
github.com/golang/protobuf v1.5.2 // indirect
4647
github.com/gorilla/websocket v1.5.0 // indirect
4748
github.com/grpc-ecosystem/go-grpc-prometheus v1.2.0 // indirect

components/public-api-server/go.sum

Lines changed: 2 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

components/public-api-server/pkg/oidc/oauth2.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ func (s *Service) OAuth2Middleware(next http.Handler) http.Handler {
6666
return
6767
}
6868

69-
state, err := decodeStateParam(stateParam)
69+
state, err := s.decodeStateParam(stateParam)
7070
if err != nil {
7171
http.Error(rw, "bad state param", http.StatusBadRequest)
7272
return

components/public-api-server/pkg/oidc/router_test.go

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ func TestRoute_start(t *testing.T) {
2222
idpUrl := newFakeIdP(t)
2323

2424
// setup test server with client routes
25-
baseUrl, _, configId := newTestServer(t, testServerParams{
25+
baseUrl, _, configId, _ := newTestServer(t, testServerParams{
2626
issuer: idpUrl,
2727
returnToURL: "",
2828
})
@@ -50,11 +50,11 @@ func TestRoute_callback(t *testing.T) {
5050
idpUrl := newFakeIdP(t)
5151

5252
// setup test server with client routes
53-
baseUrl, stateParam, _ := newTestServer(t, testServerParams{
53+
baseUrl, stateParam, _, service := newTestServer(t, testServerParams{
5454
issuer: idpUrl,
5555
returnToURL: "/relative/url/to/some/page",
5656
})
57-
state, err := encodeStateParam(*stateParam)
57+
state, err := service.encodeStateParam(*stateParam)
5858
require.NoError(t, err)
5959

6060
// hit the /callback endpoint
@@ -92,7 +92,7 @@ type testServerParams struct {
9292
clientID string
9393
}
9494

95-
func newTestServer(t *testing.T, params testServerParams) (url string, state *StateParam, configId string) {
95+
func newTestServer(t *testing.T, params testServerParams) (url string, state *StateParam, configId string, oidcService *Service) {
9696
router := chi.NewRouter()
9797
oidcService, dbConn := setupOIDCServiceForTests(t)
9898
router.Mount("/oidc", Router(oidcService))
@@ -124,5 +124,5 @@ func newTestServer(t *testing.T, params testServerParams) (url string, state *St
124124
ReturnToURL: params.returnToURL,
125125
}
126126

127-
return url, stateParam, configId
127+
return url, stateParam, configId, oidcService
128128
}

components/public-api-server/pkg/oidc/service.go

Lines changed: 25 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@ import (
1414
"io"
1515
"io/ioutil"
1616
"net/http"
17-
"strings"
1817

1918
"github.com/coreos/go-oidc/v3/oidc"
2019
goidc "github.com/coreos/go-oidc/v3/oidc"
@@ -28,8 +27,9 @@ import (
2827
)
2928

3029
type Service struct {
31-
dbConn *gorm.DB
32-
cipher db.Cipher
30+
dbConn *gorm.DB
31+
cipher db.Cipher
32+
stateJWT *StateJWT
3333

3434
verifierByIssuer map[string]*goidc.IDTokenVerifier
3535
sessionServiceAddress string
@@ -57,18 +57,20 @@ type AuthFlowResult struct {
5757
Claims map[string]interface{} `json:"claims"`
5858
}
5959

60-
func NewService(sessionServiceAddress string, dbConn *gorm.DB, cipher db.Cipher) *Service {
60+
func NewService(sessionServiceAddress string, dbConn *gorm.DB, cipher db.Cipher, stateJWT *StateJWT) *Service {
6161
return &Service{
6262
verifierByIssuer: map[string]*goidc.IDTokenVerifier{},
6363
sessionServiceAddress: sessionServiceAddress,
6464

6565
dbConn: dbConn,
6666
cipher: cipher,
67+
68+
stateJWT: stateJWT,
6769
}
6870
}
6971

70-
func newTestService(sessionServiceAddress string, dbConn *gorm.DB, cipher db.Cipher) *Service {
71-
service := NewService(sessionServiceAddress, dbConn, cipher)
72+
func newTestService(sessionServiceAddress string, dbConn *gorm.DB, cipher db.Cipher, stateJWT *StateJWT) *Service {
73+
service := NewService(sessionServiceAddress, dbConn, cipher, stateJWT)
7274
service.skipVerifyIdToken = true
7375
return service
7476
}
@@ -82,7 +84,7 @@ func (s *Service) GetStartParams(config *ClientConfig, redirectURL string) (*Sta
8284
// TODO(at) read a relative URL from `returnTo` query param of the start request
8385
ReturnToURL: "/",
8486
}
85-
state, err := encodeStateParam(stateParam)
87+
state, err := s.encodeStateParam(stateParam)
8688
if err != nil {
8789
return nil, fmt.Errorf("failed to encode state")
8890
}
@@ -104,21 +106,23 @@ func (s *Service) GetStartParams(config *ClientConfig, redirectURL string) (*Sta
104106
}, nil
105107
}
106108

107-
// TODO(at) state should be a JWT encoding a redirect location
108-
// For now, just use base64
109-
func encodeStateParam(state StateParam) (string, error) {
110-
b, err := json.Marshal(state)
111-
if err != nil {
112-
return "", fmt.Errorf("failed to marshal state to json: %w", err)
113-
}
114-
115-
return base64.StdEncoding.EncodeToString(b), nil
109+
func (s *Service) encodeStateParam(state StateParam) (string, error) {
110+
encodedState, err := s.stateJWT.Encode(StateClaims{
111+
ClientConfigID: state.ClientConfigID,
112+
ReturnToURL: state.ReturnToURL,
113+
})
114+
return encodedState, err
116115
}
117116

118-
func decodeStateParam(encoded string) (StateParam, error) {
119-
var result StateParam
120-
err := json.NewDecoder(base64.NewDecoder(base64.StdEncoding, strings.NewReader(encoded))).Decode(&result)
121-
return result, err
117+
func (s *Service) decodeStateParam(encodedToken string) (StateParam, error) {
118+
claims, err := s.stateJWT.Decode(encodedToken)
119+
if err != nil {
120+
return StateParam{}, err
121+
}
122+
return StateParam{
123+
ClientConfigID: claims.ClientConfigID,
124+
ReturnToURL: claims.ReturnToURL,
125+
}, nil
122126
}
123127

124128
func randString(size int) (string, error) {
@@ -152,7 +156,7 @@ func (s *Service) GetClientConfigFromCallbackRequest(r *http.Request) (*ClientCo
152156
return nil, fmt.Errorf("missing state parameter")
153157
}
154158

155-
state, err := decodeStateParam(stateParam)
159+
state, err := s.decodeStateParam(stateParam)
156160
if err != nil {
157161
return nil, fmt.Errorf("bad state param")
158162
}

components/public-api-server/pkg/oidc/service_test.go

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -120,13 +120,13 @@ func TestGetClientConfigFromCallbackRequest(t *testing.T) {
120120
OAuth2Config: &oauth2.Config{},
121121
})
122122

123-
state, err := encodeStateParam(StateParam{
123+
state, err := service.encodeStateParam(StateParam{
124124
ClientConfigID: configID,
125125
ReturnToURL: "",
126126
})
127127
require.NoError(t, err, "failed encode state param")
128128

129-
state_unknown, err := encodeStateParam(StateParam{
129+
state_unknown, err := service.encodeStateParam(StateParam{
130130
ClientConfigID: "UNKNOWN",
131131
ReturnToURL: "",
132132
})
@@ -209,10 +209,11 @@ func setupOIDCServiceForTests(t *testing.T) (*Service, *gorm.DB) {
209209

210210
dbConn := dbtest.ConnectForTests(t)
211211
cipher := dbtest.CipherSet(t)
212+
stateJWT := newTestStateJWT([]byte("ANY KEY"), 5*time.Minute)
212213

213214
sessionServerAddress := newFakeSessionServer(t)
214215

215-
service := newTestService(sessionServerAddress, dbConn, cipher)
216+
service := newTestService(sessionServerAddress, dbConn, cipher, stateJWT)
216217
return service, dbConn
217218
}
218219

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
// Copyright (c) 2023 Gitpod GmbH. All rights reserved.
2+
// Licensed under the GNU Affero General Public License (AGPL).
3+
// See License.AGPL.txt in the project root for license information.
4+
5+
package oidc
6+
7+
import (
8+
"time"
9+
10+
"github.com/golang-jwt/jwt/v4"
11+
)
12+
13+
type StateJWT struct {
14+
key []byte
15+
expiresIn time.Duration
16+
}
17+
18+
func NewStateJWT(key []byte) *StateJWT {
19+
return &StateJWT{
20+
key: key,
21+
expiresIn: 5 * time.Minute,
22+
}
23+
}
24+
25+
func newTestStateJWT(key []byte, expiresIn time.Duration) *StateJWT {
26+
thing := NewStateJWT(key)
27+
thing.expiresIn = expiresIn
28+
return thing
29+
}
30+
31+
type StateClaims struct {
32+
// Internal client ID
33+
ClientConfigID string `json:"clientId"`
34+
ReturnToURL string `json:"returnTo"`
35+
36+
jwt.RegisteredClaims
37+
}
38+
39+
func (s *StateJWT) Encode(claims StateClaims) (string, error) {
40+
41+
expirationTime := time.Now().Add(s.expiresIn)
42+
claims.ExpiresAt = jwt.NewNumericDate(expirationTime)
43+
44+
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
45+
encodedToken, err := token.SignedString(s.key)
46+
47+
return encodedToken, err
48+
}
49+
50+
func (s *StateJWT) Decode(tokenString string) (*StateClaims, error) {
51+
claims := &StateClaims{}
52+
_, err := jwt.ParseWithClaims(
53+
tokenString,
54+
claims,
55+
func(token *jwt.Token) (interface{}, error) {
56+
return []byte(s.key), nil
57+
},
58+
)
59+
60+
return claims, err
61+
}
Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
// Copyright (c) 2023 Gitpod GmbH. All rights reserved.
2+
// Licensed under the GNU Affero General Public License (AGPL).
3+
// See License.AGPL.txt in the project root for license information.
4+
5+
package oidc
6+
7+
import (
8+
"strings"
9+
"testing"
10+
"time"
11+
12+
"github.com/golang-jwt/jwt/v4"
13+
"github.com/stretchr/testify/require"
14+
)
15+
16+
func Test_Encode(t *testing.T) {
17+
stateJWT := NewStateJWT([]byte("ANY KEY"))
18+
encodedState, err := stateJWT.Encode(StateClaims{
19+
ClientConfigID: "test-id",
20+
ReturnToURL: "test-url",
21+
})
22+
require.NoError(t, err)
23+
// check for header: { "alg": "HS256", "typ": "JWT" }
24+
require.Contains(t, encodedState, "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.", "")
25+
}
26+
27+
func Test_Decode(t *testing.T) {
28+
29+
testCases := []struct {
30+
Label string
31+
Key4Encode string
32+
expiresIn time.Duration
33+
Key4Decode string
34+
ExpectedError string
35+
}{
36+
{
37+
Label: "happy path",
38+
Key4Encode: "ANY KEY",
39+
expiresIn: 5 * time.Minute,
40+
Key4Decode: "ANY KEY",
41+
ExpectedError: "",
42+
},
43+
{
44+
Label: "expired state token",
45+
Key4Encode: "ANY KEY",
46+
expiresIn: 0 * time.Second,
47+
Key4Decode: "ANY KEY",
48+
ExpectedError: "token is expired",
49+
},
50+
{
51+
Label: "signature is invalid",
52+
Key4Encode: "OTHER KEY",
53+
expiresIn: 5 * time.Minute,
54+
Key4Decode: "ANY KEY",
55+
ExpectedError: jwt.ErrSignatureInvalid.Error(),
56+
},
57+
}
58+
59+
for _, tc := range testCases {
60+
t.Run(tc.Label, func(t *testing.T) {
61+
encoder := newTestStateJWT([]byte(tc.Key4Encode), tc.expiresIn)
62+
decoder := NewStateJWT([]byte(tc.Key4Decode))
63+
encodedState, err := encoder.Encode(StateClaims{
64+
ClientConfigID: "test-id",
65+
ReturnToURL: "test-url",
66+
})
67+
if err != nil && tc.ExpectedError == "" {
68+
require.FailNowf(t, "Unexpected error on `Encode`.", "Error: %", err)
69+
}
70+
_, err = decoder.Decode(encodedState)
71+
if err != nil && tc.ExpectedError == "" {
72+
require.FailNowf(t, "Unexpected error on `Decode`.", "Error: %", err)
73+
}
74+
if err != nil && !strings.Contains(err.Error(), tc.ExpectedError) {
75+
require.FailNowf(t, "Unmatched error.", "Got error: %", err.Error())
76+
}
77+
})
78+
}
79+
80+
}

components/public-api-server/pkg/server/server.go

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -75,11 +75,13 @@ func Start(logger *logrus.Entry, version string, cfg *config.Configuration) erro
7575
}
7676
}
7777

78+
var stateJWT *oidc.StateJWT
7879
if cfg.OIDCClientJWTSigningSecretPath != "" {
79-
_, err := readSecretFromFile(cfg.OIDCClientJWTSigningSecretPath)
80+
oidcClientJWTSigningSecret, err := readSecretFromFile(cfg.OIDCClientJWTSigningSecretPath)
8081
if err != nil {
8182
return fmt.Errorf("failed to read JWT signing secret for OIDC flows: %w", err)
8283
}
84+
stateJWT = oidc.NewStateJWT([]byte(oidcClientJWTSigningSecret))
8385
} else {
8486
log.Info("No JWT signing secret for OIDC flows is configured.")
8587
}
@@ -109,7 +111,7 @@ func Start(logger *logrus.Entry, version string, cfg *config.Configuration) erro
109111

110112
srv.HTTPMux().Handle("/stripe/invoices/webhook", handlers.ContentTypeHandler(stripeWebhookHandler, "application/json"))
111113

112-
oidcService := oidc.NewService(cfg.SessionServiceAddress, dbConn, cipherSet)
114+
oidcService := oidc.NewService(cfg.SessionServiceAddress, dbConn, cipherSet, stateJWT)
113115

114116
if registerErr := register(srv, &registerDependencies{
115117
connPool: connPool,

0 commit comments

Comments
 (0)