diff --git a/.github/workflows/postgres-integration.yml b/.github/workflows/postgres-integration.yml new file mode 100644 index 000000000..123e52a25 --- /dev/null +++ b/.github/workflows/postgres-integration.yml @@ -0,0 +1,42 @@ +name: CI Checks - PostgreSQL Integration Tests + +on: [push, pull_request] + +concurrency: + group: ${{ github.workflow }}-${{ github.ref }} + cancel-in-progress: true + +jobs: + build-and-test: + runs-on: ubuntu-latest + + services: + postgres: + image: postgres:latest + ports: + - 5432:5432 + env: + POSTGRES_DB: postgres + POSTGRES_USER: postgres + POSTGRES_PASSWORD: postgres + options: >- + --health-cmd pg_isready + --health-interval 10s + --health-timeout 5s + --health-retries 5 + + steps: + - name: Checkout code + uses: actions/checkout@v3 + - name: Install Rust stable toolchain + run: | + curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y --profile=minimal --default-toolchain stable + - name: Run PostgreSQL store tests + env: + TEST_POSTGRES_URL: "host=localhost user=postgres password=postgres" + run: cargo test --features postgres io::postgres_store + - name: Run PostgreSQL integration tests + env: + TEST_POSTGRES_URL: "host=localhost user=postgres password=postgres" + run: | + RUSTFLAGS="--cfg no_download --cfg cycle_tests" cargo test --features postgres --test integration_tests_postgres diff --git a/Cargo.toml b/Cargo.toml index 539941677..d4e63f9d4 100755 --- a/Cargo.toml +++ b/Cargo.toml @@ -24,7 +24,9 @@ codegen-units = 1 # Reduce number of codegen units to increase optimizations. panic = 'abort' # Abort on panic [features] -default = [] +default = ["sqlite"] +sqlite = ["dep:rusqlite"] +postgres = ["dep:tokio-postgres"] [dependencies] #lightning = { version = "0.2.0", features = ["std"] } @@ -58,7 +60,7 @@ bdk_wallet = { version = "2.3.0", default-features = false, features = ["std", " bitreq = { version = "0.3", default-features = false, features = ["async-https", "json-using-serde"] } rustls = { version = "0.23", default-features = false } -rusqlite = { version = "0.31.0", features = ["bundled"] } +rusqlite = { version = "0.31.0", features = ["bundled"], optional = true } bitcoin = "0.32.7" bip39 = { version = "2.0.0", features = ["rand"] } bip21 = { version = "0.5", features = ["std"], default-features = false } @@ -76,6 +78,7 @@ serde_json = { version = "1.0.128", default-features = false, features = ["std"] log = { version = "0.4.22", default-features = false, features = ["std"]} async-trait = { version = "0.1", default-features = false } +tokio-postgres = { version = "0.7", default-features = false, features = ["runtime"], optional = true } vss-client = { package = "vss-client-ng", version = "0.5" } prost = { version = "0.11.6", default-features = false} #bitcoin-payment-instructions = { version = "0.6" } diff --git a/src/builder.rs b/src/builder.rs index cd8cc184f..72b7a17c0 100644 --- a/src/builder.rs +++ b/src/builder.rs @@ -53,6 +53,9 @@ use crate::entropy::NodeEntropy; use crate::event::EventQueue; use crate::fee_estimator::OnchainFeeEstimator; use crate::gossip::GossipSource; +#[cfg(feature = "sqlite")] +use crate::io; +#[cfg(feature = "sqlite")] use crate::io::sqlite_store::SqliteStore; use crate::io::utils::{ read_event_queue, read_external_pathfinding_scores_from_cache, read_network_graph, @@ -61,7 +64,7 @@ use crate::io::utils::{ }; use crate::io::vss_store::VssStoreBuilder; use crate::io::{ - self, PAYMENT_INFO_PERSISTENCE_PRIMARY_NAMESPACE, PAYMENT_INFO_PERSISTENCE_SECONDARY_NAMESPACE, + PAYMENT_INFO_PERSISTENCE_PRIMARY_NAMESPACE, PAYMENT_INFO_PERSISTENCE_SECONDARY_NAMESPACE, PENDING_PAYMENT_INFO_PERSISTENCE_PRIMARY_NAMESPACE, PENDING_PAYMENT_INFO_PERSISTENCE_SECONDARY_NAMESPACE, }; @@ -616,6 +619,7 @@ impl NodeBuilder { /// Builds a [`Node`] instance with a [`SqliteStore`] backend and according to the options /// previously configured. + #[cfg(feature = "sqlite")] pub fn build(&self, node_entropy: NodeEntropy) -> Result { let storage_dir_path = self.config.storage_dir_path.clone(); fs::create_dir_all(storage_dir_path.clone()) @@ -629,6 +633,24 @@ impl NodeBuilder { self.build_with_store(node_entropy, kv_store) } + /// Builds a [`Node`] instance with a [PostgreSQL] backend and according to the options + /// previously configured. + /// + /// Connects to the PostgreSQL database at the given `connection_string`. + /// The given `kv_table_name` will be used or default to + /// [`DEFAULT_KV_TABLE_NAME`](crate::io::postgres_store::DEFAULT_KV_TABLE_NAME). + /// + /// [PostgreSQL]: https://www.postgresql.org + #[cfg(feature = "postgres")] + pub fn build_with_postgres_store( + &self, node_entropy: NodeEntropy, connection_string: &str, kv_table_name: Option, + ) -> Result { + let kv_store = + crate::io::postgres_store::PostgresStore::new(connection_string, kv_table_name) + .map_err(|_| BuildError::KVStoreSetupFailed)?; + self.build_with_store(node_entropy, kv_store) + } + /// Builds a [`Node`] instance with a [`FilesystemStore`] backend and according to the options /// previously configured. pub fn build_with_fs_store(&self, node_entropy: NodeEntropy) -> Result { @@ -1083,10 +1105,27 @@ impl ArcedNodeBuilder { /// Builds a [`Node`] instance with a [`SqliteStore`] backend and according to the options /// previously configured. + #[cfg(feature = "sqlite")] pub fn build(&self, node_entropy: Arc) -> Result, BuildError> { self.inner.read().unwrap().build(*node_entropy).map(Arc::new) } + /// Builds a [`Node`] instance with a [PostgreSQL] backend and according to the options + /// previously configured. + /// + /// [PostgreSQL]: https://www.postgresql.org + #[cfg(feature = "postgres")] + pub fn build_with_postgres_store( + &self, node_entropy: Arc, connection_string: String, + kv_table_name: Option, + ) -> Result, BuildError> { + self.inner + .read() + .unwrap() + .build_with_postgres_store(*node_entropy, &connection_string, kv_table_name) + .map(Arc::new) + } + /// Builds a [`Node`] instance with a [`FilesystemStore`] backend and according to the options /// previously configured. pub fn build_with_fs_store( diff --git a/src/io/mod.rs b/src/io/mod.rs index e080d39f7..d2cad17d3 100644 --- a/src/io/mod.rs +++ b/src/io/mod.rs @@ -7,6 +7,9 @@ //! Objects and traits for data persistence. +#[cfg(feature = "postgres")] +pub mod postgres_store; +#[cfg(feature = "sqlite")] pub mod sqlite_store; #[cfg(test)] pub(crate) mod test_utils; diff --git a/src/io/postgres_store/migrations.rs b/src/io/postgres_store/migrations.rs new file mode 100644 index 000000000..c9add1c57 --- /dev/null +++ b/src/io/postgres_store/migrations.rs @@ -0,0 +1,21 @@ +// This file is Copyright its original authors, visible in version control history. +// +// This file is licensed under the Apache License, Version 2.0 or the MIT license , at your option. You may not use this file except in +// accordance with one or both of these licenses. + +use lightning::io; +use tokio_postgres::Client; + +pub(super) async fn migrate_schema( + _client: &Client, _kv_table_name: &str, from_version: u16, to_version: u16, +) -> io::Result<()> { + assert!(from_version < to_version); + // Future migrations go here, e.g.: + // if from_version == 1 && to_version >= 2 { + // migrate_v1_to_v2(client, kv_table_name).await?; + // from_version = 2; + // } + Ok(()) +} diff --git a/src/io/postgres_store/mod.rs b/src/io/postgres_store/mod.rs new file mode 100644 index 000000000..7b7f78879 --- /dev/null +++ b/src/io/postgres_store/mod.rs @@ -0,0 +1,986 @@ +// This file is Copyright its original authors, visible in version control history. +// +// This file is licensed under the Apache License, Version 2.0 or the MIT license , at your option. You may not use this file except in +// accordance with one or both of these licenses. + +//! Objects related to [`PostgresStore`] live here. +use std::collections::HashMap; +use std::future::Future; +use std::sync::atomic::{AtomicI64, AtomicU64, AtomicUsize, Ordering}; +use std::sync::{Arc, Mutex}; + +use lightning::io; +use lightning::util::persist::{ + KVStore, KVStoreSync, PageToken, PaginatedKVStore, PaginatedKVStoreSync, PaginatedListResponse, +}; +use lightning_types::string::PrintableString; +use tokio_postgres::NoTls; + +use crate::io::utils::check_namespace_key_validity; + +mod migrations; + +/// The default table in which we store all data. +pub const DEFAULT_KV_TABLE_NAME: &str = "ldk_data"; + +// The current schema version for the PostgreSQL store. +const SCHEMA_VERSION: u16 = 1; + +// The number of entries returned per page in paginated list operations. +const PAGE_SIZE: usize = 50; + +// The number of worker threads for the internal runtime used by sync operations. +const INTERNAL_RUNTIME_WORKERS: usize = 2; + +/// A [`KVStoreSync`] implementation that writes to and reads from a [PostgreSQL] database. +/// +/// [PostgreSQL]: https://www.postgresql.org +pub struct PostgresStore { + inner: Arc, + + // Version counter to ensure that writes are applied in the correct order. It is assumed that read and list + // operations aren't sensitive to the order of execution. + next_write_version: AtomicU64, + + // An internal runtime we use to avoid any deadlocks we could hit when waiting on async + // operations to finish from a sync context. + internal_runtime: Option, +} + +// tokio::sync::Mutex (used for the DB client) contains UnsafeCell which opts out of +// RefUnwindSafe. std::sync::Mutex (used by SqliteStore) doesn't have this issue because +// it poisons on panic. This impl is needed for do_read_write_remove_list_persist which +// requires K: KVStoreSync + RefUnwindSafe. +impl std::panic::RefUnwindSafe for PostgresStore {} + +impl PostgresStore { + /// Constructs a new [`PostgresStore`]. + /// + /// Connects to the PostgreSQL database at the given `connection_string`. + /// + /// The given `kv_table_name` will be used or default to [`DEFAULT_KV_TABLE_NAME`]. + pub fn new(connection_string: &str, kv_table_name: Option) -> io::Result { + let internal_runtime = tokio::runtime::Builder::new_multi_thread() + .enable_all() + .thread_name_fn(|| { + static ATOMIC_ID: AtomicUsize = AtomicUsize::new(0); + let id = ATOMIC_ID.fetch_add(1, Ordering::SeqCst); + format!("ldk-node-pg-runtime-{id}") + }) + .worker_threads(INTERNAL_RUNTIME_WORKERS) + .max_blocking_threads(INTERNAL_RUNTIME_WORKERS) + .build() + .unwrap(); + + let connection_string = connection_string.to_string(); + let inner = tokio::task::block_in_place(|| { + internal_runtime.block_on(async { + PostgresStoreInner::new(&connection_string, kv_table_name).await + }) + })?; + + let inner = Arc::new(inner); + let next_write_version = AtomicU64::new(1); + Ok(Self { inner, next_write_version, internal_runtime: Some(internal_runtime) }) + } + + fn build_locking_key( + &self, primary_namespace: &str, secondary_namespace: &str, key: &str, + ) -> String { + format!("{primary_namespace}#{secondary_namespace}#{key}") + } + + fn get_new_version_and_lock_ref( + &self, locking_key: String, + ) -> (Arc>, u64) { + let version = self.next_write_version.fetch_add(1, Ordering::Relaxed); + if version == u64::MAX { + panic!("PostgresStore version counter overflowed"); + } + + let inner_lock_ref = self.inner.get_inner_lock_ref(locking_key); + + (inner_lock_ref, version) + } +} + +impl Drop for PostgresStore { + fn drop(&mut self) { + let internal_runtime = self.internal_runtime.take(); + tokio::task::block_in_place(move || drop(internal_runtime)); + } +} + +impl KVStore for PostgresStore { + fn read( + &self, primary_namespace: &str, secondary_namespace: &str, key: &str, + ) -> impl Future, io::Error>> + 'static + Send { + let primary_namespace = primary_namespace.to_string(); + let secondary_namespace = secondary_namespace.to_string(); + let key = key.to_string(); + let inner = Arc::clone(&self.inner); + async move { inner.read_internal(&primary_namespace, &secondary_namespace, &key).await } + } + + fn write( + &self, primary_namespace: &str, secondary_namespace: &str, key: &str, buf: Vec, + ) -> impl Future> + 'static + Send { + let locking_key = self.build_locking_key(primary_namespace, secondary_namespace, key); + let (inner_lock_ref, version) = self.get_new_version_and_lock_ref(locking_key.clone()); + let primary_namespace = primary_namespace.to_string(); + let secondary_namespace = secondary_namespace.to_string(); + let key = key.to_string(); + let inner = Arc::clone(&self.inner); + async move { + inner + .write_internal( + inner_lock_ref, + locking_key, + version, + &primary_namespace, + &secondary_namespace, + &key, + buf, + ) + .await + } + } + + fn remove( + &self, primary_namespace: &str, secondary_namespace: &str, key: &str, _lazy: bool, + ) -> impl Future> + 'static + Send { + let locking_key = self.build_locking_key(primary_namespace, secondary_namespace, key); + let (inner_lock_ref, version) = self.get_new_version_and_lock_ref(locking_key.clone()); + let primary_namespace = primary_namespace.to_string(); + let secondary_namespace = secondary_namespace.to_string(); + let key = key.to_string(); + let inner = Arc::clone(&self.inner); + async move { + inner + .remove_internal( + inner_lock_ref, + locking_key, + version, + &primary_namespace, + &secondary_namespace, + &key, + ) + .await + } + } + + fn list( + &self, primary_namespace: &str, secondary_namespace: &str, + ) -> impl Future, io::Error>> + 'static + Send { + let primary_namespace = primary_namespace.to_string(); + let secondary_namespace = secondary_namespace.to_string(); + let inner = Arc::clone(&self.inner); + async move { inner.list_internal(&primary_namespace, &secondary_namespace).await } + } +} + +impl KVStoreSync for PostgresStore { + fn read( + &self, primary_namespace: &str, secondary_namespace: &str, key: &str, + ) -> io::Result> { + let internal_runtime = self.internal_runtime.as_ref().ok_or_else(|| { + debug_assert!(false, "Failed to access internal runtime"); + io::Error::new(io::ErrorKind::Other, "Failed to access internal runtime") + })?; + let primary_namespace = primary_namespace.to_string(); + let secondary_namespace = secondary_namespace.to_string(); + let key = key.to_string(); + let inner = Arc::clone(&self.inner); + let fut = async move { + inner.read_internal(&primary_namespace, &secondary_namespace, &key).await + }; + tokio::task::block_in_place(move || internal_runtime.block_on(fut)) + } + + fn write( + &self, primary_namespace: &str, secondary_namespace: &str, key: &str, buf: Vec, + ) -> io::Result<()> { + let internal_runtime = self.internal_runtime.as_ref().ok_or_else(|| { + debug_assert!(false, "Failed to access internal runtime"); + io::Error::new(io::ErrorKind::Other, "Failed to access internal runtime") + })?; + let locking_key = self.build_locking_key(primary_namespace, secondary_namespace, key); + let (inner_lock_ref, version) = self.get_new_version_and_lock_ref(locking_key.clone()); + let primary_namespace = primary_namespace.to_string(); + let secondary_namespace = secondary_namespace.to_string(); + let key = key.to_string(); + let inner = Arc::clone(&self.inner); + let fut = async move { + inner + .write_internal( + inner_lock_ref, + locking_key, + version, + &primary_namespace, + &secondary_namespace, + &key, + buf, + ) + .await + }; + tokio::task::block_in_place(move || internal_runtime.block_on(fut)) + } + + fn remove( + &self, primary_namespace: &str, secondary_namespace: &str, key: &str, _lazy: bool, + ) -> io::Result<()> { + let internal_runtime = self.internal_runtime.as_ref().ok_or_else(|| { + debug_assert!(false, "Failed to access internal runtime"); + io::Error::new(io::ErrorKind::Other, "Failed to access internal runtime") + })?; + let locking_key = self.build_locking_key(primary_namespace, secondary_namespace, key); + let (inner_lock_ref, version) = self.get_new_version_and_lock_ref(locking_key.clone()); + let primary_namespace = primary_namespace.to_string(); + let secondary_namespace = secondary_namespace.to_string(); + let key = key.to_string(); + let inner = Arc::clone(&self.inner); + let fut = async move { + inner + .remove_internal( + inner_lock_ref, + locking_key, + version, + &primary_namespace, + &secondary_namespace, + &key, + ) + .await + }; + tokio::task::block_in_place(move || internal_runtime.block_on(fut)) + } + + fn list(&self, primary_namespace: &str, secondary_namespace: &str) -> io::Result> { + let internal_runtime = self.internal_runtime.as_ref().ok_or_else(|| { + debug_assert!(false, "Failed to access internal runtime"); + io::Error::new(io::ErrorKind::Other, "Failed to access internal runtime") + })?; + let primary_namespace = primary_namespace.to_string(); + let secondary_namespace = secondary_namespace.to_string(); + let inner = Arc::clone(&self.inner); + let fut = + async move { inner.list_internal(&primary_namespace, &secondary_namespace).await }; + tokio::task::block_in_place(move || internal_runtime.block_on(fut)) + } +} + +impl PaginatedKVStoreSync for PostgresStore { + fn list_paginated( + &self, primary_namespace: &str, secondary_namespace: &str, page_token: Option, + ) -> io::Result { + let internal_runtime = self.internal_runtime.as_ref().ok_or_else(|| { + debug_assert!(false, "Failed to access internal runtime"); + io::Error::new(io::ErrorKind::Other, "Failed to access internal runtime") + })?; + let primary_namespace = primary_namespace.to_string(); + let secondary_namespace = secondary_namespace.to_string(); + let inner = Arc::clone(&self.inner); + let fut = async move { + inner + .list_paginated_internal(&primary_namespace, &secondary_namespace, page_token) + .await + }; + tokio::task::block_in_place(move || internal_runtime.block_on(fut)) + } +} + +impl PaginatedKVStore for PostgresStore { + fn list_paginated( + &self, primary_namespace: &str, secondary_namespace: &str, page_token: Option, + ) -> impl Future> + 'static + Send { + let primary_namespace = primary_namespace.to_string(); + let secondary_namespace = secondary_namespace.to_string(); + let inner = Arc::clone(&self.inner); + async move { + inner + .list_paginated_internal(&primary_namespace, &secondary_namespace, page_token) + .await + } + } +} + +struct PostgresStoreInner { + client: tokio::sync::Mutex, + kv_table_name: String, + write_version_locks: Mutex>>>, + next_sort_order: AtomicI64, +} + +impl PostgresStoreInner { + async fn new(connection_string: &str, kv_table_name: Option) -> io::Result { + let kv_table_name = kv_table_name.unwrap_or(DEFAULT_KV_TABLE_NAME.to_string()); + + let (client, connection) = + tokio_postgres::connect(connection_string, NoTls).await.map_err(|e| { + let msg = format!("Failed to connect to PostgreSQL: {e}"); + io::Error::new(io::ErrorKind::Other, msg) + })?; + + // Spawn the connection task so it runs in the background. + tokio::spawn(async move { + if let Err(e) = connection.await { + log::error!("PostgreSQL connection error: {e}"); + } + }); + + // Create the KV data table if it doesn't exist. + let sql = format!( + "CREATE TABLE IF NOT EXISTS {kv_table_name} ( + primary_namespace TEXT NOT NULL, + secondary_namespace TEXT NOT NULL DEFAULT '', + key TEXT NOT NULL CHECK (key <> ''), + value BYTEA, + sort_order BIGINT NOT NULL DEFAULT 0, + PRIMARY KEY (primary_namespace, secondary_namespace, key) + )" + ); + client.execute(sql.as_str(), &[]).await.map_err(|e| { + let msg = format!("Failed to create table {kv_table_name}: {e}"); + io::Error::new(io::ErrorKind::Other, msg) + })?; + + // Read the schema version from the table comment (analogous to SQLite's PRAGMA user_version). + let sql = format!("SELECT obj_description('{kv_table_name}'::regclass, 'pg_class')"); + let row = client.query_one(sql.as_str(), &[]).await.map_err(|e| { + let msg = format!("Failed to read schema version for {kv_table_name}: {e}"); + io::Error::new(io::ErrorKind::Other, msg) + })?; + let version_res: u16 = match row.get::<_, Option<&str>>(0) { + Some(version_str) => version_str.parse().map_err(|_| { + let msg = format!("Invalid schema version: {version_str}"); + io::Error::new(io::ErrorKind::Other, msg) + })?, + None => 0, + }; + + if version_res == 0 { + // New table, set our SCHEMA_VERSION. + let sql = format!("COMMENT ON TABLE {kv_table_name} IS '{SCHEMA_VERSION}'"); + client.execute(sql.as_str(), &[]).await.map_err(|e| { + let msg = format!("Failed to set schema version: {e}"); + io::Error::new(io::ErrorKind::Other, msg) + })?; + } else if version_res < SCHEMA_VERSION { + migrations::migrate_schema(&client, &kv_table_name, version_res, SCHEMA_VERSION) + .await?; + } else if version_res > SCHEMA_VERSION { + let msg = format!( + "Failed to open database: incompatible schema version {version_res}. Expected: {SCHEMA_VERSION}" + ); + return Err(io::Error::new(io::ErrorKind::Other, msg)); + } + + // Create composite index for paginated listing. + let sql = format!( + "CREATE INDEX IF NOT EXISTS idx_{kv_table_name}_paginated ON {kv_table_name} (primary_namespace, secondary_namespace, sort_order DESC, key ASC)" + ); + client.execute(sql.as_str(), &[]).await.map_err(|e| { + let msg = format!("Failed to create index on table {kv_table_name}: {e}"); + io::Error::new(io::ErrorKind::Other, msg) + })?; + + // Initialize next_sort_order from the max existing value. + let sql = format!("SELECT COALESCE(MAX(sort_order), 0) FROM {kv_table_name}"); + let row = client.query_one(sql.as_str(), &[]).await.map_err(|e| { + let msg = format!("Failed to read max sort_order from {kv_table_name}: {e}"); + io::Error::new(io::ErrorKind::Other, msg) + })?; + let max_sort_order: i64 = row.get(0); + let next_sort_order = AtomicI64::new(max_sort_order + 1); + + let client = tokio::sync::Mutex::new(client); + let write_version_locks = Mutex::new(HashMap::new()); + Ok(Self { client, kv_table_name, write_version_locks, next_sort_order }) + } + + fn get_inner_lock_ref(&self, locking_key: String) -> Arc> { + let mut outer_lock = self.write_version_locks.lock().unwrap(); + Arc::clone(&outer_lock.entry(locking_key).or_default()) + } + + async fn read_internal( + &self, primary_namespace: &str, secondary_namespace: &str, key: &str, + ) -> io::Result> { + check_namespace_key_validity(primary_namespace, secondary_namespace, Some(key), "read")?; + + let locked_client = self.client.lock().await; + let sql = format!( + "SELECT value FROM {} WHERE primary_namespace=$1 AND secondary_namespace=$2 AND key=$3", + self.kv_table_name + ); + + let row = locked_client + .query_opt(sql.as_str(), &[&primary_namespace, &secondary_namespace, &key]) + .await + .map_err(|e| { + let msg = format!( + "Failed to read from key {}/{}/{}: {}", + PrintableString(primary_namespace), + PrintableString(secondary_namespace), + PrintableString(key), + e + ); + io::Error::new(io::ErrorKind::Other, msg) + })?; + + match row { + Some(row) => { + let value: Vec = row.get(0); + Ok(value) + }, + None => { + let msg = format!( + "Failed to read as key could not be found: {}/{}/{}", + PrintableString(primary_namespace), + PrintableString(secondary_namespace), + PrintableString(key), + ); + Err(io::Error::new(io::ErrorKind::NotFound, msg)) + }, + } + } + + async fn write_internal( + &self, inner_lock_ref: Arc>, locking_key: String, version: u64, + primary_namespace: &str, secondary_namespace: &str, key: &str, buf: Vec, + ) -> io::Result<()> { + check_namespace_key_validity(primary_namespace, secondary_namespace, Some(key), "write")?; + + self.execute_locked_write(inner_lock_ref, locking_key, version, async move || { + let locked_client = self.client.lock().await; + + let sort_order = self.next_sort_order.fetch_add(1, Ordering::Relaxed); + + let sql = format!( + "INSERT INTO {} (primary_namespace, secondary_namespace, key, value, sort_order) \ + VALUES ($1, $2, $3, $4, $5) \ + ON CONFLICT (primary_namespace, secondary_namespace, key) DO UPDATE SET value = EXCLUDED.value", + self.kv_table_name + ); + + locked_client + .execute( + sql.as_str(), + &[ + &primary_namespace, + &secondary_namespace, + &key, + &buf, + &sort_order, + ], + ) + .await + .map(|_| ()) + .map_err(|e| { + let msg = format!( + "Failed to write to key {}/{}/{}: {}", + PrintableString(primary_namespace), + PrintableString(secondary_namespace), + PrintableString(key), + e + ); + io::Error::new(io::ErrorKind::Other, msg) + }) + }) + .await + } + + async fn remove_internal( + &self, inner_lock_ref: Arc>, locking_key: String, version: u64, + primary_namespace: &str, secondary_namespace: &str, key: &str, + ) -> io::Result<()> { + check_namespace_key_validity(primary_namespace, secondary_namespace, Some(key), "remove")?; + + self.execute_locked_write(inner_lock_ref, locking_key, version, async move || { + let locked_client = self.client.lock().await; + + let sql = format!( + "DELETE FROM {} WHERE primary_namespace=$1 AND secondary_namespace=$2 AND key=$3", + self.kv_table_name + ); + + locked_client + .execute(sql.as_str(), &[&primary_namespace, &secondary_namespace, &key]) + .await + .map_err(|e| { + let msg = format!( + "Failed to delete key {}/{}/{}: {}", + PrintableString(primary_namespace), + PrintableString(secondary_namespace), + PrintableString(key), + e + ); + io::Error::new(io::ErrorKind::Other, msg) + })?; + Ok(()) + }) + .await + } + + async fn list_internal( + &self, primary_namespace: &str, secondary_namespace: &str, + ) -> io::Result> { + check_namespace_key_validity(primary_namespace, secondary_namespace, None, "list")?; + + let locked_client = self.client.lock().await; + + let sql = format!( + "SELECT key FROM {} WHERE primary_namespace=$1 AND secondary_namespace=$2", + self.kv_table_name + ); + + let rows = locked_client + .query(sql.as_str(), &[&primary_namespace, &secondary_namespace]) + .await + .map_err(|e| { + let msg = format!("Failed to retrieve queried rows: {e}"); + io::Error::new(io::ErrorKind::Other, msg) + })?; + + let keys: Vec = rows.iter().map(|row| row.get(0)).collect(); + Ok(keys) + } + + async fn list_paginated_internal( + &self, primary_namespace: &str, secondary_namespace: &str, page_token: Option, + ) -> io::Result { + check_namespace_key_validity( + primary_namespace, + secondary_namespace, + None, + "list_paginated", + )?; + + let locked_client = self.client.lock().await; + + // Fetch one extra row beyond PAGE_SIZE to determine whether a next page exists. + let fetch_limit = (PAGE_SIZE + 1) as i64; + + let mut entries: Vec<(String, i64)> = match page_token { + Some(ref token) => { + let token_sort_order: i64 = token.as_str().parse().map_err(|_| { + let token_str = token.as_str(); + let msg = format!("Invalid page token: {token_str}"); + io::Error::new(io::ErrorKind::InvalidInput, msg) + })?; + let sql = format!( + "SELECT key, sort_order FROM {} \ + WHERE primary_namespace=$1 \ + AND secondary_namespace=$2 \ + AND sort_order < $3 \ + ORDER BY sort_order DESC, key ASC \ + LIMIT $4", + self.kv_table_name + ); + + let rows = locked_client + .query( + sql.as_str(), + &[ + &primary_namespace, + &secondary_namespace, + &token_sort_order, + &fetch_limit, + ], + ) + .await + .map_err(|e| { + let msg = format!("Failed to retrieve queried rows: {e}"); + io::Error::new(io::ErrorKind::Other, msg) + })?; + + rows.iter().map(|row| (row.get(0), row.get(1))).collect() + }, + None => { + let sql = format!( + "SELECT key, sort_order FROM {} \ + WHERE primary_namespace=$1 \ + AND secondary_namespace=$2 \ + ORDER BY sort_order DESC, key ASC \ + LIMIT $3", + self.kv_table_name + ); + + let rows = locked_client + .query(sql.as_str(), &[&primary_namespace, &secondary_namespace, &fetch_limit]) + .await + .map_err(|e| { + let msg = format!("Failed to retrieve queried rows: {e}"); + io::Error::new(io::ErrorKind::Other, msg) + })?; + + rows.into_iter().map(|row| (row.get(0), row.get(1))).collect() + }, + }; + + let has_more = entries.len() > PAGE_SIZE; + entries.truncate(PAGE_SIZE); + + let next_page_token = if has_more { + let (_, last_sort_order) = *entries.last().expect("must be non-empty"); + Some(PageToken::new(last_sort_order.to_string())) + } else { + None + }; + + let keys = entries.into_iter().map(|(k, _)| k).collect(); + Ok(PaginatedListResponse { keys, next_page_token }) + } + + async fn execute_locked_write>, FN: FnOnce() -> F>( + &self, inner_lock_ref: Arc>, locking_key: String, version: u64, + callback: FN, + ) -> Result<(), io::Error> { + let res = { + let mut last_written_version = inner_lock_ref.lock().await; + + // Check if we already have a newer version written/removed. This is used in async contexts to realize eventual + // consistency. + let is_stale_version = version <= *last_written_version; + + // If the version is not stale, we execute the callback. Otherwise, we can and must skip writing. + if is_stale_version { + Ok(()) + } else { + callback().await.map(|_| { + *last_written_version = version; + }) + } + }; + + self.clean_locks(&inner_lock_ref, locking_key); + + res + } + + fn clean_locks(&self, inner_lock_ref: &Arc>, locking_key: String) { + // If there are no arcs in use elsewhere, this means that there are no in-flight writes. We can remove the map + // entry to prevent leaking memory. The two arcs that are expected are the one in the map and the one held here + // in inner_lock_ref. The outer lock is obtained first, to avoid a new arc being cloned after we've already + // counted. + let mut outer_lock = self.write_version_locks.lock().unwrap(); + + let strong_count = Arc::strong_count(inner_lock_ref); + debug_assert!(strong_count >= 2, "Unexpected PostgresStore strong count"); + + if strong_count == 2 { + outer_lock.remove(&locking_key); + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::io::test_utils::{do_read_write_remove_list_persist, do_test_store}; + + fn test_connection_string() -> String { + std::env::var("TEST_POSTGRES_URL") + .unwrap_or_else(|_| "host=localhost user=postgres password=postgres".to_string()) + } + + fn create_test_store(table_name: &str) -> PostgresStore { + PostgresStore::new(&test_connection_string(), Some(table_name.to_string())).unwrap() + } + + fn cleanup_store(store: &PostgresStore) { + if let Some(ref runtime) = store.internal_runtime { + let kv_table = store.inner.kv_table_name.clone(); + let inner = Arc::clone(&store.inner); + let _ = tokio::task::block_in_place(|| { + runtime.block_on(async { + let client = inner.client.lock().await; + let _ = client.execute(&format!("DROP TABLE IF EXISTS {kv_table}"), &[]).await; + }) + }); + } + } + + #[test] + fn read_write_remove_list_persist() { + let store = create_test_store("test_rwrl"); + do_read_write_remove_list_persist(&store); + cleanup_store(&store); + } + + #[test] + fn test_postgres_store() { + let store_0 = create_test_store("test_pg_store_0"); + let store_1 = create_test_store("test_pg_store_1"); + do_test_store(&store_0, &store_1); + cleanup_store(&store_0); + cleanup_store(&store_1); + } + + #[test] + fn test_postgres_store_paginated_listing() { + let store = create_test_store("test_pg_paginated"); + + let primary_namespace = "test_ns"; + let secondary_namespace = "test_sub"; + let num_entries = 225; + + for i in 0..num_entries { + let key = format!("key_{:04}", i); + let data = vec![i as u8; 32]; + KVStoreSync::write(&store, primary_namespace, secondary_namespace, &key, data).unwrap(); + } + + // Paginate through all entries and collect them + let mut all_keys = Vec::new(); + let mut page_token = None; + let mut page_count = 0; + + loop { + let response = PaginatedKVStoreSync::list_paginated( + &store, + primary_namespace, + secondary_namespace, + page_token, + ) + .unwrap(); + + all_keys.extend(response.keys.clone()); + page_count += 1; + + match response.next_page_token { + Some(token) => page_token = Some(token), + None => break, + } + } + + // Verify we got exactly the right number of entries + assert_eq!(all_keys.len(), num_entries); + + // Verify correct number of pages (225 entries at 50 per page = 5 pages) + assert_eq!(page_count, 5); + + // Verify no duplicates + let mut unique_keys = all_keys.clone(); + unique_keys.sort(); + unique_keys.dedup(); + assert_eq!(unique_keys.len(), num_entries); + + // Verify ordering: newest first (highest sort_order first). + assert_eq!(all_keys[0], format!("key_{:04}", num_entries - 1)); + assert_eq!(all_keys[num_entries - 1], "key_0000"); + + cleanup_store(&store); + } + + #[test] + fn test_postgres_store_paginated_update_preserves_order() { + let store = create_test_store("test_pg_paginated_update"); + + let primary_namespace = "test_ns"; + let secondary_namespace = "test_sub"; + + KVStoreSync::write(&store, primary_namespace, secondary_namespace, "first", vec![1u8; 8]) + .unwrap(); + KVStoreSync::write(&store, primary_namespace, secondary_namespace, "second", vec![2u8; 8]) + .unwrap(); + KVStoreSync::write(&store, primary_namespace, secondary_namespace, "third", vec![3u8; 8]) + .unwrap(); + + // Update the first entry + KVStoreSync::write(&store, primary_namespace, secondary_namespace, "first", vec![99u8; 8]) + .unwrap(); + + // Paginated listing should still show "first" with its original creation order + let response = PaginatedKVStoreSync::list_paginated( + &store, + primary_namespace, + secondary_namespace, + None, + ) + .unwrap(); + + // Newest first: third, second, first + assert_eq!(response.keys, vec!["third", "second", "first"]); + + // Verify the updated value was persisted + let data = + KVStoreSync::read(&store, primary_namespace, secondary_namespace, "first").unwrap(); + assert_eq!(data, vec![99u8; 8]); + + cleanup_store(&store); + } + + #[test] + fn test_postgres_store_paginated_empty_namespace() { + let store = create_test_store("test_pg_paginated_empty"); + + // Paginating an empty or unknown namespace returns an empty result with no token. + let response = + PaginatedKVStoreSync::list_paginated(&store, "nonexistent", "ns", None).unwrap(); + assert!(response.keys.is_empty()); + assert!(response.next_page_token.is_none()); + + cleanup_store(&store); + } + + #[test] + fn test_postgres_store_paginated_namespace_isolation() { + let store = create_test_store("test_pg_paginated_isolation"); + + KVStoreSync::write(&store, "ns_a", "sub", "key_1", vec![1u8; 8]).unwrap(); + KVStoreSync::write(&store, "ns_a", "sub", "key_2", vec![2u8; 8]).unwrap(); + KVStoreSync::write(&store, "ns_b", "sub", "key_3", vec![3u8; 8]).unwrap(); + KVStoreSync::write(&store, "ns_a", "other", "key_4", vec![4u8; 8]).unwrap(); + + // ns_a/sub should only contain key_1 and key_2 (newest first). + let response = PaginatedKVStoreSync::list_paginated(&store, "ns_a", "sub", None).unwrap(); + assert_eq!(response.keys, vec!["key_2", "key_1"]); + assert!(response.next_page_token.is_none()); + + // ns_b/sub should only contain key_3. + let response = PaginatedKVStoreSync::list_paginated(&store, "ns_b", "sub", None).unwrap(); + assert_eq!(response.keys, vec!["key_3"]); + + // ns_a/other should only contain key_4. + let response = PaginatedKVStoreSync::list_paginated(&store, "ns_a", "other", None).unwrap(); + assert_eq!(response.keys, vec!["key_4"]); + + cleanup_store(&store); + } + + #[test] + fn test_postgres_store_paginated_removal() { + let store = create_test_store("test_pg_paginated_removal"); + + let ns = "test_ns"; + let sub = "test_sub"; + + KVStoreSync::write(&store, ns, sub, "a", vec![1u8; 8]).unwrap(); + KVStoreSync::write(&store, ns, sub, "b", vec![2u8; 8]).unwrap(); + KVStoreSync::write(&store, ns, sub, "c", vec![3u8; 8]).unwrap(); + + KVStoreSync::remove(&store, ns, sub, "b", false).unwrap(); + + let response = PaginatedKVStoreSync::list_paginated(&store, ns, sub, None).unwrap(); + assert_eq!(response.keys, vec!["c", "a"]); + assert!(response.next_page_token.is_none()); + + cleanup_store(&store); + } + + #[test] + fn test_postgres_store_paginated_exact_page_boundary() { + let store = create_test_store("test_pg_paginated_boundary"); + + let ns = "test_ns"; + let sub = "test_sub"; + + // Write exactly PAGE_SIZE entries (50). + for i in 0..PAGE_SIZE { + let key = format!("key_{:04}", i); + KVStoreSync::write(&store, ns, sub, &key, vec![i as u8; 8]).unwrap(); + } + + // Exactly PAGE_SIZE entries: all returned in one page with no next-page token. + let response = PaginatedKVStoreSync::list_paginated(&store, ns, sub, None).unwrap(); + assert_eq!(response.keys.len(), PAGE_SIZE); + assert!(response.next_page_token.is_none()); + + // Add one more entry (PAGE_SIZE + 1 total). First page should now have a token. + KVStoreSync::write(&store, ns, sub, "key_extra", vec![0u8; 8]).unwrap(); + let response = PaginatedKVStoreSync::list_paginated(&store, ns, sub, None).unwrap(); + assert_eq!(response.keys.len(), PAGE_SIZE); + assert!(response.next_page_token.is_some()); + + // Second page should have exactly 1 entry and no token. + let response = + PaginatedKVStoreSync::list_paginated(&store, ns, sub, response.next_page_token) + .unwrap(); + assert_eq!(response.keys.len(), 1); + assert!(response.next_page_token.is_none()); + + cleanup_store(&store); + } + + #[test] + fn test_postgres_store_paginated_fewer_than_page_size() { + let store = create_test_store("test_pg_paginated_few"); + + let ns = "test_ns"; + let sub = "test_sub"; + + // Write fewer entries than PAGE_SIZE. + for i in 0..5 { + let key = format!("key_{i}"); + KVStoreSync::write(&store, ns, sub, &key, vec![i as u8; 8]).unwrap(); + } + + let response = PaginatedKVStoreSync::list_paginated(&store, ns, sub, None).unwrap(); + assert_eq!(response.keys.len(), 5); + // Fewer than PAGE_SIZE means no next page. + assert!(response.next_page_token.is_none()); + // Newest first. + assert_eq!(response.keys, vec!["key_4", "key_3", "key_2", "key_1", "key_0"]); + + cleanup_store(&store); + } + + #[test] + fn test_postgres_store_write_version_persists_across_restart() { + let table_name = "test_pg_write_version_restart"; + let primary_namespace = "test_ns"; + let secondary_namespace = "test_sub"; + + { + let store = create_test_store(table_name); + + KVStoreSync::write( + &store, + primary_namespace, + secondary_namespace, + "key_a", + vec![1u8; 8], + ) + .unwrap(); + KVStoreSync::write( + &store, + primary_namespace, + secondary_namespace, + "key_b", + vec![2u8; 8], + ) + .unwrap(); + + // Don't clean up since we want to reopen + } + + // Open a new store instance on the same database table and write more + { + let store = create_test_store(table_name); + + KVStoreSync::write( + &store, + primary_namespace, + secondary_namespace, + "key_c", + vec![3u8; 8], + ) + .unwrap(); + + // Paginated listing should show newest first: key_c, key_b, key_a + let response = PaginatedKVStoreSync::list_paginated( + &store, + primary_namespace, + secondary_namespace, + None, + ) + .unwrap(); + + assert_eq!(response.keys, vec!["key_c", "key_b", "key_a"]); + + cleanup_store(&store); + } + } +} diff --git a/tests/common/mod.rs b/tests/common/mod.rs index 4f68f9825..1c112b379 100644 --- a/tests/common/mod.rs +++ b/tests/common/mod.rs @@ -29,6 +29,7 @@ use electrsd::{corepc_node, ElectrsD}; use electrum_client::ElectrumApi; use ldk_node::config::{AsyncPaymentsRole, Config, ElectrumSyncConfig, EsploraSyncConfig}; use ldk_node::entropy::{generate_entropy_mnemonic, NodeEntropy}; +#[cfg(feature = "sqlite")] use ldk_node::io::sqlite_store::SqliteStore; use ldk_node::payment::{PaymentDirection, PaymentKind, PaymentStatus}; use ldk_node::{ @@ -329,6 +330,7 @@ pub(crate) enum TestChainSource<'a> { #[derive(Clone, Copy)] pub(crate) enum TestStoreType { TestSyncStore, + #[cfg(feature = "sqlite")] Sqlite, } @@ -486,6 +488,7 @@ pub(crate) fn setup_node(chain_source: &TestChainSource, config: TestConfig) -> let kv_store = TestSyncStore::new(config.node_config.storage_dir_path.into()); builder.build_with_store(config.node_entropy.into(), kv_store).unwrap() }, + #[cfg(feature = "sqlite")] TestStoreType::Sqlite => builder.build(config.node_entropy.into()).unwrap(), }; @@ -1519,6 +1522,7 @@ struct TestSyncStoreInner { serializer: RwLock<()>, test_store: TestStore, fs_store: FilesystemStore, + #[cfg(feature = "sqlite")] sqlite_store: SqliteStore, } @@ -1528,8 +1532,11 @@ impl TestSyncStoreInner { let mut fs_dir = dest_dir.clone(); fs_dir.push("fs_store"); let fs_store = FilesystemStore::new(fs_dir); + #[cfg(feature = "sqlite")] let mut sql_dir = dest_dir.clone(); + #[cfg(feature = "sqlite")] sql_dir.push("sqlite_store"); + #[cfg(feature = "sqlite")] let sqlite_store = SqliteStore::new( sql_dir, Some("test_sync_db".to_string()), @@ -1537,24 +1544,34 @@ impl TestSyncStoreInner { ) .unwrap(); let test_store = TestStore::new(false); - Self { serializer, fs_store, sqlite_store, test_store } + #[cfg(feature = "sqlite")] + { + return Self { serializer, fs_store, sqlite_store, test_store }; + } + #[cfg(not(feature = "sqlite"))] + { + Self { serializer, fs_store, test_store } + } } fn do_list( &self, primary_namespace: &str, secondary_namespace: &str, ) -> lightning::io::Result> { let fs_res = KVStoreSync::list(&self.fs_store, primary_namespace, secondary_namespace); - let sqlite_res = - KVStoreSync::list(&self.sqlite_store, primary_namespace, secondary_namespace); + #[cfg(feature = "sqlite")] + let sqlite_res = KVStoreSync::list(&self.sqlite_store, primary_namespace, secondary_namespace); let test_res = KVStoreSync::list(&self.test_store, primary_namespace, secondary_namespace); match fs_res { Ok(mut list) => { list.sort(); - let mut sqlite_list = sqlite_res.unwrap(); - sqlite_list.sort(); - assert_eq!(list, sqlite_list); + #[cfg(feature = "sqlite")] + { + let mut sqlite_list = sqlite_res.unwrap(); + sqlite_list.sort(); + assert_eq!(list, sqlite_list); + } let mut test_list = test_res.unwrap(); test_list.sort(); @@ -1563,6 +1580,7 @@ impl TestSyncStoreInner { Ok(list) }, Err(e) => { + #[cfg(feature = "sqlite")] assert!(sqlite_res.is_err()); assert!(test_res.is_err()); Err(e) @@ -1576,6 +1594,7 @@ impl TestSyncStoreInner { let _guard = self.serializer.read().unwrap(); let fs_res = KVStoreSync::read(&self.fs_store, primary_namespace, secondary_namespace, key); + #[cfg(feature = "sqlite")] let sqlite_res = KVStoreSync::read(&self.sqlite_store, primary_namespace, secondary_namespace, key); let test_res = @@ -1583,13 +1602,17 @@ impl TestSyncStoreInner { match fs_res { Ok(read) => { + #[cfg(feature = "sqlite")] assert_eq!(read, sqlite_res.unwrap()); assert_eq!(read, test_res.unwrap()); Ok(read) }, Err(e) => { - assert!(sqlite_res.is_err()); - assert_eq!(e.kind(), unsafe { sqlite_res.unwrap_err_unchecked().kind() }); + #[cfg(feature = "sqlite")] + { + assert!(sqlite_res.is_err()); + assert_eq!(e.kind(), unsafe { sqlite_res.unwrap_err_unchecked().kind() }); + } assert!(test_res.is_err()); assert_eq!(e.kind(), unsafe { test_res.unwrap_err_unchecked().kind() }); Err(e) @@ -1608,6 +1631,7 @@ impl TestSyncStoreInner { key, buf.clone(), ); + #[cfg(feature = "sqlite")] let sqlite_res = KVStoreSync::write( &self.sqlite_store, primary_namespace, @@ -1630,11 +1654,13 @@ impl TestSyncStoreInner { match fs_res { Ok(()) => { + #[cfg(feature = "sqlite")] assert!(sqlite_res.is_ok()); assert!(test_res.is_ok()); Ok(()) }, Err(e) => { + #[cfg(feature = "sqlite")] assert!(sqlite_res.is_err()); assert!(test_res.is_err()); Err(e) @@ -1648,6 +1674,7 @@ impl TestSyncStoreInner { let _guard = self.serializer.write().unwrap(); let fs_res = KVStoreSync::remove(&self.fs_store, primary_namespace, secondary_namespace, key, lazy); + #[cfg(feature = "sqlite")] let sqlite_res = KVStoreSync::remove( &self.sqlite_store, primary_namespace, @@ -1670,11 +1697,13 @@ impl TestSyncStoreInner { match fs_res { Ok(()) => { + #[cfg(feature = "sqlite")] assert!(sqlite_res.is_ok()); assert!(test_res.is_ok()); Ok(()) }, Err(e) => { + #[cfg(feature = "sqlite")] assert!(sqlite_res.is_err()); assert!(test_res.is_err()); Err(e) diff --git a/tests/integration_tests_postgres.rs b/tests/integration_tests_postgres.rs new file mode 100644 index 000000000..eb1e3d86b --- /dev/null +++ b/tests/integration_tests_postgres.rs @@ -0,0 +1,128 @@ +// This file is Copyright its original authors, visible in version control history. +// +// This file is licensed under the Apache License, Version 2.0 or the MIT license , at your option. You may not use this file except in +// accordance with one or both of these licenses. + +#![cfg(feature = "postgres")] + +mod common; + +use ldk_node::entropy::NodeEntropy; +use ldk_node::Builder; +use rand::RngCore; + +fn test_connection_string() -> String { + std::env::var("TEST_POSTGRES_URL") + .unwrap_or_else(|_| "host=localhost user=postgres password=postgres".to_string()) +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 1)] +async fn channel_full_cycle_with_postgres_store() { + let (bitcoind, electrsd) = common::setup_bitcoind_and_electrsd(); + println!("== Node A =="); + let esplora_url = format!("http://{}", electrsd.esplora_url.as_ref().unwrap()); + let config_a = common::random_config(true); + let mut builder_a = Builder::from_config(config_a.node_config); + builder_a.set_chain_source_esplora(esplora_url.clone(), None); + let node_a = builder_a + .build_with_postgres_store( + config_a.node_entropy, + &test_connection_string(), + Some("channel_cycle_a".to_string()), + ) + .unwrap(); + node_a.start().unwrap(); + + println!("\n== Node B =="); + let config_b = common::random_config(true); + let mut builder_b = Builder::from_config(config_b.node_config); + builder_b.set_chain_source_esplora(esplora_url.clone(), None); + let node_b = builder_b + .build_with_postgres_store( + config_b.node_entropy, + &test_connection_string(), + Some("channel_cycle_b".to_string()), + ) + .unwrap(); + node_b.start().unwrap(); + + common::do_channel_full_cycle( + node_a, + node_b, + &bitcoind.client, + &electrsd.client, + false, + true, + false, + ) + .await; +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 1)] +async fn postgres_node_restart() { + let (bitcoind, electrsd) = common::setup_bitcoind_and_electrsd(); + let esplora_url = format!("http://{}", electrsd.esplora_url.as_ref().unwrap()); + let connection_string = test_connection_string(); + + let storage_path = common::random_storage_path().to_str().unwrap().to_owned(); + let mut seed_bytes = [42u8; 64]; + rand::rng().fill_bytes(&mut seed_bytes); + let node_entropy = NodeEntropy::from_seed_bytes(seed_bytes); + + // Setup initial node and fund it. + let (expected_balance_sats, expected_node_id) = { + let mut builder = Builder::new(); + builder.set_network(bitcoin::Network::Regtest); + builder.set_storage_dir_path(storage_path.clone()); + builder.set_chain_source_esplora(esplora_url.clone(), None); + let node = builder + .build_with_postgres_store( + node_entropy, + &connection_string, + Some("restart_test".to_string()), + ) + .unwrap(); + + node.start().unwrap(); + let addr = node.onchain_payment().new_address().unwrap(); + common::premine_and_distribute_funds( + &bitcoind.client, + &electrsd.client, + vec![addr], + bitcoin::Amount::from_sat(100_000), + ) + .await; + node.sync_wallets().unwrap(); + + let balance = node.list_balances().spendable_onchain_balance_sats; + assert!(balance > 0); + let node_id = node.node_id(); + + node.stop().unwrap(); + (balance, node_id) + }; + + // Verify node can be restarted from PostgreSQL backend. + let mut builder = Builder::new(); + builder.set_network(bitcoin::Network::Regtest); + builder.set_storage_dir_path(storage_path); + builder.set_chain_source_esplora(esplora_url, None); + + let node = builder + .build_with_postgres_store( + node_entropy, + &connection_string, + Some("restart_test".to_string()), + ) + .unwrap(); + + node.start().unwrap(); + node.sync_wallets().unwrap(); + + assert_eq!(expected_node_id, node.node_id()); + assert_eq!(expected_balance_sats, node.list_balances().spendable_onchain_balance_sats); + + node.stop().unwrap(); +} diff --git a/tests/integration_tests_rust.rs b/tests/integration_tests_rust.rs index 413b2d44a..711bd4bc3 100644 --- a/tests/integration_tests_rust.rs +++ b/tests/integration_tests_rust.rs @@ -5,6 +5,8 @@ // http://opensource.org/licenses/MIT>, at your option. You may not use this file except in // accordance with one or both of these licenses. +#![cfg(feature = "sqlite")] + mod common; use std::collections::HashSet; diff --git a/tests/reorg_test.rs b/tests/reorg_test.rs index 295d9fdd2..74892d831 100644 --- a/tests/reorg_test.rs +++ b/tests/reorg_test.rs @@ -1,3 +1,5 @@ +#![cfg(feature = "sqlite")] + mod common; use std::collections::HashMap;