package auth import ( "bytes" "context" "crypto/aes" "crypto/cipher" "crypto/rand" "encoding/base64" "encoding/json" "fmt" "io" "net/http" "os" "path/filepath" "strings" "time" "github.com/99designs/keyring" "github.com/frenocorp/pop/internal/config" "github.com/manifoldco/promptui" ) type Session struct { UID string `json:"uid"` AccessToken string `json:"access_token"` RefreshToken string `json:"refresh_token"` ExpiresAt int64 `json:"expires_at"` TwoFAEnabled bool `json:"two_factor_enabled"` MailPassphrase string `json:"mail_passphrase,omitempty"` } type SessionManager struct { configDir string sessionFile string keyring keyring.Keyring } func NewSessionManager() (*SessionManager, error) { cfg := config.NewConfigManager() configDir := cfg.ConfigDir() k, err := keyring.Open(keyring.Config{ ServiceName: "pop-cli", FileDir: filepath.Join(configDir, "keyring"), }) if err != nil { return nil, fmt.Errorf("failed to open keyring: %w", err) } return &SessionManager{ configDir: configDir, sessionFile: filepath.Join(configDir, "session.json"), keyring: k, }, nil } func (m *SessionManager) LoginWithCredentials(ctx context.Context, apiBaseURL, email, password, mailPassphrase string) error { if err := os.MkdirAll(m.configDir, 0700); err != nil { return fmt.Errorf("failed to create config dir: %w", err) } if err := os.MkdirAll(filepath.Join(m.configDir, "keyring"), 0700); err != nil { return fmt.Errorf("failed to create keyring dir: %w", err) } authURL := fmt.Sprintf("%s/auth", apiBaseURL) payload := map[string]string{ "Email": email, "Password": password, } jsonData, err := json.Marshal(payload) if err != nil { return fmt.Errorf("failed to marshal auth payload: %w", err) } req, err := http.NewRequestWithContext(ctx, "POST", authURL, bytes.NewBuffer(jsonData)) if err != nil { return fmt.Errorf("failed to create auth request: %w", err) } req.Header.Set("Content-Type", "application/json") req.Header.Set("Accept", "application/json") client := &http.Client{Timeout: 30 * time.Second} resp, err := client.Do(req) if err != nil { return fmt.Errorf("failed to connect to ProtonMail API: %w", err) } defer resp.Body.Close() if resp.StatusCode != http.StatusOK { body, _ := io.ReadAll(resp.Body) return fmt.Errorf("authentication failed (status %d): %s", resp.StatusCode, string(body)) } var authResponse struct { UID string `json:"UID"` AccessToken string `json:"AccessToken"` RefreshToken string `json:"RefreshToken"` ExpiresIn int `json:"ExpiresIn"` TwoFARequired bool `json:"TwoFARequired"` } if err := json.NewDecoder(resp.Body).Decode(&authResponse); err != nil { return fmt.Errorf("failed to parse auth response: %w", err) } session := Session{ UID: authResponse.UID, AccessToken: authResponse.AccessToken, RefreshToken: authResponse.RefreshToken, ExpiresAt: time.Now().Unix() + int64(authResponse.ExpiresIn), TwoFAEnabled: authResponse.TwoFARequired, MailPassphrase: mailPassphrase, } encryptedForFile, err := encryptSession(session) if err != nil { return fmt.Errorf("failed to encrypt session: %w", err) } if err := m.keyring.Set(keyring.Item{ Key: "session", Data: encryptedForFile, }); err != nil { return fmt.Errorf("failed to store session in keyring: %w", err) } if err := os.WriteFile(m.sessionFile, encryptedForFile, 0600); err != nil { return fmt.Errorf("failed to write encrypted session file: %w", err) } fmt.Println("Logged in successfully") return nil } func (m *SessionManager) LoginInteractive(ctx context.Context, apiBaseURL string) error { if err := os.MkdirAll(m.configDir, 0700); err != nil { return fmt.Errorf("failed to create config dir: %w", err) } if err := os.MkdirAll(filepath.Join(m.configDir, "keyring"), 0700); err != nil { return fmt.Errorf("failed to create keyring dir: %w", err) } emailPrompt := promptui.Prompt{ Label: "ProtonMail email", Validate: func(input string) error { if !strings.Contains(input, "@") || !strings.Contains(input, ".") { return fmt.Errorf("invalid email format") } parts := strings.Split(input, "@") if len(parts) != 2 || len(parts[0]) == 0 || len(parts[1]) < 3 { return fmt.Errorf("invalid email format") } return nil }, } email, err := emailPrompt.Run() if err != nil { return fmt.Errorf("failed to read email: %w", err) } passwordPrompt := promptui.Prompt{ Label: "ProtonMail password", Mask: '*', } password, err := passwordPrompt.Run() if err != nil { return fmt.Errorf("failed to read password: %w", err) } passphrasePrompt := promptui.Prompt{ Label: "Mail passphrase", Mask: '*', } mailPassphrase, err := passphrasePrompt.Run() if err != nil { return fmt.Errorf("failed to read mail passphrase: %w", err) } authURL := fmt.Sprintf("%s/auth", apiBaseURL) payload := map[string]string{ "Email": email, "Password": password, } jsonData, err := json.Marshal(payload) if err != nil { return fmt.Errorf("failed to marshal auth payload: %w", err) } req, err := http.NewRequestWithContext(ctx, "POST", authURL, bytes.NewBuffer(jsonData)) if err != nil { return fmt.Errorf("failed to create auth request: %w", err) } req.Header.Set("Content-Type", "application/json") req.Header.Set("Accept", "application/json") client := &http.Client{Timeout: 30 * time.Second} resp, err := client.Do(req) if err != nil { return fmt.Errorf("failed to connect to ProtonMail API: %w", err) } defer resp.Body.Close() if resp.StatusCode != http.StatusOK { body, _ := io.ReadAll(resp.Body) return fmt.Errorf("authentication failed (status %d): %s", resp.StatusCode, string(body)) } var authResponse struct { UID string `json:"UID"` AccessToken string `json:"AccessToken"` RefreshToken string `json:"RefreshToken"` ExpiresIn int `json:"ExpiresIn"` TwoFARequired bool `json:"TwoFARequired"` TwoFAChallenge struct { Type string `json:"Type"` } `json:"TwoFAChallenge"` } if err := json.NewDecoder(resp.Body).Decode(&authResponse); err != nil { return fmt.Errorf("failed to parse auth response: %w", err) } session := Session{ UID: authResponse.UID, AccessToken: authResponse.AccessToken, RefreshToken: authResponse.RefreshToken, ExpiresAt: time.Now().Unix() + int64(authResponse.ExpiresIn), TwoFAEnabled: authResponse.TwoFARequired, MailPassphrase: mailPassphrase, } if session.TwoFAEnabled { fmt.Println("\n2FA authentication required") totpPrompt := promptui.Prompt{ Label: "Enter TOTP code", Validate: func(input string) error { if len(input) != 6 { return fmt.Errorf("TOTP code must be 6 digits") } return nil }, } totpCode, err := totpPrompt.Run() if err != nil { return fmt.Errorf("failed to read TOTP code: %w", err) } totpURL := fmt.Sprintf("%s/auth/verify", apiBaseURL) totpPayload := map[string]string{ "UID": session.UID, "Code": totpCode, } totpJSON, err := json.Marshal(totpPayload) if err != nil { return fmt.Errorf("failed to marshal TOTP payload: %w", err) } totpReq, err := http.NewRequestWithContext(ctx, "POST", totpURL, bytes.NewBuffer(totpJSON)) if err != nil { return fmt.Errorf("failed to create TOTP request: %w", err) } totpReq.Header.Set("Authorization", fmt.Sprintf("Bearer %s", session.AccessToken)) totpReq.Header.Set("Content-Type", "application/json") totpResp, err := client.Do(totpReq) if err != nil { return fmt.Errorf("failed to verify TOTP: %w", err) } defer totpResp.Body.Close() if totpResp.StatusCode != http.StatusOK { body, _ := io.ReadAll(totpResp.Body) return fmt.Errorf("TOTP verification failed (status %d): %s", totpResp.StatusCode, string(body)) } var finalAuth struct { AccessToken string `json:"AccessToken"` RefreshToken string `json:"RefreshToken"` ExpiresIn int `json:"ExpiresIn"` } if err := json.NewDecoder(totpResp.Body).Decode(&finalAuth); err != nil { return fmt.Errorf("failed to parse TOTP response: %w", err) } session.AccessToken = finalAuth.AccessToken session.RefreshToken = finalAuth.RefreshToken session.ExpiresAt = time.Now().Unix() + int64(finalAuth.ExpiresIn) } encryptedForFile, err := encryptSession(session) if err != nil { return fmt.Errorf("failed to encrypt session: %w", err) } if err := m.keyring.Set(keyring.Item{ Key: "session", Data: encryptedForFile, }); err != nil { return fmt.Errorf("failed to store session in keyring: %w", err) } if err := os.WriteFile(m.sessionFile, encryptedForFile, 0600); err != nil { return fmt.Errorf("failed to write encrypted session file: %w", err) } fmt.Println("Logged in successfully") return nil } func (m *SessionManager) Logout() error { if err := os.Remove(m.sessionFile); err != nil { return fmt.Errorf("failed to remove session file: %w", err) } if err := m.keyring.Remove("session"); err != nil { return fmt.Errorf("failed to remove keyring entry: %w", err) } fmt.Println("Logged out successfully") return nil } func (m *SessionManager) GetSession() (*Session, error) { // First, try to get from keyring (encrypted storage) item, err := m.keyring.Get("session") if err == nil { session, err := decryptSession(item.Data) if err != nil { return nil, fmt.Errorf("failed to decrypt session from keyring: %w", err) } return &session, nil } // If not in keyring, read from encrypted file data, err := os.ReadFile(m.sessionFile) if err != nil { if err == os.ErrNotExist { return nil, fmt.Errorf("no session found: %w", err) } return nil, fmt.Errorf("failed to read session file: %w", err) } // Decrypt the session data session, err := decryptSession(data) if err != nil { return nil, fmt.Errorf("failed to decrypt session: %w", err) } return &session, nil } func (m *SessionManager) IsAuthenticated() (bool, error) { session, err := m.GetSession() if err != nil { return false, err } if time.Now().Unix() > session.ExpiresAt { return false, fmt.Errorf("session expired at %d", session.ExpiresAt) } return true, nil } func (m *SessionManager) RefreshToken() error { return m.RefreshTokenWithContext(context.Background()) } func (m *SessionManager) RefreshTokenWithContext(ctx context.Context) error { session, err := m.GetSession() if err != nil { return fmt.Errorf("failed to get session: %w", err) } if session.RefreshToken == "" { return fmt.Errorf("no refresh token available") } apiBaseURL := "https://api.protonmail.ch" refreshURL := fmt.Sprintf("%s/auth/refresh", apiBaseURL) payload := map[string]string{ "UID": session.UID, "RefreshToken": session.RefreshToken, } jsonData, err := json.Marshal(payload) if err != nil { return fmt.Errorf("failed to marshal refresh payload: %w", err) } req, err := http.NewRequestWithContext(ctx, "POST", refreshURL, bytes.NewBuffer(jsonData)) if err != nil { return fmt.Errorf("failed to create refresh request: %w", err) } req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", session.AccessToken)) req.Header.Set("Content-Type", "application/json") req.Header.Set("Accept", "application/json") client := &http.Client{Timeout: 30 * time.Second} resp, err := client.Do(req) if err != nil { return fmt.Errorf("failed to connect to ProtonMail API: %w", err) } defer resp.Body.Close() if resp.StatusCode != http.StatusOK { body, _ := io.ReadAll(resp.Body) return fmt.Errorf("token refresh failed (status %d): %s", resp.StatusCode, string(body)) } var refreshResponse struct { AccessToken string `json:"AccessToken"` RefreshToken string `json:"RefreshToken"` ExpiresIn int `json:"ExpiresIn"` } if err := json.NewDecoder(resp.Body).Decode(&refreshResponse); err != nil { return fmt.Errorf("failed to parse refresh response: %w", err) } session.AccessToken = refreshResponse.AccessToken session.RefreshToken = refreshResponse.RefreshToken session.ExpiresAt = time.Now().Unix() + int64(refreshResponse.ExpiresIn) encryptedForFile, err := encryptSession(*session) if err != nil { return fmt.Errorf("failed to encrypt updated session: %w", err) } if err := m.keyring.Set(keyring.Item{ Key: "session", Data: encryptedForFile, }); err != nil { return fmt.Errorf("failed to update session in keyring: %w", err) } _ = os.WriteFile(m.sessionFile, encryptedForFile, 0600) return nil } // encryptSession encrypts the session data using AES-256-GCM func encryptSession(session Session) ([]byte, error) { // Generate a random 256-bit key key := make([]byte, 32) if _, err := rand.Read(key); err != nil { return nil, fmt.Errorf("failed to generate encryption key: %w", err) } // Generate a random 12-byte nonce nonce := make([]byte, 12) if _, err := rand.Read(nonce); err != nil { return nil, fmt.Errorf("failed to generate nonce: %w", err) } // Create AES-GCM cipher block, err := aes.NewCipher(key) if err != nil { return nil, fmt.Errorf("failed to create AES cipher: %w", err) } aead, err := cipher.NewGCM(block) if err != nil { return nil, fmt.Errorf("failed to create GCM: %w", err) } // Encrypt the session data sessionData, err := json.Marshal(session) if err != nil { return nil, fmt.Errorf("failed to marshal session for encryption: %w", err) } sealedData := aead.Seal(nil, nonce, sessionData, nil) // Prepend key and nonce (base64 encoded for readability in file) header := fmt.Sprintf("%s|%s|", base64.StdEncoding.EncodeToString(key), base64.StdEncoding.EncodeToString(nonce)) return []byte(header + string(sealedData)), nil } // decryptSession decrypts the session data func decryptSession(encryptedData []byte) (Session, error) { // Split header and encrypted data parts := strings.Split(string(encryptedData), "|") if len(parts) != 3 { return Session{}, fmt.Errorf("invalid encrypted data format") } // Decode key and nonce key, err := base64.StdEncoding.DecodeString(parts[0]) if err != nil { return Session{}, fmt.Errorf("failed to decode key: %w", err) } nonce, err := base64.StdEncoding.DecodeString(parts[1]) if err != nil { return Session{}, fmt.Errorf("failed to decode nonce: %w", err) } // Decrypt block, err := aes.NewCipher(key) if err != nil { return Session{}, fmt.Errorf("failed to create AES cipher: %w", err) } aead, err := cipher.NewGCM(block) if err != nil { return Session{}, fmt.Errorf("failed to create GCM: %w", err) } sealedData, err := base64.StdEncoding.DecodeString(parts[2]) if err != nil { return Session{}, fmt.Errorf("failed to decode sealed data: %w", err) } data, err := aead.Open(nil, nonce, sealedData, nil) if err != nil { return Session{}, fmt.Errorf("failed to decrypt session: %w", err) } var session Session if err := json.Unmarshal(data, &session); err != nil { return Session{}, fmt.Errorf("failed to unmarshal session: %w", err) } return session, nil }