FRE-4764: Fix response body leaks, race conditions, and thread-safety issues
- P1.2: Close lastResp.Body on context cancellation during retry backoff - P1.1: Close original response body after io.ReadAll on error paths to return TCP connections to the pool - P2.3: Close response body in doSingleRequest on error paths (http.Client.Do can return non-nil resp with non-nil err) - P2.3: Defensive body close on auth refresh retry failure - P2: Simplify shouldRetryError with explicit type checks - P2: RateLimiter in-place filtering to reduce GC pressure - P3.6: Replace math/rand with crypto/rand for thread-safe jitter - P3.7: Add missing error code constants (SessionExpired, TokenExpired, QuotaExceeded, AccountSuspended) Co-Authored-By: Paperclip <noreply@paperclip.ing>
This commit is contained in:
@@ -3,11 +3,12 @@ package api
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"encoding/binary"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"math/rand"
|
||||
"net"
|
||||
"net/http"
|
||||
"strconv"
|
||||
@@ -31,7 +32,11 @@ const (
|
||||
PasswordWrong Code = 8002
|
||||
HumanVerificationRequired Code = 9001
|
||||
PaidPlanRequired Code = 10004
|
||||
SessionExpired Code = 10005
|
||||
TokenExpired Code = 10006
|
||||
QuotaExceeded Code = 10011
|
||||
AuthRefreshTokenInvalid Code = 10013
|
||||
AccountSuspended Code = 10050
|
||||
HumanValidationInvalidToken Code = 12087
|
||||
)
|
||||
|
||||
@@ -281,13 +286,15 @@ func (rl *RateLimiter) Wait() {
|
||||
now := time.Now()
|
||||
windowStart := now.Add(-rl.window)
|
||||
|
||||
validRequests := make([]time.Time, 0, rl.limit)
|
||||
// In-place filtering to reduce GC pressure
|
||||
valid := 0
|
||||
for _, t := range rl.requests {
|
||||
if t.After(windowStart) {
|
||||
validRequests = append(validRequests, t)
|
||||
rl.requests[valid] = t
|
||||
valid++
|
||||
}
|
||||
}
|
||||
rl.requests = validRequests
|
||||
rl.requests = rl.requests[:valid]
|
||||
|
||||
if len(rl.requests) >= rl.limit {
|
||||
sleep := rl.requests[0].Add(rl.window).Sub(now)
|
||||
@@ -341,6 +348,9 @@ func (c *ProtonMailClient) executeWithRetry(ctx context.Context, req *http.Reque
|
||||
delay := c.calculateBackoff(attempt, lastResp)
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
if lastResp != nil && lastResp.Body != nil {
|
||||
lastResp.Body.Close()
|
||||
}
|
||||
return lastResp, ctx.Err()
|
||||
case <-time.After(delay):
|
||||
}
|
||||
@@ -381,6 +391,9 @@ func (c *ProtonMailClient) executeWithRetry(ctx context.Context, req *http.Reque
|
||||
resp, err = c.doSingleRequest(ctx, retryReq)
|
||||
if err != nil {
|
||||
c.onConnDown()
|
||||
if resp != nil && resp.Body != nil {
|
||||
resp.Body.Close()
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
@@ -400,6 +413,7 @@ func (c *ProtonMailClient) executeWithRetry(ctx context.Context, req *http.Reque
|
||||
// Check for API errors (4xx/5xx)
|
||||
if resp.StatusCode >= 400 {
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
resp.Body.Close()
|
||||
var apiErr APIError
|
||||
if err := json.Unmarshal(body, &apiErr); err == nil {
|
||||
apiErr.HTTPStatus = resp.StatusCode
|
||||
@@ -419,6 +433,7 @@ func (c *ProtonMailClient) executeWithRetry(ctx context.Context, req *http.Reque
|
||||
c.onConnUp()
|
||||
c.recordRequest()
|
||||
body, _ := io.ReadAll(lastResp.Body)
|
||||
lastResp.Body.Close()
|
||||
var apiErr APIError
|
||||
if err := json.Unmarshal(body, &apiErr); err == nil {
|
||||
apiErr.HTTPStatus = lastResp.StatusCode
|
||||
@@ -445,6 +460,9 @@ func (c *ProtonMailClient) doSingleRequest(ctx context.Context, req *http.Reques
|
||||
resp, err := c.httpClient.Do(req)
|
||||
|
||||
if err != nil {
|
||||
if resp != nil && resp.Body != nil {
|
||||
resp.Body.Close()
|
||||
}
|
||||
// Check if it's a network-level error
|
||||
if netErr := new(net.OpError); errors.As(err, &netErr) {
|
||||
return nil, NewNetError(netErr, "network error while communicating with API")
|
||||
@@ -467,21 +485,24 @@ func (c *ProtonMailClient) shouldRetryError(err error, resp *http.Response) bool
|
||||
return false
|
||||
}
|
||||
|
||||
// Network errors are retryable
|
||||
if netErr := new(NetError); errors.As(err, &netErr) {
|
||||
return true
|
||||
}
|
||||
|
||||
// Op errors (dial, connection) are retryable
|
||||
if netErr := new(net.OpError); errors.As(err, &netErr) {
|
||||
return true
|
||||
}
|
||||
|
||||
// Context errors are not retryable
|
||||
if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
|
||||
return false
|
||||
}
|
||||
|
||||
// Network errors (NetError wraps net.OpError) are retryable
|
||||
if _, ok := errors.Unwrap(err).(*NetError); ok {
|
||||
return true
|
||||
}
|
||||
if _, ok := err.(*NetError); ok {
|
||||
return true
|
||||
}
|
||||
|
||||
// Raw net.OpError from http.Client.Do are retryable
|
||||
if _, ok := err.(*net.OpError); ok {
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
@@ -520,12 +541,19 @@ func (c *ProtonMailClient) calculateBackoff(attempt int, resp *http.Response) ti
|
||||
}
|
||||
|
||||
// Add jitter (0-10 seconds) to avoid thundering herd
|
||||
jitter := time.Duration(rand.Intn(10)) * time.Second
|
||||
jitter := time.Duration(c.randIntn(10)) * time.Second
|
||||
delay += jitter
|
||||
|
||||
return delay
|
||||
}
|
||||
|
||||
// randIntn returns a thread-safe random integer in [0, n) using crypto/rand.
|
||||
func (c *ProtonMailClient) randIntn(n int) int {
|
||||
b := make([]byte, 4)
|
||||
_, _ = rand.Read(b)
|
||||
return int(binary.BigEndian.Uint32(b) % uint32(n))
|
||||
}
|
||||
|
||||
// parseRetryAfter parses the Retry-After header and returns the duration.
|
||||
// Returns 0 if the header is missing or invalid.
|
||||
func (c *ProtonMailClient) parseRetryAfter(resp *http.Response) time.Duration {
|
||||
|
||||
Reference in New Issue
Block a user