diff --git a/source/Nevermore.Tests/RelationalStore/ReadTransactionFixture.cs b/source/Nevermore.Tests/RelationalStore/ReadTransactionFixture.cs index 8a13a130..2092ca2e 100644 --- a/source/Nevermore.Tests/RelationalStore/ReadTransactionFixture.cs +++ b/source/Nevermore.Tests/RelationalStore/ReadTransactionFixture.cs @@ -19,11 +19,11 @@ public class ReadTransactionFixture RelationalTransactionRegistry registry; readonly List createdConnections = new(); - DbConnection ConnectionFactory(string s) + (DbConnection connection, bool ownsConnection) ConnectionFactory(string s) { var c = new FakeSqlConnection { ConnectionString = s }; createdConnections.Add(c); - return c; + return (c, true); } [SetUp] @@ -95,11 +95,11 @@ public override void Open() [Test] public void OpenWillRetryATransientFailure() { - DbConnection ConnectionFactoryTransientFailure(string s) + (DbConnection connection, bool ownsConnection) ConnectionFactoryTransientFailure(string s) { var c = new FakeSqlConnectionWhichThrowsOnFirstOpen { ConnectionString = s }; createdConnections.Add(c); - return c; + return (c, true); } var c = new ReadTransaction(null!, registry, RetriableOperation.Select, new RelationalStoreConfiguration(FakeConnectionString), ConnectionFactoryTransientFailure); @@ -115,11 +115,11 @@ DbConnection ConnectionFactoryTransientFailure(string s) [Test] public async Task OpenAsyncWillRetryATransientFailure() { - DbConnection ConnectionFactoryTransientFailure(string s) + (DbConnection connection, bool ownsConnection) ConnectionFactoryTransientFailure(string s) { var c = new FakeSqlConnectionWhichThrowsOnFirstOpen { ConnectionString = s }; createdConnections.Add(c); - return c; + return (c, true); } var c = new ReadTransaction(null!, registry, RetriableOperation.Select, new RelationalStoreConfiguration(FakeConnectionString), ConnectionFactoryTransientFailure); @@ -136,11 +136,11 @@ DbConnection ConnectionFactoryTransientFailure(string s) [Test] public void OpenWithIsolationWillRetryATransientFailure() { - DbConnection ConnectionFactoryTransientFailure(string s) + (DbConnection connection, bool ownsConnection) ConnectionFactoryTransientFailure(string s) { var c = new FakeSqlConnectionWhichThrowsOnFirstOpen { ConnectionString = s }; createdConnections.Add(c); - return c; + return (c, true); } var c = new ReadTransaction(null!, registry, RetriableOperation.Select, new RelationalStoreConfiguration(FakeConnectionString), ConnectionFactoryTransientFailure); @@ -183,11 +183,11 @@ public override DbTransaction BeginTransaction(IsolationLevel iso, string transa [Test] public void OpenWithIsolationWillRetryATransientFailureFromTransaction() { - DbConnection ConnectionFactoryTransientFailure(string s) + (DbConnection connection, bool ownsConnection) ConnectionFactoryTransientFailure(string s) { var c = new FakeSqlConnectionWhichThrowsOnFirstTransaction { ConnectionString = s }; createdConnections.Add(c); - return c; + return (c, true); } var c = new ReadTransaction(null!, registry, RetriableOperation.Select, new RelationalStoreConfiguration(FakeConnectionString), ConnectionFactoryTransientFailure); @@ -203,11 +203,11 @@ DbConnection ConnectionFactoryTransientFailure(string s) [Test] public async Task OpenAsyncWithIsolationWillRetryATransientFailureFromTransaction() { - DbConnection ConnectionFactoryTransientFailure(string s) + (DbConnection connection, bool ownsConnection) ConnectionFactoryTransientFailure(string s) { var c = new FakeSqlConnectionWhichThrowsOnFirstTransaction { ConnectionString = s }; createdConnections.Add(c); - return c; + return (c, true); } var c = new ReadTransaction(null!, registry, RetriableOperation.Select, new RelationalStoreConfiguration(FakeConnectionString), ConnectionFactoryTransientFailure); diff --git a/source/Nevermore/Advanced/ReadTransaction.cs b/source/Nevermore/Advanced/ReadTransaction.cs index 6fe857e0..9724f2af 100644 --- a/source/Nevermore/Advanced/ReadTransaction.cs +++ b/source/Nevermore/Advanced/ReadTransaction.cs @@ -26,14 +26,14 @@ namespace Nevermore.Advanced public class ReadTransaction : IReadTransaction, ITransactionDiagnostic { static readonly ILog Log = LogProvider.For(); - protected static DbConnection DefaultConnectionFactory(string connectionString) => new SqlConnection(connectionString); + protected static (DbConnection connection, bool ownsConnection) DefaultConnectionFactory(string connectionString) => (new SqlConnection(connectionString), true); readonly IRelationalTransactionRegistry registry; readonly RetriableOperation operationsToRetry; readonly IRelationalStoreConfiguration configuration; readonly ITableAliasGenerator tableAliasGenerator = new TableAliasGenerator(); - readonly Func connectionFactory; + readonly Func connectionFactory; readonly Action? customCommandTrace; readonly string name; @@ -86,12 +86,12 @@ public ReadTransaction( OwnsSqlTransaction = false; } - internal ReadTransaction( + public ReadTransaction( IRelationalStore store, IRelationalTransactionRegistry registry, RetriableOperation operationsToRetry, IRelationalStoreConfiguration configuration, - Func connectionFactory, + Func connectionFactory, Action? customCommandTrace = null, string? name = null) { @@ -126,8 +126,13 @@ public void Open() if (!OwnsSqlTransaction) throw new InvalidOperationException("An existing connection and transaction were provided, they should have been opened externally"); - connection = connectionFactory(configuration.ConnectionString); - connection.OpenWithRetry(); + var (connectionFactoryConnection, ownsConnection) = connectionFactory(configuration.ConnectionString); + connection = connectionFactoryConnection; + + if (ownsConnection) + { + connection.OpenWithRetry(); + } TransactionTimer = new TimedSection(ms => configuration.TransactionLogger.Write(ms, name)); } @@ -137,8 +142,13 @@ public async Task OpenAsync(CancellationToken cancellationToken = default) if (!OwnsSqlTransaction) throw new InvalidOperationException("An existing connection and transaction were provided, they should have been opened externally"); - connection = connectionFactory(configuration.ConnectionString); - await connection.OpenWithRetryAsync(cancellationToken).ConfigureAwait(false); + var (connectionFactoryConnection, ownsConnection) = connectionFactory(configuration.ConnectionString); + connection = connectionFactoryConnection; + + if (ownsConnection) + { + await connection.OpenWithRetryAsync(cancellationToken).ConfigureAwait(false); + } TransactionTimer = new TimedSection(ms => configuration.TransactionLogger.Write(ms, name)); } diff --git a/source/Nevermore/Advanced/WriteTransaction.cs b/source/Nevermore/Advanced/WriteTransaction.cs index 2befdd7c..8f6abc95 100644 --- a/source/Nevermore/Advanced/WriteTransaction.cs +++ b/source/Nevermore/Advanced/WriteTransaction.cs @@ -55,6 +55,22 @@ public WriteTransaction( this.keyAllocator = keyAllocator; builder = new DataModificationQueryBuilder(configuration, AllocateId); } + + public WriteTransaction( + IRelationalStore store, + IRelationalTransactionRegistry registry, + RetriableOperation operationsToRetry, + IRelationalStoreConfiguration configuration, + IKeyAllocator keyAllocator, + Func connectionFactory, + Action? customCommandTrace = null, + string? name = null + ) : base(store, registry, operationsToRetry, configuration, connectionFactory, customCommandTrace, name) + { + this.configuration = configuration; + this.keyAllocator = keyAllocator; + builder = new DataModificationQueryBuilder(configuration, AllocateId); + } #nullable disable public void Insert(TDocument document, InsertOptions options = null) where TDocument : class diff --git a/source/Nevermore/RelationalStore.cs b/source/Nevermore/RelationalStore.cs index 3456b6b6..836b6149 100644 --- a/source/Nevermore/RelationalStore.cs +++ b/source/Nevermore/RelationalStore.cs @@ -96,6 +96,37 @@ public async Task BeginWriteTransactionAsync(IsolationLevel i throw; } } + + public IWriteTransaction BeginWriteTransactionFromExistingConnectionFactory(Func connectionFactory, IsolationLevel isolationLevel = NevermoreDefaults.IsolationLevel, RetriableOperation retriableOperation = NevermoreDefaults.RetriableOperations, string? name = null, CancellationToken cancellationToken = default) + { + var txn = CreateWriteTransactionFromExistingConnectionFactory(connectionFactory, retriableOperation, name); + try + { + txn.Open(isolationLevel); + return txn; + } + catch + { + txn.Dispose(); + throw; + } + } + + public async Task BeginWriteTransactionFromExistingConnectionFactoryAsync(Func connectionFactory, IsolationLevel isolationLevel = NevermoreDefaults.IsolationLevel, RetriableOperation retriableOperation = NevermoreDefaults.RetriableOperations, string? name = null, CancellationToken cancellationToken = default) + { + var txn = CreateWriteTransactionFromExistingConnectionFactory(connectionFactory, retriableOperation, name); + try + { + await txn.OpenAsync(isolationLevel, cancellationToken).ConfigureAwait(false); + return txn; + } + catch + { + txn.Dispose(); + throw; + } + } + public IRelationalTransaction BeginTransaction(IsolationLevel isolationLevel = NevermoreDefaults.IsolationLevel, RetriableOperation retriableOperation = NevermoreDefaults.RetriableOperations, string? name = null) { @@ -140,6 +171,21 @@ public IWriteTransaction CreateWriteTransactionFromExistingConnectionAndTransact customCommandTrace, name); } + + public WriteTransaction CreateWriteTransactionFromExistingConnectionFactory( + Func connectionFactory, + RetriableOperation retriableOperation, + string? name = null) + { + return new WriteTransaction( + this, + registry.Value, + retriableOperation, + Configuration, + keyAllocator.Value, + connectionFactory, + name: name); + } ReadTransaction CreateReadTransaction(RetriableOperation retriableOperation, string? name = null) {