Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions hpgsql-tests/DbUtils.hs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ import qualified Data.Text as Text
import Hpgsql (ErrorDetail (..), HPgConnection, IrrecoverableHpgsqlError (..), PostgresError (..), execute, execute_)
import Hpgsql.Connection (defaultConnectOpts, withConnection, withConnectionOpts)
import Hpgsql.InternalTypes (ConnectOpts (..), ConnectionString (..))
import System.Environment (getEnv)
import System.Environment (getEnv, lookupEnv)
import System.Mem (performGC)
import Test.Hspec

Expand All @@ -26,7 +26,8 @@ testConnInfo = do
hostname <- Text.pack <$> getEnv "PGHOST"
database <- Text.pack <$> getEnv "PGDATABASE"
user <- Text.pack <$> getEnv "PGUSER"
pure ConnectionString {user, database, hostname, port = read portStr, password = "", options = ""}
password <- maybe "" Text.pack <$> lookupEnv "PGPASSWORD"
pure ConnectionString {user, database, hostname, port = read portStr, password, options = ""}

aroundConn :: SpecWith HPgConnection -> Spec
aroundConn = around $ \act -> do
Expand Down
4 changes: 3 additions & 1 deletion hpgsql/hpgsql.cabal
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ library
Hpgsql.Msgs
Hpgsql.Networking
Hpgsql.QueryInternal
Hpgsql.ScramSHA256
Hpgsql.SimpleParser
Hpgsql.TransactionStatusInternal
Paths_hpgsql
Expand Down Expand Up @@ -99,7 +100,8 @@ library
case-insensitive >= 1.2 && < 1.3,
cereal >= 0.5 && < 0.6,
containers >= 0.6 && < 0.8,
cryptohash-md5 >= 0.11.7.1 && < 0.12,
crypton >= 1.0.0 && < 1.1,
memory >= 0.18.0 && < 0.19,
hashable >= 1.5 && < 1.6,
haskell-src-meta >= 0.8 && < 0.9,
network >= 3.2 && < 3.3,
Expand Down
40 changes: 39 additions & 1 deletion hpgsql/src/Hpgsql/Internal.hs
Original file line number Diff line number Diff line change
Expand Up @@ -139,10 +139,11 @@ import Hpgsql.Encoding (FieldInfo (..), FromPgRow (..), RowDecoder (..), RowEnco
import Hpgsql.Encoding.RowDecoderMonadic (ConversionState (..), RowDecoderMonadic (..))
import Hpgsql.InternalTypes (BindComplete (..), CommandComplete (..), ConnectOpts (..), ConnectionString (..), CopyInResponse (..), CopyQueryState (..), DataRow (..), Either3 (..), EncodingContext (..), ErrorDetail (..), ErrorResponse (..), HPgConnection (..), InternalConnectionState (..), IrrecoverableHpgsqlError (..), NoData (..), NotificationResponse (..), ParseComplete (..), Pipeline (..), PostgresError (..), Query (..), QueryId (..), QueryProtocol (..), QueryState (..), ReadyForQuery (..), ResetConnectionOpts (..), ResponseMsg (..), ResponseMsgsReceived (..), RowDescription (..), SingleQuery (..), TransactionStatus (..), WeakThreadId (..), mkMutex, queryToByteString, throwIrrecoverableError)
import Hpgsql.Locking (getMyWeakThreadId, withMutex)
import Hpgsql.Msgs (AuthenticationMethod (..), AuthenticationResponse (..), BackendKeyData (..), Bind (..), CancelRequest (..), CopyData (..), CopyDone (..), Describe (..), Execute (..), FromPgMessage (..), NoticeResponse (..), ParameterStatus (..), Parse (..), PasswordMessage (..), PgMsgParser (..), StartupMessage (..), Sync (..), Terminate (..), ToPgMessage (..), parsePgMessage)
import Hpgsql.Msgs (AuthenticationMethod (..), AuthenticationResponse (..), BackendKeyData (..), Bind (..), CancelRequest (..), CopyData (..), CopyDone (..), Describe (..), Execute (..), FromPgMessage (..), NoticeResponse (..), ParameterStatus (..), Parse (..), PasswordMessage (..), PgMsgParser (..), SASLInitialResponse (..), SASLResponse (..), StartupMessage (..), Sync (..), Terminate (..), ToPgMessage (..), parsePgMessage)
import qualified Hpgsql.Msgs as Msgs
import Hpgsql.Networking (recvNonBlocking, sendNonBlocking, socketWaitRead, socketWaitWrite)
import Hpgsql.Query (breakQueryIntoStatements)
import qualified Hpgsql.ScramSHA256 as ScramSHA256
import qualified Hpgsql.SimpleParser as Parser
import Hpgsql.TypeInfo (ArrayTypeDetails (..), TypeDetails (..), TypeInfo (..), buildTypeInfoCache, builtinPgTypesMap)
import Network.Socket (AddrInfo (..))
Expand Down Expand Up @@ -284,6 +285,43 @@ internalConnectOrCancel connectOrCancel connOpts originalConnStr@ConnectionStrin
AuthMD5Password _ -> do
nonAtomicSendMsg hpgConnPartialDoNotReturn $ PasswordMessage authMethod (Text.unpack user) (Text.unpack password)
receiveAuthOkOrThrow hpgConnPartialDoNotReturn
AuthSASL mechanisms
| "SCRAM-SHA-256" `elem` mechanisms -> do
-- 1. Generate and send client-first-message
clientFirstMsg <- ScramSHA256.generateClientFirstMessage
nonAtomicSendMsg hpgConnPartialDoNotReturn $ SASLInitialResponse "SCRAM-SHA-256" clientFirstMsg.fullMessage

-- 2. Receive server-first-message
saslContinueMsg <-
receiveNextMsgUnsafe
hpgConnPartialDoNotReturn
(Right <$> msgParser @AuthenticationResponse <|> Left <$> msgParser @ErrorResponse)
serverFirst <- case saslContinueMsg of
Right (AuthenticationResponse (AuthSASLContinue d)) -> pure d
Left errResp -> throw $ IrrecoverableHpgsqlError {hpgsqlDetails = "SCRAM-SHA-256 authentication failed", innerException = Just $ toException $ mkPostgresError "" errResp, relatedStatement = Nothing}
_ -> throwIrrecoverableError "Expected AuthSASLContinue during SCRAM-SHA-256 authentication"

-- 3. Generate and send client-final-message
clientFinalMsg <- case ScramSHA256.handleServerFirstMsg password clientFirstMsg serverFirst of
Left err -> throwIrrecoverableError $ "SCRAM-SHA-256 error: " <> Text.pack err
Right r -> pure r
nonAtomicSendMsg hpgConnPartialDoNotReturn $ SASLResponse clientFinalMsg

-- 4. Receive server-final-message
saslFinalMsg <-
receiveNextMsgUnsafe
hpgConnPartialDoNotReturn
(Right <$> msgParser @AuthenticationResponse <|> Left <$> msgParser @ErrorResponse)
case saslFinalMsg of
Right (AuthenticationResponse (AuthSASLFinal serverFinalData)) ->
case ScramSHA256.verifyServerFinal clientFinalMsg serverFinalData of
Left err -> throwIrrecoverableError $ "SCRAM-SHA-256 server verification failed: " <> Text.pack err
Right () -> pure ()
Left errResp -> throw $ IrrecoverableHpgsqlError {hpgsqlDetails = "SCRAM-SHA-256 authentication failed", innerException = Just $ toException $ mkPostgresError "" errResp, relatedStatement = Nothing}
_ -> throwIrrecoverableError "Expected AuthSASLFinal during SCRAM-SHA-256 authentication"

-- 5. Receive AuthOk
receiveAuthOkOrThrow hpgConnPartialDoNotReturn
_ -> throwIrrecoverableError $ "Hpgsql does not yet support authenticating with method " <> Text.pack (show authMethod)
errorOrBackendKeyData <- receiveNextMsgUnsafe hpgConnPartialDoNotReturn $ Right <$> msgParser @BackendKeyData <|> Left <$> msgParser @ErrorResponse
case errorOrBackendKeyData of
Expand Down
63 changes: 44 additions & 19 deletions hpgsql/src/Hpgsql/Msgs.hs
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
module Hpgsql.Msgs (AuthenticationResponse (..), AuthenticationMethod (..), BackendKeyData (..), Bind (..), BindComplete (..), CancelRequest (..), CommandComplete (..), CopyData (..), CopyDone (..), CopyFail (..), CopyInResponse (..), DataRow (..), Describe (..), ErrorDetail (..), ErrorResponse (..), Execute (..), Flush (..), NoData (..), ParameterStatus (..), Query (..), ReadyForQuery (..), RowDescription (..), StartupMessage (..), ToPgMessage (..), FromPgMessage (..), PgMsgParser (..), Terminate (..), TransactionStatus (..), NoticeResponse (..), NotificationResponse (..), Parse (..), ParseComplete (..), PasswordMessage (..), Sync (..), parsePgMessage, nulTermCString) where
module Hpgsql.Msgs (AuthenticationResponse (..), AuthenticationMethod (..), BackendKeyData (..), Bind (..), BindComplete (..), CancelRequest (..), CommandComplete (..), CopyData (..), CopyDone (..), CopyFail (..), CopyInResponse (..), DataRow (..), Describe (..), ErrorDetail (..), ErrorResponse (..), Execute (..), Flush (..), NoData (..), ParameterStatus (..), Query (..), ReadyForQuery (..), RowDescription (..), SASLInitialResponse (..), SASLResponse (..), StartupMessage (..), ToPgMessage (..), FromPgMessage (..), PgMsgParser (..), Terminate (..), TransactionStatus (..), NoticeResponse (..), NotificationResponse (..), Parse (..), ParseComplete (..), PasswordMessage (..), Sync (..), parsePgMessage, nulTermCString) where

import Control.Applicative (Alternative (..))
import Control.Monad (replicateM)
import qualified Crypto.Hash.MD5 as MD5
import qualified Crypto.Hash as Crypto
import qualified Data.Attoparsec.ByteString as Parsec
import qualified Data.Attoparsec.ByteString.Lazy as LazyParsec
import qualified Data.Attoparsec.Text as TextParsec
import Data.Bifunctor (first)
import Data.ByteArray (convert)
import Data.ByteString (ByteString)
import qualified Data.ByteString as BS
import Data.ByteString.Internal (w2c)
Expand All @@ -18,11 +19,12 @@ import qualified Data.Map.Strict as Map
import Data.Maybe (fromMaybe, mapMaybe)
import qualified Data.Serialize as Cereal
import Data.Text (Text)
import Data.Text.Encoding (decodeASCII, decodeUtf8)
import Data.Text.Encoding (decodeASCII, decodeUtf8, encodeUtf8)
import Data.Word (Word8)
import Hpgsql.Builder (BinaryField, Builder, builderLength)
import qualified Hpgsql.Builder as Builder
import Hpgsql.InternalTypes (BindComplete (..), CommandComplete (..), CopyInResponse (..), DataRow (..), ErrorDetail (..), ErrorResponse (..), NoData (..), NotificationResponse (..), ParseComplete (..), ReadyForQuery (..), RowDescription (..), TransactionStatus (..))
import Hpgsql.ScramSHA256 (ScramClientFinalMessage (..), ScramServerFirstMessage (..))
import Hpgsql.TypeInfo (Oid (..))

class ToPgMessage a where
Expand Down Expand Up @@ -71,7 +73,13 @@ nulTerminatedCStringParser = do
newtype AuthenticationResponse = AuthenticationResponse AuthenticationMethod
deriving stock (Show)

data AuthenticationMethod = AuthOk | AuthKerberosV5 | AuthCleartextPassword | AuthMD5Password LBS.ByteString | AuthGSS | AuthGSSContinue | AuthSSPI | AuthSASL | AuthSASLContinue | AuthSASLFinal
data AuthenticationMethod = AuthOk | AuthKerberosV5 | AuthCleartextPassword | AuthMD5Password LBS.ByteString | AuthGSS | AuthGSSContinue | AuthSSPI | AuthSASL [Text] | AuthSASLContinue !ScramServerFirstMessage | AuthSASLFinal ByteString
deriving stock (Show)

data SASLInitialResponse = SASLInitialResponse Text ByteString
deriving stock (Show)

newtype SASLResponse = SASLResponse ScramClientFinalMessage
deriving stock (Show)

data PasswordMessage
Expand Down Expand Up @@ -136,19 +144,19 @@ data Terminate = Terminate

instance FromPgMessage AuthenticationResponse where
msgParser = PgMsgParser $ \c restOfMsg -> case c of
'R' -> case LBS.length restOfMsg of
4 ->
case Cereal.decodeLazy @Int32 restOfMsg of
Right 0 -> Just $ AuthenticationResponse AuthOk
Right 2 -> Just $ AuthenticationResponse AuthKerberosV5
Right 3 -> Just $ AuthenticationResponse AuthCleartextPassword
Right 7 -> Just $ AuthenticationResponse AuthGSS
Right 9 -> Just $ AuthenticationResponse AuthSSPI
_ -> Nothing
8 -> case first (Cereal.decodeLazy @Int32) $ LBS.splitAt 4 restOfMsg of
(Right 5, salt) -> Just $ AuthenticationResponse (AuthMD5Password salt)
_ -> Nothing -- We don't care too much about parsing every possible response yet
_ -> Nothing -- We don't care too much about parsing every possible response yet
'R' -> case first (Cereal.decodeLazy @Int32) $ LBS.splitAt 4 restOfMsg of
(Right 0, _) -> Just $ AuthenticationResponse AuthOk
(Right 2, _) -> Just $ AuthenticationResponse AuthKerberosV5
(Right 3, _) -> Just $ AuthenticationResponse AuthCleartextPassword
(Right 5, salt) -> Just $ AuthenticationResponse (AuthMD5Password salt)
(Right 7, _) -> Just $ AuthenticationResponse AuthGSS
(Right 9, _) -> Just $ AuthenticationResponse AuthSSPI
(Right 10, saslMechanismName) -> case Parsec.parseOnly (Parsec.many1 nulTerminatedCStringParser) (LBS.toStrict saslMechanismName) of
Right sm -> Just $ AuthenticationResponse (AuthSASL sm)
Left _ -> Nothing
(Right 11, saslData) -> Just $ AuthenticationResponse (AuthSASLContinue (ScramServerFirstMessage $ LBS.toStrict saslData))
(Right 12, saslData) -> Just $ AuthenticationResponse (AuthSASLFinal (LBS.toStrict saslData))
_ -> Nothing
_ -> Nothing

instance FromPgMessage BackendKeyData where
Expand Down Expand Up @@ -250,12 +258,29 @@ instance ToPgMessage PasswordMessage where
AuthMD5Password (LBS.toStrict -> salt) ->
let passwordBytes = Builder.toStrictByteString (Builder.string7 password)
usernameBytes = Builder.toStrictByteString (Builder.string7 username)
innerHex = bytestringToHex (MD5.hash (passwordBytes <> usernameBytes))
outerHex = bytestringToHex (MD5.hash (innerHex <> salt))
innerHex = bytestringToHex (md5Hash (passwordBytes <> usernameBytes))
outerHex = bytestringToHex (md5Hash (innerHex <> salt))
in Builder.byteString "md5" <> Builder.byteString outerHex <> Builder.word8 0
_ -> error "PasswordMessage method not implemented"
msgLen = builderLength passwordBs + 4
in Builder.char7 'p' <> Builder.int32BE msgLen <> passwordBs
where
md5Hash :: ByteString -> ByteString
md5Hash = convert @(Crypto.Digest Crypto.MD5) . Crypto.hash

instance ToPgMessage SASLInitialResponse where
toPgMessage (SASLInitialResponse mechanism clientFirstMsg) =
let mechanismBs = Builder.byteString (encodeUtf8 mechanism) <> Builder.word8 0
dataLen = fromIntegral (BS.length clientFirstMsg) :: Int32
contents = mechanismBs <> Builder.int32BE dataLen <> Builder.byteString clientFirstMsg
msgLen = builderLength contents + 4
in Builder.char7 'p' <> Builder.int32BE msgLen <> contents

instance ToPgMessage SASLResponse where
toPgMessage (SASLResponse responseData) =
let contents = Builder.byteString responseData.clientFinalMessage
msgLen = builderLength contents + 4
in Builder.char7 'p' <> Builder.int32BE msgLen <> contents

instance ToPgMessage Query where
toPgMessage (Query bs) =
Expand Down
Loading