Skip to content

Commit 5d2e0f0

Browse files
Add the option to establish more connections per dask worker.
1 parent 61cf70d commit 5d2e0f0

1 file changed

Lines changed: 19 additions & 11 deletions

File tree

btrdb/experimental/dask.py

Lines changed: 19 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,11 @@
55
import dask.dataframe
66
import pyarrow
77
import pandas
8+
import random
89

910
# This process local connection variable is initialized in all
1011
# dask worker processes by the configure function.
11-
_btrdb_conn = None
12+
_btrdb_conns = []
1213

1314

1415
def get_btrdb():
@@ -29,12 +30,12 @@ def get_btrdb():
2930
Ensure that the `configure` function is called to set up the
3031
BtrDB credentials before calling this function.
3132
"""
32-
conn = _btrdb_conn
33-
if conn is None:
33+
conns = _btrdb_conns
34+
if len(conns) == 0:
3435
raise btrdb.exceptions.ConnectionError(
3536
"call configure to configure btrdb credentials for the cluster"
3637
)
37-
return conn
38+
return random.choice(conns)
3839

3940

4041
class BtrdbConnectionPlugin(dask.distributed.WorkerPlugin):
@@ -49,16 +50,20 @@ class BtrdbConnectionPlugin(dask.distributed.WorkerPlugin):
4950
This plugin should not be used directly, and instead be used via `configure`.
5051
"""
5152

52-
def __init__(self, endpoints=None, apikey=None):
53+
def __init__(self, connections=None, endpoints=None, apikey=None):
54+
self._connections = connections
5355
self._endpoints = endpoints
5456
self._apikey = apikey
5557

5658
def setup(self, worker):
57-
global _btrdb_conn
58-
_btrdb_conn = btrdb._connect(endpoints=self._endpoints, apikey=self._apikey)
59+
global _btrdb_conns
60+
_btrdb_conns = [
61+
btrdb._connect(endpoints=self._endpoints, apikey=self._apikey)
62+
for i in range(self._connections)
63+
]
5964

6065

61-
def configure(client=None, conn_str=None, apikey=None, profile=None):
66+
def configure(client=None, conn_str=None, apikey=None, profile=None, connections=1):
6267
"""
6368
Configure a btrdb connection on all worker nodes in the dask cluster.
6469
"""
@@ -70,8 +75,11 @@ def configure(client=None, conn_str=None, apikey=None, profile=None):
7075
pass
7176
if client is None:
7277
# We have a threaded scheduler.
73-
global _btrdb_conn
74-
_btrdb_conn = btrdb.connect(conn_str=conn_str, apikey=apikey, profile=profile)
78+
global _btrdb_conns
79+
_btrdb_conns = [
80+
btrdb.connect(conn_str=conn_str, apikey=apikey, profile=profile)
81+
for i in range(connections)
82+
]
7583
else:
7684
if profile is not None:
7785
creds = btrdb.credentials_by_profile(profile)
@@ -84,7 +92,7 @@ def configure(client=None, conn_str=None, apikey=None, profile=None):
8492
)
8593

8694
# Configure the distributed scheduler.
87-
plugin = BtrdbConnectionPlugin(**creds)
95+
plugin = BtrdbConnectionPlugin(connections=connections, **creds)
8896
client.register_worker_plugin(plugin, name="btrdb_connection")
8997

9098

0 commit comments

Comments
 (0)