Files
pop/internal/auth/session.go

461 lines
12 KiB
Go

package auth
import (
"bytes"
"crypto/aes"
"crypto/cipher"
"crypto/rand"
"encoding/base64"
"encoding/json"
"fmt"
"io"
"net/http"
"os"
"path/filepath"
"strings"
"time"
"github.com/99designs/keyring"
"github.com/frenocorp/pop/internal/config"
"github.com/manifoldco/promptui"
)
type Session struct {
UID string `json:"uid"`
AccessToken string `json:"access_token"`
RefreshToken string `json:"refresh_token"`
ExpiresAt int64 `json:"expires_at"`
TwoFAEnabled bool `json:"two_factor_enabled"`
}
type SessionManager struct {
configDir string
sessionFile string
keyring keyring.Keyring
}
func NewSessionManager() (*SessionManager, error) {
cfg := config.NewConfigManager()
configDir := cfg.ConfigDir()
k, err := keyring.Open(keyring.Config{
ServiceName: "pop-cli",
FileDir: filepath.Join(configDir, "keyring"),
})
if err != nil {
return nil, fmt.Errorf("failed to open keyring: %w", err)
}
return &SessionManager{
configDir: configDir,
sessionFile: filepath.Join(configDir, "session.json"),
keyring: k,
}, nil
}
func (m *SessionManager) LoginWithCredentials(apiBaseURL, email, password string) error {
if err := os.MkdirAll(m.configDir, 0700); err != nil {
return fmt.Errorf("failed to create config dir: %w", err)
}
if err := os.MkdirAll(filepath.Join(m.configDir, "keyring"), 0700); err != nil {
return fmt.Errorf("failed to create keyring dir: %w", err)
}
authURL := fmt.Sprintf("%s/auth", apiBaseURL)
payload := map[string]string{
"Email": email,
"Password": password,
}
jsonData, err := json.Marshal(payload)
if err != nil {
return fmt.Errorf("failed to marshal auth payload: %w", err)
}
req, err := http.NewRequest("POST", authURL, bytes.NewBuffer(jsonData))
if err != nil {
return fmt.Errorf("failed to create auth request: %w", err)
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Accept", "application/json")
client := &http.Client{Timeout: 30 * time.Second}
resp, err := client.Do(req)
if err != nil {
return fmt.Errorf("failed to connect to ProtonMail API: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(resp.Body)
return fmt.Errorf("authentication failed (status %d): %s", resp.StatusCode, string(body))
}
var authResponse struct {
UID string `json:"UID"`
AccessToken string `json:"AccessToken"`
RefreshToken string `json:"RefreshToken"`
ExpiresIn int `json:"ExpiresIn"`
TwoFARequired bool `json:"TwoFARequired"`
}
if err := json.NewDecoder(resp.Body).Decode(&authResponse); err != nil {
return fmt.Errorf("failed to parse auth response: %w", err)
}
session := Session{
UID: authResponse.UID,
AccessToken: authResponse.AccessToken,
RefreshToken: authResponse.RefreshToken,
ExpiresAt: time.Now().Unix() + int64(authResponse.ExpiresIn),
TwoFAEnabled: authResponse.TwoFARequired,
}
encryptedData, err := encryptSession(session)
if err != nil {
return fmt.Errorf("failed to encrypt session: %w", err)
}
data, err := json.MarshalIndent(session, "", " ")
if err != nil {
return fmt.Errorf("failed to marshal session: %w", err)
}
if err := m.keyring.Set(keyring.Item{
Key: "session",
Data: data,
}); err != nil {
return fmt.Errorf("failed to store session in keyring: %w", err)
}
if err := os.WriteFile(m.sessionFile, encryptedData, 0600); err != nil {
return fmt.Errorf("failed to write encrypted session file: %w", err)
}
fmt.Println("Logged in successfully")
return nil
}
func (m *SessionManager) LoginInteractive(apiBaseURL string) error {
if err := os.MkdirAll(m.configDir, 0700); err != nil {
return fmt.Errorf("failed to create config dir: %w", err)
}
if err := os.MkdirAll(filepath.Join(m.configDir, "keyring"), 0700); err != nil {
return fmt.Errorf("failed to create keyring dir: %w", err)
}
emailPrompt := promptui.Prompt{
Label: "ProtonMail email",
Validate: func(input string) error {
if !strings.Contains(input, "@") {
return fmt.Errorf("invalid email format")
}
return nil
},
}
email, err := emailPrompt.Run()
if err != nil {
return fmt.Errorf("failed to read email: %w", err)
}
passwordPrompt := promptui.Prompt{
Label: "ProtonMail password",
Mask: '*',
}
password, err := passwordPrompt.Run()
if err != nil {
return fmt.Errorf("failed to read password: %w", err)
}
authURL := fmt.Sprintf("%s/auth", apiBaseURL)
payload := map[string]string{
"Email": email,
"Password": password,
}
jsonData, err := json.Marshal(payload)
if err != nil {
return fmt.Errorf("failed to marshal auth payload: %w", err)
}
req, err := http.NewRequest("POST", authURL, bytes.NewBuffer(jsonData))
if err != nil {
return fmt.Errorf("failed to create auth request: %w", err)
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Accept", "application/json")
client := &http.Client{Timeout: 30 * time.Second}
resp, err := client.Do(req)
if err != nil {
return fmt.Errorf("failed to connect to ProtonMail API: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(resp.Body)
return fmt.Errorf("authentication failed (status %d): %s", resp.StatusCode, string(body))
}
var authResponse struct {
UID string `json:"UID"`
AccessToken string `json:"AccessToken"`
RefreshToken string `json:"RefreshToken"`
ExpiresIn int `json:"ExpiresIn"`
TwoFARequired bool `json:"TwoFARequired"`
TwoFAChallenge struct {
Type string `json:"Type"`
} `json:"TwoFAChallenge"`
}
if err := json.NewDecoder(resp.Body).Decode(&authResponse); err != nil {
return fmt.Errorf("failed to parse auth response: %w", err)
}
session := Session{
UID: authResponse.UID,
AccessToken: authResponse.AccessToken,
RefreshToken: authResponse.RefreshToken,
ExpiresAt: time.Now().Unix() + int64(authResponse.ExpiresIn),
TwoFAEnabled: authResponse.TwoFARequired,
}
if session.TwoFAEnabled {
fmt.Println("\n2FA authentication required")
totpPrompt := promptui.Prompt{
Label: "Enter TOTP code",
Validate: func(input string) error {
if len(input) != 6 {
return fmt.Errorf("TOTP code must be 6 digits")
}
return nil
},
}
totpCode, err := totpPrompt.Run()
if err != nil {
return fmt.Errorf("failed to read TOTP code: %w", err)
}
totpURL := fmt.Sprintf("%s/auth/verify", apiBaseURL)
totpPayload := map[string]string{
"UID": session.UID,
"Code": totpCode,
}
totpJSON, err := json.Marshal(totpPayload)
if err != nil {
return fmt.Errorf("failed to marshal TOTP payload: %w", err)
}
totpReq, err := http.NewRequest("POST", totpURL, bytes.NewBuffer(totpJSON))
if err != nil {
return fmt.Errorf("failed to create TOTP request: %w", err)
}
totpReq.Header.Set("Authorization", fmt.Sprintf("Bearer %s", session.AccessToken))
totpReq.Header.Set("Content-Type", "application/json")
totpResp, err := client.Do(totpReq)
if err != nil {
return fmt.Errorf("failed to verify TOTP: %w", err)
}
defer totpResp.Body.Close()
if totpResp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(totpResp.Body)
return fmt.Errorf("TOTP verification failed (status %d): %s", totpResp.StatusCode, string(body))
}
var finalAuth struct {
AccessToken string `json:"AccessToken"`
RefreshToken string `json:"RefreshToken"`
ExpiresIn int `json:"ExpiresIn"`
}
if err := json.NewDecoder(totpResp.Body).Decode(&finalAuth); err != nil {
return fmt.Errorf("failed to parse TOTP response: %w", err)
}
session.AccessToken = finalAuth.AccessToken
session.RefreshToken = finalAuth.RefreshToken
session.ExpiresAt = time.Now().Unix() + int64(finalAuth.ExpiresIn)
}
encryptedData, err := encryptSession(session)
if err != nil {
return fmt.Errorf("failed to encrypt session: %w", err)
}
data, err := json.MarshalIndent(session, "", " ")
if err != nil {
return fmt.Errorf("failed to marshal session: %w", err)
}
if err := m.keyring.Set(keyring.Item{
Key: "session",
Data: data,
}); err != nil {
return fmt.Errorf("failed to store session in keyring: %w", err)
}
if err := os.WriteFile(m.sessionFile, encryptedData, 0600); err != nil {
return fmt.Errorf("failed to write encrypted session file: %w", err)
}
fmt.Println("Logged in successfully")
return nil
}
func (m *SessionManager) Logout() error {
if err := os.Remove(m.sessionFile); err != nil {
return fmt.Errorf("failed to remove session file: %w", err)
}
fmt.Println("Logged out successfully")
return nil
}
func (m *SessionManager) GetSession() (*Session, error) {
// First, try to get from keyring (encrypted storage)
item, err := m.keyring.Get("session")
if err == nil {
var session Session
if err := json.Unmarshal(item.Data, &session); err != nil {
return nil, fmt.Errorf("failed to parse session from keyring: %w", err)
}
return &session, nil
}
// If not in keyring, read from encrypted file
data, err := os.ReadFile(m.sessionFile)
if err != nil {
if err == os.ErrNotExist {
return nil, fmt.Errorf("no session found: %w", err)
}
return nil, fmt.Errorf("failed to read session file: %w", err)
}
// Decrypt the session data
session, err := decryptSession(data)
if err != nil {
return nil, fmt.Errorf("failed to decrypt session: %w", err)
}
return &session, nil
}
func (m *SessionManager) IsAuthenticated() (bool, error) {
session, err := m.GetSession()
if err != nil {
return false, err
}
if time.Now().Unix() > session.ExpiresAt {
return false, fmt.Errorf("session expired at %d", session.ExpiresAt)
}
return true, nil
}
// RefreshToken refreshes the access token using the refresh token
func (m *SessionManager) RefreshToken() error {
_, err := m.GetSession()
if err != nil {
return fmt.Errorf("failed to get session: %w", err)
}
// TODO: Implement actual token refresh with API
// This would make a request to the ProtonMail API to get a new access token
return fmt.Errorf("token refresh not yet implemented - requires API integration")
}
// encryptSession encrypts the session data using AES-256-GCM
func encryptSession(session Session) ([]byte, error) {
// Generate a random 256-bit key
key := make([]byte, 32)
if _, err := rand.Read(key); err != nil {
return nil, fmt.Errorf("failed to generate encryption key: %w", err)
}
// Generate a random 12-byte nonce
nonce := make([]byte, 12)
if _, err := rand.Read(nonce); err != nil {
return nil, fmt.Errorf("failed to generate nonce: %w", err)
}
// Create AES-GCM cipher
block, err := aes.NewCipher(key)
if err != nil {
return nil, fmt.Errorf("failed to create AES cipher: %w", err)
}
aead, err := cipher.NewGCM(block)
if err != nil {
return nil, fmt.Errorf("failed to create GCM: %w", err)
}
// Encrypt the session data
sessionData, err := json.Marshal(session)
if err != nil {
return nil, fmt.Errorf("failed to marshal session for encryption: %w", err)
}
sealedData := aead.Seal(nil, nonce, sessionData, nil)
// Prepend key and nonce (base64 encoded for readability in file)
header := fmt.Sprintf("%s|%s|", base64.StdEncoding.EncodeToString(key), base64.StdEncoding.EncodeToString(nonce))
return []byte(header + string(sealedData)), nil
}
// decryptSession decrypts the session data
func decryptSession(encryptedData []byte) (Session, error) {
// Split header and encrypted data
parts := strings.Split(string(encryptedData), "|")
if len(parts) != 3 {
return Session{}, fmt.Errorf("invalid encrypted data format")
}
// Decode key and nonce
key, err := base64.StdEncoding.DecodeString(parts[0])
if err != nil {
return Session{}, fmt.Errorf("failed to decode key: %w", err)
}
nonce, err := base64.StdEncoding.DecodeString(parts[1])
if err != nil {
return Session{}, fmt.Errorf("failed to decode nonce: %w", err)
}
// Decrypt
block, err := aes.NewCipher(key)
if err != nil {
return Session{}, fmt.Errorf("failed to create AES cipher: %w", err)
}
aead, err := cipher.NewGCM(block)
if err != nil {
return Session{}, fmt.Errorf("failed to create GCM: %w", err)
}
sealedData, err := base64.StdEncoding.DecodeString(parts[2])
if err != nil {
return Session{}, fmt.Errorf("failed to decode sealed data: %w", err)
}
data, err := aead.Open(nil, nonce, sealedData, nil)
if err != nil {
return Session{}, fmt.Errorf("failed to decrypt session: %w", err)
}
var session Session
if err := json.Unmarshal(data, &session); err != nil {
return Session{}, fmt.Errorf("failed to unmarshal session: %w", err)
}
return session, nil
}