FRE-4764: Fix response body leaks, race conditions, and thread-safety issues
- P1.2: Close lastResp.Body on context cancellation during retry backoff - P1.1: Close original response body after io.ReadAll on error paths to return TCP connections to the pool - P2.3: Close response body in doSingleRequest on error paths (http.Client.Do can return non-nil resp with non-nil err) - P2.3: Defensive body close on auth refresh retry failure - P2: Simplify shouldRetryError with explicit type checks - P2: RateLimiter in-place filtering to reduce GC pressure - P3.6: Replace math/rand with crypto/rand for thread-safe jitter - P3.7: Add missing error code constants (SessionExpired, TokenExpired, QuotaExceeded, AccountSuspended) Co-Authored-By: Paperclip <noreply@paperclip.ing>
This commit is contained in:
@@ -3,11 +3,12 @@ package api
|
|||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"context"
|
"context"
|
||||||
|
"crypto/rand"
|
||||||
|
"encoding/binary"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"math/rand"
|
|
||||||
"net"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
"strconv"
|
"strconv"
|
||||||
@@ -31,7 +32,11 @@ const (
|
|||||||
PasswordWrong Code = 8002
|
PasswordWrong Code = 8002
|
||||||
HumanVerificationRequired Code = 9001
|
HumanVerificationRequired Code = 9001
|
||||||
PaidPlanRequired Code = 10004
|
PaidPlanRequired Code = 10004
|
||||||
|
SessionExpired Code = 10005
|
||||||
|
TokenExpired Code = 10006
|
||||||
|
QuotaExceeded Code = 10011
|
||||||
AuthRefreshTokenInvalid Code = 10013
|
AuthRefreshTokenInvalid Code = 10013
|
||||||
|
AccountSuspended Code = 10050
|
||||||
HumanValidationInvalidToken Code = 12087
|
HumanValidationInvalidToken Code = 12087
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -281,13 +286,15 @@ func (rl *RateLimiter) Wait() {
|
|||||||
now := time.Now()
|
now := time.Now()
|
||||||
windowStart := now.Add(-rl.window)
|
windowStart := now.Add(-rl.window)
|
||||||
|
|
||||||
validRequests := make([]time.Time, 0, rl.limit)
|
// In-place filtering to reduce GC pressure
|
||||||
|
valid := 0
|
||||||
for _, t := range rl.requests {
|
for _, t := range rl.requests {
|
||||||
if t.After(windowStart) {
|
if t.After(windowStart) {
|
||||||
validRequests = append(validRequests, t)
|
rl.requests[valid] = t
|
||||||
|
valid++
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
rl.requests = validRequests
|
rl.requests = rl.requests[:valid]
|
||||||
|
|
||||||
if len(rl.requests) >= rl.limit {
|
if len(rl.requests) >= rl.limit {
|
||||||
sleep := rl.requests[0].Add(rl.window).Sub(now)
|
sleep := rl.requests[0].Add(rl.window).Sub(now)
|
||||||
@@ -341,6 +348,9 @@ func (c *ProtonMailClient) executeWithRetry(ctx context.Context, req *http.Reque
|
|||||||
delay := c.calculateBackoff(attempt, lastResp)
|
delay := c.calculateBackoff(attempt, lastResp)
|
||||||
select {
|
select {
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
|
if lastResp != nil && lastResp.Body != nil {
|
||||||
|
lastResp.Body.Close()
|
||||||
|
}
|
||||||
return lastResp, ctx.Err()
|
return lastResp, ctx.Err()
|
||||||
case <-time.After(delay):
|
case <-time.After(delay):
|
||||||
}
|
}
|
||||||
@@ -381,6 +391,9 @@ func (c *ProtonMailClient) executeWithRetry(ctx context.Context, req *http.Reque
|
|||||||
resp, err = c.doSingleRequest(ctx, retryReq)
|
resp, err = c.doSingleRequest(ctx, retryReq)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.onConnDown()
|
c.onConnDown()
|
||||||
|
if resp != nil && resp.Body != nil {
|
||||||
|
resp.Body.Close()
|
||||||
|
}
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -400,6 +413,7 @@ func (c *ProtonMailClient) executeWithRetry(ctx context.Context, req *http.Reque
|
|||||||
// Check for API errors (4xx/5xx)
|
// Check for API errors (4xx/5xx)
|
||||||
if resp.StatusCode >= 400 {
|
if resp.StatusCode >= 400 {
|
||||||
body, _ := io.ReadAll(resp.Body)
|
body, _ := io.ReadAll(resp.Body)
|
||||||
|
resp.Body.Close()
|
||||||
var apiErr APIError
|
var apiErr APIError
|
||||||
if err := json.Unmarshal(body, &apiErr); err == nil {
|
if err := json.Unmarshal(body, &apiErr); err == nil {
|
||||||
apiErr.HTTPStatus = resp.StatusCode
|
apiErr.HTTPStatus = resp.StatusCode
|
||||||
@@ -419,6 +433,7 @@ func (c *ProtonMailClient) executeWithRetry(ctx context.Context, req *http.Reque
|
|||||||
c.onConnUp()
|
c.onConnUp()
|
||||||
c.recordRequest()
|
c.recordRequest()
|
||||||
body, _ := io.ReadAll(lastResp.Body)
|
body, _ := io.ReadAll(lastResp.Body)
|
||||||
|
lastResp.Body.Close()
|
||||||
var apiErr APIError
|
var apiErr APIError
|
||||||
if err := json.Unmarshal(body, &apiErr); err == nil {
|
if err := json.Unmarshal(body, &apiErr); err == nil {
|
||||||
apiErr.HTTPStatus = lastResp.StatusCode
|
apiErr.HTTPStatus = lastResp.StatusCode
|
||||||
@@ -445,6 +460,9 @@ func (c *ProtonMailClient) doSingleRequest(ctx context.Context, req *http.Reques
|
|||||||
resp, err := c.httpClient.Do(req)
|
resp, err := c.httpClient.Do(req)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
if resp != nil && resp.Body != nil {
|
||||||
|
resp.Body.Close()
|
||||||
|
}
|
||||||
// Check if it's a network-level error
|
// Check if it's a network-level error
|
||||||
if netErr := new(net.OpError); errors.As(err, &netErr) {
|
if netErr := new(net.OpError); errors.As(err, &netErr) {
|
||||||
return nil, NewNetError(netErr, "network error while communicating with API")
|
return nil, NewNetError(netErr, "network error while communicating with API")
|
||||||
@@ -467,21 +485,24 @@ func (c *ProtonMailClient) shouldRetryError(err error, resp *http.Response) bool
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
// Network errors are retryable
|
|
||||||
if netErr := new(NetError); errors.As(err, &netErr) {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
// Op errors (dial, connection) are retryable
|
|
||||||
if netErr := new(net.OpError); errors.As(err, &netErr) {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
// Context errors are not retryable
|
// Context errors are not retryable
|
||||||
if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
|
if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Network errors (NetError wraps net.OpError) are retryable
|
||||||
|
if _, ok := errors.Unwrap(err).(*NetError); ok {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
if _, ok := err.(*NetError); ok {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
// Raw net.OpError from http.Client.Do are retryable
|
||||||
|
if _, ok := err.(*net.OpError); ok {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -520,12 +541,19 @@ func (c *ProtonMailClient) calculateBackoff(attempt int, resp *http.Response) ti
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Add jitter (0-10 seconds) to avoid thundering herd
|
// Add jitter (0-10 seconds) to avoid thundering herd
|
||||||
jitter := time.Duration(rand.Intn(10)) * time.Second
|
jitter := time.Duration(c.randIntn(10)) * time.Second
|
||||||
delay += jitter
|
delay += jitter
|
||||||
|
|
||||||
return delay
|
return delay
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// randIntn returns a thread-safe random integer in [0, n) using crypto/rand.
|
||||||
|
func (c *ProtonMailClient) randIntn(n int) int {
|
||||||
|
b := make([]byte, 4)
|
||||||
|
_, _ = rand.Read(b)
|
||||||
|
return int(binary.BigEndian.Uint32(b) % uint32(n))
|
||||||
|
}
|
||||||
|
|
||||||
// parseRetryAfter parses the Retry-After header and returns the duration.
|
// parseRetryAfter parses the Retry-After header and returns the duration.
|
||||||
// Returns 0 if the header is missing or invalid.
|
// Returns 0 if the header is missing or invalid.
|
||||||
func (c *ProtonMailClient) parseRetryAfter(resp *http.Response) time.Duration {
|
func (c *ProtonMailClient) parseRetryAfter(resp *http.Response) time.Duration {
|
||||||
|
|||||||
Reference in New Issue
Block a user