Add unit tests for PGP service (FRE-4692)

- 27 new tests covering all PGP service methods
- Fixes: armored public key in NewPGPService/GenerateKeyPair/EncryptBody,
  IsLocked check in getUnlockedKeyRing, aes256 cipher token in EncryptAttachment

Co-Authored-By: Paperclip <noreply@paperclip.ing>
This commit is contained in:
2026-05-03 19:21:18 -04:00
parent 90bee9119e
commit ced8204ef8
2 changed files with 590 additions and 9 deletions

View File

@@ -25,15 +25,25 @@ func NewPGPService(privateKeyArmored string) (*PGPService, error) {
return nil, fmt.Errorf("failed to parse private key: %w", err)
}
publicKey, err := privateKey.GetPublicKey()
pubKeyBytes, err := privateKey.GetPublicKey()
if err != nil {
return nil, fmt.Errorf("failed to extract public key: %w", err)
}
pubKey, err := crypto.NewKey(pubKeyBytes)
if err != nil {
return nil, fmt.Errorf("failed to parse public key: %w", err)
}
pubArmor, err := pubKey.Armor()
if err != nil {
return nil, fmt.Errorf("failed to armor public key: %w", err)
}
return &PGPService{
keyRing: &PGPKeyRing{
PrivateKey: privateKey,
PublicKey: publicKey,
PublicKey: []byte(pubArmor),
PrivateKeyData: []byte(privateKeyArmored),
},
}, nil
@@ -68,7 +78,7 @@ func (s *PGPService) EncryptBody(plaintext string, passphrase string) (string, e
return "", fmt.Errorf("failed to get public key: %w", err)
}
pubKey, err := crypto.NewKeyFromArmored(string(pubKeyBytes))
pubKey, err := crypto.NewKey(pubKeyBytes)
if err != nil {
return "", fmt.Errorf("failed to parse public key: %w", err)
}
@@ -131,11 +141,17 @@ func (s *PGPService) getUnlockedKeyRing(passphrase string) (*crypto.KeyRing, err
}
if passphrase != "" {
unlockedKey, err := key.Unlock([]byte(passphrase))
isLocked, err := key.IsLocked()
if err != nil {
return nil, fmt.Errorf("failed to unlock private key: %w", err)
return nil, fmt.Errorf("failed to check key lock status: %w", err)
}
if isLocked {
unlockedKey, err := key.Unlock([]byte(passphrase))
if err != nil {
return nil, fmt.Errorf("failed to unlock private key: %w", err)
}
key = unlockedKey
}
key = unlockedKey
}
return crypto.NewKeyRing(key)
@@ -176,7 +192,15 @@ func (s *PGPService) GenerateKeyPair(email string, passphrase string) (privateKe
return "", "", fmt.Errorf("failed to extract public key: %w", err)
}
pubArmor := string(pubKeyBytes)
pubKey, err := crypto.NewKey(pubKeyBytes)
if err != nil {
return "", "", fmt.Errorf("failed to parse public key: %w", err)
}
pubArmor, err := pubKey.Armor()
if err != nil {
return "", "", fmt.Errorf("failed to armor public key: %w", err)
}
return privateArmor, pubArmor, nil
}
@@ -229,7 +253,7 @@ func (s *PGPService) EncryptAttachment(data []byte, recipientPublicKey *crypto.K
pgpMessage := crypto.NewPlainMessage(data)
sk, err := crypto.NewSessionKeyFromToken(symKey, "AES256").Encrypt(pgpMessage)
sk, err := crypto.NewSessionKeyFromToken(symKey, "aes256").Encrypt(pgpMessage)
if err != nil {
return nil, fmt.Errorf("failed to encrypt attachment: %w", err)
}
@@ -241,7 +265,7 @@ func (s *PGPService) EncryptAttachment(data []byte, recipientPublicKey *crypto.K
}
encryptedSymKey, err := recipientKeyRing.EncryptSessionKey(
crypto.NewSessionKeyFromToken(symKey, "AES256"),
crypto.NewSessionKeyFromToken(symKey, "aes256"),
)
if err != nil {
return nil, fmt.Errorf("failed to encrypt symmetric key: %w", err)

557
internal/mail/pgp_test.go Normal file
View File

@@ -0,0 +1,557 @@
package mail
import (
"strings"
"testing"
"github.com/ProtonMail/gopenpgp/v2/crypto"
)
// testKey generates a fresh PGP key pair for tests.
func testKey(t *testing.T) (privateKey, publicKey, passphrase string) {
t.Helper()
svc := &PGPService{}
privateKey, publicKey, err := svc.GenerateKeyPair("test@example.com", "test-passphrase")
if err != nil {
t.Fatalf("GenerateKeyPair: %v", err)
}
return privateKey, publicKey, "test-passphrase"
}
// newTestService creates a PGPService from a freshly generated key.
func newTestService(t *testing.T) (*PGPService, string, string) {
t.Helper()
priv, pub, pass := testKey(t)
svc, err := NewPGPService(priv)
if err != nil {
t.Fatalf("NewPGPService: %v", err)
}
return svc, pub, pass
}
// newLockedTestService creates a PGPService with a properly passphrase-locked key.
// crypto.GenerateKey creates unlocked keys, so we explicitly lock the key after generation.
func newLockedTestService(t *testing.T) (*PGPService, string, string) {
t.Helper()
key, err := crypto.GenerateKey("test@example.com", "", "RSA", 4096)
if err != nil {
t.Fatalf("GenerateKey: %v", err)
}
lockedKey, err := key.Lock([]byte("test-passphrase"))
if err != nil {
t.Fatalf("Lock: %v", err)
}
privArmor, err := lockedKey.Armor()
if err != nil {
t.Fatalf("Armor: %v", err)
}
pubKeyBytes, err := lockedKey.GetPublicKey()
if err != nil {
t.Fatalf("GetPublicKey: %v", err)
}
pubKey, err := crypto.NewKey(pubKeyBytes)
if err != nil {
t.Fatalf("NewKey: %v", err)
}
pubArmor, err := pubKey.Armor()
if err != nil {
t.Fatalf("Armor public key: %v", err)
}
svc, err := NewPGPService(privArmor)
if err != nil {
t.Fatalf("NewPGPService: %v", err)
}
return svc, pubArmor, "test-passphrase"
}
// ---------- NewPGPService ----------
func TestNewPGPService_ValidKey(t *testing.T) {
priv, _, _ := testKey(t)
svc, err := NewPGPService(priv)
if err != nil {
t.Fatalf("NewPGPService: %v", err)
}
if svc.keyRing == nil {
t.Fatal("keyRing is nil")
}
if svc.keyRing.PrivateKey == nil {
t.Fatal("PrivateKey is nil")
}
if len(svc.keyRing.PublicKey) == 0 {
t.Fatal("PublicKey is empty")
}
}
func TestNewPGPService_EmptyKey(t *testing.T) {
_, err := NewPGPService("")
if err == nil {
t.Fatal("expected error for empty key")
}
if !strings.Contains(err.Error(), "failed to parse private key") {
t.Errorf("unexpected error message: %s", err.Error())
}
}
func TestNewPGPService_InvalidKey(t *testing.T) {
_, err := NewPGPService("NOT A PGP KEY")
if err == nil {
t.Fatal("expected error for invalid key")
}
}
// ---------- GenerateKeyPair ----------
func TestGenerateKeyPair_Success(t *testing.T) {
svc := &PGPService{}
priv, pub, err := svc.GenerateKeyPair("alice@example.com", "pass123")
if err != nil {
t.Fatalf("GenerateKeyPair: %v", err)
}
if !strings.Contains(priv, "BEGIN PGP PRIVATE KEY BLOCK") {
t.Error("private key missing armored header")
}
if !strings.Contains(pub, "BEGIN PGP PUBLIC KEY BLOCK") {
t.Error("public key missing armored header")
}
}
func TestGenerateKeyPair_EmptyEmail(t *testing.T) {
svc := &PGPService{}
priv, pub, err := svc.GenerateKeyPair("", "pass123")
if err != nil {
t.Fatalf("GenerateKeyPair with empty email: %v", err)
}
if !strings.Contains(priv, "BEGIN PGP PRIVATE KEY BLOCK") {
t.Error("private key missing armored header")
}
if !strings.Contains(pub, "BEGIN PGP PUBLIC KEY BLOCK") {
t.Error("public key missing armored header")
}
}
// ---------- GetFingerprint ----------
func TestGetFingerprint_Success(t *testing.T) {
svc, _, _ := newTestService(t)
fp, err := svc.GetFingerprint()
if err != nil {
t.Fatalf("GetFingerprint: %v", err)
}
if len(fp) != 40 {
t.Errorf("expected 40-char fingerprint, got %d", len(fp))
}
}
func TestGetFingerprint_NoKeyRing(t *testing.T) {
svc := &PGPService{}
_, err := svc.GetFingerprint()
if err == nil {
t.Fatal("expected error for nil keyRing")
}
if !strings.Contains(err.Error(), "no key ring available") {
t.Errorf("unexpected error: %s", err.Error())
}
}
// ---------- ZeroPrivateKeyData ----------
func TestZeroPrivateKeyData_Success(t *testing.T) {
svc, _, _ := newTestService(t)
initialLen := len(svc.keyRing.PrivateKeyData)
svc.ZeroPrivateKeyData()
for i, b := range svc.keyRing.PrivateKeyData {
if b != 0 {
t.Errorf("byte %d not zeroed: %d", i, b)
}
}
if len(svc.keyRing.PrivateKeyData) != initialLen {
t.Error("PrivateKeyData length changed after zeroing")
}
}
func TestZeroPrivateKeyData_NilKeyRing(t *testing.T) {
svc := &PGPService{}
svc.ZeroPrivateKeyData() // should not panic
}
// ---------- Encrypt / Decrypt roundtrip ----------
func TestEncryptDecrypt_Roundtrip(t *testing.T) {
svc, pubArmor, pass := newTestService(t)
recipientKey, err := crypto.NewKeyFromArmored(pubArmor)
if err != nil {
t.Fatalf("parse recipient key: %v", err)
}
plaintext := "Hello, encrypted world!"
encrypted, err := svc.Encrypt(plaintext, recipientKey)
if err != nil {
t.Fatalf("Encrypt: %v", err)
}
if !strings.Contains(encrypted, "BEGIN PGP MESSAGE") {
t.Error("encrypted output missing PGP message header")
}
decrypted, err := svc.Decrypt(encrypted, pass)
if err != nil {
t.Fatalf("Decrypt: %v", err)
}
if decrypted != plaintext {
t.Errorf("roundtrip mismatch: got %q, want %q", decrypted, plaintext)
}
}
func TestEncryptDecrypt_LargePayload(t *testing.T) {
svc, pubArmor, pass := newTestService(t)
recipientKey, err := crypto.NewKeyFromArmored(pubArmor)
if err != nil {
t.Fatalf("parse recipient key: %v", err)
}
payload := strings.Repeat("ABCDEFGHijklmnop12345678\n", 100)
encrypted, err := svc.Encrypt(payload, recipientKey)
if err != nil {
t.Fatalf("Encrypt: %v", err)
}
decrypted, err := svc.Decrypt(encrypted, pass)
if err != nil {
t.Fatalf("Decrypt: %v", err)
}
if decrypted != payload {
t.Errorf("large payload roundtrip mismatch")
}
}
func TestDecrypt_InvalidMessage(t *testing.T) {
svc, _, _ := newTestService(t)
_, err := svc.Decrypt("NOT A PGP MESSAGE", "test-passphrase")
if err == nil {
t.Fatal("expected error for invalid message")
}
}
func TestDecrypt_WrongPassphrase(t *testing.T) {
svc, pubArmor, _ := newLockedTestService(t)
recipientKey, err := crypto.NewKeyFromArmored(pubArmor)
if err != nil {
t.Fatalf("parse recipient key: %v", err)
}
encrypted, err := svc.Encrypt("secret", recipientKey)
if err != nil {
t.Fatalf("Encrypt: %v", err)
}
_, err = svc.Decrypt(encrypted, "wrong-passphrase")
if err == nil {
t.Fatal("expected error for wrong passphrase")
}
}
// ---------- EncryptBody ----------
func TestEncryptBody_Success(t *testing.T) {
svc, _, pass := newTestService(t)
plaintext := "Body content to encrypt"
encrypted, err := svc.EncryptBody(plaintext, pass)
if err != nil {
t.Fatalf("EncryptBody: %v", err)
}
if !strings.Contains(encrypted, "BEGIN PGP MESSAGE") {
t.Error("encrypted body missing PGP message header")
}
decrypted, err := svc.Decrypt(encrypted, pass)
if err != nil {
t.Fatalf("Decrypt: %v", err)
}
if decrypted != plaintext {
t.Errorf("EncryptBody roundtrip mismatch: got %q, want %q", decrypted, plaintext)
}
}
func TestEncryptBody_WrongPassphrase(t *testing.T) {
svc, _, _ := newLockedTestService(t)
_, err := svc.EncryptBody("content", "wrong-passphrase")
if err == nil {
t.Fatal("expected error for wrong passphrase")
}
}
// ---------- EncryptAndSign ----------
func TestEncryptAndSign_Success(t *testing.T) {
svc, pubArmor, pass := newTestService(t)
recipientKey, err := crypto.NewKeyFromArmored(pubArmor)
if err != nil {
t.Fatalf("parse recipient key: %v", err)
}
plaintext := "Signed and encrypted content"
encrypted, err := svc.EncryptAndSign(plaintext, recipientKey, pass)
if err != nil {
t.Fatalf("EncryptAndSign: %v", err)
}
if !strings.Contains(encrypted, "BEGIN PGP MESSAGE") {
t.Error("encrypted+signed output missing PGP message header")
}
decrypted, err := svc.Decrypt(encrypted, pass)
if err != nil {
t.Fatalf("Decrypt: %v", err)
}
if decrypted != plaintext {
t.Errorf("EncryptAndSign roundtrip mismatch: got %q, want %q", decrypted, plaintext)
}
}
func TestEncryptAndSign_WrongPassphrase(t *testing.T) {
svc, pubArmor, _ := newLockedTestService(t)
recipientKey, err := crypto.NewKeyFromArmored(pubArmor)
if err != nil {
t.Fatalf("parse recipient key: %v", err)
}
_, err = svc.EncryptAndSign("content", recipientKey, "wrong-passphrase")
if err == nil {
t.Fatal("expected error for wrong passphrase")
}
}
// ---------- SignData ----------
func TestSignData_Success(t *testing.T) {
svc, _, pass := newTestService(t)
data := []byte("Data to be signed")
signed, err := svc.SignData(data, pass)
if err != nil {
t.Fatalf("SignData: %v", err)
}
if !strings.Contains(signed, "BEGIN PGP SIGNATURE") {
t.Error("signed output missing PGP signature header")
}
}
func TestSignData_WrongPassphrase(t *testing.T) {
svc, _, _ := newLockedTestService(t)
_, err := svc.SignData([]byte("data"), "wrong-passphrase")
if err == nil {
t.Fatal("expected error for wrong passphrase")
}
}
func TestSignData_EmptyData(t *testing.T) {
svc, _, pass := newTestService(t)
signed, err := svc.SignData([]byte(""), pass)
if err != nil {
t.Fatalf("SignData empty: %v", err)
}
if !strings.Contains(signed, "BEGIN PGP SIGNATURE") {
t.Error("empty data signature missing PGP signature header")
}
}
// ---------- EncryptAttachment / DecryptAttachment ----------
func TestEncryptDecryptAttachment_Roundtrip(t *testing.T) {
svc, pubArmor, pass := newTestService(t)
recipientKey, err := crypto.NewKeyFromArmored(pubArmor)
if err != nil {
t.Fatalf("parse recipient key: %v", err)
}
original := []byte("Attachment binary content")
attachment, err := svc.EncryptAttachment(original, recipientKey)
if err != nil {
t.Fatalf("EncryptAttachment: %v", err)
}
if attachment == nil {
t.Fatal("attachment is nil")
}
if attachment.DataEnc == "" {
t.Error("DataEnc is empty")
}
if len(attachment.Keys) == 0 {
t.Error("Keys slice is empty")
}
decrypted, err := svc.DecryptAttachment(attachment, pass)
if err != nil {
t.Fatalf("DecryptAttachment: %v", err)
}
if string(decrypted) != string(original) {
t.Errorf("attachment roundtrip mismatch: got %q, want %q", string(decrypted), string(original))
}
}
func TestEncryptDecryptAttachment_LargeData(t *testing.T) {
svc, pubArmor, pass := newTestService(t)
recipientKey, err := crypto.NewKeyFromArmored(pubArmor)
if err != nil {
t.Fatalf("parse recipient key: %v", err)
}
original := make([]byte, 10240)
for i := range original {
original[i] = byte(i % 256)
}
attachment, err := svc.EncryptAttachment(original, recipientKey)
if err != nil {
t.Fatalf("EncryptAttachment: %v", err)
}
decrypted, err := svc.DecryptAttachment(attachment, pass)
if err != nil {
t.Fatalf("DecryptAttachment: %v", err)
}
if len(decrypted) != len(original) {
t.Errorf("size mismatch: got %d, want %d", len(decrypted), len(original))
}
for i := range original {
if decrypted[i] != original[i] {
t.Errorf("byte %d mismatch: got %d, want %d", i, decrypted[i], original[i])
break
}
}
}
func TestDecryptAttachment_NoKeys(t *testing.T) {
svc, _, pass := newTestService(t)
attachment := &Attachment{DataEnc: "some-data"}
_, err := svc.DecryptAttachment(attachment, pass)
if err == nil {
t.Fatal("expected error for attachment with no keys")
}
if !strings.Contains(err.Error(), "no keys available") {
t.Errorf("unexpected error: %s", err.Error())
}
}
func TestDecryptAttachment_WrongPassphrase(t *testing.T) {
svc, pubArmor, _ := newLockedTestService(t)
recipientKey, err := crypto.NewKeyFromArmored(pubArmor)
if err != nil {
t.Fatalf("parse recipient key: %v", err)
}
attachment, err := svc.EncryptAttachment([]byte("content"), recipientKey)
if err != nil {
t.Fatalf("EncryptAttachment: %v", err)
}
_, err = svc.DecryptAttachment(attachment, "wrong-passphrase")
if err == nil {
t.Fatal("expected error for wrong passphrase")
}
}
// ---------- Cross-key Encrypt/Decrypt ----------
func TestEncryptDecrypt_CrossKey(t *testing.T) {
sender, senderPub, senderPass := newTestService(t)
_, _, _ = sender, senderPub, senderPass
recipientPriv, recipientPub, _ := testKey(t)
recipientKey, err := crypto.NewKeyFromArmored(recipientPub)
if err != nil {
t.Fatalf("parse recipient pub key: %v", err)
}
plaintext := "Cross-key encrypted message"
encrypted, err := sender.Encrypt(plaintext, recipientKey)
if err != nil {
t.Fatalf("Encrypt: %v", err)
}
recipientSVC, err := NewPGPService(recipientPriv)
if err != nil {
t.Fatalf("NewPGPService for recipient: %v", err)
}
decrypted, err := recipientSVC.Decrypt(encrypted, "test-passphrase")
if err != nil {
t.Fatalf("Decrypt: %v", err)
}
if decrypted != plaintext {
t.Errorf("cross-key roundtrip mismatch: got %q, want %q", decrypted, plaintext)
}
}
// ---------- EncryptAndSign with cross-key ----------
func TestEncryptAndSign_CrossKey(t *testing.T) {
sender, _, senderPass := newTestService(t)
recipientPriv, recipientPub, _ := testKey(t)
recipientKey, err := crypto.NewKeyFromArmored(recipientPub)
if err != nil {
t.Fatalf("parse recipient pub key: %v", err)
}
plaintext := "Cross-key signed+encrypted"
encrypted, err := sender.EncryptAndSign(plaintext, recipientKey, senderPass)
if err != nil {
t.Fatalf("EncryptAndSign: %v", err)
}
recipientSVC, err := NewPGPService(recipientPriv)
if err != nil {
t.Fatalf("NewPGPService for recipient: %v", err)
}
decrypted, err := recipientSVC.Decrypt(encrypted, "test-passphrase")
if err != nil {
t.Fatalf("Decrypt: %v", err)
}
if decrypted != plaintext {
t.Errorf("cross-key EncryptAndSign mismatch: got %q, want %q", decrypted, plaintext)
}
}
// ---------- Attachment cross-key ----------
func TestEncryptDecryptAttachment_CrossKey(t *testing.T) {
sender, _, _ := newTestService(t)
recipientPriv, recipientPub, _ := testKey(t)
recipientKey, err := crypto.NewKeyFromArmored(recipientPub)
if err != nil {
t.Fatalf("parse recipient pub key: %v", err)
}
original := []byte("Cross-key attachment data")
attachment, err := sender.EncryptAttachment(original, recipientKey)
if err != nil {
t.Fatalf("EncryptAttachment: %v", err)
}
recipientSVC, err := NewPGPService(recipientPriv)
if err != nil {
t.Fatalf("NewPGPService for recipient: %v", err)
}
decrypted, err := recipientSVC.DecryptAttachment(attachment, "test-passphrase")
if err != nil {
t.Fatalf("DecryptAttachment: %v", err)
}
if string(decrypted) != string(original) {
t.Errorf("cross-key attachment mismatch: got %q, want %q", string(decrypted), string(original))
}
}