- Add SessionRefresher interface for token refresh abstraction - Update ProtonMailClient to auto-refresh on 401 responses - Add DoWithContext method for context-aware HTTP requests - Update SessionManager with RefreshTokenWithContext method - Update LoginWithCredentials and LoginInteractive to accept context - Add checkAuthenticatedWithManager helper for commands needing session manager - All API methods now support proper cancellation via context.Context Files changed: - internal/api/client.go - Auto-refresh on 401, context support - internal/auth/session.go - Context-aware refresh and login methods - internal/auth/interface.go - SessionRefresher interface - cmd/mail.go, cmd/draft.go, cmd/folders.go - Updated to use new helpers - cmd/auth.go - Context support for login commands Co-Authored-By: Paperclip <noreply@paperclip.ing>
180 lines
4.3 KiB
Go
180 lines
4.3 KiB
Go
package api
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"encoding/json"
|
|
"fmt"
|
|
"io"
|
|
"net/http"
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/frenocorp/pop/internal/config"
|
|
"github.com/frenocorp/pop/internal/auth"
|
|
)
|
|
|
|
type ProtonMailClient struct {
|
|
baseURL string
|
|
httpClient *http.Client
|
|
config *config.Config
|
|
rateLimiter *RateLimiter
|
|
authHeader string
|
|
authMu sync.RWMutex
|
|
sessionRefresher auth.SessionRefresher
|
|
}
|
|
|
|
type RateLimiter struct {
|
|
mu sync.Mutex
|
|
requests []time.Time
|
|
limit int
|
|
window time.Duration
|
|
}
|
|
|
|
func NewProtonMailClient(cfg *config.Config, refresher auth.SessionRefresher) *ProtonMailClient {
|
|
return &ProtonMailClient{
|
|
baseURL: cfg.APIBaseURL,
|
|
httpClient: &http.Client{Timeout: time.Duration(cfg.TimeoutSec) * time.Second},
|
|
config: cfg,
|
|
sessionRefresher: refresher,
|
|
rateLimiter: &RateLimiter{
|
|
requests: make([]time.Time, 0, cfg.RateLimitReq),
|
|
limit: cfg.RateLimitReq,
|
|
window: time.Duration(cfg.RateLimitWin) * time.Second,
|
|
},
|
|
}
|
|
}
|
|
|
|
func (c *ProtonMailClient) SetAuthHeader(token string) {
|
|
c.authMu.Lock()
|
|
defer c.authMu.Unlock()
|
|
c.authHeader = token
|
|
}
|
|
|
|
func (c *ProtonMailClient) getAuthHeader() string {
|
|
c.authMu.RLock()
|
|
defer c.authMu.RUnlock()
|
|
return c.authHeader
|
|
}
|
|
|
|
func (c *ProtonMailClient) refreshAuth() error {
|
|
if c.sessionRefresher == nil {
|
|
return fmt.Errorf("no session refresher configured")
|
|
}
|
|
return c.sessionRefresher.RefreshToken()
|
|
}
|
|
|
|
func (c *ProtonMailClient) GetBaseURL() string {
|
|
return c.baseURL
|
|
}
|
|
|
|
func (rl *RateLimiter) Wait() {
|
|
rl.mu.Lock()
|
|
defer rl.mu.Unlock()
|
|
|
|
now := time.Now()
|
|
windowStart := now.Add(-rl.window)
|
|
|
|
// Remove old requests outside the window
|
|
validRequests := make([]time.Time, 0, rl.limit)
|
|
for _, t := range rl.requests {
|
|
if t.After(windowStart) {
|
|
validRequests = append(validRequests, t)
|
|
}
|
|
}
|
|
rl.requests = validRequests
|
|
|
|
// Wait if at limit
|
|
if len(rl.requests) >= rl.limit {
|
|
sleep := rl.requests[0].Add(rl.window).Sub(now)
|
|
if sleep > 0 {
|
|
time.Sleep(sleep)
|
|
}
|
|
}
|
|
}
|
|
|
|
func (c *ProtonMailClient) Do(req *http.Request) (*http.Response, error) {
|
|
return c.DoWithContext(context.Background(), req)
|
|
}
|
|
|
|
func (c *ProtonMailClient) DoWithContext(ctx context.Context, req *http.Request) (*http.Response, error) {
|
|
c.rateLimiter.Wait()
|
|
|
|
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", c.getAuthHeader()))
|
|
req.Header.Set("Accept", "application/json")
|
|
|
|
// Check if request has its own context
|
|
if ctx != context.Background() {
|
|
req = req.WithContext(ctx)
|
|
}
|
|
|
|
resp, err := c.httpClient.Do(req)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
// Record the request
|
|
c.rateLimiter.mu.Lock()
|
|
c.rateLimiter.requests = append(c.rateLimiter.requests, time.Now())
|
|
c.rateLimiter.mu.Unlock()
|
|
|
|
// Check for 401 and attempt refresh
|
|
if resp.StatusCode == http.StatusUnauthorized {
|
|
// Close the current response body
|
|
resp.Body.Close()
|
|
|
|
// Attempt to refresh the token
|
|
if err := c.refreshAuth(); err != nil {
|
|
return resp, fmt.Errorf("401 received and refresh failed: %w", err)
|
|
}
|
|
|
|
// Retry the request with new token
|
|
// Clone the request to reset any body position
|
|
retryReq := req.Clone(ctx)
|
|
retryReq.Header.Set("Authorization", fmt.Sprintf("Bearer %s", c.getAuthHeader()))
|
|
|
|
resp, err = c.httpClient.Do(retryReq)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
// Record the retry request
|
|
c.rateLimiter.mu.Lock()
|
|
c.rateLimiter.requests = append(c.rateLimiter.requests, time.Now())
|
|
c.rateLimiter.mu.Unlock()
|
|
}
|
|
|
|
// Check for other API errors
|
|
if resp.StatusCode >= 400 {
|
|
body, _ := io.ReadAll(resp.Body)
|
|
var apiErr APIError
|
|
if err := json.Unmarshal(body, &apiErr); err == nil {
|
|
resp.Body = io.NopCloser(io.MultiReader(io.NopCloser(bytes.NewReader(body)), bytes.NewReader(body)))
|
|
return resp, &apiErr
|
|
}
|
|
}
|
|
|
|
return resp, nil
|
|
}
|
|
|
|
type APIError struct {
|
|
HTTPStatus int `json:"-"`
|
|
Code int `json:"Code,omitempty"`
|
|
Message string `json:"Message,omitempty"`
|
|
}
|
|
|
|
func (e *APIError) Error() string {
|
|
return fmt.Sprintf("API error %d: %s", e.HTTPStatus, e.Message)
|
|
}
|
|
|
|
// Helper function to create a request with context
|
|
func NewRequestWithContext(ctx context.Context, method, url string, body io.Reader) (*http.Request, error) {
|
|
req, err := http.NewRequestWithContext(ctx, method, url, body)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
req.Header.Set("Content-Type", "application/json")
|
|
req.Header.Set("Accept", "application/json")
|
|
return req, nil
|
|
}
|