Skip to content

Commit 155f1d9

Browse files
pmorris-devclaude
andcommitted
feat: add configurable limit on loaded database count
Unbounded load() calls can exhaust memory via connection pool proliferation. Add a configurable max (default 50) to DbInstances, rejecting new load() calls with TOO_MANY_DATABASES when the limit is reached. Convert DbInstances from a tuple struct to named fields (inner, max) and expose Builder::max_databases() for configuration. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent dd891a3 commit 155f1d9

3 files changed

Lines changed: 74 additions & 20 deletions

File tree

src/commands.rs

Lines changed: 21 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,7 @@ pub async fn load<R: Runtime>(
112112
// Wait for migrations to complete if registered for this database
113113
await_migrations(&migration_states, &db).await?;
114114

115-
let instances = db_instances.0.read().await;
115+
let instances = db_instances.inner.read().await;
116116

117117
// Return cached if db was already loaded
118118
if instances.contains_key(&db) {
@@ -121,7 +121,14 @@ pub async fn load<R: Runtime>(
121121

122122
drop(instances); // Release read lock before acquiring write lock
123123

124-
let mut instances = db_instances.0.write().await;
124+
let mut instances = db_instances.inner.write().await;
125+
126+
// Check database count limit before creating a new connection.
127+
// This check is before entry() to avoid borrow conflicts, and the write lock
128+
// prevents races between the len() check and the insert.
129+
if !instances.contains_key(&db) && instances.len() >= db_instances.max {
130+
return Err(Error::TooManyDatabases(db_instances.max));
131+
}
125132

126133
// Use entry API to atomically check and insert, avoiding race conditions
127134
// where two callers could both create wrappers
@@ -187,7 +194,7 @@ pub async fn execute(
187194
values: Vec<JsonValue>,
188195
attached: Option<Vec<AttachedDatabaseSpec>>,
189196
) -> Result<(u64, i64)> {
190-
let instances = db_instances.0.read().await;
197+
let instances = db_instances.inner.read().await;
191198

192199
let wrapper = instances
193200
.get(&db)
@@ -214,7 +221,7 @@ pub async fn execute_transaction(
214221
statements: Vec<Statement>,
215222
attached: Option<Vec<AttachedDatabaseSpec>>,
216223
) -> Result<Vec<WriteQueryResult>> {
217-
let instances = db_instances.0.read().await;
224+
let instances = db_instances.inner.read().await;
218225

219226
let wrapper = instances
220227
.get(&db)
@@ -292,7 +299,7 @@ pub async fn fetch_all(
292299
values: Vec<JsonValue>,
293300
attached: Option<Vec<AttachedDatabaseSpec>>,
294301
) -> Result<Vec<IndexMap<String, JsonValue>>> {
295-
let instances = db_instances.0.read().await;
302+
let instances = db_instances.inner.read().await;
296303

297304
let wrapper = instances
298305
.get(&db)
@@ -319,7 +326,7 @@ pub async fn fetch_one(
319326
values: Vec<JsonValue>,
320327
attached: Option<Vec<AttachedDatabaseSpec>>,
321328
) -> Result<Option<IndexMap<String, JsonValue>>> {
322-
let instances = db_instances.0.read().await;
329+
let instances = db_instances.inner.read().await;
323330

324331
let wrapper = instances
325332
.get(&db)
@@ -357,7 +364,7 @@ pub async fn fetch_page(
357364
));
358365
}
359366

360-
let instances = db_instances.0.read().await;
367+
let instances = db_instances.inner.read().await;
361368

362369
let wrapper = instances
363370
.get(&db)
@@ -394,7 +401,7 @@ pub async fn close(
394401
) -> Result<bool> {
395402
active_subs.remove_for_db(&db).await;
396403

397-
let mut instances = db_instances.0.write().await;
404+
let mut instances = db_instances.inner.write().await;
398405

399406
if let Some(wrapper) = instances.remove(&db) {
400407
wrapper.close().await?;
@@ -415,7 +422,7 @@ pub async fn close_all(
415422
) -> Result<()> {
416423
active_subs.abort_all().await;
417424

418-
let mut instances = db_instances.0.write().await;
425+
let mut instances = db_instances.inner.write().await;
419426

420427
// Collect all wrappers to close
421428
let wrappers: Vec<DatabaseWrapper> = instances.drain().map(|(_, v)| v).collect();
@@ -447,7 +454,7 @@ pub async fn remove(
447454
) -> Result<bool> {
448455
active_subs.remove_for_db(&db).await;
449456

450-
let mut instances = db_instances.0.write().await;
457+
let mut instances = db_instances.inner.write().await;
451458

452459
if let Some(wrapper) = instances.remove(&db) {
453460
wrapper.remove().await?;
@@ -489,7 +496,7 @@ pub async fn begin_interruptible_transaction(
489496
initial_statements: Vec<Statement>,
490497
attached: Option<Vec<AttachedDatabaseSpec>>,
491498
) -> Result<TransactionToken> {
492-
let instances = db_instances.0.read().await;
499+
let instances = db_instances.inner.read().await;
493500

494501
let wrapper = instances
495502
.get(&db)
@@ -641,7 +648,7 @@ pub async fn observe(
641648
// enable_observation() drops the old broker
642649
active_subs.remove_for_db(&db).await;
643650

644-
let mut instances = db_instances.0.write().await;
651+
let mut instances = db_instances.inner.write().await;
645652

646653
let wrapper = instances
647654
.get_mut(&db)
@@ -683,7 +690,7 @@ pub async fn subscribe(
683690
tables: Vec<String>,
684691
on_event: Channel<TableChangePayload>,
685692
) -> Result<String> {
686-
let instances = db_instances.0.read().await;
693+
let instances = db_instances.inner.read().await;
687694

688695
let wrapper = instances
689696
.get(&db)
@@ -747,7 +754,7 @@ pub async fn unobserve(
747754
// Abort all subscriptions for this database first
748755
active_subs.remove_for_db(&db).await;
749756

750-
let mut instances = db_instances.0.write().await;
757+
let mut instances = db_instances.inner.write().await;
751758

752759
let wrapper = instances
753760
.get_mut(&db)

src/error.rs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,10 @@ pub enum Error {
3939
#[error("observation not enabled for database: {0}")]
4040
ObservationNotEnabled(String),
4141

42+
/// Too many databases loaded simultaneously.
43+
#[error("cannot load more than {0} databases")]
44+
TooManyDatabases(usize),
45+
4246
/// Invalid configuration parameter.
4347
#[error("invalid configuration: {0}")]
4448
InvalidConfig(String),
@@ -78,6 +82,7 @@ impl Error {
7882
Error::PathTraversal(_) => "PATH_TRAVERSAL".to_string(),
7983
Error::DatabaseNotLoaded(_) => "DATABASE_NOT_LOADED".to_string(),
8084
Error::ObservationNotEnabled(_) => "OBSERVATION_NOT_ENABLED".to_string(),
85+
Error::TooManyDatabases(_) => "TOO_MANY_DATABASES".to_string(),
8186
Error::InvalidConfig(_) => "INVALID_CONFIG".to_string(),
8287
Error::Other(_) => "ERROR".to_string(),
8388
}

src/lib.rs

Lines changed: 48 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -22,12 +22,38 @@ pub use sqlx_sqlite_toolkit::{
2222
TransactionExecutionBuilder, WriteQueryResult,
2323
};
2424

25+
/// Default maximum number of concurrently loaded databases.
26+
const DEFAULT_MAX_DATABASES: usize = 50;
27+
2528
/// Database instances managed by the plugin.
2629
///
2730
/// This struct maintains a thread-safe map of database paths to their corresponding
28-
/// connection wrappers.
29-
#[derive(Clone, Default)]
30-
pub struct DbInstances(pub Arc<RwLock<HashMap<String, DatabaseWrapper>>>);
31+
/// connection wrappers, with a configurable upper limit on how many databases can be
32+
/// loaded simultaneously.
33+
#[derive(Clone)]
34+
pub struct DbInstances {
35+
pub(crate) inner: Arc<RwLock<HashMap<String, DatabaseWrapper>>>,
36+
pub(crate) max: usize,
37+
}
38+
39+
impl Default for DbInstances {
40+
fn default() -> Self {
41+
Self {
42+
inner: Arc::new(RwLock::new(HashMap::new())),
43+
max: DEFAULT_MAX_DATABASES,
44+
}
45+
}
46+
}
47+
48+
impl DbInstances {
49+
/// Create a new instance with the given maximum database count.
50+
pub fn new(max: usize) -> Self {
51+
Self {
52+
inner: Arc::new(RwLock::new(HashMap::new())),
53+
max,
54+
}
55+
}
56+
}
3157

3258
/// Migration status for a database.
3359
#[derive(Debug, Clone)]
@@ -136,6 +162,8 @@ pub struct Builder {
136162
migrations: HashMap<String, Arc<Migrator>>,
137163
/// Timeout for interruptible transactions. Defaults to 5 minutes.
138164
transaction_timeout: Option<std::time::Duration>,
165+
/// Maximum number of concurrently loaded databases. Defaults to 50.
166+
max_databases: Option<usize>,
139167
}
140168

141169
impl Builder {
@@ -144,6 +172,7 @@ impl Builder {
144172
Self {
145173
migrations: HashMap::new(),
146174
transaction_timeout: None,
175+
max_databases: None,
147176
}
148177
}
149178

@@ -182,10 +211,20 @@ impl Builder {
182211
self
183212
}
184213

214+
/// Set the maximum number of databases that can be loaded simultaneously.
215+
///
216+
/// Prevents unbounded memory growth from connection pool proliferation.
217+
/// Defaults to 50.
218+
pub fn max_databases(mut self, max: usize) -> Self {
219+
self.max_databases = Some(max);
220+
self
221+
}
222+
185223
/// Build the plugin with command registration and state management.
186224
pub fn build<R: Runtime>(self) -> tauri::plugin::TauriPlugin<R> {
187225
let migrations = Arc::new(self.migrations);
188226
let transaction_timeout = self.transaction_timeout;
227+
let max_databases = self.max_databases;
189228

190229
PluginBuilder::<R>::new("sqlite")
191230
.invoke_handler(tauri::generate_handler![
@@ -208,7 +247,10 @@ impl Builder {
208247
commands::unobserve,
209248
])
210249
.setup(move |app, _api| {
211-
app.manage(DbInstances::default());
250+
app.manage(match max_databases {
251+
Some(max) => DbInstances::new(max),
252+
None => DbInstances::default(),
253+
});
212254
app.manage(MigrationStates::default());
213255
app.manage(match transaction_timeout {
214256
Some(timeout) => ActiveInterruptibleTransactions::new(timeout),
@@ -286,7 +328,7 @@ impl Builder {
286328

287329
// Close databases (each wrapper's close() disables its own
288330
// observer at the crate level, unregistering SQLite hooks)
289-
let mut guard = instances_clone.0.write().await;
331+
let mut guard = instances_clone.inner.write().await;
290332
let wrappers: Vec<DatabaseWrapper> =
291333
guard.drain().map(|(_, v)| v).collect();
292334

@@ -329,7 +371,7 @@ impl Builder {
329371
// ExitRequested should have already closed all databases
330372
// This is just a safety check
331373
let instances = app.state::<DbInstances>();
332-
match instances.0.try_read() {
374+
match instances.inner.try_read() {
333375
Ok(guard) => {
334376
if !guard.is_empty() {
335377
warn!(

0 commit comments

Comments
 (0)