Skip to content

Lazy marshaling for OpenAPI v2 spec #251

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Nov 4, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
112 changes: 70 additions & 42 deletions pkg/handler/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ import (
openapi_v2 "github.com/googleapis/gnostic/openapiv2"
"github.com/munnerz/goautoneg"
"gopkg.in/yaml.v2"
klog "k8s.io/klog/v2"
"k8s.io/kube-openapi/pkg/builder"
"k8s.io/kube-openapi/pkg/common"
"k8s.io/kube-openapi/pkg/validation/spec"
Expand All @@ -55,13 +56,40 @@ type OpenAPIService struct {

lastModified time.Time

specBytes []byte
specPb []byte
specPbGz []byte
jsonCache cache
protoCache cache
}

type cache struct {
BuildCache func() ([]byte, error)
once sync.Once
bytes []byte
etag string
err error
}

specBytesETag string
specPbETag string
specPbGzETag string
func (c *cache) Get() ([]byte, string, error) {
c.once.Do(func() {
bytes, err := c.BuildCache()
// if there is an error updating the cache, there can be situations where
// c.bytes contains a valid value (carried over from the previous update)
// but c.err is also not nil; the cache user is expected to check for this
c.err = err
if c.err == nil {
// don't override previous spec if we had an error
c.bytes = bytes
c.etag = computeETag(c.bytes)
}
})
return c.bytes, c.etag, c.err
}

func (c *cache) New(cacheBuilder func() ([]byte, error)) cache {
return cache{
bytes: c.bytes,
etag: c.etag,
BuildCache: cacheBuilder,
}
}

func init() {
Expand All @@ -71,6 +99,9 @@ func init() {
}

func computeETag(data []byte) string {
if data == nil {
return ""
}
return fmt.Sprintf("\"%X\"", sha512.Sum512(data))
}

Expand All @@ -83,51 +114,40 @@ func NewOpenAPIService(spec *spec.Swagger) (*OpenAPIService, error) {
return o, nil
}

func (o *OpenAPIService) getSwaggerBytes() ([]byte, string, time.Time) {
o.rwMutex.RLock()
defer o.rwMutex.RUnlock()
return o.specBytes, o.specBytesETag, o.lastModified
}

func (o *OpenAPIService) getSwaggerPbBytes() ([]byte, string, time.Time) {
func (o *OpenAPIService) getSwaggerBytes() ([]byte, string, time.Time, error) {
o.rwMutex.RLock()
defer o.rwMutex.RUnlock()
return o.specPb, o.specPbETag, o.lastModified
specBytes, etag, err := o.jsonCache.Get()
if err != nil {
return nil, "", time.Time{}, err
}
return specBytes, etag, o.lastModified, nil
}

func (o *OpenAPIService) getSwaggerPbGzBytes() ([]byte, string, time.Time) {
func (o *OpenAPIService) getSwaggerPbBytes() ([]byte, string, time.Time, error) {
o.rwMutex.RLock()
defer o.rwMutex.RUnlock()
return o.specPbGz, o.specPbGzETag, o.lastModified
}

func (o *OpenAPIService) UpdateSpec(openapiSpec *spec.Swagger) (err error) {
specBytes, err := json.Marshal(openapiSpec)
specPb, etag, err := o.protoCache.Get()
if err != nil {
return err
return nil, "", time.Time{}, err
}
specPb, err := ToProtoBinary(specBytes)
if err != nil {
return err
}
specPbGz := toGzip(specPb)

specBytesETag := computeETag(specBytes)
specPbETag := computeETag(specPb)
specPbGzETag := computeETag(specPbGz)

lastModified := time.Now()
return specPb, etag, o.lastModified, nil
}

func (o *OpenAPIService) UpdateSpec(openapiSpec *spec.Swagger) (err error) {
o.rwMutex.Lock()
defer o.rwMutex.Unlock()

o.specBytes = specBytes
o.specPb = specPb
o.specPbGz = specPbGz
o.specBytesETag = specBytesETag
o.specPbETag = specPbETag
o.specPbGzETag = specPbGzETag
o.lastModified = lastModified
o.jsonCache = o.jsonCache.New(func() ([]byte, error) {
return json.Marshal(openapiSpec)
})
o.protoCache = o.protoCache.New(func() ([]byte, error) {
json, _, err := o.jsonCache.Get()
if err != nil {
return nil, err
}
return ToProtoBinary(json)
})
o.lastModified = time.Now()

return nil
}
Expand Down Expand Up @@ -206,7 +226,7 @@ func (o *OpenAPIService) RegisterOpenAPIVersionedService(servePath string, handl
accepted := []struct {
Type string
SubType string
GetDataAndETag func() ([]byte, string, time.Time)
GetDataAndETag func() ([]byte, string, time.Time, error)
}{
{"application", "json", o.getSwaggerBytes},
{"application", "[email protected]+protobuf", o.getSwaggerPbBytes},
Expand All @@ -230,7 +250,15 @@ func (o *OpenAPIService) RegisterOpenAPIVersionedService(servePath string, handl
}

// serve the first matching media type in the sorted clause list
data, etag, lastModified := accepts.GetDataAndETag()
data, etag, lastModified, err := accepts.GetDataAndETag()
if err != nil {
klog.Errorf("Error in OpenAPI handler: %s", err)
// only return a 503 if we have no older cache data to serve
if data == nil {
w.WriteHeader(http.StatusServiceUnavailable)
return
}
}
w.Header().Set("Etag", etag)
// ServeContent will take care of caching using eTag.
http.ServeContent(w, r, servePath, lastModified, bytes.NewReader(data))
Expand Down
84 changes: 84 additions & 0 deletions pkg/handler/handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package handler

import (
json "encoding/json"
"errors"
"io/ioutil"
"math"
"net/http"
Expand Down Expand Up @@ -177,3 +178,86 @@ func TestToProtoBinary(t *testing.T) {
}
// TODO: add some kind of roundtrip test here
}

func TestCache(t *testing.T) {
calledCount := 0
expectedBytes := []byte("ABC")
cacheObj := cache{
BuildCache: func() ([]byte, error) {
calledCount++
return expectedBytes, nil
},
}
bytes, _, _ := cacheObj.Get()
if string(bytes) != string(expectedBytes) {
t.Fatalf("got value of %q from cache (expected %q)", bytes, expectedBytes)
}
cacheObj.Get()
if calledCount != 1 {
t.Fatalf("expected BuildCache to be called once (called %d times)", calledCount)
}
}

func TestCacheError(t *testing.T) {
cacheObj := cache{
BuildCache: func() ([]byte, error) {
return nil, errors.New("cache error")
},
}
_, _, err := cacheObj.Get()
if err == nil {
t.Fatalf("expected non-nil err from cache.Get()")
}
}

func TestCacheRefresh(t *testing.T) {
// check that returning an error while having no prior cached value results in a nil value from cache.Get()
cacheObj := (&cache{}).New(func() ([]byte, error) {
return nil, errors.New("returning nil bytes")
})
// make multiple calls to Get() to ensure errors are preserved across subsequent calls
for i := 0; i < 4; i++ {
value, _, err := cacheObj.Get()
if value != nil {
t.Fatalf("expected nil bytes (got %s)", value)
}
if err == nil {
t.Fatalf("expected non-nil err from cache.Get()")
}
}
// check that we can call New() multiple times and get the last known cache value while also returning any errors
lastGoodVal := []byte("last good value")
cacheObj = cacheObj.New(func() ([]byte, error) {
return lastGoodVal, nil
})
// call Get() once, so lastGoodVal is cached
_, lastGoodEtag, _ := cacheObj.Get()
for i := 0; i < 4; i++ {
cacheObj = cacheObj.New(func() ([]byte, error) {
return nil, errors.New("check that c.bytes is preserved across New() calls")
})
value, newEtag, err := cacheObj.Get()
if err == nil {
t.Fatalf("expected non-nil err from cache.Get()")
}
if string(value) != string(lastGoodVal) {
t.Fatalf("expected previous value for cache to be returned (got %s, expected %s)", value, lastGoodVal)
}
// check that etags carry over between calls to cache.New()
if lastGoodEtag != newEtag {
t.Fatalf("expected etags to match (got %s, expected %s", newEtag, lastGoodEtag)
}
}
// check that if we successfully renew the cache the old last known value is flushed
newVal := []byte("new good value")
cacheObj = cacheObj.New(func() ([]byte, error) {
return newVal, nil
})
value, _, err := cacheObj.Get()
if err != nil {
t.Fatalf("expected nil err from cache.Get()")
}
if string(value) != string(newVal) {
t.Fatalf("got value of %s from cache (expected %s)", value, newVal)
}
}