From 6f96e803a7deb292440d1f20fe8cecd76110672f Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Wed, 22 Apr 2026 23:58:14 +0000 Subject: [PATCH] Bump github.com/jackc/pgx/v5 from 5.7.1 to 5.9.2 Bumps [github.com/jackc/pgx/v5](https://github.com/jackc/pgx) from 5.7.1 to 5.9.2. - [Changelog](https://github.com/jackc/pgx/blob/master/CHANGELOG.md) - [Commits](https://github.com/jackc/pgx/compare/v5.7.1...v5.9.2) --- updated-dependencies: - dependency-name: github.com/jackc/pgx/v5 dependency-version: 5.9.2 dependency-type: direct:production ... Signed-off-by: dependabot[bot] --- go.mod | 2 +- go.sum | 4 +- vendor/github.com/jackc/pgx/v5/.golangci.yml | 26 + vendor/github.com/jackc/pgx/v5/CHANGELOG.md | 143 +++ vendor/github.com/jackc/pgx/v5/CLAUDE.md | 73 ++ .../github.com/jackc/pgx/v5/CONTRIBUTING.md | 26 +- vendor/github.com/jackc/pgx/v5/README.md | 28 +- vendor/github.com/jackc/pgx/v5/Rakefile | 2 +- vendor/github.com/jackc/pgx/v5/batch.go | 128 ++- vendor/github.com/jackc/pgx/v5/conn.go | 204 ++-- vendor/github.com/jackc/pgx/v5/copy_from.go | 16 +- .../github.com/jackc/pgx/v5/derived_types.go | 28 +- vendor/github.com/jackc/pgx/v5/doc.go | 67 +- .../pgx/v5/internal/iobufpool/iobufpool.go | 38 +- .../jackc/pgx/v5/internal/pgio/write.go | 22 +- .../pgx/v5/internal/sanitize/benchmark.sh | 60 + .../pgx/v5/internal/sanitize/sanitize.go | 280 ++++- .../pgx/v5/internal/stmtcache/lru_cache.go | 151 ++- .../v5/internal/stmtcache/unlimited_cache.go | 77 -- .../jackc/pgx/v5/pgconn/auth_oauth.go | 67 ++ .../jackc/pgx/v5/pgconn/auth_scram.go | 171 ++- .../github.com/jackc/pgx/v5/pgconn/config.go | 189 ++- .../pgx/v5/pgconn/ctxwatch/context_watcher.go | 46 +- .../github.com/jackc/pgx/v5/pgconn/errors.go | 27 +- vendor/github.com/jackc/pgx/v5/pgconn/krb5.go | 2 +- .../github.com/jackc/pgx/v5/pgconn/pgconn.go | 1024 +++++++++++++---- .../authentication_cleartext_password.go | 3 +- .../pgx/v5/pgproto3/authentication_ok.go | 3 +- .../pgx/v5/pgproto3/authentication_sasl.go | 1 + .../jackc/pgx/v5/pgproto3/backend.go | 23 +- .../jackc/pgx/v5/pgproto3/backend_key_data.go | 33 +- .../github.com/jackc/pgx/v5/pgproto3/bind.go | 8 +- .../jackc/pgx/v5/pgproto3/cancel_request.go | 45 +- .../pgx/v5/pgproto3/copy_both_response.go | 2 +- .../jackc/pgx/v5/pgproto3/copy_done.go | 3 +- .../jackc/pgx/v5/pgproto3/copy_fail.go | 4 + .../jackc/pgx/v5/pgproto3/copy_in_response.go | 2 +- .../pgx/v5/pgproto3/copy_out_response.go | 2 +- .../jackc/pgx/v5/pgproto3/data_row.go | 7 +- .../jackc/pgx/v5/pgproto3/frontend.go | 21 +- .../jackc/pgx/v5/pgproto3/function_call.go | 27 +- .../pgx/v5/pgproto3/function_call_response.go | 4 +- .../jackc/pgx/v5/pgproto3/gss_enc_request.go | 3 +- .../v5/pgproto3/negotiate_protocol_version.go | 93 ++ .../pgx/v5/pgproto3/parameter_description.go | 2 +- .../github.com/jackc/pgx/v5/pgproto3/parse.go | 2 +- .../jackc/pgx/v5/pgproto3/password_message.go | 2 +- .../github.com/jackc/pgx/v5/pgproto3/query.go | 4 + .../jackc/pgx/v5/pgproto3/row_description.go | 3 +- .../pgx/v5/pgproto3/sasl_initial_response.go | 3 + .../jackc/pgx/v5/pgproto3/ssl_request.go | 3 +- .../jackc/pgx/v5/pgproto3/startup_message.go | 11 +- .../github.com/jackc/pgx/v5/pgproto3/trace.go | 4 +- .../github.com/jackc/pgx/v5/pgtype/array.go | 20 +- .../jackc/pgx/v5/pgtype/array_codec.go | 12 +- vendor/github.com/jackc/pgx/v5/pgtype/bits.go | 7 +- vendor/github.com/jackc/pgx/v5/pgtype/bool.go | 11 +- vendor/github.com/jackc/pgx/v5/pgtype/box.go | 7 +- .../jackc/pgx/v5/pgtype/builtin_wrappers.go | 4 +- .../github.com/jackc/pgx/v5/pgtype/bytea.go | 1 - .../github.com/jackc/pgx/v5/pgtype/circle.go | 6 +- .../jackc/pgx/v5/pgtype/composite.go | 7 +- .../github.com/jackc/pgx/v5/pgtype/convert.go | 28 +- vendor/github.com/jackc/pgx/v5/pgtype/date.go | 127 +- vendor/github.com/jackc/pgx/v5/pgtype/doc.go | 89 +- .../github.com/jackc/pgx/v5/pgtype/float4.go | 12 +- .../github.com/jackc/pgx/v5/pgtype/float8.go | 12 +- .../github.com/jackc/pgx/v5/pgtype/hstore.go | 20 +- vendor/github.com/jackc/pgx/v5/pgtype/inet.go | 3 +- vendor/github.com/jackc/pgx/v5/pgtype/int.go | 34 +- .../github.com/jackc/pgx/v5/pgtype/int.go.erb | 9 +- .../pgtype/integration_benchmark_test.go.erb | 4 +- .../jackc/pgx/v5/pgtype/interval.go | 11 +- vendor/github.com/jackc/pgx/v5/pgtype/json.go | 78 +- vendor/github.com/jackc/pgx/v5/pgtype/line.go | 7 +- vendor/github.com/jackc/pgx/v5/pgtype/lseg.go | 7 +- .../jackc/pgx/v5/pgtype/multirange.go | 16 +- .../github.com/jackc/pgx/v5/pgtype/numeric.go | 81 +- vendor/github.com/jackc/pgx/v5/pgtype/path.go | 9 +- .../github.com/jackc/pgx/v5/pgtype/pgtype.go | 198 ++-- .../jackc/pgx/v5/pgtype/pgtype_default.go | 26 +- .../github.com/jackc/pgx/v5/pgtype/point.go | 9 +- .../github.com/jackc/pgx/v5/pgtype/polygon.go | 9 +- .../github.com/jackc/pgx/v5/pgtype/range.go | 13 +- .../jackc/pgx/v5/pgtype/record_codec.go | 1 - vendor/github.com/jackc/pgx/v5/pgtype/text.go | 9 +- vendor/github.com/jackc/pgx/v5/pgtype/tid.go | 7 +- vendor/github.com/jackc/pgx/v5/pgtype/time.go | 7 +- .../jackc/pgx/v5/pgtype/timestamp.go | 44 +- .../jackc/pgx/v5/pgtype/timestamptz.go | 23 +- .../jackc/pgx/v5/pgtype/tsvector.go | 507 ++++++++ .../github.com/jackc/pgx/v5/pgtype/uint32.go | 33 +- .../github.com/jackc/pgx/v5/pgtype/uint64.go | 323 ++++++ vendor/github.com/jackc/pgx/v5/pgtype/uuid.go | 16 +- vendor/github.com/jackc/pgx/v5/pgtype/xml.go | 2 +- .../github.com/jackc/pgx/v5/pgxpool/pool.go | 269 +++-- .../github.com/jackc/pgx/v5/pgxpool/stat.go | 7 + vendor/github.com/jackc/pgx/v5/rows.go | 77 +- vendor/github.com/jackc/pgx/v5/test.sh | 170 +++ vendor/github.com/jackc/pgx/v5/tx.go | 32 +- vendor/golang.org/x/crypto/pbkdf2/pbkdf2.go | 77 -- vendor/modules.txt | 5 +- 102 files changed, 4427 insertions(+), 1278 deletions(-) create mode 100644 vendor/github.com/jackc/pgx/v5/.golangci.yml create mode 100644 vendor/github.com/jackc/pgx/v5/CLAUDE.md create mode 100644 vendor/github.com/jackc/pgx/v5/internal/sanitize/benchmark.sh delete mode 100644 vendor/github.com/jackc/pgx/v5/internal/stmtcache/unlimited_cache.go create mode 100644 vendor/github.com/jackc/pgx/v5/pgconn/auth_oauth.go create mode 100644 vendor/github.com/jackc/pgx/v5/pgproto3/negotiate_protocol_version.go create mode 100644 vendor/github.com/jackc/pgx/v5/pgtype/tsvector.go create mode 100644 vendor/github.com/jackc/pgx/v5/pgtype/uint64.go create mode 100644 vendor/github.com/jackc/pgx/v5/test.sh delete mode 100644 vendor/golang.org/x/crypto/pbkdf2/pbkdf2.go diff --git a/go.mod b/go.mod index a3ac5cae..711630cf 100644 --- a/go.mod +++ b/go.mod @@ -10,7 +10,7 @@ require ( github.com/golang-migrate/migrate/v4 v4.17.1 github.com/google/uuid v1.6.0 github.com/habx/pg-commands v0.6.1 - github.com/jackc/pgx/v5 v5.7.1 + github.com/jackc/pgx/v5 v5.9.2 github.com/joho/godotenv v1.5.1 github.com/lib/pq v1.10.9 github.com/stretchr/testify v1.11.1 diff --git a/go.sum b/go.sum index 76b274e8..c34e5787 100644 --- a/go.sum +++ b/go.sum @@ -88,8 +88,8 @@ github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsI github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg= github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 h1:iCEnooe7UlwOQYpKFhBabPMi4aNAfoODPEFNiAnClxo= github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM= -github.com/jackc/pgx/v5 v5.7.1 h1:x7SYsPBYDkHDksogeSmZZ5xzThcTgRz++I5E+ePFUcs= -github.com/jackc/pgx/v5 v5.7.1/go.mod h1:e7O26IywZZ+naJtWWos6i6fvWK+29etgITqrqHLfoZA= +github.com/jackc/pgx/v5 v5.9.2 h1:3ZhOzMWnR4yJ+RW1XImIPsD1aNSz4T4fyP7zlQb56hw= +github.com/jackc/pgx/v5 v5.9.2/go.mod h1:mal1tBGAFfLHvZzaYh77YS/eC6IX9OWbRV1QIIM0Jn4= github.com/jackc/puddle/v2 v2.2.2 h1:PR8nw+E/1w0GLuRFSmiioY6UooMp6KJv0/61nB7icHo= github.com/jackc/puddle/v2 v2.2.2/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4= github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E= diff --git a/vendor/github.com/jackc/pgx/v5/.golangci.yml b/vendor/github.com/jackc/pgx/v5/.golangci.yml new file mode 100644 index 00000000..d0903ab3 --- /dev/null +++ b/vendor/github.com/jackc/pgx/v5/.golangci.yml @@ -0,0 +1,26 @@ +# See for configurations: https://golangci-lint.run/usage/configuration/ +version: "2" + +linters: + default: none + enable: + - govet + - ineffassign + +# See: https://golangci-lint.run/usage/formatters/ +formatters: + enable: + - gofmt # https://pkg.go.dev/cmd/gofmt + - gofumpt # https://github.com/mvdan/gofumpt + + settings: + gofmt: + simplify: true # Simplify code: gofmt with `-s` option. + + gofumpt: + # Module path which contains the source code being formatted. + # Default: "" + module-path: github.com/jackc/pgx/v5 # Should match with module in go.mod + # Choose whether to use the extra rules. + # Default: false + extra-rules: true diff --git a/vendor/github.com/jackc/pgx/v5/CHANGELOG.md b/vendor/github.com/jackc/pgx/v5/CHANGELOG.md index a0ff9ba3..87c9ebfb 100644 --- a/vendor/github.com/jackc/pgx/v5/CHANGELOG.md +++ b/vendor/github.com/jackc/pgx/v5/CHANGELOG.md @@ -1,3 +1,146 @@ +# 5.9.2 (April 18, 2026) + +Fix SQL Injection via placeholder confusion with dollar quoted string literals (GHSA-j88v-2chj-qfwx) + +SQL injection can occur when: + +1. The non-default simple protocol is used. +2. A dollar quoted string literal is used in the SQL query. +3. That query contains text that would be would be interpreted outside as a placeholder outside of a string literal. +4. The value of that placeholder is controllable by the attacker. + +e.g. + +```go +attackValue := `$tag$; drop table canary; --` +_, err = tx.Exec(ctx, `select $tag$ $1 $tag$, $1`, pgx.QueryExecModeSimpleProtocol, attackValue) +``` + +This is unlikely to occur outside of a contrived scenario. + +# 5.9.1 (March 22, 2026) + +* Fix: batch result format corruption when using cached prepared statements (reported by Dirkjan Bussink) + +# 5.9.0 (March 21, 2026) + +This release includes a number of new features such as SCRAM-SHA-256-PLUS support, OAuth authentication support, and +PostgreSQL protocol 3.2 support. + +It significantly reduces the amount of network traffic when using prepared statements (which are used automatically by +default) by avoiding unnecessary Describe Portal messages. This also reduces local memory usage. + +It also includes multiple fixes for potential DoS due to panic or OOM if connected to a malicious server that sends +deliberately malformed messages. + +* Require Go 1.25+ +* Add SCRAM-SHA-256-PLUS support (Adam Brightwell) +* Add OAuth authentication support for PostgreSQL 18 (David Schneider) +* Add PostgreSQL protocol 3.2 support (Dirkjan Bussink) +* Add tsvector type support (Adam Brightwell) +* Skip Describe Portal for cached prepared statements reducing network round trips +* Make LoadTypes query easier to support on "postgres-like" servers (Jelte Fennema-Nio) +* Default empty user to current OS user matching libpq behavior (ShivangSrivastava) +* Optimize LRU statement cache with custom linked list and node pooling (Mathias Bogaert) +* Optimize date scanning by replacing regex with manual parsing (Mathias Bogaert) +* Optimize pgio append/set functions with direct byte shifts (Mathias Bogaert) +* Make RowsAffected faster (Abhishek Chanda) +* Fix: Pipeline.Close panic when server sends multiple FATAL errors (Varun Chawla) +* Fix: ContextWatcher goroutine leak (Hank Donnay) +* Fix: stdlib discard connections with open transactions in ResetSession (Jeremy Schneider) +* Fix: pipelineBatchResults.Exec silently swallowing lastRows error +* Fix: ColumnTypeLength using BPCharArrayOID instead of BPCharOID +* Fix: TSVector text encoding returning nil for valid empty tsvector +* Fix: wrong error messages for Int2 and Int4 underflow +* Fix: Numeric nil Int pointer dereference with Valid: true +* Fix: reversed strings.ContainsAny arguments in Numeric.ScanScientific +* Fix: message length parsing on 32-bit platforms +* Fix: FunctionCallResponse.Decode mishandling of signed result size +* Fix: returning wrong error in configTLS when DecryptPEMBlock fails (Maxim Motyshen) +* Fix: misleading ParseConfig error when default_query_exec_mode is invalid (Skarm) +* Fix: missed Unwatch in Pipeline error paths +* Clarify too many failed acquire attempts error message +* Better error wrapping with context and SQL statement (Aneesh Makala) +* Enable govet and ineffassign linters (Federico Guerinoni) +* Guard against various malformed binary messages (arrays, hstore, multirange, protocol messages) +* Fix various godoc comments (ferhat elmas) +* Fix typos in comments (Oleksandr Redko) + +# 5.8.0 (December 26, 2025) + +* Require Go 1.24+ +* Remove golang.org/x/crypto dependency +* Add OptionShouldPing to control ResetSession ping behavior (ilyam8) +* Fix: Avoid overflow when MaxConns is set to MaxInt32 +* Fix: Close batch pipeline after a query error (Anthonin Bonnefoy) +* Faster shutdown of pgxpool.Pool background goroutines (Blake Gentry) +* Add pgxpool ping timeout (Amirsalar Safaei) +* Fix: Rows.FieldDescriptions for empty query +* Scan unknown types into *any as string or []byte based on format code +* Optimize pgtype.Numeric (Philip Dubé) +* Add AfterNetConnect hook to pgconn.Config +* Fix: Handle for preparing statements that fail during the Describe phase +* Fix overflow in numeric scanning (Ilia Demianenko) +* Fix: json/jsonb sql.Scanner source type is []byte +* Migrate from math/rand to math/rand/v2 (Mathias Bogaert) +* Optimize internal iobufpool (Mathias Bogaert) +* Optimize stmtcache invalidation (Mathias Bogaert) +* Fix: missing error case in interval parsing (Maxime Soulé) +* Fix: invalidate statement/description cache in Exec (James Hartig) +* ColumnTypeLength method return the type length for varbit type (DengChan) +* Array and Composite codecs handle typed nils + +# 5.7.6 (September 8, 2025) + +* Use ParseConfigError in pgx.ParseConfig and pgxpool.ParseConfig (Yurasov Ilia) +* Add PrepareConn hook to pgxpool (Jonathan Hall) +* Reduce allocations in QueryContext (Dominique Lefevre) +* Add MarshalJSON and UnmarshalJSON for pgtype.Uint32 (Panos Koutsovasilis) +* Configure ping behavior on pgxpool with ShouldPing (Christian Kiely) +* zeronull int types implement Int64Valuer and Int64Scanner (Li Zeghong) +* Fix panic when receiving terminate connection message during CopyFrom (Michal Drausowski) +* Fix statement cache not being invalidated on error during batch (Muhammadali Nazarov) + +# 5.7.5 (May 17, 2025) + +* Support sslnegotiation connection option (divyam234) +* Update golang.org/x/crypto to v0.37.0. This placates security scanners that were unable to see that pgx did not use the behavior affected by https://pkg.go.dev/vuln/GO-2025-3487. +* TraceLog now logs Acquire and Release at the debug level (dave sinclair) +* Add support for PGTZ environment variable +* Add support for PGOPTIONS environment variable +* Unpin memory used by Rows quicker +* Remove PlanScan memoization. This resolves a rare issue where scanning could be broken for one type by first scanning another. The problem was in the memoization system and benchmarking revealed that memoization was not providing any meaningful benefit. + +# 5.7.4 (March 24, 2025) + +* Fix / revert change to scanning JSON `null` (Felix Röhrich) + +# 5.7.3 (March 21, 2025) + +* Expose EmptyAcquireWaitTime in pgxpool.Stat (vamshiaruru32) +* Improve SQL sanitizer performance (ninedraft) +* Fix Scan confusion with json(b), sql.Scanner, and automatic dereferencing (moukoublen, felix-roehrich) +* Fix Values() for xml type always returning nil instead of []byte +* Add ability to send Flush message in pipeline mode (zenkovev) +* Fix pgtype.Timestamp's JSON behavior to match PostgreSQL (pconstantinou) +* Better error messages when scanning structs (logicbomb) +* Fix handling of error on batch write (bonnefoa) +* Match libpq's connection fallback behavior more closely (felix-roehrich) +* Add MinIdleConns to pgxpool (djahandarie) + +# 5.7.2 (December 21, 2024) + +* Fix prepared statement already exists on batch prepare failure +* Add commit query to tx options (Lucas Hild) +* Fix pgtype.Timestamp json unmarshal (Shean de Montigny-Desautels) +* Add message body size limits in frontend and backend (zene) +* Add xid8 type +* Ensure planning encodes and scans cannot infinitely recurse +* Implement pgtype.UUID.String() (Konstantin Grachev) +* Switch from ExecParams to Exec in ValidateConnectTargetSessionAttrs functions (Alexander Rumyantsev) +* Update golang.org/x/crypto +* Fix json(b) columns prefer sql.Scanner interface like database/sql (Ludovico Russo) + # 5.7.1 (September 10, 2024) * Fix data race in tracelog.TraceLog diff --git a/vendor/github.com/jackc/pgx/v5/CLAUDE.md b/vendor/github.com/jackc/pgx/v5/CLAUDE.md new file mode 100644 index 00000000..e3ed1a2e --- /dev/null +++ b/vendor/github.com/jackc/pgx/v5/CLAUDE.md @@ -0,0 +1,73 @@ +# CLAUDE.md + +This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository. + +## Project Overview + +pgx is a PostgreSQL driver and toolkit for Go (`github.com/jackc/pgx/v5`). It provides both a native PostgreSQL interface and a `database/sql` compatible driver. Requires Go 1.25+ and supports PostgreSQL 14+ and CockroachDB. + +## Build & Test Commands + +```bash +# Run all tests (requires PGX_TEST_DATABASE to be set) +go test ./... + +# Run a specific test +go test -run TestFunctionName ./... + +# Run tests for a specific package +go test ./pgconn/... + +# Run tests with race detector +go test -race ./... + +# DevContainer: run tests against specific PostgreSQL versions +./test.sh pg18 # Default: PostgreSQL 18 +./test.sh pg16 -run TestConnect # Specific test against PG16 +./test.sh crdb # CockroachDB +./test.sh all # All targets (pg14-18 + crdb) + +# Format (always run after making changes) +goimports -w . + +# Lint +golangci-lint run ./... +``` + +## Test Database Setup + +Tests require `PGX_TEST_DATABASE` environment variable. In the devcontainer, `test.sh` handles this. For local development: + +```bash +export PGX_TEST_DATABASE="host=localhost user=postgres password=postgres dbname=pgx_test" +``` + +The test database needs extensions: `hstore`, `ltree`, and a `uint64` domain. See `testsetup/postgresql_setup.sql` for full setup. Many tests are skipped unless additional `PGX_TEST_*` env vars are set (for TLS, SCRAM, MD5, unix socket, PgBouncer testing). + +## Architecture + +The codebase is a layered architecture, bottom-up: + +- **pgproto3/** — PostgreSQL wire protocol v3 encoder/decoder. Defines `FrontendMessage` and `BackendMessage` types for every protocol message. +- **pgconn/** — Low-level connection layer (roughly libpq-equivalent). Handles authentication, TLS, query execution, COPY protocol, and notifications. `PgConn` is the core type. +- **pgx** (root package) — High-level query interface built on `pgconn`. Provides `Conn`, `Rows`, `Tx`, `Batch`, `CopyFrom`, and generic helpers like `CollectRows`/`ForEachRow`. Includes automatic statement caching (LRU). +- **pgtype/** — Type system mapping between Go and PostgreSQL types (70+ types). Key interfaces: `Codec`, `Type`, `TypeMap`. Custom types (enums, composites, domains) are registered through `TypeMap`. +- **pgxpool/** — Concurrency-safe connection pool built on `puddle/v2`. `Pool` is the main type; wraps `pgx.Conn`. +- **stdlib/** — `database/sql` compatibility adapter. + +Supporting packages: +- **internal/stmtcache/** — Prepared statement cache with LRU eviction +- **internal/sanitize/** — SQL query sanitization +- **tracelog/** — Logging adapter that implements tracer interfaces +- **multitracer/** — Composes multiple tracers into one +- **pgxtest/** — Test helpers for running tests across connection types + +## Key Design Conventions + +- **Semantic versioning** — strictly followed. Do not break the public API (no removing or renaming exported types, functions, methods, or fields; no changing function signatures). +- **Minimal dependencies** — adding new dependencies is strongly discouraged (see CONTRIBUTING.md). +- **Context-based** — all blocking operations take `context.Context`. +- **Tracer interfaces** — observability via `QueryTracer`, `BatchTracer`, `CopyFromTracer`, `PrepareTracer` on `ConnConfig.Tracer`. +- **Formatting** — always run `goimports -w .` after making changes to ensure code is properly formatted. CI checks formatting via `gofmt -l -s -w . && git diff --exit-code`. `gofumpt` with extra rules is also enforced via `golangci-lint`. +- **Linters** — `govet` and `ineffassign` only (configured in `.golangci.yml`). +- **CI matrix** — tests run against Go 1.25/1.26 × PostgreSQL 14-18 + CockroachDB, on Linux and Windows. Race detector enabled on Linux only. diff --git a/vendor/github.com/jackc/pgx/v5/CONTRIBUTING.md b/vendor/github.com/jackc/pgx/v5/CONTRIBUTING.md index c975a937..2283ae67 100644 --- a/vendor/github.com/jackc/pgx/v5/CONTRIBUTING.md +++ b/vendor/github.com/jackc/pgx/v5/CONTRIBUTING.md @@ -10,6 +10,18 @@ proposal. This will help to ensure your proposed change has a reasonable chance Adding a dependency is a big deal. While on occasion a new dependency may be accepted, the default answer to any change that adds a dependency is no. +## AI + +Using AI is acceptable (not that it can really be stopped) under one the following conditions. + +* AI was used, but you deeply understand the code and you can answer questions regarding your change. You are not going + to answer questions with "I don't know", AI did it. You are not going to "answer" questions by relaying them to your + agent. This is wasteful of the code reviewer's time. +* AI was used to solve a problem without your deep understanding. This can still be a good starting point for a fix or + feature. But you need to clearly state that this is an AI proposal. You should include additional information such as + the AI used and what prompts were used. You should also be aware that large, complicated, or subtle changes may be + rejected simply because the reviewer is not confident in a change that no human understands. + ## Development Environment Setup pgx tests naturally require a PostgreSQL database. It will connect to the database specified in the `PGX_TEST_DATABASE` @@ -17,7 +29,12 @@ environment variable. The `PGX_TEST_DATABASE` environment variable can either be the standard `PG*` environment variables will be respected. Consider using [direnv](https://github.com/direnv/direnv) to simplify environment variable handling. -### Using an Existing PostgreSQL Cluster +### Devcontainer + +The easiest way to start development is with the included devcontainer. It includes containers for each supported +PostgreSQL version as well as CockroachDB. `./test.sh all` will run the tests against all database types. + +### Using an Existing PostgreSQL Cluster Outside of a Devcontainer If you already have a PostgreSQL development server this is the quickest way to start and run the majority of the pgx test suite. Some tests will be skipped that require server configuration changes (e.g. those testing different @@ -49,7 +66,7 @@ go test ./... This will run the vast majority of the tests, but some tests will be skipped (e.g. those testing different connection methods). -### Creating a New PostgreSQL Cluster Exclusively for Testing +### Creating a New PostgreSQL Cluster Exclusively for Testing Outside of a Devcontainer The following environment variables need to be set both for initial setup and whenever the tests are run. (direnv is highly recommended). Depending on your platform, you may need to change the host for `PGX_TEST_UNIX_SOCKET_CONN_STRING`. @@ -63,10 +80,11 @@ export POSTGRESQL_DATA_DIR=postgresql export PGX_TEST_DATABASE="host=127.0.0.1 database=pgx_test user=pgx_md5 password=secret" export PGX_TEST_UNIX_SOCKET_CONN_STRING="host=/private/tmp database=pgx_test" export PGX_TEST_TCP_CONN_STRING="host=127.0.0.1 database=pgx_test user=pgx_md5 password=secret" -export PGX_TEST_SCRAM_PASSWORD_CONN_STRING="host=127.0.0.1 user=pgx_scram password=secret database=pgx_test" +export PGX_TEST_SCRAM_PASSWORD_CONN_STRING="host=127.0.0.1 user=pgx_scram password=secret database=pgx_test channel_binding=disable" +export PGX_TEST_SCRAM_PLUS_CONN_STRING="host=localhost user=pgx_ssl password=secret sslmode=verify-full sslrootcert=`pwd`/.testdb/ca.pem database=pgx_test channel_binding=require" export PGX_TEST_MD5_PASSWORD_CONN_STRING="host=127.0.0.1 database=pgx_test user=pgx_md5 password=secret" export PGX_TEST_PLAIN_PASSWORD_CONN_STRING="host=127.0.0.1 user=pgx_pw password=secret" -export PGX_TEST_TLS_CONN_STRING="host=localhost user=pgx_ssl password=secret sslmode=verify-full sslrootcert=`pwd`/.testdb/ca.pem" +export PGX_TEST_TLS_CONN_STRING="host=localhost user=pgx_ssl password=secret sslmode=verify-full sslrootcert=`pwd`/.testdb/ca.pem channel_binding=disable" export PGX_SSL_PASSWORD=certpw export PGX_TEST_TLS_CLIENT_CONN_STRING="host=localhost user=pgx_sslcert sslmode=verify-full sslrootcert=`pwd`/.testdb/ca.pem database=pgx_test sslcert=`pwd`/.testdb/pgx_sslcert.crt sslkey=`pwd`/.testdb/pgx_sslcert.key" ``` diff --git a/vendor/github.com/jackc/pgx/v5/README.md b/vendor/github.com/jackc/pgx/v5/README.md index 0cf2c291..aa35e4a3 100644 --- a/vendor/github.com/jackc/pgx/v5/README.md +++ b/vendor/github.com/jackc/pgx/v5/README.md @@ -84,7 +84,7 @@ It is also possible to use the `database/sql` interface and convert a connection ## Testing -See CONTRIBUTING.md for setup instructions. +See [CONTRIBUTING.md](./CONTRIBUTING.md) for setup instructions. ## Architecture @@ -92,7 +92,7 @@ See the presentation at Golang Estonia, [PGX Top to Bottom](https://www.youtube. ## Supported Go and PostgreSQL Versions -pgx supports the same versions of Go and PostgreSQL that are supported by their respective teams. For [Go](https://golang.org/doc/devel/release.html#policy) that is the two most recent major releases and for [PostgreSQL](https://www.postgresql.org/support/versioning/) the major releases in the last 5 years. This means pgx supports Go 1.21 and higher and PostgreSQL 12 and higher. pgx also is tested against the latest version of [CockroachDB](https://www.cockroachlabs.com/product/). +pgx supports the same versions of Go and PostgreSQL that are supported by their respective teams. For [Go](https://golang.org/doc/devel/release.html#policy) that is the two most recent major releases and for [PostgreSQL](https://www.postgresql.org/support/versioning/) the major releases in the last 5 years. This means pgx supports Go 1.25 and higher and PostgreSQL 14 and higher. pgx also is tested against the latest version of [CockroachDB](https://www.cockroachlabs.com/product/). ## Version Policy @@ -120,13 +120,15 @@ pgerrcode contains constants for the PostgreSQL error codes. * [github.com/jackc/pgx-gofrs-uuid](https://github.com/jackc/pgx-gofrs-uuid) * [github.com/jackc/pgx-shopspring-decimal](https://github.com/jackc/pgx-shopspring-decimal) +* [github.com/ColeBurch/pgx-govalues-decimal](https://github.com/ColeBurch/pgx-govalues-decimal) * [github.com/twpayne/pgx-geos](https://github.com/twpayne/pgx-geos) ([PostGIS](https://postgis.net/) and [GEOS](https://libgeos.org/) via [go-geos](https://github.com/twpayne/go-geos)) * [github.com/vgarvardt/pgx-google-uuid](https://github.com/vgarvardt/pgx-google-uuid) ## Adapters for 3rd Party Tracers -* [https://github.com/jackhopner/pgx-xray-tracer](https://github.com/jackhopner/pgx-xray-tracer) +* [github.com/jackhopner/pgx-xray-tracer](https://github.com/jackhopner/pgx-xray-tracer) +* [github.com/exaring/otelpgx](https://github.com/exaring/otelpgx) ## Adapters for 3rd Party Loggers @@ -156,7 +158,7 @@ Library for scanning data from a database into Go structs and more. A carefully designed SQL client for making using SQL easier, more productive, and less error-prone on Golang. -### [https://github.com/otan/gopgkrb5](https://github.com/otan/gopgkrb5) +### [github.com/otan/gopgkrb5](https://github.com/otan/gopgkrb5) Adds GSSAPI / Kerberos authentication support. @@ -169,6 +171,22 @@ Explicit data mapping and scanning library for Go structs and slices. Type safe and flexible package for scanning database data into Go types. Supports, structs, maps, slices and custom mapping functions. -### [https://github.com/z0ne-dev/mgx](https://github.com/z0ne-dev/mgx) +### [github.com/z0ne-dev/mgx](https://github.com/z0ne-dev/mgx) Code first migration library for native pgx (no database/sql abstraction). + +### [github.com/amirsalarsafaei/sqlc-pgx-monitoring](https://github.com/amirsalarsafaei/sqlc-pgx-monitoring) + +A database monitoring/metrics library for pgx and sqlc. Trace, log and monitor your sqlc query performance using OpenTelemetry. + +### [https://github.com/nikolayk812/pgx-outbox](https://github.com/nikolayk812/pgx-outbox) + +Simple Golang implementation for transactional outbox pattern for PostgreSQL using jackc/pgx driver. + +### [https://github.com/Arlandaren/pgxWrappy](https://github.com/Arlandaren/pgxWrappy) + +Simplifies working with the pgx library, providing convenient scanning of nested structures. + +### [https://github.com/KoNekoD/pgx-colon-query-rewriter](https://github.com/KoNekoD/pgx-colon-query-rewriter) + +Implementation of the pgx query rewriter to use ':' instead of '@' in named query parameters. diff --git a/vendor/github.com/jackc/pgx/v5/Rakefile b/vendor/github.com/jackc/pgx/v5/Rakefile index d957573e..3e3aa503 100644 --- a/vendor/github.com/jackc/pgx/v5/Rakefile +++ b/vendor/github.com/jackc/pgx/v5/Rakefile @@ -2,7 +2,7 @@ require "erb" rule '.go' => '.go.erb' do |task| erb = ERB.new(File.read(task.source)) - File.write(task.name, "// Do not edit. Generated from #{task.source}\n" + erb.result(binding)) + File.write(task.name, "// Code generated from #{task.source}. DO NOT EDIT.\n\n" + erb.result(binding)) sh "goimports", "-w", task.name end diff --git a/vendor/github.com/jackc/pgx/v5/batch.go b/vendor/github.com/jackc/pgx/v5/batch.go index c3c2834f..805cc39e 100644 --- a/vendor/github.com/jackc/pgx/v5/batch.go +++ b/vendor/github.com/jackc/pgx/v5/batch.go @@ -8,7 +8,7 @@ import ( "github.com/jackc/pgx/v5/pgconn" ) -// QueuedQuery is a query that has been queued for execution via a Batch. +// QueuedQuery is a query that has been queued for execution via a [Batch]. type QueuedQuery struct { SQL string Arguments []any @@ -43,6 +43,10 @@ func (qq *QueuedQuery) QueryRow(fn func(row Row) error) { } // Exec sets fn to be called when the response to qq is received. +// +// Note: for simple batch insert uses where it is not required to handle +// each potential error individually, it's sufficient to not set any callbacks, +// and just handle the return value of [BatchResults.Close]. func (qq *QueuedQuery) Exec(fn func(ct pgconn.CommandTag) error) { qq.Fn = func(br BatchResults) error { ct, err := br.Exec() @@ -61,12 +65,13 @@ type Batch struct { } // Queue queues a query to batch b. query can be an SQL query or the name of a prepared statement. The only pgx option -// argument that is supported is QueryRewriter. Queries are executed using the connection's DefaultQueryExecMode. +// argument that is supported is [QueryRewriter]. Queries are executed using the connection's DefaultQueryExecMode +// (see [ConnConfig.DefaultQueryExecMode]). // -// While query can contain multiple statements if the connection's DefaultQueryExecMode is QueryModeSimple, this should -// be avoided. QueuedQuery.Fn must not be set as it will only be called for the first query. That is, QueuedQuery.Query, -// QueuedQuery.QueryRow, and QueuedQuery.Exec must not be called. In addition, any error messages or tracing that -// include the current query may reference the wrong query. +// While query can contain multiple statements if the connection's DefaultQueryExecMode is [QueryExecModeSimpleProtocol], +// this should be avoided. QueuedQuery.Fn must not be set as it will only be called for the first query. That is, +// [QueuedQuery.Query], [QueuedQuery.QueryRow], and [QueuedQuery.Exec] must not be called. In addition, any error +// messages or tracing that include the current query may reference the wrong query. func (b *Batch) Queue(query string, arguments ...any) *QueuedQuery { qq := &QueuedQuery{ SQL: query, @@ -82,22 +87,25 @@ func (b *Batch) Len() int { } type BatchResults interface { - // Exec reads the results from the next query in the batch as if the query has been sent with Conn.Exec. Prefer - // calling Exec on the QueuedQuery. + // Exec reads the results from the next query in the batch as if the query has been sent with [Conn.Exec]. Prefer + // calling Exec on the QueuedQuery, or just calling Close. Exec() (pgconn.CommandTag, error) - // Query reads the results from the next query in the batch as if the query has been sent with Conn.Query. Prefer - // calling Query on the QueuedQuery. + // Query reads the results from the next query in the batch as if the query has been sent with [Conn.Query]. Prefer + // calling [QueuedQuery.Query]. Query() (Rows, error) - // QueryRow reads the results from the next query in the batch as if the query has been sent with Conn.QueryRow. - // Prefer calling QueryRow on the QueuedQuery. + // QueryRow reads the results from the next query in the batch as if the query has been sent with [Conn.QueryRow]. + // Prefer calling [QueuedQuery.QueryRow]. QueryRow() Row // Close closes the batch operation. All unread results are read and any callback functions registered with - // QueuedQuery.Query, QueuedQuery.QueryRow, or QueuedQuery.Exec will be called. If a callback function returns an + // [QueuedQuery.Query], [QueuedQuery.QueryRow], or [QueuedQuery.Exec] will be called. If a callback function returns an // error or the batch encounters an error subsequent callback functions will not be called. // + // For simple batch inserts inside a transaction or similar queries, it's sufficient to not set any callbacks, + // and just handle the return value of Close. + // // Close must be called before the underlying connection can be used again. Any error that occurred during a batch // operation may have made it impossible to resyncronize the connection with the server. In this case the underlying // connection will have been closed. @@ -207,7 +215,6 @@ func (br *batchResults) Query() (Rows, error) { func (br *batchResults) QueryRow() Row { rows, _ := br.Query() return (*connRow)(rows.(*baseRows)) - } // Close closes the batch operation. Any error that occurred during a batch operation may have made it impossible to @@ -220,6 +227,8 @@ func (br *batchResults) Close() error { } br.endTraced = true } + + invalidateCachesOnBatchResultsError(br.conn, br.b, br.err) }() if br.err != nil { @@ -264,7 +273,7 @@ func (br *batchResults) nextQueryAndArgs() (query string, args []any, ok bool) { ok = true br.qqIdx++ } - return + return query, args, ok } type pipelineBatchResults struct { @@ -288,6 +297,7 @@ func (br *pipelineBatchResults) Exec() (pgconn.CommandTag, error) { return pgconn.CommandTag{}, fmt.Errorf("batch already closed") } if br.lastRows != nil && br.lastRows.err != nil { + br.err = br.lastRows.err return pgconn.CommandTag{}, br.err } @@ -378,7 +388,6 @@ func (br *pipelineBatchResults) Query() (Rows, error) { func (br *pipelineBatchResults) QueryRow() Row { rows, _ := br.Query() return (*connRow)(rows.(*baseRows)) - } // Close closes the batch operation. Any error that occurred during a batch operation may have made it impossible to @@ -391,11 +400,12 @@ func (br *pipelineBatchResults) Close() error { } br.endTraced = true } + + invalidateCachesOnBatchResultsError(br.conn, br.b, br.err) }() if br.err == nil && br.lastRows != nil && br.lastRows.err != nil { br.err = br.lastRows.err - return br.err } if br.closed { @@ -441,3 +451,87 @@ func (br *pipelineBatchResults) nextQueryAndArgs() (query string, args []any, er br.qqIdx++ return bi.SQL, bi.Arguments, nil } + +type emptyBatchResults struct { + conn *Conn + closed bool +} + +// Exec reads the results from the next query in the batch as if the query has been sent with Exec. +func (br *emptyBatchResults) Exec() (pgconn.CommandTag, error) { + if br.closed { + return pgconn.CommandTag{}, fmt.Errorf("batch already closed") + } + return pgconn.CommandTag{}, errors.New("no more results in batch") +} + +// Query reads the results from the next query in the batch as if the query has been sent with Query. +func (br *emptyBatchResults) Query() (Rows, error) { + if br.closed { + alreadyClosedErr := fmt.Errorf("batch already closed") + return &baseRows{err: alreadyClosedErr, closed: true}, alreadyClosedErr + } + + rows := br.conn.getRows(context.Background(), "", nil) + rows.err = errors.New("no more results in batch") + rows.closed = true + return rows, rows.err +} + +// QueryRow reads the results from the next query in the batch as if the query has been sent with QueryRow. +func (br *emptyBatchResults) QueryRow() Row { + rows, _ := br.Query() + return (*connRow)(rows.(*baseRows)) +} + +// Close closes the batch operation. Any error that occurred during a batch operation may have made it impossible to +// resyncronize the connection with the server. In this case the underlying connection will have been closed. +func (br *emptyBatchResults) Close() error { + br.closed = true + return nil +} + +// invalidates statement and description caches on batch results error +func invalidateCachesOnBatchResultsError(conn *Conn, b *Batch, err error) { + if err != nil && conn != nil && b != nil { + if sc := conn.statementCache; sc != nil { + for _, bi := range b.QueuedQueries { + sc.Invalidate(bi.SQL) + } + } + + if sc := conn.descriptionCache; sc != nil { + for _, bi := range b.QueuedQueries { + sc.Invalidate(bi.SQL) + } + } + } +} + +// ErrPreprocessingBatch occurs when an error is encountered while preprocessing a batch. +// The two preprocessing steps are "prepare" (server-side SQL parse/plan) and +// "build" (client-side argument encoding). +type ErrPreprocessingBatch struct { + step string // "prepare" or "build" + sql string + err error +} + +func newErrPreprocessingBatch(step, sql string, err error) ErrPreprocessingBatch { + return ErrPreprocessingBatch{step: step, sql: sql, err: err} +} + +func (e ErrPreprocessingBatch) Error() string { + // intentionally not including the SQL query in the error message + // to avoid leaking potentially sensitive information into logs. + // If the user wants the SQL, they can call SQL(). + return fmt.Sprintf("error preprocessing batch (%s): %v", e.step, e.err) +} + +func (e ErrPreprocessingBatch) Unwrap() error { + return e.err +} + +func (e ErrPreprocessingBatch) SQL() string { + return e.sql +} diff --git a/vendor/github.com/jackc/pgx/v5/conn.go b/vendor/github.com/jackc/pgx/v5/conn.go index 187b3dd5..4f27a5df 100644 --- a/vendor/github.com/jackc/pgx/v5/conn.go +++ b/vendor/github.com/jackc/pgx/v5/conn.go @@ -17,8 +17,8 @@ import ( "github.com/jackc/pgx/v5/pgtype" ) -// ConnConfig contains all the options used to establish a connection. It must be created by ParseConfig and -// then it can be modified. A manually initialized ConnConfig will cause ConnectConfig to panic. +// ConnConfig contains all the options used to establish a connection. It must be created by [ParseConfig] and +// then it can be modified. A manually initialized ConnConfig will cause [ConnectConfig] to panic. type ConnConfig struct { pgconn.Config @@ -37,8 +37,8 @@ type ConnConfig struct { // DefaultQueryExecMode controls the default mode for executing queries. By default pgx uses the extended protocol // and automatically prepares and caches prepared statements. However, this may be incompatible with proxies such as - // PGBouncer. In this case it may be preferable to use QueryExecModeExec or QueryExecModeSimpleProtocol. The same - // functionality can be controlled on a per query basis by passing a QueryExecMode as the first query argument. + // PGBouncer. In this case it may be preferable to use [QueryExecModeExec] or [QueryExecModeSimpleProtocol]. The same + // functionality can be controlled on a per query basis by passing a [QueryExecMode] as the first query argument. DefaultQueryExecMode QueryExecMode createdByParseConfig bool // Used to enforce created by ParseConfig rule. @@ -65,11 +65,12 @@ func (cc *ConnConfig) ConnString() string { return cc.connString } // Conn is a PostgreSQL connection handle. It is not safe for concurrent usage. Use a connection pool to manage access // to multiple database connections from multiple goroutines. type Conn struct { - pgConn *pgconn.PgConn - config *ConnConfig // config used when establishing this connection - preparedStatements map[string]*pgconn.StatementDescription - statementCache stmtcache.Cache - descriptionCache stmtcache.Cache + pgConn *pgconn.PgConn + config *ConnConfig // config used when establishing this connection + preparedStatements map[string]*pgconn.StatementDescription + failedDescribeStatement string + statementCache stmtcache.Cache + descriptionCache stmtcache.Cache queryTracer QueryTracer batchTracer BatchTracer @@ -130,7 +131,7 @@ var ( ) // Connect establishes a connection with a PostgreSQL server with a connection string. See -// pgconn.Connect for details. +// [pgconn.Connect] for details. func Connect(ctx context.Context, connString string) (*Conn, error) { connConfig, err := ParseConfig(connString) if err != nil { @@ -140,7 +141,7 @@ func Connect(ctx context.Context, connString string) (*Conn, error) { } // ConnectWithOptions behaves exactly like Connect with the addition of options. At the present options is only used to -// provide a GetSSLPassword function. +// provide a [pgconn.GetSSLPasswordFunc] function. func ConnectWithOptions(ctx context.Context, connString string, options ParseConfigOptions) (*Conn, error) { connConfig, err := ParseConfigWithOptions(connString, options) if err != nil { @@ -150,7 +151,7 @@ func ConnectWithOptions(ctx context.Context, connString string, options ParseCon } // ConnectConfig establishes a connection with a PostgreSQL server with a configuration struct. -// connConfig must have been created by ParseConfig. +// connConfig must have been created by [ParseConfig]. func ConnectConfig(ctx context.Context, connConfig *ConnConfig) (*Conn, error) { // In general this improves safety. In particular avoid the config.Config.OnNotification mutation from affecting other // connections with the same config. See https://github.com/jackc/pgx/issues/618. @@ -159,8 +160,8 @@ func ConnectConfig(ctx context.Context, connConfig *ConnConfig) (*Conn, error) { return connect(ctx, connConfig) } -// ParseConfigWithOptions behaves exactly as ParseConfig does with the addition of options. At the present options is -// only used to provide a GetSSLPassword function. +// ParseConfigWithOptions behaves exactly as [ParseConfig] does with the addition of options. At the present options is +// only used to provide a [pgconn.GetSSLPasswordFunc] function. func ParseConfigWithOptions(connString string, options ParseConfigOptions) (*ConnConfig, error) { config, err := pgconn.ParseConfigWithOptions(connString, options.ParseConfigOptions) if err != nil { @@ -172,7 +173,7 @@ func ParseConfigWithOptions(connString string, options ParseConfigOptions) (*Con delete(config.RuntimeParams, "statement_cache_capacity") n, err := strconv.ParseInt(s, 10, 32) if err != nil { - return nil, fmt.Errorf("cannot parse statement_cache_capacity: %w", err) + return nil, pgconn.NewParseConfigError(connString, "cannot parse statement_cache_capacity", err) } statementCacheCapacity = int(n) } @@ -182,7 +183,7 @@ func ParseConfigWithOptions(connString string, options ParseConfigOptions) (*Con delete(config.RuntimeParams, "description_cache_capacity") n, err := strconv.ParseInt(s, 10, 32) if err != nil { - return nil, fmt.Errorf("cannot parse description_cache_capacity: %w", err) + return nil, pgconn.NewParseConfigError(connString, "cannot parse description_cache_capacity", err) } descriptionCacheCapacity = int(n) } @@ -202,7 +203,9 @@ func ParseConfigWithOptions(connString string, options ParseConfigOptions) (*Con case "simple_protocol": defaultQueryExecMode = QueryExecModeSimpleProtocol default: - return nil, fmt.Errorf("invalid default_query_exec_mode: %s", s) + return nil, pgconn.NewParseConfigError( + connString, "invalid default_query_exec_mode", fmt.Errorf("unknown value %q", s), + ) } } @@ -305,8 +308,8 @@ func (c *Conn) Close(ctx context.Context) error { } // Prepare creates a prepared statement with name and sql. sql can contain placeholders for bound parameters. These -// placeholders are referenced positionally as $1, $2, etc. name can be used instead of sql with Query, QueryRow, and -// Exec to execute the statement. It can also be used with Batch.Queue. +// placeholders are referenced positionally as $1, $2, etc. name can be used instead of sql with [Conn.Query], +// [Conn.QueryRow], and [Conn.Exec] to execute the statement. It can also be used with [Batch.Queue]. // // The underlying PostgreSQL identifier for the prepared statement will be name if name != sql or a digest of sql if // name == sql. @@ -314,6 +317,14 @@ func (c *Conn) Close(ctx context.Context) error { // Prepare is idempotent; i.e. it is safe to call Prepare multiple times with the same name and sql arguments. This // allows a code path to Prepare and Query/Exec without concern for if the statement has already been prepared. func (c *Conn) Prepare(ctx context.Context, name, sql string) (sd *pgconn.StatementDescription, err error) { + if c.failedDescribeStatement != "" { + err = c.Deallocate(ctx, c.failedDescribeStatement) + if err != nil { + return nil, fmt.Errorf("failed to deallocate previously failed statement %q: %w", c.failedDescribeStatement, err) + } + c.failedDescribeStatement = "" + } + if c.prepareTracer != nil { ctx = c.prepareTracer.TracePrepareStart(ctx, c, TracePrepareStartData{Name: name, SQL: sql}) } @@ -346,6 +357,10 @@ func (c *Conn) Prepare(ctx context.Context, name, sql string) (sd *pgconn.Statem sd, err = c.pgConn.Prepare(ctx, psName, sql, nil) if err != nil { + var pErr *pgconn.PrepareError + if errors.As(err, &pErr) { + c.failedDescribeStatement = psKey + } return nil, err } @@ -420,7 +435,7 @@ func (c *Conn) IsClosed() bool { return c.pgConn.IsClosed() } -func (c *Conn) die(err error) { +func (c *Conn) die() { if c.IsClosed() { return } @@ -502,6 +517,18 @@ optionLoop: mode = QueryExecModeSimpleProtocol } + defer func() { + if err != nil { + if sc := c.statementCache; sc != nil { + sc.Invalidate(sql) + } + + if sc := c.descriptionCache; sc != nil { + sc.Invalidate(sql) + } + } + }() + if sd, ok := c.preparedStatements[sql]; ok { return c.execPrepared(ctx, sd, arguments) } @@ -583,19 +610,11 @@ func (c *Conn) execPrepared(ctx context.Context, sd *pgconn.StatementDescription return pgconn.CommandTag{}, err } - result := c.pgConn.ExecPrepared(ctx, sd.Name, c.eqb.ParamValues, c.eqb.ParamFormats, c.eqb.ResultFormats).Read() + result := c.pgConn.ExecStatement(ctx, sd, c.eqb.ParamValues, c.eqb.ParamFormats, c.eqb.ResultFormats).Read() c.eqb.reset() // Allow c.eqb internal memory to be GC'ed as soon as possible. return result.CommandTag, result.Err } -type unknownArgumentTypeQueryExecModeExecError struct { - arg any -} - -func (e *unknownArgumentTypeQueryExecModeExecError) Error() string { - return fmt.Sprintf("cannot use unregistered type %T as query argument in QueryExecModeExec", e.arg) -} - func (c *Conn) execSQLParams(ctx context.Context, sql string, args []any) (pgconn.CommandTag, error) { err := c.eqb.Build(c.typeMap, nil, args) if err != nil { @@ -650,21 +669,33 @@ const ( // registered with pgtype.Map.RegisterDefaultPgType. Queries will be rejected that have arguments that are // unregistered or ambiguous. e.g. A map[string]string may have the PostgreSQL type json or hstore. Modes that know // the PostgreSQL type can use a map[string]string directly as an argument. This mode cannot. + // + // On rare occasions user defined types may behave differently when encoded in the text format instead of the binary + // format. For example, this could happen if a "type RomanNumeral int32" implements fmt.Stringer to format integers as + // Roman numerals (e.g. 7 is VII). The binary format would properly encode the integer 7 as the binary value for 7. + // But the text format would encode the integer 7 as the string "VII". As QueryExecModeExec uses the text format, it + // is possible that changing query mode from another mode to QueryExecModeExec could change the behavior of the query. + // This should not occur with types pgx supports directly and can be avoided by registering the types with + // pgtype.Map.RegisterDefaultPgType and implementing the appropriate type interfaces. In the cas of RomanNumeral, it + // should implement pgtype.Int64Valuer. QueryExecModeExec - // Use the simple protocol. Assume the PostgreSQL query parameter types based on the Go type of the arguments. - // Queries are executed in a single round trip. Type mappings can be registered with + // Use the simple protocol. Assume the PostgreSQL query parameter types based on the Go type of the arguments. This is + // especially significant for []byte values. []byte values are encoded as PostgreSQL bytea. string must be used + // instead for text type values including json and jsonb. Type mappings can be registered with // pgtype.Map.RegisterDefaultPgType. Queries will be rejected that have arguments that are unregistered or ambiguous. - // e.g. A map[string]string may have the PostgreSQL type json or hstore. Modes that know the PostgreSQL type can use - // a map[string]string directly as an argument. This mode cannot. + // e.g. A map[string]string may have the PostgreSQL type json or hstore. Modes that know the PostgreSQL type can use a + // map[string]string directly as an argument. This mode cannot. Queries are executed in a single round trip. // - // QueryExecModeSimpleProtocol should have the user application visible behavior as QueryExecModeExec with minor - // exceptions such as behavior when multiple result returning queries are erroneously sent in a single string. + // QueryExecModeSimpleProtocol should have the user application visible behavior as QueryExecModeExec. This includes + // the warning regarding differences in text format and binary format encoding with user defined types. There may be + // other minor exceptions such as behavior when multiple result returning queries are erroneously sent in a single + // string. // // QueryExecModeSimpleProtocol uses client side parameter interpolation. All values are quoted and escaped. Prefer - // QueryExecModeExec over QueryExecModeSimpleProtocol whenever possible. In general QueryExecModeSimpleProtocol - // should only be used if connecting to a proxy server, connection pool server, or non-PostgreSQL server that does - // not support the extended protocol. + // QueryExecModeExec over QueryExecModeSimpleProtocol whenever possible. In general QueryExecModeSimpleProtocol should + // only be used if connecting to a proxy server, connection pool server, or non-PostgreSQL server that does not + // support the extended protocol. QueryExecModeSimpleProtocol ) @@ -813,7 +844,7 @@ optionLoop: if !explicitPreparedStatement && mode == QueryExecModeCacheDescribe { rows.resultReader = c.pgConn.ExecParams(ctx, sql, c.eqb.ParamValues, sd.ParamOIDs, c.eqb.ParamFormats, resultFormats) } else { - rows.resultReader = c.pgConn.ExecPrepared(ctx, sd.Name, c.eqb.ParamValues, c.eqb.ParamFormats, resultFormats) + rows.resultReader = c.pgConn.ExecStatement(ctx, sd, c.eqb.ParamValues, c.eqb.ParamFormats, resultFormats) } } else if mode == QueryExecModeExec { err := c.eqb.Build(c.typeMap, nil, args) @@ -902,9 +933,16 @@ func (c *Conn) QueryRow(ctx context.Context, sql string, args ...any) Row { } // SendBatch sends all queued queries to the server at once. All queries are run in an implicit transaction unless -// explicit transaction control statements are executed. The returned BatchResults must be closed before the connection +// explicit transaction control statements are executed. The returned [BatchResults] must be closed before the connection // is used again. +// +// Depending on the QueryExecMode, all queries may be prepared before any are executed. This means that creating a table +// and using it in a subsequent query in the same batch can fail. func (c *Conn) SendBatch(ctx context.Context, b *Batch) (br BatchResults) { + if len(b.QueuedQueries) == 0 { + return &emptyBatchResults{conn: c} + } + if c.batchTracer != nil { ctx = c.batchTracer.TraceBatchStart(ctx, c, TraceBatchStartData{Batch: b}) defer func() { @@ -1126,62 +1164,82 @@ func (c *Conn) sendBatchExtendedWithDescription(ctx context.Context, b *Batch, d // Prepare any needed queries if len(distinctNewQueries) > 0 { - for _, sd := range distinctNewQueries { - pipeline.SendPrepare(sd.Name, sd.SQL, nil) - } + err := func() (err error) { + for _, sd := range distinctNewQueries { + pipeline.SendPrepare(sd.Name, sd.SQL, nil) + } - err := pipeline.Sync() - if err != nil { - return &pipelineBatchResults{ctx: ctx, conn: c, err: err, closed: true} - } + // Store all statements we are preparing into the cache. It's fine if it overflows because HandleInvalidated will + // clean them up later. + if sdCache != nil { + for _, sd := range distinctNewQueries { + sdCache.Put(sd) + } + } + + // If something goes wrong preparing the statements, we need to invalidate the cache entries we just added. + defer func() { + if err != nil && sdCache != nil { + for _, sd := range distinctNewQueries { + sdCache.Invalidate(sd.SQL) + } + } + }() + + err = pipeline.Sync() + if err != nil { + return err + } + + for _, sd := range distinctNewQueries { + results, err := pipeline.GetResults() + if err != nil { + return newErrPreprocessingBatch("prepare", sd.SQL, err) + } + + resultSD, ok := results.(*pgconn.StatementDescription) + if !ok { + return fmt.Errorf("expected statement description, got %T", results) + } + + // Fill in the previously empty / pending statement descriptions. + sd.ParamOIDs = resultSD.ParamOIDs + sd.Fields = resultSD.Fields + } - for _, sd := range distinctNewQueries { results, err := pipeline.GetResults() if err != nil { - return &pipelineBatchResults{ctx: ctx, conn: c, err: err, closed: true} + return err } - resultSD, ok := results.(*pgconn.StatementDescription) + _, ok := results.(*pgconn.PipelineSync) if !ok { - return &pipelineBatchResults{ctx: ctx, conn: c, err: fmt.Errorf("expected statement description, got %T", results), closed: true} + return fmt.Errorf("expected sync, got %T", results) } - // Fill in the previously empty / pending statement descriptions. - sd.ParamOIDs = resultSD.ParamOIDs - sd.Fields = resultSD.Fields - } - - results, err := pipeline.GetResults() + return nil + }() if err != nil { return &pipelineBatchResults{ctx: ctx, conn: c, err: err, closed: true} } - - _, ok := results.(*pgconn.PipelineSync) - if !ok { - return &pipelineBatchResults{ctx: ctx, conn: c, err: fmt.Errorf("expected sync, got %T", results), closed: true} - } - } - - // Put all statements into the cache. It's fine if it overflows because HandleInvalidated will clean them up later. - if sdCache != nil { - for _, sd := range distinctNewQueries { - sdCache.Put(sd) - } } // Queue the queries. for _, bi := range b.QueuedQueries { err := c.eqb.Build(c.typeMap, bi.sd, bi.Arguments) if err != nil { - // we wrap the error so we the user can understand which query failed inside the batch - err = fmt.Errorf("error building query %s: %w", bi.SQL, err) + err = newErrPreprocessingBatch("build", bi.SQL, err) return &pipelineBatchResults{ctx: ctx, conn: c, err: err, closed: true} } if bi.sd.Name == "" { pipeline.SendQueryParams(bi.sd.SQL, c.eqb.ParamValues, bi.sd.ParamOIDs, c.eqb.ParamFormats, c.eqb.ResultFormats) } else { - pipeline.SendQueryPrepared(bi.sd.Name, c.eqb.ParamValues, c.eqb.ParamFormats, c.eqb.ResultFormats) + // Copy ResultFormats because SendQueryStatement stores the slice for later use, and eqb.Build reuses the + // backing array on the next iteration. + resultFormats := make([]int16, len(c.eqb.ResultFormats)) + copy(resultFormats, c.eqb.ResultFormats) + pipeline.SendQueryStatement(bi.sd, c.eqb.ParamValues, c.eqb.ParamFormats, resultFormats) } } @@ -1219,7 +1277,7 @@ func (c *Conn) sanitizeForSimpleQuery(sql string, args ...any) (string, error) { return sanitize.SanitizeSQL(sql, valueArgs...) } -// LoadType inspects the database for typeName and produces a pgtype.Type suitable for registration. typeName must be +// LoadType inspects the database for typeName and produces a [pgtype.Type] suitable for registration. typeName must be // the name of a type where the underlying type(s) is already understood by pgx. It is for derived types. In particular, // typeName must be one of the following: // - An array type name of a type that is already registered. e.g. "_foo" when "foo" is registered. diff --git a/vendor/github.com/jackc/pgx/v5/copy_from.go b/vendor/github.com/jackc/pgx/v5/copy_from.go index abcd2239..038c568c 100644 --- a/vendor/github.com/jackc/pgx/v5/copy_from.go +++ b/vendor/github.com/jackc/pgx/v5/copy_from.go @@ -10,8 +10,8 @@ import ( "github.com/jackc/pgx/v5/pgconn" ) -// CopyFromRows returns a CopyFromSource interface over the provided rows slice -// making it usable by *Conn.CopyFrom. +// CopyFromRows returns a [CopyFromSource] interface over the provided rows slice +// making it usable by [Conn.CopyFrom]. func CopyFromRows(rows [][]any) CopyFromSource { return ©FromRows{rows: rows, idx: -1} } @@ -34,8 +34,8 @@ func (ctr *copyFromRows) Err() error { return nil } -// CopyFromSlice returns a CopyFromSource interface over a dynamic func -// making it usable by *Conn.CopyFrom. +// CopyFromSlice returns a [CopyFromSource] interface over a dynamic func +// making it usable by [Conn.CopyFrom]. func CopyFromSlice(length int, next func(int) ([]any, error)) CopyFromSource { return ©FromSlice{next: next, idx: -1, len: length} } @@ -64,7 +64,7 @@ func (cts *copyFromSlice) Err() error { return cts.err } -// CopyFromFunc returns a CopyFromSource interface that relies on nxtf for values. +// CopyFromFunc returns a [CopyFromSource] interface that relies on nxtf for values. // nxtf returns rows until it either signals an 'end of data' by returning row=nil and err=nil, // or it returns an error. If nxtf returns an error, the copy is aborted. func CopyFromFunc(nxtf func() (row []any, err error)) CopyFromSource { @@ -91,7 +91,7 @@ func (g *copyFromFunc) Err() error { return g.err } -// CopyFromSource is the interface used by *Conn.CopyFrom as the source for copy data. +// CopyFromSource is the interface used by [Conn.CopyFrom] as the source for copy data. type CopyFromSource interface { // Next returns true if there is another row and makes the next row data // available to Values(). When there are no more rows available or an error @@ -260,8 +260,8 @@ func (ct *copyFrom) buildCopyBuf(buf []byte, sd *pgconn.StatementDescription) (b // CopyFrom requires all values use the binary format. A pgtype.Type that supports the binary format must be registered // for the type of each column. Almost all types implemented by pgx support the binary format. // -// Even though enum types appear to be strings they still must be registered to use with CopyFrom. This can be done with -// Conn.LoadType and pgtype.Map.RegisterType. +// Even though enum types appear to be strings they still must be registered to use with [Conn.CopyFrom]. This can be done with +// [Conn.LoadType] and [pgtype.Map.RegisterType]. func (c *Conn) CopyFrom(ctx context.Context, tableName Identifier, columnNames []string, rowSrc CopyFromSource) (int64, error) { ct := ©From{ conn: c, diff --git a/vendor/github.com/jackc/pgx/v5/derived_types.go b/vendor/github.com/jackc/pgx/v5/derived_types.go index 22ab069c..89b9a77c 100644 --- a/vendor/github.com/jackc/pgx/v5/derived_types.go +++ b/vendor/github.com/jackc/pgx/v5/derived_types.go @@ -24,7 +24,7 @@ func buildLoadDerivedTypesSQL(pgVersion int64, typeNames []string) string { // This should not occur; this will not return any types typeNamesClause = "= ''" } else { - typeNamesClause = "= ANY($1)" + typeNamesClause = "= ANY($1::text[])" } parts := make([]string, 0, 10) @@ -161,7 +161,7 @@ type derivedTypeInfo struct { // The result of this call can be passed into RegisterTypes to complete the process. func (c *Conn) LoadTypes(ctx context.Context, typeNames []string) ([]*pgtype.Type, error) { m := c.TypeMap() - if typeNames == nil || len(typeNames) == 0 { + if len(typeNames) == 0 { return nil, fmt.Errorf("No type names were supplied.") } @@ -169,13 +169,7 @@ func (c *Conn) LoadTypes(ctx context.Context, typeNames []string) ([]*pgtype.Typ // the SQL not support recent structures such as multirange serverVersion, _ := serverVersion(c) sql := buildLoadDerivedTypesSQL(serverVersion, typeNames) - var rows Rows - var err error - if typeNames == nil { - rows, err = c.Query(ctx, sql, QueryExecModeSimpleProtocol) - } else { - rows, err = c.Query(ctx, sql, QueryExecModeSimpleProtocol, typeNames) - } + rows, err := c.Query(ctx, sql, QueryResultFormats{TextFormatCode}, typeNames) if err != nil { return nil, fmt.Errorf("While generating load types query: %w", err) } @@ -232,15 +226,15 @@ func (c *Conn) LoadTypes(ctx context.Context, typeNames []string) ([]*pgtype.Typ default: return nil, fmt.Errorf("Unknown typtype %q was found while registering %q", ti.Typtype, ti.TypeName) } - if type_ != nil { - m.RegisterType(type_) - if ti.NspName != "" { - nspType := &pgtype.Type{Name: ti.NspName + "." + type_.Name, OID: type_.OID, Codec: type_.Codec} - m.RegisterType(nspType) - result = append(result, nspType) - } - result = append(result, type_) + + // the type_ is impossible to be null + m.RegisterType(type_) + if ti.NspName != "" { + nspType := &pgtype.Type{Name: ti.NspName + "." + type_.Name, OID: type_.OID, Codec: type_.Codec} + m.RegisterType(nspType) + result = append(result, nspType) } + result = append(result, type_) } return result, nil } diff --git a/vendor/github.com/jackc/pgx/v5/doc.go b/vendor/github.com/jackc/pgx/v5/doc.go index 0e91d64e..225b4647 100644 --- a/vendor/github.com/jackc/pgx/v5/doc.go +++ b/vendor/github.com/jackc/pgx/v5/doc.go @@ -1,8 +1,8 @@ // Package pgx is a PostgreSQL database driver. /* -pgx provides a native PostgreSQL driver and can act as a database/sql driver. The native PostgreSQL interface is similar -to the database/sql interface while providing better speed and access to PostgreSQL specific features. Use -github.com/jackc/pgx/v5/stdlib to use pgx as a database/sql compatible driver. See that package's documentation for +pgx provides a native PostgreSQL driver and can act as a [database/sql/driver]. The native PostgreSQL interface is similar +to the [database/sql] interface while providing better speed and access to PostgreSQL specific features. Use +[github.com/jackc/pgx/v5/stdlib] to use pgx as a database/sql compatible driver. See that package's documentation for details. Establishing a Connection @@ -19,15 +19,15 @@ string. Connection Pool [*pgx.Conn] represents a single connection to the database and is not concurrency safe. Use package -github.com/jackc/pgx/v5/pgxpool for a concurrency safe connection pool. +[github.com/jackc/pgx/v5/pgxpool] for a concurrency safe connection pool. Query Interface -pgx implements Query in the familiar database/sql style. However, pgx provides generic functions such as CollectRows and -ForEachRow that are a simpler and safer way of processing rows than manually calling defer rows.Close(), rows.Next(), -rows.Scan, and rows.Err(). +pgx implements [Conn.Query] in the familiar database/sql style. However, pgx provides generic functions such as [CollectRows] and +[ForEachRow] that are a simpler and safer way of processing rows than manually calling defer [Rows.Close], [Rows.Next], +[Rows.Scan], and [Rows.Err]. -CollectRows can be used collect all returned rows into a slice. +[CollectRows] can be used collect all returned rows into a slice. rows, _ := conn.Query(context.Background(), "select generate_series(1,$1)", 5) numbers, err := pgx.CollectRows(rows, pgx.RowTo[int32]) @@ -36,7 +36,7 @@ CollectRows can be used collect all returned rows into a slice. } // numbers => [1 2 3 4 5] -ForEachRow can be used to execute a callback function for every row. This is often easier than iterating over rows +[ForEachRow] can be used to execute a callback function for every row. This is often easier than iterating over rows directly. var sum, n int32 @@ -49,7 +49,7 @@ directly. return err } -pgx also implements QueryRow in the same style as database/sql. +pgx also implements [Conn.QueryRow] in the same style as database/sql. var name string var weight int64 @@ -58,7 +58,7 @@ pgx also implements QueryRow in the same style as database/sql. return err } -Use Exec to execute a query that does not return a result set. +Use [Conn.Exec] to execute a query that does not return a result set. commandTag, err := conn.Exec(context.Background(), "delete from widgets where id=$1", 42) if err != nil { @@ -70,13 +70,13 @@ Use Exec to execute a query that does not return a result set. PostgreSQL Data Types -pgx uses the pgtype package to converting Go values to and from PostgreSQL values. It supports many PostgreSQL types +pgx uses the [pgtype] package to converting Go values to and from PostgreSQL values. It supports many PostgreSQL types directly and is customizable and extendable. User defined data types such as enums, domains, and composite types may require type registration. See that package's documentation for details. Transactions -Transactions are started by calling Begin. +Transactions are started by calling [Conn.Begin]. tx, err := conn.Begin(context.Background()) if err != nil { @@ -96,13 +96,13 @@ Transactions are started by calling Begin. return err } -The Tx returned from Begin also implements the Begin method. This can be used to implement pseudo nested transactions. +The [Tx] returned from [Conn.Begin] also implements the [Tx.Begin] method. This can be used to implement pseudo nested transactions. These are internally implemented with savepoints. -Use BeginTx to control the transaction mode. BeginTx also can be used to ensure a new transaction is created instead of +Use [Conn.BeginTx] to control the transaction mode. [Conn.BeginTx] also can be used to ensure a new transaction is created instead of a pseudo nested transaction. -BeginFunc and BeginTxFunc are functions that begin a transaction, execute a function, and commit or rollback the +[BeginFunc] and [BeginTxFunc] are functions that begin a transaction, execute a function, and commit or rollback the transaction depending on the return value of the function. These can be simpler and less error prone to use. err = pgx.BeginFunc(context.Background(), conn, func(tx pgx.Tx) error { @@ -115,16 +115,16 @@ transaction depending on the return value of the function. These can be simpler Prepared Statements -Prepared statements can be manually created with the Prepare method. However, this is rarely necessary because pgx -includes an automatic statement cache by default. Queries run through the normal Query, QueryRow, and Exec functions are -automatically prepared on first execution and the prepared statement is reused on subsequent executions. See ParseConfig -for information on how to customize or disable the statement cache. +Prepared statements can be manually created with the [Conn.Prepare] method. However, this is rarely necessary because pgx +includes an automatic statement cache by default. Queries run through the normal [Conn.Query], [Conn.QueryRow], and [Conn.Exec] +functions are automatically prepared on first execution and the prepared statement is reused on subsequent executions. +See [ParseConfig] for information on how to customize or disable the statement cache. Copy Protocol -Use CopyFrom to efficiently insert multiple rows at a time using the PostgreSQL copy protocol. CopyFrom accepts a -CopyFromSource interface. If the data is already in a [][]any use CopyFromRows to wrap it in a CopyFromSource interface. -Or implement CopyFromSource to avoid buffering the entire data set in memory. +Use [Conn.CopyFrom] to efficiently insert multiple rows at a time using the PostgreSQL copy protocol. [Conn.CopyFrom] accepts a +[CopyFromSource] interface. If the data is already in a [][]any use [CopyFromRows] to wrap it in a [CopyFromSource] interface. +Or implement [CopyFromSource] to avoid buffering the entire data set in memory. rows := [][]any{ {"John", "Smith", int32(36)}, @@ -138,7 +138,7 @@ Or implement CopyFromSource to avoid buffering the entire data set in memory. pgx.CopyFromRows(rows), ) -When you already have a typed array using CopyFromSlice can be more convenient. +When you already have a typed array using [CopyFromSlice] can be more convenient. rows := []User{ {"John", "Smith", 36}, @@ -158,7 +158,7 @@ CopyFrom can be faster than an insert with as few as 5 rows. Listen and Notify -pgx can listen to the PostgreSQL notification system with the `Conn.WaitForNotification` method. It blocks until a +pgx can listen to the PostgreSQL notification system with the [Conn.WaitForNotification] method. It blocks until a notification is received or the context is canceled. _, err := conn.Exec(context.Background(), "listen channelname") @@ -175,20 +175,25 @@ notification is received or the context is canceled. Tracing and Logging -pgx supports tracing by setting ConnConfig.Tracer. To combine several tracers you can use the multitracer.Tracer. +pgx supports tracing by setting [ConnConfig.Tracer]. To combine several tracers you can use the [github.com/jackc/pgx/v5/multitracer.Tracer]. -In addition, the tracelog package provides the TraceLog type which lets a traditional logger act as a Tracer. +In addition, the [github.com/jackc/pgx/v5/tracelog] package provides the [github.com/jackc/pgx/v5/tracelog.TraceLog] type which lets a +traditional logger act as a [QueryTracer]. -For debug tracing of the actual PostgreSQL wire protocol messages see github.com/jackc/pgx/v5/pgproto3. +For debug tracing of the actual PostgreSQL wire protocol messages see [github.com/jackc/pgx/v5/pgproto3]. Lower Level PostgreSQL Functionality -github.com/jackc/pgx/v5/pgconn contains a lower level PostgreSQL driver roughly at the level of libpq. pgx.Conn in -implemented on top of pgconn. The Conn.PgConn() method can be used to access this lower layer. +[github.com/jackc/pgx/v5/pgconn] contains a lower level PostgreSQL driver roughly at the level of libpq. [Conn] is +implemented on top of [pgconn.PgConn]. The [Conn.PgConn] method can be used to access this lower layer. PgBouncer By default pgx automatically uses prepared statements. Prepared statements are incompatible with PgBouncer. This can be -disabled by setting a different QueryExecMode in ConnConfig.DefaultQueryExecMode. +disabled by setting a different [QueryExecMode] in [ConnConfig.DefaultQueryExecMode]. */ package pgx + +import ( + _ "github.com/jackc/pgx/v5/pgconn" // Just for allowing godoc to resolve "pgconn" +) diff --git a/vendor/github.com/jackc/pgx/v5/internal/iobufpool/iobufpool.go b/vendor/github.com/jackc/pgx/v5/internal/iobufpool/iobufpool.go index 89e0c227..abc41f65 100644 --- a/vendor/github.com/jackc/pgx/v5/internal/iobufpool/iobufpool.go +++ b/vendor/github.com/jackc/pgx/v5/internal/iobufpool/iobufpool.go @@ -4,7 +4,10 @@ // an allocation is purposely not documented. https://github.com/golang/go/issues/16323 package iobufpool -import "sync" +import ( + "math/bits" + "sync" +) const minPoolExpOf2 = 8 @@ -37,15 +40,14 @@ func Get(size int) *[]byte { } func getPoolIdx(size int) int { - size-- - size >>= minPoolExpOf2 - i := 0 - for size > 0 { - size >>= 1 - i++ + if size < 2 { + return 0 } - - return i + idx := bits.Len(uint(size-1)) - minPoolExpOf2 + if idx < 0 { + return 0 + } + return idx } // Put returns buf to the pool. @@ -59,12 +61,18 @@ func Put(buf *[]byte) { } func putPoolIdx(size int) int { - minPoolSize := 1 << minPoolExpOf2 - for i := range pools { - if size == minPoolSize<= len(pools) { + return -1 } - return -1 + return idx } diff --git a/vendor/github.com/jackc/pgx/v5/internal/pgio/write.go b/vendor/github.com/jackc/pgx/v5/internal/pgio/write.go index 96aedf9d..3a6700dc 100644 --- a/vendor/github.com/jackc/pgx/v5/internal/pgio/write.go +++ b/vendor/github.com/jackc/pgx/v5/internal/pgio/write.go @@ -1,26 +1,18 @@ package pgio -import "encoding/binary" - func AppendUint16(buf []byte, n uint16) []byte { - wp := len(buf) - buf = append(buf, 0, 0) - binary.BigEndian.PutUint16(buf[wp:], n) - return buf + return append(buf, byte(n>>8), byte(n)) } func AppendUint32(buf []byte, n uint32) []byte { - wp := len(buf) - buf = append(buf, 0, 0, 0, 0) - binary.BigEndian.PutUint32(buf[wp:], n) - return buf + return append(buf, byte(n>>24), byte(n>>16), byte(n>>8), byte(n)) } func AppendUint64(buf []byte, n uint64) []byte { - wp := len(buf) - buf = append(buf, 0, 0, 0, 0, 0, 0, 0, 0) - binary.BigEndian.PutUint64(buf[wp:], n) - return buf + return append(buf, + byte(n>>56), byte(n>>48), byte(n>>40), byte(n>>32), + byte(n>>24), byte(n>>16), byte(n>>8), byte(n), + ) } func AppendInt16(buf []byte, n int16) []byte { @@ -36,5 +28,5 @@ func AppendInt64(buf []byte, n int64) []byte { } func SetInt32(buf []byte, n int32) { - binary.BigEndian.PutUint32(buf, uint32(n)) + *(*[4]byte)(buf) = [4]byte{byte(n >> 24), byte(n >> 16), byte(n >> 8), byte(n)} } diff --git a/vendor/github.com/jackc/pgx/v5/internal/sanitize/benchmark.sh b/vendor/github.com/jackc/pgx/v5/internal/sanitize/benchmark.sh new file mode 100644 index 00000000..b4ee3fe7 --- /dev/null +++ b/vendor/github.com/jackc/pgx/v5/internal/sanitize/benchmark.sh @@ -0,0 +1,60 @@ +#!/usr/bin/env bash + +current_branch=$(git rev-parse --abbrev-ref HEAD) +if [ "$current_branch" == "HEAD" ]; then + current_branch=$(git rev-parse HEAD) +fi + +restore_branch() { + echo "Restoring original branch/commit: $current_branch" + git checkout "$current_branch" +} +trap restore_branch EXIT + +# Check if there are uncommitted changes +if ! git diff --quiet || ! git diff --cached --quiet; then + echo "There are uncommitted changes. Please commit or stash them before running this script." + exit 1 +fi + +# Ensure that at least one commit argument is passed +if [ "$#" -lt 1 ]; then + echo "Usage: $0 ... " + exit 1 +fi + +commits=("$@") +benchmarks_dir=benchmarks + +if ! mkdir -p "${benchmarks_dir}"; then + echo "Unable to create dir for benchmarks data" + exit 1 +fi + +# Benchmark results +bench_files=() + +# Run benchmark for each listed commit +for i in "${!commits[@]}"; do + commit="${commits[i]}" + git checkout "$commit" || { + echo "Failed to checkout $commit" + exit 1 + } + + # Sanitized commit message + commit_message=$(git log -1 --pretty=format:"%s" | tr -c '[:alnum:]-_' '_') + + # Benchmark data will go there + bench_file="${benchmarks_dir}/${i}_${commit_message}.bench" + + if ! go test -bench=. -count=10 >"$bench_file"; then + echo "Benchmarking failed for commit $commit" + exit 1 + fi + + bench_files+=("$bench_file") +done + +# go install golang.org/x/perf/cmd/benchstat[@latest] +benchstat "${bench_files[@]}" diff --git a/vendor/github.com/jackc/pgx/v5/internal/sanitize/sanitize.go b/vendor/github.com/jackc/pgx/v5/internal/sanitize/sanitize.go index df58c448..033a4143 100644 --- a/vendor/github.com/jackc/pgx/v5/internal/sanitize/sanitize.go +++ b/vendor/github.com/jackc/pgx/v5/internal/sanitize/sanitize.go @@ -4,8 +4,11 @@ import ( "bytes" "encoding/hex" "fmt" + "math" + "slices" "strconv" "strings" + "sync" "time" "unicode/utf8" ) @@ -24,18 +27,33 @@ type Query struct { // https://github.com/jackc/pgx/issues/1380 const replacementcharacterwidth = 3 +const maxBufSize = 16384 // 16 Ki + +var bufPool = &pool[*bytes.Buffer]{ + new: func() *bytes.Buffer { + return &bytes.Buffer{} + }, + reset: func(b *bytes.Buffer) bool { + n := b.Len() + b.Reset() + return n < maxBufSize + }, +} + +var null = []byte("null") + func (q *Query) Sanitize(args ...any) (string, error) { argUse := make([]bool, len(args)) - buf := &bytes.Buffer{} + buf := bufPool.get() + defer bufPool.put(buf) for _, part := range q.Parts { - var str string switch part := part.(type) { case string: - str = part + buf.WriteString(part) case int: argIdx := part - 1 - + var p []byte if argIdx < 0 { return "", fmt.Errorf("first sql argument must be > 0") } @@ -43,34 +61,41 @@ func (q *Query) Sanitize(args ...any) (string, error) { if argIdx >= len(args) { return "", fmt.Errorf("insufficient arguments") } + + // Prevent SQL injection via Line Comment Creation + // https://github.com/jackc/pgx/security/advisories/GHSA-m7wr-2xf7-cm9p + buf.WriteByte(' ') + arg := args[argIdx] switch arg := arg.(type) { case nil: - str = "null" + p = null case int64: - str = strconv.FormatInt(arg, 10) + p = strconv.AppendInt(buf.AvailableBuffer(), arg, 10) case float64: - str = strconv.FormatFloat(arg, 'f', -1, 64) + p = strconv.AppendFloat(buf.AvailableBuffer(), arg, 'f', -1, 64) case bool: - str = strconv.FormatBool(arg) + p = strconv.AppendBool(buf.AvailableBuffer(), arg) case []byte: - str = QuoteBytes(arg) + p = QuoteBytes(buf.AvailableBuffer(), arg) case string: - str = QuoteString(arg) + p = QuoteString(buf.AvailableBuffer(), arg) case time.Time: - str = arg.Truncate(time.Microsecond).Format("'2006-01-02 15:04:05.999999999Z07:00:00'") + p = arg.Truncate(time.Microsecond). + AppendFormat(buf.AvailableBuffer(), "'2006-01-02 15:04:05.999999999Z07:00:00'") default: return "", fmt.Errorf("invalid arg type: %T", arg) } argUse[argIdx] = true + buf.Write(p) + // Prevent SQL injection via Line Comment Creation // https://github.com/jackc/pgx/security/advisories/GHSA-m7wr-2xf7-cm9p - str = " " + str + " " + buf.WriteByte(' ') default: return "", fmt.Errorf("invalid Part type: %T", part) } - buf.WriteString(str) } for i, used := range argUse { @@ -82,35 +107,109 @@ func (q *Query) Sanitize(args ...any) (string, error) { } func NewQuery(sql string) (*Query, error) { - l := &sqlLexer{ - src: sql, - stateFn: rawState, + query := &Query{} + query.init(sql) + + return query, nil +} + +var sqlLexerPool = &pool[*sqlLexer]{ + new: func() *sqlLexer { + return &sqlLexer{} + }, + reset: func(sl *sqlLexer) bool { + *sl = sqlLexer{} + return true + }, +} + +func (q *Query) init(sql string) { + parts := q.Parts[:0] + if parts == nil { + // dirty, but fast heuristic to preallocate for ~90% usecases + n := strings.Count(sql, "$") + strings.Count(sql, "--") + 1 + parts = make([]Part, 0, n) } + l := sqlLexerPool.get() + defer sqlLexerPool.put(l) + + l.src = sql + l.stateFn = rawState + l.parts = parts + for l.stateFn != nil { l.stateFn = l.stateFn(l) } - query := &Query{Parts: l.parts} - - return query, nil + q.Parts = l.parts } -func QuoteString(str string) string { - return "'" + strings.ReplaceAll(str, "'", "''") + "'" +func QuoteString(dst []byte, str string) []byte { + const quote = '\'' + + // Preallocate space for the worst case scenario + dst = slices.Grow(dst, len(str)*2+2) + + // Add opening quote + dst = append(dst, quote) + + // Iterate through the string without allocating + for i := 0; i < len(str); i++ { + if str[i] == quote { + dst = append(dst, quote, quote) + } else { + dst = append(dst, str[i]) + } + } + + // Add closing quote + dst = append(dst, quote) + + return dst } -func QuoteBytes(buf []byte) string { - return `'\x` + hex.EncodeToString(buf) + "'" +func QuoteBytes(dst, buf []byte) []byte { + if len(buf) == 0 { + return append(dst, `'\x'`...) + } + + // Calculate required length + requiredLen := 3 + hex.EncodedLen(len(buf)) + 1 + + // Ensure dst has enough capacity + if cap(dst)-len(dst) < requiredLen { + newDst := make([]byte, len(dst), len(dst)+requiredLen) + copy(newDst, dst) + dst = newDst + } + + // Record original length and extend slice + origLen := len(dst) + dst = dst[:origLen+requiredLen] + + // Add prefix + dst[origLen] = '\'' + dst[origLen+1] = '\\' + dst[origLen+2] = 'x' + + // Encode bytes directly into dst + hex.Encode(dst[origLen+3:len(dst)-1], buf) + + // Add suffix + dst[len(dst)-1] = '\'' + + return dst } type sqlLexer struct { - src string - start int - pos int - nested int // multiline comment nesting level. - stateFn stateFn - parts []Part + src string + start int + pos int + nested int // multiline comment nesting level. + dollarTag string // active tag while inside a dollar-quoted string (may be empty for $$). + stateFn stateFn + parts []Part } type stateFn func(*sqlLexer) stateFn @@ -140,6 +239,15 @@ func rawState(l *sqlLexer) stateFn { l.start = l.pos return placeholderState } + // PostgreSQL dollar-quoted string: $[tag]$...$[tag]$. The $ was + // just consumed; try to match the rest of the opening tag. + // Without this, placeholders embedded inside dollar-quoted + // literals would be incorrectly substituted. + if tagLen, ok := scanDollarQuoteTag(l.src[l.pos:]); ok { + l.dollarTag = l.src[l.pos : l.pos+tagLen] + l.pos += tagLen + 1 // advance past tag and closing '$' + return dollarQuoteState + } case '-': nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:]) if nextRune == '-' { @@ -222,8 +330,16 @@ func placeholderState(l *sqlLexer) stateFn { l.pos += width if '0' <= r && r <= '9' { - num *= 10 - num += int(r - '0') + // Clamp rather than silently wrap on pathological input like + // "$92233720368547758070" which would otherwise overflow int and + // could land on a valid args index. Any value above MaxInt32 far + // exceeds any plausible args length, so Sanitize will correctly + // return "insufficient arguments". + if num > (math.MaxInt32-9)/10 { + num = math.MaxInt32 + } else { + num = num*10 + int(r-'0') + } } else { l.parts = append(l.parts, num) l.pos -= width @@ -233,6 +349,68 @@ func placeholderState(l *sqlLexer) stateFn { } } +// dollarQuoteState consumes the body of a PostgreSQL dollar-quoted string +// ($[tag]$...$[tag]$). The opening tag (including its terminating '$') has +// already been consumed. +func dollarQuoteState(l *sqlLexer) stateFn { + closer := "$" + l.dollarTag + "$" + idx := strings.Index(l.src[l.pos:], closer) + if idx < 0 { + // Unterminated — mirror the behavior of other quoted-string states by + // consuming the remaining input into the current part and stopping. + if len(l.src)-l.start > 0 { + l.parts = append(l.parts, l.src[l.start:]) + l.start = len(l.src) + } + l.pos = len(l.src) + return nil + } + l.pos += idx + len(closer) + l.dollarTag = "" + return rawState +} + +// scanDollarQuoteTag checks whether src begins with an optional dollar-quoted +// string tag followed by a closing '$'. src must point just past the opening +// '$'. Returns the byte length of the tag (zero for an anonymous $$) and +// whether a valid tag was found. +// +// Tag grammar matches the PostgreSQL lexer (scan.l): +// +// dolq_start: [A-Za-z_\x80-\xff] +// dolq_cont: [A-Za-z0-9_\x80-\xff] +func scanDollarQuoteTag(src string) (int, bool) { + first := true + for i := 0; i < len(src); { + r, w := utf8.DecodeRuneInString(src[i:]) + if r == '$' { + return i, true + } + if !isDollarTagRune(r, first) { + return 0, false + } + first = false + i += w + } + return 0, false +} + +func isDollarTagRune(r rune, first bool) bool { + switch { + case r == '_': + return true + case 'a' <= r && r <= 'z': + return true + case 'A' <= r && r <= 'Z': + return true + case !first && '0' <= r && r <= '9': + return true + case r >= 0x80 && r != utf8.RuneError: + return true + } + return false +} + func escapeStringState(l *sqlLexer) stateFn { for { r, width := utf8.DecodeRuneInString(l.src[l.pos:]) @@ -319,13 +497,45 @@ func multilineCommentState(l *sqlLexer) stateFn { } } +var queryPool = &pool[*Query]{ + new: func() *Query { + return &Query{} + }, + reset: func(q *Query) bool { + n := len(q.Parts) + q.Parts = q.Parts[:0] + return n < 64 // drop too large queries + }, +} + // SanitizeSQL replaces placeholder values with args. It quotes and escapes args // as necessary. This function is only safe when standard_conforming_strings is // on. func SanitizeSQL(sql string, args ...any) (string, error) { - query, err := NewQuery(sql) - if err != nil { - return "", err - } + query := queryPool.get() + query.init(sql) + defer queryPool.put(query) + return query.Sanitize(args...) } + +type pool[E any] struct { + p sync.Pool + new func() E + reset func(E) bool +} + +func (pool *pool[E]) get() E { + v, ok := pool.p.Get().(E) + if !ok { + v = pool.new() + } + + return v +} + +func (p *pool[E]) put(v E) { + if p.reset(v) { + p.p.Put(v) + } +} diff --git a/vendor/github.com/jackc/pgx/v5/internal/stmtcache/lru_cache.go b/vendor/github.com/jackc/pgx/v5/internal/stmtcache/lru_cache.go index dec83f47..b677d29c 100644 --- a/vendor/github.com/jackc/pgx/v5/internal/stmtcache/lru_cache.go +++ b/vendor/github.com/jackc/pgx/v5/internal/stmtcache/lru_cache.go @@ -1,37 +1,54 @@ package stmtcache import ( - "container/list" - "github.com/jackc/pgx/v5/pgconn" ) +// lruNode is a typed doubly-linked list node with freelist support. +type lruNode struct { + sd *pgconn.StatementDescription + prev *lruNode + next *lruNode +} + // LRUCache implements Cache with a Least Recently Used (LRU) cache. type LRUCache struct { - cap int - m map[string]*list.Element - l *list.List + m map[string]*lruNode + head *lruNode + + tail *lruNode + len int + cap int + freelist *lruNode + invalidStmts []*pgconn.StatementDescription + invalidSet map[string]struct{} } // NewLRUCache creates a new LRUCache. cap is the maximum size of the cache. func NewLRUCache(cap int) *LRUCache { + head := &lruNode{} + tail := &lruNode{} + head.next = tail + tail.prev = head + return &LRUCache{ - cap: cap, - m: make(map[string]*list.Element), - l: list.New(), + cap: cap, + m: make(map[string]*lruNode, cap), + head: head, + tail: tail, + invalidSet: make(map[string]struct{}), } } // Get returns the statement description for sql. Returns nil if not found. func (c *LRUCache) Get(key string) *pgconn.StatementDescription { - if el, ok := c.m[key]; ok { - c.l.MoveToFront(el) - return el.Value.(*pgconn.StatementDescription) + node, ok := c.m[key] + if !ok { + return nil } - - return nil - + c.moveToFront(node) + return node.sd } // Put stores sd in the cache. Put panics if sd.SQL is "". Put does nothing if sd.SQL already exists in the cache or @@ -46,39 +63,49 @@ func (c *LRUCache) Put(sd *pgconn.StatementDescription) { } // The statement may have been invalidated but not yet handled. Do not readd it to the cache. - for _, invalidSD := range c.invalidStmts { - if invalidSD.SQL == sd.SQL { - return - } + if _, invalidated := c.invalidSet[sd.SQL]; invalidated { + return } - if c.l.Len() == c.cap { + if c.len == c.cap { c.invalidateOldest() } - el := c.l.PushFront(sd) - c.m[sd.SQL] = el + node := c.allocNode() + node.sd = sd + c.insertAfter(c.head, node) + c.m[sd.SQL] = node + c.len++ } // Invalidate invalidates statement description identified by sql. Does nothing if not found. func (c *LRUCache) Invalidate(sql string) { - if el, ok := c.m[sql]; ok { - delete(c.m, sql) - c.invalidStmts = append(c.invalidStmts, el.Value.(*pgconn.StatementDescription)) - c.l.Remove(el) + node, ok := c.m[sql] + if !ok { + return } + delete(c.m, sql) + c.invalidStmts = append(c.invalidStmts, node.sd) + c.invalidSet[sql] = struct{}{} + c.unlink(node) + c.len-- + c.freeNode(node) } // InvalidateAll invalidates all statement descriptions. func (c *LRUCache) InvalidateAll() { - el := c.l.Front() - for el != nil { - c.invalidStmts = append(c.invalidStmts, el.Value.(*pgconn.StatementDescription)) - el = el.Next() + for node := c.head.next; node != c.tail; { + next := node.next + c.invalidStmts = append(c.invalidStmts, node.sd) + c.invalidSet[node.sd.SQL] = struct{}{} + c.freeNode(node) + node = next } - c.m = make(map[string]*list.Element) - c.l = list.New() + clear(c.m) + c.head.next = c.tail + c.tail.prev = c.head + c.len = 0 } // GetInvalidated returns a slice of all statement descriptions invalidated since the last call to RemoveInvalidated. @@ -90,12 +117,13 @@ func (c *LRUCache) GetInvalidated() []*pgconn.StatementDescription { // call to GetInvalidated and RemoveInvalidated or RemoveInvalidated may remove statement descriptions that were // never seen by the call to GetInvalidated. func (c *LRUCache) RemoveInvalidated() { - c.invalidStmts = nil + c.invalidStmts = c.invalidStmts[:0] + clear(c.invalidSet) } // Len returns the number of cached prepared statement descriptions. func (c *LRUCache) Len() int { - return c.l.Len() + return c.len } // Cap returns the maximum number of cached prepared statement descriptions. @@ -104,9 +132,56 @@ func (c *LRUCache) Cap() int { } func (c *LRUCache) invalidateOldest() { - oldest := c.l.Back() - sd := oldest.Value.(*pgconn.StatementDescription) - c.invalidStmts = append(c.invalidStmts, sd) - delete(c.m, sd.SQL) - c.l.Remove(oldest) + node := c.tail.prev + if node == c.head { + return + } + c.invalidStmts = append(c.invalidStmts, node.sd) + c.invalidSet[node.sd.SQL] = struct{}{} + delete(c.m, node.sd.SQL) + c.unlink(node) + c.len-- + c.freeNode(node) +} + +// List operations - sentinel nodes eliminate nil checks + +func (c *LRUCache) insertAfter(at, node *lruNode) { + node.prev = at + node.next = at.next + at.next.prev = node + at.next = node +} + +func (c *LRUCache) unlink(node *lruNode) { + node.prev.next = node.next + node.next.prev = node.prev +} + +func (c *LRUCache) moveToFront(node *lruNode) { + if node.prev == c.head { + return + } + c.unlink(node) + c.insertAfter(c.head, node) +} + +// Node pool operations - reuse evicted nodes to avoid allocations + +func (c *LRUCache) allocNode() *lruNode { + if c.freelist != nil { + node := c.freelist + c.freelist = node.next + node.next = nil + node.prev = nil + return node + } + return &lruNode{} +} + +func (c *LRUCache) freeNode(node *lruNode) { + node.sd = nil + node.prev = nil + node.next = c.freelist + c.freelist = node } diff --git a/vendor/github.com/jackc/pgx/v5/internal/stmtcache/unlimited_cache.go b/vendor/github.com/jackc/pgx/v5/internal/stmtcache/unlimited_cache.go deleted file mode 100644 index 69641329..00000000 --- a/vendor/github.com/jackc/pgx/v5/internal/stmtcache/unlimited_cache.go +++ /dev/null @@ -1,77 +0,0 @@ -package stmtcache - -import ( - "math" - - "github.com/jackc/pgx/v5/pgconn" -) - -// UnlimitedCache implements Cache with no capacity limit. -type UnlimitedCache struct { - m map[string]*pgconn.StatementDescription - invalidStmts []*pgconn.StatementDescription -} - -// NewUnlimitedCache creates a new UnlimitedCache. -func NewUnlimitedCache() *UnlimitedCache { - return &UnlimitedCache{ - m: make(map[string]*pgconn.StatementDescription), - } -} - -// Get returns the statement description for sql. Returns nil if not found. -func (c *UnlimitedCache) Get(sql string) *pgconn.StatementDescription { - return c.m[sql] -} - -// Put stores sd in the cache. Put panics if sd.SQL is "". Put does nothing if sd.SQL already exists in the cache. -func (c *UnlimitedCache) Put(sd *pgconn.StatementDescription) { - if sd.SQL == "" { - panic("cannot store statement description with empty SQL") - } - - if _, present := c.m[sd.SQL]; present { - return - } - - c.m[sd.SQL] = sd -} - -// Invalidate invalidates statement description identified by sql. Does nothing if not found. -func (c *UnlimitedCache) Invalidate(sql string) { - if sd, ok := c.m[sql]; ok { - delete(c.m, sql) - c.invalidStmts = append(c.invalidStmts, sd) - } -} - -// InvalidateAll invalidates all statement descriptions. -func (c *UnlimitedCache) InvalidateAll() { - for _, sd := range c.m { - c.invalidStmts = append(c.invalidStmts, sd) - } - - c.m = make(map[string]*pgconn.StatementDescription) -} - -// GetInvalidated returns a slice of all statement descriptions invalidated since the last call to RemoveInvalidated. -func (c *UnlimitedCache) GetInvalidated() []*pgconn.StatementDescription { - return c.invalidStmts -} - -// RemoveInvalidated removes all invalidated statement descriptions. No other calls to Cache must be made between a -// call to GetInvalidated and RemoveInvalidated or RemoveInvalidated may remove statement descriptions that were -// never seen by the call to GetInvalidated. -func (c *UnlimitedCache) RemoveInvalidated() { - c.invalidStmts = nil -} - -// Len returns the number of cached prepared statement descriptions. -func (c *UnlimitedCache) Len() int { - return len(c.m) -} - -// Cap returns the maximum number of cached prepared statement descriptions. -func (c *UnlimitedCache) Cap() int { - return math.MaxInt -} diff --git a/vendor/github.com/jackc/pgx/v5/pgconn/auth_oauth.go b/vendor/github.com/jackc/pgx/v5/pgconn/auth_oauth.go new file mode 100644 index 00000000..991f6585 --- /dev/null +++ b/vendor/github.com/jackc/pgx/v5/pgconn/auth_oauth.go @@ -0,0 +1,67 @@ +package pgconn + +import ( + "context" + "encoding/json" + "errors" + "fmt" + + "github.com/jackc/pgx/v5/pgproto3" +) + +func (c *PgConn) oauthAuth(ctx context.Context) error { + if c.config.OAuthTokenProvider == nil { + return errors.New("OAuth authentication required but no token provider configured") + } + + token, err := c.config.OAuthTokenProvider(ctx) + if err != nil { + return fmt.Errorf("failed to obtain OAuth token: %w", err) + } + + // https://www.rfc-editor.org/rfc/rfc7628.html#section-3.1 + initialResponse := []byte("n,,\x01auth=Bearer " + token + "\x01\x01") + + saslInitialResponse := &pgproto3.SASLInitialResponse{ + AuthMechanism: "OAUTHBEARER", + Data: initialResponse, + } + c.frontend.Send(saslInitialResponse) + err = c.flushWithPotentialWriteReadDeadlock() + if err != nil { + return err + } + + msg, err := c.receiveMessage() + if err != nil { + return err + } + + switch m := msg.(type) { + case *pgproto3.AuthenticationOk: + return nil + case *pgproto3.AuthenticationSASLContinue: + // Server sent error response in SASL continue + // https://www.rfc-editor.org/rfc/rfc7628.html#section-3.2.2 + // https://www.rfc-editor.org/rfc/rfc7628.html#section-3.2.3 + errResponse := struct { + Status string `json:"status"` + Scope string `json:"scope"` + OpenIDConfiguration string `json:"openid-configuration"` + }{} + err := json.Unmarshal(m.Data, &errResponse) + if err != nil { + return fmt.Errorf("invalid OAuth error response from server: %w", err) + } + + // Per RFC 7628 section 3.2.3, we should send a SASLResponse which only contains \x01. + // However, since the connection will be closed anyway, we can skip this + return fmt.Errorf("OAuth authentication failed: %s", errResponse.Status) + + case *pgproto3.ErrorResponse: + return ErrorResponseToPgError(m) + + default: + return fmt.Errorf("unexpected message type during OAuth auth: %T", msg) + } +} diff --git a/vendor/github.com/jackc/pgx/v5/pgconn/auth_scram.go b/vendor/github.com/jackc/pgx/v5/pgconn/auth_scram.go index 06498361..f59d39c4 100644 --- a/vendor/github.com/jackc/pgx/v5/pgconn/auth_scram.go +++ b/vendor/github.com/jackc/pgx/v5/pgconn/auth_scram.go @@ -1,7 +1,8 @@ -// SCRAM-SHA-256 authentication +// SCRAM-SHA-256 and SCRAM-SHA-256-PLUS authentication // // Resources: // https://tools.ietf.org/html/rfc5802 +// https://tools.ietf.org/html/rfc5929 // https://tools.ietf.org/html/rfc8265 // https://www.postgresql.org/docs/current/sasl-authentication.html // @@ -15,19 +16,28 @@ package pgconn import ( "bytes" "crypto/hmac" + "crypto/pbkdf2" "crypto/rand" "crypto/sha256" + "crypto/sha512" + "crypto/tls" + "crypto/x509" "encoding/base64" "errors" "fmt" + "hash" + "slices" "strconv" "github.com/jackc/pgx/v5/pgproto3" - "golang.org/x/crypto/pbkdf2" "golang.org/x/text/secure/precis" ) -const clientNonceLen = 18 +const ( + clientNonceLen = 18 + scramSHA256Name = "SCRAM-SHA-256" + scramSHA256PlusName = "SCRAM-SHA-256-PLUS" +) // Perform SCRAM authentication. func (c *PgConn) scramAuth(serverAuthMechanisms []string) error { @@ -36,9 +46,35 @@ func (c *PgConn) scramAuth(serverAuthMechanisms []string) error { return err } + serverHasPlus := slices.Contains(sc.serverAuthMechanisms, scramSHA256PlusName) + if c.config.ChannelBinding == "require" && !serverHasPlus { + return errors.New("channel binding required but server does not support SCRAM-SHA-256-PLUS") + } + + // If we have a TLS connection and channel binding is not disabled, attempt to + // extract the server certificate hash for tls-server-end-point channel binding. + if tlsConn, ok := c.conn.(*tls.Conn); ok && c.config.ChannelBinding != "disable" { + certHash, err := getTLSCertificateHash(tlsConn) + if err != nil && c.config.ChannelBinding == "require" { + return fmt.Errorf("channel binding required but failed to get server certificate hash: %w", err) + } + + // Upgrade to SCRAM-SHA-256-PLUS if we have binding data and the server supports it. + if certHash != nil && serverHasPlus { + sc.authMechanism = scramSHA256PlusName + } + + sc.channelBindingData = certHash + sc.hasTLS = true + } + + if c.config.ChannelBinding == "require" && sc.channelBindingData == nil { + return errors.New("channel binding required but channel binding data is not available") + } + // Send client-first-message in a SASLInitialResponse saslInitialResponse := &pgproto3.SASLInitialResponse{ - AuthMechanism: "SCRAM-SHA-256", + AuthMechanism: sc.authMechanism, Data: sc.clientFirstMessage(), } c.frontend.Send(saslInitialResponse) @@ -107,10 +143,31 @@ func (c *PgConn) rxSASLFinal() (*pgproto3.AuthenticationSASLFinal, error) { type scramClient struct { serverAuthMechanisms []string - password []byte + password string clientNonce []byte + // authMechanism is the selected SASL mechanism for the client. Must be + // either SCRAM-SHA-256 (default) or SCRAM-SHA-256-PLUS. + // + // Upgraded to SCRAM-SHA-256-PLUS during authentication when channel binding + // is not disabled, channel binding data is available (TLS connection with + // an obtainable server certificate hash) and the server advertises + // SCRAM-SHA-256-PLUS. + authMechanism string + + // hasTLS indicates whether the connection is using TLS. This is + // needed because the GS2 header must distinguish between a client that + // supports channel binding but the server does not ("y,,") versus one + // that does not support it at all ("n,,"). + hasTLS bool + + // channelBindingData is the hash of the server's TLS certificate, computed + // per the tls-server-end-point channel binding type (RFC 5929). Used as + // the binding input in SCRAM-SHA-256-PLUS. nil when not in use. + channelBindingData []byte + clientFirstMessageBare []byte + clientGS2Header []byte serverFirstMessage []byte clientAndServerNonce []byte @@ -124,26 +181,23 @@ type scramClient struct { func newScramClient(serverAuthMechanisms []string, password string) (*scramClient, error) { sc := &scramClient{ serverAuthMechanisms: serverAuthMechanisms, + authMechanism: scramSHA256Name, } - // Ensure server supports SCRAM-SHA-256 - hasScramSHA256 := false - for _, mech := range sc.serverAuthMechanisms { - if mech == "SCRAM-SHA-256" { - hasScramSHA256 = true - break - } - } - if !hasScramSHA256 { + // Ensure the server supports SCRAM-SHA-256. SCRAM-SHA-256-PLUS is the + // channel binding variant and is only advertised when the server supports + // SSL. PostgreSQL always advertises the base SCRAM-SHA-256 mechanism + // regardless of SSL. + if !slices.Contains(sc.serverAuthMechanisms, scramSHA256Name) { return nil, errors.New("server does not support SCRAM-SHA-256") } // precis.OpaqueString is equivalent to SASLprep for password. var err error - sc.password, err = precis.OpaqueString.Bytes([]byte(password)) + sc.password, err = precis.OpaqueString.String(password) if err != nil { // PostgreSQL allows passwords invalid according to SCRAM / SASLprep. - sc.password = []byte(password) + sc.password = password } buf := make([]byte, clientNonceLen) @@ -158,8 +212,32 @@ func newScramClient(serverAuthMechanisms []string, password string) (*scramClien } func (sc *scramClient) clientFirstMessage() []byte { - sc.clientFirstMessageBare = []byte(fmt.Sprintf("n=,r=%s", sc.clientNonce)) - return []byte(fmt.Sprintf("n,,%s", sc.clientFirstMessageBare)) + // The client-first-message is the GS2 header concatenated with the bare + // message (username + client nonce). The GS2 header communicates the + // client's channel binding capability to the server: + // + // "n,," - client is not using TLS (channel binding not possible) + // "y,," - client is using TLS but channel binding is not + // in use (e.g., server did not advertise SCRAM-SHA-256-PLUS + // or the server certificate hash was not obtainable) + // "p=tls-server-end-point,," - channel binding is active via SCRAM-SHA-256-PLUS + // + // See: + // https://www.rfc-editor.org/rfc/rfc5802#section-6 + // https://www.rfc-editor.org/rfc/rfc5929#section-4 + // https://www.postgresql.org/docs/current/sasl-authentication.html#SASL-SCRAM-SHA-256 + + sc.clientFirstMessageBare = fmt.Appendf(nil, "n=,r=%s", sc.clientNonce) + + if sc.authMechanism == scramSHA256PlusName { + sc.clientGS2Header = []byte("p=tls-server-end-point,,") + } else if sc.hasTLS { + sc.clientGS2Header = []byte("y,,") + } else { + sc.clientGS2Header = []byte("n,,") + } + + return append(sc.clientGS2Header, sc.clientFirstMessageBare...) } func (sc *scramClient) recvServerFirstMessage(serverFirstMessage []byte) error { @@ -218,9 +296,25 @@ func (sc *scramClient) recvServerFirstMessage(serverFirstMessage []byte) error { } func (sc *scramClient) clientFinalMessage() string { - clientFinalMessageWithoutProof := []byte(fmt.Sprintf("c=biws,r=%s", sc.clientAndServerNonce)) + // The c= attribute carries the base64-encoded channel binding input. + // + // Without channel binding this is just the GS2 header alone ("biws" for + // "n,," or "eSws" for "y,,"). + // + // With channel binding, this is the GS2 header with the channel binding data + // (certificate hash) appended. + channelBindInput := sc.clientGS2Header + if sc.authMechanism == scramSHA256PlusName { + channelBindInput = slices.Concat(sc.clientGS2Header, sc.channelBindingData) + } + channelBindingEncoded := base64.StdEncoding.EncodeToString(channelBindInput) + clientFinalMessageWithoutProof := fmt.Appendf(nil, "c=%s,r=%s", channelBindingEncoded, sc.clientAndServerNonce) - sc.saltedPassword = pbkdf2.Key([]byte(sc.password), sc.salt, sc.iterations, 32, sha256.New) + var err error + sc.saltedPassword, err = pbkdf2.Key(sha256.New, sc.password, sc.salt, sc.iterations, 32) + if err != nil { + panic(err) // This should never happen. + } sc.authMessage = bytes.Join([][]byte{sc.clientFirstMessageBare, sc.serverFirstMessage, clientFinalMessageWithoutProof}, []byte(",")) clientProof := computeClientProof(sc.saltedPassword, sc.authMessage) @@ -254,7 +348,7 @@ func computeClientProof(saltedPassword, authMessage []byte) []byte { clientSignature := computeHMAC(storedKey[:], authMessage) clientProof := make([]byte, len(clientSignature)) - for i := 0; i < len(clientSignature); i++ { + for i := range clientSignature { clientProof[i] = clientKey[i] ^ clientSignature[i] } @@ -263,10 +357,43 @@ func computeClientProof(saltedPassword, authMessage []byte) []byte { return buf } -func computeServerSignature(saltedPassword []byte, authMessage []byte) []byte { +func computeServerSignature(saltedPassword, authMessage []byte) []byte { serverKey := computeHMAC(saltedPassword, []byte("Server Key")) serverSignature := computeHMAC(serverKey, authMessage) buf := make([]byte, base64.StdEncoding.EncodedLen(len(serverSignature))) base64.StdEncoding.Encode(buf, serverSignature) return buf } + +// Get the server certificate hash for SCRAM channel binding type +// tls-server-end-point. +func getTLSCertificateHash(conn *tls.Conn) ([]byte, error) { + state := conn.ConnectionState() + if len(state.PeerCertificates) == 0 { + return nil, errors.New("no peer certificates for channel binding") + } + + cert := state.PeerCertificates[0] + + // Per RFC 5929 section 4.1: If the certificate's signatureAlgorithm uses + // MD5 or SHA-1, use SHA-256. Otherwise use the hash from the signature + // algorithm. + // + // See: https://www.rfc-editor.org/rfc/rfc5929.html#section-4.1 + var h hash.Hash + switch cert.SignatureAlgorithm { + case x509.MD5WithRSA, x509.SHA1WithRSA, x509.ECDSAWithSHA1: + h = sha256.New() + case x509.SHA256WithRSA, x509.SHA256WithRSAPSS, x509.ECDSAWithSHA256: + h = sha256.New() + case x509.SHA384WithRSA, x509.SHA384WithRSAPSS, x509.ECDSAWithSHA384: + h = sha512.New384() + case x509.SHA512WithRSA, x509.SHA512WithRSAPSS, x509.ECDSAWithSHA512: + h = sha512.New() + default: + return nil, fmt.Errorf("tls-server-end-point channel binding is undefined for certificate signature algorithm %v", cert.SignatureAlgorithm) + } + + h.Write(cert.Raw) + return h.Sum(nil), nil +} diff --git a/vendor/github.com/jackc/pgx/v5/pgconn/config.go b/vendor/github.com/jackc/pgx/v5/pgconn/config.go index 6a198e67..0177d22c 100644 --- a/vendor/github.com/jackc/pgx/v5/pgconn/config.go +++ b/vendor/github.com/jackc/pgx/v5/pgconn/config.go @@ -8,6 +8,7 @@ import ( "errors" "fmt" "io" + "maps" "math" "net" "net/url" @@ -23,9 +24,11 @@ import ( "github.com/jackc/pgx/v5/pgproto3" ) -type AfterConnectFunc func(ctx context.Context, pgconn *PgConn) error -type ValidateConnectFunc func(ctx context.Context, pgconn *PgConn) error -type GetSSLPasswordFunc func(ctx context.Context) string +type ( + AfterConnectFunc func(ctx context.Context, pgconn *PgConn) error + ValidateConnectFunc func(ctx context.Context, pgconn *PgConn) error + GetSSLPasswordFunc func(ctx context.Context) string +) // Config is the settings used to establish a connection to a PostgreSQL server. It must be created by [ParseConfig]. A // manually initialized Config will cause ConnectConfig to panic. @@ -51,6 +54,15 @@ type Config struct { KerberosSpn string Fallbacks []*FallbackConfig + SSLNegotiation string // sslnegotiation=postgres or sslnegotiation=direct + + // AfterNetConnect is called after the network connection, including TLS if applicable, is established but before any + // PostgreSQL protocol communication. It takes the established net.Conn and returns a net.Conn that will be used in + // its place. It can be used to wrap the net.Conn (e.g. for logging, diagnostics, or testing). Its functionality has + // some overlap with DialFunc. However, DialFunc takes place before TLS is established and cannot be used to control + // the final net.Conn used for PostgreSQL protocol communication while AfterNetConnect can. + AfterNetConnect func(ctx context.Context, config *Config, conn net.Conn) (net.Conn, error) + // ValidateConnect is called during a connection attempt after a successful authentication with the PostgreSQL server. // It can be used to validate that the server is acceptable. If this returns an error the connection is closed and the next // fallback config is tried. This allows implementing high availability behavior such as libpq does with target_session_attrs. @@ -71,6 +83,23 @@ type Config struct { // that you close on FATAL errors by returning false. OnPgError PgErrorHandler + // OAuthTokenProvider is a function that returns an OAuth token for authentication. If set, it will be used for + // OAUTHBEARER SASL authentication when the server requests it. + OAuthTokenProvider func(context.Context) (string, error) + + // MinProtocolVersion is the minimum acceptable PostgreSQL protocol version. + // If the server does not support at least this version, the connection will fail. + // Valid values: "3.0", "3.2", "latest". Defaults to "3.0". + MinProtocolVersion string + + // MaxProtocolVersion is the maximum PostgreSQL protocol version to request from the server. + // Valid values: "3.0", "3.2", "latest". Defaults to "3.0" for compatibility. + MaxProtocolVersion string + + // ChannelBinding is the channel_binding parameter for SCRAM-SHA-256-PLUS authentication. + // Valid values: "disable", "prefer", "require". Defaults to "prefer". + ChannelBinding string + createdByParseConfig bool // Used to enforce created by ParseConfig rule. } @@ -92,9 +121,7 @@ func (c *Config) Copy() *Config { } if newConf.RuntimeParams != nil { newConf.RuntimeParams = make(map[string]string, len(c.RuntimeParams)) - for k, v := range c.RuntimeParams { - newConf.RuntimeParams[k] = v - } + maps.Copy(newConf.RuntimeParams, c.RuntimeParams) } if newConf.Fallbacks != nil { newConf.Fallbacks = make([]*FallbackConfig, len(c.Fallbacks)) @@ -177,7 +204,7 @@ func NetworkAddress(host string, port uint16) (network, address string) { // // ParseConfig supports specifying multiple hosts in similar manner to libpq. Host and port may include comma separated // values that will be tried in order. This can be used as part of a high availability system. See -// https://www.postgresql.org/docs/11/libpq-connect.html#LIBPQ-MULTIPLE-HOSTS for more information. +// https://www.postgresql.org/docs/current/libpq-connect.html#LIBPQ-MULTIPLE-HOSTS for more information. // // # Example URL // postgres://jack:secret@foo.example.com:5432,bar.example.com:5432/mydb @@ -198,13 +225,17 @@ func NetworkAddress(host string, port uint16) (network, address string) { // PGSSLKEY // PGSSLROOTCERT // PGSSLPASSWORD +// PGOPTIONS // PGAPPNAME // PGCONNECT_TIMEOUT // PGTARGETSESSIONATTRS +// PGTZ +// PGMINPROTOCOLVERSION +// PGMAXPROTOCOLVERSION // -// See http://www.postgresql.org/docs/11/static/libpq-envars.html for details on the meaning of environment variables. +// See http://www.postgresql.org/docs/current/static/libpq-envars.html for details on the meaning of environment variables. // -// See https://www.postgresql.org/docs/11/libpq-connect.html#LIBPQ-PARAMKEYWORDS for parameter key word names. They are +// See https://www.postgresql.org/docs/current/libpq-connect.html#LIBPQ-PARAMKEYWORDS for parameter key word names. They are // usually but not always the environment variable name downcased and without the "PG" prefix. // // Important Security Notes: @@ -212,7 +243,7 @@ func NetworkAddress(host string, port uint16) (network, address string) { // ParseConfig tries to match libpq behavior with regard to PGSSLMODE. This includes defaulting to "prefer" behavior if // not set. // -// See http://www.postgresql.org/docs/11/static/libpq-ssl.html#LIBPQ-SSL-PROTECTION for details on what level of +// See http://www.postgresql.org/docs/current/static/libpq-ssl.html#LIBPQ-SSL-PROTECTION for details on what level of // security each sslmode provides. // // The sslmode "prefer" (the default), sslmode "allow", and multiple hosts are implemented via the Fallbacks field of @@ -318,6 +349,7 @@ func ParseConfigWithOptions(connString string, options ParseConfigOptions) (*Con "sslkey": {}, "sslcert": {}, "sslrootcert": {}, + "sslnegotiation": {}, "sslpassword": {}, "sslsni": {}, "krbspn": {}, @@ -325,6 +357,9 @@ func ParseConfigWithOptions(connString string, options ParseConfigOptions) (*Con "target_session_attrs": {}, "service": {}, "servicefile": {}, + "min_protocol_version": {}, + "max_protocol_version": {}, + "channel_binding": {}, } // Adding kerberos configuration @@ -386,6 +421,7 @@ func ParseConfigWithOptions(connString string, options ParseConfigOptions) (*Con config.Port = fallbacks[0].Port config.TLSConfig = fallbacks[0].TLSConfig config.Fallbacks = fallbacks[1:] + config.SSLNegotiation = settings["sslnegotiation"] passfile, err := pgpassfile.ReadPassfile(settings["passfile"]) if err == nil { @@ -416,6 +452,52 @@ func ParseConfigWithOptions(connString string, options ParseConfigOptions) (*Con return nil, &ParseConfigError{ConnString: connString, msg: fmt.Sprintf("unknown target_session_attrs value: %v", tsa)} } + minProto, err := parseProtocolVersion(settings["min_protocol_version"]) + if err != nil { + return nil, &ParseConfigError{ConnString: connString, msg: fmt.Sprintf("invalid min_protocol_version: %q", settings["min_protocol_version"]), err: err} + } + maxProto, err := parseProtocolVersion(settings["max_protocol_version"]) + if err != nil { + return nil, &ParseConfigError{ConnString: connString, msg: fmt.Sprintf("invalid max_protocol_version: %q", settings["max_protocol_version"]), err: err} + } + + config.MinProtocolVersion = settings["min_protocol_version"] + config.MaxProtocolVersion = settings["max_protocol_version"] + + if config.MinProtocolVersion == "" { + config.MinProtocolVersion = "3.0" + } + + // When max_protocol_version is not explicitly set, default based on + // min_protocol_version. This matches libpq behavior: if min > 3.0, + // default max to latest; otherwise default to 3.0 for compatibility + // with older servers/poolers that don't support NegotiateProtocolVersion. + if config.MaxProtocolVersion == "" { + if minProto > pgproto3.ProtocolVersion30 { + config.MaxProtocolVersion = "latest" + } else { + config.MaxProtocolVersion = "3.0" + } + } + + // Only error when max_protocol_version was explicitly set and conflicts + // with min_protocol_version. When max_protocol_version is not explicitly + // set, the auto-raise logic above already ensures a valid default. + if minProto > maxProto && settings["max_protocol_version"] != "" { + return nil, &ParseConfigError{ConnString: connString, msg: "min_protocol_version cannot be greater than max_protocol_version"} + } + + switch channelBinding := settings["channel_binding"]; channelBinding { + case "", "prefer": + config.ChannelBinding = "prefer" + case "disable": + config.ChannelBinding = "disable" + case "require": + config.ChannelBinding = "require" + default: + return nil, &ParseConfigError{ConnString: connString, msg: fmt.Sprintf("unknown channel_binding value: %v", channelBinding)} + } + return config, nil } @@ -423,9 +505,7 @@ func mergeSettings(settingSets ...map[string]string) map[string]string { settings := make(map[string]string) for _, s2 := range settingSets { - for k, v := range s2 { - settings[k] = v - } + maps.Copy(settings, s2) } return settings @@ -449,9 +529,14 @@ func parseEnvSettings() map[string]string { "PGSSLSNI": "sslsni", "PGSSLROOTCERT": "sslrootcert", "PGSSLPASSWORD": "sslpassword", + "PGSSLNEGOTIATION": "sslnegotiation", "PGTARGETSESSIONATTRS": "target_session_attrs", "PGSERVICE": "service", "PGSERVICEFILE": "servicefile", + "PGTZ": "timezone", + "PGOPTIONS": "options", + "PGMINPROTOCOLVERSION": "min_protocol_version", + "PGMAXPROTOCOLVERSION": "max_protocol_version", } for envname, realname := range nameMap { @@ -476,7 +561,9 @@ func parseURLSettings(connString string) (map[string]string, error) { } if parsedURL.User != nil { - settings["user"] = parsedURL.User.Username() + if u := parsedURL.User.Username(); u != "" { + settings["user"] = u + } if password, present := parsedURL.User.Password(); present { settings["password"] = password } @@ -485,7 +572,7 @@ func parseURLSettings(connString string) (map[string]string, error) { // Handle multiple host:port's in url.Host by splitting them into host,host,host and port,port,port. var hosts []string var ports []string - for _, host := range strings.Split(parsedURL.Host, ",") { + for host := range strings.SplitSeq(parsedURL.Host, ",") { if host == "" { continue } @@ -603,6 +690,9 @@ func parseKeywordValueSettings(s string) (map[string]string, error) { return nil, errors.New("invalid keyword/value") } + if key == "user" && val == "" { + continue + } settings[key] = val } @@ -646,6 +736,7 @@ func configTLS(settings map[string]string, thisHost string, parseConfigOptions P sslkey := settings["sslkey"] sslpassword := settings["sslpassword"] sslsni := settings["sslsni"] + sslnegotiation := settings["sslnegotiation"] // Match libpq default behavior if sslmode == "" { @@ -657,6 +748,13 @@ func configTLS(settings map[string]string, thisHost string, parseConfigOptions P tlsConfig := &tls.Config{} + if sslnegotiation == "direct" { + tlsConfig.NextProtos = []string{"postgresql"} + if sslmode == "prefer" { + sslmode = "require" + } + } + if sslrootcert != "" { var caCertPool *x509.CertPool @@ -696,7 +794,7 @@ func configTLS(settings map[string]string, thisHost string, parseConfigOptions P // According to PostgreSQL documentation, if a root CA file exists, // the behavior of sslmode=require should be the same as that of verify-ca // - // See https://www.postgresql.org/docs/12/libpq-ssl.html + // See https://www.postgresql.org/docs/current/libpq-ssl.html if sslrootcert != "" { goto nextCase } @@ -765,10 +863,10 @@ func configTLS(settings map[string]string, thisHost string, parseConfigOptions P // Attempt decryption with pass phrase // NOTE: only supports RSA (PKCS#1) if sslpassword != "" { - decryptedKey, decryptedError = x509.DecryptPEMBlock(block, []byte(sslpassword)) + decryptedKey, decryptedError = x509.DecryptPEMBlock(block, []byte(sslpassword)) //nolint:ineffassign } - //if sslpassword not provided or has decryption error when use it - //try to find sslpassword with callback function + // if sslpassword not provided or has decryption error when use it + // try to find sslpassword with callback function if sslpassword == "" || decryptedError != nil { if parseConfigOptions.GetSSLPassword != nil { sslpassword = parseConfigOptions.GetSSLPassword(context.Background()) @@ -780,7 +878,7 @@ func configTLS(settings map[string]string, thisHost string, parseConfigOptions P decryptedKey, decryptedError = x509.DecryptPEMBlock(block, []byte(sslpassword)) // Should we also provide warning for PKCS#1 needed? if decryptedError != nil { - return nil, fmt.Errorf("unable to decrypt key: %w", err) + return nil, fmt.Errorf("unable to decrypt key: %w", decryptedError) } pemBytes := pem.Block{ @@ -861,12 +959,12 @@ func makeConnectTimeoutDialFunc(timeout time.Duration) DialFunc { // ValidateConnectTargetSessionAttrsReadWrite is a ValidateConnectFunc that implements libpq compatible // target_session_attrs=read-write. func ValidateConnectTargetSessionAttrsReadWrite(ctx context.Context, pgConn *PgConn) error { - result := pgConn.ExecParams(ctx, "show transaction_read_only", nil, nil, nil, nil).Read() - if result.Err != nil { - return result.Err + result, err := pgConn.Exec(ctx, "show transaction_read_only").ReadAll() + if err != nil { + return err } - if string(result.Rows[0][0]) == "on" { + if string(result[0].Rows[0][0]) == "on" { return errors.New("read only connection") } @@ -876,12 +974,12 @@ func ValidateConnectTargetSessionAttrsReadWrite(ctx context.Context, pgConn *PgC // ValidateConnectTargetSessionAttrsReadOnly is a ValidateConnectFunc that implements libpq compatible // target_session_attrs=read-only. func ValidateConnectTargetSessionAttrsReadOnly(ctx context.Context, pgConn *PgConn) error { - result := pgConn.ExecParams(ctx, "show transaction_read_only", nil, nil, nil, nil).Read() - if result.Err != nil { - return result.Err + result, err := pgConn.Exec(ctx, "show transaction_read_only").ReadAll() + if err != nil { + return err } - if string(result.Rows[0][0]) != "on" { + if string(result[0].Rows[0][0]) != "on" { return errors.New("connection is not read only") } @@ -891,12 +989,12 @@ func ValidateConnectTargetSessionAttrsReadOnly(ctx context.Context, pgConn *PgCo // ValidateConnectTargetSessionAttrsStandby is a ValidateConnectFunc that implements libpq compatible // target_session_attrs=standby. func ValidateConnectTargetSessionAttrsStandby(ctx context.Context, pgConn *PgConn) error { - result := pgConn.ExecParams(ctx, "select pg_is_in_recovery()", nil, nil, nil, nil).Read() - if result.Err != nil { - return result.Err + result, err := pgConn.Exec(ctx, "select pg_is_in_recovery()").ReadAll() + if err != nil { + return err } - if string(result.Rows[0][0]) != "t" { + if string(result[0].Rows[0][0]) != "t" { return errors.New("server is not in hot standby mode") } @@ -906,12 +1004,12 @@ func ValidateConnectTargetSessionAttrsStandby(ctx context.Context, pgConn *PgCon // ValidateConnectTargetSessionAttrsPrimary is a ValidateConnectFunc that implements libpq compatible // target_session_attrs=primary. func ValidateConnectTargetSessionAttrsPrimary(ctx context.Context, pgConn *PgConn) error { - result := pgConn.ExecParams(ctx, "select pg_is_in_recovery()", nil, nil, nil, nil).Read() - if result.Err != nil { - return result.Err + result, err := pgConn.Exec(ctx, "select pg_is_in_recovery()").ReadAll() + if err != nil { + return err } - if string(result.Rows[0][0]) == "t" { + if string(result[0].Rows[0][0]) == "t" { return errors.New("server is in standby mode") } @@ -921,14 +1019,25 @@ func ValidateConnectTargetSessionAttrsPrimary(ctx context.Context, pgConn *PgCon // ValidateConnectTargetSessionAttrsPreferStandby is a ValidateConnectFunc that implements libpq compatible // target_session_attrs=prefer-standby. func ValidateConnectTargetSessionAttrsPreferStandby(ctx context.Context, pgConn *PgConn) error { - result := pgConn.ExecParams(ctx, "select pg_is_in_recovery()", nil, nil, nil, nil).Read() - if result.Err != nil { - return result.Err + result, err := pgConn.Exec(ctx, "select pg_is_in_recovery()").ReadAll() + if err != nil { + return err } - if string(result.Rows[0][0]) != "t" { + if string(result[0].Rows[0][0]) != "t" { return &NotPreferredError{err: errors.New("server is not in hot standby mode")} } return nil } + +func parseProtocolVersion(s string) (uint32, error) { + switch s { + case "", "3.0": + return pgproto3.ProtocolVersion30, nil + case "3.2", "latest": + return pgproto3.ProtocolVersion32, nil + default: + return 0, fmt.Errorf("invalid protocol version: %q", s) + } +} diff --git a/vendor/github.com/jackc/pgx/v5/pgconn/ctxwatch/context_watcher.go b/vendor/github.com/jackc/pgx/v5/pgconn/ctxwatch/context_watcher.go index db8884eb..b8892e68 100644 --- a/vendor/github.com/jackc/pgx/v5/pgconn/ctxwatch/context_watcher.go +++ b/vendor/github.com/jackc/pgx/v5/pgconn/ctxwatch/context_watcher.go @@ -8,12 +8,13 @@ import ( // ContextWatcher watches a context and performs an action when the context is canceled. It can watch one context at a // time. type ContextWatcher struct { - handler Handler - unwatchChan chan struct{} + handler Handler - lock sync.Mutex - watchInProgress bool - onCancelWasCalled bool + // Lock protects the members below. + lock sync.Mutex + // Stop is the handle for an "after func". See [context.AfterFunc]. + stop func() bool + done chan struct{} } // NewContextWatcher returns a ContextWatcher. onCancel will be called when a watched context is canceled. @@ -21,8 +22,7 @@ type ContextWatcher struct { // onCancel called. func NewContextWatcher(handler Handler) *ContextWatcher { cw := &ContextWatcher{ - handler: handler, - unwatchChan: make(chan struct{}), + handler: handler, } return cw @@ -33,25 +33,16 @@ func (cw *ContextWatcher) Watch(ctx context.Context) { cw.lock.Lock() defer cw.lock.Unlock() - if cw.watchInProgress { - panic("Watch already in progress") + if cw.stop != nil { + panic("watch already in progress") } - cw.onCancelWasCalled = false - if ctx.Done() != nil { - cw.watchInProgress = true - go func() { - select { - case <-ctx.Done(): - cw.handler.HandleCancel(ctx) - cw.onCancelWasCalled = true - <-cw.unwatchChan - case <-cw.unwatchChan: - } - }() - } else { - cw.watchInProgress = false + cw.done = make(chan struct{}) + cw.stop = context.AfterFunc(ctx, func() { + cw.handler.HandleCancel(ctx) + close(cw.done) + }) } } @@ -61,12 +52,13 @@ func (cw *ContextWatcher) Unwatch() { cw.lock.Lock() defer cw.lock.Unlock() - if cw.watchInProgress { - cw.unwatchChan <- struct{}{} - if cw.onCancelWasCalled { + if cw.stop != nil { + if !cw.stop() { + <-cw.done cw.handler.HandleUnwatchAfterCancel() } - cw.watchInProgress = false + cw.stop = nil + cw.done = nil } } diff --git a/vendor/github.com/jackc/pgx/v5/pgconn/errors.go b/vendor/github.com/jackc/pgx/v5/pgconn/errors.go index ec4a6d47..bc1e31e3 100644 --- a/vendor/github.com/jackc/pgx/v5/pgconn/errors.go +++ b/vendor/github.com/jackc/pgx/v5/pgconn/errors.go @@ -27,7 +27,7 @@ func Timeout(err error) bool { } // PgError represents an error reported by the PostgreSQL server. See -// http://www.postgresql.org/docs/11/static/protocol-error-fields.html for +// http://www.postgresql.org/docs/current/static/protocol-error-fields.html for // detailed field description. type PgError struct { Severity string @@ -112,6 +112,14 @@ type ParseConfigError struct { err error } +func NewParseConfigError(conn, msg string, err error) error { + return &ParseConfigError{ + ConnString: conn, + msg: msg, + err: err, + } +} + func (e *ParseConfigError) Error() string { // Now that ParseConfigError is public and ConnString is available to the developer, perhaps it would be better only // return a static string. That would ensure that the error message cannot leak a password. The ConnString field would @@ -246,3 +254,20 @@ func (e *NotPreferredError) SafeToRetry() bool { func (e *NotPreferredError) Unwrap() error { return e.err } + +type PrepareError struct { + err error + + ParseComplete bool // Indicates whether the error occurred after a ParseComplete message was received. +} + +func (e *PrepareError) Error() string { + if e.ParseComplete { + return fmt.Sprintf("prepare failed after ParseComplete: %s", e.err.Error()) + } + return e.err.Error() +} + +func (e *PrepareError) Unwrap() error { + return e.err +} diff --git a/vendor/github.com/jackc/pgx/v5/pgconn/krb5.go b/vendor/github.com/jackc/pgx/v5/pgconn/krb5.go index 3c1af347..efb0d61b 100644 --- a/vendor/github.com/jackc/pgx/v5/pgconn/krb5.go +++ b/vendor/github.com/jackc/pgx/v5/pgconn/krb5.go @@ -28,7 +28,7 @@ func RegisterGSSProvider(newGSSArg NewGSSFunc) { // GSS provides GSSAPI authentication (e.g., Kerberos). type GSS interface { - GetInitToken(host string, service string) ([]byte, error) + GetInitToken(host, service string) ([]byte, error) GetInitTokenFromSPN(spn string) ([]byte, error) Continue(inToken []byte) (done bool, outToken []byte, err error) } diff --git a/vendor/github.com/jackc/pgx/v5/pgconn/pgconn.go b/vendor/github.com/jackc/pgx/v5/pgconn/pgconn.go index 7efb522a..d6587cef 100644 --- a/vendor/github.com/jackc/pgx/v5/pgconn/pgconn.go +++ b/vendor/github.com/jackc/pgx/v5/pgconn/pgconn.go @@ -1,6 +1,7 @@ package pgconn import ( + "container/list" "context" "crypto/md5" "crypto/tls" @@ -9,6 +10,7 @@ import ( "errors" "fmt" "io" + "maps" "math" "net" "strconv" @@ -21,6 +23,7 @@ import ( "github.com/jackc/pgx/v5/pgconn/ctxwatch" "github.com/jackc/pgx/v5/pgconn/internal/bgreader" "github.com/jackc/pgx/v5/pgproto3" + "github.com/jackc/pgx/v5/pgtype" ) const ( @@ -74,7 +77,7 @@ type NotificationHandler func(*PgConn, *Notification) type PgConn struct { conn net.Conn pid uint32 // backend pid - secretKey uint32 // key to use to send a cancel query message to the server + secretKey []byte // key to use to send a cancel query message to the server parameterStatuses map[string]string // parameters that have been reported by the server txStatus byte frontend *pgproto3.Frontend @@ -134,7 +137,7 @@ func ConnectWithOptions(ctx context.Context, connString string, parseConfigOptio // // If config.Fallbacks are present they will sequentially be tried in case of error establishing network connection. An // authentication error will terminate the chain of attempts (like libpq: -// https://www.postgresql.org/docs/11/libpq-connect.html#LIBPQ-MULTIPLE-HOSTS) and be returned as the error. +// https://www.postgresql.org/docs/current/libpq-connect.html#LIBPQ-MULTIPLE-HOSTS) and be returned as the error. func ConnectConfig(ctx context.Context, config *Config) (*PgConn, error) { // Default values are set in ParseConfig. Enforce initial creation by ParseConfig rather than setting defaults from // zero values. @@ -267,12 +270,15 @@ func connectPreferred(ctx context.Context, config *Config, connectOneConfigs []* var pgErr *PgError if errors.As(err, &pgErr) { - const ERRCODE_INVALID_PASSWORD = "28P01" // wrong password - const ERRCODE_INVALID_AUTHORIZATION_SPECIFICATION = "28000" // wrong password or bad pg_hba.conf settings - const ERRCODE_INVALID_CATALOG_NAME = "3D000" // db does not exist - const ERRCODE_INSUFFICIENT_PRIVILEGE = "42501" // missing connect privilege + // pgx will try next host even if libpq does not in certain cases (see #2246) + // consider change for the next major version + + const ERRCODE_INVALID_PASSWORD = "28P01" + const ERRCODE_INVALID_CATALOG_NAME = "3D000" // db does not exist + const ERRCODE_INSUFFICIENT_PRIVILEGE = "42501" // missing connect privilege + + // auth failed due to invalid password, db does not exist or user has no permission if pgErr.Code == ERRCODE_INVALID_PASSWORD || - pgErr.Code == ERRCODE_INVALID_AUTHORIZATION_SPECIFICATION && c.tlsConfig != nil || pgErr.Code == ERRCODE_INVALID_CATALOG_NAME || pgErr.Code == ERRCODE_INSUFFICIENT_PRIVILEGE { return nil, allErrors @@ -313,6 +319,15 @@ func connectOne(ctx context.Context, config *Config, connectConfig *connectOneCo return e } + maxProtocolVersion, err := parseProtocolVersion(config.MaxProtocolVersion) + if err != nil { + return nil, newPerDialConnectError("invalid max_protocol_version", err) + } + minProtocolVersion, err := parseProtocolVersion(config.MinProtocolVersion) + if err != nil { + return nil, newPerDialConnectError("invalid min_protocol_version", err) + } + pgConn.conn, err = config.DialFunc(ctx, connectConfig.network, connectConfig.address) if err != nil { return nil, newPerDialConnectError("dial error", err) @@ -321,7 +336,15 @@ func connectOne(ctx context.Context, config *Config, connectConfig *connectOneCo if connectConfig.tlsConfig != nil { pgConn.contextWatcher = ctxwatch.NewContextWatcher(&DeadlineContextWatcherHandler{Conn: pgConn.conn}) pgConn.contextWatcher.Watch(ctx) - tlsConn, err := startTLS(pgConn.conn, connectConfig.tlsConfig) + var ( + tlsConn net.Conn + err error + ) + if config.SSLNegotiation == "direct" { + tlsConn = tls.Client(pgConn.conn, connectConfig.tlsConfig) + } else { + tlsConn, err = startTLS(pgConn.conn, connectConfig.tlsConfig) + } pgConn.contextWatcher.Unwatch() // Always unwatch `netConn` after TLS. if err != nil { pgConn.conn.Close() @@ -331,6 +354,14 @@ func connectOne(ctx context.Context, config *Config, connectConfig *connectOneCo pgConn.conn = tlsConn } + if config.AfterNetConnect != nil { + pgConn.conn, err = config.AfterNetConnect(ctx, config, pgConn.conn) + if err != nil { + pgConn.conn.Close() + return nil, newPerDialConnectError("AfterNetConnect failed", err) + } + } + pgConn.contextWatcher = ctxwatch.NewContextWatcher(config.BuildContextWatcherHandler(pgConn)) pgConn.contextWatcher.Watch(ctx) defer pgConn.contextWatcher.Unwatch() @@ -349,14 +380,12 @@ func connectOne(ctx context.Context, config *Config, connectConfig *connectOneCo pgConn.frontend = config.BuildFrontend(pgConn.bgReader, pgConn.conn) startupMsg := pgproto3.StartupMessage{ - ProtocolVersion: pgproto3.ProtocolVersionNumber, + ProtocolVersion: maxProtocolVersion, Parameters: make(map[string]string), } // Copy default run-time params - for k, v := range config.RuntimeParams { - startupMsg.Parameters[k] = v - } + maps.Copy(startupMsg.Parameters, config.RuntimeParams) startupMsg.Parameters["user"] = config.User if config.Database != "" { @@ -399,7 +428,20 @@ func connectOne(ctx context.Context, config *Config, connectConfig *connectOneCo return nil, newPerDialConnectError("failed to write password message", err) } case *pgproto3.AuthenticationSASL: - err = pgConn.scramAuth(msg.AuthMechanisms) + // Check if OAUTHBEARER is supported + serverSupportsOAuthBearer := false + for _, mech := range msg.AuthMechanisms { + if mech == "OAUTHBEARER" { + serverSupportsOAuthBearer = true + break + } + } + + if serverSupportsOAuthBearer && pgConn.config.OAuthTokenProvider != nil { + err = pgConn.oauthAuth(ctx) + } else { + err = pgConn.scramAuth(msg.AuthMechanisms) + } if err != nil { pgConn.conn.Close() return nil, newPerDialConnectError("failed SASL auth", err) @@ -432,6 +474,12 @@ func connectOne(ctx context.Context, config *Config, connectConfig *connectOneCo return pgConn, nil case *pgproto3.ParameterStatus, *pgproto3.NoticeResponse: // handled by ReceiveMessage + case *pgproto3.NegotiateProtocolVersion: + serverVersion := pgproto3.ProtocolVersion30&0xFFFF0000 | uint32(msg.NewestMinorProtocol) + if serverVersion < minProtocolVersion { + pgConn.conn.Close() + return nil, newPerDialConnectError("server protocol version too low", nil) + } case *pgproto3.ErrorResponse: pgConn.conn.Close() return nil, newPerDialConnectError("server error", ErrorResponseToPgError(msg)) @@ -490,7 +538,7 @@ func (pgConn *PgConn) signalMessage() chan struct{} { } // ReceiveMessage receives one wire protocol message from the PostgreSQL server. It must only be used when the -// connection is not busy. e.g. It is an error to call ReceiveMessage while reading the result of a query. The messages +// connection is not busy. e.g. It is an error to call [PgConn.ReceiveMessage] while reading the result of a query. The messages // are still handled by the core pgconn message handling system so receiving a NotificationResponse will still trigger // the OnNotification callback. // @@ -564,6 +612,10 @@ func (pgConn *PgConn) peekMessage() (pgproto3.BackendMessage, error) { // receiveMessage receives a message without setting up context cancellation func (pgConn *PgConn) receiveMessage() (pgproto3.BackendMessage, error) { + if pgConn.status == connStatusClosed { + return nil, &connLockError{status: "conn closed"} + } + msg, err := pgConn.peekMessage() if err != nil { return nil, err @@ -621,7 +673,7 @@ func (pgConn *PgConn) TxStatus() byte { } // SecretKey returns the backend secret key used to send a cancel query message to the server. -func (pgConn *PgConn) SecretKey() uint32 { +func (pgConn *PgConn) SecretKey() []byte { return pgConn.secretKey } @@ -758,25 +810,20 @@ func NewCommandTag(s string) CommandTag { // RowsAffected returns the number of rows affected. If the CommandTag was not // for a row affecting command (e.g. "CREATE TABLE") then it returns 0. func (ct CommandTag) RowsAffected() int64 { - // Find last non-digit - idx := -1 + // Parse the number from the end in a single pass. + var n int64 + var mult int64 = 1 + for i := len(ct.s) - 1; i >= 0; i-- { - if ct.s[i] >= '0' && ct.s[i] <= '9' { - idx = i + c := ct.s[i] + if c >= '0' && c <= '9' { + n += int64(c-'0') * mult + mult *= 10 } else { break } } - if idx == -1 { - return 0 - } - - var n int64 - for _, b := range ct.s[idx:] { - n = n*10 + int64(b-'0') - } - return n } @@ -814,13 +861,15 @@ type FieldDescription struct { Format int16 } -func (pgConn *PgConn) convertRowDescription(dst []FieldDescription, rd *pgproto3.RowDescription) []FieldDescription { - if cap(dst) >= len(rd.Fields) { - dst = dst[:len(rd.Fields):len(rd.Fields)] +func (pgConn *PgConn) getFieldDescriptionSlice(n int) []FieldDescription { + if cap(pgConn.fieldDescriptions) >= n { + return pgConn.fieldDescriptions[:n:n] } else { - dst = make([]FieldDescription, len(rd.Fields)) + return make([]FieldDescription, n) } +} +func convertRowDescription(dst []FieldDescription, rd *pgproto3.RowDescription) { for i := range rd.Fields { dst[i].Name = string(rd.Fields[i].Name) dst[i].TableOID = rd.Fields[i].TableOID @@ -830,8 +879,6 @@ func (pgConn *PgConn) convertRowDescription(dst []FieldDescription, rd *pgproto3 dst[i].TypeModifier = rd.Fields[i].TypeModifier dst[i].Format = rd.Fields[i].Format } - - return dst } type StatementDescription struct { @@ -846,6 +893,10 @@ type StatementDescription struct { // // Prepare does not send a PREPARE statement to the server. It uses the PostgreSQL Parse and Describe protocol messages // directly. +// +// In extremely rare cases, Prepare may fail after the Parse is successful, but before the Describe is complete. In this +// case, the returned error will be an error where errors.As with a *PrepareError succeeds and the *PrepareError has +// ParseComplete set to true. func (pgConn *PgConn) Prepare(ctx context.Context, name, sql string, paramOIDs []uint32) (*StatementDescription, error) { if err := pgConn.lock(); err != nil { return nil, err @@ -873,7 +924,8 @@ func (pgConn *PgConn) Prepare(ctx context.Context, name, sql string, paramOIDs [ psd := &StatementDescription{Name: name, SQL: sql} - var parseErr error + var ParseComplete bool + var pgErr *PgError readloop: for { @@ -884,20 +936,23 @@ readloop: } switch msg := msg.(type) { + case *pgproto3.ParseComplete: + ParseComplete = true case *pgproto3.ParameterDescription: psd.ParamOIDs = make([]uint32, len(msg.ParameterOIDs)) copy(psd.ParamOIDs, msg.ParameterOIDs) case *pgproto3.RowDescription: - psd.Fields = pgConn.convertRowDescription(nil, msg) + psd.Fields = make([]FieldDescription, len(msg.Fields)) + convertRowDescription(psd.Fields, msg) case *pgproto3.ErrorResponse: - parseErr = ErrorResponseToPgError(msg) + pgErr = ErrorResponseToPgError(msg) case *pgproto3.ReadyForQuery: break readloop } } - if parseErr != nil { - return nil, parseErr + if pgErr != nil { + return nil, &PrepareError{err: pgErr, ParseComplete: ParseComplete} } return psd, nil } @@ -979,7 +1034,8 @@ func noticeResponseToNotice(msg *pgproto3.NoticeResponse) *Notice { // CancelRequest sends a cancel request to the PostgreSQL server. It returns an error if unable to deliver the cancel // request, but lack of an error does not ensure that the query was canceled. As specified in the documentation, there -// is no way to be sure a query was canceled. See https://www.postgresql.org/docs/11/protocol-flow.html#id-1.10.5.7.9 +// is no way to be sure a query was canceled. +// See https://www.postgresql.org/docs/current/protocol-flow.html#PROTOCOL-FLOW-CANCELING-REQUESTS func (pgConn *PgConn) CancelRequest(ctx context.Context) error { // Open a cancellation request to the same server. The address is taken from the net.Conn directly instead of reusing // the connection config. This is important in high availability configurations where fallback connections may be @@ -1016,11 +1072,11 @@ func (pgConn *PgConn) CancelRequest(ctx context.Context) error { defer contextWatcher.Unwatch() } - buf := make([]byte, 16) - binary.BigEndian.PutUint32(buf[0:4], 16) + buf := make([]byte, 12+len(pgConn.secretKey)) + binary.BigEndian.PutUint32(buf[0:4], uint32(len(buf))) binary.BigEndian.PutUint32(buf[4:8], 80877102) binary.BigEndian.PutUint32(buf[8:12], pgConn.pid) - binary.BigEndian.PutUint32(buf[12:16], pgConn.secretKey) + copy(buf[12:], pgConn.secretKey) if _, err := cancelConn.Write(buf); err != nil { return fmt.Errorf("write to connection for cancellation: %w", err) @@ -1069,7 +1125,7 @@ func (pgConn *PgConn) WaitForNotification(ctx context.Context) error { // implicitly wrapped in a transaction unless a transaction is already in progress or SQL contains transaction control // statements. // -// Prefer ExecParams unless executing arbitrary SQL that may contain multiple queries. +// Prefer [PgConn.ExecParams] unless executing arbitrary SQL that may contain multiple queries. func (pgConn *PgConn) Exec(ctx context.Context, sql string) *MultiResultReader { if err := pgConn.lock(); err != nil { return &MultiResultReader{ @@ -1127,8 +1183,8 @@ func (pgConn *PgConn) Exec(ctx context.Context, sql string) *MultiResultReader { // resultFormats is a slice of format codes determining for each result column whether it is encoded in text or // binary format. If resultFormats is nil all results will be in text format. // -// ResultReader must be closed before PgConn can be used again. -func (pgConn *PgConn) ExecParams(ctx context.Context, sql string, paramValues [][]byte, paramOIDs []uint32, paramFormats []int16, resultFormats []int16) *ResultReader { +// [ResultReader] must be closed before [PgConn] can be used again. +func (pgConn *PgConn) ExecParams(ctx context.Context, sql string, paramValues [][]byte, paramOIDs []uint32, paramFormats, resultFormats []int16) *ResultReader { result := pgConn.execExtendedPrefix(ctx, paramValues) if result.closed { return result @@ -1137,7 +1193,7 @@ func (pgConn *PgConn) ExecParams(ctx context.Context, sql string, paramValues [] pgConn.frontend.SendParse(&pgproto3.Parse{Query: sql, ParameterOIDs: paramOIDs}) pgConn.frontend.SendBind(&pgproto3.Bind{ParameterFormatCodes: paramFormats, Parameters: paramValues, ResultFormatCodes: resultFormats}) - pgConn.execExtendedSuffix(result) + pgConn.execExtendedSuffix(result, nil, nil) return result } @@ -1153,8 +1209,8 @@ func (pgConn *PgConn) ExecParams(ctx context.Context, sql string, paramValues [] // resultFormats is a slice of format codes determining for each result column whether it is encoded in text or // binary format. If resultFormats is nil all results will be in text format. // -// ResultReader must be closed before PgConn can be used again. -func (pgConn *PgConn) ExecPrepared(ctx context.Context, stmtName string, paramValues [][]byte, paramFormats []int16, resultFormats []int16) *ResultReader { +// [ResultReader] must be closed before [PgConn] can be used again. +func (pgConn *PgConn) ExecPrepared(ctx context.Context, stmtName string, paramValues [][]byte, paramFormats, resultFormats []int16) *ResultReader { result := pgConn.execExtendedPrefix(ctx, paramValues) if result.closed { return result @@ -1162,7 +1218,36 @@ func (pgConn *PgConn) ExecPrepared(ctx context.Context, stmtName string, paramVa pgConn.frontend.SendBind(&pgproto3.Bind{PreparedStatement: stmtName, ParameterFormatCodes: paramFormats, Parameters: paramValues, ResultFormatCodes: resultFormats}) - pgConn.execExtendedSuffix(result) + pgConn.execExtendedSuffix(result, nil, nil) + + return result +} + +// ExecStatement enqueues the execution of a prepared statement via the PostgreSQL extended query protocol. +// +// This differs from [PgConn.ExecPrepared] in that it takes a [*StatementDescription] instead of the prepared statement name. +// Because it has the [*StatementDescription] it can avoid the Describe Portal message that [PgConn.ExecPrepared] must send to get +// the result column descriptions. +// +// paramValues are the parameter values. It must be encoded in the format given by paramFormats. +// +// paramFormats is a slice of format codes determining for each paramValue column whether it is encoded in text or +// binary format. If paramFormats is nil all params are text format. ExecStatement will panic if len(paramFormats) is not +// 0, 1, or len(paramValues). +// +// resultFormats is a slice of format codes determining for each result column whether it is encoded in text or binary +// format. If resultFormats is nil all results will be in text format. +// +// [ResultReader] must be closed before [PgConn] can be used again. +func (pgConn *PgConn) ExecStatement(ctx context.Context, statementDescription *StatementDescription, paramValues [][]byte, paramFormats, resultFormats []int16) *ResultReader { + result := pgConn.execExtendedPrefix(ctx, paramValues) + if result.closed { + return result + } + + pgConn.frontend.SendBind(&pgproto3.Bind{PreparedStatement: statementDescription.Name, ParameterFormatCodes: paramFormats, Parameters: paramValues, ResultFormatCodes: resultFormats}) + + pgConn.execExtendedSuffix(result, statementDescription, resultFormats) return result } @@ -1202,8 +1287,10 @@ func (pgConn *PgConn) execExtendedPrefix(ctx context.Context, paramValues [][]by return result } -func (pgConn *PgConn) execExtendedSuffix(result *ResultReader) { - pgConn.frontend.SendDescribe(&pgproto3.Describe{ObjectType: 'P'}) +func (pgConn *PgConn) execExtendedSuffix(result *ResultReader, statementDescription *StatementDescription, resultFormats []int16) { + if statementDescription == nil { + pgConn.frontend.SendDescribe(&pgproto3.Describe{ObjectType: 'P'}) + } pgConn.frontend.SendExecute(&pgproto3.Execute{}) pgConn.frontend.SendSync(&pgproto3.Sync{}) @@ -1217,7 +1304,7 @@ func (pgConn *PgConn) execExtendedSuffix(result *ResultReader) { return } - result.readUntilRowDescription() + result.readUntilRowDescription(statementDescription, resultFormats) } // CopyTo executes the copy command sql and copies the results to w. @@ -1309,10 +1396,7 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co copyErrChan := make(chan error, 1) signalMessageChan := pgConn.signalMessage() var wg sync.WaitGroup - wg.Add(1) - - go func() { - defer wg.Done() + wg.Go(func() { buf := iobufpool.Get(65536) defer iobufpool.Put(buf) (*buf)[0] = 'd' @@ -1344,7 +1428,7 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co default: } } - }() + }) var pgErr error var copyErr error @@ -1361,7 +1445,14 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co close(pgConn.cleanupDone) return CommandTag{}, normalizeTimeoutError(ctx, err) } - msg, _ := pgConn.receiveMessage() + // peekMessage never returns err in the bufferingReceive mode - it only forwards the bufferingReceive variables. + // Therefore, the only case for receiveMessage to return err is during handling of the ErrorResponse message type + // and using pgOnError handler to determine the connection is no longer valid (and thus closing the conn). + msg, serverError := pgConn.receiveMessage() + if serverError != nil { + close(abortCopyChan) + return CommandTag{}, serverError + } switch msg := msg.(type) { case *pgproto3.ErrorResponse: @@ -1408,12 +1499,15 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co // MultiResultReader is a reader for a command that could return multiple results such as Exec or ExecBatch. type MultiResultReader struct { - pgConn *PgConn - ctx context.Context - pipeline *Pipeline + pgConn *PgConn + ctx context.Context rr *ResultReader + // Data from when the batch was queued. + statementDescriptions []*StatementDescription + resultFormats [][]int16 + closed bool err error } @@ -1443,12 +1537,8 @@ func (mrr *MultiResultReader) receiveMessage() (pgproto3.BackendMessage, error) switch msg := msg.(type) { case *pgproto3.ReadyForQuery: mrr.closed = true - if mrr.pipeline != nil { - mrr.pipeline.expectedReadyForQueryCount-- - } else { - mrr.pgConn.contextWatcher.Unwatch() - mrr.pgConn.unlock() - } + mrr.pgConn.contextWatcher.Unwatch() + mrr.pgConn.unlock() case *pgproto3.ErrorResponse: mrr.err = ErrorResponseToPgError(msg) } @@ -1459,6 +1549,39 @@ func (mrr *MultiResultReader) receiveMessage() (pgproto3.BackendMessage, error) // NextResult returns advances the MultiResultReader to the next result and returns true if a result is available. func (mrr *MultiResultReader) NextResult() bool { for !mrr.closed && mrr.err == nil { + msg, _ := mrr.pgConn.peekMessage() + if _, ok := msg.(*pgproto3.DataRow); ok { + if len(mrr.statementDescriptions) > 0 { + rr := ResultReader{ + pgConn: mrr.pgConn, + multiResultReader: mrr, + ctx: mrr.ctx, + } + + // This result corresponds to a prepared statement description that was provided when queuing the batch. + sd := mrr.statementDescriptions[0] + mrr.statementDescriptions = mrr.statementDescriptions[1:] + + resultFormats := mrr.resultFormats[0] + mrr.resultFormats = mrr.resultFormats[1:] + + sdFields := sd.Fields + rr.fieldDescriptions = rr.pgConn.getFieldDescriptionSlice(len(sdFields)) + + err := combineFieldDescriptionsAndResultFormats(rr.fieldDescriptions, sdFields, resultFormats) + if err != nil { + rr.concludeCommand(CommandTag{}, err) + } + + mrr.pgConn.resultReader = rr + mrr.rr = &mrr.pgConn.resultReader + return true + } + + mrr.err = fmt.Errorf("unexpected DataRow message without preceding RowDescription") + return false + } + msg, err := mrr.receiveMessage() if err != nil { return false @@ -1470,8 +1593,9 @@ func (mrr *MultiResultReader) NextResult() bool { pgConn: mrr.pgConn, multiResultReader: mrr, ctx: mrr.ctx, - fieldDescriptions: mrr.pgConn.convertRowDescription(mrr.pgConn.fieldDescriptions[:], msg), + fieldDescriptions: mrr.pgConn.getFieldDescriptionSlice(len(msg.Fields)), } + convertRowDescription(mrr.pgConn.resultReader.fieldDescriptions, msg) mrr.rr = &mrr.pgConn.resultReader return true @@ -1484,7 +1608,12 @@ func (mrr *MultiResultReader) NextResult() bool { mrr.rr = &mrr.pgConn.resultReader return true case *pgproto3.EmptyQueryResponse: - return false + mrr.pgConn.resultReader = ResultReader{ + commandConcluded: true, + closed: true, + } + mrr.rr = &mrr.pgConn.resultReader + return true } } @@ -1518,6 +1647,7 @@ type ResultReader struct { fieldDescriptions []FieldDescription rowValues [][]byte commandTag CommandTag + preloaded bool commandConcluded bool closed bool err error @@ -1559,6 +1689,11 @@ func (rr *ResultReader) Read() *Result { // NextRow advances the ResultReader to the next row and returns true if a row is available. func (rr *ResultReader) NextRow() bool { + if rr.preloaded { + rr.preloaded = false + return true + } + for !rr.commandConcluded { msg, err := rr.receiveMessage() if err != nil { @@ -1575,6 +1710,11 @@ func (rr *ResultReader) NextRow() bool { return false } +func (rr *ResultReader) preloadRowValues(values [][]byte) { + rr.rowValues = values + rr.preloaded = true +} + // FieldDescriptions returns the field descriptions for the current result set. The returned slice is only valid until // the ResultReader is closed. It may return nil (for example, if the query did not return a result set or an error was // encountered.) @@ -1627,19 +1767,34 @@ func (rr *ResultReader) Close() (CommandTag, error) { // readUntilRowDescription ensures the ResultReader's fieldDescriptions are loaded. It does not return an error as any // error will be stored in the ResultReader. -func (rr *ResultReader) readUntilRowDescription() { +func (rr *ResultReader) readUntilRowDescription(statementDescription *StatementDescription, resultFormats []int16) { for !rr.commandConcluded { - // Peek before receive to avoid consuming a DataRow if the result set does not include a RowDescription method. - // This should never happen under normal pgconn usage, but it is possible if SendBytes and ReceiveResults are - // manually used to construct a query that does not issue a describe statement. - msg, _ := rr.pgConn.peekMessage() - if _, ok := msg.(*pgproto3.DataRow); ok { + msg, _ := rr.receiveMessage() + switch msg := msg.(type) { + case *pgproto3.RowDescription: return - } + case *pgproto3.DataRow: + rr.preloadRowValues(msg.Values) + if statementDescription != nil { + sdFields := statementDescription.Fields + rr.fieldDescriptions = rr.pgConn.getFieldDescriptionSlice(len(sdFields)) - // Consume the message - msg, _ = rr.receiveMessage() - if _, ok := msg.(*pgproto3.RowDescription); ok { + err := combineFieldDescriptionsAndResultFormats(rr.fieldDescriptions, sdFields, resultFormats) + if err != nil { + rr.concludeCommand(CommandTag{}, err) + } + } + return + case *pgproto3.CommandComplete: + if statementDescription != nil { + sdFields := statementDescription.Fields + rr.fieldDescriptions = rr.pgConn.getFieldDescriptionSlice(len(sdFields)) + + err := combineFieldDescriptionsAndResultFormats(rr.fieldDescriptions, sdFields, resultFormats) + if err != nil { + rr.concludeCommand(CommandTag{}, err) + } + } return } } @@ -1666,13 +1821,18 @@ func (rr *ResultReader) receiveMessage() (msg pgproto3.BackendMessage, err error switch msg := msg.(type) { case *pgproto3.RowDescription: - rr.fieldDescriptions = rr.pgConn.convertRowDescription(rr.pgConn.fieldDescriptions[:], msg) + rr.fieldDescriptions = rr.pgConn.getFieldDescriptionSlice(len(msg.Fields)) + convertRowDescription(rr.fieldDescriptions, msg) case *pgproto3.CommandComplete: rr.concludeCommand(rr.pgConn.makeCommandTag(msg.CommandTag), nil) case *pgproto3.EmptyQueryResponse: rr.concludeCommand(CommandTag{}, nil) case *pgproto3.ErrorResponse: - rr.concludeCommand(CommandTag{}, ErrorResponseToPgError(msg)) + pgErr := ErrorResponseToPgError(msg) + if rr.pipeline != nil { + rr.pipeline.state.HandleError(pgErr) + } + rr.concludeCommand(CommandTag{}, pgErr) } return msg, nil @@ -1696,12 +1856,14 @@ func (rr *ResultReader) concludeCommand(commandTag CommandTag, err error) { // Batch is a collection of queries that can be sent to the PostgreSQL server in a single round-trip. type Batch struct { - buf []byte - err error + buf []byte + statementDescriptions []*StatementDescription + resultFormats [][]int16 + err error } // ExecParams appends an ExecParams command to the batch. See PgConn.ExecParams for parameter descriptions. -func (batch *Batch) ExecParams(sql string, paramValues [][]byte, paramOIDs []uint32, paramFormats []int16, resultFormats []int16) { +func (batch *Batch) ExecParams(sql string, paramValues [][]byte, paramOIDs []uint32, paramFormats, resultFormats []int16) { if batch.err != nil { return } @@ -1714,7 +1876,7 @@ func (batch *Batch) ExecParams(sql string, paramValues [][]byte, paramOIDs []uin } // ExecPrepared appends an ExecPrepared e command to the batch. See PgConn.ExecPrepared for parameter descriptions. -func (batch *Batch) ExecPrepared(stmtName string, paramValues [][]byte, paramFormats []int16, resultFormats []int16) { +func (batch *Batch) ExecPrepared(stmtName string, paramValues [][]byte, paramFormats, resultFormats []int16) { if batch.err != nil { return } @@ -1735,6 +1897,30 @@ func (batch *Batch) ExecPrepared(stmtName string, paramValues [][]byte, paramFor } } +// ExecStatement appends an ExecStatement command to the batch. See PgConn.ExecPrepared for parameter descriptions. +// +// This differs from ExecPrepared in that it takes a *StatementDescription instead of just the prepared statement name. +// Because it has the *StatementDescription it can avoid the Describe Portal message that ExecPrepared must send to get +// the result column descriptions. +func (batch *Batch) ExecStatement(statementDescription *StatementDescription, paramValues [][]byte, paramFormats, resultFormats []int16) { + if batch.err != nil { + return + } + + batch.buf, batch.err = (&pgproto3.Bind{PreparedStatement: statementDescription.Name, ParameterFormatCodes: paramFormats, Parameters: paramValues, ResultFormatCodes: resultFormats}).Encode(batch.buf) + if batch.err != nil { + return + } + + batch.statementDescriptions = append(batch.statementDescriptions, statementDescription) + batch.resultFormats = append(batch.resultFormats, resultFormats) + + batch.buf, batch.err = (&pgproto3.Execute{}).Encode(batch.buf) + if batch.err != nil { + return + } +} + // ExecBatch executes all the queries in batch in a single round-trip. Execution is implicitly transactional unless a // transaction is already in progress or SQL contains transaction control statements. This is a simpler way of executing // multiple queries in a single round trip than using pipeline mode. @@ -1754,8 +1940,10 @@ func (pgConn *PgConn) ExecBatch(ctx context.Context, batch *Batch) *MultiResultR } pgConn.multiResultReader = MultiResultReader{ - pgConn: pgConn, - ctx: ctx, + pgConn: pgConn, + ctx: ctx, + statementDescriptions: batch.statementDescriptions, + resultFormats: batch.resultFormats, } multiResult := &pgConn.multiResultReader @@ -1773,19 +1961,23 @@ func (pgConn *PgConn) ExecBatch(ctx context.Context, batch *Batch) *MultiResultR batch.buf, batch.err = (&pgproto3.Sync{}).Encode(batch.buf) if batch.err != nil { + pgConn.contextWatcher.Unwatch() + multiResult.err = normalizeTimeoutError(multiResult.ctx, batch.err) multiResult.closed = true - multiResult.err = batch.err - pgConn.unlock() + pgConn.asyncClose() return multiResult } - pgConn.enterPotentialWriteReadDeadlock() - defer pgConn.exitPotentialWriteReadDeadlock() - _, err := pgConn.conn.Write(batch.buf) + _, err := func(buf []byte) (int, error) { + pgConn.enterPotentialWriteReadDeadlock() + defer pgConn.exitPotentialWriteReadDeadlock() + return pgConn.conn.Write(buf) + }(batch.buf) if err != nil { + pgConn.contextWatcher.Unwatch() + multiResult.err = normalizeTimeoutError(multiResult.ctx, err) multiResult.closed = true - multiResult.err = err - pgConn.unlock() + pgConn.asyncClose() return multiResult } @@ -1886,7 +2078,7 @@ func (pgConn *PgConn) flushWithPotentialWriteReadDeadlock() error { // // This should not be confused with the PostgreSQL protocol Sync message. func (pgConn *PgConn) SyncConn(ctx context.Context) error { - for i := 0; i < 10; i++ { + for range 10 { if pgConn.bgReader.Status() == bgreader.StatusStopped && pgConn.frontend.ReadBufferLen() == 0 { return nil } @@ -1914,7 +2106,7 @@ func (pgConn *PgConn) CustomData() map[string]any { type HijackedConn struct { Conn net.Conn PID uint32 // backend pid - SecretKey uint32 // key to use to send a cancel query message to the server + SecretKey []byte // key to use to send a cancel query message to the server ParameterStatuses map[string]string // parameters that have been reported by the server TxStatus byte Frontend *pgproto3.Frontend @@ -1986,9 +2178,10 @@ func Construct(hc *HijackedConn) (*PgConn, error) { // Pipeline represents a connection in pipeline mode. // -// SendPrepare, SendQueryParams, and SendQueryPrepared queue requests to the server. These requests are not written until -// pipeline is flushed by Flush or Sync. Sync must be called after the last request is queued. Requests between -// synchronization points are implicitly transactional unless explicit transaction control statements have been issued. +// SendPrepare, SendQueryParams, SendQueryPrepared, and SendQueryStatement queue requests to the server. These requests +// are not written until pipeline is flushed by Flush or Sync. Sync must be called after the last request is queued. +// Requests between synchronization points are implicitly transactional unless explicit transaction control statements +// have been issued. // // The context the pipeline was started with is in effect for the entire life of the Pipeline. // @@ -1999,9 +2192,7 @@ type Pipeline struct { conn *PgConn ctx context.Context - expectedReadyForQueryCount int - pendingSync bool - + state pipelineState err error closed bool } @@ -2012,6 +2203,150 @@ type PipelineSync struct{} // CloseComplete is returned by GetResults when a CloseComplete message is received. type CloseComplete struct{} +type pipelineRequestType int + +const ( + pipelineNil pipelineRequestType = iota + pipelinePrepare + pipelineQueryParams + pipelineQueryPrepared + pipelineQueryStatement + pipelineDeallocate + pipelineSyncRequest + pipelineFlushRequest +) + +type pipelineRequestEvent struct { + RequestType pipelineRequestType + WasSentToServer bool + BeforeFlushOrSync bool +} + +type pipelineState struct { + requestEventQueue list.List + statementDescriptionsQueue list.List + resultFormatsQueue list.List + lastRequestType pipelineRequestType + pgErr *PgError + expectedReadyForQueryCount int +} + +func (s *pipelineState) Init() { + s.requestEventQueue.Init() + s.statementDescriptionsQueue.Init() + s.resultFormatsQueue.Init() + s.lastRequestType = pipelineNil +} + +func (s *pipelineState) RegisterSendingToServer() { + for elem := s.requestEventQueue.Back(); elem != nil; elem = elem.Prev() { + val := elem.Value.(pipelineRequestEvent) + if val.WasSentToServer { + return + } + val.WasSentToServer = true + elem.Value = val + } +} + +func (s *pipelineState) registerFlushingBufferOnServer() { + for elem := s.requestEventQueue.Back(); elem != nil; elem = elem.Prev() { + val := elem.Value.(pipelineRequestEvent) + if val.BeforeFlushOrSync { + return + } + val.BeforeFlushOrSync = true + elem.Value = val + } +} + +func (s *pipelineState) PushBackRequestType(req pipelineRequestType) { + if req == pipelineNil { + return + } + + if req != pipelineFlushRequest { + s.requestEventQueue.PushBack(pipelineRequestEvent{RequestType: req}) + } + if req == pipelineFlushRequest || req == pipelineSyncRequest { + s.registerFlushingBufferOnServer() + } + s.lastRequestType = req + + if req == pipelineSyncRequest { + s.expectedReadyForQueryCount++ + } +} + +func (s *pipelineState) ExtractFrontRequestType() pipelineRequestType { + for { + elem := s.requestEventQueue.Front() + if elem == nil { + return pipelineNil + } + val := elem.Value.(pipelineRequestEvent) + if !(val.WasSentToServer && val.BeforeFlushOrSync) { + return pipelineNil + } + + s.requestEventQueue.Remove(elem) + if val.RequestType == pipelineSyncRequest { + s.pgErr = nil + } + if s.pgErr == nil { + return val.RequestType + } + } +} + +func (s *pipelineState) PushBackStatementData(sd *StatementDescription, resultFormats []int16) { + s.statementDescriptionsQueue.PushBack(sd) + s.resultFormatsQueue.PushBack(resultFormats) +} + +func (s *pipelineState) ExtractFrontStatementData() (*StatementDescription, []int16) { + sdElem := s.statementDescriptionsQueue.Front() + var sd *StatementDescription + if sdElem != nil { + s.statementDescriptionsQueue.Remove(sdElem) + sd = sdElem.Value.(*StatementDescription) + } + + rfElem := s.resultFormatsQueue.Front() + var resultFormats []int16 + if rfElem != nil { + s.resultFormatsQueue.Remove(rfElem) + resultFormats = rfElem.Value.([]int16) + } + + return sd, resultFormats +} + +func (s *pipelineState) HandleError(err *PgError) { + s.pgErr = err +} + +func (s *pipelineState) HandleReadyForQuery() { + s.expectedReadyForQueryCount-- +} + +func (s *pipelineState) PendingSync() bool { + var notPendingSync bool + + if elem := s.requestEventQueue.Back(); elem != nil { + val := elem.Value.(pipelineRequestEvent) + notPendingSync = (val.RequestType == pipelineSyncRequest) && val.WasSentToServer + } else { + notPendingSync = (s.lastRequestType == pipelineSyncRequest) || (s.lastRequestType == pipelineNil) + } + + return !notPendingSync +} + +func (s *pipelineState) ExpectedReadyForQuery() int { + return s.expectedReadyForQueryCount +} + // StartPipeline switches the connection to pipeline mode and returns a *Pipeline. In pipeline mode requests can be sent // to the server without waiting for a response. Close must be called on the returned *Pipeline to return the connection // to normal mode. While in pipeline mode, no methods that communicate with the server may be called except @@ -2020,16 +2355,23 @@ type CloseComplete struct{} // Prefer ExecBatch when only sending one group of queries at once. func (pgConn *PgConn) StartPipeline(ctx context.Context) *Pipeline { if err := pgConn.lock(); err != nil { - return &Pipeline{ + pipeline := &Pipeline{ closed: true, err: err, } + pipeline.state.Init() + + return pipeline } + pgConn.resultReader = ResultReader{closed: true} + pgConn.pipeline = Pipeline{ conn: pgConn, ctx: ctx, } + pgConn.pipeline.state.Init() + pipeline := &pgConn.pipeline if ctx != context.Background() { @@ -2052,10 +2394,10 @@ func (p *Pipeline) SendPrepare(name, sql string, paramOIDs []uint32) { if p.closed { return } - p.pendingSync = true p.conn.frontend.SendParse(&pgproto3.Parse{Name: name, Query: sql, ParameterOIDs: paramOIDs}) p.conn.frontend.SendDescribe(&pgproto3.Describe{ObjectType: 'S', Name: name}) + p.state.PushBackRequestType(pipelinePrepare) } // SendDeallocate deallocates a prepared statement. @@ -2063,34 +2405,77 @@ func (p *Pipeline) SendDeallocate(name string) { if p.closed { return } - p.pendingSync = true p.conn.frontend.SendClose(&pgproto3.Close{ObjectType: 'S', Name: name}) + p.state.PushBackRequestType(pipelineDeallocate) } -// SendQueryParams is the pipeline version of *PgConn.QueryParams. -func (p *Pipeline) SendQueryParams(sql string, paramValues [][]byte, paramOIDs []uint32, paramFormats []int16, resultFormats []int16) { +// SendQueryParams is the pipeline version of *PgConn.ExecParams. +func (p *Pipeline) SendQueryParams(sql string, paramValues [][]byte, paramOIDs []uint32, paramFormats, resultFormats []int16) { if p.closed { return } - p.pendingSync = true p.conn.frontend.SendParse(&pgproto3.Parse{Query: sql, ParameterOIDs: paramOIDs}) p.conn.frontend.SendBind(&pgproto3.Bind{ParameterFormatCodes: paramFormats, Parameters: paramValues, ResultFormatCodes: resultFormats}) p.conn.frontend.SendDescribe(&pgproto3.Describe{ObjectType: 'P'}) p.conn.frontend.SendExecute(&pgproto3.Execute{}) + p.state.PushBackRequestType(pipelineQueryParams) } -// SendQueryPrepared is the pipeline version of *PgConn.QueryPrepared. -func (p *Pipeline) SendQueryPrepared(stmtName string, paramValues [][]byte, paramFormats []int16, resultFormats []int16) { +// SendQueryPrepared is the pipeline version of *PgConn.ExecPrepared. +func (p *Pipeline) SendQueryPrepared(stmtName string, paramValues [][]byte, paramFormats, resultFormats []int16) { if p.closed { return } - p.pendingSync = true p.conn.frontend.SendBind(&pgproto3.Bind{PreparedStatement: stmtName, ParameterFormatCodes: paramFormats, Parameters: paramValues, ResultFormatCodes: resultFormats}) p.conn.frontend.SendDescribe(&pgproto3.Describe{ObjectType: 'P'}) p.conn.frontend.SendExecute(&pgproto3.Execute{}) + p.state.PushBackRequestType(pipelineQueryPrepared) +} + +// SendQueryStatement is the pipeline version of *PgConn.ExecStatement. +func (p *Pipeline) SendQueryStatement(statementDescription *StatementDescription, paramValues [][]byte, paramFormats, resultFormats []int16) { + if p.closed { + return + } + + p.conn.frontend.SendBind(&pgproto3.Bind{PreparedStatement: statementDescription.Name, ParameterFormatCodes: paramFormats, Parameters: paramValues, ResultFormatCodes: resultFormats}) + p.conn.frontend.SendExecute(&pgproto3.Execute{}) + p.state.PushBackRequestType(pipelineQueryStatement) + p.state.PushBackStatementData(statementDescription, resultFormats) +} + +// SendFlushRequest sends a request for the server to flush its output buffer. +// +// The server flushes its output buffer automatically as a result of Sync being called, +// or on any request when not in pipeline mode; this function is useful to cause the server +// to flush its output buffer in pipeline mode without establishing a synchronization point. +// Note that the request is not itself flushed to the server automatically; use Flush if +// necessary. This copies the behavior of libpq PQsendFlushRequest. +func (p *Pipeline) SendFlushRequest() { + if p.closed { + return + } + + p.conn.frontend.Send(&pgproto3.Flush{}) + p.state.PushBackRequestType(pipelineFlushRequest) +} + +// SendPipelineSync marks a synchronization point in a pipeline by sending a sync message +// without flushing the send buffer. This serves as the delimiter of an implicit +// transaction and an error recovery point. +// +// Note that the request is not itself flushed to the server automatically; use Flush if +// necessary. This copies the behavior of libpq PQsendPipelineSync. +func (p *Pipeline) SendPipelineSync() { + if p.closed { + return + } + + p.conn.frontend.SendSync(&pgproto3.Sync{}) + p.state.PushBackRequestType(pipelineSyncRequest) } // Flush flushes the queued requests without establishing a synchronization point. @@ -2115,28 +2500,14 @@ func (p *Pipeline) Flush() error { return err } + p.state.RegisterSendingToServer() return nil } // Sync establishes a synchronization point and flushes the queued requests. func (p *Pipeline) Sync() error { - if p.closed { - if p.err != nil { - return p.err - } - return errors.New("pipeline closed") - } - - p.conn.frontend.SendSync(&pgproto3.Sync{}) - err := p.Flush() - if err != nil { - return err - } - - p.pendingSync = false - p.expectedReadyForQueryCount++ - - return nil + p.SendPipelineSync() + return p.Flush() } // GetResults gets the next results. If results are present, results may be a *ResultReader, *StatementDescription, or @@ -2150,98 +2521,315 @@ func (p *Pipeline) GetResults() (results any, err error) { return nil, errors.New("pipeline closed") } - if p.expectedReadyForQueryCount == 0 { - return nil, nil - } - return p.getResults() } func (p *Pipeline) getResults() (results any, err error) { - for { - msg, err := p.conn.receiveMessage() + if !p.conn.resultReader.closed { + _, err := p.conn.resultReader.Close() if err != nil { - p.closed = true - p.err = err - p.conn.asyncClose() - return nil, normalizeTimeoutError(p.ctx, err) - } - - switch msg := msg.(type) { - case *pgproto3.RowDescription: - p.conn.resultReader = ResultReader{ - pgConn: p.conn, - pipeline: p, - ctx: p.ctx, - fieldDescriptions: p.conn.convertRowDescription(p.conn.fieldDescriptions[:], msg), - } - return &p.conn.resultReader, nil - case *pgproto3.CommandComplete: - p.conn.resultReader = ResultReader{ - commandTag: p.conn.makeCommandTag(msg.CommandTag), - commandConcluded: true, - closed: true, - } - return &p.conn.resultReader, nil - case *pgproto3.ParseComplete: - peekedMsg, err := p.conn.peekMessage() - if err != nil { - p.conn.asyncClose() - return nil, normalizeTimeoutError(p.ctx, err) - } - if _, ok := peekedMsg.(*pgproto3.ParameterDescription); ok { - return p.getResultsPrepare() - } - case *pgproto3.CloseComplete: - return &CloseComplete{}, nil - case *pgproto3.ReadyForQuery: - p.expectedReadyForQueryCount-- - return &PipelineSync{}, nil - case *pgproto3.ErrorResponse: - pgErr := ErrorResponseToPgError(msg) - return nil, pgErr + return nil, err } + } + currentRequestType := p.state.ExtractFrontRequestType() + switch currentRequestType { + case pipelineNil: + return nil, nil + case pipelinePrepare: + return p.getResultsPrepare() + case pipelineQueryParams: + return p.getResultsQueryParams() + case pipelineQueryPrepared: + return p.getResultsQueryPrepared() + case pipelineQueryStatement: + return p.getResultsQueryStatement() + case pipelineDeallocate: + return p.getResultsDeallocate() + case pipelineSyncRequest: + return p.getResultsSync() + case pipelineFlushRequest: + return nil, errors.New("BUG: pipelineFlushRequest should not be in request queue") + default: + return nil, errors.New("BUG: unknown pipeline request type") } } func (p *Pipeline) getResultsPrepare() (*StatementDescription, error) { + err := p.receiveParseComplete("Prepare") + if err != nil { + return nil, err + } + psd := &StatementDescription{} + msg, err := p.receiveMessage() + if err != nil { + return nil, err + } + + switch msg := msg.(type) { + case *pgproto3.ParameterDescription: + psd.ParamOIDs = make([]uint32, len(msg.ParameterOIDs)) + copy(psd.ParamOIDs, msg.ParameterOIDs) + case *pgproto3.ErrorResponse: + pgErr := ErrorResponseToPgError(msg) + p.state.HandleError(pgErr) + return nil, pgErr + default: + return nil, p.handleUnexpectedMessage("Prepare ParameterDescription", msg) + } + + msg, err = p.receiveMessage() + if err != nil { + return nil, err + } + + switch msg := msg.(type) { + case *pgproto3.RowDescription: + psd.Fields = make([]FieldDescription, len(msg.Fields)) + convertRowDescription(psd.Fields, msg) + return psd, nil + + // NoData is returned instead of RowDescription when there is no expected result. e.g. An INSERT without a RETURNING + // clause. + case *pgproto3.NoData: + return psd, nil + + case *pgproto3.ErrorResponse: + pgErr := ErrorResponseToPgError(msg) + p.state.HandleError(pgErr) + return nil, pgErr + default: + return nil, p.handleUnexpectedMessage("Prepare RowDescription", msg) + } +} + +func (p *Pipeline) getResultsQueryParams() (*ResultReader, error) { + err := p.receiveParseComplete("QueryParams") + if err != nil { + return nil, err + } + + err = p.receiveBindComplete("QueryParams") + if err != nil { + return nil, err + } + + return p.receiveDescribedResultReader("QueryParams") +} + +func (p *Pipeline) getResultsQueryPrepared() (*ResultReader, error) { + err := p.receiveBindComplete("QueryPrepared") + if err != nil { + return nil, err + } + + return p.receiveDescribedResultReader("QueryPrepared") +} + +func (p *Pipeline) getResultsQueryStatement() (*ResultReader, error) { + err := p.receiveBindComplete("QueryStatement") + if err != nil { + return nil, err + } + + msg, err := p.receiveMessage() + if err != nil { + return nil, err + } + + sd, resultFormats := p.state.ExtractFrontStatementData() + if sd == nil { + return nil, errors.New("BUG: missing statement description or result formats for QueryStatement") + } + sdFields := sd.Fields + fieldDescriptions := p.conn.getFieldDescriptionSlice(len(sdFields)) + err = combineFieldDescriptionsAndResultFormats(fieldDescriptions, sdFields, resultFormats) + if err != nil { + return nil, err + } + + switch msg := msg.(type) { + case *pgproto3.DataRow: + rr := ResultReader{ + pgConn: p.conn, + pipeline: p, + ctx: p.ctx, + fieldDescriptions: fieldDescriptions, + } + rr.preloadRowValues(msg.Values) + p.conn.resultReader = rr + return &p.conn.resultReader, nil + case *pgproto3.CommandComplete: + p.conn.resultReader = ResultReader{ + commandTag: p.conn.makeCommandTag(msg.CommandTag), + commandConcluded: true, + closed: true, + fieldDescriptions: fieldDescriptions, + } + return &p.conn.resultReader, nil + case *pgproto3.ErrorResponse: + pgErr := ErrorResponseToPgError(msg) + p.state.HandleError(pgErr) + p.conn.resultReader.closed = true + return nil, pgErr + default: + return nil, p.handleUnexpectedMessage("QueryStatement", msg) + } +} + +func (p *Pipeline) getResultsDeallocate() (*CloseComplete, error) { + msg, err := p.receiveMessage() + if err != nil { + return nil, err + } + + switch msg := msg.(type) { + case *pgproto3.CloseComplete: + return &CloseComplete{}, nil + case *pgproto3.ErrorResponse: + pgErr := ErrorResponseToPgError(msg) + p.state.HandleError(pgErr) + p.conn.resultReader.closed = true + return nil, pgErr + default: + return nil, p.handleUnexpectedMessage("Deallocate", msg) + } +} + +func (p *Pipeline) getResultsSync() (*PipelineSync, error) { + msg, err := p.receiveMessage() + if err != nil { + return nil, err + } + + switch msg := msg.(type) { + case *pgproto3.ReadyForQuery: + p.state.HandleReadyForQuery() + return &PipelineSync{}, nil + case *pgproto3.ErrorResponse: + // Error message that is received while expecting a Sync message still consumes the expected Sync. Put it back. + p.state.requestEventQueue.PushFront(pipelineRequestEvent{RequestType: pipelineSyncRequest, WasSentToServer: true, BeforeFlushOrSync: true}) + + pgErr := ErrorResponseToPgError(msg) + p.state.HandleError(pgErr) + p.conn.resultReader.closed = true + return nil, pgErr + default: + return nil, p.handleUnexpectedMessage("Sync", msg) + } +} + +func (p *Pipeline) receiveParseComplete(errStr string) error { + msg, err := p.receiveMessage() + if err != nil { + return err + } + + switch msg := msg.(type) { + case *pgproto3.ParseComplete: + return nil + case *pgproto3.ErrorResponse: + pgErr := ErrorResponseToPgError(msg) + p.state.HandleError(pgErr) + return pgErr + default: + return p.handleUnexpectedMessage(fmt.Sprintf("%s Parse", errStr), msg) + } +} + +func (p *Pipeline) receiveBindComplete(errStr string) error { + msg, err := p.receiveMessage() + if err != nil { + return err + } + + switch msg := msg.(type) { + case *pgproto3.BindComplete: + return nil + case *pgproto3.ErrorResponse: + pgErr := ErrorResponseToPgError(msg) + p.state.HandleError(pgErr) + return pgErr + default: + return p.handleUnexpectedMessage(fmt.Sprintf("%s Bind", errStr), msg) + } +} + +func (p *Pipeline) receiveDescribedResultReader(errStr string) (*ResultReader, error) { + msg, err := p.receiveMessage() + if err != nil { + return nil, err + } + + switch msg := msg.(type) { + case *pgproto3.RowDescription: + p.conn.resultReader = ResultReader{ + pgConn: p.conn, + pipeline: p, + ctx: p.ctx, + fieldDescriptions: p.conn.getFieldDescriptionSlice(len(msg.Fields)), + } + convertRowDescription(p.conn.resultReader.fieldDescriptions, msg) + return &p.conn.resultReader, nil + case *pgproto3.NoData: + case *pgproto3.ErrorResponse: + pgErr := ErrorResponseToPgError(msg) + p.state.HandleError(pgErr) + p.conn.resultReader.closed = true + return nil, pgErr + default: + return nil, p.handleUnexpectedMessage(fmt.Sprintf("%s RowDescription or NoData", errStr), msg) + } + + msg, err = p.receiveMessage() + if err != nil { + return nil, err + } + + switch msg := msg.(type) { + case *pgproto3.CommandComplete: + p.conn.resultReader = ResultReader{ + commandTag: p.conn.makeCommandTag(msg.CommandTag), + commandConcluded: true, + closed: true, + } + return &p.conn.resultReader, nil + case *pgproto3.ErrorResponse: + pgErr := ErrorResponseToPgError(msg) + p.state.HandleError(pgErr) + p.conn.resultReader.closed = true + return nil, pgErr + default: + return nil, p.handleUnexpectedMessage(fmt.Sprintf("%s CommandComplete", errStr), msg) + } +} + +func (p *Pipeline) receiveMessage() (pgproto3.BackendMessage, error) { for { msg, err := p.conn.receiveMessage() if err != nil { + p.err = err p.conn.asyncClose() return nil, normalizeTimeoutError(p.ctx, err) } switch msg := msg.(type) { - case *pgproto3.ParameterDescription: - psd.ParamOIDs = make([]uint32, len(msg.ParameterOIDs)) - copy(psd.ParamOIDs, msg.ParameterOIDs) - case *pgproto3.RowDescription: - psd.Fields = p.conn.convertRowDescription(nil, msg) - return psd, nil - - // NoData is returned instead of RowDescription when there is no expected result. e.g. An INSERT without a RETURNING - // clause. - case *pgproto3.NoData: - return psd, nil - - // These should never happen here. But don't take chances that could lead to a deadlock. - case *pgproto3.ErrorResponse: - pgErr := ErrorResponseToPgError(msg) - return nil, pgErr - case *pgproto3.CommandComplete: - p.conn.asyncClose() - return nil, errors.New("BUG: received CommandComplete while handling Describe") - case *pgproto3.ReadyForQuery: - p.conn.asyncClose() - return nil, errors.New("BUG: received ReadyForQuery while handling Describe") + case *pgproto3.ParameterStatus, *pgproto3.NoticeResponse, *pgproto3.NotificationResponse: + // Filter these message types out in pipeline mode. The normal processing is handled by PgConn.receiveMessage. + default: + return msg, nil } } } +func (p *Pipeline) handleUnexpectedMessage(errStr string, msg pgproto3.BackendMessage) error { + p.err = fmt.Errorf("pipeline: %s: received unexpected message type %T", errStr, msg) + p.conn.asyncClose() + return p.err +} + // Close closes the pipeline and returns the connection to normal mode. func (p *Pipeline) Close() error { if p.closed { @@ -2250,7 +2838,7 @@ func (p *Pipeline) Close() error { p.closed = true - if p.pendingSync { + if p.state.PendingSync() { p.conn.asyncClose() p.err = errors.New("pipeline has unsynced requests") p.conn.contextWatcher.Unwatch() @@ -2259,8 +2847,8 @@ func (p *Pipeline) Close() error { return p.err } - for p.expectedReadyForQueryCount > 0 { - _, err := p.getResults() + for p.state.ExpectedReadyForQuery() > 0 { + results, err := p.getResults() if err != nil { p.err = err var pgErr *PgError @@ -2268,6 +2856,15 @@ func (p *Pipeline) Close() error { p.conn.asyncClose() break } + } else if results == nil { + // getResults returns (nil, nil) when the request queue is exhausted but + // ExpectedReadyForQuery is still > 0. This can happen when FATAL errors consume + // queued request slots without the server ever sending ReadyForQuery. + p.conn.asyncClose() + if p.err == nil { + p.err = errors.New("pipeline: no more results but expected ReadyForQuery") + } + break } } @@ -2344,3 +2941,32 @@ func (h *CancelRequestContextWatcherHandler) HandleUnwatchAfterCancel() { h.Conn.conn.SetDeadline(time.Time{}) } + +func combineFieldDescriptionsAndResultFormats(outputFields, inputFields []FieldDescription, resultFormats []int16) error { + switch { + case len(resultFormats) == 0: + // No format codes provided means text format for all columns. + for i := range inputFields { + outputFields[i] = inputFields[i] + outputFields[i].Format = pgtype.TextFormatCode + } + case len(resultFormats) == 1: + // Single format code applies to all columns. + format := resultFormats[0] + for i := range inputFields { + outputFields[i] = inputFields[i] + outputFields[i].Format = format + } + case len(resultFormats) == len(inputFields): + // One format code per column. + for i := range inputFields { + outputFields[i] = inputFields[i] + outputFields[i].Format = resultFormats[i] + } + default: + // This should not occur if Bind validation is correct, but handle gracefully + return fmt.Errorf("result format codes length %d does not match field count %d", len(resultFormats), len(inputFields)) + } + + return nil +} diff --git a/vendor/github.com/jackc/pgx/v5/pgproto3/authentication_cleartext_password.go b/vendor/github.com/jackc/pgx/v5/pgproto3/authentication_cleartext_password.go index ac2962e9..415e1a24 100644 --- a/vendor/github.com/jackc/pgx/v5/pgproto3/authentication_cleartext_password.go +++ b/vendor/github.com/jackc/pgx/v5/pgproto3/authentication_cleartext_password.go @@ -9,8 +9,7 @@ import ( ) // AuthenticationCleartextPassword is a message sent from the backend indicating that a clear-text password is required. -type AuthenticationCleartextPassword struct { -} +type AuthenticationCleartextPassword struct{} // Backend identifies this message as sendable by the PostgreSQL backend. func (*AuthenticationCleartextPassword) Backend() {} diff --git a/vendor/github.com/jackc/pgx/v5/pgproto3/authentication_ok.go b/vendor/github.com/jackc/pgx/v5/pgproto3/authentication_ok.go index ec11d39f..98c0b2d6 100644 --- a/vendor/github.com/jackc/pgx/v5/pgproto3/authentication_ok.go +++ b/vendor/github.com/jackc/pgx/v5/pgproto3/authentication_ok.go @@ -9,8 +9,7 @@ import ( ) // AuthenticationOk is a message sent from the backend indicating that authentication was successful. -type AuthenticationOk struct { -} +type AuthenticationOk struct{} // Backend identifies this message as sendable by the PostgreSQL backend. func (*AuthenticationOk) Backend() {} diff --git a/vendor/github.com/jackc/pgx/v5/pgproto3/authentication_sasl.go b/vendor/github.com/jackc/pgx/v5/pgproto3/authentication_sasl.go index e66580f4..69e22821 100644 --- a/vendor/github.com/jackc/pgx/v5/pgproto3/authentication_sasl.go +++ b/vendor/github.com/jackc/pgx/v5/pgproto3/authentication_sasl.go @@ -33,6 +33,7 @@ func (dst *AuthenticationSASL) Decode(src []byte) error { return errors.New("bad auth type") } + dst.AuthMechanisms = dst.AuthMechanisms[:0] authMechanisms := src[4:] for len(authMechanisms) > 1 { idx := bytes.IndexByte(authMechanisms, 0) diff --git a/vendor/github.com/jackc/pgx/v5/pgproto3/backend.go b/vendor/github.com/jackc/pgx/v5/pgproto3/backend.go index d146c338..65388ad4 100644 --- a/vendor/github.com/jackc/pgx/v5/pgproto3/backend.go +++ b/vendor/github.com/jackc/pgx/v5/pgproto3/backend.go @@ -46,8 +46,8 @@ type Backend struct { } const ( - minStartupPacketLen = 4 // minStartupPacketLen is a single 32-bit int version or code. - maxStartupPacketLen = 10000 // maxStartupPacketLen is MAX_STARTUP_PACKET_LENGTH from PG source. + minStartupPacketLen = 4 // minStartupPacketLen is a single 32-bit int version or code. + maxStartupPacketLen = 10_000 // maxStartupPacketLen is MAX_STARTUP_PACKET_LENGTH from PG source. ) // NewBackend creates a new Backend. @@ -123,7 +123,7 @@ func (b *Backend) ReceiveStartupMessage() (FrontendMessage, error) { if err != nil { return nil, err } - msgSize := int(binary.BigEndian.Uint32(buf) - 4) + msgSize := int(int32(binary.BigEndian.Uint32(buf)) - 4) if msgSize < minStartupPacketLen || msgSize > maxStartupPacketLen { return nil, fmt.Errorf("invalid length of startup packet: %d", msgSize) @@ -137,7 +137,7 @@ func (b *Backend) ReceiveStartupMessage() (FrontendMessage, error) { code := binary.BigEndian.Uint32(buf) switch code { - case ProtocolVersionNumber: + case ProtocolVersion30, ProtocolVersion32: err = b.startupMessage.Decode(buf) if err != nil { return nil, err @@ -175,7 +175,13 @@ func (b *Backend) Receive() (FrontendMessage, error) { } b.msgType = header[0] - b.bodyLen = int(binary.BigEndian.Uint32(header[1:])) - 4 + + msgLength := int(int32(binary.BigEndian.Uint32(header[1:]))) + if msgLength < 4 { + return nil, fmt.Errorf("invalid message length: %d", msgLength) + } + + b.bodyLen = msgLength - 4 if b.maxBodyLen > 0 && b.bodyLen > b.maxBodyLen { return nil, &ExceededMaxBodyLenErr{b.maxBodyLen, b.bodyLen} } @@ -282,9 +288,10 @@ func (b *Backend) SetAuthType(authType uint32) error { return nil } -// SetMaxBodyLen sets the maximum length of a message body in octets. If a message body exceeds this length, Receive will return -// an error. This is useful for protecting against malicious clients that send large messages with the intent of -// causing memory exhaustion. +// SetMaxBodyLen sets the maximum length of a message body in octets. +// If a message body exceeds this length, Receive will return an error. +// This is useful for protecting against malicious clients that send +// large messages with the intent of causing memory exhaustion. // The default value is 0. // If maxBodyLen is 0, then no maximum is enforced. func (b *Backend) SetMaxBodyLen(maxBodyLen int) { diff --git a/vendor/github.com/jackc/pgx/v5/pgproto3/backend_key_data.go b/vendor/github.com/jackc/pgx/v5/pgproto3/backend_key_data.go index 23f5da67..c73b2da0 100644 --- a/vendor/github.com/jackc/pgx/v5/pgproto3/backend_key_data.go +++ b/vendor/github.com/jackc/pgx/v5/pgproto3/backend_key_data.go @@ -2,6 +2,7 @@ package pgproto3 import ( "encoding/binary" + "encoding/hex" "encoding/json" "github.com/jackc/pgx/v5/internal/pgio" @@ -9,7 +10,7 @@ import ( type BackendKeyData struct { ProcessID uint32 - SecretKey uint32 + SecretKey []byte } // Backend identifies this message as sendable by the PostgreSQL backend. @@ -18,12 +19,13 @@ func (*BackendKeyData) Backend() {} // Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message // type identifier and 4 byte message length. func (dst *BackendKeyData) Decode(src []byte) error { - if len(src) != 8 { + if len(src) < 8 { return &invalidMessageLenErr{messageType: "BackendKeyData", expectedLen: 8, actualLen: len(src)} } dst.ProcessID = binary.BigEndian.Uint32(src[:4]) - dst.SecretKey = binary.BigEndian.Uint32(src[4:]) + dst.SecretKey = make([]byte, len(src)-4) + copy(dst.SecretKey, src[4:]) return nil } @@ -32,7 +34,7 @@ func (dst *BackendKeyData) Decode(src []byte) error { func (src *BackendKeyData) Encode(dst []byte) ([]byte, error) { dst, sp := beginMessage(dst, 'K') dst = pgio.AppendUint32(dst, src.ProcessID) - dst = pgio.AppendUint32(dst, src.SecretKey) + dst = append(dst, src.SecretKey...) return finishMessage(dst, sp) } @@ -41,10 +43,29 @@ func (src BackendKeyData) MarshalJSON() ([]byte, error) { return json.Marshal(struct { Type string ProcessID uint32 - SecretKey uint32 + SecretKey string }{ Type: "BackendKeyData", ProcessID: src.ProcessID, - SecretKey: src.SecretKey, + SecretKey: hex.EncodeToString(src.SecretKey), }) } + +// UnmarshalJSON implements encoding/json.Unmarshaler. +func (dst *BackendKeyData) UnmarshalJSON(data []byte) error { + var msg struct { + ProcessID uint32 + SecretKey string + } + if err := json.Unmarshal(data, &msg); err != nil { + return err + } + + dst.ProcessID = msg.ProcessID + secretKey, err := hex.DecodeString(msg.SecretKey) + if err != nil { + return err + } + dst.SecretKey = secretKey + return nil +} diff --git a/vendor/github.com/jackc/pgx/v5/pgproto3/bind.go b/vendor/github.com/jackc/pgx/v5/pgproto3/bind.go index ad6ac48b..fb56e4dc 100644 --- a/vendor/github.com/jackc/pgx/v5/pgproto3/bind.go +++ b/vendor/github.com/jackc/pgx/v5/pgproto3/bind.go @@ -54,7 +54,7 @@ func (dst *Bind) Decode(src []byte) error { if len(src[rp:]) < len(dst.ParameterFormatCodes)*2 { return &invalidMessageFormatErr{messageType: "Bind"} } - for i := 0; i < parameterFormatCodeCount; i++ { + for i := range parameterFormatCodeCount { dst.ParameterFormatCodes[i] = int16(binary.BigEndian.Uint16(src[rp:])) rp += 2 } @@ -69,7 +69,7 @@ func (dst *Bind) Decode(src []byte) error { if parameterCount > 0 { dst.Parameters = make([][]byte, parameterCount) - for i := 0; i < parameterCount; i++ { + for i := range parameterCount { if len(src[rp:]) < 4 { return &invalidMessageFormatErr{messageType: "Bind"} } @@ -82,7 +82,7 @@ func (dst *Bind) Decode(src []byte) error { continue } - if len(src[rp:]) < msgSize { + if msgSize < 0 || len(src[rp:]) < msgSize { return &invalidMessageFormatErr{messageType: "Bind"} } @@ -101,7 +101,7 @@ func (dst *Bind) Decode(src []byte) error { if len(src[rp:]) < len(dst.ResultFormatCodes)*2 { return &invalidMessageFormatErr{messageType: "Bind"} } - for i := 0; i < resultFormatCodeCount; i++ { + for i := range resultFormatCodeCount { dst.ResultFormatCodes[i] = int16(binary.BigEndian.Uint16(src[rp:])) rp += 2 } diff --git a/vendor/github.com/jackc/pgx/v5/pgproto3/cancel_request.go b/vendor/github.com/jackc/pgx/v5/pgproto3/cancel_request.go index 6b52dd97..63ebe5c4 100644 --- a/vendor/github.com/jackc/pgx/v5/pgproto3/cancel_request.go +++ b/vendor/github.com/jackc/pgx/v5/pgproto3/cancel_request.go @@ -2,6 +2,7 @@ package pgproto3 import ( "encoding/binary" + "encoding/hex" "encoding/json" "errors" @@ -12,35 +13,42 @@ const cancelRequestCode = 80877102 type CancelRequest struct { ProcessID uint32 - SecretKey uint32 + SecretKey []byte } // Frontend identifies this message as sendable by a PostgreSQL frontend. func (*CancelRequest) Frontend() {} func (dst *CancelRequest) Decode(src []byte) error { - if len(src) != 12 { - return errors.New("bad cancel request size") + if len(src) < 12 { + return errors.New("cancel request too short") + } + if len(src) > 264 { + return errors.New("cancel request too long") } requestCode := binary.BigEndian.Uint32(src) - if requestCode != cancelRequestCode { return errors.New("bad cancel request code") } dst.ProcessID = binary.BigEndian.Uint32(src[4:]) - dst.SecretKey = binary.BigEndian.Uint32(src[8:]) + dst.SecretKey = make([]byte, len(src)-8) + copy(dst.SecretKey, src[8:]) return nil } // Encode encodes src into dst. dst will include the 4 byte message length. func (src *CancelRequest) Encode(dst []byte) ([]byte, error) { - dst = pgio.AppendInt32(dst, 16) + if len(src.SecretKey) > 256 { + return nil, errors.New("secret key too long") + } + msgLen := int32(12 + len(src.SecretKey)) + dst = pgio.AppendInt32(dst, msgLen) dst = pgio.AppendInt32(dst, cancelRequestCode) dst = pgio.AppendUint32(dst, src.ProcessID) - dst = pgio.AppendUint32(dst, src.SecretKey) + dst = append(dst, src.SecretKey...) return dst, nil } @@ -49,10 +57,29 @@ func (src CancelRequest) MarshalJSON() ([]byte, error) { return json.Marshal(struct { Type string ProcessID uint32 - SecretKey uint32 + SecretKey string }{ Type: "CancelRequest", ProcessID: src.ProcessID, - SecretKey: src.SecretKey, + SecretKey: hex.EncodeToString(src.SecretKey), }) } + +// UnmarshalJSON implements encoding/json.Unmarshaler. +func (dst *CancelRequest) UnmarshalJSON(data []byte) error { + var msg struct { + ProcessID uint32 + SecretKey string + } + if err := json.Unmarshal(data, &msg); err != nil { + return err + } + + dst.ProcessID = msg.ProcessID + secretKey, err := hex.DecodeString(msg.SecretKey) + if err != nil { + return err + } + dst.SecretKey = secretKey + return nil +} diff --git a/vendor/github.com/jackc/pgx/v5/pgproto3/copy_both_response.go b/vendor/github.com/jackc/pgx/v5/pgproto3/copy_both_response.go index 99e1afea..e2a402f9 100644 --- a/vendor/github.com/jackc/pgx/v5/pgproto3/copy_both_response.go +++ b/vendor/github.com/jackc/pgx/v5/pgproto3/copy_both_response.go @@ -35,7 +35,7 @@ func (dst *CopyBothResponse) Decode(src []byte) error { } columnFormatCodes := make([]uint16, columnCount) - for i := 0; i < columnCount; i++ { + for i := range columnCount { columnFormatCodes[i] = binary.BigEndian.Uint16(buf.Next(2)) } diff --git a/vendor/github.com/jackc/pgx/v5/pgproto3/copy_done.go b/vendor/github.com/jackc/pgx/v5/pgproto3/copy_done.go index 040814db..c3421a9b 100644 --- a/vendor/github.com/jackc/pgx/v5/pgproto3/copy_done.go +++ b/vendor/github.com/jackc/pgx/v5/pgproto3/copy_done.go @@ -4,8 +4,7 @@ import ( "encoding/json" ) -type CopyDone struct { -} +type CopyDone struct{} // Backend identifies this message as sendable by the PostgreSQL backend. func (*CopyDone) Backend() {} diff --git a/vendor/github.com/jackc/pgx/v5/pgproto3/copy_fail.go b/vendor/github.com/jackc/pgx/v5/pgproto3/copy_fail.go index 72a85fd0..f8a00b8b 100644 --- a/vendor/github.com/jackc/pgx/v5/pgproto3/copy_fail.go +++ b/vendor/github.com/jackc/pgx/v5/pgproto3/copy_fail.go @@ -15,6 +15,10 @@ func (*CopyFail) Frontend() {} // Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message // type identifier and 4 byte message length. func (dst *CopyFail) Decode(src []byte) error { + if len(src) == 0 { + return &invalidMessageFormatErr{messageType: "CopyFail"} + } + idx := bytes.IndexByte(src, 0) if idx != len(src)-1 { return &invalidMessageFormatErr{messageType: "CopyFail"} diff --git a/vendor/github.com/jackc/pgx/v5/pgproto3/copy_in_response.go b/vendor/github.com/jackc/pgx/v5/pgproto3/copy_in_response.go index 06cf99ce..0633935b 100644 --- a/vendor/github.com/jackc/pgx/v5/pgproto3/copy_in_response.go +++ b/vendor/github.com/jackc/pgx/v5/pgproto3/copy_in_response.go @@ -35,7 +35,7 @@ func (dst *CopyInResponse) Decode(src []byte) error { } columnFormatCodes := make([]uint16, columnCount) - for i := 0; i < columnCount; i++ { + for i := range columnCount { columnFormatCodes[i] = binary.BigEndian.Uint16(buf.Next(2)) } diff --git a/vendor/github.com/jackc/pgx/v5/pgproto3/copy_out_response.go b/vendor/github.com/jackc/pgx/v5/pgproto3/copy_out_response.go index 549e916c..006864ac 100644 --- a/vendor/github.com/jackc/pgx/v5/pgproto3/copy_out_response.go +++ b/vendor/github.com/jackc/pgx/v5/pgproto3/copy_out_response.go @@ -34,7 +34,7 @@ func (dst *CopyOutResponse) Decode(src []byte) error { } columnFormatCodes := make([]uint16, columnCount) - for i := 0; i < columnCount; i++ { + for i := range columnCount { columnFormatCodes[i] = binary.BigEndian.Uint16(buf.Next(2)) } diff --git a/vendor/github.com/jackc/pgx/v5/pgproto3/data_row.go b/vendor/github.com/jackc/pgx/v5/pgproto3/data_row.go index fdfb0f7f..54418d58 100644 --- a/vendor/github.com/jackc/pgx/v5/pgproto3/data_row.go +++ b/vendor/github.com/jackc/pgx/v5/pgproto3/data_row.go @@ -31,16 +31,13 @@ func (dst *DataRow) Decode(src []byte) error { // large reallocate. This is too avoid one row with many columns from // permanently allocating memory. if cap(dst.Values) < fieldCount || cap(dst.Values)-fieldCount > 32 { - newCap := 32 - if newCap < fieldCount { - newCap = fieldCount - } + newCap := max(32, fieldCount) dst.Values = make([][]byte, fieldCount, newCap) } else { dst.Values = dst.Values[:fieldCount] } - for i := 0; i < fieldCount; i++ { + for i := range fieldCount { if len(src[rp:]) < 4 { return &invalidMessageFormatErr{messageType: "DataRow"} } diff --git a/vendor/github.com/jackc/pgx/v5/pgproto3/frontend.go b/vendor/github.com/jackc/pgx/v5/pgproto3/frontend.go index b41abbe1..3d66518b 100644 --- a/vendor/github.com/jackc/pgx/v5/pgproto3/frontend.go +++ b/vendor/github.com/jackc/pgx/v5/pgproto3/frontend.go @@ -52,8 +52,10 @@ type Frontend struct { readyForQuery ReadyForQuery rowDescription RowDescription portalSuspended PortalSuspended + negotiateProtocolVersion NegotiateProtocolVersion bodyLen int + maxBodyLen int // maxBodyLen is the maximum length of a message body in octets. If a message body exceeds this length, Receive will return an error. msgType byte partialMsg bool authType uint32 @@ -229,7 +231,7 @@ func (f *Frontend) SendExecute(msg *Execute) { f.wbuf = newBuf if f.tracer != nil { - f.tracer.TraceQueryute('F', int32(len(f.wbuf)-prevLen), msg) + f.tracer.traceExecute('F', int32(len(f.wbuf)-prevLen), msg) } } @@ -311,12 +313,15 @@ func (f *Frontend) Receive() (BackendMessage, error) { f.msgType = header[0] - msgLength := int(binary.BigEndian.Uint32(header[1:])) + msgLength := int(int32(binary.BigEndian.Uint32(header[1:]))) if msgLength < 4 { return nil, fmt.Errorf("invalid message length: %d", msgLength) } f.bodyLen = msgLength - 4 + if f.maxBodyLen > 0 && f.bodyLen > f.maxBodyLen { + return nil, &ExceededMaxBodyLenErr{f.maxBodyLen, f.bodyLen} + } f.partialMsg = true } @@ -379,6 +384,8 @@ func (f *Frontend) Receive() (BackendMessage, error) { msg = &f.copyBothResponse case 'Z': msg = &f.readyForQuery + case 'v': + msg = &f.negotiateProtocolVersion default: return nil, fmt.Errorf("unknown message type: %c", f.msgType) } @@ -452,3 +459,13 @@ func (f *Frontend) GetAuthType() uint32 { func (f *Frontend) ReadBufferLen() int { return f.cr.wp - f.cr.rp } + +// SetMaxBodyLen sets the maximum length of a message body in octets. +// If a message body exceeds this length, Receive will return an error. +// This is useful for protecting against a corrupted server that sends +// messages with incorrect length, which can cause memory exhaustion. +// The default value is 0. +// If maxBodyLen is 0, then no maximum is enforced. +func (f *Frontend) SetMaxBodyLen(maxBodyLen int) { + f.maxBodyLen = maxBodyLen +} diff --git a/vendor/github.com/jackc/pgx/v5/pgproto3/function_call.go b/vendor/github.com/jackc/pgx/v5/pgproto3/function_call.go index 7d83579f..23bbd8b8 100644 --- a/vendor/github.com/jackc/pgx/v5/pgproto3/function_call.go +++ b/vendor/github.com/jackc/pgx/v5/pgproto3/function_call.go @@ -23,6 +23,11 @@ func (*FunctionCall) Frontend() {} func (dst *FunctionCall) Decode(src []byte) error { *dst = FunctionCall{} rp := 0 + + if len(src) < 8 { + return &invalidMessageFormatErr{messageType: "FunctionCall"} + } + // Specifies the object ID of the function to call. dst.Function = binary.BigEndian.Uint32(src[rp:]) rp += 4 @@ -32,8 +37,13 @@ func (dst *FunctionCall) Decode(src []byte) error { // or it can equal the actual number of arguments. nArgumentCodes := int(binary.BigEndian.Uint16(src[rp:])) rp += 2 + + if len(src[rp:]) < nArgumentCodes*2+2 { + return &invalidMessageFormatErr{messageType: "FunctionCall"} + } + argumentCodes := make([]uint16, nArgumentCodes) - for i := 0; i < nArgumentCodes; i++ { + for i := range nArgumentCodes { // The argument format codes. Each must presently be zero (text) or one (binary). ac := binary.BigEndian.Uint16(src[rp:]) if ac != 0 && ac != 1 { @@ -48,14 +58,22 @@ func (dst *FunctionCall) Decode(src []byte) error { nArguments := int(binary.BigEndian.Uint16(src[rp:])) rp += 2 arguments := make([][]byte, nArguments) - for i := 0; i < nArguments; i++ { + for i := range nArguments { + if len(src[rp:]) < 4 { + return &invalidMessageFormatErr{messageType: "FunctionCall"} + } // The length of the argument value, in bytes (this count does not include itself). Can be zero. // As a special case, -1 indicates a NULL argument value. No value bytes follow in the NULL case. - argumentLength := int(binary.BigEndian.Uint32(src[rp:])) + argumentLength := int(int32(binary.BigEndian.Uint32(src[rp:]))) rp += 4 if argumentLength == -1 { arguments[i] = nil + } else if argumentLength < 0 { + return &invalidMessageFormatErr{messageType: "FunctionCall"} } else { + if len(src[rp:]) < argumentLength { + return &invalidMessageFormatErr{messageType: "FunctionCall"} + } // The value of the argument, in the format indicated by the associated format code. n is the above length. argumentValue := src[rp : rp+argumentLength] rp += argumentLength @@ -64,6 +82,9 @@ func (dst *FunctionCall) Decode(src []byte) error { } dst.Arguments = arguments // The format code for the function result. Must presently be zero (text) or one (binary). + if len(src[rp:]) < 2 { + return &invalidMessageFormatErr{messageType: "FunctionCall"} + } resultFormatCode := binary.BigEndian.Uint16(src[rp:]) if resultFormatCode != 0 && resultFormatCode != 1 { return &invalidMessageFormatErr{messageType: "FunctionCall"} diff --git a/vendor/github.com/jackc/pgx/v5/pgproto3/function_call_response.go b/vendor/github.com/jackc/pgx/v5/pgproto3/function_call_response.go index 1f273495..6b6ed8b9 100644 --- a/vendor/github.com/jackc/pgx/v5/pgproto3/function_call_response.go +++ b/vendor/github.com/jackc/pgx/v5/pgproto3/function_call_response.go @@ -22,7 +22,7 @@ func (dst *FunctionCallResponse) Decode(src []byte) error { return &invalidMessageFormatErr{messageType: "FunctionCallResponse"} } rp := 0 - resultSize := int(binary.BigEndian.Uint32(src[rp:])) + resultSize := int(int32(binary.BigEndian.Uint32(src[rp:]))) rp += 4 if resultSize == -1 { @@ -30,7 +30,7 @@ func (dst *FunctionCallResponse) Decode(src []byte) error { return nil } - if len(src[rp:]) != resultSize { + if resultSize < 0 || len(src[rp:]) != resultSize { return &invalidMessageFormatErr{messageType: "FunctionCallResponse"} } diff --git a/vendor/github.com/jackc/pgx/v5/pgproto3/gss_enc_request.go b/vendor/github.com/jackc/pgx/v5/pgproto3/gss_enc_request.go index 70cb20cd..122d1341 100644 --- a/vendor/github.com/jackc/pgx/v5/pgproto3/gss_enc_request.go +++ b/vendor/github.com/jackc/pgx/v5/pgproto3/gss_enc_request.go @@ -10,8 +10,7 @@ import ( const gssEncReqNumber = 80877104 -type GSSEncRequest struct { -} +type GSSEncRequest struct{} // Frontend identifies this message as sendable by a PostgreSQL frontend. func (*GSSEncRequest) Frontend() {} diff --git a/vendor/github.com/jackc/pgx/v5/pgproto3/negotiate_protocol_version.go b/vendor/github.com/jackc/pgx/v5/pgproto3/negotiate_protocol_version.go new file mode 100644 index 00000000..43bd7ec6 --- /dev/null +++ b/vendor/github.com/jackc/pgx/v5/pgproto3/negotiate_protocol_version.go @@ -0,0 +1,93 @@ +package pgproto3 + +import ( + "encoding/binary" + "encoding/json" + + "github.com/jackc/pgx/v5/internal/pgio" +) + +type NegotiateProtocolVersion struct { + NewestMinorProtocol uint32 + UnrecognizedOptions []string +} + +// Backend identifies this message as sendable by the PostgreSQL backend. +func (*NegotiateProtocolVersion) Backend() {} + +// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message +// type identifier and 4 byte message length. +func (dst *NegotiateProtocolVersion) Decode(src []byte) error { + if len(src) < 8 { + return &invalidMessageLenErr{messageType: "NegotiateProtocolVersion", expectedLen: 8, actualLen: len(src)} + } + + dst.NewestMinorProtocol = binary.BigEndian.Uint32(src[:4]) + optionCount := int(binary.BigEndian.Uint32(src[4:8])) + + rp := 8 + + // Use the remaining message size as an upper bound for capacity to prevent + // malicious optionCount values from causing excessive memory allocation. + capHint := optionCount + if remaining := len(src) - rp; capHint > remaining { + capHint = remaining + } + dst.UnrecognizedOptions = make([]string, 0, capHint) + for i := 0; i < optionCount; i++ { + if rp >= len(src) { + return &invalidMessageFormatErr{messageType: "NegotiateProtocolVersion"} + } + end := rp + for end < len(src) && src[end] != 0 { + end++ + } + if end >= len(src) { + return &invalidMessageFormatErr{messageType: "NegotiateProtocolVersion"} + } + dst.UnrecognizedOptions = append(dst.UnrecognizedOptions, string(src[rp:end])) + rp = end + 1 + } + + return nil +} + +// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. +func (src *NegotiateProtocolVersion) Encode(dst []byte) ([]byte, error) { + dst, sp := beginMessage(dst, 'v') + dst = pgio.AppendUint32(dst, src.NewestMinorProtocol) + dst = pgio.AppendUint32(dst, uint32(len(src.UnrecognizedOptions))) + for _, option := range src.UnrecognizedOptions { + dst = append(dst, option...) + dst = append(dst, 0) + } + return finishMessage(dst, sp) +} + +// MarshalJSON implements encoding/json.Marshaler. +func (src NegotiateProtocolVersion) MarshalJSON() ([]byte, error) { + return json.Marshal(struct { + Type string + NewestMinorProtocol uint32 + UnrecognizedOptions []string + }{ + Type: "NegotiateProtocolVersion", + NewestMinorProtocol: src.NewestMinorProtocol, + UnrecognizedOptions: src.UnrecognizedOptions, + }) +} + +// UnmarshalJSON implements encoding/json.Unmarshaler. +func (dst *NegotiateProtocolVersion) UnmarshalJSON(data []byte) error { + var msg struct { + NewestMinorProtocol uint32 + UnrecognizedOptions []string + } + if err := json.Unmarshal(data, &msg); err != nil { + return err + } + + dst.NewestMinorProtocol = msg.NewestMinorProtocol + dst.UnrecognizedOptions = msg.UnrecognizedOptions + return nil +} diff --git a/vendor/github.com/jackc/pgx/v5/pgproto3/parameter_description.go b/vendor/github.com/jackc/pgx/v5/pgproto3/parameter_description.go index 1ef27b75..58eb26ef 100644 --- a/vendor/github.com/jackc/pgx/v5/pgproto3/parameter_description.go +++ b/vendor/github.com/jackc/pgx/v5/pgproto3/parameter_description.go @@ -33,7 +33,7 @@ func (dst *ParameterDescription) Decode(src []byte) error { *dst = ParameterDescription{ParameterOIDs: make([]uint32, parameterCount)} - for i := 0; i < parameterCount; i++ { + for i := range parameterCount { dst.ParameterOIDs[i] = binary.BigEndian.Uint32(buf.Next(4)) } diff --git a/vendor/github.com/jackc/pgx/v5/pgproto3/parse.go b/vendor/github.com/jackc/pgx/v5/pgproto3/parse.go index 6ba3486c..8fb8de5d 100644 --- a/vendor/github.com/jackc/pgx/v5/pgproto3/parse.go +++ b/vendor/github.com/jackc/pgx/v5/pgproto3/parse.go @@ -43,7 +43,7 @@ func (dst *Parse) Decode(src []byte) error { } parameterOIDCount := int(binary.BigEndian.Uint16(buf.Next(2))) - for i := 0; i < parameterOIDCount; i++ { + for range parameterOIDCount { if buf.Len() < 4 { return &invalidMessageFormatErr{messageType: "Parse"} } diff --git a/vendor/github.com/jackc/pgx/v5/pgproto3/password_message.go b/vendor/github.com/jackc/pgx/v5/pgproto3/password_message.go index d820d327..67b78515 100644 --- a/vendor/github.com/jackc/pgx/v5/pgproto3/password_message.go +++ b/vendor/github.com/jackc/pgx/v5/pgproto3/password_message.go @@ -12,7 +12,7 @@ type PasswordMessage struct { // Frontend identifies this message as sendable by a PostgreSQL frontend. func (*PasswordMessage) Frontend() {} -// Frontend identifies this message as an authentication response. +// InitialResponse identifies this message as an authentication response. func (*PasswordMessage) InitialResponse() {} // Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message diff --git a/vendor/github.com/jackc/pgx/v5/pgproto3/query.go b/vendor/github.com/jackc/pgx/v5/pgproto3/query.go index aebdfde8..9e16465c 100644 --- a/vendor/github.com/jackc/pgx/v5/pgproto3/query.go +++ b/vendor/github.com/jackc/pgx/v5/pgproto3/query.go @@ -15,6 +15,10 @@ func (*Query) Frontend() {} // Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message // type identifier and 4 byte message length. func (dst *Query) Decode(src []byte) error { + if len(src) == 0 { + return &invalidMessageFormatErr{messageType: "Query"} + } + i := bytes.IndexByte(src, 0) if i != len(src)-1 { return &invalidMessageFormatErr{messageType: "Query"} diff --git a/vendor/github.com/jackc/pgx/v5/pgproto3/row_description.go b/vendor/github.com/jackc/pgx/v5/pgproto3/row_description.go index dc2a4ddf..b46f510d 100644 --- a/vendor/github.com/jackc/pgx/v5/pgproto3/row_description.go +++ b/vendor/github.com/jackc/pgx/v5/pgproto3/row_description.go @@ -56,7 +56,6 @@ func (*RowDescription) Backend() {} // Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message // type identifier and 4 byte message length. func (dst *RowDescription) Decode(src []byte) error { - if len(src) < 2 { return &invalidMessageFormatErr{messageType: "RowDescription"} } @@ -65,7 +64,7 @@ func (dst *RowDescription) Decode(src []byte) error { dst.Fields = dst.Fields[0:0] - for i := 0; i < fieldCount; i++ { + for range fieldCount { var fd FieldDescription idx := bytes.IndexByte(src[rp:], 0) diff --git a/vendor/github.com/jackc/pgx/v5/pgproto3/sasl_initial_response.go b/vendor/github.com/jackc/pgx/v5/pgproto3/sasl_initial_response.go index 9eb1b6a4..123f3cd6 100644 --- a/vendor/github.com/jackc/pgx/v5/pgproto3/sasl_initial_response.go +++ b/vendor/github.com/jackc/pgx/v5/pgproto3/sasl_initial_response.go @@ -32,6 +32,9 @@ func (dst *SASLInitialResponse) Decode(src []byte) error { dst.AuthMechanism = string(src[rp:idx]) rp = idx + 1 + if len(src[rp:]) < 4 { + return errors.New("invalid SASLInitialResponse") + } rp += 4 // The rest of the message is data so we can just skip the size dst.Data = src[rp:] diff --git a/vendor/github.com/jackc/pgx/v5/pgproto3/ssl_request.go b/vendor/github.com/jackc/pgx/v5/pgproto3/ssl_request.go index b0fc2847..bdfc7c42 100644 --- a/vendor/github.com/jackc/pgx/v5/pgproto3/ssl_request.go +++ b/vendor/github.com/jackc/pgx/v5/pgproto3/ssl_request.go @@ -10,8 +10,7 @@ import ( const sslRequestNumber = 80877103 -type SSLRequest struct { -} +type SSLRequest struct{} // Frontend identifies this message as sendable by a PostgreSQL frontend. func (*SSLRequest) Frontend() {} diff --git a/vendor/github.com/jackc/pgx/v5/pgproto3/startup_message.go b/vendor/github.com/jackc/pgx/v5/pgproto3/startup_message.go index 3af4587d..eb48f72b 100644 --- a/vendor/github.com/jackc/pgx/v5/pgproto3/startup_message.go +++ b/vendor/github.com/jackc/pgx/v5/pgproto3/startup_message.go @@ -10,7 +10,12 @@ import ( "github.com/jackc/pgx/v5/internal/pgio" ) -const ProtocolVersionNumber = 196608 // 3.0 +const ( + ProtocolVersion30 = 196608 // 3.0 + ProtocolVersion32 = 196610 // 3.2 + ProtocolVersionLatest = ProtocolVersion32 // Latest is 3.2 + ProtocolVersionNumber = ProtocolVersion30 // Default is still 3.0 +) type StartupMessage struct { ProtocolVersion uint32 @@ -30,8 +35,8 @@ func (dst *StartupMessage) Decode(src []byte) error { dst.ProtocolVersion = binary.BigEndian.Uint32(src) rp := 4 - if dst.ProtocolVersion != ProtocolVersionNumber { - return fmt.Errorf("Bad startup message version number. Expected %d, got %d", ProtocolVersionNumber, dst.ProtocolVersion) + if dst.ProtocolVersion != ProtocolVersion30 && dst.ProtocolVersion != ProtocolVersion32 { + return fmt.Errorf("Bad startup message version number. Expected %d or %d, got %d", ProtocolVersion30, ProtocolVersion32, dst.ProtocolVersion) } dst.Parameters = make(map[string]string) diff --git a/vendor/github.com/jackc/pgx/v5/pgproto3/trace.go b/vendor/github.com/jackc/pgx/v5/pgproto3/trace.go index 6cc7d3e3..2f9da628 100644 --- a/vendor/github.com/jackc/pgx/v5/pgproto3/trace.go +++ b/vendor/github.com/jackc/pgx/v5/pgproto3/trace.go @@ -82,7 +82,7 @@ func (t *tracer) traceMessage(sender byte, encodedLen int32, msg Message) { case *ErrorResponse: t.traceErrorResponse(sender, encodedLen, msg) case *Execute: - t.TraceQueryute(sender, encodedLen, msg) + t.traceExecute(sender, encodedLen, msg) case *Flush: t.traceFlush(sender, encodedLen, msg) case *FunctionCall: @@ -260,7 +260,7 @@ func (t *tracer) traceErrorResponse(sender byte, encodedLen int32, msg *ErrorRes t.writeTrace(sender, encodedLen, "ErrorResponse", nil) } -func (t *tracer) TraceQueryute(sender byte, encodedLen int32, msg *Execute) { +func (t *tracer) traceExecute(sender byte, encodedLen int32, msg *Execute) { t.writeTrace(sender, encodedLen, "Execute", func() { fmt.Fprintf(t.buf, "\t %s %d", traceDoubleQuotedString([]byte(msg.Portal)), msg.MaxRows) }) diff --git a/vendor/github.com/jackc/pgx/v5/pgtype/array.go b/vendor/github.com/jackc/pgx/v5/pgtype/array.go index 06b824ad..10b96e7b 100644 --- a/vendor/github.com/jackc/pgx/v5/pgtype/array.go +++ b/vendor/github.com/jackc/pgx/v5/pgtype/array.go @@ -38,6 +38,10 @@ func cardinality(dimensions []ArrayDimension) int { elementCount *= int(d.Length) } + if elementCount < 0 { + return 0 + } + return elementCount } @@ -51,16 +55,20 @@ func (dst *arrayHeader) DecodeBinary(m *Map, src []byte) (int, error) { numDims := int(binary.BigEndian.Uint32(src[rp:])) rp += 4 + if numDims > 6 { + return 0, fmt.Errorf("array has too many dimensions: %d", numDims) + } + dst.ContainsNull = binary.BigEndian.Uint32(src[rp:]) == 1 rp += 4 dst.ElementOID = binary.BigEndian.Uint32(src[rp:]) rp += 4 - dst.Dimensions = make([]ArrayDimension, numDims) if len(src) < 12+numDims*8 { return 0, fmt.Errorf("array header too short for %d dimensions: %d", numDims, len(src)) } + dst.Dimensions = make([]ArrayDimension, numDims) for i := range dst.Dimensions { dst.Dimensions[i].Length = int32(binary.BigEndian.Uint32(src[rp:])) rp += 4 @@ -299,7 +307,7 @@ func arrayParseQuotedValue(buf *bytes.Buffer) (string, bool, error) { return "", false, err } case '"': - r, _, err = buf.ReadRune() + _, _, err = buf.ReadRune() if err != nil { return "", false, err } @@ -374,8 +382,8 @@ func quoteArrayElementIfNeeded(src string) string { return src } -// Array represents a PostgreSQL array for T. It implements the ArrayGetter and ArraySetter interfaces. It preserves -// PostgreSQL dimensions and custom lower bounds. Use FlatArray if these are not needed. +// Array represents a PostgreSQL array for T. It implements the [ArrayGetter] and [ArraySetter] interfaces. It preserves +// PostgreSQL dimensions and custom lower bounds. Use [FlatArray] if these are not needed. type Array[T any] struct { Elements []T Dims []ArrayDimension @@ -419,8 +427,8 @@ func (a Array[T]) ScanIndexType() any { return new(T) } -// FlatArray implements the ArrayGetter and ArraySetter interfaces for any slice of T. It ignores PostgreSQL dimensions -// and custom lower bounds. Use Array to preserve these. +// FlatArray implements the [ArrayGetter] and [ArraySetter] interfaces for any slice of T. It ignores PostgreSQL dimensions +// and custom lower bounds. Use [Array] to preserve these. type FlatArray[T any] []T func (a FlatArray[T]) Dimensions() []ArrayDimension { diff --git a/vendor/github.com/jackc/pgx/v5/pgtype/array_codec.go b/vendor/github.com/jackc/pgx/v5/pgtype/array_codec.go index bf5f6989..f6b36f43 100644 --- a/vendor/github.com/jackc/pgx/v5/pgtype/array_codec.go +++ b/vendor/github.com/jackc/pgx/v5/pgtype/array_codec.go @@ -118,7 +118,7 @@ func (p *encodePlanArrayCodecText) Encode(value any, buf []byte) (newBuf []byte, var encodePlan EncodePlan var lastElemType reflect.Type inElemBuf := make([]byte, 0, 32) - for i := 0; i < elementCount; i++ { + for i := range elementCount { if i > 0 { buf = append(buf, ',') } @@ -131,7 +131,7 @@ func (p *encodePlanArrayCodecText) Encode(value any, buf []byte) (newBuf []byte, elem := array.Index(i) var elemBuf []byte - if elem != nil { + if isNil, _ := isNilDriverValuer(elem); !isNil { elemType := reflect.TypeOf(elem) if lastElemType != elemType { lastElemType = elemType @@ -189,13 +189,13 @@ func (p *encodePlanArrayCodecBinary) Encode(value any, buf []byte) (newBuf []byt var encodePlan EncodePlan var lastElemType reflect.Type - for i := 0; i < elementCount; i++ { + for i := range elementCount { sp := len(buf) buf = pgio.AppendInt32(buf, -1) elem := array.Index(i) var elemBuf []byte - if elem != nil { + if isNil, _ := isNilDriverValuer(elem); !isNil { elemType := reflect.TypeOf(elem) if lastElemType != elemType { lastElemType = elemType @@ -270,7 +270,7 @@ func (c *ArrayCodec) decodeBinary(m *Map, arrayOID uint32, src []byte, array Arr elementScanPlan = m.PlanScan(c.ElementType.OID, BinaryFormatCode, array.ScanIndex(0)) } - for i := 0; i < elementCount; i++ { + for i := range elementCount { elem := array.ScanIndex(i) elemLen := int(int32(binary.BigEndian.Uint32(src[rp:]))) rp += 4 @@ -388,7 +388,7 @@ func isRagged(slice reflect.Value) bool { sliceLen := slice.Len() innerLen := 0 - for i := 0; i < sliceLen; i++ { + for i := range sliceLen { if i == 0 { innerLen = slice.Index(i).Len() } else { diff --git a/vendor/github.com/jackc/pgx/v5/pgtype/bits.go b/vendor/github.com/jackc/pgx/v5/pgtype/bits.go index e7a1d016..2a48e354 100644 --- a/vendor/github.com/jackc/pgx/v5/pgtype/bits.go +++ b/vendor/github.com/jackc/pgx/v5/pgtype/bits.go @@ -23,16 +23,18 @@ type Bits struct { Valid bool } +// ScanBits implements the [BitsScanner] interface. func (b *Bits) ScanBits(v Bits) error { *b = v return nil } +// BitsValue implements the [BitsValuer] interface. func (b Bits) BitsValue() (Bits, error) { return b, nil } -// Scan implements the database/sql Scanner interface. +// Scan implements the [database/sql.Scanner] interface. func (dst *Bits) Scan(src any) error { if src == nil { *dst = Bits{} @@ -47,7 +49,7 @@ func (dst *Bits) Scan(src any) error { return fmt.Errorf("cannot scan %T", src) } -// Value implements the database/sql/driver Valuer interface. +// Value implements the [database/sql/driver.Valuer] interface. func (src Bits) Value() (driver.Value, error) { if !src.Valid { return nil, nil @@ -127,7 +129,6 @@ func (encodePlanBitsCodecText) Encode(value any, buf []byte) (newBuf []byte, err } func (BitsCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan { - switch format { case BinaryFormatCode: switch target.(type) { diff --git a/vendor/github.com/jackc/pgx/v5/pgtype/bool.go b/vendor/github.com/jackc/pgx/v5/pgtype/bool.go index 71caffa7..955f01fe 100644 --- a/vendor/github.com/jackc/pgx/v5/pgtype/bool.go +++ b/vendor/github.com/jackc/pgx/v5/pgtype/bool.go @@ -22,16 +22,18 @@ type Bool struct { Valid bool } +// ScanBool implements the [BoolScanner] interface. func (b *Bool) ScanBool(v Bool) error { *b = v return nil } +// BoolValue implements the [BoolValuer] interface. func (b Bool) BoolValue() (Bool, error) { return b, nil } -// Scan implements the database/sql Scanner interface. +// Scan implements the [database/sql.Scanner] interface. func (dst *Bool) Scan(src any) error { if src == nil { *dst = Bool{} @@ -61,7 +63,7 @@ func (dst *Bool) Scan(src any) error { return fmt.Errorf("cannot scan %T", src) } -// Value implements the database/sql/driver Valuer interface. +// Value implements the [database/sql/driver.Valuer] interface. func (src Bool) Value() (driver.Value, error) { if !src.Valid { return nil, nil @@ -70,6 +72,7 @@ func (src Bool) Value() (driver.Value, error) { return src.Bool, nil } +// MarshalJSON implements the [encoding/json.Marshaler] interface. func (src Bool) MarshalJSON() ([]byte, error) { if !src.Valid { return []byte("null"), nil @@ -82,6 +85,7 @@ func (src Bool) MarshalJSON() ([]byte, error) { } } +// UnmarshalJSON implements the [encoding/json.Unmarshaler] interface. func (dst *Bool) UnmarshalJSON(b []byte) error { var v *bool err := json.Unmarshal(b, &v) @@ -200,7 +204,6 @@ func (encodePlanBoolCodecTextBool) Encode(value any, buf []byte) (newBuf []byte, } func (BoolCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan { - switch format { case BinaryFormatCode: switch target.(type) { @@ -328,7 +331,7 @@ func (scanPlanTextAnyToBoolScanner) Scan(src []byte, dst any) error { return s.ScanBool(Bool{Bool: v, Valid: true}) } -// https://www.postgresql.org/docs/11/datatype-boolean.html +// https://www.postgresql.org/docs/current/datatype-boolean.html func planTextToBool(src []byte) (bool, error) { s := string(bytes.ToLower(bytes.TrimSpace(src))) diff --git a/vendor/github.com/jackc/pgx/v5/pgtype/box.go b/vendor/github.com/jackc/pgx/v5/pgtype/box.go index 887d268b..d243f58e 100644 --- a/vendor/github.com/jackc/pgx/v5/pgtype/box.go +++ b/vendor/github.com/jackc/pgx/v5/pgtype/box.go @@ -24,16 +24,18 @@ type Box struct { Valid bool } +// ScanBox implements the [BoxScanner] interface. func (b *Box) ScanBox(v Box) error { *b = v return nil } +// BoxValue implements the [BoxValuer] interface. func (b Box) BoxValue() (Box, error) { return b, nil } -// Scan implements the database/sql Scanner interface. +// Scan implements the [database/sql.Scanner] interface. func (dst *Box) Scan(src any) error { if src == nil { *dst = Box{} @@ -48,7 +50,7 @@ func (dst *Box) Scan(src any) error { return fmt.Errorf("cannot scan %T", src) } -// Value implements the database/sql/driver Valuer interface. +// Value implements the [database/sql/driver.Valuer] interface. func (src Box) Value() (driver.Value, error) { if !src.Valid { return nil, nil @@ -127,7 +129,6 @@ func (encodePlanBoxCodecText) Encode(value any, buf []byte) (newBuf []byte, err } func (BoxCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan { - switch format { case BinaryFormatCode: switch target.(type) { diff --git a/vendor/github.com/jackc/pgx/v5/pgtype/builtin_wrappers.go b/vendor/github.com/jackc/pgx/v5/pgtype/builtin_wrappers.go index b39d3fa1..126e0be2 100644 --- a/vendor/github.com/jackc/pgx/v5/pgtype/builtin_wrappers.go +++ b/vendor/github.com/jackc/pgx/v5/pgtype/builtin_wrappers.go @@ -527,6 +527,7 @@ func (w *netIPNetWrapper) ScanNetipPrefix(v netip.Prefix) error { return nil } + func (w netIPNetWrapper) NetipPrefixValue() (netip.Prefix, error) { ip, ok := netip.AddrFromSlice(w.IP) if !ok { @@ -881,7 +882,6 @@ func (a *anyMultiDimSliceArray) SetDimensions(dimensions []ArrayDimension) error return nil } - } func (a *anyMultiDimSliceArray) makeMultidimensionalSlice(sliceType reflect.Type, dimensions []ArrayDimension, flatSlice reflect.Value, flatSliceIdx int) reflect.Value { @@ -892,7 +892,7 @@ func (a *anyMultiDimSliceArray) makeMultidimensionalSlice(sliceType reflect.Type sliceLen := int(dimensions[0].Length) slice := reflect.MakeSlice(sliceType, sliceLen, sliceLen) - for i := 0; i < sliceLen; i++ { + for i := range sliceLen { subSlice := a.makeMultidimensionalSlice(sliceType.Elem(), dimensions[1:], flatSlice, flatSliceIdx+(i*int(dimensions[1].Length))) slice.Index(i).Set(subSlice) } diff --git a/vendor/github.com/jackc/pgx/v5/pgtype/bytea.go b/vendor/github.com/jackc/pgx/v5/pgtype/bytea.go index a247705e..6c4f0c5e 100644 --- a/vendor/github.com/jackc/pgx/v5/pgtype/bytea.go +++ b/vendor/github.com/jackc/pgx/v5/pgtype/bytea.go @@ -148,7 +148,6 @@ func (encodePlanBytesCodecTextBytesValuer) Encode(value any, buf []byte) (newBuf } func (ByteaCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan { - switch format { case BinaryFormatCode: switch target.(type) { diff --git a/vendor/github.com/jackc/pgx/v5/pgtype/circle.go b/vendor/github.com/jackc/pgx/v5/pgtype/circle.go index e8f118cc..fb9b4c11 100644 --- a/vendor/github.com/jackc/pgx/v5/pgtype/circle.go +++ b/vendor/github.com/jackc/pgx/v5/pgtype/circle.go @@ -25,16 +25,18 @@ type Circle struct { Valid bool } +// ScanCircle implements the [CircleScanner] interface. func (c *Circle) ScanCircle(v Circle) error { *c = v return nil } +// CircleValue implements the [CircleValuer] interface. func (c Circle) CircleValue() (Circle, error) { return c, nil } -// Scan implements the database/sql Scanner interface. +// Scan implements the [database/sql.Scanner] interface. func (dst *Circle) Scan(src any) error { if src == nil { *dst = Circle{} @@ -49,7 +51,7 @@ func (dst *Circle) Scan(src any) error { return fmt.Errorf("cannot scan %T", src) } -// Value implements the database/sql/driver Valuer interface. +// Value implements the [database/sql/driver.Valuer] interface. func (src Circle) Value() (driver.Value, error) { if !src.Valid { return nil, nil diff --git a/vendor/github.com/jackc/pgx/v5/pgtype/composite.go b/vendor/github.com/jackc/pgx/v5/pgtype/composite.go index fb372325..4667036b 100644 --- a/vendor/github.com/jackc/pgx/v5/pgtype/composite.go +++ b/vendor/github.com/jackc/pgx/v5/pgtype/composite.go @@ -276,7 +276,6 @@ func (c *CompositeCodec) DecodeValue(m *Map, oid uint32, format int16, src []byt default: return nil, fmt.Errorf("unknown format code %d", format) } - } type CompositeBinaryScanner struct { @@ -290,7 +289,7 @@ type CompositeBinaryScanner struct { err error } -// NewCompositeBinaryScanner a scanner over a binary encoded composite balue. +// NewCompositeBinaryScanner a scanner over a binary encoded composite value. func NewCompositeBinaryScanner(m *Map, src []byte) *CompositeBinaryScanner { rp := 0 if len(src[rp:]) < 4 { @@ -477,7 +476,7 @@ func (b *CompositeBinaryBuilder) AppendValue(oid uint32, field any) { return } - if field == nil { + if isNil, _ := isNilDriverValuer(field); isNil { b.buf = pgio.AppendUint32(b.buf, oid) b.buf = pgio.AppendInt32(b.buf, -1) b.fieldCount++ @@ -534,7 +533,7 @@ func (b *CompositeTextBuilder) AppendValue(oid uint32, field any) { return } - if field == nil { + if isNil, _ := isNilDriverValuer(field); isNil { b.buf = append(b.buf, ',') return } diff --git a/vendor/github.com/jackc/pgx/v5/pgtype/convert.go b/vendor/github.com/jackc/pgx/v5/pgtype/convert.go index 8a9cee9c..5cfc0ea3 100644 --- a/vendor/github.com/jackc/pgx/v5/pgtype/convert.go +++ b/vendor/github.com/jackc/pgx/v5/pgtype/convert.go @@ -90,19 +90,19 @@ func GetAssignToDstType(dst any) (any, bool) { func init() { kindTypes = map[reflect.Kind]reflect.Type{ - reflect.Bool: reflect.TypeOf(false), - reflect.Float32: reflect.TypeOf(float32(0)), - reflect.Float64: reflect.TypeOf(float64(0)), - reflect.Int: reflect.TypeOf(int(0)), - reflect.Int8: reflect.TypeOf(int8(0)), - reflect.Int16: reflect.TypeOf(int16(0)), - reflect.Int32: reflect.TypeOf(int32(0)), - reflect.Int64: reflect.TypeOf(int64(0)), - reflect.Uint: reflect.TypeOf(uint(0)), - reflect.Uint8: reflect.TypeOf(uint8(0)), - reflect.Uint16: reflect.TypeOf(uint16(0)), - reflect.Uint32: reflect.TypeOf(uint32(0)), - reflect.Uint64: reflect.TypeOf(uint64(0)), - reflect.String: reflect.TypeOf(""), + reflect.Bool: reflect.TypeFor[bool](), + reflect.Float32: reflect.TypeFor[float32](), + reflect.Float64: reflect.TypeFor[float64](), + reflect.Int: reflect.TypeFor[int](), + reflect.Int8: reflect.TypeFor[int8](), + reflect.Int16: reflect.TypeFor[int16](), + reflect.Int32: reflect.TypeFor[int32](), + reflect.Int64: reflect.TypeFor[int64](), + reflect.Uint: reflect.TypeFor[uint](), + reflect.Uint8: reflect.TypeFor[uint8](), + reflect.Uint16: reflect.TypeFor[uint16](), + reflect.Uint32: reflect.TypeFor[uint32](), + reflect.Uint64: reflect.TypeFor[uint64](), + reflect.String: reflect.TypeFor[string](), } } diff --git a/vendor/github.com/jackc/pgx/v5/pgtype/date.go b/vendor/github.com/jackc/pgx/v5/pgtype/date.go index 784b16de..68c9585e 100644 --- a/vendor/github.com/jackc/pgx/v5/pgtype/date.go +++ b/vendor/github.com/jackc/pgx/v5/pgtype/date.go @@ -5,7 +5,6 @@ import ( "encoding/binary" "encoding/json" "fmt" - "regexp" "strconv" "time" @@ -26,11 +25,13 @@ type Date struct { Valid bool } +// ScanDate implements the [DateScanner] interface. func (d *Date) ScanDate(v Date) error { *d = v return nil } +// DateValue implements the [DateValuer] interface. func (d Date) DateValue() (Date, error) { return d, nil } @@ -40,7 +41,7 @@ const ( infinityDayOffset = 2147483647 ) -// Scan implements the database/sql Scanner interface. +// Scan implements the [database/sql.Scanner] interface. func (dst *Date) Scan(src any) error { if src == nil { *dst = Date{} @@ -58,7 +59,7 @@ func (dst *Date) Scan(src any) error { return fmt.Errorf("cannot scan %T", src) } -// Value implements the database/sql/driver Valuer interface. +// Value implements the [database/sql/driver.Valuer] interface. func (src Date) Value() (driver.Value, error) { if !src.Valid { return nil, nil @@ -70,6 +71,7 @@ func (src Date) Value() (driver.Value, error) { return src.Time, nil } +// MarshalJSON implements the [encoding/json.Marshaler] interface. func (src Date) MarshalJSON() ([]byte, error) { if !src.Valid { return []byte("null"), nil @@ -89,6 +91,7 @@ func (src Date) MarshalJSON() ([]byte, error) { return json.Marshal(s) } +// UnmarshalJSON implements the [encoding/json.Unmarshaler] interface. func (dst *Date) UnmarshalJSON(b []byte) error { var s *string err := json.Unmarshal(b, &s) @@ -223,7 +226,6 @@ func (encodePlanDateCodecText) Encode(value any, buf []byte) (newBuf []byte, err } func (DateCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan { - switch format { case BinaryFormatCode: switch target.(type) { @@ -268,8 +270,6 @@ func (scanPlanBinaryDateToDateScanner) Scan(src []byte, dst any) error { type scanPlanTextAnyToDateScanner struct{} -var dateRegexp = regexp.MustCompile(`^(\d{4,})-(\d\d)-(\d\d)( BC)?$`) - func (scanPlanTextAnyToDateScanner) Scan(src []byte, dst any) error { scanner := (dst).(DateScanner) @@ -277,41 +277,104 @@ func (scanPlanTextAnyToDateScanner) Scan(src []byte, dst any) error { return scanner.ScanDate(Date{}) } - sbuf := string(src) - match := dateRegexp.FindStringSubmatch(sbuf) - if match != nil { - year, err := strconv.ParseInt(match[1], 10, 32) - if err != nil { - return fmt.Errorf("BUG: cannot parse date that regexp matched (year): %w", err) - } + // Check infinity cases first + if len(src) == 8 && string(src) == "infinity" { + return scanner.ScanDate(Date{InfinityModifier: Infinity, Valid: true}) + } + if len(src) == 9 && string(src) == "-infinity" { + return scanner.ScanDate(Date{InfinityModifier: -Infinity, Valid: true}) + } - month, err := strconv.ParseInt(match[2], 10, 32) - if err != nil { - return fmt.Errorf("BUG: cannot parse date that regexp matched (month): %w", err) - } + // Format: YYYY-MM-DD or YYYY...-MM-DD BC + // Minimum: 10 chars (2000-01-01), with BC: 13 chars + if len(src) < 10 { + return fmt.Errorf("invalid date format") + } - day, err := strconv.ParseInt(match[3], 10, 32) - if err != nil { - return fmt.Errorf("BUG: cannot parse date that regexp matched (month): %w", err) - } + // Check for BC suffix + bc := false + datePart := src + if len(src) >= 13 && string(src[len(src)-3:]) == " BC" { + bc = true + datePart = src[:len(src)-3] + } - // BC matched - if len(match[4]) > 0 { - year = -year + 1 + // Find year-month separator (first dash after at least 4 digits) + yearEnd := -1 + for i := 4; i < len(datePart); i++ { + if datePart[i] == '-' { + yearEnd = i + break + } + if datePart[i] < '0' || datePart[i] > '9' { + return fmt.Errorf("invalid date format") } + } + if yearEnd == -1 || yearEnd+6 > len(datePart) { + return fmt.Errorf("invalid date format") + } - t := time.Date(int(year), time.Month(month), int(day), 0, 0, 0, 0, time.UTC) - return scanner.ScanDate(Date{Time: t, Valid: true}) + // Validate: -MM-DD structure after year + if datePart[yearEnd+3] != '-' { + return fmt.Errorf("invalid date format") } - switch sbuf { - case "infinity": - return scanner.ScanDate(Date{InfinityModifier: Infinity, Valid: true}) - case "-infinity": - return scanner.ScanDate(Date{InfinityModifier: -Infinity, Valid: true}) - default: + // Parse year + year, err := parseDigits(datePart[:yearEnd]) + if err != nil { + return fmt.Errorf("invalid date format") + } + + // Parse month (2 digits) + month, err := parse2Digits(datePart[yearEnd+1 : yearEnd+3]) + if err != nil { + return fmt.Errorf("invalid date format") + } + + // Parse day (2 digits) + day, err := parse2Digits(datePart[yearEnd+4 : yearEnd+6]) + if err != nil { + return fmt.Errorf("invalid date format") + } + + // Ensure nothing extra after day + if yearEnd+6 != len(datePart) { return fmt.Errorf("invalid date format") } + + if bc { + year = -year + 1 + } + + t := time.Date(int(year), time.Month(month), int(day), 0, 0, 0, 0, time.UTC) + return scanner.ScanDate(Date{Time: t, Valid: true}) +} + +// parse2Digits parses exactly 2 ASCII digits. +func parse2Digits(b []byte) (int64, error) { + if len(b) != 2 { + return 0, fmt.Errorf("expected 2 digits") + } + d1, d2 := b[0], b[1] + if d1 < '0' || d1 > '9' || d2 < '0' || d2 > '9' { + return 0, fmt.Errorf("expected digits") + } + return int64(d1-'0')*10 + int64(d2-'0'), nil +} + +// parseDigits parses a sequence of ASCII digits. +func parseDigits(b []byte) (int64, error) { + if len(b) == 0 { + return 0, fmt.Errorf("empty") + } + var n int64 + for _, c := range b { + if c < '0' || c > '9' { + return 0, fmt.Errorf("non-digit") + } + n = n*10 + int64(c-'0') + } + return n, nil } func (c DateCodec) DecodeDatabaseSQLValue(m *Map, oid uint32, format int16, src []byte) (driver.Value, error) { diff --git a/vendor/github.com/jackc/pgx/v5/pgtype/doc.go b/vendor/github.com/jackc/pgx/v5/pgtype/doc.go index 7687ea8f..dbcdf692 100644 --- a/vendor/github.com/jackc/pgx/v5/pgtype/doc.go +++ b/vendor/github.com/jackc/pgx/v5/pgtype/doc.go @@ -1,10 +1,10 @@ // Package pgtype converts between Go and PostgreSQL values. /* -The primary type is the Map type. It is a map of PostgreSQL types identified by OID (object ID) to a Codec. A Codec is -responsible for converting between Go and PostgreSQL values. NewMap creates a Map with all supported standard PostgreSQL -types already registered. Additional types can be registered with Map.RegisterType. +The primary type is the [Map] type. It is a map of PostgreSQL types identified by OID (object ID) to a [Codec]. A [Codec] is +responsible for converting between Go and PostgreSQL values. [NewMap] creates a [Map] with all supported standard PostgreSQL +types already registered. Additional types can be registered with [Map.RegisterType]. -Use Map.Scan and Map.Encode to decode PostgreSQL values to Go and encode Go values to PostgreSQL respectively. +Use [Map.Scan] and [Map.Encode] to decode PostgreSQL values to Go and encode Go values to PostgreSQL respectively. Base Type Mapping @@ -53,8 +53,8 @@ similar fashion to database/sql. The second is to use a pointer to a pointer. return err } -When using nullable pgtype types as parameters for queries, one has to remember -to explicitly set their Valid field to true, otherwise the parameter's value will be NULL. +When using nullable pgtype types as parameters for queries, one has to remember to explicitly set their Valid field to +true, otherwise the parameter's value will be NULL. JSON Support @@ -63,8 +63,8 @@ pgtype automatically marshals and unmarshals data from json and jsonb PostgreSQL Extending Existing PostgreSQL Type Support Generally, all Codecs will support interfaces that can be implemented to enable scanning and encoding. For example, -PointCodec can use any Go type that implements the PointScanner and PointValuer interfaces. So rather than use -pgtype.Point and application can directly use its own point type with pgtype as long as it implements those interfaces. +[PointCodec] can use any Go type that implements the [PointScanner] and [PointValuer] interfaces. So rather than use +[Point] an application can directly use its own point type with pgtype as long as it implements those interfaces. See example_custom_type_test.go for an example of a custom type for the PostgreSQL point type. @@ -77,10 +77,10 @@ New PostgreSQL Type Support pgtype uses the PostgreSQL OID to determine how to encode or decode a value. pgtype supports array, composite, domain, and enum types. However, any type created in PostgreSQL with CREATE TYPE will receive a new OID. This means that the OID -of each new PostgreSQL type must be registered for pgtype to handle values of that type with the correct Codec. +of each new PostgreSQL type must be registered for pgtype to handle values of that type with the correct [Codec]. -The pgx.Conn LoadType method can return a *Type for array, composite, domain, and enum types by inspecting the database -metadata. This *Type can then be registered with Map.RegisterType. +The [github.com/jackc/pgx/v5.Conn.LoadType] method can return a [*Type] for array, composite, domain, and enum types by +inspecting the database metadata. This [*Type] can then be registered with [Map.RegisterType]. For example, the following function could be called after a connection is established: @@ -106,30 +106,30 @@ For example, the following function could be called after a connection is establ A type cannot be registered unless all types it depends on are already registered. e.g. An array type cannot be registered until its element type is registered. -ArrayCodec implements support for arrays. If pgtype supports type T then it can easily support []T by registering an -ArrayCodec for the appropriate PostgreSQL OID. In addition, Array[T] type can support multi-dimensional arrays. +[ArrayCodec] implements support for arrays. If pgtype supports type T then it can easily support []T by registering an +[ArrayCodec] for the appropriate PostgreSQL OID. In addition, [Array] type can support multi-dimensional arrays. -CompositeCodec implements support for PostgreSQL composite types. Go structs can be scanned into if the public fields of -the struct are in the exact order and type of the PostgreSQL type or by implementing CompositeIndexScanner and -CompositeIndexGetter. +[CompositeCodec] implements support for PostgreSQL composite types. Go structs can be scanned into if the public fields of +the struct are in the exact order and type of the PostgreSQL type or by implementing [CompositeIndexScanner] and +[CompositeIndexGetter]. Domain types are treated as their underlying type if the underlying type and the domain type are registered. -PostgreSQL enums can usually be treated as text. However, EnumCodec implements support for interning strings which can +PostgreSQL enums can usually be treated as text. However, [EnumCodec] implements support for interning strings which can reduce memory usage. While pgtype will often still work with unregistered types it is highly recommended that all types be registered due to an improvement in performance and the elimination of certain edge cases. If an entirely new PostgreSQL type (e.g. PostGIS types) is used then the application or a library can create a new -Codec. Then the OID / Codec mapping can be registered with Map.RegisterType. There is no difference between a Codec -defined and registered by the application and a Codec built in to pgtype. See any of the Codecs in pgtype for Codec +[Codec]. Then the OID / [Codec] mapping can be registered with [Map.RegisterType]. There is no difference between a [Codec] +defined and registered by the application and a [Codec] built in to pgtype. See any of the [Codec]s in pgtype for [Codec] examples and for examples of type registration. Encoding Unknown Types pgtype works best when the OID of the PostgreSQL type is known. But in some cases such as using the simple protocol the -OID is unknown. In this case Map.RegisterDefaultPgType can be used to register an assumed OID for a particular Go type. +OID is unknown. In this case [Map.RegisterDefaultPgType] can be used to register an assumed OID for a particular Go type. Renamed Types @@ -137,18 +137,18 @@ If pgtype does not recognize a type and that type is a renamed simple type simpl as if it is the underlying type. It currently cannot automatically detect the underlying type of renamed structs (eg.g. type MyTime time.Time). -Compatibility with database/sql +Compatibility with [database/sql] -pgtype also includes support for custom types implementing the database/sql.Scanner and database/sql/driver.Valuer +pgtype also includes support for custom types implementing the [database/sql.Scanner] and [database/sql/driver.Valuer] interfaces. Encoding Typed Nils -pgtype encodes untyped and typed nils (e.g. nil and []byte(nil)) to the SQL NULL value without going through the Codec -system. This means that Codecs and other encoding logic do not have to handle nil or *T(nil). +pgtype encodes untyped and typed nils (e.g. nil and []byte(nil)) to the SQL NULL value without going through the [Codec] +system. This means that [Codec]s and other encoding logic do not have to handle nil or *T(nil). -However, database/sql compatibility requires Value to be called on T(nil) when T implements driver.Valuer. Therefore, -driver.Valuer values are only considered NULL when *T(nil) where driver.Valuer is implemented on T not on *T. See +However, [database/sql] compatibility requires Value to be called on T(nil) when T implements [database/sql/driver.Valuer]. Therefore, +[database/sql/driver.Valuer] values are only considered NULL when *T(nil) where [database/sql/driver.Valuer] is implemented on T not on *T. See https://github.com/golang/go/issues/8415 and https://github.com/golang/go/commit/0ce1d79a6a771f7449ec493b993ed2a720917870. @@ -159,33 +159,38 @@ example_child_records_test.go for an example. Overview of Scanning Implementation -The first step is to use the OID to lookup the correct Codec. If the OID is unavailable, Map will try to find the OID -from previous calls of Map.RegisterDefaultPgType. The Map will call the Codec's PlanScan method to get a plan for -scanning into the Go value. A Codec will support scanning into one or more Go types. Oftentime these Go types are -interfaces rather than explicit types. For example, PointCodec can use any Go type that implements the PointScanner and -PointValuer interfaces. +The first step is to use the OID to lookup the correct [Codec]. The [Map] will call the [Codec.PlanScan] method to get a +plan for scanning into the Go value. A [Codec] will support scanning into one or more Go types. Oftentime these Go types +are interfaces rather than explicit types. For example, [PointCodec] can use any Go type that implements the [PointScanner] +and [PointValuer] interfaces. -If a Go value is not supported directly by a Codec then Map will try wrapping it with additional logic and try again. -For example, Int8Codec does not support scanning into a renamed type (e.g. type myInt64 int64). But Map will detect that +If a Go value is not supported directly by a [Codec] then [Map] will try see if it is a [database/sql.Scanner]. If is then that +interface will be used to scan the value. Most [database/sql.Scanner]s require the input to be in the text format (e.g. UUIDs and +numeric). However, pgx will typically have received the value in the binary format. In this case the binary value will be +parsed, reencoded as text, and then passed to the [database/sql.Scanner]. This may incur additional overhead for query results with +a large number of affected values. + +If a Go value is not supported directly by a [Codec] then [Map] will try wrapping it with additional logic and try again. +For example, [Int8Codec] does not support scanning into a renamed type (e.g. type myInt64 int64). But [Map] will detect that myInt64 is a renamed type and create a plan that converts the value to the underlying int64 type and then passes that to -the Codec (see TryFindUnderlyingTypeScanPlan). +the [Codec] (see [TryFindUnderlyingTypeScanPlan]). -These plan wrappers are contained in Map.TryWrapScanPlanFuncs. By default these contain shared logic to handle renamed +These plan wrappers are contained in [Map.TryWrapScanPlanFuncs]. By default these contain shared logic to handle renamed types, pointers to pointers, slices, composite types, etc. Additional plan wrappers can be added to seamlessly integrate types that do not support pgx directly. For example, the before mentioned https://github.com/jackc/pgx-shopspring-decimal package detects decimal.Decimal values, wraps them in something -implementing NumericScanner and passes that to the Codec. +implementing [NumericScanner] and passes that to the [Codec]. -Map.Scan and Map.Encode are convenience methods that wrap Map.PlanScan and Map.PlanEncode. Determining how to scan or +[Map.Scan] and [Map.Encode] are convenience methods that wrap [Map.PlanScan] and [Map.PlanEncode]. Determining how to scan or encode a particular type may be a time consuming operation. Hence the planning and execution steps of a conversion are internally separated. Reducing Compiled Binary Size -pgx.QueryExecModeExec and pgx.QueryExecModeSimpleProtocol require the default PostgreSQL type to be registered for each -Go type used as a query parameter. By default pgx does this for all supported types and their array variants. If an -application does not use those query execution modes or manually registers the default PostgreSQL type for the types it -uses as query parameters it can use the build tag nopgxregisterdefaulttypes. This omits the default type registration -and reduces the compiled binary size by ~2MB. +[github.com/jackc/pgx/v5.QueryExecModeExec] and [github.com/jackc/pgx/v5.QueryExecModeSimpleProtocol] require the default +PostgreSQL type to be registered for each Go type used as a query parameter. By default pgx does this for all supported +types and their array variants. If an application does not use those query execution modes or manually registers the default +PostgreSQL type for the types it uses as query parameters it can use the build tag nopgxregisterdefaulttypes. This omits +the default type registration and reduces the compiled binary size by ~2MB. */ package pgtype diff --git a/vendor/github.com/jackc/pgx/v5/pgtype/float4.go b/vendor/github.com/jackc/pgx/v5/pgtype/float4.go index 8646d9d2..241a25ad 100644 --- a/vendor/github.com/jackc/pgx/v5/pgtype/float4.go +++ b/vendor/github.com/jackc/pgx/v5/pgtype/float4.go @@ -16,26 +16,29 @@ type Float4 struct { Valid bool } -// ScanFloat64 implements the Float64Scanner interface. +// ScanFloat64 implements the [Float64Scanner] interface. func (f *Float4) ScanFloat64(n Float8) error { *f = Float4{Float32: float32(n.Float64), Valid: n.Valid} return nil } +// Float64Value implements the [Float64Valuer] interface. func (f Float4) Float64Value() (Float8, error) { return Float8{Float64: float64(f.Float32), Valid: f.Valid}, nil } +// ScanInt64 implements the [Int64Scanner] interface. func (f *Float4) ScanInt64(n Int8) error { *f = Float4{Float32: float32(n.Int64), Valid: n.Valid} return nil } +// Int64Value implements the [Int64Valuer] interface. func (f Float4) Int64Value() (Int8, error) { return Int8{Int64: int64(f.Float32), Valid: f.Valid}, nil } -// Scan implements the database/sql Scanner interface. +// Scan implements the [database/sql.Scanner] interface. func (f *Float4) Scan(src any) error { if src == nil { *f = Float4{} @@ -58,7 +61,7 @@ func (f *Float4) Scan(src any) error { return fmt.Errorf("cannot scan %T", src) } -// Value implements the database/sql/driver Valuer interface. +// Value implements the [database/sql/driver.Valuer] interface. func (f Float4) Value() (driver.Value, error) { if !f.Valid { return nil, nil @@ -66,6 +69,7 @@ func (f Float4) Value() (driver.Value, error) { return float64(f.Float32), nil } +// MarshalJSON implements the [encoding/json.Marshaler] interface. func (f Float4) MarshalJSON() ([]byte, error) { if !f.Valid { return []byte("null"), nil @@ -73,6 +77,7 @@ func (f Float4) MarshalJSON() ([]byte, error) { return json.Marshal(f.Float32) } +// UnmarshalJSON implements the [encoding/json.Unmarshaler] interface. func (f *Float4) UnmarshalJSON(b []byte) error { var n *float32 err := json.Unmarshal(b, &n) @@ -170,7 +175,6 @@ func (encodePlanFloat4CodecBinaryInt64Valuer) Encode(value any, buf []byte) (new } func (Float4Codec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan { - switch format { case BinaryFormatCode: switch target.(type) { diff --git a/vendor/github.com/jackc/pgx/v5/pgtype/float8.go b/vendor/github.com/jackc/pgx/v5/pgtype/float8.go index 9c923c9a..54d6781e 100644 --- a/vendor/github.com/jackc/pgx/v5/pgtype/float8.go +++ b/vendor/github.com/jackc/pgx/v5/pgtype/float8.go @@ -24,26 +24,29 @@ type Float8 struct { Valid bool } -// ScanFloat64 implements the Float64Scanner interface. +// ScanFloat64 implements the [Float64Scanner] interface. func (f *Float8) ScanFloat64(n Float8) error { *f = n return nil } +// Float64Value implements the [Float64Valuer] interface. func (f Float8) Float64Value() (Float8, error) { return f, nil } +// ScanInt64 implements the [Int64Scanner] interface. func (f *Float8) ScanInt64(n Int8) error { *f = Float8{Float64: float64(n.Int64), Valid: n.Valid} return nil } +// Int64Value implements the [Int64Valuer] interface. func (f Float8) Int64Value() (Int8, error) { return Int8{Int64: int64(f.Float64), Valid: f.Valid}, nil } -// Scan implements the database/sql Scanner interface. +// Scan implements the [database/sql.Scanner] interface. func (f *Float8) Scan(src any) error { if src == nil { *f = Float8{} @@ -66,7 +69,7 @@ func (f *Float8) Scan(src any) error { return fmt.Errorf("cannot scan %T", src) } -// Value implements the database/sql/driver Valuer interface. +// Value implements the [database/sql/driver.Valuer] interface. func (f Float8) Value() (driver.Value, error) { if !f.Valid { return nil, nil @@ -74,6 +77,7 @@ func (f Float8) Value() (driver.Value, error) { return f.Float64, nil } +// MarshalJSON implements the [encoding/json.Marshaler] interface. func (f Float8) MarshalJSON() ([]byte, error) { if !f.Valid { return []byte("null"), nil @@ -81,6 +85,7 @@ func (f Float8) MarshalJSON() ([]byte, error) { return json.Marshal(f.Float64) } +// UnmarshalJSON implements the [encoding/json.Unmarshaler] interface. func (f *Float8) UnmarshalJSON(b []byte) error { var n *float64 err := json.Unmarshal(b, &n) @@ -208,7 +213,6 @@ func (encodePlanTextInt64Valuer) Encode(value any, buf []byte) (newBuf []byte, e } func (Float8Codec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan { - switch format { case BinaryFormatCode: switch target.(type) { diff --git a/vendor/github.com/jackc/pgx/v5/pgtype/hstore.go b/vendor/github.com/jackc/pgx/v5/pgtype/hstore.go index 2f34f4c9..c5fa22c6 100644 --- a/vendor/github.com/jackc/pgx/v5/pgtype/hstore.go +++ b/vendor/github.com/jackc/pgx/v5/pgtype/hstore.go @@ -22,16 +22,18 @@ type HstoreValuer interface { // associated with its keys. type Hstore map[string]*string +// ScanHstore implements the [HstoreScanner] interface. func (h *Hstore) ScanHstore(v Hstore) error { *h = v return nil } +// HstoreValue implements the [HstoreValuer] interface. func (h Hstore) HstoreValue() (Hstore, error) { return h, nil } -// Scan implements the database/sql Scanner interface. +// Scan implements the [database/sql.Scanner] interface. func (h *Hstore) Scan(src any) error { if src == nil { *h = nil @@ -46,7 +48,7 @@ func (h *Hstore) Scan(src any) error { return fmt.Errorf("cannot scan %T", src) } -// Value implements the database/sql/driver Valuer interface. +// Value implements the [database/sql/driver.Valuer] interface. func (h Hstore) Value() (driver.Value, error) { if h == nil { return nil, nil @@ -162,7 +164,6 @@ func (encodePlanHstoreCodecText) Encode(value any, buf []byte) (newBuf []byte, e } func (HstoreCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan { - switch format { case BinaryFormatCode: switch target.(type) { @@ -197,17 +198,24 @@ func (scanPlanBinaryHstoreToHstoreScanner) Scan(src []byte, dst any) error { pairCount := int(int32(binary.BigEndian.Uint32(src[rp:]))) rp += uint32Len + if pairCount < 0 { + return fmt.Errorf("hstore invalid pair count: %d", pairCount) + } + hstore := make(Hstore, pairCount) // one allocation for all *string, rather than one per string, just like text parsing valueStrings := make([]string, pairCount) - for i := 0; i < pairCount; i++ { + for i := range pairCount { if len(src[rp:]) < uint32Len { return fmt.Errorf("hstore incomplete %v", src) } keyLen := int(int32(binary.BigEndian.Uint32(src[rp:]))) rp += uint32Len + if keyLen < 0 { + return fmt.Errorf("hstore invalid key length: %d", keyLen) + } if len(src[rp:]) < keyLen { return fmt.Errorf("hstore incomplete %v", src) } @@ -298,7 +306,7 @@ func (p *hstoreParser) consume() (b byte, end bool) { return b, false } -func unexpectedByteErr(actualB byte, expectedB byte) error { +func unexpectedByteErr(actualB, expectedB byte) error { return fmt.Errorf("expected '%c' ('%#v'); found '%c' ('%#v')", expectedB, expectedB, actualB, actualB) } @@ -316,7 +324,7 @@ func (p *hstoreParser) consumeExpectedByte(expectedB byte) error { // consumeExpected2 consumes two expected bytes or returns an error. // This was a bit faster than using a string argument (better inlining? Not sure). -func (p *hstoreParser) consumeExpected2(one byte, two byte) error { +func (p *hstoreParser) consumeExpected2(one, two byte) error { if p.pos+2 > len(p.str) { return errors.New("unexpected end of string") } diff --git a/vendor/github.com/jackc/pgx/v5/pgtype/inet.go b/vendor/github.com/jackc/pgx/v5/pgtype/inet.go index 6ca10ea0..b92edb23 100644 --- a/vendor/github.com/jackc/pgx/v5/pgtype/inet.go +++ b/vendor/github.com/jackc/pgx/v5/pgtype/inet.go @@ -24,7 +24,7 @@ type NetipPrefixValuer interface { NetipPrefixValue() (netip.Prefix, error) } -// InetCodec handles both inet and cidr PostgreSQL types. The preferred Go types are netip.Prefix and netip.Addr. If +// InetCodec handles both inet and cidr PostgreSQL types. The preferred Go types are [netip.Prefix] and [netip.Addr]. If // IsValid() is false then they are treated as SQL NULL. type InetCodec struct{} @@ -107,7 +107,6 @@ func (encodePlanInetCodecText) Encode(value any, buf []byte) (newBuf []byte, err } func (InetCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan { - switch format { case BinaryFormatCode: switch target.(type) { diff --git a/vendor/github.com/jackc/pgx/v5/pgtype/int.go b/vendor/github.com/jackc/pgx/v5/pgtype/int.go index 90a20a26..95032e5a 100644 --- a/vendor/github.com/jackc/pgx/v5/pgtype/int.go +++ b/vendor/github.com/jackc/pgx/v5/pgtype/int.go @@ -1,4 +1,5 @@ -// Do not edit. Generated from pgtype/int.go.erb +// Code generated from pgtype/int.go.erb. DO NOT EDIT. + package pgtype import ( @@ -25,7 +26,7 @@ type Int2 struct { Valid bool } -// ScanInt64 implements the Int64Scanner interface. +// ScanInt64 implements the [Int64Scanner] interface. func (dst *Int2) ScanInt64(n Int8) error { if !n.Valid { *dst = Int2{} @@ -43,11 +44,12 @@ func (dst *Int2) ScanInt64(n Int8) error { return nil } +// Int64Value implements the [Int64Valuer] interface. func (n Int2) Int64Value() (Int8, error) { return Int8{Int64: int64(n.Int16), Valid: n.Valid}, nil } -// Scan implements the database/sql Scanner interface. +// Scan implements the [database/sql.Scanner] interface. func (dst *Int2) Scan(src any) error { if src == nil { *dst = Int2{} @@ -76,7 +78,7 @@ func (dst *Int2) Scan(src any) error { } if n < math.MinInt16 { - return fmt.Errorf("%d is greater than maximum value for Int2", n) + return fmt.Errorf("%d is less than minimum value for Int2", n) } if n > math.MaxInt16 { return fmt.Errorf("%d is greater than maximum value for Int2", n) @@ -86,7 +88,7 @@ func (dst *Int2) Scan(src any) error { return nil } -// Value implements the database/sql/driver Valuer interface. +// Value implements the [database/sql/driver.Valuer] interface. func (src Int2) Value() (driver.Value, error) { if !src.Valid { return nil, nil @@ -94,6 +96,7 @@ func (src Int2) Value() (driver.Value, error) { return int64(src.Int16), nil } +// MarshalJSON implements the [encoding/json.Marshaler] interface. func (src Int2) MarshalJSON() ([]byte, error) { if !src.Valid { return []byte("null"), nil @@ -101,6 +104,7 @@ func (src Int2) MarshalJSON() ([]byte, error) { return []byte(strconv.FormatInt(int64(src.Int16), 10)), nil } +// UnmarshalJSON implements the [encoding/json.Unmarshaler] interface. func (dst *Int2) UnmarshalJSON(b []byte) error { var n *int16 err := json.Unmarshal(b, &n) @@ -585,7 +589,7 @@ type Int4 struct { Valid bool } -// ScanInt64 implements the Int64Scanner interface. +// ScanInt64 implements the [Int64Scanner] interface. func (dst *Int4) ScanInt64(n Int8) error { if !n.Valid { *dst = Int4{} @@ -603,11 +607,12 @@ func (dst *Int4) ScanInt64(n Int8) error { return nil } +// Int64Value implements the [Int64Valuer] interface. func (n Int4) Int64Value() (Int8, error) { return Int8{Int64: int64(n.Int32), Valid: n.Valid}, nil } -// Scan implements the database/sql Scanner interface. +// Scan implements the [database/sql.Scanner] interface. func (dst *Int4) Scan(src any) error { if src == nil { *dst = Int4{} @@ -636,7 +641,7 @@ func (dst *Int4) Scan(src any) error { } if n < math.MinInt32 { - return fmt.Errorf("%d is greater than maximum value for Int4", n) + return fmt.Errorf("%d is less than minimum value for Int4", n) } if n > math.MaxInt32 { return fmt.Errorf("%d is greater than maximum value for Int4", n) @@ -646,7 +651,7 @@ func (dst *Int4) Scan(src any) error { return nil } -// Value implements the database/sql/driver Valuer interface. +// Value implements the [database/sql/driver.Valuer] interface. func (src Int4) Value() (driver.Value, error) { if !src.Valid { return nil, nil @@ -654,6 +659,7 @@ func (src Int4) Value() (driver.Value, error) { return int64(src.Int32), nil } +// MarshalJSON implements the [encoding/json.Marshaler] interface. func (src Int4) MarshalJSON() ([]byte, error) { if !src.Valid { return []byte("null"), nil @@ -661,6 +667,7 @@ func (src Int4) MarshalJSON() ([]byte, error) { return []byte(strconv.FormatInt(int64(src.Int32), 10)), nil } +// UnmarshalJSON implements the [encoding/json.Unmarshaler] interface. func (dst *Int4) UnmarshalJSON(b []byte) error { var n *int32 err := json.Unmarshal(b, &n) @@ -1156,7 +1163,7 @@ type Int8 struct { Valid bool } -// ScanInt64 implements the Int64Scanner interface. +// ScanInt64 implements the [Int64Scanner] interface. func (dst *Int8) ScanInt64(n Int8) error { if !n.Valid { *dst = Int8{} @@ -1174,11 +1181,12 @@ func (dst *Int8) ScanInt64(n Int8) error { return nil } +// Int64Value implements the [Int64Valuer] interface. func (n Int8) Int64Value() (Int8, error) { return Int8{Int64: int64(n.Int64), Valid: n.Valid}, nil } -// Scan implements the database/sql Scanner interface. +// Scan implements the [database/sql.Scanner] interface. func (dst *Int8) Scan(src any) error { if src == nil { *dst = Int8{} @@ -1217,7 +1225,7 @@ func (dst *Int8) Scan(src any) error { return nil } -// Value implements the database/sql/driver Valuer interface. +// Value implements the [database/sql/driver.Valuer] interface. func (src Int8) Value() (driver.Value, error) { if !src.Valid { return nil, nil @@ -1225,6 +1233,7 @@ func (src Int8) Value() (driver.Value, error) { return int64(src.Int64), nil } +// MarshalJSON implements the [encoding/json.Marshaler] interface. func (src Int8) MarshalJSON() ([]byte, error) { if !src.Valid { return []byte("null"), nil @@ -1232,6 +1241,7 @@ func (src Int8) MarshalJSON() ([]byte, error) { return []byte(strconv.FormatInt(int64(src.Int64), 10)), nil } +// UnmarshalJSON implements the [encoding/json.Unmarshaler] interface. func (dst *Int8) UnmarshalJSON(b []byte) error { var n *int64 err := json.Unmarshal(b, &n) diff --git a/vendor/github.com/jackc/pgx/v5/pgtype/int.go.erb b/vendor/github.com/jackc/pgx/v5/pgtype/int.go.erb index e0c8b7a3..c2d40f60 100644 --- a/vendor/github.com/jackc/pgx/v5/pgtype/int.go.erb +++ b/vendor/github.com/jackc/pgx/v5/pgtype/int.go.erb @@ -27,7 +27,7 @@ type Int<%= pg_byte_size %> struct { Valid bool } -// ScanInt64 implements the Int64Scanner interface. +// ScanInt64 implements the [Int64Scanner] interface. func (dst *Int<%= pg_byte_size %>) ScanInt64(n Int8) error { if !n.Valid { *dst = Int<%= pg_byte_size %>{} @@ -45,11 +45,12 @@ func (dst *Int<%= pg_byte_size %>) ScanInt64(n Int8) error { return nil } +// Int64Value implements the [Int64Valuer] interface. func (n Int<%= pg_byte_size %>) Int64Value() (Int8, error) { return Int8{Int64: int64(n.Int<%= pg_bit_size %>), Valid: n.Valid}, nil } -// Scan implements the database/sql Scanner interface. +// Scan implements the [database/sql.Scanner] interface. func (dst *Int<%= pg_byte_size %>) Scan(src any) error { if src == nil { *dst = Int<%= pg_byte_size %>{} @@ -88,7 +89,7 @@ func (dst *Int<%= pg_byte_size %>) Scan(src any) error { return nil } -// Value implements the database/sql/driver Valuer interface. +// Value implements the [database/sql/driver.Valuer] interface. func (src Int<%= pg_byte_size %>) Value() (driver.Value, error) { if !src.Valid { return nil, nil @@ -96,6 +97,7 @@ func (src Int<%= pg_byte_size %>) Value() (driver.Value, error) { return int64(src.Int<%= pg_bit_size %>), nil } +// MarshalJSON implements the [encoding/json.Marshaler] interface. func (src Int<%= pg_byte_size %>) MarshalJSON() ([]byte, error) { if !src.Valid { return []byte("null"), nil @@ -103,6 +105,7 @@ func (src Int<%= pg_byte_size %>) MarshalJSON() ([]byte, error) { return []byte(strconv.FormatInt(int64(src.Int<%= pg_bit_size %>), 10)), nil } +// UnmarshalJSON implements the [encoding/json.Unmarshaler] interface. func (dst *Int<%= pg_byte_size %>) UnmarshalJSON(b []byte) error { var n *int<%= pg_bit_size %> err := json.Unmarshal(b, &n) diff --git a/vendor/github.com/jackc/pgx/v5/pgtype/integration_benchmark_test.go.erb b/vendor/github.com/jackc/pgx/v5/pgtype/integration_benchmark_test.go.erb index 0175700a..6f401153 100644 --- a/vendor/github.com/jackc/pgx/v5/pgtype/integration_benchmark_test.go.erb +++ b/vendor/github.com/jackc/pgx/v5/pgtype/integration_benchmark_test.go.erb @@ -25,7 +25,7 @@ func BenchmarkQuery<%= format_name %>FormatDecode_PG_<%= pg_type %>_to_Go_<%= go rows, _ := conn.Query( ctx, `select <% columns.times do |col_idx| %><% if col_idx != 0 %>, <% end %>n::<%= pg_type %> + <%= col_idx%><% end %> from generate_series(1, <%= rows %>) n`, - []any{pgx.QueryResultFormats{<%= format_code %>}}, + pgx.QueryResultFormats{<%= format_code %>}, ) _, err := pgx.ForEachRow(rows, []any{<% columns.times do |col_idx| %><% if col_idx != 0 %>, <% end %>&v[<%= col_idx%>]<% end %>}, func() error { return nil }) if err != nil { @@ -49,7 +49,7 @@ func BenchmarkQuery<%= format_name %>FormatDecode_PG_Int4Array_With_Go_Int4Array rows, _ := conn.Query( ctx, `select array_agg(n) from generate_series(1, <%= array_size %>) n`, - []any{pgx.QueryResultFormats{<%= format_code %>}}, + pgx.QueryResultFormats{<%= format_code %>}, ) _, err := pgx.ForEachRow(rows, []any{&v}, func() error { return nil }) if err != nil { diff --git a/vendor/github.com/jackc/pgx/v5/pgtype/interval.go b/vendor/github.com/jackc/pgx/v5/pgtype/interval.go index 4b511629..b1bc7852 100644 --- a/vendor/github.com/jackc/pgx/v5/pgtype/interval.go +++ b/vendor/github.com/jackc/pgx/v5/pgtype/interval.go @@ -11,7 +11,7 @@ import ( ) const ( - microsecondsPerSecond = 1000000 + microsecondsPerSecond = 1_000_000 microsecondsPerMinute = 60 * microsecondsPerSecond microsecondsPerHour = 60 * microsecondsPerMinute microsecondsPerDay = 24 * microsecondsPerHour @@ -33,16 +33,18 @@ type Interval struct { Valid bool } +// ScanInterval implements the [IntervalScanner] interface. func (interval *Interval) ScanInterval(v Interval) error { *interval = v return nil } +// IntervalValue implements the [IntervalValuer] interface. func (interval Interval) IntervalValue() (Interval, error) { return interval, nil } -// Scan implements the database/sql Scanner interface. +// Scan implements the [database/sql.Scanner] interface. func (interval *Interval) Scan(src any) error { if src == nil { *interval = Interval{} @@ -57,7 +59,7 @@ func (interval *Interval) Scan(src any) error { return fmt.Errorf("cannot scan %T", src) } -// Value implements the database/sql/driver Valuer interface. +// Value implements the [database/sql/driver.Valuer] interface. func (interval Interval) Value() (driver.Value, error) { if !interval.Valid { return nil, nil @@ -157,7 +159,6 @@ func (encodePlanIntervalCodecText) Encode(value any, buf []byte) (newBuf []byte, } func (IntervalCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan { - switch format { case BinaryFormatCode: switch target.(type) { @@ -222,6 +223,8 @@ func (scanPlanTextAnyToIntervalScanner) Scan(src []byte, dst any) error { months += int32(scalar) case "day", "days": days = int32(scalar) + default: + return fmt.Errorf("bad interval format: %q", parts[i+1]) } } diff --git a/vendor/github.com/jackc/pgx/v5/pgtype/json.go b/vendor/github.com/jackc/pgx/v5/pgtype/json.go index c2aa0d3b..bf70735e 100644 --- a/vendor/github.com/jackc/pgx/v5/pgtype/json.go +++ b/vendor/github.com/jackc/pgx/v5/pgtype/json.go @@ -71,6 +71,27 @@ func (c *JSONCodec) PlanEncode(m *Map, oid uint32, format int16, value any) Enco } } +// JSON needs its on scan plan for pointers to handle 'null'::json(b). +// Consider making pointerPointerScanPlan more flexible in the future. +type jsonPointerScanPlan struct { + next ScanPlan +} + +func (p jsonPointerScanPlan) Scan(src []byte, dst any) error { + el := reflect.ValueOf(dst).Elem() + if src == nil || string(src) == "null" { + el.SetZero() + return nil + } + + el.Set(reflect.New(el.Type().Elem())) + if p.next != nil { + return p.next.Scan(src, el.Interface()) + } + + return nil +} + type encodePlanJSONCodecEitherFormatString struct{} func (encodePlanJSONCodecEitherFormatString) Encode(value any, buf []byte) (newBuf []byte, err error) { @@ -117,41 +138,35 @@ func (e *encodePlanJSONCodecEitherFormatMarshal) Encode(value any, buf []byte) ( return buf, nil } -func (c *JSONCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan { +func (c *JSONCodec) PlanScan(m *Map, oid uint32, formatCode int16, target any) ScanPlan { + return c.planScan(m, oid, formatCode, target, 0) +} + +// JSON cannot fallback to pointerPointerScanPlan because of 'null'::json(b), +// so we need to duplicate the logic here. +func (c *JSONCodec) planScan(m *Map, oid uint32, formatCode int16, target any, depth int) ScanPlan { + if depth > 8 { + return &scanPlanFail{m: m, oid: oid, formatCode: formatCode} + } + switch target.(type) { case *string: - return scanPlanAnyToString{} - - case **string: - // This is to fix **string scanning. It seems wrong to special case **string, but it's not clear what a better - // solution would be. - // - // https://github.com/jackc/pgx/issues/1470 -- **string - // https://github.com/jackc/pgx/issues/1691 -- ** anything else - - if wrapperPlan, nextDst, ok := TryPointerPointerScanPlan(target); ok { - if nextPlan := m.planScan(oid, format, nextDst); nextPlan != nil { - if _, failed := nextPlan.(*scanPlanFail); !failed { - wrapperPlan.SetNext(nextPlan) - return wrapperPlan - } - } - } - + return &scanPlanAnyToString{} case *[]byte: - return scanPlanJSONToByteSlice{} + return &scanPlanJSONToByteSlice{} case BytesScanner: - return scanPlanBinaryBytesToBytesScanner{} - - // Cannot rely on sql.Scanner being handled later because scanPlanJSONToJSONUnmarshal will take precedence. - // - // https://github.com/jackc/pgx/issues/1418 + return &scanPlanBinaryBytesToBytesScanner{} case sql.Scanner: - return &scanPlanSQLScanner{formatCode: format} + return &scanPlanCodecSQLScanner{c: c, m: m, oid: oid, formatCode: formatCode} } - return &scanPlanJSONToJSONUnmarshal{ - unmarshal: c.Unmarshal, + rv := reflect.ValueOf(target) + if rv.Kind() == reflect.Pointer && rv.Elem().Kind() == reflect.Pointer { + var plan jsonPointerScanPlan + plan.next = c.planScan(m, oid, formatCode, rv.Elem().Interface(), depth+1) + return plan + } else { + return &scanPlanJSONToJSONUnmarshal{unmarshal: c.Unmarshal} } } @@ -196,7 +211,12 @@ func (s *scanPlanJSONToJSONUnmarshal) Scan(src []byte, dst any) error { return fmt.Errorf("cannot scan NULL into %T", dst) } - elem := reflect.ValueOf(dst).Elem() + v := reflect.ValueOf(dst) + if v.Kind() != reflect.Pointer || v.IsNil() { + return fmt.Errorf("cannot scan into non-pointer or nil destinations %T", dst) + } + + elem := v.Elem() elem.Set(reflect.Zero(elem.Type())) return s.unmarshal(src, dst) diff --git a/vendor/github.com/jackc/pgx/v5/pgtype/line.go b/vendor/github.com/jackc/pgx/v5/pgtype/line.go index 4ae8003e..10efc8ce 100644 --- a/vendor/github.com/jackc/pgx/v5/pgtype/line.go +++ b/vendor/github.com/jackc/pgx/v5/pgtype/line.go @@ -24,11 +24,13 @@ type Line struct { Valid bool } +// ScanLine implements the [LineScanner] interface. func (line *Line) ScanLine(v Line) error { *line = v return nil } +// LineValue implements the [LineValuer] interface. func (line Line) LineValue() (Line, error) { return line, nil } @@ -37,7 +39,7 @@ func (line *Line) Set(src any) error { return fmt.Errorf("cannot convert %v to Line", src) } -// Scan implements the database/sql Scanner interface. +// Scan implements the [database/sql.Scanner] interface. func (line *Line) Scan(src any) error { if src == nil { *line = Line{} @@ -52,7 +54,7 @@ func (line *Line) Scan(src any) error { return fmt.Errorf("cannot scan %T", src) } -// Value implements the database/sql/driver Valuer interface. +// Value implements the [database/sql/driver.Valuer] interface. func (line Line) Value() (driver.Value, error) { if !line.Valid { return nil, nil @@ -129,7 +131,6 @@ func (encodePlanLineCodecText) Encode(value any, buf []byte) (newBuf []byte, err } func (LineCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan { - switch format { case BinaryFormatCode: switch target.(type) { diff --git a/vendor/github.com/jackc/pgx/v5/pgtype/lseg.go b/vendor/github.com/jackc/pgx/v5/pgtype/lseg.go index 05a86e1c..ed0d40d2 100644 --- a/vendor/github.com/jackc/pgx/v5/pgtype/lseg.go +++ b/vendor/github.com/jackc/pgx/v5/pgtype/lseg.go @@ -24,16 +24,18 @@ type Lseg struct { Valid bool } +// ScanLseg implements the [LsegScanner] interface. func (lseg *Lseg) ScanLseg(v Lseg) error { *lseg = v return nil } +// LsegValue implements the [LsegValuer] interface. func (lseg Lseg) LsegValue() (Lseg, error) { return lseg, nil } -// Scan implements the database/sql Scanner interface. +// Scan implements the [database/sql.Scanner] interface. func (lseg *Lseg) Scan(src any) error { if src == nil { *lseg = Lseg{} @@ -48,7 +50,7 @@ func (lseg *Lseg) Scan(src any) error { return fmt.Errorf("cannot scan %T", src) } -// Value implements the database/sql/driver Valuer interface. +// Value implements the [database/sql/driver.Valuer] interface. func (lseg Lseg) Value() (driver.Value, error) { if !lseg.Valid { return nil, nil @@ -127,7 +129,6 @@ func (encodePlanLsegCodecText) Encode(value any, buf []byte) (newBuf []byte, err } func (LsegCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan { - switch format { case BinaryFormatCode: switch target.(type) { diff --git a/vendor/github.com/jackc/pgx/v5/pgtype/multirange.go b/vendor/github.com/jackc/pgx/v5/pgtype/multirange.go index e5763788..0c02575c 100644 --- a/vendor/github.com/jackc/pgx/v5/pgtype/multirange.go +++ b/vendor/github.com/jackc/pgx/v5/pgtype/multirange.go @@ -98,7 +98,7 @@ func (p *encodePlanMultirangeCodecText) Encode(value any, buf []byte) (newBuf [] var encodePlan EncodePlan var lastElemType reflect.Type inElemBuf := make([]byte, 0, 32) - for i := 0; i < elementCount; i++ { + for i := range elementCount { if i > 0 { buf = append(buf, ',') } @@ -151,7 +151,7 @@ func (p *encodePlanMultirangeCodecBinary) Encode(value any, buf []byte) (newBuf var encodePlan EncodePlan var lastElemType reflect.Type - for i := 0; i < elementCount; i++ { + for i := range elementCount { sp := len(buf) buf = pgio.AppendInt32(buf, -1) @@ -210,6 +210,11 @@ func (c *MultirangeCodec) decodeBinary(m *Map, multirangeOID uint32, src []byte, elementCount := int(binary.BigEndian.Uint32(src[rp:])) rp += 4 + // Each element requires at least 4 bytes for its length prefix. + if elementCount > len(src)/4 { + return fmt.Errorf("multirange element count %d exceeds available data", elementCount) + } + err := multirange.SetLen(elementCount) if err != nil { return err @@ -224,7 +229,7 @@ func (c *MultirangeCodec) decodeBinary(m *Map, multirangeOID uint32, src []byte, elementScanPlan = m.PlanScan(c.ElementType.OID, BinaryFormatCode, multirange.ScanIndex(0)) } - for i := 0; i < elementCount; i++ { + for i := range elementCount { elem := multirange.ScanIndex(i) elemLen := int(int32(binary.BigEndian.Uint32(src[rp:]))) rp += 4 @@ -374,7 +379,6 @@ parseValueLoop: } return elements, nil - } func parseRange(buf *bytes.Buffer) (string, error) { @@ -403,8 +407,8 @@ func parseRange(buf *bytes.Buffer) (string, error) { // Multirange is a generic multirange type. // -// T should implement RangeValuer and *T should implement RangeScanner. However, there does not appear to be a way to -// enforce the RangeScanner constraint. +// T should implement [RangeValuer] and *T should implement [RangeScanner]. However, there does not appear to be a way to +// enforce the [RangeScanner] constraint. type Multirange[T RangeValuer] []T func (r Multirange[T]) IsNull() bool { diff --git a/vendor/github.com/jackc/pgx/v5/pgtype/numeric.go b/vendor/github.com/jackc/pgx/v5/pgtype/numeric.go index 4dbec786..c9022abc 100644 --- a/vendor/github.com/jackc/pgx/v5/pgtype/numeric.go +++ b/vendor/github.com/jackc/pgx/v5/pgtype/numeric.go @@ -14,7 +14,7 @@ import ( ) // PostgreSQL internal numeric storage uses 16-bit "digits" with base of 10,000 -const nbase = 10000 +const nbase = 10_000 const ( pgNumericNaN = 0x00000000c0000000 @@ -27,16 +27,19 @@ const ( pgNumericNegInfSign = 0xf000 ) -var big0 *big.Int = big.NewInt(0) -var big1 *big.Int = big.NewInt(1) -var big10 *big.Int = big.NewInt(10) -var big100 *big.Int = big.NewInt(100) -var big1000 *big.Int = big.NewInt(1000) +var ( + big1 *big.Int = big.NewInt(1) + big10 *big.Int = big.NewInt(10) + big100 *big.Int = big.NewInt(100) + big1000 *big.Int = big.NewInt(1000) +) -var bigNBase *big.Int = big.NewInt(nbase) -var bigNBaseX2 *big.Int = big.NewInt(nbase * nbase) -var bigNBaseX3 *big.Int = big.NewInt(nbase * nbase * nbase) -var bigNBaseX4 *big.Int = big.NewInt(nbase * nbase * nbase * nbase) +var ( + bigNBase *big.Int = big.NewInt(nbase) + bigNBaseX2 *big.Int = big.NewInt(nbase * nbase) + bigNBaseX3 *big.Int = big.NewInt(nbase * nbase * nbase) + bigNBaseX4 *big.Int = big.NewInt(nbase * nbase * nbase * nbase) +) type NumericScanner interface { ScanNumeric(v Numeric) error @@ -54,15 +57,18 @@ type Numeric struct { Valid bool } +// ScanNumeric implements the [NumericScanner] interface. func (n *Numeric) ScanNumeric(v Numeric) error { *n = v return nil } +// NumericValue implements the [NumericValuer] interface. func (n Numeric) NumericValue() (Numeric, error) { return n, nil } +// Float64Value implements the [Float64Valuer] interface. func (n Numeric) Float64Value() (Float8, error) { if !n.Valid { return Float8{}, nil @@ -92,6 +98,7 @@ func (n Numeric) Float64Value() (Float8, error) { return Float8{Float64: f, Valid: true}, nil } +// ScanInt64 implements the [Int64Scanner] interface. func (n *Numeric) ScanInt64(v Int8) error { if !v.Valid { *n = Numeric{} @@ -102,6 +109,7 @@ func (n *Numeric) ScanInt64(v Int8) error { return nil } +// Int64Value implements the [Int64Valuer] interface. func (n Numeric) Int64Value() (Int8, error) { if !n.Valid { return Int8{}, nil @@ -120,7 +128,7 @@ func (n Numeric) Int64Value() (Int8, error) { } func (n *Numeric) ScanScientific(src string) error { - if !strings.ContainsAny("eE", src) { + if !strings.ContainsAny(src, "eE") { return scanPlanTextAnyToNumericScanner{}.Scan([]byte(src), n) } @@ -157,7 +165,7 @@ func (n *Numeric) toBigInt() (*big.Int, error) { div.Exp(big10, big.NewInt(int64(-n.Exp)), nil) remainder := &big.Int{} num.DivMod(num, div, remainder) - if remainder.Cmp(big0) != 0 { + if remainder.Sign() != 0 { return nil, fmt.Errorf("cannot convert %v to integer", n) } return num, nil @@ -185,14 +193,11 @@ func parseNumericString(str string) (n *big.Int, exp int32, err error) { } func nbaseDigitsToInt64(src []byte) (accum int64, bytesRead, digitsRead int) { - digits := len(src) / 2 - if digits > 4 { - digits = 4 - } + digits := min(len(src)/2, 4) rp := 0 - for i := 0; i < digits; i++ { + for i := range digits { if i > 0 { accum *= nbase } @@ -203,7 +208,7 @@ func nbaseDigitsToInt64(src []byte) (accum int64, bytesRead, digitsRead int) { return accum, rp, digits } -// Scan implements the database/sql Scanner interface. +// Scan implements the [database/sql.Scanner] interface. func (n *Numeric) Scan(src any) error { if src == nil { *n = Numeric{} @@ -218,7 +223,7 @@ func (n *Numeric) Scan(src any) error { return fmt.Errorf("cannot scan %T", src) } -// Value implements the database/sql/driver Valuer interface. +// Value implements the [database/sql/driver.Valuer] interface. func (n Numeric) Value() (driver.Value, error) { if !n.Valid { return nil, nil @@ -231,6 +236,7 @@ func (n Numeric) Value() (driver.Value, error) { return string(buf), err } +// MarshalJSON implements the [encoding/json.Marshaler] interface. func (n Numeric) MarshalJSON() ([]byte, error) { if !n.Valid { return []byte("null"), nil @@ -243,6 +249,7 @@ func (n Numeric) MarshalJSON() ([]byte, error) { return n.numberTextBytes(), nil } +// UnmarshalJSON implements the [encoding/json.Unmarshaler] interface. func (n *Numeric) UnmarshalJSON(src []byte) error { if bytes.Equal(src, []byte(`null`)) { *n = Numeric{} @@ -257,6 +264,10 @@ func (n *Numeric) UnmarshalJSON(src []byte) error { // numberString returns a string of the number. undefined if NaN, infinite, or NULL func (n Numeric) numberTextBytes() []byte { + if n.Int == nil { + return []byte("0") + } + intStr := n.Int.String() buf := &bytes.Buffer{} @@ -269,14 +280,14 @@ func (n Numeric) numberTextBytes() []byte { exp := int(n.Exp) if exp > 0 { buf.WriteString(intStr) - for i := 0; i < exp; i++ { + for range exp { buf.WriteByte('0') } } else if exp < 0 { if len(intStr) <= -exp { buf.WriteString("0.") leadingZeros := -exp - len(intStr) - for i := 0; i < leadingZeros; i++ { + for range leadingZeros { buf.WriteByte('0') } buf.WriteString(intStr) @@ -398,7 +409,7 @@ func encodeNumericBinary(n Numeric, buf []byte) (newBuf []byte, err error) { } var sign int16 - if n.Int.Cmp(big0) < 0 { + if n.Int != nil && n.Int.Sign() < 0 { sign = 16384 } @@ -406,7 +417,9 @@ func encodeNumericBinary(n Numeric, buf []byte) (newBuf []byte, err error) { wholePart := &big.Int{} fracPart := &big.Int{} remainder := &big.Int{} - absInt.Abs(n.Int) + if n.Int != nil { + absInt.Abs(n.Int) + } // Normalize absInt and exp to where exp is always a multiple of 4. This makes // converting to 16-bit base 10,000 digits easier. @@ -436,12 +449,12 @@ func encodeNumericBinary(n Numeric, buf []byte) (newBuf []byte, err error) { var wholeDigits, fracDigits []int16 - for wholePart.Cmp(big0) != 0 { + for wholePart.Sign() != 0 { wholePart.DivMod(wholePart, bigNBase, remainder) wholeDigits = append(wholeDigits, int16(remainder.Int64())) } - if fracPart.Cmp(big0) != 0 { + if fracPart.Sign() != 0 { for fracPart.Cmp(big1) != 0 { fracPart.DivMod(fracPart, bigNBase, remainder) fracDigits = append(fracDigits, int16(remainder.Int64())) @@ -553,7 +566,6 @@ func encodeNumericText(n Numeric, buf []byte) (newBuf []byte, err error) { } func (NumericCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan { - switch format { case BinaryFormatCode: switch target.(type) { @@ -648,18 +660,19 @@ func (scanPlanBinaryNumericToNumericScanner) Scan(src []byte, dst any) error { exp := (int32(weight) - int32(ndigits) + 1) * 4 if dscale > 0 { - fracNBaseDigits := int16(int32(ndigits) - int32(weight) - 1) + fracNBaseDigits := int(ndigits) - int(weight) - 1 fracDecimalDigits := fracNBaseDigits * 4 + dscaleInt := int(dscale) - if dscale > fracDecimalDigits { - multCount := int(dscale - fracDecimalDigits) - for i := 0; i < multCount; i++ { + if dscaleInt > fracDecimalDigits { + multCount := dscaleInt - fracDecimalDigits + for range multCount { accum.Mul(accum, big10) exp-- } - } else if dscale < fracDecimalDigits { - divCount := int(fracDecimalDigits - dscale) - for i := 0; i < divCount; i++ { + } else if dscaleInt < fracDecimalDigits { + divCount := fracDecimalDigits - dscaleInt + for range divCount { accum.Div(accum, big10) exp++ } @@ -671,7 +684,7 @@ func (scanPlanBinaryNumericToNumericScanner) Scan(src []byte, dst any) error { if exp >= 0 { for { reduced.DivMod(accum, big10, remainder) - if remainder.Cmp(big0) != 0 { + if remainder.Sign() != 0 { break } accum.Set(reduced) diff --git a/vendor/github.com/jackc/pgx/v5/pgtype/path.go b/vendor/github.com/jackc/pgx/v5/pgtype/path.go index 73e0ec52..685996a8 100644 --- a/vendor/github.com/jackc/pgx/v5/pgtype/path.go +++ b/vendor/github.com/jackc/pgx/v5/pgtype/path.go @@ -25,16 +25,18 @@ type Path struct { Valid bool } +// ScanPath implements the [PathScanner] interface. func (path *Path) ScanPath(v Path) error { *path = v return nil } +// PathValue implements the [PathValuer] interface. func (path Path) PathValue() (Path, error) { return path, nil } -// Scan implements the database/sql Scanner interface. +// Scan implements the [database/sql.Scanner] interface. func (path *Path) Scan(src any) error { if src == nil { *path = Path{} @@ -49,7 +51,7 @@ func (path *Path) Scan(src any) error { return fmt.Errorf("cannot scan %T", src) } -// Value implements the database/sql/driver Valuer interface. +// Value implements the [database/sql/driver.Valuer] interface. func (path Path) Value() (driver.Value, error) { if !path.Valid { return nil, nil @@ -154,7 +156,6 @@ func (encodePlanPathCodecText) Encode(value any, buf []byte) (newBuf []byte, err } func (PathCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan { - switch format { case BinaryFormatCode: switch target.(type) { @@ -194,7 +195,7 @@ func (scanPlanBinaryPathToPathScanner) Scan(src []byte, dst any) error { } points := make([]Vec2, pointCount) - for i := 0; i < len(points); i++ { + for i := range points { x := binary.BigEndian.Uint64(src[rp:]) rp += 8 y := binary.BigEndian.Uint64(src[rp:]) diff --git a/vendor/github.com/jackc/pgx/v5/pgtype/pgtype.go b/vendor/github.com/jackc/pgx/v5/pgtype/pgtype.go index bdd9f05c..253d8096 100644 --- a/vendor/github.com/jackc/pgx/v5/pgtype/pgtype.go +++ b/vendor/github.com/jackc/pgx/v5/pgtype/pgtype.go @@ -29,6 +29,7 @@ const ( XMLOID = 142 XMLArrayOID = 143 JSONArrayOID = 199 + XID8ArrayOID = 271 PointOID = 600 LsegOID = 601 PathOID = 602 @@ -95,6 +96,8 @@ const ( RecordArrayOID = 2287 UUIDOID = 2950 UUIDArrayOID = 2951 + TSVectorOID = 3614 + TSVectorArrayOID = 3643 JSONBOID = 3802 JSONBArrayOID = 3807 DaterangeOID = 3912 @@ -117,6 +120,7 @@ const ( TstzmultirangeOID = 4534 DatemultirangeOID = 4535 Int8multirangeOID = 4536 + XID8OID = 5069 Int4multirangeArrayOID = 6150 NummultirangeArrayOID = 6151 TsmultirangeArrayOID = 6152 @@ -152,7 +156,7 @@ const ( BinaryFormatCode = 1 ) -// A Codec converts between Go and PostgreSQL values. A Codec must not be mutated after it is registered with a Map. +// A Codec converts between Go and PostgreSQL values. A Codec must not be mutated after it is registered with a [Map]. type Codec interface { // FormatSupported returns true if the format is supported. FormatSupported(int16) bool @@ -183,7 +187,7 @@ func (e *nullAssignmentError) Error() string { return fmt.Sprintf("cannot assign NULL to %T", e.dst) } -// Type represents a PostgreSQL data type. It must not be mutated after it is registered with a Map. +// Type represents a PostgreSQL data type. It must not be mutated after it is registered with a [Map]. type Type struct { Codec Codec Name string @@ -200,7 +204,6 @@ type Map struct { reflectTypeToType map[reflect.Type]*Type - memoizedScanPlans map[uint32]map[reflect.Type][2]ScanPlan memoizedEncodePlans map[uint32]map[reflect.Type][2]EncodePlan // TryWrapEncodePlanFuncs is a slice of functions that will wrap a value that cannot be encoded by the Codec. Every @@ -234,13 +237,13 @@ func NewMap() *Map { reflectTypeToName: make(map[reflect.Type]string), oidToFormatCode: make(map[uint32]int16), - memoizedScanPlans: make(map[uint32]map[reflect.Type][2]ScanPlan), memoizedEncodePlans: make(map[uint32]map[reflect.Type][2]EncodePlan), TryWrapEncodePlanFuncs: []TryWrapEncodePlanFunc{ TryWrapDerefPointerEncodePlan, TryWrapBuiltinTypeEncodePlan, TryWrapFindUnderlyingTypeEncodePlan, + TryWrapStringerEncodePlan, TryWrapStructEncodePlan, TryWrapSliceEncodePlan, TryWrapMultiDimSliceEncodePlan, @@ -266,7 +269,7 @@ func (m *Map) RegisterTypes(types []*Type) { } } -// RegisterType registers a data type with the Map. t must not be mutated after it is registered. +// RegisterType registers a data type with the [Map]. t must not be mutated after it is registered. func (m *Map) RegisterType(t *Type) { m.oidToType[t.OID] = t m.nameToType[t.Name] = t @@ -274,9 +277,6 @@ func (m *Map) RegisterType(t *Type) { // Invalidated by type registration m.reflectTypeToType = nil - for k := range m.memoizedScanPlans { - delete(m.memoizedScanPlans, k) - } for k := range m.memoizedEncodePlans { delete(m.memoizedEncodePlans, k) } @@ -290,15 +290,12 @@ func (m *Map) RegisterDefaultPgType(value any, name string) { // Invalidated by type registration m.reflectTypeToType = nil - for k := range m.memoizedScanPlans { - delete(m.memoizedScanPlans, k) - } for k := range m.memoizedEncodePlans { delete(m.memoizedEncodePlans, k) } } -// TypeForOID returns the Type registered for the given OID. The returned Type must not be mutated. +// TypeForOID returns the [Type] registered for the given OID. The returned [Type] must not be mutated. func (m *Map) TypeForOID(oid uint32) (*Type, bool) { if dt, ok := m.oidToType[oid]; ok { return dt, true @@ -308,7 +305,7 @@ func (m *Map) TypeForOID(oid uint32) (*Type, bool) { return dt, ok } -// TypeForName returns the Type registered for the given name. The returned Type must not be mutated. +// TypeForName returns the [Type] registered for the given name. The returned [Type] must not be mutated. func (m *Map) TypeForName(name string) (*Type, bool) { if dt, ok := m.nameToType[name]; ok { return dt, true @@ -327,8 +324,8 @@ func (m *Map) buildReflectTypeToType() { } } -// TypeForValue finds a data type suitable for v. Use RegisterType to register types that can encode and decode -// themselves. Use RegisterDefaultPgType to register that can be handled by a registered data type. The returned Type +// TypeForValue finds a data type suitable for v. Use [Map.RegisterType] to register types that can encode and decode +// themselves. Use [Map.RegisterDefaultPgType] to register that can be handled by a registered data type. The returned [Type] // must not be mutated. func (m *Map) TypeForValue(v any) (*Type, bool) { if m.reflectTypeToType == nil { @@ -395,6 +392,7 @@ type scanPlanSQLScanner struct { func (plan *scanPlanSQLScanner) Scan(src []byte, dst any) error { scanner := dst.(sql.Scanner) + if src == nil { // This is necessary because interface value []byte:nil does not equal nil:nil for the binary format path and the // text format path would be converted to empty string. @@ -449,14 +447,14 @@ func (plan *scanPlanFail) Scan(src []byte, dst any) error { // As a horrible hack try all types to find anything that can scan into dst. for oid := range plan.m.oidToType { // using planScan instead of Scan or PlanScan to avoid polluting the planned scan cache. - plan := plan.m.planScan(oid, plan.formatCode, dst) + plan := plan.m.planScan(oid, plan.formatCode, dst, 0) if _, ok := plan.(*scanPlanFail); !ok { return plan.Scan(src, dst) } } for oid := range defaultMap.oidToType { if _, ok := plan.m.oidToType[oid]; !ok { - plan := plan.m.planScan(oid, plan.formatCode, dst) + plan := plan.m.planScan(oid, plan.formatCode, dst, 0) if _, ok := plan.(*scanPlanFail); !ok { return plan.Scan(src, dst) } @@ -528,20 +526,20 @@ type SkipUnderlyingTypePlanner interface { } var elemKindToPointerTypes map[reflect.Kind]reflect.Type = map[reflect.Kind]reflect.Type{ - reflect.Int: reflect.TypeOf(new(int)), - reflect.Int8: reflect.TypeOf(new(int8)), - reflect.Int16: reflect.TypeOf(new(int16)), - reflect.Int32: reflect.TypeOf(new(int32)), - reflect.Int64: reflect.TypeOf(new(int64)), - reflect.Uint: reflect.TypeOf(new(uint)), - reflect.Uint8: reflect.TypeOf(new(uint8)), - reflect.Uint16: reflect.TypeOf(new(uint16)), - reflect.Uint32: reflect.TypeOf(new(uint32)), - reflect.Uint64: reflect.TypeOf(new(uint64)), - reflect.Float32: reflect.TypeOf(new(float32)), - reflect.Float64: reflect.TypeOf(new(float64)), - reflect.String: reflect.TypeOf(new(string)), - reflect.Bool: reflect.TypeOf(new(bool)), + reflect.Int: reflect.TypeFor[*int](), + reflect.Int8: reflect.TypeFor[*int8](), + reflect.Int16: reflect.TypeFor[*int16](), + reflect.Int32: reflect.TypeFor[*int32](), + reflect.Int64: reflect.TypeFor[*int64](), + reflect.Uint: reflect.TypeFor[*uint](), + reflect.Uint8: reflect.TypeFor[*uint8](), + reflect.Uint16: reflect.TypeFor[*uint16](), + reflect.Uint32: reflect.TypeFor[*uint32](), + reflect.Uint64: reflect.TypeFor[*uint64](), + reflect.Float32: reflect.TypeFor[*float32](), + reflect.Float64: reflect.TypeFor[*float64](), + reflect.String: reflect.TypeFor[*string](), + reflect.Bool: reflect.TypeFor[*bool](), } type underlyingTypeScanPlan struct { @@ -906,7 +904,7 @@ func (plan *pointerEmptyInterfaceScanPlan) Scan(src []byte, dst any) error { return nil } -// TryWrapStructPlan tries to wrap a struct with a wrapper that implements CompositeIndexGetter. +// TryWrapStructScanPlan tries to wrap a struct with a wrapper that implements CompositeIndexGetter. func TryWrapStructScanPlan(target any) (plan WrappedScanPlanNextSetter, nextValue any, ok bool) { targetValue := reflect.ValueOf(target) if targetValue.Kind() != reflect.Ptr { @@ -1064,24 +1062,14 @@ func (plan *wrapPtrArrayReflectScanPlan) Scan(src []byte, target any) error { // PlanScan prepares a plan to scan a value into target. func (m *Map) PlanScan(oid uint32, formatCode int16, target any) ScanPlan { - oidMemo := m.memoizedScanPlans[oid] - if oidMemo == nil { - oidMemo = make(map[reflect.Type][2]ScanPlan) - m.memoizedScanPlans[oid] = oidMemo - } - targetReflectType := reflect.TypeOf(target) - typeMemo := oidMemo[targetReflectType] - plan := typeMemo[formatCode] - if plan == nil { - plan = m.planScan(oid, formatCode, target) - typeMemo[formatCode] = plan - oidMemo[targetReflectType] = typeMemo - } - - return plan + return m.planScan(oid, formatCode, target, 0) } -func (m *Map) planScan(oid uint32, formatCode int16, target any) ScanPlan { +func (m *Map) planScan(oid uint32, formatCode int16, target any, depth int) ScanPlan { + if depth > 8 { + return &scanPlanFail{m: m, oid: oid, formatCode: formatCode} + } + if target == nil { return &scanPlanFail{m: m, oid: oid, formatCode: formatCode} } @@ -1141,7 +1129,7 @@ func (m *Map) planScan(oid uint32, formatCode int16, target any) ScanPlan { for _, f := range m.TryWrapScanPlanFuncs { if wrapperPlan, nextDst, ok := f(target); ok { - if nextPlan := m.planScan(oid, formatCode, nextDst); nextPlan != nil { + if nextPlan := m.planScan(oid, formatCode, nextDst, depth+1); nextPlan != nil { if _, failed := nextPlan.(*scanPlanFail); !failed { wrapperPlan.SetNext(nextPlan) return wrapperPlan @@ -1150,10 +1138,18 @@ func (m *Map) planScan(oid uint32, formatCode int16, target any) ScanPlan { } } - if dt != nil { - if _, ok := target.(*any); ok { - return &pointerEmptyInterfaceScanPlan{codec: dt.Codec, m: m, oid: oid, formatCode: formatCode} + if _, ok := target.(*any); ok { + var codec Codec + if dt != nil { + codec = dt.Codec + } else { + if formatCode == TextFormatCode { + codec = TextCodec{} + } else { + codec = ByteaCodec{} + } } + return &pointerEmptyInterfaceScanPlan{codec: codec, m: m, oid: oid, formatCode: formatCode} } return &scanPlanFail{m: m, oid: oid, formatCode: formatCode} @@ -1198,9 +1194,18 @@ func codecDecodeToTextFormat(codec Codec, m *Map, oid uint32, format int16, src } } -// PlanEncode returns an Encode plan for encoding value into PostgreSQL format for oid and format. If no plan can be +// PlanEncode returns an EncodePlan for encoding value into PostgreSQL format for oid and format. If no plan can be // found then nil is returned. func (m *Map) PlanEncode(oid uint32, format int16, value any) EncodePlan { + return m.planEncodeDepth(oid, format, value, 0) +} + +func (m *Map) planEncodeDepth(oid uint32, format int16, value any, depth int) EncodePlan { + // Guard against infinite recursion. + if depth > 8 { + return nil + } + oidMemo := m.memoizedEncodePlans[oid] if oidMemo == nil { oidMemo = make(map[reflect.Type][2]EncodePlan) @@ -1210,7 +1215,7 @@ func (m *Map) PlanEncode(oid uint32, format int16, value any) EncodePlan { typeMemo := oidMemo[targetReflectType] plan := typeMemo[format] if plan == nil { - plan = m.planEncode(oid, format, value) + plan = m.planEncode(oid, format, value, depth) typeMemo[format] = plan oidMemo[targetReflectType] = typeMemo } @@ -1218,7 +1223,7 @@ func (m *Map) PlanEncode(oid uint32, format int16, value any) EncodePlan { return plan } -func (m *Map) planEncode(oid uint32, format int16, value any) EncodePlan { +func (m *Map) planEncode(oid uint32, format int16, value any, depth int) EncodePlan { if format == TextFormatCode { switch value.(type) { case string: @@ -1249,7 +1254,7 @@ func (m *Map) planEncode(oid uint32, format int16, value any) EncodePlan { for _, f := range m.TryWrapEncodePlanFuncs { if wrapperPlan, nextValue, ok := f(value); ok { - if nextPlan := m.PlanEncode(oid, format, nextValue); nextPlan != nil { + if nextPlan := m.planEncodeDepth(oid, format, nextValue, depth+1); nextPlan != nil { wrapperPlan.SetNext(nextPlan) return wrapperPlan } @@ -1370,23 +1375,23 @@ func TryWrapDerefPointerEncodePlan(value any) (plan WrappedEncodePlanNextSetter, } var kindToTypes map[reflect.Kind]reflect.Type = map[reflect.Kind]reflect.Type{ - reflect.Int: reflect.TypeOf(int(0)), - reflect.Int8: reflect.TypeOf(int8(0)), - reflect.Int16: reflect.TypeOf(int16(0)), - reflect.Int32: reflect.TypeOf(int32(0)), - reflect.Int64: reflect.TypeOf(int64(0)), - reflect.Uint: reflect.TypeOf(uint(0)), - reflect.Uint8: reflect.TypeOf(uint8(0)), - reflect.Uint16: reflect.TypeOf(uint16(0)), - reflect.Uint32: reflect.TypeOf(uint32(0)), - reflect.Uint64: reflect.TypeOf(uint64(0)), - reflect.Float32: reflect.TypeOf(float32(0)), - reflect.Float64: reflect.TypeOf(float64(0)), - reflect.String: reflect.TypeOf(""), - reflect.Bool: reflect.TypeOf(false), -} - -var byteSliceType = reflect.TypeOf([]byte{}) + reflect.Int: reflect.TypeFor[int](), + reflect.Int8: reflect.TypeFor[int8](), + reflect.Int16: reflect.TypeFor[int16](), + reflect.Int32: reflect.TypeFor[int32](), + reflect.Int64: reflect.TypeFor[int64](), + reflect.Uint: reflect.TypeFor[uint](), + reflect.Uint8: reflect.TypeFor[uint8](), + reflect.Uint16: reflect.TypeFor[uint16](), + reflect.Uint32: reflect.TypeFor[uint32](), + reflect.Uint64: reflect.TypeFor[uint64](), + reflect.Float32: reflect.TypeFor[float32](), + reflect.Float64: reflect.TypeFor[float64](), + reflect.String: reflect.TypeFor[string](), + reflect.Bool: reflect.TypeFor[bool](), +} + +var byteSliceType = reflect.TypeFor[[]byte]() type underlyingTypeEncodePlan struct { nextValueType reflect.Type @@ -1442,6 +1447,24 @@ func TryWrapFindUnderlyingTypeEncodePlan(value any) (plan WrappedEncodePlanNextS return nil, nil, false } +// TryWrapStringerEncodePlan tries to wrap a fmt.Stringer type with a wrapper that provides TextValuer. This is +// intentionally a separate function from TryWrapBuiltinTypeEncodePlan so it can be ordered after +// TryWrapFindUnderlyingTypeEncodePlan. This ensures that named types with an underlying builtin type (e.g. type MyEnum +// int32 with a String() method) prefer encoding via the underlying type's codec (e.g. as an integer) rather than via +// Stringer. Stringer is only used as a fallback when no type-specific encoding plan succeeds. +// (https://github.com/jackc/pgx/discussions/2527) +func TryWrapStringerEncodePlan(value any) (plan WrappedEncodePlanNextSetter, nextValue any, ok bool) { + if _, ok := value.(driver.Valuer); ok { + return nil, nil, false + } + + if s, ok := value.(fmt.Stringer); ok { + return &wrapFmtStringerEncodePlan{}, fmtStringerWrapper{s}, true + } + + return nil, nil, false +} + type WrappedEncodePlanNextSetter interface { SetNext(EncodePlan) EncodePlan @@ -1502,8 +1525,6 @@ func TryWrapBuiltinTypeEncodePlan(value any) (plan WrappedEncodePlanNextSetter, return &wrapByte16EncodePlan{}, byte16Wrapper(value), true case []byte: return &wrapByteSliceEncodePlan{}, byteSliceWrapper(value), true - case fmt.Stringer: - return &wrapFmtStringerEncodePlan{}, fmtStringerWrapper{value}, true } return nil, nil, false @@ -1749,7 +1770,7 @@ func (plan *wrapFmtStringerEncodePlan) Encode(value any, buf []byte) (newBuf []b return plan.next.Encode(fmtStringerWrapper{value.(fmt.Stringer)}, buf) } -// TryWrapStructPlan tries to wrap a struct with a wrapper that implements CompositeIndexGetter. +// TryWrapStructEncodePlan tries to wrap a struct with a wrapper that implements CompositeIndexGetter. func TryWrapStructEncodePlan(value any) (plan WrappedEncodePlanNextSetter, nextValue any, ok bool) { if _, ok := value.(driver.Valuer); ok { return nil, nil, false @@ -2005,37 +2026,18 @@ func (w *sqlScannerWrapper) Scan(src any) error { case []byte: bufSrc = src default: - bufSrc = []byte(fmt.Sprint(bufSrc)) + bufSrc = fmt.Append(nil, bufSrc) } } return w.m.Scan(t.OID, TextFormatCode, bufSrc, w.v) } -// canBeNil returns true if value can be nil. -func canBeNil(value any) bool { - refVal := reflect.ValueOf(value) - kind := refVal.Kind() - switch kind { - case reflect.Chan, reflect.Func, reflect.Map, reflect.Ptr, reflect.UnsafePointer, reflect.Interface, reflect.Slice: - return true - default: - return false - } -} - -// valuerReflectType is a reflect.Type for driver.Valuer. It has confusing syntax because reflect.TypeOf returns nil -// when it's argument is a nil interface value. So we use a pointer to the interface and call Elem to get the actual -// type. Yuck. -// -// This can be simplified in Go 1.22 with reflect.TypeFor. -// -// var valuerReflectType = reflect.TypeFor[driver.Valuer]() -var valuerReflectType = reflect.TypeOf((*driver.Valuer)(nil)).Elem() +var valuerReflectType = reflect.TypeFor[driver.Valuer]() // isNilDriverValuer returns true if value is any type of nil unless it implements driver.Valuer. *T is not considered to implement // driver.Valuer if it is only implemented by T. -func isNilDriverValuer(value any) (isNil bool, callNilDriverValuer bool) { +func isNilDriverValuer(value any) (isNil, callNilDriverValuer bool) { if value == nil { return true, false } diff --git a/vendor/github.com/jackc/pgx/v5/pgtype/pgtype_default.go b/vendor/github.com/jackc/pgx/v5/pgtype/pgtype_default.go index c8125731..42b39d82 100644 --- a/vendor/github.com/jackc/pgx/v5/pgtype/pgtype_default.go +++ b/vendor/github.com/jackc/pgx/v5/pgtype/pgtype_default.go @@ -23,7 +23,6 @@ func initDefaultMap() { reflectTypeToName: make(map[reflect.Type]string), oidToFormatCode: make(map[uint32]int16), - memoizedScanPlans: make(map[uint32]map[reflect.Type][2]ScanPlan), memoizedEncodePlans: make(map[uint32]map[reflect.Type][2]EncodePlan), TryWrapEncodePlanFuncs: []TryWrapEncodePlanFunc{ @@ -82,6 +81,7 @@ func initDefaultMap() { defaultMap.RegisterType(&Type{Name: "record", OID: RecordOID, Codec: RecordCodec{}}) defaultMap.RegisterType(&Type{Name: "text", OID: TextOID, Codec: TextCodec{}}) defaultMap.RegisterType(&Type{Name: "tid", OID: TIDOID, Codec: TIDCodec{}}) + defaultMap.RegisterType(&Type{Name: "tsvector", OID: TSVectorOID, Codec: TSVectorCodec{}}) defaultMap.RegisterType(&Type{Name: "time", OID: TimeOID, Codec: TimeCodec{}}) defaultMap.RegisterType(&Type{Name: "timestamp", OID: TimestampOID, Codec: &TimestampCodec{}}) defaultMap.RegisterType(&Type{Name: "timestamptz", OID: TimestamptzOID, Codec: &TimestamptzCodec{}}) @@ -90,7 +90,26 @@ func initDefaultMap() { defaultMap.RegisterType(&Type{Name: "varbit", OID: VarbitOID, Codec: BitsCodec{}}) defaultMap.RegisterType(&Type{Name: "varchar", OID: VarcharOID, Codec: TextCodec{}}) defaultMap.RegisterType(&Type{Name: "xid", OID: XIDOID, Codec: Uint32Codec{}}) - defaultMap.RegisterType(&Type{Name: "xml", OID: XMLOID, Codec: &XMLCodec{Marshal: xml.Marshal, Unmarshal: xml.Unmarshal}}) + defaultMap.RegisterType(&Type{Name: "xid8", OID: XID8OID, Codec: Uint64Codec{}}) + defaultMap.RegisterType(&Type{Name: "xml", OID: XMLOID, Codec: &XMLCodec{ + Marshal: xml.Marshal, + // xml.Unmarshal does not support unmarshalling into *any. However, XMLCodec.DecodeValue calls Unmarshal with a + // *any. Wrap xml.Marshal with a function that copies the data into a new byte slice in this case. Not implementing + // directly in XMLCodec.DecodeValue to allow for the unlikely possibility that someone uses an alternative XML + // unmarshaler that does support unmarshalling into *any. + // + // https://github.com/jackc/pgx/issues/2227 + // https://github.com/jackc/pgx/pull/2228 + Unmarshal: func(data []byte, v any) error { + if v, ok := v.(*any); ok { + dstBuf := make([]byte, len(data)) + copy(dstBuf, data) + *v = dstBuf + return nil + } + return xml.Unmarshal(data, v) + }, + }}) // Range types defaultMap.RegisterType(&Type{Name: "daterange", OID: DaterangeOID, Codec: &RangeCodec{ElementType: defaultMap.oidToType[DateOID]}}) @@ -146,6 +165,7 @@ func initDefaultMap() { defaultMap.RegisterType(&Type{Name: "_record", OID: RecordArrayOID, Codec: &ArrayCodec{ElementType: defaultMap.oidToType[RecordOID]}}) defaultMap.RegisterType(&Type{Name: "_text", OID: TextArrayOID, Codec: &ArrayCodec{ElementType: defaultMap.oidToType[TextOID]}}) defaultMap.RegisterType(&Type{Name: "_tid", OID: TIDArrayOID, Codec: &ArrayCodec{ElementType: defaultMap.oidToType[TIDOID]}}) + defaultMap.RegisterType(&Type{Name: "_tsvector", OID: TSVectorArrayOID, Codec: &ArrayCodec{ElementType: defaultMap.oidToType[TSVectorOID]}}) defaultMap.RegisterType(&Type{Name: "_time", OID: TimeArrayOID, Codec: &ArrayCodec{ElementType: defaultMap.oidToType[TimeOID]}}) defaultMap.RegisterType(&Type{Name: "_timestamp", OID: TimestampArrayOID, Codec: &ArrayCodec{ElementType: defaultMap.oidToType[TimestampOID]}}) defaultMap.RegisterType(&Type{Name: "_timestamptz", OID: TimestamptzArrayOID, Codec: &ArrayCodec{ElementType: defaultMap.oidToType[TimestamptzOID]}}) @@ -155,6 +175,7 @@ func initDefaultMap() { defaultMap.RegisterType(&Type{Name: "_varbit", OID: VarbitArrayOID, Codec: &ArrayCodec{ElementType: defaultMap.oidToType[VarbitOID]}}) defaultMap.RegisterType(&Type{Name: "_varchar", OID: VarcharArrayOID, Codec: &ArrayCodec{ElementType: defaultMap.oidToType[VarcharOID]}}) defaultMap.RegisterType(&Type{Name: "_xid", OID: XIDArrayOID, Codec: &ArrayCodec{ElementType: defaultMap.oidToType[XIDOID]}}) + defaultMap.RegisterType(&Type{Name: "_xid8", OID: XID8ArrayOID, Codec: &ArrayCodec{ElementType: defaultMap.oidToType[XID8OID]}}) defaultMap.RegisterType(&Type{Name: "_xml", OID: XMLArrayOID, Codec: &ArrayCodec{ElementType: defaultMap.oidToType[XMLOID]}}) // Integer types that directly map to a PostgreSQL type @@ -223,6 +244,7 @@ func initDefaultMap() { registerDefaultPgTypeVariants[Multirange[Range[Timestamp]]](defaultMap, "tsmultirange") registerDefaultPgTypeVariants[Range[Timestamptz]](defaultMap, "tstzrange") registerDefaultPgTypeVariants[Multirange[Range[Timestamptz]]](defaultMap, "tstzmultirange") + registerDefaultPgTypeVariants[TSVector](defaultMap, "tsvector") registerDefaultPgTypeVariants[UUID](defaultMap, "uuid") defaultMap.buildReflectTypeToType() diff --git a/vendor/github.com/jackc/pgx/v5/pgtype/point.go b/vendor/github.com/jackc/pgx/v5/pgtype/point.go index 09b19bb5..b701513d 100644 --- a/vendor/github.com/jackc/pgx/v5/pgtype/point.go +++ b/vendor/github.com/jackc/pgx/v5/pgtype/point.go @@ -30,11 +30,13 @@ type Point struct { Valid bool } +// ScanPoint implements the [PointScanner] interface. func (p *Point) ScanPoint(v Point) error { *p = v return nil } +// PointValue implements the [PointValuer] interface. func (p Point) PointValue() (Point, error) { return p, nil } @@ -68,7 +70,7 @@ func parsePoint(src []byte) (*Point, error) { return &Point{P: Vec2{x, y}, Valid: true}, nil } -// Scan implements the database/sql Scanner interface. +// Scan implements the [database/sql.Scanner] interface. func (dst *Point) Scan(src any) error { if src == nil { *dst = Point{} @@ -83,7 +85,7 @@ func (dst *Point) Scan(src any) error { return fmt.Errorf("cannot scan %T", src) } -// Value implements the database/sql/driver Valuer interface. +// Value implements the [database/sql/driver.Valuer] interface. func (src Point) Value() (driver.Value, error) { if !src.Valid { return nil, nil @@ -96,6 +98,7 @@ func (src Point) Value() (driver.Value, error) { return string(buf), err } +// MarshalJSON implements the [encoding/json.Marshaler] interface. func (src Point) MarshalJSON() ([]byte, error) { if !src.Valid { return []byte("null"), nil @@ -108,6 +111,7 @@ func (src Point) MarshalJSON() ([]byte, error) { return buff.Bytes(), nil } +// UnmarshalJSON implements the [encoding/json.Unmarshaler] interface. func (dst *Point) UnmarshalJSON(point []byte) error { p, err := parsePoint(point) if err != nil { @@ -178,7 +182,6 @@ func (encodePlanPointCodecText) Encode(value any, buf []byte) (newBuf []byte, er } func (PointCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan { - switch format { case BinaryFormatCode: switch target.(type) { diff --git a/vendor/github.com/jackc/pgx/v5/pgtype/polygon.go b/vendor/github.com/jackc/pgx/v5/pgtype/polygon.go index 04b0ba6b..e18c9da6 100644 --- a/vendor/github.com/jackc/pgx/v5/pgtype/polygon.go +++ b/vendor/github.com/jackc/pgx/v5/pgtype/polygon.go @@ -24,16 +24,18 @@ type Polygon struct { Valid bool } +// ScanPolygon implements the [PolygonScanner] interface. func (p *Polygon) ScanPolygon(v Polygon) error { *p = v return nil } +// PolygonValue implements the [PolygonValuer] interface. func (p Polygon) PolygonValue() (Polygon, error) { return p, nil } -// Scan implements the database/sql Scanner interface. +// Scan implements the [database/sql.Scanner] interface. func (p *Polygon) Scan(src any) error { if src == nil { *p = Polygon{} @@ -48,7 +50,7 @@ func (p *Polygon) Scan(src any) error { return fmt.Errorf("cannot scan %T", src) } -// Value implements the database/sql/driver Valuer interface. +// Value implements the [database/sql/driver.Valuer] interface. func (p Polygon) Value() (driver.Value, error) { if !p.Valid { return nil, nil @@ -139,7 +141,6 @@ func (encodePlanPolygonCodecText) Encode(value any, buf []byte) (newBuf []byte, } func (PolygonCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan { - switch format { case BinaryFormatCode: switch target.(type) { @@ -177,7 +178,7 @@ func (scanPlanBinaryPolygonToPolygonScanner) Scan(src []byte, dst any) error { } points := make([]Vec2, pointCount) - for i := 0; i < len(points); i++ { + for i := range points { x := binary.BigEndian.Uint64(src[rp:]) rp += 8 y := binary.BigEndian.Uint64(src[rp:]) diff --git a/vendor/github.com/jackc/pgx/v5/pgtype/range.go b/vendor/github.com/jackc/pgx/v5/pgtype/range.go index 16427ccc..62d69990 100644 --- a/vendor/github.com/jackc/pgx/v5/pgtype/range.go +++ b/vendor/github.com/jackc/pgx/v5/pgtype/range.go @@ -191,11 +191,13 @@ type untypedBinaryRange struct { // 18 = [ = 10010 // 24 = = 11000 -const emptyMask = 1 -const lowerInclusiveMask = 2 -const upperInclusiveMask = 4 -const lowerUnboundedMask = 8 -const upperUnboundedMask = 16 +const ( + emptyMask = 1 + lowerInclusiveMask = 2 + upperInclusiveMask = 4 + lowerUnboundedMask = 8 + upperUnboundedMask = 16 +) func parseUntypedBinaryRange(src []byte) (*untypedBinaryRange, error) { ubr := &untypedBinaryRange{} @@ -273,7 +275,6 @@ func parseUntypedBinaryRange(src []byte) (*untypedBinaryRange, error) { } return ubr, nil - } // Range is a generic range type. diff --git a/vendor/github.com/jackc/pgx/v5/pgtype/record_codec.go b/vendor/github.com/jackc/pgx/v5/pgtype/record_codec.go index b3b16604..90b9bd4b 100644 --- a/vendor/github.com/jackc/pgx/v5/pgtype/record_codec.go +++ b/vendor/github.com/jackc/pgx/v5/pgtype/record_codec.go @@ -121,5 +121,4 @@ func (RecordCodec) DecodeValue(m *Map, oid uint32, format int16, src []byte) (an default: return nil, fmt.Errorf("unknown format code %d", format) } - } diff --git a/vendor/github.com/jackc/pgx/v5/pgtype/text.go b/vendor/github.com/jackc/pgx/v5/pgtype/text.go index 021ee331..e08b1254 100644 --- a/vendor/github.com/jackc/pgx/v5/pgtype/text.go +++ b/vendor/github.com/jackc/pgx/v5/pgtype/text.go @@ -19,16 +19,18 @@ type Text struct { Valid bool } +// ScanText implements the [TextScanner] interface. func (t *Text) ScanText(v Text) error { *t = v return nil } +// TextValue implements the [TextValuer] interface. func (t Text) TextValue() (Text, error) { return t, nil } -// Scan implements the database/sql Scanner interface. +// Scan implements the [database/sql.Scanner] interface. func (dst *Text) Scan(src any) error { if src == nil { *dst = Text{} @@ -47,7 +49,7 @@ func (dst *Text) Scan(src any) error { return fmt.Errorf("cannot scan %T", src) } -// Value implements the database/sql/driver Valuer interface. +// Value implements the [database/sql/driver.Valuer] interface. func (src Text) Value() (driver.Value, error) { if !src.Valid { return nil, nil @@ -55,6 +57,7 @@ func (src Text) Value() (driver.Value, error) { return src.String, nil } +// MarshalJSON implements the [encoding/json.Marshaler] interface. func (src Text) MarshalJSON() ([]byte, error) { if !src.Valid { return []byte("null"), nil @@ -63,6 +66,7 @@ func (src Text) MarshalJSON() ([]byte, error) { return json.Marshal(src.String) } +// UnmarshalJSON implements the [encoding/json.Unmarshaler] interface. func (dst *Text) UnmarshalJSON(b []byte) error { var s *string err := json.Unmarshal(b, &s) @@ -146,7 +150,6 @@ func (encodePlanTextCodecTextValuer) Encode(value any, buf []byte) (newBuf []byt } func (TextCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan { - switch format { case TextFormatCode, BinaryFormatCode: switch target.(type) { diff --git a/vendor/github.com/jackc/pgx/v5/pgtype/tid.go b/vendor/github.com/jackc/pgx/v5/pgtype/tid.go index 9bc2c2a1..05c9e6d9 100644 --- a/vendor/github.com/jackc/pgx/v5/pgtype/tid.go +++ b/vendor/github.com/jackc/pgx/v5/pgtype/tid.go @@ -35,16 +35,18 @@ type TID struct { Valid bool } +// ScanTID implements the [TIDScanner] interface. func (b *TID) ScanTID(v TID) error { *b = v return nil } +// TIDValue implements the [TIDValuer] interface. func (b TID) TIDValue() (TID, error) { return b, nil } -// Scan implements the database/sql Scanner interface. +// Scan implements the [database/sql.Scanner] interface. func (dst *TID) Scan(src any) error { if src == nil { *dst = TID{} @@ -59,7 +61,7 @@ func (dst *TID) Scan(src any) error { return fmt.Errorf("cannot scan %T", src) } -// Value implements the database/sql/driver Valuer interface. +// Value implements the [database/sql/driver.Valuer] interface. func (src TID) Value() (driver.Value, error) { if !src.Valid { return nil, nil @@ -131,7 +133,6 @@ func (encodePlanTIDCodecText) Encode(value any, buf []byte) (newBuf []byte, err } func (TIDCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan { - switch format { case BinaryFormatCode: switch target.(type) { diff --git a/vendor/github.com/jackc/pgx/v5/pgtype/time.go b/vendor/github.com/jackc/pgx/v5/pgtype/time.go index f8fd9489..4b8f6908 100644 --- a/vendor/github.com/jackc/pgx/v5/pgtype/time.go +++ b/vendor/github.com/jackc/pgx/v5/pgtype/time.go @@ -29,16 +29,18 @@ type Time struct { Valid bool } +// ScanTime implements the [TimeScanner] interface. func (t *Time) ScanTime(v Time) error { *t = v return nil } +// TimeValue implements the [TimeValuer] interface. func (t Time) TimeValue() (Time, error) { return t, nil } -// Scan implements the database/sql Scanner interface. +// Scan implements the [database/sql.Scanner] interface. func (t *Time) Scan(src any) error { if src == nil { *t = Time{} @@ -58,7 +60,7 @@ func (t *Time) Scan(src any) error { return fmt.Errorf("cannot scan %T", src) } -// Value implements the database/sql/driver Valuer interface. +// Value implements the [database/sql/driver.Valuer] interface. func (t Time) Value() (driver.Value, error) { if !t.Valid { return nil, nil @@ -137,7 +139,6 @@ func (encodePlanTimeCodecText) Encode(value any, buf []byte) (newBuf []byte, err } func (TimeCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan { - switch format { case BinaryFormatCode: switch target.(type) { diff --git a/vendor/github.com/jackc/pgx/v5/pgtype/timestamp.go b/vendor/github.com/jackc/pgx/v5/pgtype/timestamp.go index 677a2c6e..de500a19 100644 --- a/vendor/github.com/jackc/pgx/v5/pgtype/timestamp.go +++ b/vendor/github.com/jackc/pgx/v5/pgtype/timestamp.go @@ -11,7 +11,10 @@ import ( "github.com/jackc/pgx/v5/internal/pgio" ) -const pgTimestampFormat = "2006-01-02 15:04:05.999999999" +const ( + pgTimestampFormat = "2006-01-02 15:04:05.999999999" + jsonISO8601 = "2006-01-02T15:04:05.999999999" +) type TimestampScanner interface { ScanTimestamp(v Timestamp) error @@ -28,16 +31,18 @@ type Timestamp struct { Valid bool } +// ScanTimestamp implements the [TimestampScanner] interface. func (ts *Timestamp) ScanTimestamp(v Timestamp) error { *ts = v return nil } +// TimestampValue implements the [TimestampValuer] interface. func (ts Timestamp) TimestampValue() (Timestamp, error) { return ts, nil } -// Scan implements the database/sql Scanner interface. +// Scan implements the [database/sql.Scanner] interface. func (ts *Timestamp) Scan(src any) error { if src == nil { *ts = Timestamp{} @@ -55,7 +60,7 @@ func (ts *Timestamp) Scan(src any) error { return fmt.Errorf("cannot scan %T", src) } -// Value implements the database/sql/driver Valuer interface. +// Value implements the [database/sql/driver.Valuer] interface. func (ts Timestamp) Value() (driver.Value, error) { if !ts.Valid { return nil, nil @@ -67,6 +72,7 @@ func (ts Timestamp) Value() (driver.Value, error) { return ts.Time, nil } +// MarshalJSON implements the [encoding/json.Marshaler] interface. func (ts Timestamp) MarshalJSON() ([]byte, error) { if !ts.Valid { return []byte("null"), nil @@ -76,7 +82,7 @@ func (ts Timestamp) MarshalJSON() ([]byte, error) { switch ts.InfinityModifier { case Finite: - s = ts.Time.Format(time.RFC3339Nano) + s = ts.Time.Format(jsonISO8601) case Infinity: s = "infinity" case NegativeInfinity: @@ -86,6 +92,7 @@ func (ts Timestamp) MarshalJSON() ([]byte, error) { return json.Marshal(s) } +// UnmarshalJSON implements the [encoding/json.Unmarshaler] interface. func (ts *Timestamp) UnmarshalJSON(b []byte) error { var s *string err := json.Unmarshal(b, &s) @@ -104,15 +111,23 @@ func (ts *Timestamp) UnmarshalJSON(b []byte) error { case "-infinity": *ts = Timestamp{Valid: true, InfinityModifier: -Infinity} default: - // PostgreSQL uses ISO 8601 for to_json function and casting from a string to timestamptz - tim, err := time.Parse(time.RFC3339Nano, *s) - if err != nil { - return err + // Parse time with or without timezone + tss := *s + // PostgreSQL uses ISO 8601 without timezone for to_json function and casting from a string to timestamp + tim, err := time.Parse(time.RFC3339Nano, tss) + if err == nil { + *ts = Timestamp{Time: tim, Valid: true} + return nil } - - *ts = Timestamp{Time: tim, Valid: true} + tim, err = time.ParseInLocation(jsonISO8601, tss, time.UTC) + if err == nil { + *ts = Timestamp{Time: tim, Valid: true} + return nil + } + ts.Valid = false + return fmt.Errorf("cannot unmarshal %s to timestamp with layout %s or %s (%w)", + *s, time.RFC3339Nano, jsonISO8601, err) } - return nil } @@ -161,7 +176,7 @@ func (encodePlanTimestampCodecBinary) Encode(value any, buf []byte) (newBuf []by switch ts.InfinityModifier { case Finite: t := discardTimeZone(ts.Time) - microsecSinceUnixEpoch := t.Unix()*1000000 + int64(t.Nanosecond())/1000 + microsecSinceUnixEpoch := t.Unix()*1_000_000 + int64(t.Nanosecond())/1000 microsecSinceY2K = microsecSinceUnixEpoch - microsecFromUnixEpochToY2K case Infinity: microsecSinceY2K = infinityMicrosecondOffset @@ -225,7 +240,6 @@ func discardTimeZone(t time.Time) time.Time { } func (c *TimestampCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan { - switch format { case BinaryFormatCode: switch target.(type) { @@ -265,8 +279,8 @@ func (plan *scanPlanBinaryTimestampToTimestampScanner) Scan(src []byte, dst any) ts = Timestamp{Valid: true, InfinityModifier: -Infinity} default: tim := time.Unix( - microsecFromUnixEpochToY2K/1000000+microsecSinceY2K/1000000, - (microsecFromUnixEpochToY2K%1000000*1000)+(microsecSinceY2K%1000000*1000), + microsecFromUnixEpochToY2K/1_000_000+microsecSinceY2K/1_000_000, + (microsecFromUnixEpochToY2K%1_000_000*1_000)+(microsecSinceY2K%1_000_000*1000), ).UTC() if plan.location != nil { tim = time.Date(tim.Year(), tim.Month(), tim.Day(), tim.Hour(), tim.Minute(), tim.Second(), tim.Nanosecond(), plan.location) diff --git a/vendor/github.com/jackc/pgx/v5/pgtype/timestamptz.go b/vendor/github.com/jackc/pgx/v5/pgtype/timestamptz.go index 7efbcffd..4d055bfa 100644 --- a/vendor/github.com/jackc/pgx/v5/pgtype/timestamptz.go +++ b/vendor/github.com/jackc/pgx/v5/pgtype/timestamptz.go @@ -11,10 +11,12 @@ import ( "github.com/jackc/pgx/v5/internal/pgio" ) -const pgTimestamptzHourFormat = "2006-01-02 15:04:05.999999999Z07" -const pgTimestamptzMinuteFormat = "2006-01-02 15:04:05.999999999Z07:00" -const pgTimestamptzSecondFormat = "2006-01-02 15:04:05.999999999Z07:00:00" -const microsecFromUnixEpochToY2K = 946684800 * 1000000 +const ( + pgTimestamptzHourFormat = "2006-01-02 15:04:05.999999999Z07" + pgTimestamptzMinuteFormat = "2006-01-02 15:04:05.999999999Z07:00" + pgTimestamptzSecondFormat = "2006-01-02 15:04:05.999999999Z07:00:00" + microsecFromUnixEpochToY2K = 946_684_800 * 1_000_000 +) const ( negativeInfinityMicrosecondOffset = -9223372036854775808 @@ -36,16 +38,18 @@ type Timestamptz struct { Valid bool } +// ScanTimestamptz implements the [TimestamptzScanner] interface. func (tstz *Timestamptz) ScanTimestamptz(v Timestamptz) error { *tstz = v return nil } +// TimestamptzValue implements the [TimestamptzValuer] interface. func (tstz Timestamptz) TimestamptzValue() (Timestamptz, error) { return tstz, nil } -// Scan implements the database/sql Scanner interface. +// Scan implements the [database/sql.Scanner] interface. func (tstz *Timestamptz) Scan(src any) error { if src == nil { *tstz = Timestamptz{} @@ -63,7 +67,7 @@ func (tstz *Timestamptz) Scan(src any) error { return fmt.Errorf("cannot scan %T", src) } -// Value implements the database/sql/driver Valuer interface. +// Value implements the [database/sql/driver.Valuer] interface. func (tstz Timestamptz) Value() (driver.Value, error) { if !tstz.Valid { return nil, nil @@ -75,6 +79,7 @@ func (tstz Timestamptz) Value() (driver.Value, error) { return tstz.Time, nil } +// MarshalJSON implements the [encoding/json.Marshaler] interface. func (tstz Timestamptz) MarshalJSON() ([]byte, error) { if !tstz.Valid { return []byte("null"), nil @@ -94,6 +99,7 @@ func (tstz Timestamptz) MarshalJSON() ([]byte, error) { return json.Marshal(s) } +// UnmarshalJSON implements the [encoding/json.Unmarshaler] interface. func (tstz *Timestamptz) UnmarshalJSON(b []byte) error { var s *string err := json.Unmarshal(b, &s) @@ -225,7 +231,6 @@ func (encodePlanTimestamptzCodecText) Encode(value any, buf []byte) (newBuf []by } func (c *TimestamptzCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan { - switch format { case BinaryFormatCode: switch target.(type) { @@ -265,8 +270,8 @@ func (plan *scanPlanBinaryTimestamptzToTimestamptzScanner) Scan(src []byte, dst tstz = Timestamptz{Valid: true, InfinityModifier: -Infinity} default: tim := time.Unix( - microsecFromUnixEpochToY2K/1000000+microsecSinceY2K/1000000, - (microsecFromUnixEpochToY2K%1000000*1000)+(microsecSinceY2K%1000000*1000), + microsecFromUnixEpochToY2K/1_000_000+microsecSinceY2K/1_000_000, + (microsecFromUnixEpochToY2K%1_000_000*1_000)+(microsecSinceY2K%1_000_000*1_000), ) if plan.location != nil { tim = tim.In(plan.location) diff --git a/vendor/github.com/jackc/pgx/v5/pgtype/tsvector.go b/vendor/github.com/jackc/pgx/v5/pgtype/tsvector.go new file mode 100644 index 00000000..b357948a --- /dev/null +++ b/vendor/github.com/jackc/pgx/v5/pgtype/tsvector.go @@ -0,0 +1,507 @@ +package pgtype + +import ( + "bytes" + "database/sql/driver" + "encoding/binary" + "fmt" + "strconv" + "strings" + + "github.com/jackc/pgx/v5/internal/pgio" +) + +type TSVectorScanner interface { + ScanTSVector(TSVector) error +} + +type TSVectorValuer interface { + TSVectorValue() (TSVector, error) +} + +// TSVector represents a PostgreSQL tsvector value. +type TSVector struct { + Lexemes []TSVectorLexeme + Valid bool +} + +// TSVectorLexeme represents a lexeme within a tsvector, consisting of a word and its positions. +type TSVectorLexeme struct { + Word string + Positions []TSVectorPosition +} + +// ScanTSVector implements the [TSVectorScanner] interface. +func (t *TSVector) ScanTSVector(v TSVector) error { + *t = v + return nil +} + +// TSVectorValue implements the [TSVectorValuer] interface. +func (t TSVector) TSVectorValue() (TSVector, error) { + return t, nil +} + +func (t TSVector) String() string { + buf, _ := encodePlanTSVectorCodecText{}.Encode(t, nil) + return string(buf) +} + +// Scan implements the [database/sql.Scanner] interface. +func (t *TSVector) Scan(src any) error { + if src == nil { + *t = TSVector{} + return nil + } + + switch src := src.(type) { + case string: + return scanPlanTextAnyToTSVectorScanner{}.scanString(src, t) + } + + return fmt.Errorf("cannot scan %T", src) +} + +// Value implements the [database/sql/driver.Valuer] interface. +func (t TSVector) Value() (driver.Value, error) { + if !t.Valid { + return nil, nil + } + + buf, err := TSVectorCodec{}.PlanEncode(nil, 0, TextFormatCode, t).Encode(t, nil) + if err != nil { + return nil, err + } + + return string(buf), nil +} + +// TSVectorWeight represents the weight label of a lexeme position in a tsvector. +type TSVectorWeight byte + +const ( + TSVectorWeightA = TSVectorWeight('A') + TSVectorWeightB = TSVectorWeight('B') + TSVectorWeightC = TSVectorWeight('C') + TSVectorWeightD = TSVectorWeight('D') +) + +// tsvectorWeightToBinary converts a TSVectorWeight to the 2-bit binary encoding used by PostgreSQL. +func tsvectorWeightToBinary(w TSVectorWeight) uint16 { + switch w { + case TSVectorWeightA: + return 3 + case TSVectorWeightB: + return 2 + case TSVectorWeightC: + return 1 + default: + return 0 // D or unset + } +} + +// tsvectorWeightFromBinary converts a 2-bit binary weight value to a TSVectorWeight. +func tsvectorWeightFromBinary(b uint16) TSVectorWeight { + switch b { + case 3: + return TSVectorWeightA + case 2: + return TSVectorWeightB + case 1: + return TSVectorWeightC + default: + return TSVectorWeightD + } +} + +// TSVectorPosition represents a lexeme position and its optional weight within a tsvector. +type TSVectorPosition struct { + Position uint16 + Weight TSVectorWeight +} + +func (p TSVectorPosition) String() string { + s := strconv.FormatUint(uint64(p.Position), 10) + if p.Weight != 0 && p.Weight != TSVectorWeightD { + s += string(p.Weight) + } + return s +} + +type TSVectorCodec struct{} + +func (TSVectorCodec) FormatSupported(format int16) bool { + return format == TextFormatCode || format == BinaryFormatCode +} + +func (TSVectorCodec) PreferredFormat() int16 { + return BinaryFormatCode +} + +func (TSVectorCodec) PlanEncode(m *Map, oid uint32, format int16, value any) EncodePlan { + if _, ok := value.(TSVectorValuer); !ok { + return nil + } + + switch format { + case BinaryFormatCode: + return encodePlanTSVectorCodecBinary{} + case TextFormatCode: + return encodePlanTSVectorCodecText{} + } + + return nil +} + +type encodePlanTSVectorCodecBinary struct{} + +func (encodePlanTSVectorCodecBinary) Encode(value any, buf []byte) ([]byte, error) { + tsv, err := value.(TSVectorValuer).TSVectorValue() + if err != nil { + return nil, err + } + + if !tsv.Valid { + return nil, nil + } + + buf = pgio.AppendInt32(buf, int32(len(tsv.Lexemes))) + + for _, entry := range tsv.Lexemes { + buf = append(buf, entry.Word...) + buf = append(buf, 0x00) + buf = pgio.AppendUint16(buf, uint16(len(entry.Positions))) + + // Each position is a uint16: weight (2 bits) | position (14 bits) + for _, pos := range entry.Positions { + packed := tsvectorWeightToBinary(pos.Weight)<<14 | uint16(pos.Position)&0x3FFF + buf = pgio.AppendUint16(buf, packed) + } + } + + return buf, nil +} + +type scanPlanBinaryTSVectorToTSVectorScanner struct{} + +func (scanPlanBinaryTSVectorToTSVectorScanner) Scan(src []byte, dst any) error { + scanner := (dst).(TSVectorScanner) + + if src == nil { + return scanner.ScanTSVector(TSVector{}) + } + + rp := 0 + + const ( + uint16Len = 2 + uint32Len = 4 + ) + + if len(src[rp:]) < uint32Len { + return fmt.Errorf("tsvector incomplete %v", src) + } + entryCount := int(int32(binary.BigEndian.Uint32(src[rp:]))) + rp += uint32Len + + var tsv TSVector + if entryCount > 0 { + tsv.Lexemes = make([]TSVectorLexeme, entryCount) + } + + for i := range entryCount { + nullIndex := bytes.IndexByte(src[rp:], 0x00) + if nullIndex == -1 { + return fmt.Errorf("invalid tsvector binary format: missing null terminator") + } + + lexeme := TSVectorLexeme{Word: string(src[rp : rp+nullIndex])} + rp += nullIndex + 1 // skip past null terminator + + // Read position count. + if len(src[rp:]) < uint16Len { + return fmt.Errorf("invalid tsvector binary format: incomplete position count") + } + + numPositions := int(binary.BigEndian.Uint16(src[rp:])) + rp += uint16Len + + // Read each packed position: weight (2 bits) | position (14 bits) + if len(src[rp:]) < numPositions*uint16Len { + return fmt.Errorf("invalid tsvector binary format: incomplete positions") + } + + if numPositions > 0 { + lexeme.Positions = make([]TSVectorPosition, numPositions) + for pos := range numPositions { + packed := binary.BigEndian.Uint16(src[rp:]) + rp += uint16Len + lexeme.Positions[pos] = TSVectorPosition{ + Position: packed & 0x3FFF, + Weight: tsvectorWeightFromBinary(packed >> 14), + } + } + } + + tsv.Lexemes[i] = lexeme + } + tsv.Valid = true + + return scanner.ScanTSVector(tsv) +} + +var tsvectorLexemeReplacer = strings.NewReplacer( + `\`, `\\`, + `'`, `\'`, +) + +type encodePlanTSVectorCodecText struct{} + +func (encodePlanTSVectorCodecText) Encode(value any, buf []byte) ([]byte, error) { + tsv, err := value.(TSVectorValuer).TSVectorValue() + if err != nil { + return nil, err + } + + if !tsv.Valid { + return nil, nil + } + + if buf == nil { + buf = []byte{} + } + + for i, lex := range tsv.Lexemes { + if i > 0 { + buf = append(buf, ' ') + } + + buf = append(buf, '\'') + buf = append(buf, tsvectorLexemeReplacer.Replace(lex.Word)...) + buf = append(buf, '\'') + + sep := byte(':') + for _, p := range lex.Positions { + buf = append(buf, sep) + buf = append(buf, p.String()...) + sep = ',' + } + } + + return buf, nil +} + +func (TSVectorCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan { + switch format { + case BinaryFormatCode: + switch target.(type) { + case TSVectorScanner: + return scanPlanBinaryTSVectorToTSVectorScanner{} + } + case TextFormatCode: + switch target.(type) { + case TSVectorScanner: + return scanPlanTextAnyToTSVectorScanner{} + } + } + + return nil +} + +type scanPlanTextAnyToTSVectorScanner struct{} + +func (s scanPlanTextAnyToTSVectorScanner) Scan(src []byte, dst any) error { + scanner := (dst).(TSVectorScanner) + + if src == nil { + return scanner.ScanTSVector(TSVector{}) + } + + return s.scanString(string(src), scanner) +} + +func (scanPlanTextAnyToTSVectorScanner) scanString(src string, scanner TSVectorScanner) error { + tsv, err := parseTSVector(src) + if err != nil { + return err + } + return scanner.ScanTSVector(tsv) +} + +func (c TSVectorCodec) DecodeDatabaseSQLValue(m *Map, oid uint32, format int16, src []byte) (driver.Value, error) { + return codecDecodeToTextFormat(c, m, oid, format, src) +} + +func (c TSVectorCodec) DecodeValue(m *Map, oid uint32, format int16, src []byte) (any, error) { + if src == nil { + return nil, nil + } + + var tsv TSVector + err := codecScan(c, m, oid, format, src, &tsv) + if err != nil { + return nil, err + } + return tsv, nil +} + +type tsvectorParser struct { + str string + pos int +} + +func (p *tsvectorParser) atEnd() bool { + return p.pos >= len(p.str) +} + +func (p *tsvectorParser) peek() byte { + return p.str[p.pos] +} + +func (p *tsvectorParser) consume() (byte, bool) { + if p.pos >= len(p.str) { + return 0, true + } + b := p.str[p.pos] + p.pos++ + return b, false +} + +func (p *tsvectorParser) consumeSpaces() { + for !p.atEnd() && p.peek() == ' ' { + p.consume() + } +} + +// consumeLexeme consumes a single-quoted lexeme, handling single quotes and backslash escapes. +func (p *tsvectorParser) consumeLexeme() (string, error) { + ch, end := p.consume() + if end || ch != '\'' { + return "", fmt.Errorf("invalid tsvector format: lexeme must start with a single quote") + } + + var buf strings.Builder + for { + ch, end := p.consume() + if end { + return "", fmt.Errorf("invalid tsvector format: unterminated quoted lexeme") + } + + switch ch { + case '\'': + // Escaped quote ('') — write a literal single quote + if !p.atEnd() && p.peek() == '\'' { + p.consume() + buf.WriteByte('\'') + } else { + // Closing quote — lexeme is complete + return buf.String(), nil + } + case '\\': + next, end := p.consume() + if end { + return "", fmt.Errorf("invalid tsvector format: unexpected end after backslash") + } + buf.WriteByte(next) + default: + buf.WriteByte(ch) + } + } +} + +// consumePositions consumes a comma-separated list of position[weight] values. +func (p *tsvectorParser) consumePositions() ([]TSVectorPosition, error) { + var positions []TSVectorPosition + + for { + pos, err := p.consumePosition() + if err != nil { + return nil, err + } + positions = append(positions, pos) + + if p.atEnd() || p.peek() != ',' { + break + } + + p.consume() // skip ',' + } + + return positions, nil +} + +// consumePosition consumes a single position number with optional weight letter. +func (p *tsvectorParser) consumePosition() (TSVectorPosition, error) { + start := p.pos + + for !p.atEnd() && p.peek() >= '0' && p.peek() <= '9' { + p.consume() + } + + if p.pos == start { + return TSVectorPosition{}, fmt.Errorf("invalid tsvector format: expected position number") + } + + num, err := strconv.ParseUint(p.str[start:p.pos], 10, 16) + if err != nil { + return TSVectorPosition{}, fmt.Errorf("invalid tsvector format: invalid position number %q", p.str[start:p.pos]) + } + + pos := TSVectorPosition{Position: uint16(num), Weight: TSVectorWeightD} + + // Check for optional weight letter + if !p.atEnd() { + switch p.peek() { + case 'A', 'a': + pos.Weight = TSVectorWeightA + case 'B', 'b': + pos.Weight = TSVectorWeightB + case 'C', 'c': + pos.Weight = TSVectorWeightC + case 'D', 'd': + pos.Weight = TSVectorWeightD + default: + return pos, nil + } + p.consume() + } + + return pos, nil +} + +// parseTSVector parses a PostgreSQL tsvector text representation. +func parseTSVector(s string) (TSVector, error) { + result := TSVector{} + p := &tsvectorParser{str: strings.TrimSpace(s), pos: 0} + + for !p.atEnd() { + p.consumeSpaces() + if p.atEnd() { + break + } + + word, err := p.consumeLexeme() + if err != nil { + return TSVector{}, err + } + + entry := TSVectorLexeme{Word: word} + + // Check for optional positions after ':' + if !p.atEnd() && p.peek() == ':' { + p.consume() // skip ':' + + positions, err := p.consumePositions() + if err != nil { + return TSVector{}, err + } + entry.Positions = positions + } + + result.Lexemes = append(result.Lexemes, entry) + } + + result.Valid = true + + return result, nil +} diff --git a/vendor/github.com/jackc/pgx/v5/pgtype/uint32.go b/vendor/github.com/jackc/pgx/v5/pgtype/uint32.go index f2b2fa6d..e6d4b1cf 100644 --- a/vendor/github.com/jackc/pgx/v5/pgtype/uint32.go +++ b/vendor/github.com/jackc/pgx/v5/pgtype/uint32.go @@ -3,6 +3,7 @@ package pgtype import ( "database/sql/driver" "encoding/binary" + "encoding/json" "fmt" "math" "strconv" @@ -24,16 +25,18 @@ type Uint32 struct { Valid bool } +// ScanUint32 implements the [Uint32Scanner] interface. func (n *Uint32) ScanUint32(v Uint32) error { *n = v return nil } +// Uint32Value implements the [Uint32Valuer] interface. func (n Uint32) Uint32Value() (Uint32, error) { return n, nil } -// Scan implements the database/sql Scanner interface. +// Scan implements the [database/sql.Scanner] interface. func (dst *Uint32) Scan(src any) error { if src == nil { *dst = Uint32{} @@ -67,7 +70,7 @@ func (dst *Uint32) Scan(src any) error { return nil } -// Value implements the database/sql/driver Valuer interface. +// Value implements the [database/sql/driver.Valuer] interface. func (src Uint32) Value() (driver.Value, error) { if !src.Valid { return nil, nil @@ -75,6 +78,31 @@ func (src Uint32) Value() (driver.Value, error) { return int64(src.Uint32), nil } +// MarshalJSON implements the [encoding/json.Marshaler] interface. +func (src Uint32) MarshalJSON() ([]byte, error) { + if !src.Valid { + return []byte("null"), nil + } + return json.Marshal(src.Uint32) +} + +// UnmarshalJSON implements the [encoding/json.Unmarshaler] interface. +func (dst *Uint32) UnmarshalJSON(b []byte) error { + var n *uint32 + err := json.Unmarshal(b, &n) + if err != nil { + return err + } + + if n == nil { + *dst = Uint32{} + } else { + *dst = Uint32{Uint32: *n, Valid: true} + } + + return nil +} + type Uint32Codec struct{} func (Uint32Codec) FormatSupported(format int16) bool { @@ -197,7 +225,6 @@ func (encodePlanUint32CodecTextInt64Valuer) Encode(value any, buf []byte) (newBu } func (Uint32Codec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan { - switch format { case BinaryFormatCode: switch target.(type) { diff --git a/vendor/github.com/jackc/pgx/v5/pgtype/uint64.go b/vendor/github.com/jackc/pgx/v5/pgtype/uint64.go new file mode 100644 index 00000000..68fd1661 --- /dev/null +++ b/vendor/github.com/jackc/pgx/v5/pgtype/uint64.go @@ -0,0 +1,323 @@ +package pgtype + +import ( + "database/sql/driver" + "encoding/binary" + "fmt" + "math" + "strconv" + + "github.com/jackc/pgx/v5/internal/pgio" +) + +type Uint64Scanner interface { + ScanUint64(v Uint64) error +} + +type Uint64Valuer interface { + Uint64Value() (Uint64, error) +} + +// Uint64 is the core type that is used to represent PostgreSQL types such as XID8. +type Uint64 struct { + Uint64 uint64 + Valid bool +} + +// ScanUint64 implements the [Uint64Scanner] interface. +func (n *Uint64) ScanUint64(v Uint64) error { + *n = v + return nil +} + +// Uint64Value implements the [Uint64Valuer] interface. +func (n Uint64) Uint64Value() (Uint64, error) { + return n, nil +} + +// Scan implements the [database/sql.Scanner] interface. +func (dst *Uint64) Scan(src any) error { + if src == nil { + *dst = Uint64{} + return nil + } + + var n uint64 + + switch src := src.(type) { + case int64: + if src < 0 { + return fmt.Errorf("%d is less than the minimum value for Uint64", src) + } + n = uint64(src) + case string: + un, err := strconv.ParseUint(src, 10, 64) + if err != nil { + return err + } + n = un + default: + return fmt.Errorf("cannot scan %T", src) + } + + *dst = Uint64{Uint64: n, Valid: true} + + return nil +} + +// Value implements the [database/sql/driver.Valuer] interface. +func (src Uint64) Value() (driver.Value, error) { + if !src.Valid { + return nil, nil + } + + // If the value is greater than the maximum value for int64, return it as a string instead of losing data or returning + // an error. + if src.Uint64 > math.MaxInt64 { + return strconv.FormatUint(src.Uint64, 10), nil + } + + return int64(src.Uint64), nil +} + +type Uint64Codec struct{} + +func (Uint64Codec) FormatSupported(format int16) bool { + return format == TextFormatCode || format == BinaryFormatCode +} + +func (Uint64Codec) PreferredFormat() int16 { + return BinaryFormatCode +} + +func (Uint64Codec) PlanEncode(m *Map, oid uint32, format int16, value any) EncodePlan { + switch format { + case BinaryFormatCode: + switch value.(type) { + case uint64: + return encodePlanUint64CodecBinaryUint64{} + case Uint64Valuer: + return encodePlanUint64CodecBinaryUint64Valuer{} + case Int64Valuer: + return encodePlanUint64CodecBinaryInt64Valuer{} + } + case TextFormatCode: + switch value.(type) { + case uint64: + return encodePlanUint64CodecTextUint64{} + case Int64Valuer: + return encodePlanUint64CodecTextInt64Valuer{} + } + } + + return nil +} + +type encodePlanUint64CodecBinaryUint64 struct{} + +func (encodePlanUint64CodecBinaryUint64) Encode(value any, buf []byte) (newBuf []byte, err error) { + v := value.(uint64) + return pgio.AppendUint64(buf, v), nil +} + +type encodePlanUint64CodecBinaryUint64Valuer struct{} + +func (encodePlanUint64CodecBinaryUint64Valuer) Encode(value any, buf []byte) (newBuf []byte, err error) { + v, err := value.(Uint64Valuer).Uint64Value() + if err != nil { + return nil, err + } + + if !v.Valid { + return nil, nil + } + + return pgio.AppendUint64(buf, v.Uint64), nil +} + +type encodePlanUint64CodecBinaryInt64Valuer struct{} + +func (encodePlanUint64CodecBinaryInt64Valuer) Encode(value any, buf []byte) (newBuf []byte, err error) { + v, err := value.(Int64Valuer).Int64Value() + if err != nil { + return nil, err + } + + if !v.Valid { + return nil, nil + } + + if v.Int64 < 0 { + return nil, fmt.Errorf("%d is less than minimum value for uint64", v.Int64) + } + + return pgio.AppendUint64(buf, uint64(v.Int64)), nil +} + +type encodePlanUint64CodecTextUint64 struct{} + +func (encodePlanUint64CodecTextUint64) Encode(value any, buf []byte) (newBuf []byte, err error) { + v := value.(uint64) + return append(buf, strconv.FormatUint(uint64(v), 10)...), nil +} + +type encodePlanUint64CodecTextUint64Valuer struct{} + +func (encodePlanUint64CodecTextUint64Valuer) Encode(value any, buf []byte) (newBuf []byte, err error) { + v, err := value.(Uint64Valuer).Uint64Value() + if err != nil { + return nil, err + } + + if !v.Valid { + return nil, nil + } + + return append(buf, strconv.FormatUint(v.Uint64, 10)...), nil +} + +type encodePlanUint64CodecTextInt64Valuer struct{} + +func (encodePlanUint64CodecTextInt64Valuer) Encode(value any, buf []byte) (newBuf []byte, err error) { + v, err := value.(Int64Valuer).Int64Value() + if err != nil { + return nil, err + } + + if !v.Valid { + return nil, nil + } + + if v.Int64 < 0 { + return nil, fmt.Errorf("%d is less than minimum value for uint64", v.Int64) + } + + return append(buf, strconv.FormatInt(v.Int64, 10)...), nil +} + +func (Uint64Codec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan { + switch format { + case BinaryFormatCode: + switch target.(type) { + case *uint64: + return scanPlanBinaryUint64ToUint64{} + case Uint64Scanner: + return scanPlanBinaryUint64ToUint64Scanner{} + case TextScanner: + return scanPlanBinaryUint64ToTextScanner{} + } + case TextFormatCode: + switch target.(type) { + case *uint64: + return scanPlanTextAnyToUint64{} + case Uint64Scanner: + return scanPlanTextAnyToUint64Scanner{} + } + } + + return nil +} + +func (c Uint64Codec) DecodeDatabaseSQLValue(m *Map, oid uint32, format int16, src []byte) (driver.Value, error) { + if src == nil { + return nil, nil + } + + var n uint64 + err := codecScan(c, m, oid, format, src, &n) + if err != nil { + return nil, err + } + return int64(n), nil +} + +func (c Uint64Codec) DecodeValue(m *Map, oid uint32, format int16, src []byte) (any, error) { + if src == nil { + return nil, nil + } + + var n uint64 + err := codecScan(c, m, oid, format, src, &n) + if err != nil { + return nil, err + } + return n, nil +} + +type scanPlanBinaryUint64ToUint64 struct{} + +func (scanPlanBinaryUint64ToUint64) Scan(src []byte, dst any) error { + if src == nil { + return fmt.Errorf("cannot scan NULL into %T", dst) + } + + if len(src) != 8 { + return fmt.Errorf("invalid length for uint64: %v", len(src)) + } + + p := (dst).(*uint64) + *p = binary.BigEndian.Uint64(src) + + return nil +} + +type scanPlanBinaryUint64ToUint64Scanner struct{} + +func (scanPlanBinaryUint64ToUint64Scanner) Scan(src []byte, dst any) error { + s, ok := (dst).(Uint64Scanner) + if !ok { + return ErrScanTargetTypeChanged + } + + if src == nil { + return s.ScanUint64(Uint64{}) + } + + if len(src) != 8 { + return fmt.Errorf("invalid length for uint64: %v", len(src)) + } + + n := binary.BigEndian.Uint64(src) + + return s.ScanUint64(Uint64{Uint64: n, Valid: true}) +} + +type scanPlanBinaryUint64ToTextScanner struct{} + +func (scanPlanBinaryUint64ToTextScanner) Scan(src []byte, dst any) error { + s, ok := (dst).(TextScanner) + if !ok { + return ErrScanTargetTypeChanged + } + + if src == nil { + return s.ScanText(Text{}) + } + + if len(src) != 8 { + return fmt.Errorf("invalid length for uint64: %v", len(src)) + } + + n := uint64(binary.BigEndian.Uint64(src)) + return s.ScanText(Text{String: strconv.FormatUint(n, 10), Valid: true}) +} + +type scanPlanTextAnyToUint64Scanner struct{} + +func (scanPlanTextAnyToUint64Scanner) Scan(src []byte, dst any) error { + s, ok := (dst).(Uint64Scanner) + if !ok { + return ErrScanTargetTypeChanged + } + + if src == nil { + return s.ScanUint64(Uint64{}) + } + + n, err := strconv.ParseUint(string(src), 10, 64) + if err != nil { + return err + } + + return s.ScanUint64(Uint64{Uint64: n, Valid: true}) +} diff --git a/vendor/github.com/jackc/pgx/v5/pgtype/uuid.go b/vendor/github.com/jackc/pgx/v5/pgtype/uuid.go index d57c0f2f..83d0c412 100644 --- a/vendor/github.com/jackc/pgx/v5/pgtype/uuid.go +++ b/vendor/github.com/jackc/pgx/v5/pgtype/uuid.go @@ -20,11 +20,13 @@ type UUID struct { Valid bool } +// ScanUUID implements the [UUIDScanner] interface. func (b *UUID) ScanUUID(v UUID) error { *b = v return nil } +// UUIDValue implements the [UUIDValuer] interface. func (b UUID) UUIDValue() (UUID, error) { return b, nil } @@ -67,7 +69,7 @@ func encodeUUID(src [16]byte) string { return string(buf[:]) } -// Scan implements the database/sql Scanner interface. +// Scan implements the [database/sql.Scanner] interface. func (dst *UUID) Scan(src any) error { if src == nil { *dst = UUID{} @@ -87,7 +89,7 @@ func (dst *UUID) Scan(src any) error { return fmt.Errorf("cannot scan %T", src) } -// Value implements the database/sql/driver Valuer interface. +// Value implements the [database/sql/driver.Valuer] interface. func (src UUID) Value() (driver.Value, error) { if !src.Valid { return nil, nil @@ -96,6 +98,15 @@ func (src UUID) Value() (driver.Value, error) { return encodeUUID(src.Bytes), nil } +func (src UUID) String() string { + if !src.Valid { + return "" + } + + return encodeUUID(src.Bytes) +} + +// MarshalJSON implements the [encoding/json.Marshaler] interface. func (src UUID) MarshalJSON() ([]byte, error) { if !src.Valid { return []byte("null"), nil @@ -108,6 +119,7 @@ func (src UUID) MarshalJSON() ([]byte, error) { return buff.Bytes(), nil } +// UnmarshalJSON implements the [encoding/json.Unmarshaler] interface. func (dst *UUID) UnmarshalJSON(src []byte) error { if bytes.Equal(src, []byte("null")) { *dst = UUID{} diff --git a/vendor/github.com/jackc/pgx/v5/pgtype/xml.go b/vendor/github.com/jackc/pgx/v5/pgtype/xml.go index fb4c49ad..79e3698a 100644 --- a/vendor/github.com/jackc/pgx/v5/pgtype/xml.go +++ b/vendor/github.com/jackc/pgx/v5/pgtype/xml.go @@ -113,7 +113,7 @@ func (c *XMLCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPl // https://github.com/jackc/pgx/issues/1691 -- ** anything else if wrapperPlan, nextDst, ok := TryPointerPointerScanPlan(target); ok { - if nextPlan := m.planScan(oid, format, nextDst); nextPlan != nil { + if nextPlan := m.planScan(oid, format, nextDst, 0); nextPlan != nil { if _, failed := nextPlan.(*scanPlanFail); !failed { wrapperPlan.SetNext(nextPlan) return wrapperPlan diff --git a/vendor/github.com/jackc/pgx/v5/pgxpool/pool.go b/vendor/github.com/jackc/pgx/v5/pgxpool/pool.go index fdcba724..dac8058c 100644 --- a/vendor/github.com/jackc/pgx/v5/pgxpool/pool.go +++ b/vendor/github.com/jackc/pgx/v5/pgxpool/pool.go @@ -2,8 +2,8 @@ package pgxpool import ( "context" - "fmt" - "math/rand" + "errors" + "math/rand/v2" "runtime" "strconv" "sync" @@ -15,11 +15,14 @@ import ( "github.com/jackc/puddle/v2" ) -var defaultMaxConns = int32(4) -var defaultMinConns = int32(0) -var defaultMaxConnLifetime = time.Hour -var defaultMaxConnIdleTime = time.Minute * 30 -var defaultHealthCheckPeriod = time.Minute +var ( + defaultMaxConns = int32(4) + defaultMinConns = int32(0) + defaultMinIdleConns = int32(0) + defaultMaxConnLifetime = time.Hour + defaultMaxConnIdleTime = time.Minute * 30 + defaultHealthCheckPeriod = time.Minute +) type connResource struct { conn *pgx.Conn @@ -83,15 +86,21 @@ type Pool struct { config *Config beforeConnect func(context.Context, *pgx.ConnConfig) error afterConnect func(context.Context, *pgx.Conn) error - beforeAcquire func(context.Context, *pgx.Conn) bool + prepareConn func(context.Context, *pgx.Conn) (bool, error) afterRelease func(*pgx.Conn) bool beforeClose func(*pgx.Conn) + shouldPing func(context.Context, ShouldPingParams) bool minConns int32 + minIdleConns int32 maxConns int32 maxConnLifetime time.Duration maxConnLifetimeJitter time.Duration maxConnIdleTime time.Duration healthCheckPeriod time.Duration + pingTimeout time.Duration + + healthCheckMu sync.Mutex + healthCheckTimer *time.Timer healthCheckChan chan struct{} @@ -102,12 +111,18 @@ type Pool struct { closeChan chan struct{} } +// ShouldPingParams are the parameters passed to ShouldPing. +type ShouldPingParams struct { + Conn *pgx.Conn + IdleDuration time.Duration +} + // Config is the configuration struct for creating a pool. It must be created by [ParseConfig] and then it can be // modified. type Config struct { ConnConfig *pgx.ConnConfig - // BeforeConnect is called before a new connection is made. It is passed a copy of the underlying pgx.ConnConfig and + // BeforeConnect is called before a new connection is made. It is passed a copy of the underlying [pgx.ConnConfig] and // will not impact any existing open connections. BeforeConnect func(context.Context, *pgx.ConnConfig) error @@ -117,8 +132,23 @@ type Config struct { // BeforeAcquire is called before a connection is acquired from the pool. It must return true to allow the // acquisition or false to indicate that the connection should be destroyed and a different connection should be // acquired. + // + // Deprecated: Use PrepareConn instead. If both PrepareConn and BeforeAcquire are set, PrepareConn will take + // precedence, ignoring BeforeAcquire. BeforeAcquire func(context.Context, *pgx.Conn) bool + // PrepareConn is called before a connection is acquired from the pool. If this function returns true, the connection + // is considered valid, otherwise the connection is destroyed. If the function returns a non-nil error, the instigating + // query will fail with the returned error. + // + // Specifically, this means that: + // + // - If it returns true and a nil error, the query proceeds as normal. + // - If it returns true and an error, the connection will be returned to the pool, and the instigating query will fail with the returned error. + // - If it returns false, and an error, the connection will be destroyed, and the query will fail with the returned error. + // - If it returns false and a nil error, the connection will be destroyed, and the instigating query will be retried on a new connection. + PrepareConn func(context.Context, *pgx.Conn) (bool, error) + // AfterRelease is called after a connection is released, but before it is returned to the pool. It must return true to // return the connection to the pool or false to destroy the connection. AfterRelease func(*pgx.Conn) bool @@ -126,6 +156,10 @@ type Config struct { // BeforeClose is called right before a connection is closed and removed from the pool. BeforeClose func(*pgx.Conn) + // ShouldPing is called after a connection is acquired from the pool. If it returns true, the connection is pinged to check for liveness. + // If this func is not set, the default behavior is to ping connections that have been idle for at least 1 second. + ShouldPing func(context.Context, ShouldPingParams) bool + // MaxConnLifetime is the duration since creation after which a connection will be automatically closed. MaxConnLifetime time.Duration @@ -136,6 +170,10 @@ type Config struct { // MaxConnIdleTime is the duration after which an idle connection will be automatically closed by the health check. MaxConnIdleTime time.Duration + // PingTimeout is the maximum amount of time to wait for a connection to pong before considering it as unhealthy and + // destroying it. If zero, the default is no timeout. + PingTimeout time.Duration + // MaxConns is the maximum size of the pool. The default is the greater of 4 or runtime.NumCPU(). MaxConns int32 @@ -144,6 +182,13 @@ type Config struct { // to create new connections. MinConns int32 + // MinIdleConns is the minimum number of idle connections in the pool. You can increase this to ensure that + // there are always idle connections available. This can help reduce tail latencies during request processing, + // as you can avoid the latency of establishing a new connection while handling requests. It is superior + // to MinConns for this purpose. + // Similar to MinConns, the pool might temporarily dip below MinIdleConns after connection closes. + MinIdleConns int32 + // HealthCheckPeriod is the duration between checks of the health of idle connections. HealthCheckPeriod time.Duration @@ -173,7 +218,7 @@ func New(ctx context.Context, connString string) (*Pool, error) { return NewWithConfig(ctx, config) } -// NewWithConfig creates a new Pool. config must have been created by [ParseConfig]. +// NewWithConfig creates a new [Pool]. config must have been created by [ParseConfig]. func NewWithConfig(ctx context.Context, config *Config) (*Pool, error) { // Default values are set in ParseConfig. Enforce initial creation by ParseConfig rather than setting defaults from // zero values. @@ -181,18 +226,27 @@ func NewWithConfig(ctx context.Context, config *Config) (*Pool, error) { panic("config must be created by ParseConfig") } + prepareConn := config.PrepareConn + if prepareConn == nil && config.BeforeAcquire != nil { + prepareConn = func(ctx context.Context, conn *pgx.Conn) (bool, error) { + return config.BeforeAcquire(ctx, conn), nil + } + } + p := &Pool{ config: config, beforeConnect: config.BeforeConnect, afterConnect: config.AfterConnect, - beforeAcquire: config.BeforeAcquire, + prepareConn: prepareConn, afterRelease: config.AfterRelease, beforeClose: config.BeforeClose, minConns: config.MinConns, + minIdleConns: config.MinIdleConns, maxConns: config.MaxConns, maxConnLifetime: config.MaxConnLifetime, maxConnLifetimeJitter: config.MaxConnLifetimeJitter, maxConnIdleTime: config.MaxConnIdleTime, + pingTimeout: config.PingTimeout, healthCheckPeriod: config.HealthCheckPeriod, healthCheckChan: make(chan struct{}, 1), closeChan: make(chan struct{}), @@ -206,6 +260,14 @@ func NewWithConfig(ctx context.Context, config *Config) (*Pool, error) { p.releaseTracer = t } + if config.ShouldPing != nil { + p.shouldPing = config.ShouldPing + } else { + p.shouldPing = func(ctx context.Context, params ShouldPingParams) bool { + return params.IdleDuration > time.Second + } + } + var err error p.p, err = puddle.NewPool( &puddle.Config[*connResource]{ @@ -271,7 +333,8 @@ func NewWithConfig(ctx context.Context, config *Config) (*Pool, error) { } go func() { - p.createIdleResources(ctx, int(p.minConns)) + targetIdleResources := max(int(p.minConns), int(p.minIdleConns)) + p.createIdleResources(ctx, targetIdleResources) p.backgroundHealthCheck() }() @@ -281,20 +344,20 @@ func NewWithConfig(ctx context.Context, config *Config) (*Pool, error) { // ParseConfig builds a Config from connString. It parses connString with the same behavior as [pgx.ParseConfig] with the // addition of the following variables: // -// - pool_max_conns: integer greater than 0 -// - pool_min_conns: integer 0 or greater -// - pool_max_conn_lifetime: duration string -// - pool_max_conn_idle_time: duration string -// - pool_health_check_period: duration string -// - pool_max_conn_lifetime_jitter: duration string +// - pool_max_conns: integer greater than 0 (default 4) +// - pool_min_conns: integer 0 or greater (default 0) +// - pool_max_conn_lifetime: duration string (default 1 hour) +// - pool_max_conn_idle_time: duration string (default 30 minutes) +// - pool_health_check_period: duration string (default 1 minute) +// - pool_max_conn_lifetime_jitter: duration string (default 0) // // See Config for definitions of these arguments. // // # Example Keyword/Value -// user=jack password=secret host=pg.example.com port=5432 dbname=mydb sslmode=verify-ca pool_max_conns=10 +// user=jack password=secret host=pg.example.com port=5432 dbname=mydb sslmode=verify-ca pool_max_conns=10 pool_max_conn_lifetime=1h30m // // # Example URL -// postgres://jack:secret@pg.example.com:5432/mydb?sslmode=verify-ca&pool_max_conns=10 +// postgres://jack:secret@pg.example.com:5432/mydb?sslmode=verify-ca&pool_max_conns=10&pool_max_conn_lifetime=1h30m func ParseConfig(connString string) (*Config, error) { connConfig, err := pgx.ParseConfig(connString) if err != nil { @@ -310,10 +373,10 @@ func ParseConfig(connString string) (*Config, error) { delete(connConfig.Config.RuntimeParams, "pool_max_conns") n, err := strconv.ParseInt(s, 10, 32) if err != nil { - return nil, fmt.Errorf("cannot parse pool_max_conns: %w", err) + return nil, pgconn.NewParseConfigError(connString, "cannot parse pool_max_conns", err) } if n < 1 { - return nil, fmt.Errorf("pool_max_conns too small: %d", n) + return nil, pgconn.NewParseConfigError(connString, "pool_max_conns too small", err) } config.MaxConns = int32(n) } else { @@ -327,18 +390,29 @@ func ParseConfig(connString string) (*Config, error) { delete(connConfig.Config.RuntimeParams, "pool_min_conns") n, err := strconv.ParseInt(s, 10, 32) if err != nil { - return nil, fmt.Errorf("cannot parse pool_min_conns: %w", err) + return nil, pgconn.NewParseConfigError(connString, "cannot parse pool_min_conns", err) } config.MinConns = int32(n) } else { config.MinConns = defaultMinConns } + if s, ok := config.ConnConfig.Config.RuntimeParams["pool_min_idle_conns"]; ok { + delete(connConfig.Config.RuntimeParams, "pool_min_idle_conns") + n, err := strconv.ParseInt(s, 10, 32) + if err != nil { + return nil, pgconn.NewParseConfigError(connString, "cannot parse pool_min_idle_conns", err) + } + config.MinIdleConns = int32(n) + } else { + config.MinIdleConns = defaultMinIdleConns + } + if s, ok := config.ConnConfig.Config.RuntimeParams["pool_max_conn_lifetime"]; ok { delete(connConfig.Config.RuntimeParams, "pool_max_conn_lifetime") d, err := time.ParseDuration(s) if err != nil { - return nil, fmt.Errorf("invalid pool_max_conn_lifetime: %w", err) + return nil, pgconn.NewParseConfigError(connString, "cannot parse pool_max_conn_lifetime", err) } config.MaxConnLifetime = d } else { @@ -349,7 +423,7 @@ func ParseConfig(connString string) (*Config, error) { delete(connConfig.Config.RuntimeParams, "pool_max_conn_idle_time") d, err := time.ParseDuration(s) if err != nil { - return nil, fmt.Errorf("invalid pool_max_conn_idle_time: %w", err) + return nil, pgconn.NewParseConfigError(connString, "cannot parse pool_max_conn_idle_time", err) } config.MaxConnIdleTime = d } else { @@ -360,7 +434,7 @@ func ParseConfig(connString string) (*Config, error) { delete(connConfig.Config.RuntimeParams, "pool_health_check_period") d, err := time.ParseDuration(s) if err != nil { - return nil, fmt.Errorf("invalid pool_health_check_period: %w", err) + return nil, pgconn.NewParseConfigError(connString, "cannot parse pool_health_check_period", err) } config.HealthCheckPeriod = d } else { @@ -371,7 +445,7 @@ func ParseConfig(connString string) (*Config, error) { delete(connConfig.Config.RuntimeParams, "pool_max_conn_lifetime_jitter") d, err := time.ParseDuration(s) if err != nil { - return nil, fmt.Errorf("invalid pool_max_conn_lifetime_jitter: %w", err) + return nil, pgconn.NewParseConfigError(connString, "cannot parse pool_max_conn_lifetime_jitter", err) } config.MaxConnLifetimeJitter = d } @@ -379,7 +453,7 @@ func ParseConfig(connString string) (*Config, error) { return config, nil } -// Close closes all connections in the pool and rejects future Acquire calls. Blocks until all connections are returned +// Close closes all connections in the pool and rejects future [Pool.Acquire] calls. Blocks until all connections are returned // to pool and closed. func (p *Pool) Close() { p.closeOnce.Do(func() { @@ -393,15 +467,25 @@ func (p *Pool) isExpired(res *puddle.Resource[*connResource]) bool { } func (p *Pool) triggerHealthCheck() { - go func() { + const healthCheckDelay = 500 * time.Millisecond + + p.healthCheckMu.Lock() + defer p.healthCheckMu.Unlock() + + if p.healthCheckTimer == nil { // Destroy is asynchronous so we give it time to actually remove itself from // the pool otherwise we might try to check the pool size too soon - time.Sleep(500 * time.Millisecond) - select { - case p.healthCheckChan <- struct{}{}: - default: - } - }() + p.healthCheckTimer = time.AfterFunc(healthCheckDelay, func() { + select { + case <-p.closeChan: + case p.healthCheckChan <- struct{}{}: + default: + } + }) + return + } + + p.healthCheckTimer.Reset(healthCheckDelay) } func (p *Pool) backgroundHealthCheck() { @@ -472,7 +556,10 @@ func (p *Pool) checkMinConns() error { // TotalConns can include ones that are being destroyed but we should have // sleep(500ms) around all of the destroys to help prevent that from throwing // off this check - toCreate := p.minConns - p.Stat().TotalConns() + + // Create the number of connections needed to get to both minConns and minIdleConns + stat := p.Stat() + toCreate := max(p.minConns-stat.TotalConns(), p.minIdleConns-stat.IdleConns()) if toCreate > 0 { return p.createIdleResources(context.Background(), int(toCreate)) } @@ -485,7 +572,7 @@ func (p *Pool) createIdleResources(parentCtx context.Context, targetResources in errs := make(chan error, targetResources) - for i := 0; i < targetResources; i++ { + for range targetResources { go func() { err := p.p.CreateResource(ctx) // Ignore ErrNotAvailable since it means that the pool has become full since we started creating resource. @@ -497,7 +584,7 @@ func (p *Pool) createIdleResources(parentCtx context.Context, targetResources in } var firstError error - for i := 0; i < targetResources; i++ { + for range targetResources { err := <-errs if err != nil && firstError == nil { cancel() @@ -508,7 +595,7 @@ func (p *Pool) createIdleResources(parentCtx context.Context, targetResources in return firstError } -// Acquire returns a connection (*Conn) from the Pool +// Acquire returns a connection ([Conn]) from the [Pool]. func (p *Pool) Acquire(ctx context.Context) (c *Conn, err error) { if p.acquireTracer != nil { ctx = p.acquireTracer.TraceAcquireStart(ctx, p, TraceAcquireStartData{}) @@ -521,7 +608,10 @@ func (p *Pool) Acquire(ctx context.Context) (c *Conn, err error) { }() } - for { + // Try to acquire from the connection pool up to maxConns + 1 times, so that + // any that fatal errors would empty the pool and still at least try 1 fresh + // connection. + for range int(p.maxConns) + 1 { res, err := p.p.Acquire(ctx) if err != nil { return nil, err @@ -529,24 +619,46 @@ func (p *Pool) Acquire(ctx context.Context) (c *Conn, err error) { cr := res.Value() - if res.IdleDuration() > time.Second { - err := cr.conn.Ping(ctx) + shouldPingParams := ShouldPingParams{Conn: cr.conn, IdleDuration: res.IdleDuration()} + if p.shouldPing(ctx, shouldPingParams) { + err := func() error { + pingCtx := ctx + if p.pingTimeout > 0 { + var cancel context.CancelFunc + pingCtx, cancel = context.WithTimeout(ctx, p.pingTimeout) + defer cancel() + } + return cr.conn.Ping(pingCtx) + }() if err != nil { res.Destroy() continue } } - if p.beforeAcquire == nil || p.beforeAcquire(ctx, cr.conn) { - return cr.getConn(p, res), nil + if p.prepareConn != nil { + ok, err := p.prepareConn(ctx, cr.conn) + if !ok { + res.Destroy() + } + if err != nil { + if ok { + res.Release() + } + return nil, err + } + if !ok { + continue + } } - res.Destroy() + return cr.getConn(p, res), nil } + return nil, errors.New("pgxpool: too many failed attempts acquiring connection; likely bug in PrepareConn, BeforeAcquire, or ShouldPing hook") } -// AcquireFunc acquires a *Conn and calls f with that *Conn. ctx will only affect the Acquire. It has no effect on the -// call of f. The return value is either an error acquiring the *Conn or the return value of f. The *Conn is +// AcquireFunc acquires a [Conn] and calls f with that [Conn]. ctx will only affect the [Pool.Acquire]. It has no effect on the +// call of f. The return value is either an error acquiring the [Conn] or the return value of f. The [Conn] is // automatically released after the call of f. func (p *Pool) AcquireFunc(ctx context.Context, f func(*Conn) error) error { conn, err := p.Acquire(ctx) @@ -565,11 +677,14 @@ func (p *Pool) AcquireAllIdle(ctx context.Context) []*Conn { conns := make([]*Conn, 0, len(resources)) for _, res := range resources { cr := res.Value() - if p.beforeAcquire == nil || p.beforeAcquire(ctx, cr.conn) { - conns = append(conns, cr.getConn(p, res)) - } else { - res.Destroy() + if p.prepareConn != nil { + ok, err := p.prepareConn(ctx, cr.conn) + if !ok || err != nil { + res.Destroy() + continue + } } + conns = append(conns, cr.getConn(p, res)) } return conns @@ -584,7 +699,7 @@ func (p *Pool) Reset() { p.p.Reset() } -// Config returns a copy of config that was used to initialize this pool. +// Config returns a copy of config that was used to initialize this [Pool]. func (p *Pool) Config() *Config { return p.config.Copy() } // Stat returns a pgxpool.Stat struct with a snapshot of Pool statistics. @@ -597,10 +712,10 @@ func (p *Pool) Stat() *Stat { } } -// Exec acquires a connection from the Pool and executes the given SQL. +// Exec acquires a connection from the [Pool] and executes the given SQL. // SQL can be either a prepared statement name or an SQL string. // Arguments should be referenced positionally from the SQL string as $1, $2, etc. -// The acquired connection is returned to the pool when the Exec function returns. +// The acquired connection is returned to the pool when the [Pool.Exec] function returns. func (p *Pool) Exec(ctx context.Context, sql string, arguments ...any) (pgconn.CommandTag, error) { c, err := p.Acquire(ctx) if err != nil { @@ -611,15 +726,15 @@ func (p *Pool) Exec(ctx context.Context, sql string, arguments ...any) (pgconn.C return c.Exec(ctx, sql, arguments...) } -// Query acquires a connection and executes a query that returns pgx.Rows. +// Query acquires a connection and executes a query that returns [pgx.Rows]. // Arguments should be referenced positionally from the SQL string as $1, $2, etc. -// See pgx.Rows documentation to close the returned Rows and return the acquired connection to the Pool. +// See [pgx.Rows] documentation to close the returned [pgx.Rows] and return the acquired connection to the [Pool]. // -// If there is an error, the returned pgx.Rows will be returned in an error state. -// If preferred, ignore the error returned from Query and handle errors using the returned pgx.Rows. +// If there is an error, the returned [pgx.Rows] will be returned in an error state. +// If preferred, ignore the error returned from [Pool.Query] and handle errors using the returned [pgx.Rows]. // -// For extra control over how the query is executed, the types QuerySimpleProtocol, QueryResultFormats, and -// QueryResultFormatsByOID may be used as the first args to control exactly how the query is executed. This is rarely +// For extra control over how the query is executed, the types [pgx.QueryExecMode], [pgx.QueryResultFormats], and +// [pgx.QueryResultFormatsByOID] may be used as the first args to control exactly how the query is executed. This is rarely // needed. See the documentation for those types for details. func (p *Pool) Query(ctx context.Context, sql string, args ...any) (pgx.Rows, error) { c, err := p.Acquire(ctx) @@ -637,16 +752,16 @@ func (p *Pool) Query(ctx context.Context, sql string, args ...any) (pgx.Rows, er } // QueryRow acquires a connection and executes a query that is expected -// to return at most one row (pgx.Row). Errors are deferred until pgx.Row's -// Scan method is called. If the query selects no rows, pgx.Row's Scan will -// return ErrNoRows. Otherwise, pgx.Row's Scan scans the first selected row -// and discards the rest. The acquired connection is returned to the Pool when -// pgx.Row's Scan method is called. +// to return at most one row ([pgx.Row]). Errors are deferred until [pgx.Row]'s +// Scan method is called. If the query selects no rows, [pgx.Row]'s Scan will +// return [pgx.ErrNoRows]. Otherwise, [pgx.Row]'s Scan scans the first selected row +// and discards the rest. The acquired connection is returned to the [Pool] when +// [pgx.Row]'s Scan method is called. // // Arguments should be referenced positionally from the SQL string as $1, $2, etc. // -// For extra control over how the query is executed, the types QuerySimpleProtocol, QueryResultFormats, and -// QueryResultFormatsByOID may be used as the first args to control exactly how the query is executed. This is rarely +// For extra control over how the query is executed, the types [pgx.QueryExecMode], [pgx.QueryResultFormats], and +// [pgx.QueryResultFormatsByOID] may be used as the first args to control exactly how the query is executed. This is rarely // needed. See the documentation for those types for details. func (p *Pool) QueryRow(ctx context.Context, sql string, args ...any) pgx.Row { c, err := p.Acquire(ctx) @@ -668,18 +783,18 @@ func (p *Pool) SendBatch(ctx context.Context, b *pgx.Batch) pgx.BatchResults { return &poolBatchResults{br: br, c: c} } -// Begin acquires a connection from the Pool and starts a transaction. Unlike database/sql, the context only affects the begin command. i.e. there is no -// auto-rollback on context cancellation. Begin initiates a transaction block without explicitly setting a transaction mode for the block (see BeginTx with TxOptions if transaction mode is required). -// *pgxpool.Tx is returned, which implements the pgx.Tx interface. -// Commit or Rollback must be called on the returned transaction to finalize the transaction block. +// Begin acquires a connection from the [Pool] and starts a transaction. Unlike [database/sql], the context only affects the begin command. i.e. there is no +// auto-rollback on context cancellation. Begin initiates a transaction block without explicitly setting a transaction mode for the block (see [Pool.BeginTx] with [pgx.TxOptions] if transaction mode is required). +// [*Tx] is returned, which implements the [pgx.Tx] interface. +// [Tx.Commit] or [Tx.Rollback] must be called on the returned transaction to finalize the transaction block. func (p *Pool) Begin(ctx context.Context) (pgx.Tx, error) { return p.BeginTx(ctx, pgx.TxOptions{}) } -// BeginTx acquires a connection from the Pool and starts a transaction with pgx.TxOptions determining the transaction mode. -// Unlike database/sql, the context only affects the begin command. i.e. there is no auto-rollback on context cancellation. -// *pgxpool.Tx is returned, which implements the pgx.Tx interface. -// Commit or Rollback must be called on the returned transaction to finalize the transaction block. +// BeginTx acquires a connection from the [Pool] and starts a transaction with [pgx.TxOptions] determining the transaction mode. +// Unlike [database/sql], the context only affects the begin command. i.e. there is no auto-rollback on context cancellation. +// [*Tx] is returned, which implements the [pgx.Tx] interface. +// [Tx.Commit] or [Tx.Rollback] must be called on the returned transaction to finalize the transaction block. func (p *Pool) BeginTx(ctx context.Context, txOptions pgx.TxOptions) (pgx.Tx, error) { c, err := p.Acquire(ctx) if err != nil { @@ -705,8 +820,8 @@ func (p *Pool) CopyFrom(ctx context.Context, tableName pgx.Identifier, columnNam return c.Conn().CopyFrom(ctx, tableName, columnNames, rowSrc) } -// Ping acquires a connection from the Pool and executes an empty sql statement against it. -// If the sql returns without error, the database Ping is considered successful, otherwise, the error is returned. +// Ping acquires a connection from the [Pool] and executes an empty sql statement against it. +// If the sql returns without error, the database [Pool.Ping] is considered successful, otherwise, the error is returned. func (p *Pool) Ping(ctx context.Context) error { c, err := p.Acquire(ctx) if err != nil { diff --git a/vendor/github.com/jackc/pgx/v5/pgxpool/stat.go b/vendor/github.com/jackc/pgx/v5/pgxpool/stat.go index cfa0c4c5..e02b6ac3 100644 --- a/vendor/github.com/jackc/pgx/v5/pgxpool/stat.go +++ b/vendor/github.com/jackc/pgx/v5/pgxpool/stat.go @@ -82,3 +82,10 @@ func (s *Stat) MaxLifetimeDestroyCount() int64 { func (s *Stat) MaxIdleDestroyCount() int64 { return s.idleDestroyCount } + +// EmptyAcquireWaitTime returns the cumulative time waited for successful acquires +// from the pool for a resource to be released or constructed because the pool was +// empty. +func (s *Stat) EmptyAcquireWaitTime() time.Duration { + return s.s.EmptyAcquireWaitTime() +} diff --git a/vendor/github.com/jackc/pgx/v5/rows.go b/vendor/github.com/jackc/pgx/v5/rows.go index f23625d4..2c5d2424 100644 --- a/vendor/github.com/jackc/pgx/v5/rows.go +++ b/vendor/github.com/jackc/pgx/v5/rows.go @@ -13,12 +13,12 @@ import ( "github.com/jackc/pgx/v5/pgtype" ) -// Rows is the result set returned from *Conn.Query. Rows must be closed before -// the *Conn can be used again. Rows are closed by explicitly calling Close(), -// calling Next() until it returns false, or when a fatal error occurs. +// Rows is the result set returned from [Conn.Query]. Rows must be closed before +// the [Conn] can be used again. Rows are closed by explicitly calling [Rows.Close], +// calling [Rows.Next] until it returns false, or when a fatal error occurs. // -// Once a Rows is closed the only methods that may be called are Close(), Err(), -// and CommandTag(). +// Once a Rows is closed the only methods that may be called are [Rows.Close], [Rows.Err], +// and [Rows.CommandTag]. // // Rows is an interface instead of a struct to allow tests to mock Query. However, // adding a method to an interface is technically a breaking change. Because of this @@ -29,9 +29,9 @@ type Rows interface { // to call Close after rows is already closed. Close() - // Err returns any error that occurred while reading. Err must only be called after the Rows is closed (either by - // calling Close or by Next returning false). If it is called early it may return nil even if there was an error - // executing the query. + // Err returns any error that occurred while executing a query or reading its results. Err must be called after the + // Rows is closed (either by calling Close or by Next returning false) to check if the query was successful. If it is + // called before the Rows is closed it may return nil even if the query failed on the server. Err() error // CommandTag returns the command tag from this query. It is only available after Rows is closed. @@ -41,22 +41,19 @@ type Rows interface { // when there was an error executing the query. FieldDescriptions() []pgconn.FieldDescription - // Next prepares the next row for reading. It returns true if there is another - // row and false if no more rows are available or a fatal error has occurred. - // It automatically closes rows when all rows are read. + // Next prepares the next row for reading. It returns true if there is another row and false if no more rows are + // available or a fatal error has occurred. It automatically closes rows upon returning false (whether due to all rows + // having been read or due to an error). // - // Callers should check rows.Err() after rows.Next() returns false to detect - // whether result-set reading ended prematurely due to an error. See - // Conn.Query for details. + // Callers should check rows.Err() after rows.Next() returns false to detect whether result-set reading ended + // prematurely due to an error. See [Conn.Query] for details. // - // For simpler error handling, consider using the higher-level pgx v5 - // CollectRows() and ForEachRow() helpers instead. + // For simpler error handling, consider using the higher-level pgx v5 [CollectRows()] and [ForEachRow()] helpers instead. Next() bool - // Scan reads the values from the current row into dest values positionally. - // dest can include pointers to core types, values implementing the Scanner - // interface, and nil. nil will skip the value entirely. It is an error to - // call Scan without first calling Next() and checking that it returned true. + // Scan reads the values from the current row into dest values positionally. dest can include pointers to core types, + // values implementing the Scanner interface, and nil. nil will skip the value entirely. It is an error to call Scan + // without first calling Next() and checking that it returned true. Rows is automatically closed upon error. Scan(dest ...any) error // Values returns the decoded row values. As with Scan(), it is an error to @@ -73,7 +70,7 @@ type Rows interface { Conn() *Conn } -// Row is a convenience wrapper over Rows that is returned by QueryRow. +// Row is a convenience wrapper over [Rows] that is returned by [Conn.QueryRow]. // // Row is an interface instead of a struct to allow tests to mock QueryRow. However, // adding a method to an interface is technically a breaking change. Because of this @@ -188,6 +185,17 @@ func (rows *baseRows) Close() { } else if rows.queryTracer != nil { rows.queryTracer.TraceQueryEnd(rows.ctx, rows.conn, TraceQueryEndData{rows.commandTag, rows.err}) } + + // Zero references to other memory allocations. This allows them to be GC'd even when the Rows still referenced. In + // particular, when using pgxpool GC could be delayed as pgxpool.poolRows are allocated in large slices. + // + // https://github.com/jackc/pgx/pull/2269 + rows.values = nil + rows.scanPlans = nil + rows.scanTypes = nil + rows.ctx = nil + rows.sql = "" + rows.args = nil } func (rows *baseRows) CommandTag() pgconn.CommandTag { @@ -272,7 +280,7 @@ func (rows *baseRows) Scan(dest ...any) error { err := rows.scanPlans[i].Scan(values[i], dst) if err != nil { - err = ScanArgError{ColumnIndex: i, Err: err} + err = ScanArgError{ColumnIndex: i, FieldName: fieldDescriptions[i].Name, Err: err} rows.fatal(err) return err } @@ -334,18 +342,23 @@ func (rows *baseRows) Conn() *Conn { type ScanArgError struct { ColumnIndex int + FieldName string Err error } func (e ScanArgError) Error() string { - return fmt.Sprintf("can't scan into dest[%d]: %v", e.ColumnIndex, e.Err) + if e.FieldName == "?column?" { // Don't include the fieldname if it's unknown + return fmt.Sprintf("can't scan into dest[%d]: %v", e.ColumnIndex, e.Err) + } + + return fmt.Sprintf("can't scan into dest[%d] (col: %s): %v", e.ColumnIndex, e.FieldName, e.Err) } func (e ScanArgError) Unwrap() error { return e.Err } -// ScanRow decodes raw row data into dest. It can be used to scan rows read from the lower level pgconn interface. +// ScanRow decodes raw row data into dest. It can be used to scan rows read from the lower level [pgconn] interface. // // typeMap - OID to Go type mapping. // fieldDescriptions - OID and format of values @@ -366,15 +379,15 @@ func ScanRow(typeMap *pgtype.Map, fieldDescriptions []pgconn.FieldDescription, v err := typeMap.Scan(fieldDescriptions[i].DataTypeOID, fieldDescriptions[i].Format, values[i], d) if err != nil { - return ScanArgError{ColumnIndex: i, Err: err} + return ScanArgError{ColumnIndex: i, FieldName: fieldDescriptions[i].Name, Err: err} } } return nil } -// RowsFromResultReader returns a Rows that will read from values resultReader and decode with typeMap. It can be used -// to read from the lower level pgconn interface. +// RowsFromResultReader returns a [Rows] that will read from values resultReader and decode with typeMap. It can be used +// to read from the lower level [pgconn] interface. func RowsFromResultReader(typeMap *pgtype.Map, resultReader *pgconn.ResultReader) Rows { return &baseRows{ typeMap: typeMap, @@ -447,7 +460,7 @@ func CollectRows[T any](rows Rows, fn RowToFunc[T]) ([]T, error) { } // CollectOneRow calls fn for the first row in rows and returns the result. If no rows are found returns an error where errors.Is(ErrNoRows) is true. -// CollectOneRow is to CollectRows as QueryRow is to Query. +// CollectOneRow is to [CollectRows] as [Conn.QueryRow] is to [Conn.Query]. // // This function closes the rows automatically on return. func CollectOneRow[T any](rows Rows, fn RowToFunc[T]) (T, error) { @@ -468,6 +481,8 @@ func CollectOneRow[T any](rows Rows, fn RowToFunc[T]) (T, error) { return value, err } + // The defer rows.Close() won't have executed yet. If the query returned more than one row, rows would still be open. + // rows.Close() must be called before rows.Err() so we explicitly call it here. rows.Close() return value, rows.Err() } @@ -514,7 +529,7 @@ func RowTo[T any](row CollectableRow) (T, error) { return value, err } -// RowTo returns a the address of a T scanned from row. +// RowToAddrOf returns the address of a T scanned from row. func RowToAddrOf[T any](row CollectableRow) (*T, error) { var value T err := row.Scan(&value) @@ -545,7 +560,7 @@ func (rs *mapRowScanner) ScanRow(rows Rows) error { return nil } -// RowToStructByPos returns a T scanned from row. T must be a struct. T must have the same number a public fields as row +// RowToStructByPos returns a T scanned from row. T must be a struct. T must have the same number of public fields as row // has fields. The row and T fields will be matched by position. If the "db" struct tag is "-" then the field will be // ignored. func RowToStructByPos[T any](row CollectableRow) (T, error) { @@ -833,7 +848,7 @@ func fieldPosByName(fldDescs []pgconn.FieldDescription, field string, normalize } } } - return + return i } // structRowField describes a field of a struct. diff --git a/vendor/github.com/jackc/pgx/v5/test.sh b/vendor/github.com/jackc/pgx/v5/test.sh new file mode 100644 index 00000000..8bab2d28 --- /dev/null +++ b/vendor/github.com/jackc/pgx/v5/test.sh @@ -0,0 +1,170 @@ +#!/usr/bin/env bash +set -euo pipefail + +# test.sh - Run pgx tests against specific database targets +# +# Usage: +# ./test.sh [target] [go test flags...] +# +# Targets: +# pg14 - PostgreSQL 14 (port 5414) +# pg15 - PostgreSQL 15 (port 5415) +# pg16 - PostgreSQL 16 (port 5416) +# pg17 - PostgreSQL 17 (port 5417) +# pg18 - PostgreSQL 18 (port 5432) [default] +# crdb - CockroachDB (port 26257) +# all - Run against all targets sequentially +# +# Examples: +# ./test.sh # Test against PG18 +# ./test.sh pg14 # Test against PG14 +# ./test.sh crdb # Test against CockroachDB +# ./test.sh all # Test against all targets +# ./test.sh pg16 -run TestConnect # Test specific test against PG16 +# ./test.sh pg18 -count=1 -v # Verbose, no cache, PG18 + +# Color output (disabled if not a terminal) +if [ -t 1 ]; then + GREEN='\033[0;32m' + RED='\033[0;31m' + BLUE='\033[0;34m' + NC='\033[0m' +else + GREEN='' + RED='' + BLUE='' + NC='' +fi + +log_info() { echo -e "${BLUE}==> $*${NC}"; } +log_ok() { echo -e "${GREEN}==> $*${NC}"; } +log_err() { echo -e "${RED}==> $*${NC}" >&2; } + +# Wait for a database to accept connections +wait_for_ready() { + local connstr="$1" + local label="$2" + local max_attempts=30 + local attempt=0 + + log_info "Waiting for $label to be ready..." + while ! psql "$connstr" -c "SELECT 1" > /dev/null 2>&1; do + attempt=$((attempt + 1)) + if [ "$attempt" -ge "$max_attempts" ]; then + log_err "$label did not become ready after $max_attempts attempts" + return 1 + fi + sleep 1 + done + log_ok "$label is ready" +} + +# Directory containing this script (used to locate testsetup/) +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +CERTS_DIR="$SCRIPT_DIR/testsetup/certs" + +# Copy client certificates to /tmp for TLS tests +setup_client_certs() { + if [ -d "$CERTS_DIR" ]; then + base64 -d "$CERTS_DIR/ca.pem.b64" > /tmp/ca.pem + base64 -d "$CERTS_DIR/pgx_sslcert.crt.b64" > /tmp/pgx_sslcert.crt + base64 -d "$CERTS_DIR/pgx_sslcert.key.b64" > /tmp/pgx_sslcert.key + fi +} + +# Initialize CockroachDB (create database if not exists) +init_crdb() { + local connstr="postgresql://root@localhost:26257/?sslmode=disable" + wait_for_ready "$connstr" "CockroachDB" + log_info "Ensuring pgx_test database exists on CockroachDB..." + psql "$connstr" -c "CREATE DATABASE IF NOT EXISTS pgx_test" 2>/dev/null || true +} + +# Run tests against a single target +run_tests() { + local target="$1" + shift + local extra_args=("$@") + + local label="" + local port="" + + case "$target" in + pg14) label="PostgreSQL 14"; port=5414 ;; + pg15) label="PostgreSQL 15"; port=5415 ;; + pg16) label="PostgreSQL 16"; port=5416 ;; + pg17) label="PostgreSQL 17"; port=5417 ;; + pg18) label="PostgreSQL 18"; port=5432 ;; + crdb) + label="CockroachDB (port 26257)" + init_crdb + log_info "Testing against $label" + if ! PGX_TEST_DATABASE="postgresql://root@localhost:26257/pgx_test?sslmode=disable&experimental_enable_temp_tables=on" \ + go test -count=1 "${extra_args[@]}" ./...; then + log_err "Tests FAILED against $label" + return 1 + fi + log_ok "Tests passed against $label" + return 0 + ;; + *) + log_err "Unknown target: $target" + log_err "Valid targets: pg14, pg15, pg16, pg17, pg18, crdb, all" + return 1 + ;; + esac + + setup_client_certs + + log_info "Testing against $label (port $port)" + if ! PGX_TEST_DATABASE="host=localhost port=$port user=postgres password=postgres dbname=pgx_test" \ + PGX_TEST_UNIX_SOCKET_CONN_STRING="host=/var/run/postgresql port=$port user=postgres dbname=pgx_test" \ + PGX_TEST_TCP_CONN_STRING="host=127.0.0.1 port=$port user=pgx_md5 password=secret dbname=pgx_test" \ + PGX_TEST_MD5_PASSWORD_CONN_STRING="host=127.0.0.1 port=$port user=pgx_md5 password=secret dbname=pgx_test" \ + PGX_TEST_SCRAM_PASSWORD_CONN_STRING="host=127.0.0.1 port=$port user=pgx_scram password=secret dbname=pgx_test channel_binding=disable" \ + PGX_TEST_SCRAM_PLUS_CONN_STRING="host=localhost port=$port user=pgx_ssl password=secret sslmode=verify-full sslrootcert=/tmp/ca.pem dbname=pgx_test channel_binding=require" \ + PGX_TEST_PLAIN_PASSWORD_CONN_STRING="host=127.0.0.1 port=$port user=pgx_pw password=secret dbname=pgx_test" \ + PGX_TEST_TLS_CONN_STRING="host=localhost port=$port user=pgx_ssl password=secret sslmode=verify-full sslrootcert=/tmp/ca.pem dbname=pgx_test channel_binding=disable" \ + PGX_TEST_TLS_CLIENT_CONN_STRING="host=localhost port=$port user=pgx_sslcert sslmode=verify-full sslrootcert=/tmp/ca.pem sslcert=/tmp/pgx_sslcert.crt sslkey=/tmp/pgx_sslcert.key dbname=pgx_test" \ + PGX_SSL_PASSWORD=certpw \ + go test -count=1 "${extra_args[@]}" ./...; then + log_err "Tests FAILED against $label" + return 1 + fi + log_ok "Tests passed against $label" +} + +# Main +main() { + local target="${1:-pg18}" + + if [ "$target" = "all" ]; then + shift || true + local targets=(pg14 pg15 pg16 pg17 pg18 crdb) + local failed=() + + for t in "${targets[@]}"; do + echo "" + log_info "==========================================" + log_info "Target: $t" + log_info "==========================================" + if ! run_tests "$t" "$@"; then + failed+=("$t") + log_err "FAILED: $t" + fi + done + + echo "" + if [ ${#failed[@]} -gt 0 ]; then + log_err "Failed targets: ${failed[*]}" + return 1 + else + log_ok "All targets passed" + fi + else + shift || true + run_tests "$target" "$@" + fi +} + +main "$@" diff --git a/vendor/github.com/jackc/pgx/v5/tx.go b/vendor/github.com/jackc/pgx/v5/tx.go index 8feeb512..3f93a6f2 100644 --- a/vendor/github.com/jackc/pgx/v5/tx.go +++ b/vendor/github.com/jackc/pgx/v5/tx.go @@ -3,7 +3,6 @@ package pgx import ( "context" "errors" - "fmt" "strconv" "strings" @@ -48,6 +47,8 @@ type TxOptions struct { // BeginQuery is the SQL query that will be executed to begin the transaction. This allows using non-standard syntax // such as BEGIN PRIORITY HIGH with CockroachDB. If set this will override the other settings. BeginQuery string + // CommitQuery is the SQL query that will be executed to commit the transaction. + CommitQuery string } var emptyTxOptions TxOptions @@ -88,24 +89,27 @@ var ErrTxClosed = errors.New("tx is closed") // it is treated as ROLLBACK. var ErrTxCommitRollback = errors.New("commit unexpectedly resulted in rollback") -// Begin starts a transaction. Unlike database/sql, the context only affects the begin command. i.e. there is no +// Begin starts a transaction. Unlike [database/sql], the context only affects the begin command. i.e. there is no // auto-rollback on context cancellation. func (c *Conn) Begin(ctx context.Context) (Tx, error) { return c.BeginTx(ctx, TxOptions{}) } -// BeginTx starts a transaction with txOptions determining the transaction mode. Unlike database/sql, the context only +// BeginTx starts a transaction with txOptions determining the transaction mode. Unlike [database/sql], the context only // affects the begin command. i.e. there is no auto-rollback on context cancellation. func (c *Conn) BeginTx(ctx context.Context, txOptions TxOptions) (Tx, error) { _, err := c.Exec(ctx, txOptions.beginSQL()) if err != nil { // begin should never fail unless there is an underlying connection issue or // a context timeout. In either case, the connection is possibly broken. - c.die(errors.New("failed to begin transaction")) + c.die() return nil, err } - return &dbTx{conn: c}, nil + return &dbTx{ + conn: c, + commitQuery: txOptions.CommitQuery, + }, nil } // Tx represents a database transaction. @@ -154,6 +158,7 @@ type dbTx struct { conn *Conn savepointNum int64 closed bool + commitQuery string } // Begin starts a pseudo nested transaction implemented with a savepoint. @@ -177,7 +182,12 @@ func (tx *dbTx) Commit(ctx context.Context) error { return ErrTxClosed } - commandTag, err := tx.conn.Exec(ctx, "commit") + commandSQL := "commit" + if tx.commitQuery != "" { + commandSQL = tx.commitQuery + } + + commandTag, err := tx.conn.Exec(ctx, commandSQL) tx.closed = true if err != nil { if tx.conn.PgConn().TxStatus() != 'I' { @@ -205,7 +215,7 @@ func (tx *dbTx) Rollback(ctx context.Context) error { tx.closed = true if err != nil { // A rollback failure leaves the connection in an undefined state - tx.conn.die(fmt.Errorf("rollback failed: %w", err)) + tx.conn.die() return err } @@ -375,8 +385,8 @@ func (sp *dbSimulatedNestedTx) Conn() *Conn { return sp.tx.Conn() } -// BeginFunc calls Begin on db and then calls fn. If fn does not return an error then it calls Commit on db. If fn -// returns an error it calls Rollback on db. The context will be used when executing the transaction control statements +// BeginFunc calls Begin on db and then calls fn. If fn does not return an error then it calls [Tx.Commit] on db. If fn +// returns an error it calls [Tx.Rollback] on db. The context will be used when executing the transaction control statements // (BEGIN, ROLLBACK, and COMMIT) but does not otherwise affect the execution of fn. func BeginFunc( ctx context.Context, @@ -394,8 +404,8 @@ func BeginFunc( return beginFuncExec(ctx, tx, fn) } -// BeginTxFunc calls BeginTx on db and then calls fn. If fn does not return an error then it calls Commit on db. If fn -// returns an error it calls Rollback on db. The context will be used when executing the transaction control statements +// BeginTxFunc calls BeginTx on db and then calls fn. If fn does not return an error then it calls [Tx.Commit] on db. If fn +// returns an error it calls [Tx.Rollback] on db. The context will be used when executing the transaction control statements // (BEGIN, ROLLBACK, and COMMIT) but does not otherwise affect the execution of fn. func BeginTxFunc( ctx context.Context, diff --git a/vendor/golang.org/x/crypto/pbkdf2/pbkdf2.go b/vendor/golang.org/x/crypto/pbkdf2/pbkdf2.go deleted file mode 100644 index 28cd99c7..00000000 --- a/vendor/golang.org/x/crypto/pbkdf2/pbkdf2.go +++ /dev/null @@ -1,77 +0,0 @@ -// Copyright 2012 The Go Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -/* -Package pbkdf2 implements the key derivation function PBKDF2 as defined in RFC -2898 / PKCS #5 v2.0. - -A key derivation function is useful when encrypting data based on a password -or any other not-fully-random data. It uses a pseudorandom function to derive -a secure encryption key based on the password. - -While v2.0 of the standard defines only one pseudorandom function to use, -HMAC-SHA1, the drafted v2.1 specification allows use of all five FIPS Approved -Hash Functions SHA-1, SHA-224, SHA-256, SHA-384 and SHA-512 for HMAC. To -choose, you can pass the `New` functions from the different SHA packages to -pbkdf2.Key. -*/ -package pbkdf2 - -import ( - "crypto/hmac" - "hash" -) - -// Key derives a key from the password, salt and iteration count, returning a -// []byte of length keylen that can be used as cryptographic key. The key is -// derived based on the method described as PBKDF2 with the HMAC variant using -// the supplied hash function. -// -// For example, to use a HMAC-SHA-1 based PBKDF2 key derivation function, you -// can get a derived key for e.g. AES-256 (which needs a 32-byte key) by -// doing: -// -// dk := pbkdf2.Key([]byte("some password"), salt, 4096, 32, sha1.New) -// -// Remember to get a good random salt. At least 8 bytes is recommended by the -// RFC. -// -// Using a higher iteration count will increase the cost of an exhaustive -// search but will also make derivation proportionally slower. -func Key(password, salt []byte, iter, keyLen int, h func() hash.Hash) []byte { - prf := hmac.New(h, password) - hashLen := prf.Size() - numBlocks := (keyLen + hashLen - 1) / hashLen - - var buf [4]byte - dk := make([]byte, 0, numBlocks*hashLen) - U := make([]byte, hashLen) - for block := 1; block <= numBlocks; block++ { - // N.B.: || means concatenation, ^ means XOR - // for each block T_i = U_1 ^ U_2 ^ ... ^ U_iter - // U_1 = PRF(password, salt || uint(i)) - prf.Reset() - prf.Write(salt) - buf[0] = byte(block >> 24) - buf[1] = byte(block >> 16) - buf[2] = byte(block >> 8) - buf[3] = byte(block) - prf.Write(buf[:4]) - dk = prf.Sum(dk) - T := dk[len(dk)-hashLen:] - copy(U, T) - - // U_n = PRF(password, U_(n-1)) - for n := 2; n <= iter; n++ { - prf.Reset() - prf.Write(U) - U = U[:0] - U = prf.Sum(U) - for x := range U { - T[x] ^= U[x] - } - } - } - return dk[:keyLen] -} diff --git a/vendor/modules.txt b/vendor/modules.txt index 3506a2ad..af68efc7 100644 --- a/vendor/modules.txt +++ b/vendor/modules.txt @@ -148,8 +148,8 @@ github.com/jackc/pgpassfile # github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 ## explicit; go 1.14 github.com/jackc/pgservicefile -# github.com/jackc/pgx/v5 v5.7.1 -## explicit; go 1.21 +# github.com/jackc/pgx/v5 v5.9.2 +## explicit; go 1.25.0 github.com/jackc/pgx/v5 github.com/jackc/pgx/v5/internal/iobufpool github.com/jackc/pgx/v5/internal/pgio @@ -335,7 +335,6 @@ golang.org/x/crypto/chacha20 golang.org/x/crypto/curve25519 golang.org/x/crypto/internal/alias golang.org/x/crypto/internal/poly1305 -golang.org/x/crypto/pbkdf2 golang.org/x/crypto/ssh golang.org/x/crypto/ssh/internal/bcrypt_pbkdf # golang.org/x/sync v0.18.0