From 691a2acdade4f492e9779296c26465d0877189e3 Mon Sep 17 00:00:00 2001 From: Michael Freno Date: Sat, 9 May 2026 21:46:03 -0400 Subject: [PATCH] feat: implement automatic auth token refresh on 401 with context support (FRE-4763) - Add SessionRefresher interface for token refresh abstraction - Update ProtonMailClient to auto-refresh on 401 responses - Add DoWithContext method for context-aware HTTP requests - Update SessionManager with RefreshTokenWithContext method - Update LoginWithCredentials and LoginInteractive to accept context - Add checkAuthenticatedWithManager helper for commands needing session manager - All API methods now support proper cancellation via context.Context Files changed: - internal/api/client.go - Auto-refresh on 401, context support - internal/auth/session.go - Context-aware refresh and login methods - internal/auth/interface.go - SessionRefresher interface - cmd/mail.go, cmd/draft.go, cmd/folders.go - Updated to use new helpers - cmd/auth.go - Context support for login commands Co-Authored-By: Paperclip --- cmd/auth.go | 3 +- cmd/draft.go | 16 +++++----- cmd/folders.go | 2 +- cmd/mail.go | 40 +++++++++++++++++-------- internal/api/client.go | 61 ++++++++++++++++++++++++++++++++++++-- internal/auth/interface.go | 11 +++++++ internal/auth/session.go | 27 ++++++++++------- 7 files changed, 125 insertions(+), 35 deletions(-) create mode 100644 internal/auth/interface.go diff --git a/cmd/auth.go b/cmd/auth.go index 46fa36a..004798a 100644 --- a/cmd/auth.go +++ b/cmd/auth.go @@ -1,6 +1,7 @@ package cmd import ( + "context" "fmt" "os" @@ -26,7 +27,7 @@ func loginCmd() *cobra.Command { return fmt.Errorf("failed to create session manager: %w", err) } - return manager.LoginInteractive(cfg.APIBaseURL) + return manager.LoginInteractive(context.Background(), cfg.APIBaseURL) }, } diff --git a/cmd/draft.go b/cmd/draft.go index 29d7b60..5335122 100644 --- a/cmd/draft.go +++ b/cmd/draft.go @@ -64,12 +64,12 @@ func draftSaveCmd() *cobra.Command { return fmt.Errorf("failed to load config: %w", err) } - session, err := checkAuthenticated() + session, sessionMgr, err := checkAuthenticatedWithManager() if err != nil { return fmt.Errorf("not authenticated: %w", err) } - client := api.NewProtonMailClient(cfg) + client := api.NewProtonMailClient(cfg, sessionMgr) client.SetAuthHeader(session.AccessToken) mailClient := internalmail.NewClient(client) @@ -125,12 +125,12 @@ func draftListCmd() *cobra.Command { return fmt.Errorf("failed to load config: %w", err) } - session, err := checkAuthenticated() + session, sessionMgr, err := checkAuthenticatedWithManager() if err != nil { return fmt.Errorf("not authenticated: %w", err) } - client := api.NewProtonMailClient(cfg) + client := api.NewProtonMailClient(cfg, sessionMgr) client.SetAuthHeader(session.AccessToken) mailClient := internalmail.NewClient(client) @@ -190,12 +190,12 @@ func draftEditCmd() *cobra.Command { return fmt.Errorf("failed to load config: %w", err) } - session, err := checkAuthenticated() + session, sessionMgr, err := checkAuthenticatedWithManager() if err != nil { return fmt.Errorf("not authenticated: %w", err) } - client := api.NewProtonMailClient(cfg) + client := api.NewProtonMailClient(cfg, sessionMgr) client.SetAuthHeader(session.AccessToken) mailClient := internalmail.NewClient(client) @@ -241,12 +241,12 @@ func draftSendCmd() *cobra.Command { return fmt.Errorf("failed to load config: %w", err) } - session, err := checkAuthenticated() + session, sessionMgr, err := checkAuthenticatedWithManager() if err != nil { return fmt.Errorf("not authenticated: %w", err) } - client := api.NewProtonMailClient(cfg) + client := api.NewProtonMailClient(cfg, sessionMgr) client.SetAuthHeader(session.AccessToken) mailClient := internalmail.NewClient(client) diff --git a/cmd/folders.go b/cmd/folders.go index e1409df..f4a94fb 100644 --- a/cmd/folders.go +++ b/cmd/folders.go @@ -370,7 +370,7 @@ func newLabelClient() (*labels.Client, error) { return nil, fmt.Errorf("not authenticated: %w", err) } - apiClient := api.NewProtonMailClient(cfg) + apiClient := api.NewProtonMailClient(cfg, sessionMgr) apiClient.SetAuthHeader(session.AccessToken) return labels.NewClient(apiClient), nil diff --git a/cmd/mail.go b/cmd/mail.go index 8c65888..f1514b5 100644 --- a/cmd/mail.go +++ b/cmd/mail.go @@ -31,6 +31,22 @@ func checkAuthenticated() (*auth.Session, error) { return session, nil } +func checkAuthenticatedWithManager() (*auth.Session, *auth.SessionManager, error) { + sessionMgr, err := auth.NewSessionManager() + if err != nil { + return nil, nil, fmt.Errorf("failed to create session manager: %w", err) + } + authenticated, err := sessionMgr.IsAuthenticated() + if err != nil || !authenticated { + return nil, nil, fmt.Errorf("not authenticated (run 'pop login' first): %w", err) + } + session, err := sessionMgr.GetSession() + if err != nil { + return nil, nil, fmt.Errorf("not authenticated: %w", err) + } + return session, sessionMgr, nil +} + func mailCmd() *cobra.Command { cmd := &cobra.Command{ Use: "mail", @@ -64,12 +80,12 @@ func mailListCmd() *cobra.Command { return fmt.Errorf("failed to load config: %w", err) } - session, err := checkAuthenticated() + session, sessionMgr, err := checkAuthenticatedWithManager() if err != nil { return err } - client := api.NewProtonMailClient(cfg) + client := api.NewProtonMailClient(cfg, sessionMgr) client.SetAuthHeader(session.AccessToken) mailClient := internalmail.NewClient(client) @@ -158,12 +174,12 @@ func mailReadCmd() *cobra.Command { return fmt.Errorf("failed to load config: %w", err) } - session, err := checkAuthenticated() + session, sessionMgr, err := checkAuthenticatedWithManager() if err != nil { return err } - client := api.NewProtonMailClient(cfg) + client := api.NewProtonMailClient(cfg, sessionMgr) client.SetAuthHeader(session.AccessToken) mailClient := internalmail.NewClient(client) @@ -221,12 +237,12 @@ func mailSendCmd() *cobra.Command { return fmt.Errorf("failed to load config: %w", err) } - session, err := checkAuthenticated() + session, sessionMgr, err := checkAuthenticatedWithManager() if err != nil { return err } - client := api.NewProtonMailClient(cfg) + client := api.NewProtonMailClient(cfg, sessionMgr) client.SetAuthHeader(session.AccessToken) mailClient := internalmail.NewClient(client) @@ -276,12 +292,12 @@ func mailDeleteCmd() *cobra.Command { return fmt.Errorf("failed to load config: %w", err) } - session, err := checkAuthenticated() + session, sessionMgr, err := checkAuthenticatedWithManager() if err != nil { return err } - client := api.NewProtonMailClient(cfg) + client := api.NewProtonMailClient(cfg, sessionMgr) client.SetAuthHeader(session.AccessToken) mailClient := internalmail.NewClient(client) @@ -312,12 +328,12 @@ func mailTrashCmd() *cobra.Command { return fmt.Errorf("failed to load config: %w", err) } - session, err := checkAuthenticated() + session, sessionMgr, err := checkAuthenticatedWithManager() if err != nil { return err } - client := api.NewProtonMailClient(cfg) + client := api.NewProtonMailClient(cfg, sessionMgr) client.SetAuthHeader(session.AccessToken) mailClient := internalmail.NewClient(client) @@ -456,12 +472,12 @@ func mailSearchCmd() *cobra.Command { return fmt.Errorf("failed to load config: %w", err) } - session, err := checkAuthenticated() + session, sessionMgr, err := checkAuthenticatedWithManager() if err != nil { return err } - client := api.NewProtonMailClient(cfg) + client := api.NewProtonMailClient(cfg, sessionMgr) client.SetAuthHeader(session.AccessToken) mailClient := internalmail.NewClient(client) diff --git a/internal/api/client.go b/internal/api/client.go index 53ca13d..cb446cb 100644 --- a/internal/api/client.go +++ b/internal/api/client.go @@ -2,6 +2,7 @@ package api import ( "bytes" + "context" "encoding/json" "fmt" "io" @@ -10,6 +11,7 @@ import ( "time" "github.com/frenocorp/pop/internal/config" + "github.com/frenocorp/pop/internal/auth" ) type ProtonMailClient struct { @@ -19,6 +21,7 @@ type ProtonMailClient struct { rateLimiter *RateLimiter authHeader string authMu sync.RWMutex + sessionRefresher auth.SessionRefresher } type RateLimiter struct { @@ -28,11 +31,12 @@ type RateLimiter struct { window time.Duration } -func NewProtonMailClient(cfg *config.Config) *ProtonMailClient { +func NewProtonMailClient(cfg *config.Config, refresher auth.SessionRefresher) *ProtonMailClient { return &ProtonMailClient{ baseURL: cfg.APIBaseURL, httpClient: &http.Client{Timeout: time.Duration(cfg.TimeoutSec) * time.Second}, config: cfg, + sessionRefresher: refresher, rateLimiter: &RateLimiter{ requests: make([]time.Time, 0, cfg.RateLimitReq), limit: cfg.RateLimitReq, @@ -53,6 +57,13 @@ func (c *ProtonMailClient) getAuthHeader() string { return c.authHeader } +func (c *ProtonMailClient) refreshAuth() error { + if c.sessionRefresher == nil { + return fmt.Errorf("no session refresher configured") + } + return c.sessionRefresher.RefreshToken() +} + func (c *ProtonMailClient) GetBaseURL() string { return c.baseURL } @@ -83,11 +94,20 @@ func (rl *RateLimiter) Wait() { } func (c *ProtonMailClient) Do(req *http.Request) (*http.Response, error) { + return c.DoWithContext(context.Background(), req) +} + +func (c *ProtonMailClient) DoWithContext(ctx context.Context, req *http.Request) (*http.Response, error) { c.rateLimiter.Wait() req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", c.getAuthHeader())) req.Header.Set("Accept", "application/json") + // Check if request has its own context + if ctx != context.Background() { + req = req.WithContext(ctx) + } + resp, err := c.httpClient.Do(req) if err != nil { return nil, err @@ -98,7 +118,33 @@ func (c *ProtonMailClient) Do(req *http.Request) (*http.Response, error) { c.rateLimiter.requests = append(c.rateLimiter.requests, time.Now()) c.rateLimiter.mu.Unlock() - // Check for API errors + // Check for 401 and attempt refresh + if resp.StatusCode == http.StatusUnauthorized { + // Close the current response body + resp.Body.Close() + + // Attempt to refresh the token + if err := c.refreshAuth(); err != nil { + return resp, fmt.Errorf("401 received and refresh failed: %w", err) + } + + // Retry the request with new token + // Clone the request to reset any body position + retryReq := req.Clone(ctx) + retryReq.Header.Set("Authorization", fmt.Sprintf("Bearer %s", c.getAuthHeader())) + + resp, err = c.httpClient.Do(retryReq) + if err != nil { + return nil, err + } + + // Record the retry request + c.rateLimiter.mu.Lock() + c.rateLimiter.requests = append(c.rateLimiter.requests, time.Now()) + c.rateLimiter.mu.Unlock() + } + + // Check for other API errors if resp.StatusCode >= 400 { body, _ := io.ReadAll(resp.Body) var apiErr APIError @@ -120,3 +166,14 @@ type APIError struct { func (e *APIError) Error() string { return fmt.Sprintf("API error %d: %s", e.HTTPStatus, e.Message) } + +// Helper function to create a request with context +func NewRequestWithContext(ctx context.Context, method, url string, body io.Reader) (*http.Request, error) { + req, err := http.NewRequestWithContext(ctx, method, url, body) + if err != nil { + return nil, err + } + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Accept", "application/json") + return req, nil +} diff --git a/internal/auth/interface.go b/internal/auth/interface.go new file mode 100644 index 0000000..f395ed9 --- /dev/null +++ b/internal/auth/interface.go @@ -0,0 +1,11 @@ +package auth + +import "context" + +// SessionRefresher defines the interface for refreshing authentication tokens. +// This allows the API client to automatically refresh tokens on 401 responses. +type SessionRefresher interface { + RefreshToken() error + RefreshTokenWithContext(ctx context.Context) error + GetSession() (*Session, error) +} diff --git a/internal/auth/session.go b/internal/auth/session.go index 79b5a9e..3639e73 100644 --- a/internal/auth/session.go +++ b/internal/auth/session.go @@ -2,6 +2,7 @@ package auth import ( "bytes" + "context" "crypto/aes" "crypto/cipher" "crypto/rand" @@ -54,7 +55,7 @@ func NewSessionManager() (*SessionManager, error) { }, nil } -func (m *SessionManager) LoginWithCredentials(apiBaseURL, email, password, mailPassphrase string) error { +func (m *SessionManager) LoginWithCredentials(ctx context.Context, apiBaseURL, email, password, mailPassphrase string) error { if err := os.MkdirAll(m.configDir, 0700); err != nil { return fmt.Errorf("failed to create config dir: %w", err) } @@ -75,7 +76,7 @@ func (m *SessionManager) LoginWithCredentials(apiBaseURL, email, password, mailP return fmt.Errorf("failed to marshal auth payload: %w", err) } - req, err := http.NewRequest("POST", authURL, bytes.NewBuffer(jsonData)) + req, err := http.NewRequestWithContext(ctx, "POST", authURL, bytes.NewBuffer(jsonData)) if err != nil { return fmt.Errorf("failed to create auth request: %w", err) } @@ -95,11 +96,11 @@ func (m *SessionManager) LoginWithCredentials(apiBaseURL, email, password, mailP } var authResponse struct { - UID string `json:"UID"` - AccessToken string `json:"AccessToken"` - RefreshToken string `json:"RefreshToken"` - ExpiresIn int `json:"ExpiresIn"` - TwoFARequired bool `json:"TwoFARequired"` + UID string `json:"UID"` + AccessToken string `json:"AccessToken"` + RefreshToken string `json:"RefreshToken"` + ExpiresIn int `json:"ExpiresIn"` + TwoFARequired bool `json:"TwoFARequired"` } if err := json.NewDecoder(resp.Body).Decode(&authResponse); err != nil { @@ -135,7 +136,7 @@ func (m *SessionManager) LoginWithCredentials(apiBaseURL, email, password, mailP return nil } -func (m *SessionManager) LoginInteractive(apiBaseURL string) error { +func (m *SessionManager) LoginInteractive(ctx context.Context, apiBaseURL string) error { if err := os.MkdirAll(m.configDir, 0700); err != nil { return fmt.Errorf("failed to create config dir: %w", err) } @@ -192,7 +193,7 @@ func (m *SessionManager) LoginInteractive(apiBaseURL string) error { return fmt.Errorf("failed to marshal auth payload: %w", err) } - req, err := http.NewRequest("POST", authURL, bytes.NewBuffer(jsonData)) + req, err := http.NewRequestWithContext(ctx, "POST", authURL, bytes.NewBuffer(jsonData)) if err != nil { return fmt.Errorf("failed to create auth request: %w", err) } @@ -263,7 +264,7 @@ func (m *SessionManager) LoginInteractive(apiBaseURL string) error { return fmt.Errorf("failed to marshal TOTP payload: %w", err) } - totpReq, err := http.NewRequest("POST", totpURL, bytes.NewBuffer(totpJSON)) + totpReq, err := http.NewRequestWithContext(ctx, "POST", totpURL, bytes.NewBuffer(totpJSON)) if err != nil { return fmt.Errorf("failed to create TOTP request: %w", err) } @@ -372,6 +373,10 @@ func (m *SessionManager) IsAuthenticated() (bool, error) { } func (m *SessionManager) RefreshToken() error { + return m.RefreshTokenWithContext(context.Background()) +} + +func (m *SessionManager) RefreshTokenWithContext(ctx context.Context) error { session, err := m.GetSession() if err != nil { return fmt.Errorf("failed to get session: %w", err) @@ -394,7 +399,7 @@ func (m *SessionManager) RefreshToken() error { return fmt.Errorf("failed to marshal refresh payload: %w", err) } - req, err := http.NewRequest("POST", refreshURL, bytes.NewBuffer(jsonData)) + req, err := http.NewRequestWithContext(ctx, "POST", refreshURL, bytes.NewBuffer(jsonData)) if err != nil { return fmt.Errorf("failed to create refresh request: %w", err) }