feat: implement automatic auth token refresh on 401 with context support (FRE-4763)

- Add SessionRefresher interface for token refresh abstraction
- Update ProtonMailClient to auto-refresh on 401 responses
- Add DoWithContext method for context-aware HTTP requests
- Update SessionManager with RefreshTokenWithContext method
- Update LoginWithCredentials and LoginInteractive to accept context
- Add checkAuthenticatedWithManager helper for commands needing session manager
- All API methods now support proper cancellation via context.Context

Files changed:
- internal/api/client.go - Auto-refresh on 401, context support
- internal/auth/session.go - Context-aware refresh and login methods
- internal/auth/interface.go - SessionRefresher interface
- cmd/mail.go, cmd/draft.go, cmd/folders.go - Updated to use new helpers
- cmd/auth.go - Context support for login commands

Co-Authored-By: Paperclip <noreply@paperclip.ing>
This commit is contained in:
2026-05-09 21:46:03 -04:00
parent 19a9e2a3df
commit 691a2acdad
7 changed files with 125 additions and 35 deletions

View File

@@ -2,6 +2,7 @@ package api
import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
@@ -10,6 +11,7 @@ import (
"time"
"github.com/frenocorp/pop/internal/config"
"github.com/frenocorp/pop/internal/auth"
)
type ProtonMailClient struct {
@@ -19,6 +21,7 @@ type ProtonMailClient struct {
rateLimiter *RateLimiter
authHeader string
authMu sync.RWMutex
sessionRefresher auth.SessionRefresher
}
type RateLimiter struct {
@@ -28,11 +31,12 @@ type RateLimiter struct {
window time.Duration
}
func NewProtonMailClient(cfg *config.Config) *ProtonMailClient {
func NewProtonMailClient(cfg *config.Config, refresher auth.SessionRefresher) *ProtonMailClient {
return &ProtonMailClient{
baseURL: cfg.APIBaseURL,
httpClient: &http.Client{Timeout: time.Duration(cfg.TimeoutSec) * time.Second},
config: cfg,
sessionRefresher: refresher,
rateLimiter: &RateLimiter{
requests: make([]time.Time, 0, cfg.RateLimitReq),
limit: cfg.RateLimitReq,
@@ -53,6 +57,13 @@ func (c *ProtonMailClient) getAuthHeader() string {
return c.authHeader
}
func (c *ProtonMailClient) refreshAuth() error {
if c.sessionRefresher == nil {
return fmt.Errorf("no session refresher configured")
}
return c.sessionRefresher.RefreshToken()
}
func (c *ProtonMailClient) GetBaseURL() string {
return c.baseURL
}
@@ -83,11 +94,20 @@ func (rl *RateLimiter) Wait() {
}
func (c *ProtonMailClient) Do(req *http.Request) (*http.Response, error) {
return c.DoWithContext(context.Background(), req)
}
func (c *ProtonMailClient) DoWithContext(ctx context.Context, req *http.Request) (*http.Response, error) {
c.rateLimiter.Wait()
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", c.getAuthHeader()))
req.Header.Set("Accept", "application/json")
// Check if request has its own context
if ctx != context.Background() {
req = req.WithContext(ctx)
}
resp, err := c.httpClient.Do(req)
if err != nil {
return nil, err
@@ -98,7 +118,33 @@ func (c *ProtonMailClient) Do(req *http.Request) (*http.Response, error) {
c.rateLimiter.requests = append(c.rateLimiter.requests, time.Now())
c.rateLimiter.mu.Unlock()
// Check for API errors
// Check for 401 and attempt refresh
if resp.StatusCode == http.StatusUnauthorized {
// Close the current response body
resp.Body.Close()
// Attempt to refresh the token
if err := c.refreshAuth(); err != nil {
return resp, fmt.Errorf("401 received and refresh failed: %w", err)
}
// Retry the request with new token
// Clone the request to reset any body position
retryReq := req.Clone(ctx)
retryReq.Header.Set("Authorization", fmt.Sprintf("Bearer %s", c.getAuthHeader()))
resp, err = c.httpClient.Do(retryReq)
if err != nil {
return nil, err
}
// Record the retry request
c.rateLimiter.mu.Lock()
c.rateLimiter.requests = append(c.rateLimiter.requests, time.Now())
c.rateLimiter.mu.Unlock()
}
// Check for other API errors
if resp.StatusCode >= 400 {
body, _ := io.ReadAll(resp.Body)
var apiErr APIError
@@ -120,3 +166,14 @@ type APIError struct {
func (e *APIError) Error() string {
return fmt.Sprintf("API error %d: %s", e.HTTPStatus, e.Message)
}
// Helper function to create a request with context
func NewRequestWithContext(ctx context.Context, method, url string, body io.Reader) (*http.Request, error) {
req, err := http.NewRequestWithContext(ctx, method, url, body)
if err != nil {
return nil, err
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Accept", "application/json")
return req, nil
}

View File

@@ -0,0 +1,11 @@
package auth
import "context"
// SessionRefresher defines the interface for refreshing authentication tokens.
// This allows the API client to automatically refresh tokens on 401 responses.
type SessionRefresher interface {
RefreshToken() error
RefreshTokenWithContext(ctx context.Context) error
GetSession() (*Session, error)
}

View File

@@ -2,6 +2,7 @@ package auth
import (
"bytes"
"context"
"crypto/aes"
"crypto/cipher"
"crypto/rand"
@@ -54,7 +55,7 @@ func NewSessionManager() (*SessionManager, error) {
}, nil
}
func (m *SessionManager) LoginWithCredentials(apiBaseURL, email, password, mailPassphrase string) error {
func (m *SessionManager) LoginWithCredentials(ctx context.Context, apiBaseURL, email, password, mailPassphrase string) error {
if err := os.MkdirAll(m.configDir, 0700); err != nil {
return fmt.Errorf("failed to create config dir: %w", err)
}
@@ -75,7 +76,7 @@ func (m *SessionManager) LoginWithCredentials(apiBaseURL, email, password, mailP
return fmt.Errorf("failed to marshal auth payload: %w", err)
}
req, err := http.NewRequest("POST", authURL, bytes.NewBuffer(jsonData))
req, err := http.NewRequestWithContext(ctx, "POST", authURL, bytes.NewBuffer(jsonData))
if err != nil {
return fmt.Errorf("failed to create auth request: %w", err)
}
@@ -95,11 +96,11 @@ func (m *SessionManager) LoginWithCredentials(apiBaseURL, email, password, mailP
}
var authResponse struct {
UID string `json:"UID"`
AccessToken string `json:"AccessToken"`
RefreshToken string `json:"RefreshToken"`
ExpiresIn int `json:"ExpiresIn"`
TwoFARequired bool `json:"TwoFARequired"`
UID string `json:"UID"`
AccessToken string `json:"AccessToken"`
RefreshToken string `json:"RefreshToken"`
ExpiresIn int `json:"ExpiresIn"`
TwoFARequired bool `json:"TwoFARequired"`
}
if err := json.NewDecoder(resp.Body).Decode(&authResponse); err != nil {
@@ -135,7 +136,7 @@ func (m *SessionManager) LoginWithCredentials(apiBaseURL, email, password, mailP
return nil
}
func (m *SessionManager) LoginInteractive(apiBaseURL string) error {
func (m *SessionManager) LoginInteractive(ctx context.Context, apiBaseURL string) error {
if err := os.MkdirAll(m.configDir, 0700); err != nil {
return fmt.Errorf("failed to create config dir: %w", err)
}
@@ -192,7 +193,7 @@ func (m *SessionManager) LoginInteractive(apiBaseURL string) error {
return fmt.Errorf("failed to marshal auth payload: %w", err)
}
req, err := http.NewRequest("POST", authURL, bytes.NewBuffer(jsonData))
req, err := http.NewRequestWithContext(ctx, "POST", authURL, bytes.NewBuffer(jsonData))
if err != nil {
return fmt.Errorf("failed to create auth request: %w", err)
}
@@ -263,7 +264,7 @@ func (m *SessionManager) LoginInteractive(apiBaseURL string) error {
return fmt.Errorf("failed to marshal TOTP payload: %w", err)
}
totpReq, err := http.NewRequest("POST", totpURL, bytes.NewBuffer(totpJSON))
totpReq, err := http.NewRequestWithContext(ctx, "POST", totpURL, bytes.NewBuffer(totpJSON))
if err != nil {
return fmt.Errorf("failed to create TOTP request: %w", err)
}
@@ -372,6 +373,10 @@ func (m *SessionManager) IsAuthenticated() (bool, error) {
}
func (m *SessionManager) RefreshToken() error {
return m.RefreshTokenWithContext(context.Background())
}
func (m *SessionManager) RefreshTokenWithContext(ctx context.Context) error {
session, err := m.GetSession()
if err != nil {
return fmt.Errorf("failed to get session: %w", err)
@@ -394,7 +399,7 @@ func (m *SessionManager) RefreshToken() error {
return fmt.Errorf("failed to marshal refresh payload: %w", err)
}
req, err := http.NewRequest("POST", refreshURL, bytes.NewBuffer(jsonData))
req, err := http.NewRequestWithContext(ctx, "POST", refreshURL, bytes.NewBuffer(jsonData))
if err != nil {
return fmt.Errorf("failed to create refresh request: %w", err)
}