Compare commits
2 Commits
19a9e2a3df
...
5dc4a1b742
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
5dc4a1b742 | ||
| 691a2acdad |
@@ -1,6 +1,7 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
|
||||
@@ -26,7 +27,7 @@ func loginCmd() *cobra.Command {
|
||||
return fmt.Errorf("failed to create session manager: %w", err)
|
||||
}
|
||||
|
||||
return manager.LoginInteractive(cfg.APIBaseURL)
|
||||
return manager.LoginInteractive(context.Background(), cfg.APIBaseURL)
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
16
cmd/draft.go
16
cmd/draft.go
@@ -64,12 +64,12 @@ func draftSaveCmd() *cobra.Command {
|
||||
return fmt.Errorf("failed to load config: %w", err)
|
||||
}
|
||||
|
||||
session, err := checkAuthenticated()
|
||||
session, sessionMgr, err := checkAuthenticatedWithManager()
|
||||
if err != nil {
|
||||
return fmt.Errorf("not authenticated: %w", err)
|
||||
}
|
||||
|
||||
client := api.NewProtonMailClient(cfg)
|
||||
client := api.NewProtonMailClient(cfg, sessionMgr)
|
||||
client.SetAuthHeader(session.AccessToken)
|
||||
mailClient := internalmail.NewClient(client)
|
||||
|
||||
@@ -125,12 +125,12 @@ func draftListCmd() *cobra.Command {
|
||||
return fmt.Errorf("failed to load config: %w", err)
|
||||
}
|
||||
|
||||
session, err := checkAuthenticated()
|
||||
session, sessionMgr, err := checkAuthenticatedWithManager()
|
||||
if err != nil {
|
||||
return fmt.Errorf("not authenticated: %w", err)
|
||||
}
|
||||
|
||||
client := api.NewProtonMailClient(cfg)
|
||||
client := api.NewProtonMailClient(cfg, sessionMgr)
|
||||
client.SetAuthHeader(session.AccessToken)
|
||||
mailClient := internalmail.NewClient(client)
|
||||
|
||||
@@ -190,12 +190,12 @@ func draftEditCmd() *cobra.Command {
|
||||
return fmt.Errorf("failed to load config: %w", err)
|
||||
}
|
||||
|
||||
session, err := checkAuthenticated()
|
||||
session, sessionMgr, err := checkAuthenticatedWithManager()
|
||||
if err != nil {
|
||||
return fmt.Errorf("not authenticated: %w", err)
|
||||
}
|
||||
|
||||
client := api.NewProtonMailClient(cfg)
|
||||
client := api.NewProtonMailClient(cfg, sessionMgr)
|
||||
client.SetAuthHeader(session.AccessToken)
|
||||
mailClient := internalmail.NewClient(client)
|
||||
|
||||
@@ -241,12 +241,12 @@ func draftSendCmd() *cobra.Command {
|
||||
return fmt.Errorf("failed to load config: %w", err)
|
||||
}
|
||||
|
||||
session, err := checkAuthenticated()
|
||||
session, sessionMgr, err := checkAuthenticatedWithManager()
|
||||
if err != nil {
|
||||
return fmt.Errorf("not authenticated: %w", err)
|
||||
}
|
||||
|
||||
client := api.NewProtonMailClient(cfg)
|
||||
client := api.NewProtonMailClient(cfg, sessionMgr)
|
||||
client.SetAuthHeader(session.AccessToken)
|
||||
mailClient := internalmail.NewClient(client)
|
||||
|
||||
|
||||
@@ -370,7 +370,7 @@ func newLabelClient() (*labels.Client, error) {
|
||||
return nil, fmt.Errorf("not authenticated: %w", err)
|
||||
}
|
||||
|
||||
apiClient := api.NewProtonMailClient(cfg)
|
||||
apiClient := api.NewProtonMailClient(cfg, sessionMgr)
|
||||
apiClient.SetAuthHeader(session.AccessToken)
|
||||
|
||||
return labels.NewClient(apiClient), nil
|
||||
|
||||
40
cmd/mail.go
40
cmd/mail.go
@@ -31,6 +31,22 @@ func checkAuthenticated() (*auth.Session, error) {
|
||||
return session, nil
|
||||
}
|
||||
|
||||
func checkAuthenticatedWithManager() (*auth.Session, *auth.SessionManager, error) {
|
||||
sessionMgr, err := auth.NewSessionManager()
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("failed to create session manager: %w", err)
|
||||
}
|
||||
authenticated, err := sessionMgr.IsAuthenticated()
|
||||
if err != nil || !authenticated {
|
||||
return nil, nil, fmt.Errorf("not authenticated (run 'pop login' first): %w", err)
|
||||
}
|
||||
session, err := sessionMgr.GetSession()
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("not authenticated: %w", err)
|
||||
}
|
||||
return session, sessionMgr, nil
|
||||
}
|
||||
|
||||
func mailCmd() *cobra.Command {
|
||||
cmd := &cobra.Command{
|
||||
Use: "mail",
|
||||
@@ -64,12 +80,12 @@ func mailListCmd() *cobra.Command {
|
||||
return fmt.Errorf("failed to load config: %w", err)
|
||||
}
|
||||
|
||||
session, err := checkAuthenticated()
|
||||
session, sessionMgr, err := checkAuthenticatedWithManager()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
client := api.NewProtonMailClient(cfg)
|
||||
client := api.NewProtonMailClient(cfg, sessionMgr)
|
||||
client.SetAuthHeader(session.AccessToken)
|
||||
mailClient := internalmail.NewClient(client)
|
||||
|
||||
@@ -158,12 +174,12 @@ func mailReadCmd() *cobra.Command {
|
||||
return fmt.Errorf("failed to load config: %w", err)
|
||||
}
|
||||
|
||||
session, err := checkAuthenticated()
|
||||
session, sessionMgr, err := checkAuthenticatedWithManager()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
client := api.NewProtonMailClient(cfg)
|
||||
client := api.NewProtonMailClient(cfg, sessionMgr)
|
||||
client.SetAuthHeader(session.AccessToken)
|
||||
mailClient := internalmail.NewClient(client)
|
||||
|
||||
@@ -221,12 +237,12 @@ func mailSendCmd() *cobra.Command {
|
||||
return fmt.Errorf("failed to load config: %w", err)
|
||||
}
|
||||
|
||||
session, err := checkAuthenticated()
|
||||
session, sessionMgr, err := checkAuthenticatedWithManager()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
client := api.NewProtonMailClient(cfg)
|
||||
client := api.NewProtonMailClient(cfg, sessionMgr)
|
||||
client.SetAuthHeader(session.AccessToken)
|
||||
mailClient := internalmail.NewClient(client)
|
||||
|
||||
@@ -276,12 +292,12 @@ func mailDeleteCmd() *cobra.Command {
|
||||
return fmt.Errorf("failed to load config: %w", err)
|
||||
}
|
||||
|
||||
session, err := checkAuthenticated()
|
||||
session, sessionMgr, err := checkAuthenticatedWithManager()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
client := api.NewProtonMailClient(cfg)
|
||||
client := api.NewProtonMailClient(cfg, sessionMgr)
|
||||
client.SetAuthHeader(session.AccessToken)
|
||||
mailClient := internalmail.NewClient(client)
|
||||
|
||||
@@ -312,12 +328,12 @@ func mailTrashCmd() *cobra.Command {
|
||||
return fmt.Errorf("failed to load config: %w", err)
|
||||
}
|
||||
|
||||
session, err := checkAuthenticated()
|
||||
session, sessionMgr, err := checkAuthenticatedWithManager()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
client := api.NewProtonMailClient(cfg)
|
||||
client := api.NewProtonMailClient(cfg, sessionMgr)
|
||||
client.SetAuthHeader(session.AccessToken)
|
||||
mailClient := internalmail.NewClient(client)
|
||||
|
||||
@@ -456,12 +472,12 @@ func mailSearchCmd() *cobra.Command {
|
||||
return fmt.Errorf("failed to load config: %w", err)
|
||||
}
|
||||
|
||||
session, err := checkAuthenticated()
|
||||
session, sessionMgr, err := checkAuthenticatedWithManager()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
client := api.NewProtonMailClient(cfg)
|
||||
client := api.NewProtonMailClient(cfg, sessionMgr)
|
||||
client.SetAuthHeader(session.AccessToken)
|
||||
mailClient := internalmail.NewClient(client)
|
||||
|
||||
|
||||
@@ -2,6 +2,7 @@ package api
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
@@ -10,6 +11,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/frenocorp/pop/internal/config"
|
||||
"github.com/frenocorp/pop/internal/auth"
|
||||
)
|
||||
|
||||
type ProtonMailClient struct {
|
||||
@@ -19,6 +21,7 @@ type ProtonMailClient struct {
|
||||
rateLimiter *RateLimiter
|
||||
authHeader string
|
||||
authMu sync.RWMutex
|
||||
sessionRefresher auth.SessionRefresher
|
||||
}
|
||||
|
||||
type RateLimiter struct {
|
||||
@@ -28,11 +31,12 @@ type RateLimiter struct {
|
||||
window time.Duration
|
||||
}
|
||||
|
||||
func NewProtonMailClient(cfg *config.Config) *ProtonMailClient {
|
||||
func NewProtonMailClient(cfg *config.Config, refresher auth.SessionRefresher) *ProtonMailClient {
|
||||
return &ProtonMailClient{
|
||||
baseURL: cfg.APIBaseURL,
|
||||
httpClient: &http.Client{Timeout: time.Duration(cfg.TimeoutSec) * time.Second},
|
||||
config: cfg,
|
||||
sessionRefresher: refresher,
|
||||
rateLimiter: &RateLimiter{
|
||||
requests: make([]time.Time, 0, cfg.RateLimitReq),
|
||||
limit: cfg.RateLimitReq,
|
||||
@@ -53,6 +57,13 @@ func (c *ProtonMailClient) getAuthHeader() string {
|
||||
return c.authHeader
|
||||
}
|
||||
|
||||
func (c *ProtonMailClient) refreshAuth() error {
|
||||
if c.sessionRefresher == nil {
|
||||
return fmt.Errorf("no session refresher configured")
|
||||
}
|
||||
return c.sessionRefresher.RefreshToken()
|
||||
}
|
||||
|
||||
func (c *ProtonMailClient) GetBaseURL() string {
|
||||
return c.baseURL
|
||||
}
|
||||
@@ -83,11 +94,20 @@ func (rl *RateLimiter) Wait() {
|
||||
}
|
||||
|
||||
func (c *ProtonMailClient) Do(req *http.Request) (*http.Response, error) {
|
||||
return c.DoWithContext(context.Background(), req)
|
||||
}
|
||||
|
||||
func (c *ProtonMailClient) DoWithContext(ctx context.Context, req *http.Request) (*http.Response, error) {
|
||||
c.rateLimiter.Wait()
|
||||
|
||||
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", c.getAuthHeader()))
|
||||
req.Header.Set("Accept", "application/json")
|
||||
|
||||
// Check if request has its own context
|
||||
if ctx != context.Background() {
|
||||
req = req.WithContext(ctx)
|
||||
}
|
||||
|
||||
resp, err := c.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -98,7 +118,33 @@ func (c *ProtonMailClient) Do(req *http.Request) (*http.Response, error) {
|
||||
c.rateLimiter.requests = append(c.rateLimiter.requests, time.Now())
|
||||
c.rateLimiter.mu.Unlock()
|
||||
|
||||
// Check for API errors
|
||||
// Check for 401 and attempt refresh
|
||||
if resp.StatusCode == http.StatusUnauthorized {
|
||||
// Close the current response body
|
||||
resp.Body.Close()
|
||||
|
||||
// Attempt to refresh the token
|
||||
if err := c.refreshAuth(); err != nil {
|
||||
return resp, fmt.Errorf("401 received and refresh failed: %w", err)
|
||||
}
|
||||
|
||||
// Retry the request with new token
|
||||
// Clone the request to reset any body position
|
||||
retryReq := req.Clone(ctx)
|
||||
retryReq.Header.Set("Authorization", fmt.Sprintf("Bearer %s", c.getAuthHeader()))
|
||||
|
||||
resp, err = c.httpClient.Do(retryReq)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Record the retry request
|
||||
c.rateLimiter.mu.Lock()
|
||||
c.rateLimiter.requests = append(c.rateLimiter.requests, time.Now())
|
||||
c.rateLimiter.mu.Unlock()
|
||||
}
|
||||
|
||||
// Check for other API errors
|
||||
if resp.StatusCode >= 400 {
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
var apiErr APIError
|
||||
@@ -120,3 +166,14 @@ type APIError struct {
|
||||
func (e *APIError) Error() string {
|
||||
return fmt.Sprintf("API error %d: %s", e.HTTPStatus, e.Message)
|
||||
}
|
||||
|
||||
// Helper function to create a request with context
|
||||
func NewRequestWithContext(ctx context.Context, method, url string, body io.Reader) (*http.Request, error) {
|
||||
req, err := http.NewRequestWithContext(ctx, method, url, body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Accept", "application/json")
|
||||
return req, nil
|
||||
}
|
||||
|
||||
11
internal/auth/interface.go
Normal file
11
internal/auth/interface.go
Normal file
@@ -0,0 +1,11 @@
|
||||
package auth
|
||||
|
||||
import "context"
|
||||
|
||||
// SessionRefresher defines the interface for refreshing authentication tokens.
|
||||
// This allows the API client to automatically refresh tokens on 401 responses.
|
||||
type SessionRefresher interface {
|
||||
RefreshToken() error
|
||||
RefreshTokenWithContext(ctx context.Context) error
|
||||
GetSession() (*Session, error)
|
||||
}
|
||||
@@ -2,6 +2,7 @@ package auth
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/aes"
|
||||
"crypto/cipher"
|
||||
"crypto/rand"
|
||||
@@ -54,7 +55,7 @@ func NewSessionManager() (*SessionManager, error) {
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (m *SessionManager) LoginWithCredentials(apiBaseURL, email, password, mailPassphrase string) error {
|
||||
func (m *SessionManager) LoginWithCredentials(ctx context.Context, apiBaseURL, email, password, mailPassphrase string) error {
|
||||
if err := os.MkdirAll(m.configDir, 0700); err != nil {
|
||||
return fmt.Errorf("failed to create config dir: %w", err)
|
||||
}
|
||||
@@ -75,7 +76,7 @@ func (m *SessionManager) LoginWithCredentials(apiBaseURL, email, password, mailP
|
||||
return fmt.Errorf("failed to marshal auth payload: %w", err)
|
||||
}
|
||||
|
||||
req, err := http.NewRequest("POST", authURL, bytes.NewBuffer(jsonData))
|
||||
req, err := http.NewRequestWithContext(ctx, "POST", authURL, bytes.NewBuffer(jsonData))
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create auth request: %w", err)
|
||||
}
|
||||
@@ -95,11 +96,11 @@ func (m *SessionManager) LoginWithCredentials(apiBaseURL, email, password, mailP
|
||||
}
|
||||
|
||||
var authResponse struct {
|
||||
UID string `json:"UID"`
|
||||
AccessToken string `json:"AccessToken"`
|
||||
RefreshToken string `json:"RefreshToken"`
|
||||
ExpiresIn int `json:"ExpiresIn"`
|
||||
TwoFARequired bool `json:"TwoFARequired"`
|
||||
UID string `json:"UID"`
|
||||
AccessToken string `json:"AccessToken"`
|
||||
RefreshToken string `json:"RefreshToken"`
|
||||
ExpiresIn int `json:"ExpiresIn"`
|
||||
TwoFARequired bool `json:"TwoFARequired"`
|
||||
}
|
||||
|
||||
if err := json.NewDecoder(resp.Body).Decode(&authResponse); err != nil {
|
||||
@@ -135,7 +136,7 @@ func (m *SessionManager) LoginWithCredentials(apiBaseURL, email, password, mailP
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *SessionManager) LoginInteractive(apiBaseURL string) error {
|
||||
func (m *SessionManager) LoginInteractive(ctx context.Context, apiBaseURL string) error {
|
||||
if err := os.MkdirAll(m.configDir, 0700); err != nil {
|
||||
return fmt.Errorf("failed to create config dir: %w", err)
|
||||
}
|
||||
@@ -192,7 +193,7 @@ func (m *SessionManager) LoginInteractive(apiBaseURL string) error {
|
||||
return fmt.Errorf("failed to marshal auth payload: %w", err)
|
||||
}
|
||||
|
||||
req, err := http.NewRequest("POST", authURL, bytes.NewBuffer(jsonData))
|
||||
req, err := http.NewRequestWithContext(ctx, "POST", authURL, bytes.NewBuffer(jsonData))
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create auth request: %w", err)
|
||||
}
|
||||
@@ -263,7 +264,7 @@ func (m *SessionManager) LoginInteractive(apiBaseURL string) error {
|
||||
return fmt.Errorf("failed to marshal TOTP payload: %w", err)
|
||||
}
|
||||
|
||||
totpReq, err := http.NewRequest("POST", totpURL, bytes.NewBuffer(totpJSON))
|
||||
totpReq, err := http.NewRequestWithContext(ctx, "POST", totpURL, bytes.NewBuffer(totpJSON))
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create TOTP request: %w", err)
|
||||
}
|
||||
@@ -372,6 +373,10 @@ func (m *SessionManager) IsAuthenticated() (bool, error) {
|
||||
}
|
||||
|
||||
func (m *SessionManager) RefreshToken() error {
|
||||
return m.RefreshTokenWithContext(context.Background())
|
||||
}
|
||||
|
||||
func (m *SessionManager) RefreshTokenWithContext(ctx context.Context) error {
|
||||
session, err := m.GetSession()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get session: %w", err)
|
||||
@@ -394,7 +399,7 @@ func (m *SessionManager) RefreshToken() error {
|
||||
return fmt.Errorf("failed to marshal refresh payload: %w", err)
|
||||
}
|
||||
|
||||
req, err := http.NewRequest("POST", refreshURL, bytes.NewBuffer(jsonData))
|
||||
req, err := http.NewRequestWithContext(ctx, "POST", refreshURL, bytes.NewBuffer(jsonData))
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create refresh request: %w", err)
|
||||
}
|
||||
|
||||
@@ -87,7 +87,7 @@ func newTestClient(t *testing.T, srv *mockServer) *Client {
|
||||
RateLimitReq: 100,
|
||||
RateLimitWin: 60,
|
||||
}
|
||||
apiClient := api.NewProtonMailClient(cfg)
|
||||
apiClient := api.NewProtonMailClient(cfg, nil)
|
||||
apiClient.SetAuthHeader("test-token")
|
||||
return NewClient(apiClient)
|
||||
}
|
||||
@@ -294,8 +294,8 @@ func TestListMessages_APIError(t *testing.T) {
|
||||
client := newTestClient(t, srv)
|
||||
|
||||
srv.Handle("POST /api/messages", func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusUnauthorized)
|
||||
fmt.Fprintf(w, `{"Code":401,"Message":"invalid token"}`)
|
||||
w.WriteHeader(http.StatusForbidden)
|
||||
fmt.Fprintf(w, `{"Code":403,"Message":"invalid token"}`)
|
||||
})
|
||||
|
||||
_, err := client.ListMessages(ListMessagesRequest{
|
||||
@@ -304,7 +304,7 @@ func TestListMessages_APIError(t *testing.T) {
|
||||
Passphrase: "pass",
|
||||
})
|
||||
if err == nil {
|
||||
t.Fatal("expected error for 401 response")
|
||||
t.Fatal("expected error for 403 response")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "invalid token") {
|
||||
t.Errorf("expected 'invalid token' in error, got: %s", err.Error())
|
||||
@@ -404,6 +404,9 @@ func TestGetMessage_NotFound(t *testing.T) {
|
||||
if err == nil {
|
||||
t.Fatal("expected error for 404")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "message not found") {
|
||||
t.Errorf("expected 'message not found' in error, got: %s", err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetMessage_DecryptBody(t *testing.T) {
|
||||
@@ -417,7 +420,7 @@ func TestGetMessage_DecryptBody(t *testing.T) {
|
||||
RateLimitReq: 100,
|
||||
RateLimitWin: 60,
|
||||
}
|
||||
apiClient := api.NewProtonMailClient(cfg)
|
||||
apiClient := api.NewProtonMailClient(cfg, nil)
|
||||
apiClient.SetAuthHeader("test-token")
|
||||
client := NewClient(apiClient)
|
||||
client.SetPGPService(svc)
|
||||
@@ -488,7 +491,7 @@ func TestSend_WithPGP(t *testing.T) {
|
||||
RateLimitReq: 100,
|
||||
RateLimitWin: 60,
|
||||
}
|
||||
apiClient := api.NewProtonMailClient(cfg)
|
||||
apiClient := api.NewProtonMailClient(cfg, nil)
|
||||
apiClient.SetAuthHeader("test-token")
|
||||
client := NewClient(apiClient)
|
||||
client.SetPGPService(svc)
|
||||
@@ -1089,7 +1092,7 @@ func TestAuthHeader_Propagated(t *testing.T) {
|
||||
RateLimitReq: 100,
|
||||
RateLimitWin: 60,
|
||||
}
|
||||
apiClient := api.NewProtonMailClient(cfg)
|
||||
apiClient := api.NewProtonMailClient(cfg, nil)
|
||||
apiClient.SetAuthHeader("my-test-token")
|
||||
client := NewClient(apiClient)
|
||||
|
||||
@@ -1324,7 +1327,7 @@ func TestListMessages_Timeout(t *testing.T) {
|
||||
RateLimitReq: 100,
|
||||
RateLimitWin: 60,
|
||||
}
|
||||
apiClient := api.NewProtonMailClient(cfg)
|
||||
apiClient := api.NewProtonMailClient(cfg, nil)
|
||||
apiClient.SetAuthHeader("test-token")
|
||||
client := NewClient(apiClient)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user