feat: implement automatic auth token refresh on 401 with context support (FRE-4763)
- Add SessionRefresher interface for token refresh abstraction - Update ProtonMailClient to auto-refresh on 401 responses - Add DoWithContext method for context-aware HTTP requests - Update SessionManager with RefreshTokenWithContext method - Update LoginWithCredentials and LoginInteractive to accept context - Add checkAuthenticatedWithManager helper for commands needing session manager - All API methods now support proper cancellation via context.Context Files changed: - internal/api/client.go - Auto-refresh on 401, context support - internal/auth/session.go - Context-aware refresh and login methods - internal/auth/interface.go - SessionRefresher interface - cmd/mail.go, cmd/draft.go, cmd/folders.go - Updated to use new helpers - cmd/auth.go - Context support for login commands Co-Authored-By: Paperclip <noreply@paperclip.ing>
This commit is contained in:
@@ -2,6 +2,7 @@ package api
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
@@ -10,6 +11,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/frenocorp/pop/internal/config"
|
||||
"github.com/frenocorp/pop/internal/auth"
|
||||
)
|
||||
|
||||
type ProtonMailClient struct {
|
||||
@@ -19,6 +21,7 @@ type ProtonMailClient struct {
|
||||
rateLimiter *RateLimiter
|
||||
authHeader string
|
||||
authMu sync.RWMutex
|
||||
sessionRefresher auth.SessionRefresher
|
||||
}
|
||||
|
||||
type RateLimiter struct {
|
||||
@@ -28,11 +31,12 @@ type RateLimiter struct {
|
||||
window time.Duration
|
||||
}
|
||||
|
||||
func NewProtonMailClient(cfg *config.Config) *ProtonMailClient {
|
||||
func NewProtonMailClient(cfg *config.Config, refresher auth.SessionRefresher) *ProtonMailClient {
|
||||
return &ProtonMailClient{
|
||||
baseURL: cfg.APIBaseURL,
|
||||
httpClient: &http.Client{Timeout: time.Duration(cfg.TimeoutSec) * time.Second},
|
||||
config: cfg,
|
||||
sessionRefresher: refresher,
|
||||
rateLimiter: &RateLimiter{
|
||||
requests: make([]time.Time, 0, cfg.RateLimitReq),
|
||||
limit: cfg.RateLimitReq,
|
||||
@@ -53,6 +57,13 @@ func (c *ProtonMailClient) getAuthHeader() string {
|
||||
return c.authHeader
|
||||
}
|
||||
|
||||
func (c *ProtonMailClient) refreshAuth() error {
|
||||
if c.sessionRefresher == nil {
|
||||
return fmt.Errorf("no session refresher configured")
|
||||
}
|
||||
return c.sessionRefresher.RefreshToken()
|
||||
}
|
||||
|
||||
func (c *ProtonMailClient) GetBaseURL() string {
|
||||
return c.baseURL
|
||||
}
|
||||
@@ -83,11 +94,20 @@ func (rl *RateLimiter) Wait() {
|
||||
}
|
||||
|
||||
func (c *ProtonMailClient) Do(req *http.Request) (*http.Response, error) {
|
||||
return c.DoWithContext(context.Background(), req)
|
||||
}
|
||||
|
||||
func (c *ProtonMailClient) DoWithContext(ctx context.Context, req *http.Request) (*http.Response, error) {
|
||||
c.rateLimiter.Wait()
|
||||
|
||||
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", c.getAuthHeader()))
|
||||
req.Header.Set("Accept", "application/json")
|
||||
|
||||
// Check if request has its own context
|
||||
if ctx != context.Background() {
|
||||
req = req.WithContext(ctx)
|
||||
}
|
||||
|
||||
resp, err := c.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -98,7 +118,33 @@ func (c *ProtonMailClient) Do(req *http.Request) (*http.Response, error) {
|
||||
c.rateLimiter.requests = append(c.rateLimiter.requests, time.Now())
|
||||
c.rateLimiter.mu.Unlock()
|
||||
|
||||
// Check for API errors
|
||||
// Check for 401 and attempt refresh
|
||||
if resp.StatusCode == http.StatusUnauthorized {
|
||||
// Close the current response body
|
||||
resp.Body.Close()
|
||||
|
||||
// Attempt to refresh the token
|
||||
if err := c.refreshAuth(); err != nil {
|
||||
return resp, fmt.Errorf("401 received and refresh failed: %w", err)
|
||||
}
|
||||
|
||||
// Retry the request with new token
|
||||
// Clone the request to reset any body position
|
||||
retryReq := req.Clone(ctx)
|
||||
retryReq.Header.Set("Authorization", fmt.Sprintf("Bearer %s", c.getAuthHeader()))
|
||||
|
||||
resp, err = c.httpClient.Do(retryReq)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Record the retry request
|
||||
c.rateLimiter.mu.Lock()
|
||||
c.rateLimiter.requests = append(c.rateLimiter.requests, time.Now())
|
||||
c.rateLimiter.mu.Unlock()
|
||||
}
|
||||
|
||||
// Check for other API errors
|
||||
if resp.StatusCode >= 400 {
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
var apiErr APIError
|
||||
@@ -120,3 +166,14 @@ type APIError struct {
|
||||
func (e *APIError) Error() string {
|
||||
return fmt.Sprintf("API error %d: %s", e.HTTPStatus, e.Message)
|
||||
}
|
||||
|
||||
// Helper function to create a request with context
|
||||
func NewRequestWithContext(ctx context.Context, method, url string, body io.Reader) (*http.Request, error) {
|
||||
req, err := http.NewRequestWithContext(ctx, method, url, body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Accept", "application/json")
|
||||
return req, nil
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user