- Add SessionRefresher interface for token refresh abstraction - Update ProtonMailClient to auto-refresh on 401 responses - Add DoWithContext method for context-aware HTTP requests - Update SessionManager with RefreshTokenWithContext method - Update LoginWithCredentials and LoginInteractive to accept context - Add checkAuthenticatedWithManager helper for commands needing session manager - All API methods now support proper cancellation via context.Context Files changed: - internal/api/client.go - Auto-refresh on 401, context support - internal/auth/session.go - Context-aware refresh and login methods - internal/auth/interface.go - SessionRefresher interface - cmd/mail.go, cmd/draft.go, cmd/folders.go - Updated to use new helpers - cmd/auth.go - Context support for login commands Co-Authored-By: Paperclip <noreply@paperclip.ing>
538 lines
15 KiB
Go
538 lines
15 KiB
Go
package auth
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"crypto/aes"
|
|
"crypto/cipher"
|
|
"crypto/rand"
|
|
"encoding/base64"
|
|
"encoding/json"
|
|
"fmt"
|
|
"io"
|
|
"net/http"
|
|
"os"
|
|
"path/filepath"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/99designs/keyring"
|
|
"github.com/frenocorp/pop/internal/config"
|
|
"github.com/manifoldco/promptui"
|
|
)
|
|
|
|
type Session struct {
|
|
UID string `json:"uid"`
|
|
AccessToken string `json:"access_token"`
|
|
RefreshToken string `json:"refresh_token"`
|
|
ExpiresAt int64 `json:"expires_at"`
|
|
TwoFAEnabled bool `json:"two_factor_enabled"`
|
|
MailPassphrase string `json:"mail_passphrase,omitempty"`
|
|
}
|
|
|
|
type SessionManager struct {
|
|
configDir string
|
|
sessionFile string
|
|
keyring keyring.Keyring
|
|
}
|
|
|
|
func NewSessionManager() (*SessionManager, error) {
|
|
cfg := config.NewConfigManager()
|
|
configDir := cfg.ConfigDir()
|
|
|
|
k, err := keyring.Open(keyring.Config{
|
|
ServiceName: "pop-cli",
|
|
FileDir: filepath.Join(configDir, "keyring"),
|
|
})
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to open keyring: %w", err)
|
|
}
|
|
|
|
return &SessionManager{
|
|
configDir: configDir,
|
|
sessionFile: filepath.Join(configDir, "session.json"),
|
|
keyring: k,
|
|
}, nil
|
|
}
|
|
|
|
func (m *SessionManager) LoginWithCredentials(ctx context.Context, apiBaseURL, email, password, mailPassphrase string) error {
|
|
if err := os.MkdirAll(m.configDir, 0700); err != nil {
|
|
return fmt.Errorf("failed to create config dir: %w", err)
|
|
}
|
|
|
|
if err := os.MkdirAll(filepath.Join(m.configDir, "keyring"), 0700); err != nil {
|
|
return fmt.Errorf("failed to create keyring dir: %w", err)
|
|
}
|
|
|
|
authURL := fmt.Sprintf("%s/auth", apiBaseURL)
|
|
|
|
payload := map[string]string{
|
|
"Email": email,
|
|
"Password": password,
|
|
}
|
|
|
|
jsonData, err := json.Marshal(payload)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to marshal auth payload: %w", err)
|
|
}
|
|
|
|
req, err := http.NewRequestWithContext(ctx, "POST", authURL, bytes.NewBuffer(jsonData))
|
|
if err != nil {
|
|
return fmt.Errorf("failed to create auth request: %w", err)
|
|
}
|
|
req.Header.Set("Content-Type", "application/json")
|
|
req.Header.Set("Accept", "application/json")
|
|
|
|
client := &http.Client{Timeout: 30 * time.Second}
|
|
resp, err := client.Do(req)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to connect to ProtonMail API: %w", err)
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
if resp.StatusCode != http.StatusOK {
|
|
body, _ := io.ReadAll(resp.Body)
|
|
return fmt.Errorf("authentication failed (status %d): %s", resp.StatusCode, string(body))
|
|
}
|
|
|
|
var authResponse struct {
|
|
UID string `json:"UID"`
|
|
AccessToken string `json:"AccessToken"`
|
|
RefreshToken string `json:"RefreshToken"`
|
|
ExpiresIn int `json:"ExpiresIn"`
|
|
TwoFARequired bool `json:"TwoFARequired"`
|
|
}
|
|
|
|
if err := json.NewDecoder(resp.Body).Decode(&authResponse); err != nil {
|
|
return fmt.Errorf("failed to parse auth response: %w", err)
|
|
}
|
|
|
|
session := Session{
|
|
UID: authResponse.UID,
|
|
AccessToken: authResponse.AccessToken,
|
|
RefreshToken: authResponse.RefreshToken,
|
|
ExpiresAt: time.Now().Unix() + int64(authResponse.ExpiresIn),
|
|
TwoFAEnabled: authResponse.TwoFARequired,
|
|
MailPassphrase: mailPassphrase,
|
|
}
|
|
|
|
encryptedForFile, err := encryptSession(session)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to encrypt session: %w", err)
|
|
}
|
|
|
|
if err := m.keyring.Set(keyring.Item{
|
|
Key: "session",
|
|
Data: encryptedForFile,
|
|
}); err != nil {
|
|
return fmt.Errorf("failed to store session in keyring: %w", err)
|
|
}
|
|
|
|
if err := os.WriteFile(m.sessionFile, encryptedForFile, 0600); err != nil {
|
|
return fmt.Errorf("failed to write encrypted session file: %w", err)
|
|
}
|
|
|
|
fmt.Println("Logged in successfully")
|
|
return nil
|
|
}
|
|
|
|
func (m *SessionManager) LoginInteractive(ctx context.Context, apiBaseURL string) error {
|
|
if err := os.MkdirAll(m.configDir, 0700); err != nil {
|
|
return fmt.Errorf("failed to create config dir: %w", err)
|
|
}
|
|
|
|
if err := os.MkdirAll(filepath.Join(m.configDir, "keyring"), 0700); err != nil {
|
|
return fmt.Errorf("failed to create keyring dir: %w", err)
|
|
}
|
|
|
|
emailPrompt := promptui.Prompt{
|
|
Label: "ProtonMail email",
|
|
Validate: func(input string) error {
|
|
if !strings.Contains(input, "@") || !strings.Contains(input, ".") {
|
|
return fmt.Errorf("invalid email format")
|
|
}
|
|
parts := strings.Split(input, "@")
|
|
if len(parts) != 2 || len(parts[0]) == 0 || len(parts[1]) < 3 {
|
|
return fmt.Errorf("invalid email format")
|
|
}
|
|
return nil
|
|
},
|
|
}
|
|
email, err := emailPrompt.Run()
|
|
if err != nil {
|
|
return fmt.Errorf("failed to read email: %w", err)
|
|
}
|
|
|
|
passwordPrompt := promptui.Prompt{
|
|
Label: "ProtonMail password",
|
|
Mask: '*',
|
|
}
|
|
password, err := passwordPrompt.Run()
|
|
if err != nil {
|
|
return fmt.Errorf("failed to read password: %w", err)
|
|
}
|
|
|
|
passphrasePrompt := promptui.Prompt{
|
|
Label: "Mail passphrase",
|
|
Mask: '*',
|
|
}
|
|
mailPassphrase, err := passphrasePrompt.Run()
|
|
if err != nil {
|
|
return fmt.Errorf("failed to read mail passphrase: %w", err)
|
|
}
|
|
|
|
authURL := fmt.Sprintf("%s/auth", apiBaseURL)
|
|
|
|
payload := map[string]string{
|
|
"Email": email,
|
|
"Password": password,
|
|
}
|
|
|
|
jsonData, err := json.Marshal(payload)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to marshal auth payload: %w", err)
|
|
}
|
|
|
|
req, err := http.NewRequestWithContext(ctx, "POST", authURL, bytes.NewBuffer(jsonData))
|
|
if err != nil {
|
|
return fmt.Errorf("failed to create auth request: %w", err)
|
|
}
|
|
req.Header.Set("Content-Type", "application/json")
|
|
req.Header.Set("Accept", "application/json")
|
|
|
|
client := &http.Client{Timeout: 30 * time.Second}
|
|
resp, err := client.Do(req)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to connect to ProtonMail API: %w", err)
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
if resp.StatusCode != http.StatusOK {
|
|
body, _ := io.ReadAll(resp.Body)
|
|
return fmt.Errorf("authentication failed (status %d): %s", resp.StatusCode, string(body))
|
|
}
|
|
|
|
var authResponse struct {
|
|
UID string `json:"UID"`
|
|
AccessToken string `json:"AccessToken"`
|
|
RefreshToken string `json:"RefreshToken"`
|
|
ExpiresIn int `json:"ExpiresIn"`
|
|
TwoFARequired bool `json:"TwoFARequired"`
|
|
TwoFAChallenge struct {
|
|
Type string `json:"Type"`
|
|
} `json:"TwoFAChallenge"`
|
|
}
|
|
|
|
if err := json.NewDecoder(resp.Body).Decode(&authResponse); err != nil {
|
|
return fmt.Errorf("failed to parse auth response: %w", err)
|
|
}
|
|
|
|
session := Session{
|
|
UID: authResponse.UID,
|
|
AccessToken: authResponse.AccessToken,
|
|
RefreshToken: authResponse.RefreshToken,
|
|
ExpiresAt: time.Now().Unix() + int64(authResponse.ExpiresIn),
|
|
TwoFAEnabled: authResponse.TwoFARequired,
|
|
MailPassphrase: mailPassphrase,
|
|
}
|
|
|
|
if session.TwoFAEnabled {
|
|
fmt.Println("\n2FA authentication required")
|
|
|
|
totpPrompt := promptui.Prompt{
|
|
Label: "Enter TOTP code",
|
|
Validate: func(input string) error {
|
|
if len(input) != 6 {
|
|
return fmt.Errorf("TOTP code must be 6 digits")
|
|
}
|
|
return nil
|
|
},
|
|
}
|
|
totpCode, err := totpPrompt.Run()
|
|
if err != nil {
|
|
return fmt.Errorf("failed to read TOTP code: %w", err)
|
|
}
|
|
|
|
totpURL := fmt.Sprintf("%s/auth/verify", apiBaseURL)
|
|
totpPayload := map[string]string{
|
|
"UID": session.UID,
|
|
"Code": totpCode,
|
|
}
|
|
|
|
totpJSON, err := json.Marshal(totpPayload)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to marshal TOTP payload: %w", err)
|
|
}
|
|
|
|
totpReq, err := http.NewRequestWithContext(ctx, "POST", totpURL, bytes.NewBuffer(totpJSON))
|
|
if err != nil {
|
|
return fmt.Errorf("failed to create TOTP request: %w", err)
|
|
}
|
|
totpReq.Header.Set("Authorization", fmt.Sprintf("Bearer %s", session.AccessToken))
|
|
totpReq.Header.Set("Content-Type", "application/json")
|
|
|
|
totpResp, err := client.Do(totpReq)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to verify TOTP: %w", err)
|
|
}
|
|
defer totpResp.Body.Close()
|
|
|
|
if totpResp.StatusCode != http.StatusOK {
|
|
body, _ := io.ReadAll(totpResp.Body)
|
|
return fmt.Errorf("TOTP verification failed (status %d): %s", totpResp.StatusCode, string(body))
|
|
}
|
|
|
|
var finalAuth struct {
|
|
AccessToken string `json:"AccessToken"`
|
|
RefreshToken string `json:"RefreshToken"`
|
|
ExpiresIn int `json:"ExpiresIn"`
|
|
}
|
|
|
|
if err := json.NewDecoder(totpResp.Body).Decode(&finalAuth); err != nil {
|
|
return fmt.Errorf("failed to parse TOTP response: %w", err)
|
|
}
|
|
|
|
session.AccessToken = finalAuth.AccessToken
|
|
session.RefreshToken = finalAuth.RefreshToken
|
|
session.ExpiresAt = time.Now().Unix() + int64(finalAuth.ExpiresIn)
|
|
}
|
|
|
|
encryptedForFile, err := encryptSession(session)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to encrypt session: %w", err)
|
|
}
|
|
|
|
if err := m.keyring.Set(keyring.Item{
|
|
Key: "session",
|
|
Data: encryptedForFile,
|
|
}); err != nil {
|
|
return fmt.Errorf("failed to store session in keyring: %w", err)
|
|
}
|
|
|
|
if err := os.WriteFile(m.sessionFile, encryptedForFile, 0600); err != nil {
|
|
return fmt.Errorf("failed to write encrypted session file: %w", err)
|
|
}
|
|
|
|
fmt.Println("Logged in successfully")
|
|
return nil
|
|
}
|
|
|
|
func (m *SessionManager) Logout() error {
|
|
if err := os.Remove(m.sessionFile); err != nil {
|
|
return fmt.Errorf("failed to remove session file: %w", err)
|
|
}
|
|
|
|
if err := m.keyring.Remove("session"); err != nil {
|
|
return fmt.Errorf("failed to remove keyring entry: %w", err)
|
|
}
|
|
|
|
fmt.Println("Logged out successfully")
|
|
return nil
|
|
}
|
|
|
|
func (m *SessionManager) GetSession() (*Session, error) {
|
|
// First, try to get from keyring (encrypted storage)
|
|
item, err := m.keyring.Get("session")
|
|
if err == nil {
|
|
session, err := decryptSession(item.Data)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to decrypt session from keyring: %w", err)
|
|
}
|
|
return &session, nil
|
|
}
|
|
|
|
// If not in keyring, read from encrypted file
|
|
data, err := os.ReadFile(m.sessionFile)
|
|
if err != nil {
|
|
if err == os.ErrNotExist {
|
|
return nil, fmt.Errorf("no session found: %w", err)
|
|
}
|
|
return nil, fmt.Errorf("failed to read session file: %w", err)
|
|
}
|
|
|
|
// Decrypt the session data
|
|
session, err := decryptSession(data)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to decrypt session: %w", err)
|
|
}
|
|
|
|
return &session, nil
|
|
}
|
|
|
|
func (m *SessionManager) IsAuthenticated() (bool, error) {
|
|
session, err := m.GetSession()
|
|
if err != nil {
|
|
return false, err
|
|
}
|
|
|
|
if time.Now().Unix() > session.ExpiresAt {
|
|
return false, fmt.Errorf("session expired at %d", session.ExpiresAt)
|
|
}
|
|
|
|
return true, nil
|
|
}
|
|
|
|
func (m *SessionManager) RefreshToken() error {
|
|
return m.RefreshTokenWithContext(context.Background())
|
|
}
|
|
|
|
func (m *SessionManager) RefreshTokenWithContext(ctx context.Context) error {
|
|
session, err := m.GetSession()
|
|
if err != nil {
|
|
return fmt.Errorf("failed to get session: %w", err)
|
|
}
|
|
|
|
if session.RefreshToken == "" {
|
|
return fmt.Errorf("no refresh token available")
|
|
}
|
|
|
|
apiBaseURL := "https://api.protonmail.ch"
|
|
refreshURL := fmt.Sprintf("%s/auth/refresh", apiBaseURL)
|
|
|
|
payload := map[string]string{
|
|
"UID": session.UID,
|
|
"RefreshToken": session.RefreshToken,
|
|
}
|
|
|
|
jsonData, err := json.Marshal(payload)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to marshal refresh payload: %w", err)
|
|
}
|
|
|
|
req, err := http.NewRequestWithContext(ctx, "POST", refreshURL, bytes.NewBuffer(jsonData))
|
|
if err != nil {
|
|
return fmt.Errorf("failed to create refresh request: %w", err)
|
|
}
|
|
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", session.AccessToken))
|
|
req.Header.Set("Content-Type", "application/json")
|
|
req.Header.Set("Accept", "application/json")
|
|
|
|
client := &http.Client{Timeout: 30 * time.Second}
|
|
resp, err := client.Do(req)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to connect to ProtonMail API: %w", err)
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
if resp.StatusCode != http.StatusOK {
|
|
body, _ := io.ReadAll(resp.Body)
|
|
return fmt.Errorf("token refresh failed (status %d): %s", resp.StatusCode, string(body))
|
|
}
|
|
|
|
var refreshResponse struct {
|
|
AccessToken string `json:"AccessToken"`
|
|
RefreshToken string `json:"RefreshToken"`
|
|
ExpiresIn int `json:"ExpiresIn"`
|
|
}
|
|
|
|
if err := json.NewDecoder(resp.Body).Decode(&refreshResponse); err != nil {
|
|
return fmt.Errorf("failed to parse refresh response: %w", err)
|
|
}
|
|
|
|
session.AccessToken = refreshResponse.AccessToken
|
|
session.RefreshToken = refreshResponse.RefreshToken
|
|
session.ExpiresAt = time.Now().Unix() + int64(refreshResponse.ExpiresIn)
|
|
|
|
encryptedForFile, err := encryptSession(*session)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to encrypt updated session: %w", err)
|
|
}
|
|
|
|
if err := m.keyring.Set(keyring.Item{
|
|
Key: "session",
|
|
Data: encryptedForFile,
|
|
}); err != nil {
|
|
return fmt.Errorf("failed to update session in keyring: %w", err)
|
|
}
|
|
|
|
_ = os.WriteFile(m.sessionFile, encryptedForFile, 0600)
|
|
|
|
return nil
|
|
}
|
|
|
|
// encryptSession encrypts the session data using AES-256-GCM
|
|
func encryptSession(session Session) ([]byte, error) {
|
|
// Generate a random 256-bit key
|
|
key := make([]byte, 32)
|
|
if _, err := rand.Read(key); err != nil {
|
|
return nil, fmt.Errorf("failed to generate encryption key: %w", err)
|
|
}
|
|
|
|
// Generate a random 12-byte nonce
|
|
nonce := make([]byte, 12)
|
|
if _, err := rand.Read(nonce); err != nil {
|
|
return nil, fmt.Errorf("failed to generate nonce: %w", err)
|
|
}
|
|
|
|
// Create AES-GCM cipher
|
|
block, err := aes.NewCipher(key)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to create AES cipher: %w", err)
|
|
}
|
|
|
|
aead, err := cipher.NewGCM(block)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to create GCM: %w", err)
|
|
}
|
|
|
|
// Encrypt the session data
|
|
sessionData, err := json.Marshal(session)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to marshal session for encryption: %w", err)
|
|
}
|
|
|
|
sealedData := aead.Seal(nil, nonce, sessionData, nil)
|
|
|
|
// Prepend key and nonce (base64 encoded for readability in file)
|
|
header := fmt.Sprintf("%s|%s|", base64.StdEncoding.EncodeToString(key), base64.StdEncoding.EncodeToString(nonce))
|
|
return []byte(header + string(sealedData)), nil
|
|
}
|
|
|
|
// decryptSession decrypts the session data
|
|
func decryptSession(encryptedData []byte) (Session, error) {
|
|
// Split header and encrypted data
|
|
parts := strings.Split(string(encryptedData), "|")
|
|
if len(parts) != 3 {
|
|
return Session{}, fmt.Errorf("invalid encrypted data format")
|
|
}
|
|
|
|
// Decode key and nonce
|
|
key, err := base64.StdEncoding.DecodeString(parts[0])
|
|
if err != nil {
|
|
return Session{}, fmt.Errorf("failed to decode key: %w", err)
|
|
}
|
|
|
|
nonce, err := base64.StdEncoding.DecodeString(parts[1])
|
|
if err != nil {
|
|
return Session{}, fmt.Errorf("failed to decode nonce: %w", err)
|
|
}
|
|
|
|
// Decrypt
|
|
block, err := aes.NewCipher(key)
|
|
if err != nil {
|
|
return Session{}, fmt.Errorf("failed to create AES cipher: %w", err)
|
|
}
|
|
|
|
aead, err := cipher.NewGCM(block)
|
|
if err != nil {
|
|
return Session{}, fmt.Errorf("failed to create GCM: %w", err)
|
|
}
|
|
|
|
sealedData, err := base64.StdEncoding.DecodeString(parts[2])
|
|
if err != nil {
|
|
return Session{}, fmt.Errorf("failed to decode sealed data: %w", err)
|
|
}
|
|
|
|
data, err := aead.Open(nil, nonce, sealedData, nil)
|
|
if err != nil {
|
|
return Session{}, fmt.Errorf("failed to decrypt session: %w", err)
|
|
}
|
|
|
|
var session Session
|
|
if err := json.Unmarshal(data, &session); err != nil {
|
|
return Session{}, fmt.Errorf("failed to unmarshal session: %w", err)
|
|
}
|
|
|
|
return session, nil
|
|
}
|