55import dask .dataframe
66import pyarrow
77import pandas
8+ import random
89
910# This process local connection variable is initialized in all
1011# dask worker processes by the configure function.
11- _btrdb_conn = None
12+ _btrdb_conns = []
1213
1314
1415def get_btrdb ():
@@ -29,12 +30,12 @@ def get_btrdb():
2930 Ensure that the `configure` function is called to set up the
3031 BtrDB credentials before calling this function.
3132 """
32- conn = _btrdb_conn
33- if conn is None :
33+ conns = _btrdb_conns
34+ if len ( conns ) == 0 :
3435 raise btrdb .exceptions .ConnectionError (
3536 "call configure to configure btrdb credentials for the cluster"
3637 )
37- return conn
38+ return random . choice ( conns )
3839
3940
4041class BtrdbConnectionPlugin (dask .distributed .WorkerPlugin ):
@@ -49,16 +50,20 @@ class BtrdbConnectionPlugin(dask.distributed.WorkerPlugin):
4950 This plugin should not be used directly, and instead be used via `configure`.
5051 """
5152
52- def __init__ (self , endpoints = None , apikey = None ):
53+ def __init__ (self , connections = None , endpoints = None , apikey = None ):
54+ self ._connections = connections
5355 self ._endpoints = endpoints
5456 self ._apikey = apikey
5557
5658 def setup (self , worker ):
57- global _btrdb_conn
58- _btrdb_conn = btrdb ._connect (endpoints = self ._endpoints , apikey = self ._apikey )
59+ global _btrdb_conns
60+ _btrdb_conns = [
61+ btrdb ._connect (endpoints = self ._endpoints , apikey = self ._apikey )
62+ for i in range (self ._connections )
63+ ]
5964
6065
61- def configure (client = None , conn_str = None , apikey = None , profile = None ):
66+ def configure (client = None , conn_str = None , apikey = None , profile = None , connections = 1 ):
6267 """
6368 Configure a btrdb connection on all worker nodes in the dask cluster.
6469 """
@@ -70,8 +75,11 @@ def configure(client=None, conn_str=None, apikey=None, profile=None):
7075 pass
7176 if client is None :
7277 # We have a threaded scheduler.
73- global _btrdb_conn
74- _btrdb_conn = btrdb .connect (conn_str = conn_str , apikey = apikey , profile = profile )
78+ global _btrdb_conns
79+ _btrdb_conns = [
80+ btrdb .connect (conn_str = conn_str , apikey = apikey , profile = profile )
81+ for i in range (connections )
82+ ]
7583 else :
7684 if profile is not None :
7785 creds = btrdb .credentials_by_profile (profile )
@@ -84,7 +92,7 @@ def configure(client=None, conn_str=None, apikey=None, profile=None):
8492 )
8593
8694 # Configure the distributed scheduler.
87- plugin = BtrdbConnectionPlugin (** creds )
95+ plugin = BtrdbConnectionPlugin (connections = connections , ** creds )
8896 client .register_worker_plugin (plugin , name = "btrdb_connection" )
8997
9098
0 commit comments