-
Notifications
You must be signed in to change notification settings - Fork 8
Expand file tree
/
Copy pathpg.go
More file actions
98 lines (80 loc) · 2.85 KB
/
pg.go
File metadata and controls
98 lines (80 loc) · 2.85 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
package pg
import (
"context"
"fmt"
"log/slog"
"time"
"github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgxpool"
"github.com/specterops/dawgs"
"github.com/specterops/dawgs/cypher/models/pgsql"
"github.com/specterops/dawgs/drivers"
"github.com/specterops/dawgs/graph"
)
const (
DriverName = "pg"
// defaultBatchWriteSize is currently set to 2k. This is meant to strike a balance between the cost of thousands
// of round-trips against the cost of locking tables for too long.
defaultBatchWriteSize = 2_000
poolInitConnectionTimeout = time.Second * 10
)
func afterPooledConnectionEstablished(ctx context.Context, conn *pgx.Conn) error {
for _, dataType := range pgsql.CompositeTypes {
if definition, err := conn.LoadType(ctx, dataType.String()); err != nil {
if !StateObjectDoesNotExist.ErrorMatches(err) {
return fmt.Errorf("failed to match composite type %s to database: %w", dataType, err)
}
} else {
conn.TypeMap().RegisterType(definition)
}
}
return nil
}
func afterPooledConnectionRelease(conn *pgx.Conn) bool {
for _, dataType := range pgsql.CompositeTypes {
if _, hasType := conn.TypeMap().TypeForName(dataType.String()); !hasType {
// This connection should be destroyed since it does not contain information regarding the schema's
// composite types
slog.Warn(fmt.Sprintf("Unable to find expected data type: %s. This database connection will not be pooled.", dataType))
return false
}
}
return true
}
func NewPool(cfg drivers.DatabaseConfiguration) (*pgxpool.Pool, error) {
poolCtx, done := context.WithTimeout(context.Background(), poolInitConnectionTimeout)
defer done()
poolCfg, err := pgxpool.ParseConfig(cfg.PostgreSQLConnectionString())
if err != nil {
return nil, err
}
// TODO: Min and Max connections for the pool should be configurable
poolCfg.MinConns = 5
poolCfg.MaxConns = 50
// Bind functions to the AfterConnect and AfterRelease hooks to ensure that composite type registration occurs.
// Without composite type registration, the pgx connection type will not be able to marshal PG OIDs to their
// respective Golang structs.
poolCfg.AfterConnect = afterPooledConnectionEstablished
poolCfg.AfterRelease = afterPooledConnectionRelease
if cfg.EnableRDSIAMAuth {
// Only enable the BeforeConnect handler if RDS IAM Auth is enabled
poolCfg.BeforeConnect = func(ctx context.Context, connCfg *pgx.ConnConfig) error {
if newPoolCfg, err := pgxpool.ParseConfig(cfg.PostgreSQLConnectionString()); err != nil {
return err
} else {
connCfg.Password = newPoolCfg.ConnConfig.Password
}
return nil
}
}
pool, err := pgxpool.NewWithConfig(poolCtx, poolCfg)
if err != nil {
return nil, err
}
return pool, nil
}
func init() {
dawgs.Register(DriverName, func(ctx context.Context, cfg dawgs.Config) (graph.Database, error) {
return NewDriver(cfg.GraphQueryMemoryLimit, cfg.Pool), nil
})
}