Skip to content

Forward headers to AsyncAPI #2329

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 12 commits into from
Jul 19, 2021
10 changes: 2 additions & 8 deletions pkg/async-gateway/endpoint.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,20 +50,14 @@ func (e *Endpoint) CreateWorkload(w http.ResponseWriter, r *http.Request) {
return
}

contentType := r.Header.Get("Content-Type")
if contentType == "" {
respondPlainText(w, http.StatusBadRequest, "error: missing Content-Type key in request header")
return
}

body := r.Body
defer func() {
_ = r.Body.Close()
}()

log := e.logger.With(zap.String("id", requestID), zap.String("contentType", contentType))
log := e.logger.With(zap.String("id", requestID))

id, err := e.service.CreateWorkload(requestID, body, contentType)
id, err := e.service.CreateWorkload(requestID, body, r.Header)
if err != nil {
respondPlainText(w, http.StatusInternalServerError, fmt.Sprintf("error: %v", err))
logErrorWithTelemetry(log, errors.Wrap(err, "failed to create workload"))
Expand Down
29 changes: 22 additions & 7 deletions pkg/async-gateway/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,18 +17,21 @@ limitations under the License.
package gateway

import (
"bytes"
"encoding/json"
"fmt"
"io"
"net/http"
"strings"

"github.com/cortexlabs/cortex/pkg/lib/errors"
"github.com/cortexlabs/cortex/pkg/types/async"
"go.uber.org/zap"
)

// Service provides an interface to the async-gateway business logic
type Service interface {
CreateWorkload(id string, payload io.Reader, contentType string) (string, error)
CreateWorkload(id string, payload io.Reader, headers http.Header) (string, error)
GetWorkload(id string) (GetWorkloadResponse, error)
}

Expand All @@ -52,25 +55,37 @@ func NewService(clusterUID, apiName string, queue Queue, storage Storage, logger
}

// CreateWorkload enqueues an async workload request and uploads the request payload to S3
func (s *service) CreateWorkload(id string, payload io.Reader, contentType string) (string, error) {
func (s *service) CreateWorkload(id string, payload io.Reader, headers http.Header) (string, error) {
prefix := async.StoragePath(s.clusterUID, s.apiName)
log := s.logger.With(zap.String("id", id), zap.String("contentType", contentType))
log := s.logger.With(zap.String("id", id))

buf := &bytes.Buffer{}
if err := json.NewEncoder(buf).Encode(headers); err != nil {
return "", errors.Wrap(err, "failed to dump headers")
}

headersPath := async.HeadersPath(prefix, id)
log.Debugw("uploading headers", zap.String("path", headersPath))
if err := s.storage.Upload(headersPath, buf, "application/json"); err != nil {
return "", errors.Wrap(err, "failed to upload headers")
}

contentType := headers.Get("Content-Type")
payloadPath := async.PayloadPath(prefix, id)
log.Debug("uploading payload", zap.String("path", payloadPath))
log.Debugw("uploading payload", zap.String("path", payloadPath))
if err := s.storage.Upload(payloadPath, payload, contentType); err != nil {
return "", err
return "", errors.Wrap(err, "failed to upload payload")
}

log.Debug("sending message to queue")
if err := s.queue.SendMessage(id, id); err != nil {
return "", err
return "", errors.Wrap(err, "failed to send message to queue")
}

statusPath := fmt.Sprintf("%s/%s/status/%s", prefix, id, async.StatusInQueue)
log.Debug(fmt.Sprintf("setting status to %s", async.StatusInQueue))
if err := s.storage.Upload(statusPath, strings.NewReader(""), "text/plain"); err != nil {
return "", err
return "", errors.Wrap(err, "failed to upload workload status")
}

return id, nil
Expand Down
51 changes: 30 additions & 21 deletions pkg/dequeuer/async_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,11 +55,6 @@ type AsyncMessageHandlerConfig struct {
TargetURL string
}

type userPayload struct {
Body io.ReadCloser
ContentType string
}

func NewAsyncMessageHandler(config AsyncMessageHandlerConfig, awsClient *awslib.Client, eventHandler RequestEventHandler, logger *zap.SugaredLogger) *AsyncMessageHandler {
return &AsyncMessageHandler{
config: config,
Expand Down Expand Up @@ -104,9 +99,21 @@ func (h *AsyncMessageHandler) handleMessage(requestID string) error {
}
return errors.Wrap(err, "failed to get payload")
}
defer h.deletePayload(requestID)
defer func() {
h.deletePayload(requestID)
_ = payload.Close()
}()

result, err := h.submitRequest(payload, requestID)
headers, err := h.getHeaders(requestID)
if err != nil {
updateStatusErr := h.updateStatus(requestID, async.StatusFailed)
if updateStatusErr != nil {
h.log.Errorw("failed to update status after failure to get headers", "id", requestID, "error", updateStatusErr)
}
return errors.Wrap(err, "failed to get payload")
}

result, err := h.submitRequest(payload, headers, requestID)
if err != nil {
h.log.Errorw("failed to submit request to user container", "id", requestID, "error", err)
updateStatusErr := h.updateStatus(requestID, async.StatusFailed)
Expand Down Expand Up @@ -138,7 +145,7 @@ func (h *AsyncMessageHandler) updateStatus(requestID string, status async.Status
return h.aws.UploadStringToS3("", h.config.Bucket, key)
}

func (h *AsyncMessageHandler) getPayload(requestID string) (*userPayload, error) {
func (h *AsyncMessageHandler) getPayload(requestID string) (io.ReadCloser, error) {
key := async.PayloadPath(h.storagePath, requestID)
output, err := h.aws.S3().GetObject(
&s3.GetObjectInput{
Expand All @@ -149,16 +156,7 @@ func (h *AsyncMessageHandler) getPayload(requestID string) (*userPayload, error)
if err != nil {
return nil, errors.WithStack(err)
}

contentType := "application/octet-stream"
if output.ContentType != nil {
contentType = *output.ContentType
}

return &userPayload{
Body: output.Body,
ContentType: contentType,
}, nil
return output.Body, nil
}

func (h *AsyncMessageHandler) deletePayload(requestID string) {
Expand All @@ -170,13 +168,13 @@ func (h *AsyncMessageHandler) deletePayload(requestID string) {
}
}

func (h *AsyncMessageHandler) submitRequest(payload *userPayload, requestID string) (interface{}, error) {
req, err := http.NewRequest(http.MethodPost, h.config.TargetURL, payload.Body)
func (h *AsyncMessageHandler) submitRequest(payload io.Reader, headers http.Header, requestID string) (interface{}, error) {
req, err := http.NewRequest(http.MethodPost, h.config.TargetURL, payload)
if err != nil {
return nil, errors.WithStack(err)
}

req.Header.Set("Content-Type", payload.ContentType)
req.Header = headers
req.Header.Set(CortexRequestIDHeader, requestID)

startTime := time.Now()
Expand Down Expand Up @@ -216,3 +214,14 @@ func (h *AsyncMessageHandler) uploadResult(requestID string, result interface{})
key := async.ResultPath(h.storagePath, requestID)
return h.aws.UploadJSONToS3(result, h.config.Bucket, key)
}

func (h *AsyncMessageHandler) getHeaders(requestID string) (http.Header, error) {
key := async.HeadersPath(h.storagePath, requestID)

var headers http.Header
if err := h.aws.ReadJSONFromS3(&headers, h.config.Bucket, key); err != nil {
return nil, err
}

return headers, nil
}
5 changes: 4 additions & 1 deletion pkg/dequeuer/async_handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,10 @@ func TestAsyncMessageHandler_Handle(t *testing.T) {
})
require.NoError(t, err)

err = awsClient.UploadStringToS3("{}", asyncHandler.config.Bucket, fmt.Sprintf("%s/%s/payload", asyncHandler.storagePath, requestID))
err = awsClient.UploadStringToS3("{}", asyncHandler.config.Bucket, async.PayloadPath(asyncHandler.storagePath, requestID))
require.NoError(t, err)

err = awsClient.UploadStringToS3("{}", asyncHandler.config.Bucket, async.HeadersPath(asyncHandler.storagePath, requestID))
require.NoError(t, err)

err = asyncHandler.Handle(&sqs.Message{
Expand Down
4 changes: 4 additions & 0 deletions pkg/types/async/s3_paths.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,10 @@ func PayloadPath(storagePath string, requestID string) string {
return fmt.Sprintf("%s/%s/payload", storagePath, requestID)
}

func HeadersPath(storagePath string, requestID string) string {
return fmt.Sprintf("%s/%s/headers.json", storagePath, requestID)
}

func ResultPath(storagePath string, requestID string) string {
return fmt.Sprintf("%s/%s/result.json", storagePath, requestID)
}
Expand Down
2 changes: 1 addition & 1 deletion test/e2e/tests/aws/test_autoscaling.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@


@pytest.mark.usefixtures("client")
@pytest.mark.parametrize("apis", TEST_APIS)
@pytest.mark.parametrize("apis", TEST_APIS, ids=[api["primary"] for api in TEST_APIS])
def test_autoscaling(printer: Callable, config: Dict, client: cx.Client, apis: Dict[str, Any]):
skip_autoscaling_test = config["global"].get("skip_autoscaling", False)
if skip_autoscaling_test:
Expand Down
4 changes: 1 addition & 3 deletions test/e2e/tests/aws/test_realtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,10 +64,8 @@ def test_realtime_api(printer: Callable, config: Dict, client: cx.Client, api: D


@pytest.mark.usefixtures("client")
@pytest.mark.parametrize("api", TEST_APIS_ARM)
@pytest.mark.parametrize("api", TEST_APIS_ARM, ids=[api["name"] for api in TEST_APIS_ARM])
def test_realtime_api_arm(printer: Callable, config: Dict, client: cx.Client, api: Dict[str, str]):

printer(f"testing {api['name']}")
e2e.tests.test_realtime_api(
printer=printer,
client=client,
Expand Down