diff --git a/hpgsql-tests/DbUtils.hs b/hpgsql-tests/DbUtils.hs index f7695af..874d275 100644 --- a/hpgsql-tests/DbUtils.hs +++ b/hpgsql-tests/DbUtils.hs @@ -16,7 +16,7 @@ import qualified Data.Text as Text import Hpgsql (ErrorDetail (..), HPgConnection, IrrecoverableHpgsqlError (..), PostgresError (..), execute, execute_) import Hpgsql.Connection (defaultConnectOpts, withConnection, withConnectionOpts) import Hpgsql.InternalTypes (ConnectOpts (..), ConnectionString (..)) -import System.Environment (getEnv) +import System.Environment (getEnv, lookupEnv) import System.Mem (performGC) import Test.Hspec @@ -26,7 +26,8 @@ testConnInfo = do hostname <- Text.pack <$> getEnv "PGHOST" database <- Text.pack <$> getEnv "PGDATABASE" user <- Text.pack <$> getEnv "PGUSER" - pure ConnectionString {user, database, hostname, port = read portStr, password = "", options = ""} + password <- maybe "" Text.pack <$> lookupEnv "PGPASSWORD" + pure ConnectionString {user, database, hostname, port = read portStr, password, options = ""} aroundConn :: SpecWith HPgConnection -> Spec aroundConn = around $ \act -> do diff --git a/hpgsql/hpgsql.cabal b/hpgsql/hpgsql.cabal index 0259dd2..3efd2a5 100644 --- a/hpgsql/hpgsql.cabal +++ b/hpgsql/hpgsql.cabal @@ -47,6 +47,7 @@ library Hpgsql.Msgs Hpgsql.Networking Hpgsql.QueryInternal + Hpgsql.ScramSHA256 Hpgsql.SimpleParser Hpgsql.TransactionStatusInternal Paths_hpgsql @@ -99,7 +100,8 @@ library case-insensitive >= 1.2 && < 1.3, cereal >= 0.5 && < 0.6, containers >= 0.6 && < 0.8, - cryptohash-md5 >= 0.11.7.1 && < 0.12, + crypton >= 1.0.0 && < 1.1, + memory >= 0.18.0 && < 0.19, hashable >= 1.5 && < 1.6, haskell-src-meta >= 0.8 && < 0.9, network >= 3.2 && < 3.3, diff --git a/hpgsql/src/Hpgsql/Internal.hs b/hpgsql/src/Hpgsql/Internal.hs index b132819..34d9f84 100644 --- a/hpgsql/src/Hpgsql/Internal.hs +++ b/hpgsql/src/Hpgsql/Internal.hs @@ -139,10 +139,11 @@ import Hpgsql.Encoding (FieldInfo (..), FromPgRow (..), RowDecoder (..), RowEnco import Hpgsql.Encoding.RowDecoderMonadic (ConversionState (..), RowDecoderMonadic (..)) import Hpgsql.InternalTypes (BindComplete (..), CommandComplete (..), ConnectOpts (..), ConnectionString (..), CopyInResponse (..), CopyQueryState (..), DataRow (..), Either3 (..), EncodingContext (..), ErrorDetail (..), ErrorResponse (..), HPgConnection (..), InternalConnectionState (..), IrrecoverableHpgsqlError (..), NoData (..), NotificationResponse (..), ParseComplete (..), Pipeline (..), PostgresError (..), Query (..), QueryId (..), QueryProtocol (..), QueryState (..), ReadyForQuery (..), ResetConnectionOpts (..), ResponseMsg (..), ResponseMsgsReceived (..), RowDescription (..), SingleQuery (..), TransactionStatus (..), WeakThreadId (..), mkMutex, queryToByteString, throwIrrecoverableError) import Hpgsql.Locking (getMyWeakThreadId, withMutex) -import Hpgsql.Msgs (AuthenticationMethod (..), AuthenticationResponse (..), BackendKeyData (..), Bind (..), CancelRequest (..), CopyData (..), CopyDone (..), Describe (..), Execute (..), FromPgMessage (..), NoticeResponse (..), ParameterStatus (..), Parse (..), PasswordMessage (..), PgMsgParser (..), StartupMessage (..), Sync (..), Terminate (..), ToPgMessage (..), parsePgMessage) +import Hpgsql.Msgs (AuthenticationMethod (..), AuthenticationResponse (..), BackendKeyData (..), Bind (..), CancelRequest (..), CopyData (..), CopyDone (..), Describe (..), Execute (..), FromPgMessage (..), NoticeResponse (..), ParameterStatus (..), Parse (..), PasswordMessage (..), PgMsgParser (..), SASLInitialResponse (..), SASLResponse (..), StartupMessage (..), Sync (..), Terminate (..), ToPgMessage (..), parsePgMessage) import qualified Hpgsql.Msgs as Msgs import Hpgsql.Networking (recvNonBlocking, sendNonBlocking, socketWaitRead, socketWaitWrite) import Hpgsql.Query (breakQueryIntoStatements) +import qualified Hpgsql.ScramSHA256 as ScramSHA256 import qualified Hpgsql.SimpleParser as Parser import Hpgsql.TypeInfo (ArrayTypeDetails (..), TypeDetails (..), TypeInfo (..), buildTypeInfoCache, builtinPgTypesMap) import Network.Socket (AddrInfo (..)) @@ -284,6 +285,43 @@ internalConnectOrCancel connectOrCancel connOpts originalConnStr@ConnectionStrin AuthMD5Password _ -> do nonAtomicSendMsg hpgConnPartialDoNotReturn $ PasswordMessage authMethod (Text.unpack user) (Text.unpack password) receiveAuthOkOrThrow hpgConnPartialDoNotReturn + AuthSASL mechanisms + | "SCRAM-SHA-256" `elem` mechanisms -> do + -- 1. Generate and send client-first-message + clientFirstMsg <- ScramSHA256.generateClientFirstMessage + nonAtomicSendMsg hpgConnPartialDoNotReturn $ SASLInitialResponse "SCRAM-SHA-256" clientFirstMsg.fullMessage + + -- 2. Receive server-first-message + saslContinueMsg <- + receiveNextMsgUnsafe + hpgConnPartialDoNotReturn + (Right <$> msgParser @AuthenticationResponse <|> Left <$> msgParser @ErrorResponse) + serverFirst <- case saslContinueMsg of + Right (AuthenticationResponse (AuthSASLContinue d)) -> pure d + Left errResp -> throw $ IrrecoverableHpgsqlError {hpgsqlDetails = "SCRAM-SHA-256 authentication failed", innerException = Just $ toException $ mkPostgresError "" errResp, relatedStatement = Nothing} + _ -> throwIrrecoverableError "Expected AuthSASLContinue during SCRAM-SHA-256 authentication" + + -- 3. Generate and send client-final-message + clientFinalMsg <- case ScramSHA256.handleServerFirstMsg password clientFirstMsg serverFirst of + Left err -> throwIrrecoverableError $ "SCRAM-SHA-256 error: " <> Text.pack err + Right r -> pure r + nonAtomicSendMsg hpgConnPartialDoNotReturn $ SASLResponse clientFinalMsg + + -- 4. Receive server-final-message + saslFinalMsg <- + receiveNextMsgUnsafe + hpgConnPartialDoNotReturn + (Right <$> msgParser @AuthenticationResponse <|> Left <$> msgParser @ErrorResponse) + case saslFinalMsg of + Right (AuthenticationResponse (AuthSASLFinal serverFinalData)) -> + case ScramSHA256.verifyServerFinal clientFinalMsg serverFinalData of + Left err -> throwIrrecoverableError $ "SCRAM-SHA-256 server verification failed: " <> Text.pack err + Right () -> pure () + Left errResp -> throw $ IrrecoverableHpgsqlError {hpgsqlDetails = "SCRAM-SHA-256 authentication failed", innerException = Just $ toException $ mkPostgresError "" errResp, relatedStatement = Nothing} + _ -> throwIrrecoverableError "Expected AuthSASLFinal during SCRAM-SHA-256 authentication" + + -- 5. Receive AuthOk + receiveAuthOkOrThrow hpgConnPartialDoNotReturn _ -> throwIrrecoverableError $ "Hpgsql does not yet support authenticating with method " <> Text.pack (show authMethod) errorOrBackendKeyData <- receiveNextMsgUnsafe hpgConnPartialDoNotReturn $ Right <$> msgParser @BackendKeyData <|> Left <$> msgParser @ErrorResponse case errorOrBackendKeyData of diff --git a/hpgsql/src/Hpgsql/Msgs.hs b/hpgsql/src/Hpgsql/Msgs.hs index aa206cc..dc2e638 100644 --- a/hpgsql/src/Hpgsql/Msgs.hs +++ b/hpgsql/src/Hpgsql/Msgs.hs @@ -1,12 +1,13 @@ -module Hpgsql.Msgs (AuthenticationResponse (..), AuthenticationMethod (..), BackendKeyData (..), Bind (..), BindComplete (..), CancelRequest (..), CommandComplete (..), CopyData (..), CopyDone (..), CopyFail (..), CopyInResponse (..), DataRow (..), Describe (..), ErrorDetail (..), ErrorResponse (..), Execute (..), Flush (..), NoData (..), ParameterStatus (..), Query (..), ReadyForQuery (..), RowDescription (..), StartupMessage (..), ToPgMessage (..), FromPgMessage (..), PgMsgParser (..), Terminate (..), TransactionStatus (..), NoticeResponse (..), NotificationResponse (..), Parse (..), ParseComplete (..), PasswordMessage (..), Sync (..), parsePgMessage, nulTermCString) where +module Hpgsql.Msgs (AuthenticationResponse (..), AuthenticationMethod (..), BackendKeyData (..), Bind (..), BindComplete (..), CancelRequest (..), CommandComplete (..), CopyData (..), CopyDone (..), CopyFail (..), CopyInResponse (..), DataRow (..), Describe (..), ErrorDetail (..), ErrorResponse (..), Execute (..), Flush (..), NoData (..), ParameterStatus (..), Query (..), ReadyForQuery (..), RowDescription (..), SASLInitialResponse (..), SASLResponse (..), StartupMessage (..), ToPgMessage (..), FromPgMessage (..), PgMsgParser (..), Terminate (..), TransactionStatus (..), NoticeResponse (..), NotificationResponse (..), Parse (..), ParseComplete (..), PasswordMessage (..), Sync (..), parsePgMessage, nulTermCString) where import Control.Applicative (Alternative (..)) import Control.Monad (replicateM) -import qualified Crypto.Hash.MD5 as MD5 +import qualified Crypto.Hash as Crypto import qualified Data.Attoparsec.ByteString as Parsec import qualified Data.Attoparsec.ByteString.Lazy as LazyParsec import qualified Data.Attoparsec.Text as TextParsec import Data.Bifunctor (first) +import Data.ByteArray (convert) import Data.ByteString (ByteString) import qualified Data.ByteString as BS import Data.ByteString.Internal (w2c) @@ -18,11 +19,12 @@ import qualified Data.Map.Strict as Map import Data.Maybe (fromMaybe, mapMaybe) import qualified Data.Serialize as Cereal import Data.Text (Text) -import Data.Text.Encoding (decodeASCII, decodeUtf8) +import Data.Text.Encoding (decodeASCII, decodeUtf8, encodeUtf8) import Data.Word (Word8) import Hpgsql.Builder (BinaryField, Builder, builderLength) import qualified Hpgsql.Builder as Builder import Hpgsql.InternalTypes (BindComplete (..), CommandComplete (..), CopyInResponse (..), DataRow (..), ErrorDetail (..), ErrorResponse (..), NoData (..), NotificationResponse (..), ParseComplete (..), ReadyForQuery (..), RowDescription (..), TransactionStatus (..)) +import Hpgsql.ScramSHA256 (ScramClientFinalMessage (..), ScramServerFirstMessage (..)) import Hpgsql.TypeInfo (Oid (..)) class ToPgMessage a where @@ -71,7 +73,13 @@ nulTerminatedCStringParser = do newtype AuthenticationResponse = AuthenticationResponse AuthenticationMethod deriving stock (Show) -data AuthenticationMethod = AuthOk | AuthKerberosV5 | AuthCleartextPassword | AuthMD5Password LBS.ByteString | AuthGSS | AuthGSSContinue | AuthSSPI | AuthSASL | AuthSASLContinue | AuthSASLFinal +data AuthenticationMethod = AuthOk | AuthKerberosV5 | AuthCleartextPassword | AuthMD5Password LBS.ByteString | AuthGSS | AuthGSSContinue | AuthSSPI | AuthSASL [Text] | AuthSASLContinue !ScramServerFirstMessage | AuthSASLFinal ByteString + deriving stock (Show) + +data SASLInitialResponse = SASLInitialResponse Text ByteString + deriving stock (Show) + +newtype SASLResponse = SASLResponse ScramClientFinalMessage deriving stock (Show) data PasswordMessage @@ -136,19 +144,19 @@ data Terminate = Terminate instance FromPgMessage AuthenticationResponse where msgParser = PgMsgParser $ \c restOfMsg -> case c of - 'R' -> case LBS.length restOfMsg of - 4 -> - case Cereal.decodeLazy @Int32 restOfMsg of - Right 0 -> Just $ AuthenticationResponse AuthOk - Right 2 -> Just $ AuthenticationResponse AuthKerberosV5 - Right 3 -> Just $ AuthenticationResponse AuthCleartextPassword - Right 7 -> Just $ AuthenticationResponse AuthGSS - Right 9 -> Just $ AuthenticationResponse AuthSSPI - _ -> Nothing - 8 -> case first (Cereal.decodeLazy @Int32) $ LBS.splitAt 4 restOfMsg of - (Right 5, salt) -> Just $ AuthenticationResponse (AuthMD5Password salt) - _ -> Nothing -- We don't care too much about parsing every possible response yet - _ -> Nothing -- We don't care too much about parsing every possible response yet + 'R' -> case first (Cereal.decodeLazy @Int32) $ LBS.splitAt 4 restOfMsg of + (Right 0, _) -> Just $ AuthenticationResponse AuthOk + (Right 2, _) -> Just $ AuthenticationResponse AuthKerberosV5 + (Right 3, _) -> Just $ AuthenticationResponse AuthCleartextPassword + (Right 5, salt) -> Just $ AuthenticationResponse (AuthMD5Password salt) + (Right 7, _) -> Just $ AuthenticationResponse AuthGSS + (Right 9, _) -> Just $ AuthenticationResponse AuthSSPI + (Right 10, saslMechanismName) -> case Parsec.parseOnly (Parsec.many1 nulTerminatedCStringParser) (LBS.toStrict saslMechanismName) of + Right sm -> Just $ AuthenticationResponse (AuthSASL sm) + Left _ -> Nothing + (Right 11, saslData) -> Just $ AuthenticationResponse (AuthSASLContinue (ScramServerFirstMessage $ LBS.toStrict saslData)) + (Right 12, saslData) -> Just $ AuthenticationResponse (AuthSASLFinal (LBS.toStrict saslData)) + _ -> Nothing _ -> Nothing instance FromPgMessage BackendKeyData where @@ -250,12 +258,29 @@ instance ToPgMessage PasswordMessage where AuthMD5Password (LBS.toStrict -> salt) -> let passwordBytes = Builder.toStrictByteString (Builder.string7 password) usernameBytes = Builder.toStrictByteString (Builder.string7 username) - innerHex = bytestringToHex (MD5.hash (passwordBytes <> usernameBytes)) - outerHex = bytestringToHex (MD5.hash (innerHex <> salt)) + innerHex = bytestringToHex (md5Hash (passwordBytes <> usernameBytes)) + outerHex = bytestringToHex (md5Hash (innerHex <> salt)) in Builder.byteString "md5" <> Builder.byteString outerHex <> Builder.word8 0 _ -> error "PasswordMessage method not implemented" msgLen = builderLength passwordBs + 4 in Builder.char7 'p' <> Builder.int32BE msgLen <> passwordBs + where + md5Hash :: ByteString -> ByteString + md5Hash = convert @(Crypto.Digest Crypto.MD5) . Crypto.hash + +instance ToPgMessage SASLInitialResponse where + toPgMessage (SASLInitialResponse mechanism clientFirstMsg) = + let mechanismBs = Builder.byteString (encodeUtf8 mechanism) <> Builder.word8 0 + dataLen = fromIntegral (BS.length clientFirstMsg) :: Int32 + contents = mechanismBs <> Builder.int32BE dataLen <> Builder.byteString clientFirstMsg + msgLen = builderLength contents + 4 + in Builder.char7 'p' <> Builder.int32BE msgLen <> contents + +instance ToPgMessage SASLResponse where + toPgMessage (SASLResponse responseData) = + let contents = Builder.byteString responseData.clientFinalMessage + msgLen = builderLength contents + 4 + in Builder.char7 'p' <> Builder.int32BE msgLen <> contents instance ToPgMessage Query where toPgMessage (Query bs) = diff --git a/hpgsql/src/Hpgsql/ScramSHA256.hs b/hpgsql/src/Hpgsql/ScramSHA256.hs new file mode 100644 index 0000000..46012ad --- /dev/null +++ b/hpgsql/src/Hpgsql/ScramSHA256.hs @@ -0,0 +1,192 @@ +module Hpgsql.ScramSHA256 + ( ScramClientFirstMessage (..), + ScramServerFirstMessage (..), + ScramClientFinalMessage (..), + generateClientFirstMessage, + handleServerFirstMsg, + verifyServerFinal, + ) +where + +-- https://datatracker.ietf.org/doc/html/rfc5802#section-3 +-- The RFC for SCRAM-SHA-1 (which SCRAM-SHA-256 follows except it uses SHA256) +-- says this: +-- +-- "Informative Note: Implementors are encouraged to create test cases +-- that use both usernames and passwords with non-ASCII codepoints. In +-- particular, it's useful to test codepoints whose "Unicode +-- Normalization Form C" and "Unicode Normalization Form KC" are +-- different. Some examples of such codepoints include Vulgar Fraction +-- One Half (U+00BD) and Acute Accent (U+00B4)." +-- +-- TODO: We should test these scenarios, but we don't yet. + +import Control.Applicative ((<|>)) +import Control.Monad (void) +import Crypto.Hash (Digest, SHA256, hash) +import Crypto.KDF.PBKDF2 (Parameters (..), fastPBKDF2_SHA256) +import Crypto.MAC.HMAC (HMAC, hmac, hmacGetDigest) +import Crypto.Random (getRandomBytes) +import Data.Attoparsec.ByteString (parseOnly, string) +import qualified Data.Attoparsec.ByteString.Char8 as Parsec +import Data.ByteArray (convert, xor) +import Data.ByteArray.Encoding (Base (Base64), convertFromBase, convertToBase) +import Data.ByteString (ByteString) +import qualified Data.ByteString as BS +import qualified Data.ByteString.Char8 as BS8 +import Data.Text (Text) +import Data.Text.Encoding (encodeUtf8) + +-- From the RFC: +-- SCRAM is a SASL mechanism whose client response and server challenge +-- messages are text-based messages containing one or more attribute- +-- value pairs separated by commas. Each attribute has a one-letter +-- name. The messages and their attributes are described in +-- Section 5.1, and defined in Section 7. + +-- SCRAM is a client-first SASL mechanism (see [RFC4422], Section 5, +-- item 2a), and returns additional data together with a server's +-- indication of a successful outcome. + +-- This is a simple example of a SCRAM-SHA-1 authentication exchange +-- when the client doesn't support channel bindings (username 'user' and +-- password 'pencil' are used): + +-- C: n,,n=user,r=fyko+d2lbbFgONRv9qkxdawL +-- S: r=fyko+d2lbbFgONRv9qkxdawL3rfcNHYJY1ZVvWVs7j,s=QSXCR+Q6sek8bf92, +-- i=4096 +-- C: c=biws,r=fyko+d2lbbFgONRv9qkxdawL3rfcNHYJY1ZVvWVs7j, +-- p=v0X8v3Bz2T0CJGbJQyF0X+HI4Ts= +-- S: v=rmF9pqV8S7suAoZWja4dJRkFsKQ= +-- +-- +-- And from Postgres (https://www.postgresql.org/docs/current/sasl-authentication.html): +-- 1. The server sends an AuthenticationSASL message. It includes a list of SASL authentication mechanisms that the server can accept. This will be SCRAM-SHA-256-PLUS and SCRAM-SHA-256 if the server is built with SSL support, or else just the latter. +-- 2. The client responds by sending a SASLInitialResponse message, which indicates the chosen mechanism, SCRAM-SHA-256 or SCRAM-SHA-256-PLUS. (A client is free to choose either mechanism, but for better security it should choose the channel-binding variant if it can support it.) In the Initial Client response field, the message contains the SCRAM client-first-message. The client-first-message also contains the channel binding type chosen by the client. +-- TODO: hpgsql right now only supports SCRAM-SHA-256, not the -PLUS variant. +-- 3. Server sends an AuthenticationSASLContinue message, with a SCRAM server-first-message as the content. +-- 4. Client sends a SASLResponse message, with SCRAM client-final-message as the content. +-- 5. Server sends an AuthenticationSASLFinal message, with the SCRAM server-final-message, followed immediately by an AuthenticationOk message. + +-- | The "client-first-message", something like `n,,n=user,r=nonce`, where +-- the nonce is something like `fyko+d2lbbFgONRv9qkxdawL`. +data ScramClientFirstMessage = ScramClientFirstMessage + { -- | This is the string "n,," + gs2Header :: !ByteString, + -- | This is the string "n=user,r=nonce", except that "user" is the + -- empty string because postgres ignores it anyway. + clientFirstBare :: !ByteString, + -- | This is just the concatenation of the two strings above. + fullMessage :: !ByteString + } + +-- | Generate the client-first-message for SCRAM-SHA-256. +generateClientFirstMessage :: IO ScramClientFirstMessage +generateClientFirstMessage = do + -- We don't send a username in this message even if it's necessary according + -- to the RFC, because postgres says + -- "When SCRAM-SHA-256 is used in PostgreSQL, the server will ignore the user name that the client sends in the client-first-message. The user name that was already sent in the startup message is used instead." + nonce <- generateNonce + let gs2Header = "n,," + clientFirstBare = "n=,r=" <> nonce + fullMessage = gs2Header <> clientFirstBare + pure ScramClientFirstMessage {..} + where + generateNonce :: IO ByteString + generateNonce = convertToBase Base64 <$> (getRandomBytes 24 :: IO ByteString) + +newtype ScramServerFirstMessage = ScramServerFirstMessage + { -- | As per the RFC, this looks something like (notice the concatenation + -- of client and server nonces, and some salt): + -- S: r=ClientNonceServerNonce,s=QSXCR+Q6sek8bf92, + -- i=4096 + fullMessage :: ByteString + } + deriving stock (Show) + +data ScramClientFinalMessage = ScramClientFinalMessage + { -- | This looks something like: + -- c=biws,r=fyko+d2lbbFgONRv9qkxdawL3rfcNHYJY1ZVvWVs7j, + -- p=v0X8v3Bz2T0CJGbJQyF0X+HI4Ts= + clientFinalMessage :: !ByteString, + serverSignature :: !ByteString + } + deriving stock (Show) + +-- | From the RFC: +-- In response, the server sends a "server-first-message" containing the +-- user's iteration count i and the user's salt, and appends its own +-- nonce to the client-specified one. + +-- The client then responds by sending a "client-final-message" with the +-- same nonce and a ClientProof computed using the selected hash +-- function as explained earlier. +handleServerFirstMsg :: + Text -> + ScramClientFirstMessage -> + ScramServerFirstMessage -> + Either String ScramClientFinalMessage +handleServerFirstMsg (encodeUtf8 -> password) ScramClientFirstMessage {gs2Header, clientFirstBare} serverFirst = do + (serverNonce, salt, iterations) <- parseServerFirst serverFirst.fullMessage + let saltedPassword = hi salt iterations -- "hi" takes the password implicitly + clientKey = hmacSHA256 saltedPassword "Client Key" + storedKey = hashSHA256 clientKey + channelBinding = "c=" <> convertToBase Base64 gs2Header + clientFinalWithoutProof = channelBinding <> ",r=" <> serverNonce + authMessage = clientFirstBare <> "," <> serverFirst.fullMessage <> "," <> clientFinalWithoutProof + clientSignature = hmacSHA256 storedKey authMessage + clientProof = Data.ByteArray.xor clientKey clientSignature :: ByteString + clientFinalMessage = clientFinalWithoutProof <> ",p=" <> convertToBase Base64 clientProof + serverKey = hmacSHA256 saltedPassword "Server Key" + serverSignature = convertToBase Base64 (hmacSHA256 serverKey authMessage) :: ByteString + pure ScramClientFinalMessage {clientFinalMessage, serverSignature} + where + parseServerFirst :: ByteString -> Either String (ByteString, ByteString, Int) + parseServerFirst = parseOnly $ do + -- As per the RFC, this looks something like (notice the concatenation + -- of client and server nonces, and some salt): + -- S: r=ClientNonceServerNonce,s=QSXCR+Q6sek8bf92, + -- i=4096 + void $ string "r=" <|> fail "Missing nonce in server-first-message (1)" + nonce <- Parsec.takeWhile1 (/= ',') <|> fail "Missing nonce in server-first-message (2)" + void $ string "," + void $ string "s=" <|> fail "Missing salt in server-first-message (1)" + saltB64 <- Parsec.takeWhile1 (/= ',') <|> fail "Missing salt in server-first-message (2)" + void $ string "," + salt <- case convertFromBase Base64 saltB64 of + Left e -> fail $ "Error converting salt from base64: " ++ e + Right s -> pure s + void $ string "i=" <|> fail "Missing iteration count in server-first-message" + iters <- Parsec.decimal <|> fail "Iteration count not a number in server-first-message" + Parsec.endOfInput + pure (nonce, salt, iters) + hi :: ByteString -> Int -> ByteString + hi salt iterations = + fastPBKDF2_SHA256 + Parameters {iterCounts = iterations, outputLength = 32} + password + salt + +-- | Verify the server-final-message matches the expected server signature. +verifyServerFinal :: ScramClientFinalMessage -> ByteString -> Either String () +verifyServerFinal finalMsg serverFinal = do + actualSig <- parseServerFinal serverFinal + if actualSig == finalMsg.serverSignature + then Right () + else Left "Server signature mismatch" + where + parseServerFinal :: ByteString -> Either String ByteString + parseServerFinal msg + -- This looks like: + -- v=rmF9pqV8S7suAoZWja4dJRkFsKQ= + -- And v is a base64-encoded ServerSignature, or it looks like + -- e=some-error + | BS.isPrefixOf "v=" msg = Right (BS.drop 2 msg) + | BS.isPrefixOf "e=" msg = Left $ "Server error: " <> BS8.unpack (BS.drop 2 msg) + | otherwise = Left $ "Invalid server-final-message: " <> BS8.unpack msg + +hmacSHA256 :: ByteString -> ByteString -> ByteString +hmacSHA256 key msg = convert $ hmacGetDigest (hmac key msg :: HMAC SHA256) + +hashSHA256 :: ByteString -> ByteString +hashSHA256 bs = convert (hash bs :: Digest SHA256) diff --git a/hpgsql/src/Hpgsql/Transaction.hs b/hpgsql/src/Hpgsql/Transaction.hs index 89a8ab2..cef6ffc 100644 --- a/hpgsql/src/Hpgsql/Transaction.hs +++ b/hpgsql/src/Hpgsql/Transaction.hs @@ -1,7 +1,7 @@ module Hpgsql.Transaction (withTransaction, withTransactionMode, begin, beginMode, commit, rollback, transactionStatus, IsolationLevel (..), ReadWriteMode (..), TransactionStatus (..)) where import qualified Control.Concurrent.STM as STM -import Control.Exception.Safe (Exception (..), bracketWithError, throw, tryAny, tryJust) +import Control.Exception.Safe (Exception (..), bracketWithError, throw, tryJust) import Control.Monad (unless) import Hpgsql.Internal (execute_, fullTransactionStatus, transactionStatus) import Hpgsql.InternalTypes (HPgConnection (..), InternalConnectionState (..), IrrecoverableHpgsqlError) diff --git a/nix/run-db-tests.nix b/nix/run-db-tests.nix index e8a1f10..8cac7d1 100644 --- a/nix/run-db-tests.nix +++ b/nix/run-db-tests.nix @@ -23,7 +23,7 @@ in pg_ctl -l "$out/pg_ctl_init.log" start scripts/wait-for-pg-ready.sh # Create things needed by tests - psql -c "CREATE EXTENSION citext; CREATE USER user_pass WITH PASSWORD 'hpgsql-password'; SET password_encryption = 'md5'; CREATE USER user_md5 WITH PASSWORD 'hpgsql-password'" + psql -c "CREATE EXTENSION citext; CREATE USER user_pass WITH PASSWORD 'hpgsql-password'; CREATE USER user_md5 WITH PASSWORD 'hpgsql-password'" ${hpgsql-tests}/bin/hpgsql-tests ${hspecArgs} diff --git a/nix/run-hpgsql-simple-compat-db-tests.nix b/nix/run-hpgsql-simple-compat-db-tests.nix index 5fd7a65..d7be572 100644 --- a/nix/run-hpgsql-simple-compat-db-tests.nix +++ b/nix/run-hpgsql-simple-compat-db-tests.nix @@ -25,7 +25,7 @@ in scripts/wait-for-pg-ready.sh # Create things needed by tests - psql -c "CREATE EXTENSION citext; CREATE USER user_pass WITH PASSWORD 'hpgsql-password'; SET password_encryption = 'md5'; CREATE USER user_md5 WITH PASSWORD 'hpgsql-password'" + psql -c "CREATE EXTENSION citext; CREATE USER user_pass WITH PASSWORD 'hpgsql-password'; CREATE USER user_md5 WITH PASSWORD 'hpgsql-password'" ${hpgsql-simple-compat-tests}/bin/hpgsql-simple-compat-tests ${hspecArgs} diff --git a/nix/test-shell-pg.nix b/nix/test-shell-pg.nix index 90f2ef5..5999b6f 100644 --- a/nix/test-shell-pg.nix +++ b/nix/test-shell-pg.nix @@ -20,6 +20,6 @@ pkgs.mkShell { scripts/wait-for-pg-ready.sh # Create things needed by tests - psql -c "CREATE EXTENSION citext; CREATE USER user_pass WITH PASSWORD 'hpgsql-password'; SET password_encryption = 'md5'; CREATE USER user_md5 WITH PASSWORD 'hpgsql-password'" + psql -c "CREATE EXTENSION citext; CREATE USER user_pass WITH PASSWORD 'hpgsql-password'; CREATE USER user_md5 WITH PASSWORD 'hpgsql-password'" ''; }