Files
pop/internal/api/client.go
Michael Freno d28834831a FRE-4764: Fix response body leaks, race conditions, and thread-safety issues
- P1.2: Close lastResp.Body on context cancellation during retry backoff
- P1.1: Close original response body after io.ReadAll on error paths
  to return TCP connections to the pool
- P2.3: Close response body in doSingleRequest on error paths (http.Client.Do
  can return non-nil resp with non-nil err)
- P2.3: Defensive body close on auth refresh retry failure
- P2: Simplify shouldRetryError with explicit type checks
- P2: RateLimiter in-place filtering to reduce GC pressure
- P3.6: Replace math/rand with crypto/rand for thread-safe jitter
- P3.7: Add missing error code constants (SessionExpired, TokenExpired,
  QuotaExceeded, AccountSuspended)

Co-Authored-By: Paperclip <noreply@paperclip.ing>
2026-05-14 00:40:24 -04:00

582 lines
14 KiB
Go

package api
import (
"bytes"
"context"
"crypto/rand"
"encoding/binary"
"encoding/json"
"errors"
"fmt"
"io"
"net"
"net/http"
"strconv"
"sync"
"time"
"github.com/frenocorp/pop/internal/config"
"github.com/frenocorp/pop/internal/auth"
)
// Code represents a structured API error code returned by ProtonMail.
type Code int
const (
SuccessCode Code = 1000
MultiCode Code = 1001
InvalidValue Code = 2001
AppVersionMissingCode Code = 5001
AppVersionBadCode Code = 5003
UsernameInvalid Code = 6003
PasswordWrong Code = 8002
HumanVerificationRequired Code = 9001
PaidPlanRequired Code = 10004
SessionExpired Code = 10005
TokenExpired Code = 10006
QuotaExceeded Code = 10011
AuthRefreshTokenInvalid Code = 10013
AccountSuspended Code = 10050
HumanValidationInvalidToken Code = 12087
)
// Status tracks the connection state to the ProtonMail API.
type Status int
const (
StatusUp Status = iota
StatusDown
)
func (s Status) String() string {
switch s {
case StatusUp:
return "up"
case StatusDown:
return "down"
default:
return "unknown"
}
}
// StatusObserver is called when the connection status changes.
type StatusObserver func(Status)
// APIHVDetails contains information related to human verification requests.
type APIHVDetails struct {
Methods []string `json:"HumanVerificationMethods"`
Token string `json:"HumanVerificationToken"`
}
// ErrDetails contains optional error details which are specific to each request.
type ErrDetails []byte
func (d ErrDetails) MarshalJSON() ([]byte, error) {
return d, nil
}
func (d *ErrDetails) UnmarshalJSON(data []byte) error {
*d = data
return nil
}
// APIError represents an error returned by the ProtonMail API.
type APIError struct {
// HTTPStatus is the HTTP status code of the response.
HTTPStatus int `json:"-"`
// Code is the structured error code returned by the API.
Code Code `json:"Code,omitempty"`
// Message is the human-readable error message.
Message string `json:"Message,omitempty"`
// Details contains optional error details (serialized JSON).
Details ErrDetails `json:"Details,omitempty"`
}
func (e *APIError) Error() string {
return fmt.Sprintf("API error %d (code=%d): %s", e.HTTPStatus, e.Code, e.Message)
}
// IsHVError returns true if this error requires human verification.
func (e *APIError) IsHVError() bool {
return e.Code == HumanVerificationRequired
}
// GetHVDetails parses the Details field and returns structured HV information.
func (e *APIError) GetHVDetails() (*APIHVDetails, error) {
if !e.IsHVError() {
return nil, fmt.Errorf("not an HV error (code=%d): %w", e.Code, ErrNotHVError)
}
var details APIHVDetails
if err := json.Unmarshal(e.Details, &details); err != nil {
return nil, fmt.Errorf("failed to parse HV details: %w", err)
}
return &details, nil
}
// ErrNotHVError is returned when GetHVDetails is called on a non-HV error.
var ErrNotHVError = errors.New("not a human verification error")
// NetError represents a network-level error when the API is unreachable.
type NetError struct {
// Cause is the underlying error that caused the network error.
Cause error
// Message describes the network error context.
Message string
}
func NewNetError(cause error, message string) *NetError {
return &NetError{Cause: cause, Message: message}
}
func (e *NetError) Error() string {
return fmt.Sprintf("%s: %v", e.Message, e.Cause)
}
func (e *NetError) Unwrap() error {
return e.Cause
}
func (e *NetError) Is(target error) bool {
_, ok := target.(*NetError)
return ok
}
// RetryConfig configures the retry behavior for API requests.
type RetryConfig struct {
// MaxRetries is the maximum number of retry attempts.
MaxRetries int
// MaxWaitTime is the maximum time to wait before a retry.
MaxWaitTime time.Duration
// BaseBackoff is the base delay for exponential backoff.
BaseBackoff time.Duration
}
// DefaultRetryConfig returns sensible defaults matching the official library.
func DefaultRetryConfig() RetryConfig {
return RetryConfig{
MaxRetries: 3,
MaxWaitTime: time.Minute,
BaseBackoff: 500 * time.Millisecond,
}
}
type ProtonMailClient struct {
baseURL string
httpClient *http.Client
config *config.Config
rateLimiter *RateLimiter
authHeader string
authMu sync.RWMutex
sessionRefresher auth.SessionRefresher
retryConfig RetryConfig
// Connection status tracking
status Status
statusLock sync.Mutex
statusObs []StatusObserver
statusObsMu sync.RWMutex
}
// RateLimiter implements a sliding window rate limiter.
type RateLimiter struct {
mu sync.Mutex
requests []time.Time
limit int
window time.Duration
}
func NewProtonMailClient(cfg *config.Config, refresher auth.SessionRefresher) *ProtonMailClient {
return &ProtonMailClient{
baseURL: cfg.APIBaseURL,
httpClient: &http.Client{Timeout: time.Duration(cfg.TimeoutSec) * time.Second},
config: cfg,
sessionRefresher: refresher,
rateLimiter: &RateLimiter{
requests: make([]time.Time, 0, cfg.RateLimitReq),
limit: cfg.RateLimitReq,
window: time.Duration(cfg.RateLimitWin) * time.Second,
},
retryConfig: DefaultRetryConfig(),
status: StatusUp,
}
}
// SetRetryConfig updates the retry configuration.
func (c *ProtonMailClient) SetRetryConfig(rc RetryConfig) {
c.retryConfig = rc
}
func (c *ProtonMailClient) SetAuthHeader(token string) {
c.authMu.Lock()
defer c.authMu.Unlock()
c.authHeader = token
}
func (c *ProtonMailClient) getAuthHeader() string {
c.authMu.RLock()
defer c.authMu.RUnlock()
return c.authHeader
}
func (c *ProtonMailClient) refreshAuth() error {
if c.sessionRefresher == nil {
return fmt.Errorf("no session refresher configured")
}
return c.sessionRefresher.RefreshToken()
}
func (c *ProtonMailClient) GetBaseURL() string {
return c.baseURL
}
// AddStatusObserver registers a callback for connection status changes.
func (c *ProtonMailClient) AddStatusObserver(observer StatusObserver) {
c.statusObsMu.Lock()
defer c.statusObsMu.Unlock()
c.statusObs = append(c.statusObs, observer)
}
// GetStatus returns the current connection status.
func (c *ProtonMailClient) GetStatus() Status {
c.statusLock.Lock()
defer c.statusLock.Unlock()
return c.status
}
func (c *ProtonMailClient) onConnUp() {
c.statusLock.Lock()
defer c.statusLock.Unlock()
if c.status == StatusUp {
return
}
c.status = StatusUp
c.statusObsMu.RLock()
defer c.statusObsMu.RUnlock()
for _, obs := range c.statusObs {
obs(c.status)
}
}
func (c *ProtonMailClient) onConnDown() {
c.statusLock.Lock()
defer c.statusLock.Unlock()
if c.status == StatusDown {
return
}
c.status = StatusDown
c.statusObsMu.RLock()
defer c.statusObsMu.RUnlock()
for _, obs := range c.statusObs {
obs(c.status)
}
}
func (rl *RateLimiter) Wait() {
rl.mu.Lock()
defer rl.mu.Unlock()
now := time.Now()
windowStart := now.Add(-rl.window)
// In-place filtering to reduce GC pressure
valid := 0
for _, t := range rl.requests {
if t.After(windowStart) {
rl.requests[valid] = t
valid++
}
}
rl.requests = rl.requests[:valid]
if len(rl.requests) >= rl.limit {
sleep := rl.requests[0].Add(rl.window).Sub(now)
if sleep > 0 {
time.Sleep(sleep)
}
}
}
func (c *ProtonMailClient) recordRequest() {
c.rateLimiter.mu.Lock()
c.rateLimiter.requests = append(c.rateLimiter.requests, time.Now())
c.rateLimiter.mu.Unlock()
}
func (c *ProtonMailClient) Do(req *http.Request) (*http.Response, error) {
return c.DoWithContext(context.Background(), req)
}
func (c *ProtonMailClient) DoWithContext(ctx context.Context, req *http.Request) (*http.Response, error) {
c.rateLimiter.Wait()
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", c.getAuthHeader()))
req.Header.Set("Accept", "application/json")
req = req.WithContext(ctx)
resp, err := c.executeWithRetry(ctx, req)
return resp, err
}
// executeWithRetry performs the HTTP request with exponential backoff and retry logic.
func (c *ProtonMailClient) executeWithRetry(ctx context.Context, req *http.Request) (*http.Response, error) {
var lastResp *http.Response
var lastErr error
// Capture request body once so it can be restored on retries.
var bodyBytes []byte
if req.Body != nil {
bodyBytes, _ = io.ReadAll(req.Body)
req.Body = io.NopCloser(bytes.NewReader(bodyBytes))
}
for attempt := 0; attempt <= c.retryConfig.MaxRetries; attempt++ {
// Restore body before each retry attempt
if attempt > 0 && len(bodyBytes) > 0 {
req.Body = io.NopCloser(bytes.NewReader(bodyBytes))
}
if attempt > 0 {
delay := c.calculateBackoff(attempt, lastResp)
select {
case <-ctx.Done():
if lastResp != nil && lastResp.Body != nil {
lastResp.Body.Close()
}
return lastResp, ctx.Err()
case <-time.After(delay):
}
}
resp, err := c.doSingleRequest(ctx, req)
if err != nil {
lastErr = err
lastResp = nil
if !c.shouldRetryError(err, resp) {
c.onConnDown()
return resp, err
}
continue
}
// Check for 401 and attempt token refresh (single shot, no retry loop)
if resp.StatusCode == http.StatusUnauthorized {
resp.Body.Close()
if err := c.refreshAuth(); err != nil {
return resp, fmt.Errorf("401 received and refresh failed: %w", err)
}
session, err := c.sessionRefresher.GetSession()
if err != nil {
return resp, fmt.Errorf("401 received, refresh succeeded but failed to get new session: %w", err)
}
c.SetAuthHeader(session.AccessToken)
// Clone request for retry with new token
retryReq := req.Clone(ctx)
retryReq.Header.Set("Authorization", fmt.Sprintf("Bearer %s", c.getAuthHeader()))
resp, err = c.doSingleRequest(ctx, retryReq)
if err != nil {
c.onConnDown()
if resp != nil && resp.Body != nil {
resp.Body.Close()
}
return nil, err
}
}
// Check if response should trigger retry
if c.shouldRetryResponse(resp) {
if lastResp != nil {
lastResp.Body.Close()
}
lastResp = resp
continue
}
c.onConnUp()
c.recordRequest()
// Check for API errors (4xx/5xx)
if resp.StatusCode >= 400 {
body, _ := io.ReadAll(resp.Body)
resp.Body.Close()
var apiErr APIError
if err := json.Unmarshal(body, &apiErr); err == nil {
apiErr.HTTPStatus = resp.StatusCode
resp.Body = io.NopCloser(bytes.NewReader(body))
return resp, &apiErr
}
// Non-JSON error response: restore body and let caller handle
resp.Body = io.NopCloser(bytes.NewReader(body))
return resp, nil
}
return resp, nil
}
// Exhausted all retries
if lastResp != nil {
c.onConnUp()
c.recordRequest()
body, _ := io.ReadAll(lastResp.Body)
lastResp.Body.Close()
var apiErr APIError
if err := json.Unmarshal(body, &apiErr); err == nil {
apiErr.HTTPStatus = lastResp.StatusCode
lastResp.Body = io.NopCloser(bytes.NewReader(body))
return lastResp, &apiErr
}
lastResp.Body = io.NopCloser(bytes.NewReader(body))
if lastErr != nil {
return lastResp, lastErr
}
return lastResp, &APIError{
HTTPStatus: lastResp.StatusCode,
Code: 0,
Message: fmt.Sprintf("retries exhausted after %d attempts", c.retryConfig.MaxRetries+1),
}
}
c.onConnDown()
return lastResp, lastErr
}
// doSingleRequest executes a single HTTP request and tracks connection status.
func (c *ProtonMailClient) doSingleRequest(ctx context.Context, req *http.Request) (*http.Response, error) {
resp, err := c.httpClient.Do(req)
if err != nil {
if resp != nil && resp.Body != nil {
resp.Body.Close()
}
// Check if it's a network-level error
if netErr := new(net.OpError); errors.As(err, &netErr) {
return nil, NewNetError(netErr, "network error while communicating with API")
}
// Check for dial/connection errors
return nil, err
}
if resp.StatusCode == 0 {
c.onConnDown()
return nil, NewNetError(errors.New("no response received"), "received no response from API")
}
return resp, nil
}
// shouldRetryError determines if an error condition warrants a retry.
func (c *ProtonMailClient) shouldRetryError(err error, resp *http.Response) bool {
if err == nil {
return false
}
// Context errors are not retryable
if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
return false
}
// Network errors (NetError wraps net.OpError) are retryable
if _, ok := errors.Unwrap(err).(*NetError); ok {
return true
}
if _, ok := err.(*NetError); ok {
return true
}
// Raw net.OpError from http.Client.Do are retryable
if _, ok := err.(*net.OpError); ok {
return true
}
return false
}
// shouldRetryResponse determines if a response status warrants a retry.
func (c *ProtonMailClient) shouldRetryResponse(resp *http.Response) bool {
if resp == nil {
return false
}
return resp.StatusCode == http.StatusTooManyRequests ||
resp.StatusCode == http.StatusServiceUnavailable
}
// calculateBackoff computes the retry delay using exponential backoff with jitter.
// If the response contains a Retry-After header, that value is used as the base.
func (c *ProtonMailClient) calculateBackoff(attempt int, resp *http.Response) time.Duration {
var delay time.Duration
// Check for Retry-After header first
if resp != nil {
retryAfter := c.parseRetryAfter(resp)
if retryAfter > 0 {
delay = retryAfter
}
}
// Fall back to exponential backoff
if delay == 0 {
base := c.retryConfig.BaseBackoff
delay = base * (1 << uint(attempt)) // Exponential: 0.5s, 1s, 2s, ...
}
// Cap at max wait time
if delay > c.retryConfig.MaxWaitTime {
delay = c.retryConfig.MaxWaitTime
}
// Add jitter (0-10 seconds) to avoid thundering herd
jitter := time.Duration(c.randIntn(10)) * time.Second
delay += jitter
return delay
}
// randIntn returns a thread-safe random integer in [0, n) using crypto/rand.
func (c *ProtonMailClient) randIntn(n int) int {
b := make([]byte, 4)
_, _ = rand.Read(b)
return int(binary.BigEndian.Uint32(b) % uint32(n))
}
// parseRetryAfter parses the Retry-After header and returns the duration.
// Returns 0 if the header is missing or invalid.
func (c *ProtonMailClient) parseRetryAfter(resp *http.Response) time.Duration {
retryAfterStr := resp.Header.Get("Retry-After")
if retryAfterStr == "" {
return 0
}
// Try parsing as seconds (integer)
seconds, err := strconv.Atoi(retryAfterStr)
if err != nil {
// Try parsing as HTTP date
t, err := time.Parse(time.RFC1123, retryAfterStr)
if err != nil {
return 0
}
delay := t.Sub(time.Now())
if delay < 0 {
delay = 0
}
return delay
}
return time.Duration(seconds) * time.Second
}