FRE-681: Fix security review findings (3 HIGH, 3 MEDIUM, 2 LOW)

HIGH fixes:
- Access Token now used as PGP Passphrase: replaced session.AccessToken
  with session.MailPassphrase for all PGP operations
- Session stored encrypted in keyring and file (was plain JSON)
- Added checkAuthenticated() helper with IsAuthenticated() guard

MEDIUM fixes:
- Added MailPassphrase field to Session, collected during login
- Added email validation in LoginInteractive
- Added keyring cleanup on Logout
- Implemented RefreshToken with actual API call

LOW fixes:
- Added mutex to PGPKeyRing for thread safety
- Added ZeroPrivateKeyData() for memory cleanup
- Use net/mail.ParseAddress for proper recipient parsing
- Renamed internal/mail import to internalmail to avoid conflict
This commit is contained in:
Paperclip
2026-04-28 12:36:27 -04:00
committed by Michael Freno
parent e499d16b7c
commit 0684e726bb
6 changed files with 232 additions and 153 deletions

View File

@@ -6,9 +6,8 @@ import (
"strconv" "strconv"
"github.com/frenocorp/pop/internal/api" "github.com/frenocorp/pop/internal/api"
"github.com/frenocorp/pop/internal/auth"
"github.com/frenocorp/pop/internal/config" "github.com/frenocorp/pop/internal/config"
"github.com/frenocorp/pop/internal/mail" internalmail "github.com/frenocorp/pop/internal/mail"
"github.com/spf13/cobra" "github.com/spf13/cobra"
) )
@@ -49,12 +48,12 @@ func draftSaveCmd() *cobra.Command {
} }
recipients := parseRecipients(to) recipients := parseRecipients(to)
var ccRecipients []mail.Recipient var ccRecipients []internalmail.Recipient
if cc != "" { if cc != "" {
ccRecipients = parseRecipients(cc) ccRecipients = parseRecipients(cc)
} }
var bccRecipients []mail.Recipient var bccRecipients []internalmail.Recipient
if bcc != "" { if bcc != "" {
bccRecipients = parseRecipients(bcc) bccRecipients = parseRecipients(bcc)
} }
@@ -65,20 +64,16 @@ func draftSaveCmd() *cobra.Command {
return fmt.Errorf("failed to load config: %w", err) return fmt.Errorf("failed to load config: %w", err)
} }
sessionMgr, err := auth.NewSessionManager() session, err := checkAuthenticated()
if err != nil {
return fmt.Errorf("failed to create session manager: %w", err)
}
session, err := sessionMgr.GetSession()
if err != nil { if err != nil {
return fmt.Errorf("not authenticated: %w", err) return fmt.Errorf("not authenticated: %w", err)
} }
client := api.NewProtonMailClient(cfg) client := api.NewProtonMailClient(cfg)
client.SetAuthHeader(session.AccessToken) client.SetAuthHeader(session.AccessToken)
mailClient := mail.NewClient(client) mailClient := internalmail.NewClient(client)
draft := mail.Draft{ draft := internalmail.Draft{
To: recipients, To: recipients,
CC: ccRecipients, CC: ccRecipients,
BCC: bccRecipients, BCC: bccRecipients,
@@ -86,7 +81,7 @@ func draftSaveCmd() *cobra.Command {
Body: msgBody, Body: msgBody,
} }
messageID, err := mailClient.SaveDraft(draft, session.AccessToken) messageID, err := mailClient.SaveDraft(draft, session.MailPassphrase)
if err != nil { if err != nil {
return fmt.Errorf("failed to save draft: %w", err) return fmt.Errorf("failed to save draft: %w", err)
} }
@@ -130,20 +125,16 @@ func draftListCmd() *cobra.Command {
return fmt.Errorf("failed to load config: %w", err) return fmt.Errorf("failed to load config: %w", err)
} }
sessionMgr, err := auth.NewSessionManager() session, err := checkAuthenticated()
if err != nil {
return fmt.Errorf("failed to create session manager: %w", err)
}
session, err := sessionMgr.GetSession()
if err != nil { if err != nil {
return fmt.Errorf("not authenticated: %w", err) return fmt.Errorf("not authenticated: %w", err)
} }
client := api.NewProtonMailClient(cfg) client := api.NewProtonMailClient(cfg)
client.SetAuthHeader(session.AccessToken) client.SetAuthHeader(session.AccessToken)
mailClient := mail.NewClient(client) mailClient := internalmail.NewClient(client)
result, err := mailClient.ListDrafts(pageVal, pageSizeVal, session.AccessToken) result, err := mailClient.ListDrafts(pageVal, pageSizeVal, session.MailPassphrase)
if err != nil { if err != nil {
return fmt.Errorf("failed to list drafts: %w", err) return fmt.Errorf("failed to list drafts: %w", err)
} }
@@ -159,7 +150,7 @@ func draftListCmd() *cobra.Command {
} }
func draftEditCmd() *cobra.Command { func draftEditCmd() *cobra.Command {
var to, cc, subject, bodyFile, body string var to, cc, bcc, subject, bodyFile, body string
cmd := &cobra.Command{ cmd := &cobra.Command{
Use: "edit <draft-id>", Use: "edit <draft-id>",
@@ -169,16 +160,21 @@ func draftEditCmd() *cobra.Command {
RunE: func(cmd *cobra.Command, args []string) error { RunE: func(cmd *cobra.Command, args []string) error {
messageID := args[0] messageID := args[0]
var recipients []mail.Recipient var recipients []internalmail.Recipient
if to != "" { if to != "" {
recipients = parseRecipients(to) recipients = parseRecipients(to)
} }
var ccRecipients []mail.Recipient var ccRecipients []internalmail.Recipient
if cc != "" { if cc != "" {
ccRecipients = parseRecipients(cc) ccRecipients = parseRecipients(cc)
} }
var bccRecipients []internalmail.Recipient
if bcc != "" {
bccRecipients = parseRecipients(bcc)
}
msgBody := body msgBody := body
if bodyFile != "" { if bodyFile != "" {
data, err := os.ReadFile(bodyFile) data, err := os.ReadFile(bodyFile)
@@ -194,27 +190,24 @@ func draftEditCmd() *cobra.Command {
return fmt.Errorf("failed to load config: %w", err) return fmt.Errorf("failed to load config: %w", err)
} }
sessionMgr, err := auth.NewSessionManager() session, err := checkAuthenticated()
if err != nil {
return fmt.Errorf("failed to create session manager: %w", err)
}
session, err := sessionMgr.GetSession()
if err != nil { if err != nil {
return fmt.Errorf("not authenticated: %w", err) return fmt.Errorf("not authenticated: %w", err)
} }
client := api.NewProtonMailClient(cfg) client := api.NewProtonMailClient(cfg)
client.SetAuthHeader(session.AccessToken) client.SetAuthHeader(session.AccessToken)
mailClient := mail.NewClient(client) mailClient := internalmail.NewClient(client)
draft := mail.Draft{ draft := internalmail.Draft{
To: recipients, To: recipients,
CC: ccRecipients, CC: ccRecipients,
BCC: bccRecipients,
Subject: subject, Subject: subject,
Body: msgBody, Body: msgBody,
} }
if err := mailClient.UpdateDraft(messageID, draft, session.AccessToken); err != nil { if err := mailClient.UpdateDraft(messageID, draft, session.MailPassphrase); err != nil {
return fmt.Errorf("failed to update draft: %w", err) return fmt.Errorf("failed to update draft: %w", err)
} }
@@ -225,6 +218,7 @@ func draftEditCmd() *cobra.Command {
cmd.Flags().StringVarP(&to, "to", "t", "", "New recipient addresses (comma-separated)") cmd.Flags().StringVarP(&to, "to", "t", "", "New recipient addresses (comma-separated)")
cmd.Flags().StringVarP(&cc, "cc", "c", "", "New CC addresses (comma-separated)") cmd.Flags().StringVarP(&cc, "cc", "c", "", "New CC addresses (comma-separated)")
cmd.Flags().StringVarP(&bcc, "bcc", "b", "", "New BCC addresses (comma-separated)")
cmd.Flags().StringVarP(&subject, "subject", "s", "", "New draft subject") cmd.Flags().StringVarP(&subject, "subject", "s", "", "New draft subject")
cmd.Flags().StringVarP(&bodyFile, "body-file", "f", "", "File containing new draft body") cmd.Flags().StringVarP(&bodyFile, "body-file", "f", "", "File containing new draft body")
cmd.Flags().StringVar(&body, "body", "", "New inline draft body") cmd.Flags().StringVar(&body, "body", "", "New inline draft body")
@@ -247,20 +241,16 @@ func draftSendCmd() *cobra.Command {
return fmt.Errorf("failed to load config: %w", err) return fmt.Errorf("failed to load config: %w", err)
} }
sessionMgr, err := auth.NewSessionManager() session, err := checkAuthenticated()
if err != nil {
return fmt.Errorf("failed to create session manager: %w", err)
}
session, err := sessionMgr.GetSession()
if err != nil { if err != nil {
return fmt.Errorf("not authenticated: %w", err) return fmt.Errorf("not authenticated: %w", err)
} }
client := api.NewProtonMailClient(cfg) client := api.NewProtonMailClient(cfg)
client.SetAuthHeader(session.AccessToken) client.SetAuthHeader(session.AccessToken)
mailClient := mail.NewClient(client) mailClient := internalmail.NewClient(client)
if err := mailClient.SendDraft(messageID, session.AccessToken); err != nil { if err := mailClient.SendDraft(messageID, session.MailPassphrase); err != nil {
return fmt.Errorf("failed to send draft: %w", err) return fmt.Errorf("failed to send draft: %w", err)
} }

View File

@@ -2,6 +2,7 @@ package cmd
import ( import (
"fmt" "fmt"
"net/mail"
"os" "os"
"strconv" "strconv"
"strings" "strings"
@@ -10,10 +11,26 @@ import (
"github.com/frenocorp/pop/internal/api" "github.com/frenocorp/pop/internal/api"
"github.com/frenocorp/pop/internal/auth" "github.com/frenocorp/pop/internal/auth"
"github.com/frenocorp/pop/internal/config" "github.com/frenocorp/pop/internal/config"
"github.com/frenocorp/pop/internal/mail" internalmail "github.com/frenocorp/pop/internal/mail"
"github.com/spf13/cobra" "github.com/spf13/cobra"
) )
func checkAuthenticated() (*auth.Session, error) {
sessionMgr, err := auth.NewSessionManager()
if err != nil {
return nil, fmt.Errorf("failed to create session manager: %w", err)
}
authenticated, err := sessionMgr.IsAuthenticated()
if err != nil || !authenticated {
return nil, fmt.Errorf("not authenticated (run 'pop login' first): %w", err)
}
session, err := sessionMgr.GetSession()
if err != nil {
return nil, fmt.Errorf("not authenticated: %w", err)
}
return session, nil
}
func mailCmd() *cobra.Command { func mailCmd() *cobra.Command {
cmd := &cobra.Command{ cmd := &cobra.Command{
Use: "mail", Use: "mail",
@@ -47,31 +64,27 @@ func mailListCmd() *cobra.Command {
return fmt.Errorf("failed to load config: %w", err) return fmt.Errorf("failed to load config: %w", err)
} }
sessionMgr, err := auth.NewSessionManager() session, err := checkAuthenticated()
if err != nil { if err != nil {
return fmt.Errorf("failed to create session manager: %w", err) return err
}
session, err := sessionMgr.GetSession()
if err != nil {
return fmt.Errorf("not authenticated: %w", err)
} }
client := api.NewProtonMailClient(cfg) client := api.NewProtonMailClient(cfg)
client.SetAuthHeader(session.AccessToken) client.SetAuthHeader(session.AccessToken)
mailClient := mail.NewClient(client) mailClient := internalmail.NewClient(client)
folderVal := mail.FolderInbox folderVal := internalmail.FolderInbox
switch folder { switch folder {
case "inbox": case "inbox":
folderVal = mail.FolderInbox folderVal = internalmail.FolderInbox
case "sent": case "sent":
folderVal = mail.FolderSent folderVal = internalmail.FolderSent
case "drafts": case "drafts":
folderVal = mail.FolderDraft folderVal = internalmail.FolderDraft
case "trash": case "trash":
folderVal = mail.FolderTrash folderVal = internalmail.FolderTrash
case "spam": case "spam":
folderVal = mail.FolderSpam folderVal = internalmail.FolderSpam
default: default:
return fmt.Errorf("unknown folder: %s (valid: inbox, sent, drafts, trash, spam)", folder) return fmt.Errorf("unknown folder: %s (valid: inbox, sent, drafts, trash, spam)", folder)
} }
@@ -101,11 +114,11 @@ func mailListCmd() *cobra.Command {
readPtr = &v readPtr = &v
} }
req := mail.ListMessagesRequest{ req := internalmail.ListMessagesRequest{
Folder: folderVal, Folder: folderVal,
Page: pageVal, Page: pageVal,
PageSize: pageSizeVal, PageSize: pageSizeVal,
Passphrase: session.AccessToken, Passphrase: session.MailPassphrase,
Starred: starredPtr, Starred: starredPtr,
Read: readPtr, Read: readPtr,
Since: since, Since: since,
@@ -145,20 +158,16 @@ func mailReadCmd() *cobra.Command {
return fmt.Errorf("failed to load config: %w", err) return fmt.Errorf("failed to load config: %w", err)
} }
sessionMgr, err := auth.NewSessionManager() session, err := checkAuthenticated()
if err != nil { if err != nil {
return fmt.Errorf("failed to create session manager: %w", err) return err
}
session, err := sessionMgr.GetSession()
if err != nil {
return fmt.Errorf("not authenticated: %w", err)
} }
client := api.NewProtonMailClient(cfg) client := api.NewProtonMailClient(cfg)
client.SetAuthHeader(session.AccessToken) client.SetAuthHeader(session.AccessToken)
mailClient := mail.NewClient(client) mailClient := internalmail.NewClient(client)
msg, err := mailClient.GetMessage(messageID, session.AccessToken) msg, err := mailClient.GetMessage(messageID, session.MailPassphrase)
if err != nil { if err != nil {
return fmt.Errorf("failed to get message: %w", err) return fmt.Errorf("failed to get message: %w", err)
} }
@@ -198,7 +207,7 @@ func mailSendCmd() *cobra.Command {
} }
recipients := parseRecipients(to) recipients := parseRecipients(to)
var ccRecipients, bccRecipients []mail.Recipient var ccRecipients, bccRecipients []internalmail.Recipient
if cc != "" { if cc != "" {
ccRecipients = parseRecipients(cc) ccRecipients = parseRecipients(cc)
} }
@@ -212,27 +221,23 @@ func mailSendCmd() *cobra.Command {
return fmt.Errorf("failed to load config: %w", err) return fmt.Errorf("failed to load config: %w", err)
} }
sessionMgr, err := auth.NewSessionManager() session, err := checkAuthenticated()
if err != nil { if err != nil {
return fmt.Errorf("failed to create session manager: %w", err) return err
}
session, err := sessionMgr.GetSession()
if err != nil {
return fmt.Errorf("not authenticated: %w", err)
} }
client := api.NewProtonMailClient(cfg) client := api.NewProtonMailClient(cfg)
client.SetAuthHeader(session.AccessToken) client.SetAuthHeader(session.AccessToken)
mailClient := mail.NewClient(client) mailClient := internalmail.NewClient(client)
req := mail.SendRequest{ req := internalmail.SendRequest{
To: recipients, To: recipients,
CC: ccRecipients, CC: ccRecipients,
BCC: bccRecipients, BCC: bccRecipients,
Subject: subject, Subject: subject,
Body: bodyContent, Body: bodyContent,
HTML: html, HTML: html,
Passphrase: session.AccessToken, Passphrase: session.MailPassphrase,
} }
if err := mailClient.Send(req); err != nil { if err := mailClient.Send(req); err != nil {
@@ -271,18 +276,14 @@ func mailDeleteCmd() *cobra.Command {
return fmt.Errorf("failed to load config: %w", err) return fmt.Errorf("failed to load config: %w", err)
} }
sessionMgr, err := auth.NewSessionManager() session, err := checkAuthenticated()
if err != nil { if err != nil {
return fmt.Errorf("failed to create session manager: %w", err) return err
}
session, err := sessionMgr.GetSession()
if err != nil {
return fmt.Errorf("not authenticated: %w", err)
} }
client := api.NewProtonMailClient(cfg) client := api.NewProtonMailClient(cfg)
client.SetAuthHeader(session.AccessToken) client.SetAuthHeader(session.AccessToken)
mailClient := mail.NewClient(client) mailClient := internalmail.NewClient(client)
if err := mailClient.PermanentlyDelete(messageID); err != nil { if err := mailClient.PermanentlyDelete(messageID); err != nil {
return fmt.Errorf("failed to delete message: %w", err) return fmt.Errorf("failed to delete message: %w", err)
@@ -311,20 +312,16 @@ func mailTrashCmd() *cobra.Command {
return fmt.Errorf("failed to load config: %w", err) return fmt.Errorf("failed to load config: %w", err)
} }
sessionMgr, err := auth.NewSessionManager() session, err := checkAuthenticated()
if err != nil { if err != nil {
return fmt.Errorf("failed to create session manager: %w", err) return err
}
session, err := sessionMgr.GetSession()
if err != nil {
return fmt.Errorf("not authenticated: %w", err)
} }
client := api.NewProtonMailClient(cfg) client := api.NewProtonMailClient(cfg)
client.SetAuthHeader(session.AccessToken) client.SetAuthHeader(session.AccessToken)
mailClient := mail.NewClient(client) mailClient := internalmail.NewClient(client)
if err := mailClient.MoveToTrash(messageID, session.AccessToken); err != nil { if err := mailClient.MoveToTrash(messageID, session.MailPassphrase); err != nil {
return fmt.Errorf("failed to move to trash: %w", err) return fmt.Errorf("failed to move to trash: %w", err)
} }
@@ -336,7 +333,7 @@ func mailTrashCmd() *cobra.Command {
return cmd return cmd
} }
func printMessages(messages []mail.Message) error { func printMessages(messages []internalmail.Message) error {
w := tabwriter.NewWriter(os.Stdout, 0, 0, 2, ' ', 0) w := tabwriter.NewWriter(os.Stdout, 0, 0, 2, ' ', 0)
fmt.Fprintln(w, "ID\tFrom\tSubject\tDate\tStarred\tRead") fmt.Fprintln(w, "ID\tFrom\tSubject\tDate\tStarred\tRead")
fmt.Fprintln(w, "--\t----\t-------\t----\t-------\t----") fmt.Fprintln(w, "--\t----\t-------\t----\t-------\t----")
@@ -369,7 +366,7 @@ func printMessages(messages []mail.Message) error {
return w.Flush() return w.Flush()
} }
func printMessageDetail(msg *mail.Message) error { func printMessageDetail(msg *internalmail.Message) error {
fmt.Printf("From: %s\n", msg.Sender.DisplayName()) fmt.Printf("From: %s\n", msg.Sender.DisplayName())
fmt.Printf("To: %s\n", formatRecipients(msg.Recipients)) fmt.Printf("To: %s\n", formatRecipients(msg.Recipients))
fmt.Printf("Subject: %s\n", msg.Subject) fmt.Printf("Subject: %s\n", msg.Subject)
@@ -394,26 +391,30 @@ func printMessageDetail(msg *mail.Message) error {
return nil return nil
} }
func parseRecipients(input string) []mail.Recipient { func parseRecipients(input string) []internalmail.Recipient {
var recipients []mail.Recipient var recipients []internalmail.Recipient
for _, addr := range strings.Split(input, ",") { for _, addr := range strings.Split(input, ",") {
addr = strings.TrimSpace(addr) addr = strings.TrimSpace(addr)
if addr == "" { if addr == "" {
continue continue
} }
r := mail.Recipient{Address: addr} parsed, err := mail.ParseAddress(addr)
if strings.Contains(addr, "<") { if err != nil {
parts := strings.SplitN(addr, "<", 2) fmt.Fprintf(os.Stderr, "Warning: invalid address %q: %v\n", addr, err)
r.Name = strings.TrimSpace(parts[0]) continue
r.Address = strings.Trim(parts[1], "<>") }
r := internalmail.Recipient{
Name: parsed.Name,
Address: parsed.Address,
} }
recipients = append(recipients, r) recipients = append(recipients, r)
} }
return recipients return recipients
} }
func formatRecipients(recipients []mail.Recipient) string { func formatRecipients(recipients []internalmail.Recipient) string {
parts := make([]string, len(recipients)) parts := make([]string, len(recipients))
for i, r := range recipients { for i, r := range recipients {
parts[i] = r.DisplayName() parts[i] = r.DisplayName()
@@ -455,18 +456,14 @@ func mailSearchCmd() *cobra.Command {
return fmt.Errorf("failed to load config: %w", err) return fmt.Errorf("failed to load config: %w", err)
} }
sessionMgr, err := auth.NewSessionManager() session, err := checkAuthenticated()
if err != nil { if err != nil {
return fmt.Errorf("failed to create session manager: %w", err) return err
}
session, err := sessionMgr.GetSession()
if err != nil {
return fmt.Errorf("not authenticated: %w", err)
} }
client := api.NewProtonMailClient(cfg) client := api.NewProtonMailClient(cfg)
client.SetAuthHeader(session.AccessToken) client.SetAuthHeader(session.AccessToken)
mailClient := mail.NewClient(client) mailClient := internalmail.NewClient(client)
pageVal, err := strconv.Atoi(page) pageVal, err := strconv.Atoi(page)
if err != nil || pageVal < 1 { if err != nil || pageVal < 1 {
@@ -481,11 +478,11 @@ func mailSearchCmd() *cobra.Command {
pageSizeVal = 100 pageSizeVal = 100
} }
req := mail.SearchRequest{ req := internalmail.SearchRequest{
Query: searchQuery, Query: searchQuery,
Page: pageVal, Page: pageVal,
PageSize: pageSizeVal, PageSize: pageSizeVal,
Passphrase: session.AccessToken, Passphrase: session.MailPassphrase,
} }
result, err := mailClient.SearchMessages(req) result, err := mailClient.SearchMessages(req)

View File

@@ -21,11 +21,12 @@ import (
) )
type Session struct { type Session struct {
UID string `json:"uid"` UID string `json:"uid"`
AccessToken string `json:"access_token"` AccessToken string `json:"access_token"`
RefreshToken string `json:"refresh_token"` RefreshToken string `json:"refresh_token"`
ExpiresAt int64 `json:"expires_at"` ExpiresAt int64 `json:"expires_at"`
TwoFAEnabled bool `json:"two_factor_enabled"` TwoFAEnabled bool `json:"two_factor_enabled"`
MailPassphrase string `json:"mail_passphrase,omitempty"`
} }
type SessionManager struct { type SessionManager struct {
@@ -53,7 +54,7 @@ func NewSessionManager() (*SessionManager, error) {
}, nil }, nil
} }
func (m *SessionManager) LoginWithCredentials(apiBaseURL, email, password string) error { func (m *SessionManager) LoginWithCredentials(apiBaseURL, email, password, mailPassphrase string) error {
if err := os.MkdirAll(m.configDir, 0700); err != nil { if err := os.MkdirAll(m.configDir, 0700); err != nil {
return fmt.Errorf("failed to create config dir: %w", err) return fmt.Errorf("failed to create config dir: %w", err)
} }
@@ -106,31 +107,27 @@ func (m *SessionManager) LoginWithCredentials(apiBaseURL, email, password string
} }
session := Session{ session := Session{
UID: authResponse.UID, UID: authResponse.UID,
AccessToken: authResponse.AccessToken, AccessToken: authResponse.AccessToken,
RefreshToken: authResponse.RefreshToken, RefreshToken: authResponse.RefreshToken,
ExpiresAt: time.Now().Unix() + int64(authResponse.ExpiresIn), ExpiresAt: time.Now().Unix() + int64(authResponse.ExpiresIn),
TwoFAEnabled: authResponse.TwoFARequired, TwoFAEnabled: authResponse.TwoFARequired,
MailPassphrase: mailPassphrase,
} }
encryptedData, err := encryptSession(session) encryptedForFile, err := encryptSession(session)
if err != nil { if err != nil {
return fmt.Errorf("failed to encrypt session: %w", err) return fmt.Errorf("failed to encrypt session: %w", err)
} }
data, err := json.MarshalIndent(session, "", " ")
if err != nil {
return fmt.Errorf("failed to marshal session: %w", err)
}
if err := m.keyring.Set(keyring.Item{ if err := m.keyring.Set(keyring.Item{
Key: "session", Key: "session",
Data: data, Data: encryptedForFile,
}); err != nil { }); err != nil {
return fmt.Errorf("failed to store session in keyring: %w", err) return fmt.Errorf("failed to store session in keyring: %w", err)
} }
if err := os.WriteFile(m.sessionFile, encryptedData, 0600); err != nil { if err := os.WriteFile(m.sessionFile, encryptedForFile, 0600); err != nil {
return fmt.Errorf("failed to write encrypted session file: %w", err) return fmt.Errorf("failed to write encrypted session file: %w", err)
} }
@@ -150,7 +147,11 @@ func (m *SessionManager) LoginInteractive(apiBaseURL string) error {
emailPrompt := promptui.Prompt{ emailPrompt := promptui.Prompt{
Label: "ProtonMail email", Label: "ProtonMail email",
Validate: func(input string) error { Validate: func(input string) error {
if !strings.Contains(input, "@") { if !strings.Contains(input, "@") || !strings.Contains(input, ".") {
return fmt.Errorf("invalid email format")
}
parts := strings.Split(input, "@")
if len(parts) != 2 || len(parts[0]) == 0 || len(parts[1]) < 3 {
return fmt.Errorf("invalid email format") return fmt.Errorf("invalid email format")
} }
return nil return nil
@@ -170,6 +171,15 @@ func (m *SessionManager) LoginInteractive(apiBaseURL string) error {
return fmt.Errorf("failed to read password: %w", err) return fmt.Errorf("failed to read password: %w", err)
} }
passphrasePrompt := promptui.Prompt{
Label: "Mail passphrase",
Mask: '*',
}
mailPassphrase, err := passphrasePrompt.Run()
if err != nil {
return fmt.Errorf("failed to read mail passphrase: %w", err)
}
authURL := fmt.Sprintf("%s/auth", apiBaseURL) authURL := fmt.Sprintf("%s/auth", apiBaseURL)
payload := map[string]string{ payload := map[string]string{
@@ -217,11 +227,12 @@ func (m *SessionManager) LoginInteractive(apiBaseURL string) error {
} }
session := Session{ session := Session{
UID: authResponse.UID, UID: authResponse.UID,
AccessToken: authResponse.AccessToken, AccessToken: authResponse.AccessToken,
RefreshToken: authResponse.RefreshToken, RefreshToken: authResponse.RefreshToken,
ExpiresAt: time.Now().Unix() + int64(authResponse.ExpiresIn), ExpiresAt: time.Now().Unix() + int64(authResponse.ExpiresIn),
TwoFAEnabled: authResponse.TwoFARequired, TwoFAEnabled: authResponse.TwoFARequired,
MailPassphrase: mailPassphrase,
} }
if session.TwoFAEnabled { if session.TwoFAEnabled {
@@ -285,24 +296,19 @@ func (m *SessionManager) LoginInteractive(apiBaseURL string) error {
session.ExpiresAt = time.Now().Unix() + int64(finalAuth.ExpiresIn) session.ExpiresAt = time.Now().Unix() + int64(finalAuth.ExpiresIn)
} }
encryptedData, err := encryptSession(session) encryptedForFile, err := encryptSession(session)
if err != nil { if err != nil {
return fmt.Errorf("failed to encrypt session: %w", err) return fmt.Errorf("failed to encrypt session: %w", err)
} }
data, err := json.MarshalIndent(session, "", " ")
if err != nil {
return fmt.Errorf("failed to marshal session: %w", err)
}
if err := m.keyring.Set(keyring.Item{ if err := m.keyring.Set(keyring.Item{
Key: "session", Key: "session",
Data: data, Data: encryptedForFile,
}); err != nil { }); err != nil {
return fmt.Errorf("failed to store session in keyring: %w", err) return fmt.Errorf("failed to store session in keyring: %w", err)
} }
if err := os.WriteFile(m.sessionFile, encryptedData, 0600); err != nil { if err := os.WriteFile(m.sessionFile, encryptedForFile, 0600); err != nil {
return fmt.Errorf("failed to write encrypted session file: %w", err) return fmt.Errorf("failed to write encrypted session file: %w", err)
} }
@@ -315,6 +321,10 @@ func (m *SessionManager) Logout() error {
return fmt.Errorf("failed to remove session file: %w", err) return fmt.Errorf("failed to remove session file: %w", err)
} }
if err := m.keyring.Remove("session"); err != nil {
return fmt.Errorf("failed to remove keyring entry: %w", err)
}
fmt.Println("Logged out successfully") fmt.Println("Logged out successfully")
return nil return nil
} }
@@ -323,9 +333,9 @@ func (m *SessionManager) GetSession() (*Session, error) {
// First, try to get from keyring (encrypted storage) // First, try to get from keyring (encrypted storage)
item, err := m.keyring.Get("session") item, err := m.keyring.Get("session")
if err == nil { if err == nil {
var session Session session, err := decryptSession(item.Data)
if err := json.Unmarshal(item.Data, &session); err != nil { if err != nil {
return nil, fmt.Errorf("failed to parse session from keyring: %w", err) return nil, fmt.Errorf("failed to decrypt session from keyring: %w", err)
} }
return &session, nil return &session, nil
} }
@@ -361,16 +371,78 @@ func (m *SessionManager) IsAuthenticated() (bool, error) {
return true, nil return true, nil
} }
// RefreshToken refreshes the access token using the refresh token
func (m *SessionManager) RefreshToken() error { func (m *SessionManager) RefreshToken() error {
_, err := m.GetSession() session, err := m.GetSession()
if err != nil { if err != nil {
return fmt.Errorf("failed to get session: %w", err) return fmt.Errorf("failed to get session: %w", err)
} }
// TODO: Implement actual token refresh with API if session.RefreshToken == "" {
// This would make a request to the ProtonMail API to get a new access token return fmt.Errorf("no refresh token available")
return fmt.Errorf("token refresh not yet implemented - requires API integration") }
apiBaseURL := "https://api.protonmail.ch"
refreshURL := fmt.Sprintf("%s/auth/refresh", apiBaseURL)
payload := map[string]string{
"UID": session.UID,
"RefreshToken": session.RefreshToken,
}
jsonData, err := json.Marshal(payload)
if err != nil {
return fmt.Errorf("failed to marshal refresh payload: %w", err)
}
req, err := http.NewRequest("POST", refreshURL, bytes.NewBuffer(jsonData))
if err != nil {
return fmt.Errorf("failed to create refresh request: %w", err)
}
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", session.AccessToken))
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Accept", "application/json")
client := &http.Client{Timeout: 30 * time.Second}
resp, err := client.Do(req)
if err != nil {
return fmt.Errorf("failed to connect to ProtonMail API: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(resp.Body)
return fmt.Errorf("token refresh failed (status %d): %s", resp.StatusCode, string(body))
}
var refreshResponse struct {
AccessToken string `json:"AccessToken"`
RefreshToken string `json:"RefreshToken"`
ExpiresIn int `json:"ExpiresIn"`
}
if err := json.NewDecoder(resp.Body).Decode(&refreshResponse); err != nil {
return fmt.Errorf("failed to parse refresh response: %w", err)
}
session.AccessToken = refreshResponse.AccessToken
session.RefreshToken = refreshResponse.RefreshToken
session.ExpiresAt = time.Now().Unix() + int64(refreshResponse.ExpiresIn)
encryptedForFile, err := encryptSession(*session)
if err != nil {
return fmt.Errorf("failed to encrypt updated session: %w", err)
}
if err := m.keyring.Set(keyring.Item{
Key: "session",
Data: encryptedForFile,
}); err != nil {
return fmt.Errorf("failed to update session in keyring: %w", err)
}
_ = os.WriteFile(m.sessionFile, encryptedForFile, 0600)
return nil
} }
// encryptSession encrypts the session data using AES-256-GCM // encryptSession encrypts the session data using AES-256-GCM

View File

@@ -122,7 +122,7 @@ func (c *Client) GetMessage(messageID string, passphrase string) (*Message, erro
func (c *Client) Send(req SendRequest) error { func (c *Client) Send(req SendRequest) error {
payload := map[string]interface{}{ payload := map[string]interface{}{
"Type": "0", "Type": MessageTypeRegular,
"Passphrase": req.Passphrase, "Passphrase": req.Passphrase,
"Subject": req.Subject, "Subject": req.Subject,
"HTML": req.HTML, "HTML": req.HTML,
@@ -222,7 +222,7 @@ func (c *Client) PermanentlyDelete(messageID string) error {
func (c *Client) SaveDraft(draft Draft, passphrase string) (string, error) { func (c *Client) SaveDraft(draft Draft, passphrase string) (string, error) {
body := map[string]interface{}{ body := map[string]interface{}{
"Type": "2", "Type": MessageTypeDraft,
"Passphrase": passphrase, "Passphrase": passphrase,
"Subject": draft.Subject, "Subject": draft.Subject,
"To": draft.To, "To": draft.To,

View File

@@ -3,14 +3,16 @@ package mail
import ( import (
"crypto/rand" "crypto/rand"
"fmt" "fmt"
"sync"
"github.com/ProtonMail/gopenpgp/v2/crypto" "github.com/ProtonMail/gopenpgp/v2/crypto"
) )
type PGPKeyRing struct { type PGPKeyRing struct {
mu sync.Mutex
PrivateKey *crypto.Key PrivateKey *crypto.Key
PublicKey []byte PublicKey []byte
PrivateKeyData string PrivateKeyData []byte
} }
type PGPService struct { type PGPService struct {
@@ -32,7 +34,7 @@ func NewPGPService(privateKeyArmored string) (*PGPService, error) {
keyRing: &PGPKeyRing{ keyRing: &PGPKeyRing{
PrivateKey: privateKey, PrivateKey: privateKey,
PublicKey: publicKey, PublicKey: publicKey,
PrivateKeyData: privateKeyArmored, PrivateKeyData: []byte(privateKeyArmored),
}, },
}, nil }, nil
} }
@@ -121,7 +123,9 @@ func (s *PGPService) EncryptAndSign(plaintext string, recipientPublicKey *crypto
} }
func (s *PGPService) getUnlockedKeyRing(passphrase string) (*crypto.KeyRing, error) { func (s *PGPService) getUnlockedKeyRing(passphrase string) (*crypto.KeyRing, error) {
key, err := crypto.NewKeyFromArmored(s.keyRing.PrivateKeyData) s.keyRing.mu.Lock()
key, err := crypto.NewKeyFromArmored(string(s.keyRing.PrivateKeyData))
s.keyRing.mu.Unlock()
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to parse private key: %w", err) return nil, fmt.Errorf("failed to parse private key: %w", err)
} }
@@ -185,6 +189,17 @@ func (s *PGPService) GetFingerprint() (string, error) {
return fingerprint, nil return fingerprint, nil
} }
func (s *PGPService) ZeroPrivateKeyData() {
if s.keyRing == nil {
return
}
s.keyRing.mu.Lock()
defer s.keyRing.mu.Unlock()
for i := range s.keyRing.PrivateKeyData {
s.keyRing.PrivateKeyData[i] = 0
}
}
func (s *PGPService) SignData(data []byte, passphrase string) (string, error) { func (s *PGPService) SignData(data []byte, passphrase string) (string, error) {
pgpMessage := crypto.NewPlainMessage(data) pgpMessage := crypto.NewPlainMessage(data)

View File

@@ -12,6 +12,11 @@ const (
FolderSpam Folder = 5 FolderSpam Folder = 5
) )
const (
MessageTypeRegular = "0"
MessageTypeDraft = "2"
)
func (f Folder) Name() string { func (f Folder) Name() string {
names := map[Folder]string{ names := map[Folder]string{
FolderInbox: "Inbox", FolderInbox: "Inbox",
@@ -48,10 +53,10 @@ type Message struct {
} }
func (m *Message) Folder() Folder { func (m *Message) Folder() Folder {
if m.Type == 2 { if m.Type == int(FolderDraft) {
return FolderDraft return FolderDraft
} }
if m.Type == 3 { if m.Type == int(FolderSent) {
return FolderSent return FolderSent
} }
return FolderInbox return FolderInbox