feat: implement automatic auth token refresh on 401 with context support (FRE-4763)

- Add SessionRefresher interface for token refresh abstraction
- Update ProtonMailClient to auto-refresh on 401 responses
- Add DoWithContext method for context-aware HTTP requests
- Update SessionManager with RefreshTokenWithContext method
- Update LoginWithCredentials and LoginInteractive to accept context
- Add checkAuthenticatedWithManager helper for commands needing session manager
- All API methods now support proper cancellation via context.Context

Files changed:
- internal/api/client.go - Auto-refresh on 401, context support
- internal/auth/session.go - Context-aware refresh and login methods
- internal/auth/interface.go - SessionRefresher interface
- cmd/mail.go, cmd/draft.go, cmd/folders.go - Updated to use new helpers
- cmd/auth.go - Context support for login commands

Co-Authored-By: Paperclip <noreply@paperclip.ing>
This commit is contained in:
2026-05-09 21:46:03 -04:00
parent 19a9e2a3df
commit 691a2acdad
7 changed files with 125 additions and 35 deletions

View File

@@ -0,0 +1,11 @@
package auth
import "context"
// SessionRefresher defines the interface for refreshing authentication tokens.
// This allows the API client to automatically refresh tokens on 401 responses.
type SessionRefresher interface {
RefreshToken() error
RefreshTokenWithContext(ctx context.Context) error
GetSession() (*Session, error)
}

View File

@@ -2,6 +2,7 @@ package auth
import (
"bytes"
"context"
"crypto/aes"
"crypto/cipher"
"crypto/rand"
@@ -54,7 +55,7 @@ func NewSessionManager() (*SessionManager, error) {
}, nil
}
func (m *SessionManager) LoginWithCredentials(apiBaseURL, email, password, mailPassphrase string) error {
func (m *SessionManager) LoginWithCredentials(ctx context.Context, apiBaseURL, email, password, mailPassphrase string) error {
if err := os.MkdirAll(m.configDir, 0700); err != nil {
return fmt.Errorf("failed to create config dir: %w", err)
}
@@ -75,7 +76,7 @@ func (m *SessionManager) LoginWithCredentials(apiBaseURL, email, password, mailP
return fmt.Errorf("failed to marshal auth payload: %w", err)
}
req, err := http.NewRequest("POST", authURL, bytes.NewBuffer(jsonData))
req, err := http.NewRequestWithContext(ctx, "POST", authURL, bytes.NewBuffer(jsonData))
if err != nil {
return fmt.Errorf("failed to create auth request: %w", err)
}
@@ -95,11 +96,11 @@ func (m *SessionManager) LoginWithCredentials(apiBaseURL, email, password, mailP
}
var authResponse struct {
UID string `json:"UID"`
AccessToken string `json:"AccessToken"`
RefreshToken string `json:"RefreshToken"`
ExpiresIn int `json:"ExpiresIn"`
TwoFARequired bool `json:"TwoFARequired"`
UID string `json:"UID"`
AccessToken string `json:"AccessToken"`
RefreshToken string `json:"RefreshToken"`
ExpiresIn int `json:"ExpiresIn"`
TwoFARequired bool `json:"TwoFARequired"`
}
if err := json.NewDecoder(resp.Body).Decode(&authResponse); err != nil {
@@ -135,7 +136,7 @@ func (m *SessionManager) LoginWithCredentials(apiBaseURL, email, password, mailP
return nil
}
func (m *SessionManager) LoginInteractive(apiBaseURL string) error {
func (m *SessionManager) LoginInteractive(ctx context.Context, apiBaseURL string) error {
if err := os.MkdirAll(m.configDir, 0700); err != nil {
return fmt.Errorf("failed to create config dir: %w", err)
}
@@ -192,7 +193,7 @@ func (m *SessionManager) LoginInteractive(apiBaseURL string) error {
return fmt.Errorf("failed to marshal auth payload: %w", err)
}
req, err := http.NewRequest("POST", authURL, bytes.NewBuffer(jsonData))
req, err := http.NewRequestWithContext(ctx, "POST", authURL, bytes.NewBuffer(jsonData))
if err != nil {
return fmt.Errorf("failed to create auth request: %w", err)
}
@@ -263,7 +264,7 @@ func (m *SessionManager) LoginInteractive(apiBaseURL string) error {
return fmt.Errorf("failed to marshal TOTP payload: %w", err)
}
totpReq, err := http.NewRequest("POST", totpURL, bytes.NewBuffer(totpJSON))
totpReq, err := http.NewRequestWithContext(ctx, "POST", totpURL, bytes.NewBuffer(totpJSON))
if err != nil {
return fmt.Errorf("failed to create TOTP request: %w", err)
}
@@ -372,6 +373,10 @@ func (m *SessionManager) IsAuthenticated() (bool, error) {
}
func (m *SessionManager) RefreshToken() error {
return m.RefreshTokenWithContext(context.Background())
}
func (m *SessionManager) RefreshTokenWithContext(ctx context.Context) error {
session, err := m.GetSession()
if err != nil {
return fmt.Errorf("failed to get session: %w", err)
@@ -394,7 +399,7 @@ func (m *SessionManager) RefreshToken() error {
return fmt.Errorf("failed to marshal refresh payload: %w", err)
}
req, err := http.NewRequest("POST", refreshURL, bytes.NewBuffer(jsonData))
req, err := http.NewRequestWithContext(ctx, "POST", refreshURL, bytes.NewBuffer(jsonData))
if err != nil {
return fmt.Errorf("failed to create refresh request: %w", err)
}