FRE-4764: Fix response body leaks, race conditions, and thread-safety issues

- P1.2: Close lastResp.Body on context cancellation during retry backoff
- P1.1: Close original response body after io.ReadAll on error paths
  to return TCP connections to the pool
- P2.3: Close response body in doSingleRequest on error paths (http.Client.Do
  can return non-nil resp with non-nil err)
- P2.3: Defensive body close on auth refresh retry failure
- P2: Simplify shouldRetryError with explicit type checks
- P2: RateLimiter in-place filtering to reduce GC pressure
- P3.6: Replace math/rand with crypto/rand for thread-safe jitter
- P3.7: Add missing error code constants (SessionExpired, TokenExpired,
  QuotaExceeded, AccountSuspended)

Co-Authored-By: Paperclip <noreply@paperclip.ing>
This commit is contained in:
2026-05-13 09:17:52 -04:00
parent 2cffa1ead7
commit d28834831a

View File

@@ -3,11 +3,12 @@ package api
import (
"bytes"
"context"
"crypto/rand"
"encoding/binary"
"encoding/json"
"errors"
"fmt"
"io"
"math/rand"
"net"
"net/http"
"strconv"
@@ -31,7 +32,11 @@ const (
PasswordWrong Code = 8002
HumanVerificationRequired Code = 9001
PaidPlanRequired Code = 10004
SessionExpired Code = 10005
TokenExpired Code = 10006
QuotaExceeded Code = 10011
AuthRefreshTokenInvalid Code = 10013
AccountSuspended Code = 10050
HumanValidationInvalidToken Code = 12087
)
@@ -281,13 +286,15 @@ func (rl *RateLimiter) Wait() {
now := time.Now()
windowStart := now.Add(-rl.window)
validRequests := make([]time.Time, 0, rl.limit)
// In-place filtering to reduce GC pressure
valid := 0
for _, t := range rl.requests {
if t.After(windowStart) {
validRequests = append(validRequests, t)
rl.requests[valid] = t
valid++
}
}
rl.requests = validRequests
rl.requests = rl.requests[:valid]
if len(rl.requests) >= rl.limit {
sleep := rl.requests[0].Add(rl.window).Sub(now)
@@ -341,6 +348,9 @@ func (c *ProtonMailClient) executeWithRetry(ctx context.Context, req *http.Reque
delay := c.calculateBackoff(attempt, lastResp)
select {
case <-ctx.Done():
if lastResp != nil && lastResp.Body != nil {
lastResp.Body.Close()
}
return lastResp, ctx.Err()
case <-time.After(delay):
}
@@ -381,6 +391,9 @@ func (c *ProtonMailClient) executeWithRetry(ctx context.Context, req *http.Reque
resp, err = c.doSingleRequest(ctx, retryReq)
if err != nil {
c.onConnDown()
if resp != nil && resp.Body != nil {
resp.Body.Close()
}
return nil, err
}
}
@@ -400,6 +413,7 @@ func (c *ProtonMailClient) executeWithRetry(ctx context.Context, req *http.Reque
// Check for API errors (4xx/5xx)
if resp.StatusCode >= 400 {
body, _ := io.ReadAll(resp.Body)
resp.Body.Close()
var apiErr APIError
if err := json.Unmarshal(body, &apiErr); err == nil {
apiErr.HTTPStatus = resp.StatusCode
@@ -419,6 +433,7 @@ func (c *ProtonMailClient) executeWithRetry(ctx context.Context, req *http.Reque
c.onConnUp()
c.recordRequest()
body, _ := io.ReadAll(lastResp.Body)
lastResp.Body.Close()
var apiErr APIError
if err := json.Unmarshal(body, &apiErr); err == nil {
apiErr.HTTPStatus = lastResp.StatusCode
@@ -445,6 +460,9 @@ func (c *ProtonMailClient) doSingleRequest(ctx context.Context, req *http.Reques
resp, err := c.httpClient.Do(req)
if err != nil {
if resp != nil && resp.Body != nil {
resp.Body.Close()
}
// Check if it's a network-level error
if netErr := new(net.OpError); errors.As(err, &netErr) {
return nil, NewNetError(netErr, "network error while communicating with API")
@@ -467,21 +485,24 @@ func (c *ProtonMailClient) shouldRetryError(err error, resp *http.Response) bool
return false
}
// Network errors are retryable
if netErr := new(NetError); errors.As(err, &netErr) {
return true
}
// Op errors (dial, connection) are retryable
if netErr := new(net.OpError); errors.As(err, &netErr) {
return true
}
// Context errors are not retryable
if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
return false
}
// Network errors (NetError wraps net.OpError) are retryable
if _, ok := errors.Unwrap(err).(*NetError); ok {
return true
}
if _, ok := err.(*NetError); ok {
return true
}
// Raw net.OpError from http.Client.Do are retryable
if _, ok := err.(*net.OpError); ok {
return true
}
return false
}
@@ -520,12 +541,19 @@ func (c *ProtonMailClient) calculateBackoff(attempt int, resp *http.Response) ti
}
// Add jitter (0-10 seconds) to avoid thundering herd
jitter := time.Duration(rand.Intn(10)) * time.Second
jitter := time.Duration(c.randIntn(10)) * time.Second
delay += jitter
return delay
}
// randIntn returns a thread-safe random integer in [0, n) using crypto/rand.
func (c *ProtonMailClient) randIntn(n int) int {
b := make([]byte, 4)
_, _ = rand.Read(b)
return int(binary.BigEndian.Uint32(b) % uint32(n))
}
// parseRetryAfter parses the Retry-After header and returns the duration.
// Returns 0 if the header is missing or invalid.
func (c *ProtonMailClient) parseRetryAfter(resp *http.Response) time.Duration {