package api import ( "bytes" "context" "encoding/json" "fmt" "io" "net/http" "sync" "time" "github.com/frenocorp/pop/internal/config" "github.com/frenocorp/pop/internal/auth" ) type ProtonMailClient struct { baseURL string httpClient *http.Client config *config.Config rateLimiter *RateLimiter authHeader string authMu sync.RWMutex sessionRefresher auth.SessionRefresher } type RateLimiter struct { mu sync.Mutex requests []time.Time limit int window time.Duration } func NewProtonMailClient(cfg *config.Config, refresher auth.SessionRefresher) *ProtonMailClient { return &ProtonMailClient{ baseURL: cfg.APIBaseURL, httpClient: &http.Client{Timeout: time.Duration(cfg.TimeoutSec) * time.Second}, config: cfg, sessionRefresher: refresher, rateLimiter: &RateLimiter{ requests: make([]time.Time, 0, cfg.RateLimitReq), limit: cfg.RateLimitReq, window: time.Duration(cfg.RateLimitWin) * time.Second, }, } } func (c *ProtonMailClient) SetAuthHeader(token string) { c.authMu.Lock() defer c.authMu.Unlock() c.authHeader = token } func (c *ProtonMailClient) getAuthHeader() string { c.authMu.RLock() defer c.authMu.RUnlock() return c.authHeader } func (c *ProtonMailClient) refreshAuth() error { if c.sessionRefresher == nil { return fmt.Errorf("no session refresher configured") } return c.sessionRefresher.RefreshToken() } func (c *ProtonMailClient) GetBaseURL() string { return c.baseURL } func (rl *RateLimiter) Wait() { rl.mu.Lock() defer rl.mu.Unlock() now := time.Now() windowStart := now.Add(-rl.window) // Remove old requests outside the window validRequests := make([]time.Time, 0, rl.limit) for _, t := range rl.requests { if t.After(windowStart) { validRequests = append(validRequests, t) } } rl.requests = validRequests // Wait if at limit if len(rl.requests) >= rl.limit { sleep := rl.requests[0].Add(rl.window).Sub(now) if sleep > 0 { time.Sleep(sleep) } } } func (c *ProtonMailClient) Do(req *http.Request) (*http.Response, error) { return c.DoWithContext(context.Background(), req) } func (c *ProtonMailClient) DoWithContext(ctx context.Context, req *http.Request) (*http.Response, error) { c.rateLimiter.Wait() req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", c.getAuthHeader())) req.Header.Set("Accept", "application/json") // Check if request has its own context if ctx != context.Background() { req = req.WithContext(ctx) } resp, err := c.httpClient.Do(req) if err != nil { return nil, err } // Record the request c.rateLimiter.mu.Lock() c.rateLimiter.requests = append(c.rateLimiter.requests, time.Now()) c.rateLimiter.mu.Unlock() // Check for 401 and attempt refresh if resp.StatusCode == http.StatusUnauthorized { // Close the current response body resp.Body.Close() // Attempt to refresh the token if err := c.refreshAuth(); err != nil { return resp, fmt.Errorf("401 received and refresh failed: %w", err) } // Retry the request with new token // Clone the request to reset any body position retryReq := req.Clone(ctx) retryReq.Header.Set("Authorization", fmt.Sprintf("Bearer %s", c.getAuthHeader())) resp, err = c.httpClient.Do(retryReq) if err != nil { return nil, err } // Record the retry request c.rateLimiter.mu.Lock() c.rateLimiter.requests = append(c.rateLimiter.requests, time.Now()) c.rateLimiter.mu.Unlock() } // Check for other API errors if resp.StatusCode >= 400 { body, _ := io.ReadAll(resp.Body) var apiErr APIError if err := json.Unmarshal(body, &apiErr); err == nil { resp.Body = io.NopCloser(io.MultiReader(io.NopCloser(bytes.NewReader(body)), bytes.NewReader(body))) return resp, &apiErr } } return resp, nil } type APIError struct { HTTPStatus int `json:"-"` Code int `json:"Code,omitempty"` Message string `json:"Message,omitempty"` } func (e *APIError) Error() string { return fmt.Sprintf("API error %d: %s", e.HTTPStatus, e.Message) } // Helper function to create a request with context func NewRequestWithContext(ctx context.Context, method, url string, body io.Reader) (*http.Request, error) { req, err := http.NewRequestWithContext(ctx, method, url, body) if err != nil { return nil, err } req.Header.Set("Content-Type", "application/json") req.Header.Set("Accept", "application/json") return req, nil }