diff --git a/internal/mail/pgp.go b/internal/mail/pgp.go index 9d289fe..52f8d42 100644 --- a/internal/mail/pgp.go +++ b/internal/mail/pgp.go @@ -25,15 +25,25 @@ func NewPGPService(privateKeyArmored string) (*PGPService, error) { return nil, fmt.Errorf("failed to parse private key: %w", err) } - publicKey, err := privateKey.GetPublicKey() + pubKeyBytes, err := privateKey.GetPublicKey() if err != nil { return nil, fmt.Errorf("failed to extract public key: %w", err) } + pubKey, err := crypto.NewKey(pubKeyBytes) + if err != nil { + return nil, fmt.Errorf("failed to parse public key: %w", err) + } + + pubArmor, err := pubKey.Armor() + if err != nil { + return nil, fmt.Errorf("failed to armor public key: %w", err) + } + return &PGPService{ keyRing: &PGPKeyRing{ PrivateKey: privateKey, - PublicKey: publicKey, + PublicKey: []byte(pubArmor), PrivateKeyData: []byte(privateKeyArmored), }, }, nil @@ -68,7 +78,7 @@ func (s *PGPService) EncryptBody(plaintext string, passphrase string) (string, e return "", fmt.Errorf("failed to get public key: %w", err) } - pubKey, err := crypto.NewKeyFromArmored(string(pubKeyBytes)) + pubKey, err := crypto.NewKey(pubKeyBytes) if err != nil { return "", fmt.Errorf("failed to parse public key: %w", err) } @@ -131,11 +141,17 @@ func (s *PGPService) getUnlockedKeyRing(passphrase string) (*crypto.KeyRing, err } if passphrase != "" { - unlockedKey, err := key.Unlock([]byte(passphrase)) + isLocked, err := key.IsLocked() if err != nil { - return nil, fmt.Errorf("failed to unlock private key: %w", err) + return nil, fmt.Errorf("failed to check key lock status: %w", err) + } + if isLocked { + unlockedKey, err := key.Unlock([]byte(passphrase)) + if err != nil { + return nil, fmt.Errorf("failed to unlock private key: %w", err) + } + key = unlockedKey } - key = unlockedKey } return crypto.NewKeyRing(key) @@ -176,7 +192,15 @@ func (s *PGPService) GenerateKeyPair(email string, passphrase string) (privateKe return "", "", fmt.Errorf("failed to extract public key: %w", err) } - pubArmor := string(pubKeyBytes) + pubKey, err := crypto.NewKey(pubKeyBytes) + if err != nil { + return "", "", fmt.Errorf("failed to parse public key: %w", err) + } + + pubArmor, err := pubKey.Armor() + if err != nil { + return "", "", fmt.Errorf("failed to armor public key: %w", err) + } return privateArmor, pubArmor, nil } @@ -229,7 +253,7 @@ func (s *PGPService) EncryptAttachment(data []byte, recipientPublicKey *crypto.K pgpMessage := crypto.NewPlainMessage(data) - sk, err := crypto.NewSessionKeyFromToken(symKey, "AES256").Encrypt(pgpMessage) + sk, err := crypto.NewSessionKeyFromToken(symKey, "aes256").Encrypt(pgpMessage) if err != nil { return nil, fmt.Errorf("failed to encrypt attachment: %w", err) } @@ -241,7 +265,7 @@ func (s *PGPService) EncryptAttachment(data []byte, recipientPublicKey *crypto.K } encryptedSymKey, err := recipientKeyRing.EncryptSessionKey( - crypto.NewSessionKeyFromToken(symKey, "AES256"), + crypto.NewSessionKeyFromToken(symKey, "aes256"), ) if err != nil { return nil, fmt.Errorf("failed to encrypt symmetric key: %w", err) diff --git a/internal/mail/pgp_test.go b/internal/mail/pgp_test.go new file mode 100644 index 0000000..999e9c3 --- /dev/null +++ b/internal/mail/pgp_test.go @@ -0,0 +1,557 @@ +package mail + +import ( + "strings" + "testing" + + "github.com/ProtonMail/gopenpgp/v2/crypto" +) + +// testKey generates a fresh PGP key pair for tests. +func testKey(t *testing.T) (privateKey, publicKey, passphrase string) { + t.Helper() + svc := &PGPService{} + privateKey, publicKey, err := svc.GenerateKeyPair("test@example.com", "test-passphrase") + if err != nil { + t.Fatalf("GenerateKeyPair: %v", err) + } + return privateKey, publicKey, "test-passphrase" +} + +// newTestService creates a PGPService from a freshly generated key. +func newTestService(t *testing.T) (*PGPService, string, string) { + t.Helper() + priv, pub, pass := testKey(t) + svc, err := NewPGPService(priv) + if err != nil { + t.Fatalf("NewPGPService: %v", err) + } + return svc, pub, pass +} + +// newLockedTestService creates a PGPService with a properly passphrase-locked key. +// crypto.GenerateKey creates unlocked keys, so we explicitly lock the key after generation. +func newLockedTestService(t *testing.T) (*PGPService, string, string) { + t.Helper() + key, err := crypto.GenerateKey("test@example.com", "", "RSA", 4096) + if err != nil { + t.Fatalf("GenerateKey: %v", err) + } + lockedKey, err := key.Lock([]byte("test-passphrase")) + if err != nil { + t.Fatalf("Lock: %v", err) + } + privArmor, err := lockedKey.Armor() + if err != nil { + t.Fatalf("Armor: %v", err) + } + pubKeyBytes, err := lockedKey.GetPublicKey() + if err != nil { + t.Fatalf("GetPublicKey: %v", err) + } + pubKey, err := crypto.NewKey(pubKeyBytes) + if err != nil { + t.Fatalf("NewKey: %v", err) + } + pubArmor, err := pubKey.Armor() + if err != nil { + t.Fatalf("Armor public key: %v", err) + } + svc, err := NewPGPService(privArmor) + if err != nil { + t.Fatalf("NewPGPService: %v", err) + } + return svc, pubArmor, "test-passphrase" +} + +// ---------- NewPGPService ---------- + +func TestNewPGPService_ValidKey(t *testing.T) { + priv, _, _ := testKey(t) + svc, err := NewPGPService(priv) + if err != nil { + t.Fatalf("NewPGPService: %v", err) + } + if svc.keyRing == nil { + t.Fatal("keyRing is nil") + } + if svc.keyRing.PrivateKey == nil { + t.Fatal("PrivateKey is nil") + } + if len(svc.keyRing.PublicKey) == 0 { + t.Fatal("PublicKey is empty") + } +} + +func TestNewPGPService_EmptyKey(t *testing.T) { + _, err := NewPGPService("") + if err == nil { + t.Fatal("expected error for empty key") + } + if !strings.Contains(err.Error(), "failed to parse private key") { + t.Errorf("unexpected error message: %s", err.Error()) + } +} + +func TestNewPGPService_InvalidKey(t *testing.T) { + _, err := NewPGPService("NOT A PGP KEY") + if err == nil { + t.Fatal("expected error for invalid key") + } +} + +// ---------- GenerateKeyPair ---------- + +func TestGenerateKeyPair_Success(t *testing.T) { + svc := &PGPService{} + priv, pub, err := svc.GenerateKeyPair("alice@example.com", "pass123") + if err != nil { + t.Fatalf("GenerateKeyPair: %v", err) + } + if !strings.Contains(priv, "BEGIN PGP PRIVATE KEY BLOCK") { + t.Error("private key missing armored header") + } + if !strings.Contains(pub, "BEGIN PGP PUBLIC KEY BLOCK") { + t.Error("public key missing armored header") + } +} + +func TestGenerateKeyPair_EmptyEmail(t *testing.T) { + svc := &PGPService{} + priv, pub, err := svc.GenerateKeyPair("", "pass123") + if err != nil { + t.Fatalf("GenerateKeyPair with empty email: %v", err) + } + if !strings.Contains(priv, "BEGIN PGP PRIVATE KEY BLOCK") { + t.Error("private key missing armored header") + } + if !strings.Contains(pub, "BEGIN PGP PUBLIC KEY BLOCK") { + t.Error("public key missing armored header") + } +} + +// ---------- GetFingerprint ---------- + +func TestGetFingerprint_Success(t *testing.T) { + svc, _, _ := newTestService(t) + fp, err := svc.GetFingerprint() + if err != nil { + t.Fatalf("GetFingerprint: %v", err) + } + if len(fp) != 40 { + t.Errorf("expected 40-char fingerprint, got %d", len(fp)) + } +} + +func TestGetFingerprint_NoKeyRing(t *testing.T) { + svc := &PGPService{} + _, err := svc.GetFingerprint() + if err == nil { + t.Fatal("expected error for nil keyRing") + } + if !strings.Contains(err.Error(), "no key ring available") { + t.Errorf("unexpected error: %s", err.Error()) + } +} + +// ---------- ZeroPrivateKeyData ---------- + +func TestZeroPrivateKeyData_Success(t *testing.T) { + svc, _, _ := newTestService(t) + initialLen := len(svc.keyRing.PrivateKeyData) + svc.ZeroPrivateKeyData() + for i, b := range svc.keyRing.PrivateKeyData { + if b != 0 { + t.Errorf("byte %d not zeroed: %d", i, b) + } + } + if len(svc.keyRing.PrivateKeyData) != initialLen { + t.Error("PrivateKeyData length changed after zeroing") + } +} + +func TestZeroPrivateKeyData_NilKeyRing(t *testing.T) { + svc := &PGPService{} + svc.ZeroPrivateKeyData() // should not panic +} + +// ---------- Encrypt / Decrypt roundtrip ---------- + +func TestEncryptDecrypt_Roundtrip(t *testing.T) { + svc, pubArmor, pass := newTestService(t) + + recipientKey, err := crypto.NewKeyFromArmored(pubArmor) + if err != nil { + t.Fatalf("parse recipient key: %v", err) + } + + plaintext := "Hello, encrypted world!" + encrypted, err := svc.Encrypt(plaintext, recipientKey) + if err != nil { + t.Fatalf("Encrypt: %v", err) + } + if !strings.Contains(encrypted, "BEGIN PGP MESSAGE") { + t.Error("encrypted output missing PGP message header") + } + + decrypted, err := svc.Decrypt(encrypted, pass) + if err != nil { + t.Fatalf("Decrypt: %v", err) + } + if decrypted != plaintext { + t.Errorf("roundtrip mismatch: got %q, want %q", decrypted, plaintext) + } +} + +func TestEncryptDecrypt_LargePayload(t *testing.T) { + svc, pubArmor, pass := newTestService(t) + + recipientKey, err := crypto.NewKeyFromArmored(pubArmor) + if err != nil { + t.Fatalf("parse recipient key: %v", err) + } + + payload := strings.Repeat("ABCDEFGHijklmnop12345678\n", 100) + encrypted, err := svc.Encrypt(payload, recipientKey) + if err != nil { + t.Fatalf("Encrypt: %v", err) + } + + decrypted, err := svc.Decrypt(encrypted, pass) + if err != nil { + t.Fatalf("Decrypt: %v", err) + } + if decrypted != payload { + t.Errorf("large payload roundtrip mismatch") + } +} + +func TestDecrypt_InvalidMessage(t *testing.T) { + svc, _, _ := newTestService(t) + _, err := svc.Decrypt("NOT A PGP MESSAGE", "test-passphrase") + if err == nil { + t.Fatal("expected error for invalid message") + } +} + +func TestDecrypt_WrongPassphrase(t *testing.T) { + svc, pubArmor, _ := newLockedTestService(t) + + recipientKey, err := crypto.NewKeyFromArmored(pubArmor) + if err != nil { + t.Fatalf("parse recipient key: %v", err) + } + + encrypted, err := svc.Encrypt("secret", recipientKey) + if err != nil { + t.Fatalf("Encrypt: %v", err) + } + + _, err = svc.Decrypt(encrypted, "wrong-passphrase") + if err == nil { + t.Fatal("expected error for wrong passphrase") + } +} + +// ---------- EncryptBody ---------- + +func TestEncryptBody_Success(t *testing.T) { + svc, _, pass := newTestService(t) + + plaintext := "Body content to encrypt" + encrypted, err := svc.EncryptBody(plaintext, pass) + if err != nil { + t.Fatalf("EncryptBody: %v", err) + } + if !strings.Contains(encrypted, "BEGIN PGP MESSAGE") { + t.Error("encrypted body missing PGP message header") + } + + decrypted, err := svc.Decrypt(encrypted, pass) + if err != nil { + t.Fatalf("Decrypt: %v", err) + } + if decrypted != plaintext { + t.Errorf("EncryptBody roundtrip mismatch: got %q, want %q", decrypted, plaintext) + } +} + +func TestEncryptBody_WrongPassphrase(t *testing.T) { + svc, _, _ := newLockedTestService(t) + + _, err := svc.EncryptBody("content", "wrong-passphrase") + if err == nil { + t.Fatal("expected error for wrong passphrase") + } +} + +// ---------- EncryptAndSign ---------- + +func TestEncryptAndSign_Success(t *testing.T) { + svc, pubArmor, pass := newTestService(t) + + recipientKey, err := crypto.NewKeyFromArmored(pubArmor) + if err != nil { + t.Fatalf("parse recipient key: %v", err) + } + + plaintext := "Signed and encrypted content" + encrypted, err := svc.EncryptAndSign(plaintext, recipientKey, pass) + if err != nil { + t.Fatalf("EncryptAndSign: %v", err) + } + if !strings.Contains(encrypted, "BEGIN PGP MESSAGE") { + t.Error("encrypted+signed output missing PGP message header") + } + + decrypted, err := svc.Decrypt(encrypted, pass) + if err != nil { + t.Fatalf("Decrypt: %v", err) + } + if decrypted != plaintext { + t.Errorf("EncryptAndSign roundtrip mismatch: got %q, want %q", decrypted, plaintext) + } +} + +func TestEncryptAndSign_WrongPassphrase(t *testing.T) { + svc, pubArmor, _ := newLockedTestService(t) + + recipientKey, err := crypto.NewKeyFromArmored(pubArmor) + if err != nil { + t.Fatalf("parse recipient key: %v", err) + } + + _, err = svc.EncryptAndSign("content", recipientKey, "wrong-passphrase") + if err == nil { + t.Fatal("expected error for wrong passphrase") + } +} + +// ---------- SignData ---------- + +func TestSignData_Success(t *testing.T) { + svc, _, pass := newTestService(t) + + data := []byte("Data to be signed") + signed, err := svc.SignData(data, pass) + if err != nil { + t.Fatalf("SignData: %v", err) + } + if !strings.Contains(signed, "BEGIN PGP SIGNATURE") { + t.Error("signed output missing PGP signature header") + } +} + +func TestSignData_WrongPassphrase(t *testing.T) { + svc, _, _ := newLockedTestService(t) + + _, err := svc.SignData([]byte("data"), "wrong-passphrase") + if err == nil { + t.Fatal("expected error for wrong passphrase") + } +} + +func TestSignData_EmptyData(t *testing.T) { + svc, _, pass := newTestService(t) + + signed, err := svc.SignData([]byte(""), pass) + if err != nil { + t.Fatalf("SignData empty: %v", err) + } + if !strings.Contains(signed, "BEGIN PGP SIGNATURE") { + t.Error("empty data signature missing PGP signature header") + } +} + +// ---------- EncryptAttachment / DecryptAttachment ---------- + +func TestEncryptDecryptAttachment_Roundtrip(t *testing.T) { + svc, pubArmor, pass := newTestService(t) + + recipientKey, err := crypto.NewKeyFromArmored(pubArmor) + if err != nil { + t.Fatalf("parse recipient key: %v", err) + } + + original := []byte("Attachment binary content") + attachment, err := svc.EncryptAttachment(original, recipientKey) + if err != nil { + t.Fatalf("EncryptAttachment: %v", err) + } + if attachment == nil { + t.Fatal("attachment is nil") + } + if attachment.DataEnc == "" { + t.Error("DataEnc is empty") + } + if len(attachment.Keys) == 0 { + t.Error("Keys slice is empty") + } + + decrypted, err := svc.DecryptAttachment(attachment, pass) + if err != nil { + t.Fatalf("DecryptAttachment: %v", err) + } + if string(decrypted) != string(original) { + t.Errorf("attachment roundtrip mismatch: got %q, want %q", string(decrypted), string(original)) + } +} + +func TestEncryptDecryptAttachment_LargeData(t *testing.T) { + svc, pubArmor, pass := newTestService(t) + + recipientKey, err := crypto.NewKeyFromArmored(pubArmor) + if err != nil { + t.Fatalf("parse recipient key: %v", err) + } + + original := make([]byte, 10240) + for i := range original { + original[i] = byte(i % 256) + } + + attachment, err := svc.EncryptAttachment(original, recipientKey) + if err != nil { + t.Fatalf("EncryptAttachment: %v", err) + } + + decrypted, err := svc.DecryptAttachment(attachment, pass) + if err != nil { + t.Fatalf("DecryptAttachment: %v", err) + } + if len(decrypted) != len(original) { + t.Errorf("size mismatch: got %d, want %d", len(decrypted), len(original)) + } + for i := range original { + if decrypted[i] != original[i] { + t.Errorf("byte %d mismatch: got %d, want %d", i, decrypted[i], original[i]) + break + } + } +} + +func TestDecryptAttachment_NoKeys(t *testing.T) { + svc, _, pass := newTestService(t) + + attachment := &Attachment{DataEnc: "some-data"} + _, err := svc.DecryptAttachment(attachment, pass) + if err == nil { + t.Fatal("expected error for attachment with no keys") + } + if !strings.Contains(err.Error(), "no keys available") { + t.Errorf("unexpected error: %s", err.Error()) + } +} + +func TestDecryptAttachment_WrongPassphrase(t *testing.T) { + svc, pubArmor, _ := newLockedTestService(t) + + recipientKey, err := crypto.NewKeyFromArmored(pubArmor) + if err != nil { + t.Fatalf("parse recipient key: %v", err) + } + + attachment, err := svc.EncryptAttachment([]byte("content"), recipientKey) + if err != nil { + t.Fatalf("EncryptAttachment: %v", err) + } + + _, err = svc.DecryptAttachment(attachment, "wrong-passphrase") + if err == nil { + t.Fatal("expected error for wrong passphrase") + } +} + +// ---------- Cross-key Encrypt/Decrypt ---------- + +func TestEncryptDecrypt_CrossKey(t *testing.T) { + sender, senderPub, senderPass := newTestService(t) + _, _, _ = sender, senderPub, senderPass + + recipientPriv, recipientPub, _ := testKey(t) + recipientKey, err := crypto.NewKeyFromArmored(recipientPub) + if err != nil { + t.Fatalf("parse recipient pub key: %v", err) + } + + plaintext := "Cross-key encrypted message" + encrypted, err := sender.Encrypt(plaintext, recipientKey) + if err != nil { + t.Fatalf("Encrypt: %v", err) + } + + recipientSVC, err := NewPGPService(recipientPriv) + if err != nil { + t.Fatalf("NewPGPService for recipient: %v", err) + } + + decrypted, err := recipientSVC.Decrypt(encrypted, "test-passphrase") + if err != nil { + t.Fatalf("Decrypt: %v", err) + } + if decrypted != plaintext { + t.Errorf("cross-key roundtrip mismatch: got %q, want %q", decrypted, plaintext) + } +} + +// ---------- EncryptAndSign with cross-key ---------- + +func TestEncryptAndSign_CrossKey(t *testing.T) { + sender, _, senderPass := newTestService(t) + + recipientPriv, recipientPub, _ := testKey(t) + recipientKey, err := crypto.NewKeyFromArmored(recipientPub) + if err != nil { + t.Fatalf("parse recipient pub key: %v", err) + } + + plaintext := "Cross-key signed+encrypted" + encrypted, err := sender.EncryptAndSign(plaintext, recipientKey, senderPass) + if err != nil { + t.Fatalf("EncryptAndSign: %v", err) + } + + recipientSVC, err := NewPGPService(recipientPriv) + if err != nil { + t.Fatalf("NewPGPService for recipient: %v", err) + } + + decrypted, err := recipientSVC.Decrypt(encrypted, "test-passphrase") + if err != nil { + t.Fatalf("Decrypt: %v", err) + } + if decrypted != plaintext { + t.Errorf("cross-key EncryptAndSign mismatch: got %q, want %q", decrypted, plaintext) + } +} + +// ---------- Attachment cross-key ---------- + +func TestEncryptDecryptAttachment_CrossKey(t *testing.T) { + sender, _, _ := newTestService(t) + + recipientPriv, recipientPub, _ := testKey(t) + recipientKey, err := crypto.NewKeyFromArmored(recipientPub) + if err != nil { + t.Fatalf("parse recipient pub key: %v", err) + } + + original := []byte("Cross-key attachment data") + attachment, err := sender.EncryptAttachment(original, recipientKey) + if err != nil { + t.Fatalf("EncryptAttachment: %v", err) + } + + recipientSVC, err := NewPGPService(recipientPriv) + if err != nil { + t.Fatalf("NewPGPService for recipient: %v", err) + } + + decrypted, err := recipientSVC.DecryptAttachment(attachment, "test-passphrase") + if err != nil { + t.Fatalf("DecryptAttachment: %v", err) + } + if string(decrypted) != string(original) { + t.Errorf("cross-key attachment mismatch: got %q, want %q", string(decrypted), string(original)) + } +}