From fdcd41ab491343c143ea259010761d51ea27ca86 Mon Sep 17 00:00:00 2001 From: Li Yazhou Date: Tue, 14 Jan 2025 13:16:03 +0800 Subject: [PATCH] chroe --- apierrors.go | 4 +-- client.go | 70 ++++++++++++++++++++++++++++------------------------ 2 files changed, 40 insertions(+), 34 deletions(-) diff --git a/apierrors.go b/apierrors.go index 0bcd4fc..790147d 100644 --- a/apierrors.go +++ b/apierrors.go @@ -24,8 +24,8 @@ import ( var ( ProvisionWarehouseTimeout = "ProvisionWarehouseTimeout" - ErrDoRequest = errors.New("DoReqeustFailed") - ErrReadResponse = errors.New("ReadResponseFailed") + ErrDoRequest = errors.New("failed to do request") + ErrReadResponse = errors.New("failed to read response") ) type APIErrorResponseBody struct { diff --git a/client.go b/client.go index 05aba44..6460bda 100644 --- a/client.go +++ b/client.go @@ -7,6 +7,7 @@ import ( "database/sql/driver" "encoding/base64" "encoding/json" + "errors" "fmt" "io" "math/rand" @@ -18,7 +19,6 @@ import ( "github.com/avast/retry-go" "github.com/google/uuid" - "github.com/pkg/errors" "go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp" ) @@ -229,40 +229,46 @@ func (c *APIClient) doRequest(ctx context.Context, method, path string, req inte if req != nil { reqBody, err = json.Marshal(req) if err != nil { - return errors.Wrap(err, "failed to marshal request body") + return fmt.Errorf("failed to marshal request body: %w", err) } } url := c.makeURL(path) httpReq, err := http.NewRequest(method, url, bytes.NewBuffer(reqBody)) if err != nil { - return errors.Wrap(err, "failed to create http request") + return fmt.Errorf("failed to create http request: %w", err) } httpReq = httpReq.WithContext(ctx) maxRetries := 2 for i := 1; i <= maxRetries; i++ { + // do not retry if context is canceled + select { + case <-ctx.Done(): + return ctx.Err() + default: + } + headers, err := c.makeHeaders(ctx) + if err != nil { + return fmt.Errorf("failed to make request headers: %w", err) + } if needSticky && len(c.NodeID) != 0 { headers.Set(DatabendQueryStickyNode, c.NodeID) } - if err != nil { - return errors.Wrap(err, "failed to make request headers") - } if method == "GET" && len(c.NodeID) != 0 { headers.Set(DatabendQueryIDNode, c.NodeID) } headers.Set(contentType, jsonContentType) headers.Set(accept, jsonContentType) httpReq.Header = headers - if len(c.host) > 0 { httpReq.Host = c.host } httpResp, err := c.cli.Do(httpReq) if err != nil { - return errors.Wrap(ErrDoRequest, err.Error()) + return errors.Join(ErrDoRequest, err) } defer func() { _ = httpResp.Body.Close() @@ -270,7 +276,7 @@ func (c *APIClient) doRequest(ctx context.Context, method, path string, req inte httpRespBody, err := io.ReadAll(httpResp.Body) if err != nil { - return errors.Wrap(ErrReadResponse, err.Error()) + return errors.Join(ErrReadResponse, err) } if httpResp.StatusCode == http.StatusUnauthorized { @@ -292,7 +298,7 @@ func (c *APIClient) doRequest(ctx context.Context, method, path string, req inte contentType := httpResp.Header.Get("Content-Type") if strings.HasPrefix(contentType, "application/json") { if err := json.Unmarshal(httpRespBody, &resp); err != nil { - return errors.Wrap(err, "failed to unmarshal response body") + return fmt.Errorf("failed to unmarshal response body: %w", err) } } } @@ -301,7 +307,7 @@ func (c *APIClient) doRequest(ctx context.Context, method, path string, req inte } return nil } - return errors.Errorf("failed to do request after %d retries", maxRetries) + return fmt.Errorf("failed to do request after %d retries", maxRetries) } func (c *APIClient) trackStats(resp *QueryResponse) { @@ -355,11 +361,11 @@ func (c *APIClient) makeHeaders(ctx context.Context) (http.Header, error) { case AuthMethodAccessToken: accessToken, err := c.accessTokenLoader.LoadAccessToken(context.TODO(), false) if err != nil { - return nil, errors.Wrap(err, "failed to load access token") + return nil, fmt.Errorf("failed to load access token: %w", err) } headers.Set(Authorization, fmt.Sprintf("Bearer %s", accessToken)) default: - return nil, errors.New("no user password or access token") + return nil, fmt.Errorf("no user password or access token") } return headers, nil @@ -426,7 +432,7 @@ func (c *APIClient) PollUntilQueryEnd(ctx context.Context, resp *QueryResponse) return nil, err } if resp.Error != nil { - return nil, errors.Wrap(resp.Error, "query page has error") + return nil, fmt.Errorf("query page has error: %w", resp.Error) } resp.Data = append(data, resp.Data...) } @@ -437,7 +443,7 @@ func buildQuery(query string, params []driver.Value) (string, error) { if len(params) > 0 && params[0] != nil { result, err := interpolateParams(query, params) if err != nil { - return result, errors.Wrap(err, "buildRequest: failed to interpolate params") + return result, fmt.Errorf("buildRequest: failed to interpolate params: %w", err) } return result, nil } @@ -508,7 +514,7 @@ func (c *APIClient) startQueryRequest(ctx context.Context, request *QueryRequest }, Query, ) if err != nil { - return nil, errors.Wrap(err, "failed to do query request") + return nil, fmt.Errorf("failed to do query request: %w", err) } if len(resp.NodeID) != 0 { @@ -551,7 +557,7 @@ func (c *APIClient) PollQuery(ctx context.Context, nextURI string) (*QueryRespon c.applySessionState(&result) c.trackStats(&result) if err != nil { - return nil, errors.Wrap(err, "failed to query page") + return nil, fmt.Errorf("failed to query page: %w", err) } return &result, nil } @@ -608,7 +614,7 @@ func (c *APIClient) InsertWithStage(ctx context.Context, sql string, stage *Stag _ = c.CloseQuery(ctx, resp) }() if resp.Error != nil { - return nil, errors.Wrap(resp.Error, "query error:") + return nil, fmt.Errorf("query error: %w", resp.Error) } return c.PollUntilQueryEnd(ctx, resp) } @@ -625,20 +631,20 @@ func (c *APIClient) GetPresignedURL(ctx context.Context, stage *StageLocation) ( presignUploadSQL := fmt.Sprintf("PRESIGN UPLOAD %s", stage) resp, err := c.QuerySync(ctx, presignUploadSQL, nil) if err != nil { - return nil, errors.Wrap(err, "failed to query presign url") + return nil, fmt.Errorf("failed to query presign url: %w", err) } if len(resp.Data) < 1 || len(resp.Data[0]) < 2 { - return nil, errors.Errorf("generate presign url invalid response: %+v", resp.Data) + return nil, fmt.Errorf("generate presign url invalid response: %+v", resp.Data) } if resp.Data[0][0] == nil || resp.Data[0][1] == nil || resp.Data[0][2] == nil { - return nil, errors.Errorf("generate presign url invalid response: %+v", resp.Data) + return nil, fmt.Errorf("generate presign url invalid response: %+v", resp.Data) } method := *resp.Data[0][0] url := *resp.Data[0][2] headers := map[string]string{} err = json.Unmarshal([]byte(*resp.Data[0][1]), &headers) if err != nil { - return nil, errors.Wrap(err, "failed to unmarshal headers") + return nil, fmt.Errorf("failed to unmarshal headers: %w", err) } result := &PresignedResponse{ Method: method, @@ -651,7 +657,7 @@ func (c *APIClient) GetPresignedURL(ctx context.Context, stage *StageLocation) ( func (c *APIClient) UploadToStageByPresignURL(ctx context.Context, stage *StageLocation, input *bufio.Reader, size int64) error { presigned, err := c.GetPresignedURL(ctx, stage) if err != nil { - return errors.Wrap(err, "failed to get presigned url") + return fmt.Errorf("failed to get presigned url: %w", err) } req, err := http.NewRequest("PUT", presigned.URL, input) @@ -668,7 +674,7 @@ func (c *APIClient) UploadToStageByPresignURL(ctx context.Context, stage *StageL } resp, err := httpClient.Do(req) if err != nil { - return errors.Wrap(err, "failed to upload to stage by presigned url") + return fmt.Errorf("failed to upload to stage by presigned url: %w", err) } defer func() { _ = resp.Body.Close() @@ -678,7 +684,7 @@ func (c *APIClient) UploadToStageByPresignURL(ctx context.Context, stage *StageL return err } if resp.StatusCode >= 400 { - return errors.Errorf("failed to upload to stage by presigned url, status code: %d, body: %s", resp.StatusCode, string(respBody)) + return fmt.Errorf("failed to upload to stage by presigned url, status code: %d, body: %s", resp.StatusCode, string(respBody)) } return nil } @@ -688,28 +694,28 @@ func (c *APIClient) UploadToStageByAPI(ctx context.Context, stage *StageLocation writer := multipart.NewWriter(body) part, err := writer.CreateFormFile("upload", stage.Path) if err != nil { - return errors.Wrap(err, "failed to create multipart writer form file") + return fmt.Errorf("failed to create multipart writer form file: %w", err) } // TODO: do async upload _, err = io.Copy(part, input) if err != nil { - return errors.Wrap(err, "failed to copy file to multipart writer form file") + return fmt.Errorf("failed to copy file to multipart writer form file: %w", err) } err = writer.Close() if err != nil { - return errors.Wrap(err, "failed to close multipart writer") + return fmt.Errorf("failed to close multipart writer: %w", err) } path := "/v1/upload_to_stage" url := c.makeURL(path) req, err := http.NewRequest("PUT", url, body) if err != nil { - return errors.Wrap(err, "failed to create http request") + return fmt.Errorf("failed to create http request: %w", err) } req.Header, err = c.makeHeaders(ctx) if err != nil { - return errors.Wrap(err, "failed to make headers") + return fmt.Errorf("failed to make headers: %w", err) } if len(c.host) > 0 { req.Host = c.host @@ -723,7 +729,7 @@ func (c *APIClient) UploadToStageByAPI(ctx context.Context, stage *StageLocation } resp, err := httpClient.Do(req) if err != nil { - return errors.Wrap(err, "failed http do request") + return fmt.Errorf("failed http do request: %w", err) } defer func() { _ = resp.Body.Close() @@ -731,7 +737,7 @@ func (c *APIClient) UploadToStageByAPI(ctx context.Context, stage *StageLocation respBody, err := io.ReadAll(resp.Body) if err != nil { - return errors.Wrap(err, "failed to read http response body") + return fmt.Errorf("failed to read http response body: %w", err) } if resp.StatusCode == http.StatusUnauthorized {