package api import ( "bytes" "context" "crypto/rand" "encoding/binary" "encoding/json" "errors" "fmt" "io" "net" "net/http" "strconv" "sync" "time" "github.com/frenocorp/pop/internal/config" "github.com/frenocorp/pop/internal/auth" ) // Code represents a structured API error code returned by ProtonMail. type Code int const ( SuccessCode Code = 1000 MultiCode Code = 1001 InvalidValue Code = 2001 AppVersionMissingCode Code = 5001 AppVersionBadCode Code = 5003 UsernameInvalid Code = 6003 PasswordWrong Code = 8002 HumanVerificationRequired Code = 9001 PaidPlanRequired Code = 10004 SessionExpired Code = 10005 TokenExpired Code = 10006 QuotaExceeded Code = 10011 AuthRefreshTokenInvalid Code = 10013 AccountSuspended Code = 10050 HumanValidationInvalidToken Code = 12087 ) // Status tracks the connection state to the ProtonMail API. type Status int const ( StatusUp Status = iota StatusDown ) func (s Status) String() string { switch s { case StatusUp: return "up" case StatusDown: return "down" default: return "unknown" } } // StatusObserver is called when the connection status changes. type StatusObserver func(Status) // APIHVDetails contains information related to human verification requests. type APIHVDetails struct { Methods []string `json:"HumanVerificationMethods"` Token string `json:"HumanVerificationToken"` } // ErrDetails contains optional error details which are specific to each request. type ErrDetails []byte func (d ErrDetails) MarshalJSON() ([]byte, error) { return d, nil } func (d *ErrDetails) UnmarshalJSON(data []byte) error { *d = data return nil } // APIError represents an error returned by the ProtonMail API. type APIError struct { // HTTPStatus is the HTTP status code of the response. HTTPStatus int `json:"-"` // Code is the structured error code returned by the API. Code Code `json:"Code,omitempty"` // Message is the human-readable error message. Message string `json:"Message,omitempty"` // Details contains optional error details (serialized JSON). Details ErrDetails `json:"Details,omitempty"` } func (e *APIError) Error() string { return fmt.Sprintf("API error %d (code=%d): %s", e.HTTPStatus, e.Code, e.Message) } // IsHVError returns true if this error requires human verification. func (e *APIError) IsHVError() bool { return e.Code == HumanVerificationRequired } // GetHVDetails parses the Details field and returns structured HV information. func (e *APIError) GetHVDetails() (*APIHVDetails, error) { if !e.IsHVError() { return nil, fmt.Errorf("not an HV error (code=%d): %w", e.Code, ErrNotHVError) } var details APIHVDetails if err := json.Unmarshal(e.Details, &details); err != nil { return nil, fmt.Errorf("failed to parse HV details: %w", err) } return &details, nil } // ErrNotHVError is returned when GetHVDetails is called on a non-HV error. var ErrNotHVError = errors.New("not a human verification error") // NetError represents a network-level error when the API is unreachable. type NetError struct { // Cause is the underlying error that caused the network error. Cause error // Message describes the network error context. Message string } func NewNetError(cause error, message string) *NetError { return &NetError{Cause: cause, Message: message} } func (e *NetError) Error() string { return fmt.Sprintf("%s: %v", e.Message, e.Cause) } func (e *NetError) Unwrap() error { return e.Cause } func (e *NetError) Is(target error) bool { _, ok := target.(*NetError) return ok } // RetryConfig configures the retry behavior for API requests. type RetryConfig struct { // MaxRetries is the maximum number of retry attempts. MaxRetries int // MaxWaitTime is the maximum time to wait before a retry. MaxWaitTime time.Duration // BaseBackoff is the base delay for exponential backoff. BaseBackoff time.Duration } // DefaultRetryConfig returns sensible defaults matching the official library. func DefaultRetryConfig() RetryConfig { return RetryConfig{ MaxRetries: 3, MaxWaitTime: time.Minute, BaseBackoff: 500 * time.Millisecond, } } type ProtonMailClient struct { baseURL string httpClient *http.Client config *config.Config rateLimiter *RateLimiter authHeader string authMu sync.RWMutex sessionRefresher auth.SessionRefresher retryConfig RetryConfig // Connection status tracking status Status statusLock sync.Mutex statusObs []StatusObserver statusObsMu sync.RWMutex } // RateLimiter implements a sliding window rate limiter. type RateLimiter struct { mu sync.Mutex requests []time.Time limit int window time.Duration } func NewProtonMailClient(cfg *config.Config, refresher auth.SessionRefresher) *ProtonMailClient { return &ProtonMailClient{ baseURL: cfg.APIBaseURL, httpClient: &http.Client{Timeout: time.Duration(cfg.TimeoutSec) * time.Second}, config: cfg, sessionRefresher: refresher, rateLimiter: &RateLimiter{ requests: make([]time.Time, 0, cfg.RateLimitReq), limit: cfg.RateLimitReq, window: time.Duration(cfg.RateLimitWin) * time.Second, }, retryConfig: DefaultRetryConfig(), status: StatusUp, } } // SetRetryConfig updates the retry configuration. func (c *ProtonMailClient) SetRetryConfig(rc RetryConfig) { c.retryConfig = rc } func (c *ProtonMailClient) SetAuthHeader(token string) { c.authMu.Lock() defer c.authMu.Unlock() c.authHeader = token } func (c *ProtonMailClient) getAuthHeader() string { c.authMu.RLock() defer c.authMu.RUnlock() return c.authHeader } func (c *ProtonMailClient) refreshAuth() error { if c.sessionRefresher == nil { return fmt.Errorf("no session refresher configured") } return c.sessionRefresher.RefreshToken() } func (c *ProtonMailClient) GetBaseURL() string { return c.baseURL } // AddStatusObserver registers a callback for connection status changes. func (c *ProtonMailClient) AddStatusObserver(observer StatusObserver) { c.statusObsMu.Lock() defer c.statusObsMu.Unlock() c.statusObs = append(c.statusObs, observer) } // GetStatus returns the current connection status. func (c *ProtonMailClient) GetStatus() Status { c.statusLock.Lock() defer c.statusLock.Unlock() return c.status } func (c *ProtonMailClient) onConnUp() { c.statusLock.Lock() defer c.statusLock.Unlock() if c.status == StatusUp { return } c.status = StatusUp c.statusObsMu.RLock() defer c.statusObsMu.RUnlock() for _, obs := range c.statusObs { obs(c.status) } } func (c *ProtonMailClient) onConnDown() { c.statusLock.Lock() defer c.statusLock.Unlock() if c.status == StatusDown { return } c.status = StatusDown c.statusObsMu.RLock() defer c.statusObsMu.RUnlock() for _, obs := range c.statusObs { obs(c.status) } } func (rl *RateLimiter) Wait() { rl.mu.Lock() defer rl.mu.Unlock() now := time.Now() windowStart := now.Add(-rl.window) // In-place filtering to reduce GC pressure valid := 0 for _, t := range rl.requests { if t.After(windowStart) { rl.requests[valid] = t valid++ } } rl.requests = rl.requests[:valid] if len(rl.requests) >= rl.limit { sleep := rl.requests[0].Add(rl.window).Sub(now) if sleep > 0 { time.Sleep(sleep) } } } func (c *ProtonMailClient) recordRequest() { c.rateLimiter.mu.Lock() c.rateLimiter.requests = append(c.rateLimiter.requests, time.Now()) c.rateLimiter.mu.Unlock() } func (c *ProtonMailClient) Do(req *http.Request) (*http.Response, error) { return c.DoWithContext(context.Background(), req) } func (c *ProtonMailClient) DoWithContext(ctx context.Context, req *http.Request) (*http.Response, error) { c.rateLimiter.Wait() req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", c.getAuthHeader())) req.Header.Set("Accept", "application/json") req = req.WithContext(ctx) resp, err := c.executeWithRetry(ctx, req) return resp, err } // executeWithRetry performs the HTTP request with exponential backoff and retry logic. func (c *ProtonMailClient) executeWithRetry(ctx context.Context, req *http.Request) (*http.Response, error) { var lastResp *http.Response var lastErr error // Capture request body once so it can be restored on retries. var bodyBytes []byte if req.Body != nil { bodyBytes, _ = io.ReadAll(req.Body) req.Body = io.NopCloser(bytes.NewReader(bodyBytes)) } for attempt := 0; attempt <= c.retryConfig.MaxRetries; attempt++ { // Restore body before each retry attempt if attempt > 0 && len(bodyBytes) > 0 { req.Body = io.NopCloser(bytes.NewReader(bodyBytes)) } if attempt > 0 { delay := c.calculateBackoff(attempt, lastResp) select { case <-ctx.Done(): if lastResp != nil && lastResp.Body != nil { lastResp.Body.Close() } return lastResp, ctx.Err() case <-time.After(delay): } } resp, err := c.doSingleRequest(ctx, req) if err != nil { lastErr = err lastResp = nil if !c.shouldRetryError(err, resp) { c.onConnDown() return resp, err } continue } // Check for 401 and attempt token refresh (single shot, no retry loop) if resp.StatusCode == http.StatusUnauthorized { resp.Body.Close() if err := c.refreshAuth(); err != nil { return resp, fmt.Errorf("401 received and refresh failed: %w", err) } session, err := c.sessionRefresher.GetSession() if err != nil { return resp, fmt.Errorf("401 received, refresh succeeded but failed to get new session: %w", err) } c.SetAuthHeader(session.AccessToken) // Clone request for retry with new token retryReq := req.Clone(ctx) retryReq.Header.Set("Authorization", fmt.Sprintf("Bearer %s", c.getAuthHeader())) resp, err = c.doSingleRequest(ctx, retryReq) if err != nil { c.onConnDown() if resp != nil && resp.Body != nil { resp.Body.Close() } return nil, err } } // Check if response should trigger retry if c.shouldRetryResponse(resp) { if lastResp != nil { lastResp.Body.Close() } lastResp = resp continue } c.onConnUp() c.recordRequest() // Check for API errors (4xx/5xx) if resp.StatusCode >= 400 { body, _ := io.ReadAll(resp.Body) resp.Body.Close() var apiErr APIError if err := json.Unmarshal(body, &apiErr); err == nil { apiErr.HTTPStatus = resp.StatusCode resp.Body = io.NopCloser(bytes.NewReader(body)) return resp, &apiErr } // Non-JSON error response: restore body and let caller handle resp.Body = io.NopCloser(bytes.NewReader(body)) return resp, nil } return resp, nil } // Exhausted all retries if lastResp != nil { c.onConnUp() c.recordRequest() body, _ := io.ReadAll(lastResp.Body) lastResp.Body.Close() var apiErr APIError if err := json.Unmarshal(body, &apiErr); err == nil { apiErr.HTTPStatus = lastResp.StatusCode lastResp.Body = io.NopCloser(bytes.NewReader(body)) return lastResp, &apiErr } lastResp.Body = io.NopCloser(bytes.NewReader(body)) if lastErr != nil { return lastResp, lastErr } return lastResp, &APIError{ HTTPStatus: lastResp.StatusCode, Code: 0, Message: fmt.Sprintf("retries exhausted after %d attempts", c.retryConfig.MaxRetries+1), } } c.onConnDown() return lastResp, lastErr } // doSingleRequest executes a single HTTP request and tracks connection status. func (c *ProtonMailClient) doSingleRequest(ctx context.Context, req *http.Request) (*http.Response, error) { resp, err := c.httpClient.Do(req) if err != nil { if resp != nil && resp.Body != nil { resp.Body.Close() } // Check if it's a network-level error if netErr := new(net.OpError); errors.As(err, &netErr) { return nil, NewNetError(netErr, "network error while communicating with API") } // Check for dial/connection errors return nil, err } if resp.StatusCode == 0 { c.onConnDown() return nil, NewNetError(errors.New("no response received"), "received no response from API") } return resp, nil } // shouldRetryError determines if an error condition warrants a retry. func (c *ProtonMailClient) shouldRetryError(err error, resp *http.Response) bool { if err == nil { return false } // Context errors are not retryable if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) { return false } // Network errors (NetError wraps net.OpError) are retryable if _, ok := errors.Unwrap(err).(*NetError); ok { return true } if _, ok := err.(*NetError); ok { return true } // Raw net.OpError from http.Client.Do are retryable if _, ok := err.(*net.OpError); ok { return true } return false } // shouldRetryResponse determines if a response status warrants a retry. func (c *ProtonMailClient) shouldRetryResponse(resp *http.Response) bool { if resp == nil { return false } return resp.StatusCode == http.StatusTooManyRequests || resp.StatusCode == http.StatusServiceUnavailable } // calculateBackoff computes the retry delay using exponential backoff with jitter. // If the response contains a Retry-After header, that value is used as the base. func (c *ProtonMailClient) calculateBackoff(attempt int, resp *http.Response) time.Duration { var delay time.Duration // Check for Retry-After header first if resp != nil { retryAfter := c.parseRetryAfter(resp) if retryAfter > 0 { delay = retryAfter } } // Fall back to exponential backoff if delay == 0 { base := c.retryConfig.BaseBackoff delay = base * (1 << uint(attempt)) // Exponential: 0.5s, 1s, 2s, ... } // Cap at max wait time if delay > c.retryConfig.MaxWaitTime { delay = c.retryConfig.MaxWaitTime } // Add jitter (0-10 seconds) to avoid thundering herd jitter := time.Duration(c.randIntn(10)) * time.Second delay += jitter return delay } // randIntn returns a thread-safe random integer in [0, n) using crypto/rand. func (c *ProtonMailClient) randIntn(n int) int { b := make([]byte, 4) _, _ = rand.Read(b) return int(binary.BigEndian.Uint32(b) % uint32(n)) } // parseRetryAfter parses the Retry-After header and returns the duration. // Returns 0 if the header is missing or invalid. func (c *ProtonMailClient) parseRetryAfter(resp *http.Response) time.Duration { retryAfterStr := resp.Header.Get("Retry-After") if retryAfterStr == "" { return 0 } // Try parsing as seconds (integer) seconds, err := strconv.Atoi(retryAfterStr) if err != nil { // Try parsing as HTTP date t, err := time.Parse(time.RFC1123, retryAfterStr) if err != nil { return 0 } delay := t.Sub(time.Now()) if delay < 0 { delay = 0 } return delay } return time.Duration(seconds) * time.Second }