@@ -20,6 +20,7 @@ import (
2020 "context"
2121 "flag"
2222 "fmt"
23+ "net/url"
2324 "os"
2425 "sort"
2526 "strings"
@@ -31,16 +32,30 @@ import (
3132 "github.com/specterops/dawgs/opengraph"
3233 "github.com/specterops/dawgs/util/size"
3334
34- // Register drivers
35- _ "github.com/specterops/dawgs/drivers/neo4j"
35+ "github.com/specterops/dawgs/drivers/neo4j"
3636)
3737
3838var (
39- driverFlag = flag .String ("driver" , "pg" , "database driver to test against (pg, neo4j)" )
40- connStrFlag = flag .String ("connection" , "" , "database connection string (overrides PG_CONNECTION_STRING env var)" )
4139 localDatasetFlag = flag .String ("local-dataset" , "" , "name of a local dataset to test (e.g. local/phantom)" )
4240)
4341
42+ // driverFromConnStr returns the dawgs driver name based on the connection string scheme.
43+ func driverFromConnStr (connStr string ) (string , error ) {
44+ u , err := url .Parse (connStr )
45+ if err != nil {
46+ return "" , fmt .Errorf ("failed to parse connection string: %w" , err )
47+ }
48+
49+ switch u .Scheme {
50+ case "postgresql" , "postgres" :
51+ return pg .DriverName , nil
52+ case neo4j .DriverName , "neo4j+s" , "neo4j+ssc" :
53+ return neo4j .DriverName , nil
54+ default :
55+ return "" , fmt .Errorf ("unknown connection string scheme %q" , u .Scheme )
56+ }
57+ }
58+
4459// SetupDB opens a database connection for the selected driver, asserts a schema
4560// derived from the given datasets, and registers cleanup. Returns the database
4661// and a background context.
@@ -49,29 +64,30 @@ func SetupDB(t *testing.T, datasets ...string) (graph.Database, context.Context)
4964
5065 ctx := context .Background ()
5166
52- connStr := * connStrFlag
67+ connStr := os . Getenv ( "CONNECTION_STRING" )
5368 if connStr == "" {
54- connStr = os . Getenv ( "PG_CONNECTION_STRING " )
69+ t . Fatal ( "CONNECTION_STRING env var is not set " )
5570 }
56- if connStr == "" {
57- t .Fatal ("no connection string: set -connection flag or PG_CONNECTION_STRING env var" )
71+
72+ driver , err := driverFromConnStr (connStr )
73+ if err != nil {
74+ t .Fatalf ("Failed to detect driver: %v" , err )
5875 }
5976
6077 cfg := dawgs.Config {
6178 GraphQueryMemoryLimit : size .Gibibyte ,
6279 ConnectionString : connStr ,
6380 }
6481
65- // PG needs a pool with composite type registration
66- if * driverFlag == pg .DriverName {
82+ if driver == pg .DriverName {
6783 pool , err := pg .NewPool (connStr )
6884 if err != nil {
6985 t .Fatalf ("Failed to create PG pool: %v" , err )
7086 }
7187 cfg .Pool = pool
7288 }
7389
74- db , err := dawgs .Open (ctx , * driverFlag , cfg )
90+ db , err := dawgs .Open (ctx , driver , cfg )
7591 if err != nil {
7692 t .Fatalf ("Failed to open database: %v" , err )
7793 }
0 commit comments