diff --git a/crypto/rsa/key_test.go b/crypto/rsa/key_test.go index bccbfeb..e862481 100644 --- a/crypto/rsa/key_test.go +++ b/crypto/rsa/key_test.go @@ -1,6 +1,7 @@ package rsa import ( + "crypto/x509" "encoding/base64" "testing" @@ -165,3 +166,101 @@ MCB+kOgWk51uJwuiuHlffGMBPxku/t+skxI7Bw==` require.True(t, pub.VerifyASN1([]byte("message"), sig)) } + +func TestPublicKeyPKCS1RoundTrip(t *testing.T) { + pub, _, err := GenerateKeyPair(2048) + require.NoError(t, err) + + der := x509.MarshalPKCS1PublicKey(pub.Unwrap()) + rt, err := PublicKeyFromPKCS1DER(der) + require.NoError(t, err) + require.True(t, pub.Equal(rt)) +} + +func TestPublicKeyFromNERoundTrip(t *testing.T) { + pub, _, err := GenerateKeyPair(2048) + require.NoError(t, err) + + rt, err := PublicKeyFromNE(pub.NBytes(), pub.EBytes()) + require.NoError(t, err) + require.True(t, pub.Equal(rt)) +} + +func TestPrivateKeyFromNEDPQRoundTrip(t *testing.T) { + pub, priv, err := GenerateKeyPair(2048) + require.NoError(t, err) + + rt, err := PrivateKeyFromNEDPQ(pub.NBytes(), pub.EBytes(), priv.DBytes(), priv.PBytes(), priv.QBytes()) + require.NoError(t, err) + require.True(t, priv.Equal(rt)) + require.True(t, pub.Equal(rt.Public())) +} + +func TestRejectWeirdPublicKeyInputs(t *testing.T) { + // Reuse a known-good 2048-bit modulus so each case changes only the input under test. + validModulus, err := base64.RawURLEncoding.DecodeString("sbX82NTV6IylxCh7MfV4hlyvaniCajuP97GyOqSvTmoEdBOflFvZ06kR_9D6ctt45Fk6hskfnag2GG69NALVH2o4RCR6tQiLRpKcMRtDYE_thEmfBvDzm_VVkOIYfxu-Ipuo9J_S5XDNDjczx2v-3oDh5-CIHkU46hvFeCvpUS-L8TJSbgX0kjVk_m4eIb9wh63rtmD6Uz_KBtCo5mmR4TEtcLZKYdqMp3wCjN-TlgHiz_4oVXWbHUefCEe8rFnX1iQnpDHU49_SaXQoud1jCaexFn25n-Aa8f8bc5Vm-5SeRwidHa6ErvEhTvf1dz6GoNPp2iRvm-wJ1gxwWJEYPQ") + require.NoError(t, err) + + evenModulus := append([]byte{}, validModulus...) + evenModulus[len(evenModulus)-1] &^= 1 + + // Build a 8193-bit odd modulus to exceed the package's 8192-bit upper bound. + tooLargeModulus := make([]byte, 1025) + tooLargeModulus[0] = 0x80 + tooLargeModulus[len(tooLargeModulus)-1] = 0x01 + + // This is 2^63, which does not fit in a signed int64 and must be rejected. + exponentTooLarge := []byte{0x80, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00} + + for _, tc := range []struct { + name string + n []byte + e []byte + }{ + { + name: "empty modulus", + n: nil, + e: []byte{0x03}, + }, + { + name: "too small modulus", + n: validModulus[:len(validModulus)-1], + e: []byte{0x03}, + }, + { + name: "too large modulus", + n: tooLargeModulus, + e: []byte{0x03}, + }, + { + name: "even modulus", + n: evenModulus, + e: []byte{0x03}, + }, + { + name: "empty exponent", + n: validModulus, + e: nil, + }, + { + name: "exponent one", + n: validModulus, + e: []byte{0x01}, + }, + { + name: "even exponent", + n: validModulus, + e: []byte{0x02}, + }, + { + name: "exponent too large", + n: validModulus, + e: exponentTooLarge, + }, + } { + t.Run(tc.name, func(t *testing.T) { + _, err := PublicKeyFromNE(tc.n, tc.e) + require.Error(t, err) + }) + } +} diff --git a/crypto/rsa/public.go b/crypto/rsa/public.go index d89b2e6..e093c5c 100644 --- a/crypto/rsa/public.go +++ b/crypto/rsa/public.go @@ -53,6 +53,9 @@ func PublicKeyFromNE(n, e []byte) (*PublicKey, error) { if eBInt.Sign() <= 0 { return nil, fmt.Errorf("exponent must be positive") } + if eBInt.Cmp(big.NewInt(2)) < 0 { + return nil, fmt.Errorf("exponent too small") + } if eBInt.Bit(0) == 0 { return nil, fmt.Errorf("exponent must be odd") }