diff --git a/pkg/handler/handler.go b/pkg/handler/handler.go index 54794043b..7338094cb 100644 --- a/pkg/handler/handler.go +++ b/pkg/handler/handler.go @@ -33,6 +33,7 @@ import ( openapi_v2 "github.com/googleapis/gnostic/openapiv2" "github.com/munnerz/goautoneg" "gopkg.in/yaml.v2" + klog "k8s.io/klog/v2" "k8s.io/kube-openapi/pkg/builder" "k8s.io/kube-openapi/pkg/common" "k8s.io/kube-openapi/pkg/validation/spec" @@ -55,13 +56,40 @@ type OpenAPIService struct { lastModified time.Time - specBytes []byte - specPb []byte - specPbGz []byte + jsonCache cache + protoCache cache +} + +type cache struct { + BuildCache func() ([]byte, error) + once sync.Once + bytes []byte + etag string + err error +} - specBytesETag string - specPbETag string - specPbGzETag string +func (c *cache) Get() ([]byte, string, error) { + c.once.Do(func() { + bytes, err := c.BuildCache() + // if there is an error updating the cache, there can be situations where + // c.bytes contains a valid value (carried over from the previous update) + // but c.err is also not nil; the cache user is expected to check for this + c.err = err + if c.err == nil { + // don't override previous spec if we had an error + c.bytes = bytes + c.etag = computeETag(c.bytes) + } + }) + return c.bytes, c.etag, c.err +} + +func (c *cache) New(cacheBuilder func() ([]byte, error)) cache { + return cache{ + bytes: c.bytes, + etag: c.etag, + BuildCache: cacheBuilder, + } } func init() { @@ -71,6 +99,9 @@ func init() { } func computeETag(data []byte) string { + if data == nil { + return "" + } return fmt.Sprintf("\"%X\"", sha512.Sum512(data)) } @@ -83,51 +114,40 @@ func NewOpenAPIService(spec *spec.Swagger) (*OpenAPIService, error) { return o, nil } -func (o *OpenAPIService) getSwaggerBytes() ([]byte, string, time.Time) { - o.rwMutex.RLock() - defer o.rwMutex.RUnlock() - return o.specBytes, o.specBytesETag, o.lastModified -} - -func (o *OpenAPIService) getSwaggerPbBytes() ([]byte, string, time.Time) { +func (o *OpenAPIService) getSwaggerBytes() ([]byte, string, time.Time, error) { o.rwMutex.RLock() defer o.rwMutex.RUnlock() - return o.specPb, o.specPbETag, o.lastModified + specBytes, etag, err := o.jsonCache.Get() + if err != nil { + return nil, "", time.Time{}, err + } + return specBytes, etag, o.lastModified, nil } -func (o *OpenAPIService) getSwaggerPbGzBytes() ([]byte, string, time.Time) { +func (o *OpenAPIService) getSwaggerPbBytes() ([]byte, string, time.Time, error) { o.rwMutex.RLock() defer o.rwMutex.RUnlock() - return o.specPbGz, o.specPbGzETag, o.lastModified -} - -func (o *OpenAPIService) UpdateSpec(openapiSpec *spec.Swagger) (err error) { - specBytes, err := json.Marshal(openapiSpec) + specPb, etag, err := o.protoCache.Get() if err != nil { - return err + return nil, "", time.Time{}, err } - specPb, err := ToProtoBinary(specBytes) - if err != nil { - return err - } - specPbGz := toGzip(specPb) - - specBytesETag := computeETag(specBytes) - specPbETag := computeETag(specPb) - specPbGzETag := computeETag(specPbGz) - - lastModified := time.Now() + return specPb, etag, o.lastModified, nil +} +func (o *OpenAPIService) UpdateSpec(openapiSpec *spec.Swagger) (err error) { o.rwMutex.Lock() defer o.rwMutex.Unlock() - - o.specBytes = specBytes - o.specPb = specPb - o.specPbGz = specPbGz - o.specBytesETag = specBytesETag - o.specPbETag = specPbETag - o.specPbGzETag = specPbGzETag - o.lastModified = lastModified + o.jsonCache = o.jsonCache.New(func() ([]byte, error) { + return json.Marshal(openapiSpec) + }) + o.protoCache = o.protoCache.New(func() ([]byte, error) { + json, _, err := o.jsonCache.Get() + if err != nil { + return nil, err + } + return ToProtoBinary(json) + }) + o.lastModified = time.Now() return nil } @@ -206,7 +226,7 @@ func (o *OpenAPIService) RegisterOpenAPIVersionedService(servePath string, handl accepted := []struct { Type string SubType string - GetDataAndETag func() ([]byte, string, time.Time) + GetDataAndETag func() ([]byte, string, time.Time, error) }{ {"application", "json", o.getSwaggerBytes}, {"application", "com.github.proto-openapi.spec.v2@v1.0+protobuf", o.getSwaggerPbBytes}, @@ -230,7 +250,15 @@ func (o *OpenAPIService) RegisterOpenAPIVersionedService(servePath string, handl } // serve the first matching media type in the sorted clause list - data, etag, lastModified := accepts.GetDataAndETag() + data, etag, lastModified, err := accepts.GetDataAndETag() + if err != nil { + klog.Errorf("Error in OpenAPI handler: %s", err) + // only return a 503 if we have no older cache data to serve + if data == nil { + w.WriteHeader(http.StatusServiceUnavailable) + return + } + } w.Header().Set("Etag", etag) // ServeContent will take care of caching using eTag. http.ServeContent(w, r, servePath, lastModified, bytes.NewReader(data)) diff --git a/pkg/handler/handler_test.go b/pkg/handler/handler_test.go index eaea1e320..35325ce10 100644 --- a/pkg/handler/handler_test.go +++ b/pkg/handler/handler_test.go @@ -2,6 +2,7 @@ package handler import ( json "encoding/json" + "errors" "io/ioutil" "math" "net/http" @@ -177,3 +178,86 @@ func TestToProtoBinary(t *testing.T) { } // TODO: add some kind of roundtrip test here } + +func TestCache(t *testing.T) { + calledCount := 0 + expectedBytes := []byte("ABC") + cacheObj := cache{ + BuildCache: func() ([]byte, error) { + calledCount++ + return expectedBytes, nil + }, + } + bytes, _, _ := cacheObj.Get() + if string(bytes) != string(expectedBytes) { + t.Fatalf("got value of %q from cache (expected %q)", bytes, expectedBytes) + } + cacheObj.Get() + if calledCount != 1 { + t.Fatalf("expected BuildCache to be called once (called %d times)", calledCount) + } +} + +func TestCacheError(t *testing.T) { + cacheObj := cache{ + BuildCache: func() ([]byte, error) { + return nil, errors.New("cache error") + }, + } + _, _, err := cacheObj.Get() + if err == nil { + t.Fatalf("expected non-nil err from cache.Get()") + } +} + +func TestCacheRefresh(t *testing.T) { + // check that returning an error while having no prior cached value results in a nil value from cache.Get() + cacheObj := (&cache{}).New(func() ([]byte, error) { + return nil, errors.New("returning nil bytes") + }) + // make multiple calls to Get() to ensure errors are preserved across subsequent calls + for i := 0; i < 4; i++ { + value, _, err := cacheObj.Get() + if value != nil { + t.Fatalf("expected nil bytes (got %s)", value) + } + if err == nil { + t.Fatalf("expected non-nil err from cache.Get()") + } + } + // check that we can call New() multiple times and get the last known cache value while also returning any errors + lastGoodVal := []byte("last good value") + cacheObj = cacheObj.New(func() ([]byte, error) { + return lastGoodVal, nil + }) + // call Get() once, so lastGoodVal is cached + _, lastGoodEtag, _ := cacheObj.Get() + for i := 0; i < 4; i++ { + cacheObj = cacheObj.New(func() ([]byte, error) { + return nil, errors.New("check that c.bytes is preserved across New() calls") + }) + value, newEtag, err := cacheObj.Get() + if err == nil { + t.Fatalf("expected non-nil err from cache.Get()") + } + if string(value) != string(lastGoodVal) { + t.Fatalf("expected previous value for cache to be returned (got %s, expected %s)", value, lastGoodVal) + } + // check that etags carry over between calls to cache.New() + if lastGoodEtag != newEtag { + t.Fatalf("expected etags to match (got %s, expected %s", newEtag, lastGoodEtag) + } + } + // check that if we successfully renew the cache the old last known value is flushed + newVal := []byte("new good value") + cacheObj = cacheObj.New(func() ([]byte, error) { + return newVal, nil + }) + value, _, err := cacheObj.Get() + if err != nil { + t.Fatalf("expected nil err from cache.Get()") + } + if string(value) != string(newVal) { + t.Fatalf("got value of %s from cache (expected %s)", value, newVal) + } +}