Skip to content

Commit 12db695

Browse files
authored
grpc: restrict status codes from control plane (gRFC A54) (grpc#5653)
1 parent 202d355 commit 12db695

File tree

7 files changed

+296
-20
lines changed

7 files changed

+296
-20
lines changed

credentials/credentials.go

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -36,16 +36,16 @@ import (
3636
// PerRPCCredentials defines the common interface for the credentials which need to
3737
// attach security information to every RPC (e.g., oauth2).
3838
type PerRPCCredentials interface {
39-
// GetRequestMetadata gets the current request metadata, refreshing
40-
// tokens if required. This should be called by the transport layer on
41-
// each request, and the data should be populated in headers or other
42-
// context. If a status code is returned, it will be used as the status
43-
// for the RPC. uri is the URI of the entry point for the request.
44-
// When supported by the underlying implementation, ctx can be used for
45-
// timeout and cancellation. Additionally, RequestInfo data will be
46-
// available via ctx to this call.
47-
// TODO(zhaoq): Define the set of the qualified keys instead of leaving
48-
// it as an arbitrary string.
39+
// GetRequestMetadata gets the current request metadata, refreshing tokens
40+
// if required. This should be called by the transport layer on each
41+
// request, and the data should be populated in headers or other
42+
// context. If a status code is returned, it will be used as the status for
43+
// the RPC (restricted to an allowable set of codes as defined by gRFC
44+
// A54). uri is the URI of the entry point for the request. When supported
45+
// by the underlying implementation, ctx can be used for timeout and
46+
// cancellation. Additionally, RequestInfo data will be available via ctx
47+
// to this call. TODO(zhaoq): Define the set of the qualified keys instead
48+
// of leaving it as an arbitrary string.
4949
GetRequestMetadata(ctx context.Context, uri ...string) (map[string]string, error)
5050
// RequireTransportSecurity indicates whether the credentials requires
5151
// transport security.

internal/status/status.go

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -164,3 +164,13 @@ func (e *Error) Is(target error) bool {
164164
}
165165
return proto.Equal(e.s.s, tse.s.s)
166166
}
167+
168+
// IsRestrictedControlPlaneCode returns whether the status includes a code
169+
// restricted for control plane usage as defined by gRFC A54.
170+
func IsRestrictedControlPlaneCode(s *Status) bool {
171+
switch s.Code() {
172+
case codes.InvalidArgument, codes.NotFound, codes.AlreadyExists, codes.FailedPrecondition, codes.Aborted, codes.OutOfRange, codes.DataLoss:
173+
return true
174+
}
175+
return false
176+
}

internal/transport/http2_client.go

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ import (
4040
icredentials "google.golang.org/grpc/internal/credentials"
4141
"google.golang.org/grpc/internal/grpcutil"
4242
imetadata "google.golang.org/grpc/internal/metadata"
43+
istatus "google.golang.org/grpc/internal/status"
4344
"google.golang.org/grpc/internal/syscall"
4445
"google.golang.org/grpc/internal/transport/networktype"
4546
"google.golang.org/grpc/keepalive"
@@ -589,7 +590,11 @@ func (t *http2Client) getTrAuthData(ctx context.Context, audience string) (map[s
589590
for _, c := range t.perRPCCreds {
590591
data, err := c.GetRequestMetadata(ctx, audience)
591592
if err != nil {
592-
if _, ok := status.FromError(err); ok {
593+
if st, ok := status.FromError(err); ok {
594+
// Restrict the code to the list allowed by gRFC A54.
595+
if istatus.IsRestrictedControlPlaneCode(st) {
596+
err = status.Errorf(codes.Internal, "transport: received per-RPC creds error with illegal status: %v", err)
597+
}
593598
return nil, err
594599
}
595600

@@ -618,7 +623,14 @@ func (t *http2Client) getCallAuthData(ctx context.Context, audience string, call
618623
}
619624
data, err := callCreds.GetRequestMetadata(ctx, audience)
620625
if err != nil {
621-
return nil, status.Errorf(codes.Internal, "transport: %v", err)
626+
if st, ok := status.FromError(err); ok {
627+
// Restrict the code to the list allowed by gRFC A54.
628+
if istatus.IsRestrictedControlPlaneCode(st) {
629+
err = status.Errorf(codes.Internal, "transport: received per-RPC creds error with illegal status: %v", err)
630+
}
631+
return nil, err
632+
}
633+
return nil, status.Errorf(codes.Internal, "transport: per-RPC creds failed due to error: %v", err)
622634
}
623635
callAuthData = make(map[string]string, len(data))
624636
for k, v := range data {

picker_wrapper.go

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ import (
2626
"google.golang.org/grpc/balancer"
2727
"google.golang.org/grpc/codes"
2828
"google.golang.org/grpc/internal/channelz"
29+
istatus "google.golang.org/grpc/internal/status"
2930
"google.golang.org/grpc/internal/transport"
3031
"google.golang.org/grpc/status"
3132
)
@@ -129,8 +130,12 @@ func (pw *pickerWrapper) pick(ctx context.Context, failfast bool, info balancer.
129130
if err == balancer.ErrNoSubConnAvailable {
130131
continue
131132
}
132-
if _, ok := status.FromError(err); ok {
133+
if st, ok := status.FromError(err); ok {
133134
// Status error: end the RPC unconditionally with this status.
135+
// First restrict the code to the list allowed by gRFC A54.
136+
if istatus.IsRestrictedControlPlaneCode(st) {
137+
err = status.Errorf(codes.Internal, "received picker error with illegal status: %v", err)
138+
}
134139
return nil, nil, dropError{error: err}
135140
}
136141
// For all other errors, wait for ready RPCs should block and other

stream.go

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ import (
3939
imetadata "google.golang.org/grpc/internal/metadata"
4040
iresolver "google.golang.org/grpc/internal/resolver"
4141
"google.golang.org/grpc/internal/serviceconfig"
42+
istatus "google.golang.org/grpc/internal/status"
4243
"google.golang.org/grpc/internal/transport"
4344
"google.golang.org/grpc/metadata"
4445
"google.golang.org/grpc/peer"
@@ -195,6 +196,13 @@ func newClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, meth
195196
rpcInfo := iresolver.RPCInfo{Context: ctx, Method: method}
196197
rpcConfig, err := cc.safeConfigSelector.SelectConfig(rpcInfo)
197198
if err != nil {
199+
if st, ok := status.FromError(err); ok {
200+
// Restrict the code to the list allowed by gRFC A54.
201+
if istatus.IsRestrictedControlPlaneCode(st) {
202+
err = status.Errorf(codes.Internal, "config selector returned illegal status: %v", err)
203+
}
204+
return nil, err
205+
}
198206
return nil, toRPCErr(err)
199207
}
200208

test/control_plane_status_test.go

Lines changed: 234 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,234 @@
1+
/*
2+
*
3+
* Copyright 2022 gRPC authors.
4+
*
5+
* Licensed under the Apache License, Version 2.0 (the "License");
6+
* you may not use this file except in compliance with the License.
7+
* You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*
17+
*/
18+
19+
package test
20+
21+
import (
22+
"context"
23+
"strings"
24+
"testing"
25+
"time"
26+
27+
"google.golang.org/grpc"
28+
"google.golang.org/grpc/balancer"
29+
"google.golang.org/grpc/balancer/base"
30+
"google.golang.org/grpc/codes"
31+
"google.golang.org/grpc/connectivity"
32+
"google.golang.org/grpc/internal/balancer/stub"
33+
iresolver "google.golang.org/grpc/internal/resolver"
34+
"google.golang.org/grpc/internal/stubserver"
35+
"google.golang.org/grpc/resolver"
36+
"google.golang.org/grpc/resolver/manual"
37+
"google.golang.org/grpc/status"
38+
testpb "google.golang.org/grpc/test/grpc_testing"
39+
)
40+
41+
func (s) TestConfigSelectorStatusCodes(t *testing.T) {
42+
testCases := []struct {
43+
name string
44+
csErr error
45+
want error
46+
}{{
47+
name: "legal status code",
48+
csErr: status.Errorf(codes.Unavailable, "this error is fine"),
49+
want: status.Errorf(codes.Unavailable, "this error is fine"),
50+
}, {
51+
name: "illegal status code",
52+
csErr: status.Errorf(codes.NotFound, "this error is bad"),
53+
want: status.Errorf(codes.Internal, "this error is bad"),
54+
}}
55+
56+
for _, tc := range testCases {
57+
t.Run(tc.name, func(t *testing.T) {
58+
ss := &stubserver.StubServer{
59+
EmptyCallF: func(ctx context.Context, in *testpb.Empty) (*testpb.Empty, error) {
60+
return &testpb.Empty{}, nil
61+
},
62+
}
63+
ss.R = manual.NewBuilderWithScheme("confSel")
64+
65+
if err := ss.Start(nil); err != nil {
66+
t.Fatalf("Error starting endpoint server: %v", err)
67+
}
68+
defer ss.Stop()
69+
70+
state := iresolver.SetConfigSelector(resolver.State{
71+
Addresses: []resolver.Address{{Addr: ss.Address}},
72+
ServiceConfig: parseServiceConfig(t, ss.R, "{}"),
73+
}, funcConfigSelector{
74+
f: func(i iresolver.RPCInfo) (*iresolver.RPCConfig, error) {
75+
return nil, tc.csErr
76+
},
77+
})
78+
ss.R.UpdateState(state) // Blocks until config selector is applied
79+
80+
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
81+
defer cancel()
82+
if _, err := ss.Client.EmptyCall(ctx, &testpb.Empty{}); status.Code(err) != status.Code(tc.want) || !strings.Contains(err.Error(), status.Convert(tc.want).Message()) {
83+
t.Fatalf("client.EmptyCall(_, _) = _, %v; want _, %v", err, tc.want)
84+
}
85+
})
86+
}
87+
}
88+
89+
func (s) TestPickerStatusCodes(t *testing.T) {
90+
testCases := []struct {
91+
name string
92+
pickerErr error
93+
want error
94+
}{{
95+
name: "legal status code",
96+
pickerErr: status.Errorf(codes.Unavailable, "this error is fine"),
97+
want: status.Errorf(codes.Unavailable, "this error is fine"),
98+
}, {
99+
name: "illegal status code",
100+
pickerErr: status.Errorf(codes.NotFound, "this error is bad"),
101+
want: status.Errorf(codes.Internal, "this error is bad"),
102+
}}
103+
104+
for _, tc := range testCases {
105+
t.Run(tc.name, func(t *testing.T) {
106+
ss := &stubserver.StubServer{
107+
EmptyCallF: func(ctx context.Context, in *testpb.Empty) (*testpb.Empty, error) {
108+
return &testpb.Empty{}, nil
109+
},
110+
}
111+
112+
if err := ss.Start(nil); err != nil {
113+
t.Fatalf("Error starting endpoint server: %v", err)
114+
}
115+
defer ss.Stop()
116+
117+
// Create a stub balancer that creates a picker that always returns
118+
// an error.
119+
sbf := stub.BalancerFuncs{
120+
UpdateClientConnState: func(d *stub.BalancerData, _ balancer.ClientConnState) error {
121+
d.ClientConn.UpdateState(balancer.State{
122+
ConnectivityState: connectivity.TransientFailure,
123+
Picker: base.NewErrPicker(tc.pickerErr),
124+
})
125+
return nil
126+
},
127+
}
128+
stub.Register("testPickerStatusCodesBalancer", sbf)
129+
130+
ss.NewServiceConfig(`{"loadBalancingConfig": [{"testPickerStatusCodesBalancer":{}}] }`)
131+
132+
// Make calls until pickerErr is received.
133+
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
134+
defer cancel()
135+
136+
var lastErr error
137+
for ctx.Err() == nil {
138+
if _, lastErr = ss.Client.EmptyCall(ctx, &testpb.Empty{}); status.Code(lastErr) == status.Code(tc.want) && strings.Contains(lastErr.Error(), status.Convert(tc.want).Message()) {
139+
// Success!
140+
return
141+
}
142+
time.Sleep(time.Millisecond)
143+
}
144+
145+
t.Fatalf("client.EmptyCall(_, _) = _, %v; want _, %v", lastErr, tc.want)
146+
})
147+
}
148+
}
149+
150+
func (s) TestCallCredsFromDialOptionsStatusCodes(t *testing.T) {
151+
testCases := []struct {
152+
name string
153+
credsErr error
154+
want error
155+
}{{
156+
name: "legal status code",
157+
credsErr: status.Errorf(codes.Unavailable, "this error is fine"),
158+
want: status.Errorf(codes.Unavailable, "this error is fine"),
159+
}, {
160+
name: "illegal status code",
161+
credsErr: status.Errorf(codes.NotFound, "this error is bad"),
162+
want: status.Errorf(codes.Internal, "this error is bad"),
163+
}}
164+
165+
for _, tc := range testCases {
166+
t.Run(tc.name, func(t *testing.T) {
167+
ss := &stubserver.StubServer{
168+
EmptyCallF: func(ctx context.Context, in *testpb.Empty) (*testpb.Empty, error) {
169+
return &testpb.Empty{}, nil
170+
},
171+
}
172+
173+
errChan := make(chan error, 1)
174+
creds := &testPerRPCCredentials{errChan: errChan}
175+
176+
if err := ss.Start(nil, grpc.WithPerRPCCredentials(creds)); err != nil {
177+
t.Fatalf("Error starting endpoint server: %v", err)
178+
}
179+
defer ss.Stop()
180+
181+
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
182+
defer cancel()
183+
184+
errChan <- tc.credsErr
185+
186+
if _, err := ss.Client.EmptyCall(ctx, &testpb.Empty{}); status.Code(err) != status.Code(tc.want) || !strings.Contains(err.Error(), status.Convert(tc.want).Message()) {
187+
t.Fatalf("client.EmptyCall(_, _) = _, %v; want _, %v", err, tc.want)
188+
}
189+
})
190+
}
191+
}
192+
193+
func (s) TestCallCredsFromCallOptionsStatusCodes(t *testing.T) {
194+
testCases := []struct {
195+
name string
196+
credsErr error
197+
want error
198+
}{{
199+
name: "legal status code",
200+
credsErr: status.Errorf(codes.Unavailable, "this error is fine"),
201+
want: status.Errorf(codes.Unavailable, "this error is fine"),
202+
}, {
203+
name: "illegal status code",
204+
credsErr: status.Errorf(codes.NotFound, "this error is bad"),
205+
want: status.Errorf(codes.Internal, "this error is bad"),
206+
}}
207+
208+
for _, tc := range testCases {
209+
t.Run(tc.name, func(t *testing.T) {
210+
ss := &stubserver.StubServer{
211+
EmptyCallF: func(ctx context.Context, in *testpb.Empty) (*testpb.Empty, error) {
212+
return &testpb.Empty{}, nil
213+
},
214+
}
215+
216+
errChan := make(chan error, 1)
217+
creds := &testPerRPCCredentials{errChan: errChan}
218+
219+
if err := ss.Start(nil); err != nil {
220+
t.Fatalf("Error starting endpoint server: %v", err)
221+
}
222+
defer ss.Stop()
223+
224+
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
225+
defer cancel()
226+
227+
errChan <- tc.credsErr
228+
229+
if _, err := ss.Client.EmptyCall(ctx, &testpb.Empty{}, grpc.PerRPCCredentials(creds)); status.Code(err) != status.Code(tc.want) || !strings.Contains(err.Error(), status.Convert(tc.want).Message()) {
230+
t.Fatalf("client.EmptyCall(_, _) = _, %v; want _, %v", err, tc.want)
231+
}
232+
})
233+
}
234+
}

0 commit comments

Comments
 (0)