- P1.2: Close lastResp.Body on context cancellation during retry backoff - P1.1: Close original response body after io.ReadAll on error paths to return TCP connections to the pool - P2.3: Close response body in doSingleRequest on error paths (http.Client.Do can return non-nil resp with non-nil err) - P2.3: Defensive body close on auth refresh retry failure - P2: Simplify shouldRetryError with explicit type checks - P2: RateLimiter in-place filtering to reduce GC pressure - P3.6: Replace math/rand with crypto/rand for thread-safe jitter - P3.7: Add missing error code constants (SessionExpired, TokenExpired, QuotaExceeded, AccountSuspended) Co-Authored-By: Paperclip <noreply@paperclip.ing>
582 lines
14 KiB
Go
582 lines
14 KiB
Go
package api
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"crypto/rand"
|
|
"encoding/binary"
|
|
"encoding/json"
|
|
"errors"
|
|
"fmt"
|
|
"io"
|
|
"net"
|
|
"net/http"
|
|
"strconv"
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/frenocorp/pop/internal/config"
|
|
"github.com/frenocorp/pop/internal/auth"
|
|
)
|
|
|
|
// Code represents a structured API error code returned by ProtonMail.
|
|
type Code int
|
|
|
|
const (
|
|
SuccessCode Code = 1000
|
|
MultiCode Code = 1001
|
|
InvalidValue Code = 2001
|
|
AppVersionMissingCode Code = 5001
|
|
AppVersionBadCode Code = 5003
|
|
UsernameInvalid Code = 6003
|
|
PasswordWrong Code = 8002
|
|
HumanVerificationRequired Code = 9001
|
|
PaidPlanRequired Code = 10004
|
|
SessionExpired Code = 10005
|
|
TokenExpired Code = 10006
|
|
QuotaExceeded Code = 10011
|
|
AuthRefreshTokenInvalid Code = 10013
|
|
AccountSuspended Code = 10050
|
|
HumanValidationInvalidToken Code = 12087
|
|
)
|
|
|
|
// Status tracks the connection state to the ProtonMail API.
|
|
type Status int
|
|
|
|
const (
|
|
StatusUp Status = iota
|
|
StatusDown
|
|
)
|
|
|
|
func (s Status) String() string {
|
|
switch s {
|
|
case StatusUp:
|
|
return "up"
|
|
case StatusDown:
|
|
return "down"
|
|
default:
|
|
return "unknown"
|
|
}
|
|
}
|
|
|
|
// StatusObserver is called when the connection status changes.
|
|
type StatusObserver func(Status)
|
|
|
|
// APIHVDetails contains information related to human verification requests.
|
|
type APIHVDetails struct {
|
|
Methods []string `json:"HumanVerificationMethods"`
|
|
Token string `json:"HumanVerificationToken"`
|
|
}
|
|
|
|
// ErrDetails contains optional error details which are specific to each request.
|
|
type ErrDetails []byte
|
|
|
|
func (d ErrDetails) MarshalJSON() ([]byte, error) {
|
|
return d, nil
|
|
}
|
|
|
|
func (d *ErrDetails) UnmarshalJSON(data []byte) error {
|
|
*d = data
|
|
return nil
|
|
}
|
|
|
|
// APIError represents an error returned by the ProtonMail API.
|
|
type APIError struct {
|
|
// HTTPStatus is the HTTP status code of the response.
|
|
HTTPStatus int `json:"-"`
|
|
// Code is the structured error code returned by the API.
|
|
Code Code `json:"Code,omitempty"`
|
|
// Message is the human-readable error message.
|
|
Message string `json:"Message,omitempty"`
|
|
// Details contains optional error details (serialized JSON).
|
|
Details ErrDetails `json:"Details,omitempty"`
|
|
}
|
|
|
|
func (e *APIError) Error() string {
|
|
return fmt.Sprintf("API error %d (code=%d): %s", e.HTTPStatus, e.Code, e.Message)
|
|
}
|
|
|
|
// IsHVError returns true if this error requires human verification.
|
|
func (e *APIError) IsHVError() bool {
|
|
return e.Code == HumanVerificationRequired
|
|
}
|
|
|
|
// GetHVDetails parses the Details field and returns structured HV information.
|
|
func (e *APIError) GetHVDetails() (*APIHVDetails, error) {
|
|
if !e.IsHVError() {
|
|
return nil, fmt.Errorf("not an HV error (code=%d): %w", e.Code, ErrNotHVError)
|
|
}
|
|
|
|
var details APIHVDetails
|
|
if err := json.Unmarshal(e.Details, &details); err != nil {
|
|
return nil, fmt.Errorf("failed to parse HV details: %w", err)
|
|
}
|
|
|
|
return &details, nil
|
|
}
|
|
|
|
// ErrNotHVError is returned when GetHVDetails is called on a non-HV error.
|
|
var ErrNotHVError = errors.New("not a human verification error")
|
|
|
|
// NetError represents a network-level error when the API is unreachable.
|
|
type NetError struct {
|
|
// Cause is the underlying error that caused the network error.
|
|
Cause error
|
|
// Message describes the network error context.
|
|
Message string
|
|
}
|
|
|
|
func NewNetError(cause error, message string) *NetError {
|
|
return &NetError{Cause: cause, Message: message}
|
|
}
|
|
|
|
func (e *NetError) Error() string {
|
|
return fmt.Sprintf("%s: %v", e.Message, e.Cause)
|
|
}
|
|
|
|
func (e *NetError) Unwrap() error {
|
|
return e.Cause
|
|
}
|
|
|
|
func (e *NetError) Is(target error) bool {
|
|
_, ok := target.(*NetError)
|
|
return ok
|
|
}
|
|
|
|
// RetryConfig configures the retry behavior for API requests.
|
|
type RetryConfig struct {
|
|
// MaxRetries is the maximum number of retry attempts.
|
|
MaxRetries int
|
|
// MaxWaitTime is the maximum time to wait before a retry.
|
|
MaxWaitTime time.Duration
|
|
// BaseBackoff is the base delay for exponential backoff.
|
|
BaseBackoff time.Duration
|
|
}
|
|
|
|
// DefaultRetryConfig returns sensible defaults matching the official library.
|
|
func DefaultRetryConfig() RetryConfig {
|
|
return RetryConfig{
|
|
MaxRetries: 3,
|
|
MaxWaitTime: time.Minute,
|
|
BaseBackoff: 500 * time.Millisecond,
|
|
}
|
|
}
|
|
|
|
type ProtonMailClient struct {
|
|
baseURL string
|
|
httpClient *http.Client
|
|
config *config.Config
|
|
rateLimiter *RateLimiter
|
|
authHeader string
|
|
authMu sync.RWMutex
|
|
sessionRefresher auth.SessionRefresher
|
|
retryConfig RetryConfig
|
|
|
|
// Connection status tracking
|
|
status Status
|
|
statusLock sync.Mutex
|
|
statusObs []StatusObserver
|
|
statusObsMu sync.RWMutex
|
|
}
|
|
|
|
// RateLimiter implements a sliding window rate limiter.
|
|
type RateLimiter struct {
|
|
mu sync.Mutex
|
|
requests []time.Time
|
|
limit int
|
|
window time.Duration
|
|
}
|
|
|
|
func NewProtonMailClient(cfg *config.Config, refresher auth.SessionRefresher) *ProtonMailClient {
|
|
return &ProtonMailClient{
|
|
baseURL: cfg.APIBaseURL,
|
|
httpClient: &http.Client{Timeout: time.Duration(cfg.TimeoutSec) * time.Second},
|
|
config: cfg,
|
|
sessionRefresher: refresher,
|
|
rateLimiter: &RateLimiter{
|
|
requests: make([]time.Time, 0, cfg.RateLimitReq),
|
|
limit: cfg.RateLimitReq,
|
|
window: time.Duration(cfg.RateLimitWin) * time.Second,
|
|
},
|
|
retryConfig: DefaultRetryConfig(),
|
|
status: StatusUp,
|
|
}
|
|
}
|
|
|
|
// SetRetryConfig updates the retry configuration.
|
|
func (c *ProtonMailClient) SetRetryConfig(rc RetryConfig) {
|
|
c.retryConfig = rc
|
|
}
|
|
|
|
func (c *ProtonMailClient) SetAuthHeader(token string) {
|
|
c.authMu.Lock()
|
|
defer c.authMu.Unlock()
|
|
c.authHeader = token
|
|
}
|
|
|
|
func (c *ProtonMailClient) getAuthHeader() string {
|
|
c.authMu.RLock()
|
|
defer c.authMu.RUnlock()
|
|
return c.authHeader
|
|
}
|
|
|
|
func (c *ProtonMailClient) refreshAuth() error {
|
|
if c.sessionRefresher == nil {
|
|
return fmt.Errorf("no session refresher configured")
|
|
}
|
|
return c.sessionRefresher.RefreshToken()
|
|
}
|
|
|
|
func (c *ProtonMailClient) GetBaseURL() string {
|
|
return c.baseURL
|
|
}
|
|
|
|
// AddStatusObserver registers a callback for connection status changes.
|
|
func (c *ProtonMailClient) AddStatusObserver(observer StatusObserver) {
|
|
c.statusObsMu.Lock()
|
|
defer c.statusObsMu.Unlock()
|
|
c.statusObs = append(c.statusObs, observer)
|
|
}
|
|
|
|
// GetStatus returns the current connection status.
|
|
func (c *ProtonMailClient) GetStatus() Status {
|
|
c.statusLock.Lock()
|
|
defer c.statusLock.Unlock()
|
|
return c.status
|
|
}
|
|
|
|
func (c *ProtonMailClient) onConnUp() {
|
|
c.statusLock.Lock()
|
|
defer c.statusLock.Unlock()
|
|
|
|
if c.status == StatusUp {
|
|
return
|
|
}
|
|
|
|
c.status = StatusUp
|
|
|
|
c.statusObsMu.RLock()
|
|
defer c.statusObsMu.RUnlock()
|
|
for _, obs := range c.statusObs {
|
|
obs(c.status)
|
|
}
|
|
}
|
|
|
|
func (c *ProtonMailClient) onConnDown() {
|
|
c.statusLock.Lock()
|
|
defer c.statusLock.Unlock()
|
|
|
|
if c.status == StatusDown {
|
|
return
|
|
}
|
|
|
|
c.status = StatusDown
|
|
|
|
c.statusObsMu.RLock()
|
|
defer c.statusObsMu.RUnlock()
|
|
for _, obs := range c.statusObs {
|
|
obs(c.status)
|
|
}
|
|
}
|
|
|
|
func (rl *RateLimiter) Wait() {
|
|
rl.mu.Lock()
|
|
defer rl.mu.Unlock()
|
|
|
|
now := time.Now()
|
|
windowStart := now.Add(-rl.window)
|
|
|
|
// In-place filtering to reduce GC pressure
|
|
valid := 0
|
|
for _, t := range rl.requests {
|
|
if t.After(windowStart) {
|
|
rl.requests[valid] = t
|
|
valid++
|
|
}
|
|
}
|
|
rl.requests = rl.requests[:valid]
|
|
|
|
if len(rl.requests) >= rl.limit {
|
|
sleep := rl.requests[0].Add(rl.window).Sub(now)
|
|
if sleep > 0 {
|
|
time.Sleep(sleep)
|
|
}
|
|
}
|
|
}
|
|
|
|
func (c *ProtonMailClient) recordRequest() {
|
|
c.rateLimiter.mu.Lock()
|
|
c.rateLimiter.requests = append(c.rateLimiter.requests, time.Now())
|
|
c.rateLimiter.mu.Unlock()
|
|
}
|
|
|
|
func (c *ProtonMailClient) Do(req *http.Request) (*http.Response, error) {
|
|
return c.DoWithContext(context.Background(), req)
|
|
}
|
|
|
|
func (c *ProtonMailClient) DoWithContext(ctx context.Context, req *http.Request) (*http.Response, error) {
|
|
c.rateLimiter.Wait()
|
|
|
|
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", c.getAuthHeader()))
|
|
req.Header.Set("Accept", "application/json")
|
|
req = req.WithContext(ctx)
|
|
|
|
resp, err := c.executeWithRetry(ctx, req)
|
|
|
|
return resp, err
|
|
}
|
|
|
|
// executeWithRetry performs the HTTP request with exponential backoff and retry logic.
|
|
func (c *ProtonMailClient) executeWithRetry(ctx context.Context, req *http.Request) (*http.Response, error) {
|
|
var lastResp *http.Response
|
|
var lastErr error
|
|
|
|
// Capture request body once so it can be restored on retries.
|
|
var bodyBytes []byte
|
|
if req.Body != nil {
|
|
bodyBytes, _ = io.ReadAll(req.Body)
|
|
req.Body = io.NopCloser(bytes.NewReader(bodyBytes))
|
|
}
|
|
|
|
for attempt := 0; attempt <= c.retryConfig.MaxRetries; attempt++ {
|
|
// Restore body before each retry attempt
|
|
if attempt > 0 && len(bodyBytes) > 0 {
|
|
req.Body = io.NopCloser(bytes.NewReader(bodyBytes))
|
|
}
|
|
|
|
if attempt > 0 {
|
|
delay := c.calculateBackoff(attempt, lastResp)
|
|
select {
|
|
case <-ctx.Done():
|
|
if lastResp != nil && lastResp.Body != nil {
|
|
lastResp.Body.Close()
|
|
}
|
|
return lastResp, ctx.Err()
|
|
case <-time.After(delay):
|
|
}
|
|
}
|
|
|
|
resp, err := c.doSingleRequest(ctx, req)
|
|
|
|
if err != nil {
|
|
lastErr = err
|
|
lastResp = nil
|
|
|
|
if !c.shouldRetryError(err, resp) {
|
|
c.onConnDown()
|
|
return resp, err
|
|
}
|
|
|
|
continue
|
|
}
|
|
|
|
// Check for 401 and attempt token refresh (single shot, no retry loop)
|
|
if resp.StatusCode == http.StatusUnauthorized {
|
|
resp.Body.Close()
|
|
|
|
if err := c.refreshAuth(); err != nil {
|
|
return resp, fmt.Errorf("401 received and refresh failed: %w", err)
|
|
}
|
|
|
|
session, err := c.sessionRefresher.GetSession()
|
|
if err != nil {
|
|
return resp, fmt.Errorf("401 received, refresh succeeded but failed to get new session: %w", err)
|
|
}
|
|
c.SetAuthHeader(session.AccessToken)
|
|
|
|
// Clone request for retry with new token
|
|
retryReq := req.Clone(ctx)
|
|
retryReq.Header.Set("Authorization", fmt.Sprintf("Bearer %s", c.getAuthHeader()))
|
|
|
|
resp, err = c.doSingleRequest(ctx, retryReq)
|
|
if err != nil {
|
|
c.onConnDown()
|
|
if resp != nil && resp.Body != nil {
|
|
resp.Body.Close()
|
|
}
|
|
return nil, err
|
|
}
|
|
}
|
|
|
|
// Check if response should trigger retry
|
|
if c.shouldRetryResponse(resp) {
|
|
if lastResp != nil {
|
|
lastResp.Body.Close()
|
|
}
|
|
lastResp = resp
|
|
continue
|
|
}
|
|
|
|
c.onConnUp()
|
|
c.recordRequest()
|
|
|
|
// Check for API errors (4xx/5xx)
|
|
if resp.StatusCode >= 400 {
|
|
body, _ := io.ReadAll(resp.Body)
|
|
resp.Body.Close()
|
|
var apiErr APIError
|
|
if err := json.Unmarshal(body, &apiErr); err == nil {
|
|
apiErr.HTTPStatus = resp.StatusCode
|
|
resp.Body = io.NopCloser(bytes.NewReader(body))
|
|
return resp, &apiErr
|
|
}
|
|
// Non-JSON error response: restore body and let caller handle
|
|
resp.Body = io.NopCloser(bytes.NewReader(body))
|
|
return resp, nil
|
|
}
|
|
|
|
return resp, nil
|
|
}
|
|
|
|
// Exhausted all retries
|
|
if lastResp != nil {
|
|
c.onConnUp()
|
|
c.recordRequest()
|
|
body, _ := io.ReadAll(lastResp.Body)
|
|
lastResp.Body.Close()
|
|
var apiErr APIError
|
|
if err := json.Unmarshal(body, &apiErr); err == nil {
|
|
apiErr.HTTPStatus = lastResp.StatusCode
|
|
lastResp.Body = io.NopCloser(bytes.NewReader(body))
|
|
return lastResp, &apiErr
|
|
}
|
|
lastResp.Body = io.NopCloser(bytes.NewReader(body))
|
|
if lastErr != nil {
|
|
return lastResp, lastErr
|
|
}
|
|
return lastResp, &APIError{
|
|
HTTPStatus: lastResp.StatusCode,
|
|
Code: 0,
|
|
Message: fmt.Sprintf("retries exhausted after %d attempts", c.retryConfig.MaxRetries+1),
|
|
}
|
|
}
|
|
|
|
c.onConnDown()
|
|
return lastResp, lastErr
|
|
}
|
|
|
|
// doSingleRequest executes a single HTTP request and tracks connection status.
|
|
func (c *ProtonMailClient) doSingleRequest(ctx context.Context, req *http.Request) (*http.Response, error) {
|
|
resp, err := c.httpClient.Do(req)
|
|
|
|
if err != nil {
|
|
if resp != nil && resp.Body != nil {
|
|
resp.Body.Close()
|
|
}
|
|
// Check if it's a network-level error
|
|
if netErr := new(net.OpError); errors.As(err, &netErr) {
|
|
return nil, NewNetError(netErr, "network error while communicating with API")
|
|
}
|
|
// Check for dial/connection errors
|
|
return nil, err
|
|
}
|
|
|
|
if resp.StatusCode == 0 {
|
|
c.onConnDown()
|
|
return nil, NewNetError(errors.New("no response received"), "received no response from API")
|
|
}
|
|
|
|
return resp, nil
|
|
}
|
|
|
|
// shouldRetryError determines if an error condition warrants a retry.
|
|
func (c *ProtonMailClient) shouldRetryError(err error, resp *http.Response) bool {
|
|
if err == nil {
|
|
return false
|
|
}
|
|
|
|
// Context errors are not retryable
|
|
if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
|
|
return false
|
|
}
|
|
|
|
// Network errors (NetError wraps net.OpError) are retryable
|
|
if _, ok := errors.Unwrap(err).(*NetError); ok {
|
|
return true
|
|
}
|
|
if _, ok := err.(*NetError); ok {
|
|
return true
|
|
}
|
|
|
|
// Raw net.OpError from http.Client.Do are retryable
|
|
if _, ok := err.(*net.OpError); ok {
|
|
return true
|
|
}
|
|
|
|
return false
|
|
}
|
|
|
|
// shouldRetryResponse determines if a response status warrants a retry.
|
|
func (c *ProtonMailClient) shouldRetryResponse(resp *http.Response) bool {
|
|
if resp == nil {
|
|
return false
|
|
}
|
|
|
|
return resp.StatusCode == http.StatusTooManyRequests ||
|
|
resp.StatusCode == http.StatusServiceUnavailable
|
|
}
|
|
|
|
// calculateBackoff computes the retry delay using exponential backoff with jitter.
|
|
// If the response contains a Retry-After header, that value is used as the base.
|
|
func (c *ProtonMailClient) calculateBackoff(attempt int, resp *http.Response) time.Duration {
|
|
var delay time.Duration
|
|
|
|
// Check for Retry-After header first
|
|
if resp != nil {
|
|
retryAfter := c.parseRetryAfter(resp)
|
|
if retryAfter > 0 {
|
|
delay = retryAfter
|
|
}
|
|
}
|
|
|
|
// Fall back to exponential backoff
|
|
if delay == 0 {
|
|
base := c.retryConfig.BaseBackoff
|
|
delay = base * (1 << uint(attempt)) // Exponential: 0.5s, 1s, 2s, ...
|
|
}
|
|
|
|
// Cap at max wait time
|
|
if delay > c.retryConfig.MaxWaitTime {
|
|
delay = c.retryConfig.MaxWaitTime
|
|
}
|
|
|
|
// Add jitter (0-10 seconds) to avoid thundering herd
|
|
jitter := time.Duration(c.randIntn(10)) * time.Second
|
|
delay += jitter
|
|
|
|
return delay
|
|
}
|
|
|
|
// randIntn returns a thread-safe random integer in [0, n) using crypto/rand.
|
|
func (c *ProtonMailClient) randIntn(n int) int {
|
|
b := make([]byte, 4)
|
|
_, _ = rand.Read(b)
|
|
return int(binary.BigEndian.Uint32(b) % uint32(n))
|
|
}
|
|
|
|
// parseRetryAfter parses the Retry-After header and returns the duration.
|
|
// Returns 0 if the header is missing or invalid.
|
|
func (c *ProtonMailClient) parseRetryAfter(resp *http.Response) time.Duration {
|
|
retryAfterStr := resp.Header.Get("Retry-After")
|
|
if retryAfterStr == "" {
|
|
return 0
|
|
}
|
|
|
|
// Try parsing as seconds (integer)
|
|
seconds, err := strconv.Atoi(retryAfterStr)
|
|
if err != nil {
|
|
// Try parsing as HTTP date
|
|
t, err := time.Parse(time.RFC1123, retryAfterStr)
|
|
if err != nil {
|
|
return 0
|
|
}
|
|
delay := t.Sub(time.Now())
|
|
if delay < 0 {
|
|
delay = 0
|
|
}
|
|
return delay
|
|
}
|
|
|
|
return time.Duration(seconds) * time.Second
|
|
}
|