From 55303731ae6e5b0e3842b351cd12caf7787c0d63 Mon Sep 17 00:00:00 2001 From: Karen Chen <64801825+karenc-bq@users.noreply.github.com> Date: Wed, 11 Mar 2026 20:09:50 -0700 Subject: [PATCH] feat: GDB RW Splitting --- common/lib/connection_plugin_chain_builder.ts | 3 + .../connection_string_host_list_provider.ts | 4 + .../rds_host_list_provider.ts | 15 ++- common/lib/partial_plugin_service.ts | 2 +- common/lib/pg_client_wrapper.ts | 2 +- common/lib/plugin_service.ts | 23 ++++ .../plugins/bluegreen/blue_green_plugin.ts | 15 ++- .../bluegreen/blue_green_plugin_factory.ts | 2 +- .../lib/plugins/failover/failover_plugin.ts | 16 ++- .../failover/failover_plugin_factory.ts | 2 +- .../failover/writer_failover_handler.ts | 25 +++- .../gdb_read_write_splitting_plugin.ts | 96 ++++++++++++++++ ...gdb_read_write_splitting_plugin_factory.ts | 39 +++++++ .../read_write_splitting_plugin.ts | 4 +- .../fastest_response_strategy_plugin.ts | 14 +-- ...fastest_respose_strategy_plugin_factory.ts | 5 +- .../host_response_time_monitor.ts | 108 +++++++++++------- .../host_response_time_service.ts | 97 +++++++--------- common/lib/utils/errors.ts | 2 + common/lib/utils/messages.ts | 5 + common/lib/wrapper_property.ts | 14 +++ index.ts | 2 +- tests/integration/container/tests/config.ts | 5 + .../tests/failover/gdb_failover.test.ts | 1 + .../tests/iam_authentication.test.ts | 9 -- .../tests/parameterized_queries.test.ts | 2 +- .../container/tests/pg_pool.test.ts | 6 + .../container/tests/utils/test_environment.ts | 13 +++ .../host/TestEnvironmentConfig.java | 4 + .../host/util/ContainerHelper.java | 1 - tests/unit/failover_plugin.test.ts | 9 +- tests/unit/read_write_splitting.test.ts | 1 + tests/unit/writer_failover_handler.test.ts | 101 +++++++++++++++- 33 files changed, 492 insertions(+), 155 deletions(-) create mode 100644 common/lib/plugins/read_write_splitting/gdb_read_write_splitting_plugin.ts create mode 100644 common/lib/plugins/read_write_splitting/gdb_read_write_splitting_plugin_factory.ts diff --git a/common/lib/connection_plugin_chain_builder.ts b/common/lib/connection_plugin_chain_builder.ts index 1cad14c72..38ec0a4f1 100644 --- a/common/lib/connection_plugin_chain_builder.ts +++ b/common/lib/connection_plugin_chain_builder.ts @@ -45,6 +45,7 @@ import { HostMonitoring2PluginFactory } from "./plugins/efm2/host_monitoring2_pl import { BlueGreenPluginFactory } from "./plugins/bluegreen/blue_green_plugin_factory"; import { GlobalDbFailoverPluginFactory } from "./plugins/gdb_failover/global_db_failover_plugin_factory"; import { FullServicesContainer } from "./utils/full_services_container"; +import { GdbReadWriteSplittingPluginFactory } from "./plugins/read_write_splitting/gdb_read_write_splitting_plugin_factory"; /* Type alias used for plugin factory sorting. It holds a reference to a plugin @@ -65,6 +66,7 @@ export class ConnectionPluginChainBuilder { ["staleDns", { factory: StaleDnsPluginFactory, weight: 500 }], ["bg", { factory: BlueGreenPluginFactory, weight: 550 }], ["readWriteSplitting", { factory: ReadWriteSplittingPluginFactory, weight: 600 }], + ["gdbReadWriteSplitting", { factory: GdbReadWriteSplittingPluginFactory, weight: 610 }], ["failover", { factory: FailoverPluginFactory, weight: 700 }], ["failover2", { factory: Failover2PluginFactory, weight: 710 }], ["gdbFailover", { factory: GlobalDbFailoverPluginFactory, weight: 720 }], @@ -87,6 +89,7 @@ export class ConnectionPluginChainBuilder { [StaleDnsPluginFactory, 500], [BlueGreenPluginFactory, 550], [ReadWriteSplittingPluginFactory, 600], + [GdbReadWriteSplittingPluginFactory, 610], [FailoverPluginFactory, 700], [Failover2PluginFactory, 710], [GlobalDbFailoverPluginFactory, 720], diff --git a/common/lib/host_list_provider/connection_string_host_list_provider.ts b/common/lib/host_list_provider/connection_string_host_list_provider.ts index a7973111a..d21e004a2 100644 --- a/common/lib/host_list_provider/connection_string_host_list_provider.ts +++ b/common/lib/host_list_provider/connection_string_host_list_provider.ts @@ -99,4 +99,8 @@ export class ConnectionStringHostListProvider implements StaticHostListProvider getClusterId(): string { throw new AwsWrapperError("ConnectionStringHostListProvider does not support getClusterId."); } + + forceMonitoringRefresh(shouldVerifyWriter: boolean, timeoutMs: number): Promise { + throw new AwsWrapperError("ConnectionStringHostListProvider does not support forceMonitoringRefresh."); + } } diff --git a/common/lib/host_list_provider/rds_host_list_provider.ts b/common/lib/host_list_provider/rds_host_list_provider.ts index 9558e1aa4..4185d7db7 100644 --- a/common/lib/host_list_provider/rds_host_list_provider.ts +++ b/common/lib/host_list_provider/rds_host_list_provider.ts @@ -144,11 +144,21 @@ export class RdsHostListProvider implements DynamicHostListProvider { if (!this.pluginService.isDialectConfirmed()) { // We need to confirm the dialect before creating a topology monitor so that it uses the correct SQL queries. - // We will return the original hosts parsed from the connections string until the dialect has been confirmed. + // Return the original hosts parsed from the connection string. return this.initialHostList; } - return await this.forceRefreshMonitor(verifyTopology, timeoutMs); + const hosts = await this.forceRefreshMonitor(verifyTopology, timeoutMs); + if (hosts && hosts.length > 0) { + return hosts; + } + + // Check for cached topology as a fallback. + const storedTopology = this.getStoredTopology(); + if (storedTopology && storedTopology.length > 0) { + return storedTopology; + } + return this.initialHostList; } async getHostRole(client: ClientWrapper, _dialect: DatabaseDialect): Promise { @@ -236,6 +246,7 @@ export class RdsHostListProvider implements DynamicHostListProvider { } async getCurrentTopology(targetClient: ClientWrapper, dialect: DatabaseDialect): Promise { + this.init(); return await this.topologyUtils.queryForTopology(targetClient, dialect, this.initialHost, this.clusterInstanceTemplate); } diff --git a/common/lib/partial_plugin_service.ts b/common/lib/partial_plugin_service.ts index 2097d093e..b3b8f7295 100644 --- a/common/lib/partial_plugin_service.ts +++ b/common/lib/partial_plugin_service.ts @@ -264,7 +264,7 @@ export class PartialPluginService implements PluginService, HostListProviderServ } isDialectConfirmed(): boolean { - throw new AwsWrapperError(Messages.get("PartialPluginService.unexpectedMethodCall", "isDialectConfirmed")); + return true; } setInTransaction(inTransaction: boolean): void { diff --git a/common/lib/pg_client_wrapper.ts b/common/lib/pg_client_wrapper.ts index 71af129c1..a9c2e3204 100644 --- a/common/lib/pg_client_wrapper.ts +++ b/common/lib/pg_client_wrapper.ts @@ -60,7 +60,7 @@ export class PgClientWrapper implements ClientWrapper { async abort(): Promise { try { - return await ClientUtils.queryWithTimeout(this.end(), this.properties); + this.client?.connection?.stream?.destroy(); } catch (error: any) { // Ignore } diff --git a/common/lib/plugin_service.ts b/common/lib/plugin_service.ts index 68f69c979..8e8a5a1a8 100644 --- a/common/lib/plugin_service.ts +++ b/common/lib/plugin_service.ts @@ -281,7 +281,28 @@ export class PluginServiceImpl implements PluginService, HostListProviderService return this.dialect; } + private static readonly DIALECT_CONFIRMED_STATUS_KEY = "DialectConfirmed"; + + private getDialectConfirmedCacheKey(): string { + let clusterId = WrapperProperties.CLUSTER_ID.defaultValue; + try { + clusterId = this._hostListProvider?.getClusterId() ?? WrapperProperties.CLUSTER_ID.defaultValue; + } catch (e) { + // May fail if the host list provider does not support getClusterId. In this case use the default value. + } + return `${clusterId}::${PluginServiceImpl.DIALECT_CONFIRMED_STATUS_KEY}`; + } + isDialectConfirmed(): boolean { + if (this._isDialectConfirmed) { + return true; + } + + const cacheItem = this.storageService.get(StatusCacheItem, this.getDialectConfirmedCacheKey()); + if (cacheItem && cacheItem.status === true) { + this._isDialectConfirmed = true; + } + return this._isDialectConfirmed; } @@ -634,6 +655,8 @@ export class PluginServiceImpl implements PluginService, HostListProviderService this.dialect = await this.dbDialectProvider.getDialectForUpdate(targetClient, this.initialHost, this.props.get(WrapperProperties.HOST.name)); this._isDialectConfirmed = true; + this.storageService.set(this.getDialectConfirmedCacheKey(), new StatusCacheItem(true)); + if (originalDialect === this.dialect) { return; } diff --git a/common/lib/plugins/bluegreen/blue_green_plugin.ts b/common/lib/plugins/bluegreen/blue_green_plugin.ts index 17cacb8d9..ddcd34e89 100644 --- a/common/lib/plugins/bluegreen/blue_green_plugin.ts +++ b/common/lib/plugins/bluegreen/blue_green_plugin.ts @@ -27,9 +27,10 @@ import { IamAuthenticationPlugin } from "../../authentication/iam_authentication import { BlueGreenRole } from "./blue_green_role"; import { ExecuteRouting, RoutingResultHolder } from "./routing/execute_routing"; import { CanReleaseResources } from "../../can_release_resources"; +import { FullServicesContainer } from "../../utils/full_services_container"; export interface BlueGreenProviderSupplier { - create(pluginService: PluginService, props: Map, bgdId: string): BlueGreenStatusProvider; + create(servicesContainer: FullServicesContainer, props: Map, bgdId: string): BlueGreenStatusProvider; } export class BlueGreenPlugin extends AbstractConnectionPlugin implements CanReleaseResources { @@ -42,6 +43,7 @@ export class BlueGreenPlugin extends AbstractConnectionPlugin implements CanRele private static readonly CLOSED_METHOD_NAMES: Set = new Set(["end", "abort"]); protected readonly pluginService: PluginService; + protected readonly servicesContainer: FullServicesContainer; protected readonly properties: Map; protected bgProviderSupplier: BlueGreenProviderSupplier; protected bgStatus: BlueGreenStatus = null; @@ -53,18 +55,19 @@ export class BlueGreenPlugin extends AbstractConnectionPlugin implements CanRele protected endTimeNano: bigint = BigInt(0); private static provider: Map = new Map(); - constructor(pluginService: PluginService, properties: Map, bgProviderSupplier: BlueGreenProviderSupplier = null) { + constructor(servicesContainer: FullServicesContainer, properties: Map, bgProviderSupplier: BlueGreenProviderSupplier = null) { super(); if (!bgProviderSupplier) { bgProviderSupplier = { - create: (pluginService: PluginService, props: Map, bgdId: string): BlueGreenStatusProvider => { - return new BlueGreenStatusProvider(pluginService, props, bgdId); + create: (servicesContainer: FullServicesContainer, props: Map, bgdId: string): BlueGreenStatusProvider => { + return new BlueGreenStatusProvider(servicesContainer, props, bgdId); } }; } this.properties = properties; - this.pluginService = pluginService; + this.servicesContainer = servicesContainer; + this.pluginService = servicesContainer.pluginService; this.bgProviderSupplier = bgProviderSupplier; this.bgdId = WrapperProperties.BGD_ID.get(this.properties).trim().toLowerCase(); } @@ -215,7 +218,7 @@ export class BlueGreenPlugin extends AbstractConnectionPlugin implements CanRele private initProvider() { const provider = BlueGreenPlugin.provider.get(this.bgdId); if (!provider) { - const provider = this.bgProviderSupplier.create(this.pluginService, this.properties, this.bgdId); + const provider = this.bgProviderSupplier.create(this.servicesContainer, this.properties, this.bgdId); BlueGreenPlugin.provider.set(this.bgdId, provider); } } diff --git a/common/lib/plugins/bluegreen/blue_green_plugin_factory.ts b/common/lib/plugins/bluegreen/blue_green_plugin_factory.ts index d20e00b66..c7f581067 100644 --- a/common/lib/plugins/bluegreen/blue_green_plugin_factory.ts +++ b/common/lib/plugins/bluegreen/blue_green_plugin_factory.ts @@ -28,7 +28,7 @@ export class BlueGreenPluginFactory extends ConnectionPluginFactory { if (!BlueGreenPluginFactory.blueGreenPlugin) { BlueGreenPluginFactory.blueGreenPlugin = await import("./blue_green_plugin"); } - return new BlueGreenPluginFactory.blueGreenPlugin.BlueGreenPlugin(servicesContainer.pluginService, props); + return new BlueGreenPluginFactory.blueGreenPlugin.BlueGreenPlugin(servicesContainer, props); } catch (error: any) { throw new AwsWrapperError(Messages.get("ConnectionPluginChainBuilder.errorImportingPlugin", error.message, "BlueGreenPluginFactory")); } diff --git a/common/lib/plugins/failover/failover_plugin.ts b/common/lib/plugins/failover/failover_plugin.ts index bbb81e73e..ceefae36c 100644 --- a/common/lib/plugins/failover/failover_plugin.ts +++ b/common/lib/plugins/failover/failover_plugin.ts @@ -43,6 +43,7 @@ import { ClientWrapper } from "../../client_wrapper"; import { getWriter, logTopology } from "../../utils/utils"; import { TelemetryCounter } from "../../utils/telemetry/telemetry_counter"; import { TelemetryTraceLevel } from "../../utils/telemetry/telemetry_trace_level"; +import { FullServicesContainer } from "../../utils/full_services_container"; export class FailoverPlugin extends AbstractConnectionPlugin { private static readonly TELEMETRY_WRITER_FAILOVER = "failover to writer instance"; @@ -79,18 +80,19 @@ export class FailoverPlugin extends AbstractConnectionPlugin { private hostListProviderService?: HostListProviderService; private readonly pluginService: PluginService; + private readonly servicesContainer: FullServicesContainer; protected enableFailoverSetting: boolean = WrapperProperties.ENABLE_CLUSTER_AWARE_FAILOVER.defaultValue; - constructor(pluginService: PluginService, properties: Map, rdsHelper: RdsUtils); + constructor(servicesContainer: FullServicesContainer, properties: Map, rdsHelper: RdsUtils); constructor( - pluginService: PluginService, + servicesContainer: FullServicesContainer, properties: Map, rdsHelper: RdsUtils, readerFailoverHandler: ClusterAwareReaderFailoverHandler, writerFailoverHandler: ClusterAwareWriterFailoverHandler ); constructor( - pluginService: PluginService, + servicesContainer: FullServicesContainer, properties: Map, rdsHelper: RdsUtils, readerFailoverHandler?: ClusterAwareReaderFailoverHandler, @@ -98,7 +100,8 @@ export class FailoverPlugin extends AbstractConnectionPlugin { ) { super(); this._properties = properties; - this.pluginService = pluginService; + this.pluginService = servicesContainer.pluginService; + this.servicesContainer = servicesContainer; this._rdsHelper = rdsHelper; this.initSettings(); @@ -106,7 +109,7 @@ export class FailoverPlugin extends AbstractConnectionPlugin { this._readerFailoverHandler = readerFailoverHandler ? readerFailoverHandler : new ClusterAwareReaderFailoverHandler( - pluginService, + this.pluginService, properties, this.failoverTimeoutMsSetting, this.failoverReaderConnectTimeoutMsSetting, @@ -115,7 +118,8 @@ export class FailoverPlugin extends AbstractConnectionPlugin { this._writerFailoverHandler = writerFailoverHandler ? writerFailoverHandler : new ClusterAwareWriterFailoverHandler( - pluginService, + this.pluginService, + this.servicesContainer, this._readerFailoverHandler, properties, this.failoverTimeoutMsSetting, diff --git a/common/lib/plugins/failover/failover_plugin_factory.ts b/common/lib/plugins/failover/failover_plugin_factory.ts index cab67b400..82ac44835 100644 --- a/common/lib/plugins/failover/failover_plugin_factory.ts +++ b/common/lib/plugins/failover/failover_plugin_factory.ts @@ -29,7 +29,7 @@ export class FailoverPluginFactory extends ConnectionPluginFactory { if (!FailoverPluginFactory.failoverPlugin) { FailoverPluginFactory.failoverPlugin = await import("./failover_plugin"); } - return new FailoverPluginFactory.failoverPlugin.FailoverPlugin(servicesContainer.pluginService, properties, new RdsUtils()); + return new FailoverPluginFactory.failoverPlugin.FailoverPlugin(servicesContainer, properties, new RdsUtils()); } catch (error: any) { throw new AwsWrapperError(Messages.get("ConnectionPluginChainBuilder.errorImportingPlugin", error.message, "FailoverPlugin")); } diff --git a/common/lib/plugins/failover/writer_failover_handler.ts b/common/lib/plugins/failover/writer_failover_handler.ts index b8780b888..e14a73fd3 100644 --- a/common/lib/plugins/failover/writer_failover_handler.ts +++ b/common/lib/plugins/failover/writer_failover_handler.ts @@ -27,6 +27,9 @@ import { logger } from "../../../logutils"; import { WrapperProperties } from "../../wrapper_property"; import { ClientWrapper } from "../../client_wrapper"; import { FailoverRestriction } from "./failover_restriction"; +import { FullServicesContainer } from "../../utils/full_services_container"; +import { ServiceUtils } from "../../utils/service_utils"; +import { DatabaseDialect } from "../../database_dialect/database_dialect"; export interface WriterFailoverHandler { failover(currentTopology: HostInfo[]): Promise; @@ -47,6 +50,7 @@ export class ClusterAwareWriterFailoverHandler implements WriterFailoverHandler static readonly RECONNECT_WRITER_TASK = "TaskA"; static readonly WAIT_NEW_WRITER_TASK = "TaskB"; private readonly pluginService: PluginService; + private readonly servicesContainer: FullServicesContainer; private readonly readerFailoverHandler: ClusterAwareReaderFailoverHandler; private readonly initialConnectionProps: Map; maxFailoverTimeoutMs = 60000; // 60 sec @@ -55,6 +59,7 @@ export class ClusterAwareWriterFailoverHandler implements WriterFailoverHandler constructor( pluginService: PluginService, + servicesContainer: FullServicesContainer, readerFailoverHandler: ClusterAwareReaderFailoverHandler, initialConnectionProps: Map, failoverTimeoutMs?: number, @@ -62,6 +67,7 @@ export class ClusterAwareWriterFailoverHandler implements WriterFailoverHandler reconnectWriterIntervalMs?: number ) { this.pluginService = pluginService; + this.servicesContainer = servicesContainer; this.readerFailoverHandler = readerFailoverHandler; this.initialConnectionProps = initialConnectionProps; this.maxFailoverTimeoutMs = failoverTimeoutMs ?? this.maxFailoverTimeoutMs; @@ -69,16 +75,29 @@ export class ClusterAwareWriterFailoverHandler implements WriterFailoverHandler this.reconnectionWriterIntervalMs = reconnectWriterIntervalMs ?? this.reconnectionWriterIntervalMs; } + protected async newServicesContainer(): Promise { + const container = ServiceUtils.instance.createMinimalServiceContainerFrom(this.servicesContainer, this.initialConnectionProps); + await container.pluginManager.init(); + const initialHostInfo = this.pluginService.getInitialConnectionHostInfo(); + if (initialHostInfo) { + container.hostListProviderService.setInitialConnectionHostInfo(initialHostInfo); + } + return container; + } + async failover(currentTopology: HostInfo[]): Promise { if (!currentTopology || currentTopology.length == 0) { logger.error(Messages.get("ClusterAwareWriterFailoverHandler.failoverCalledWithInvalidTopology")); return ClusterAwareWriterFailoverHandler.DEFAULT_RESULT; } + const taskAContainer = await this.newServicesContainer(); + const taskBContainer = await this.newServicesContainer(); + const reconnectToWriterHandlerTask = new ReconnectToWriterHandlerTask( currentTopology, getWriter(currentTopology), - this.pluginService, + taskAContainer.pluginService, this.initialConnectionProps, this.reconnectionWriterIntervalMs, Date.now() + this.maxFailoverTimeoutMs @@ -88,7 +107,7 @@ export class ClusterAwareWriterFailoverHandler implements WriterFailoverHandler currentTopology, getWriter(currentTopology), this.readerFailoverHandler, - this.pluginService, + taskBContainer.pluginService, this.initialConnectionProps, this.readTopologyIntervalMs, Date.now() + this.maxFailoverTimeoutMs @@ -379,7 +398,7 @@ class WaitForNewWriterHandlerTask { async refreshTopologyAndConnectToNewWriter(): Promise { const allowOldWriter: boolean = this.pluginService.getDialect().getFailoverRestrictions().includes(FailoverRestriction.ENABLE_WRITER_IN_TASK_B); - while (this.pluginService.getCurrentClient() && Date.now() < this.endTime && !this.failoverCompleted) { + while (Date.now() < this.endTime && !this.failoverCompleted) { try { if (this.currentReaderTargetClient) { await this.pluginService.forceRefreshHostList(); diff --git a/common/lib/plugins/read_write_splitting/gdb_read_write_splitting_plugin.ts b/common/lib/plugins/read_write_splitting/gdb_read_write_splitting_plugin.ts new file mode 100644 index 000000000..b9f4d94f5 --- /dev/null +++ b/common/lib/plugins/read_write_splitting/gdb_read_write_splitting_plugin.ts @@ -0,0 +1,96 @@ +/* + Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"). + You may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ + +import { ReadWriteSplittingPlugin } from "./read_write_splitting_plugin"; +import { WrapperProperties } from "../../wrapper_property"; +import { HostInfo } from "../../host_info"; +import { RdsUtils } from "../../utils/rds_utils"; +import { ReadWriteSplittingError } from "../../utils/errors"; +import { Messages } from "../../utils/messages"; +import { logger } from "../../../logutils"; +import { ClientWrapper } from "../../client_wrapper"; +import { equalsIgnoreCase } from "../../utils/utils"; + +export class GdbReadWriteSplittingPlugin extends ReadWriteSplittingPlugin { + protected readonly rdsUtils: RdsUtils = new RdsUtils(); + + protected restrictWriterToHomeRegion: boolean; + protected restrictReaderToHomeRegion: boolean; + + protected isInitialized: boolean = false; + protected homeRegion: string; + + protected initSettings(initHostInfo: HostInfo, properties: Map): void { + if (this.isInitialized) { + return; + } + this.restrictWriterToHomeRegion = WrapperProperties.GDB_RW_RESTRICT_WRITER_TO_HOME_REGION.get(properties); + this.restrictReaderToHomeRegion = WrapperProperties.GDB_RW_RESTRICT_READER_TO_HOME_REGION.get(properties); + + this.homeRegion = WrapperProperties.GDB_RW_HOME_REGION.get(properties); + if (!this.homeRegion) { + const rdsUrlType = this.rdsUtils.identifyRdsType(initHostInfo.host); + if (rdsUrlType.hasRegion) { + this.homeRegion = this.rdsUtils.getRdsRegion(initHostInfo.host); + } + } + + if (!this.homeRegion) { + throw new ReadWriteSplittingError(Messages.get("GdbReadWriteSplittingPlugin.missingHomeRegion", initHostInfo.host)); + } + + logger.debug(Messages.get("GdbReadWriteSplittingPlugin.parameterValue", "gdbRwHomeRegion", this.homeRegion)); + + this.isInitialized = true; + } + + override async connect( + hostInfo: HostInfo, + props: Map, + isInitialConnection: boolean, + connectFunc: () => Promise + ): Promise { + this.initSettings(hostInfo, props); + return super.connect(hostInfo, props, isInitialConnection, connectFunc); + } + + override setWriterClient(writerTargetClient: ClientWrapper | undefined, writerHostInfo: HostInfo) { + if ( + this.restrictWriterToHomeRegion && + writerHostInfo != null && + !equalsIgnoreCase(this.rdsUtils.getRdsRegion(writerHostInfo.host), this.homeRegion) + ) { + throw new ReadWriteSplittingError( + Messages.get("GdbReadWriteSplittingPlugin.cantConnectWriterOutOfHomeRegion", writerHostInfo.host, this.homeRegion) + ); + } + super.setWriterClient(writerTargetClient, writerHostInfo); + } + + protected getReaderHostCandidates(): HostInfo[] { + if (this.restrictReaderToHomeRegion) { + const hostsInRegion: HostInfo[] = this.pluginService + .getHosts() + .filter((x) => equalsIgnoreCase(this.rdsUtils.getRdsRegion(x.host), this.homeRegion)); + + if (hostsInRegion.length === 0) { + throw new ReadWriteSplittingError(Messages.get("GdbReadWriteSplittingPlugin.noAvailableReadersInHomeRegion", this.homeRegion)); + } + return hostsInRegion; + } + return super.getReaderHostCandidates(); + } +} diff --git a/common/lib/plugins/read_write_splitting/gdb_read_write_splitting_plugin_factory.ts b/common/lib/plugins/read_write_splitting/gdb_read_write_splitting_plugin_factory.ts new file mode 100644 index 000000000..7aa6d1ad5 --- /dev/null +++ b/common/lib/plugins/read_write_splitting/gdb_read_write_splitting_plugin_factory.ts @@ -0,0 +1,39 @@ +/* + Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"). + You may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ + +import { ConnectionPluginFactory } from "../../plugin_factory"; +import { ConnectionPlugin } from "../../connection_plugin"; +import { AwsWrapperError } from "../../utils/errors"; +import { Messages } from "../../utils/messages"; +import { FullServicesContainer } from "../../utils/full_services_container"; + +export class GdbReadWriteSplittingPluginFactory extends ConnectionPluginFactory { + private static gdbReadWriteSplittingPlugin: any; + + async getInstance(servicesContainer: FullServicesContainer, properties: Map): Promise { + try { + if (!GdbReadWriteSplittingPluginFactory.gdbReadWriteSplittingPlugin) { + GdbReadWriteSplittingPluginFactory.gdbReadWriteSplittingPlugin = await import("./gdb_read_write_splitting_plugin"); + } + return new GdbReadWriteSplittingPluginFactory.gdbReadWriteSplittingPlugin.GdbReadWriteSplittingPlugin( + servicesContainer.pluginService, + properties + ); + } catch (error: any) { + throw new AwsWrapperError(Messages.get("ConnectionPluginChainBuilder.errorImportingPlugin", error.message, "gdbReadWriteSplittingPlugin")); + } + } +} diff --git a/common/lib/plugins/read_write_splitting/read_write_splitting_plugin.ts b/common/lib/plugins/read_write_splitting/read_write_splitting_plugin.ts index 6ade242fa..2d70b8936 100644 --- a/common/lib/plugins/read_write_splitting/read_write_splitting_plugin.ts +++ b/common/lib/plugins/read_write_splitting/read_write_splitting_plugin.ts @@ -64,7 +64,7 @@ export class ReadWriteSplittingPlugin extends AbstractReadWriteSplittingPlugin { if (!isInitialConnection || !this._hostListProviderService?.isDynamicHostListProvider()) { return result; } - const currentRole = this.pluginService.getCurrentHostInfo()?.role; + const currentRole = await this.pluginService.getHostRole(result); if (currentRole == HostRole.UNKNOWN) { logAndThrowError(Messages.get("ReadWriteSplittingPlugin.errorVerifyingInitialHostRole")); @@ -181,7 +181,7 @@ export class ReadWriteSplittingPlugin extends AbstractReadWriteSplittingPlugin { } } - protected getReaderHostCandidates(): HostInfo[] | undefined { + protected getReaderHostCandidates(): HostInfo[] { return this.pluginService.getHosts(); } } diff --git a/common/lib/plugins/strategy/fastest_response/fastest_response_strategy_plugin.ts b/common/lib/plugins/strategy/fastest_response/fastest_response_strategy_plugin.ts index e0d3a1748..54e10a7a9 100644 --- a/common/lib/plugins/strategy/fastest_response/fastest_response_strategy_plugin.ts +++ b/common/lib/plugins/strategy/fastest_response/fastest_response_strategy_plugin.ts @@ -26,6 +26,7 @@ import { HostChangeOptions } from "../../../host_change_options"; import { RandomHostSelector } from "../../../random_host_selector"; import { Messages } from "../../../utils/messages"; import { equalsIgnoreCase, logAndThrowError } from "../../../utils/utils"; +import { FullServicesContainer } from "../../../utils/full_services_container"; export class FastestResponseStrategyPlugin extends AbstractConnectionPlugin { static readonly FASTEST_RESPONSE_STRATEGY_NAME: string = "fastestResponse"; @@ -43,13 +44,13 @@ export class FastestResponseStrategyPlugin extends AbstractConnectionPlugin { private pluginService: PluginService; private randomHostSelector: RandomHostSelector = new RandomHostSelector(); - constructor(pluginService: PluginService, properties: Map, hostResponseTimeService?: HostResponseTimeService) { + constructor(servicesContainer: FullServicesContainer, properties: Map, hostResponseTimeService?: HostResponseTimeService) { super(); - this.pluginService = pluginService; + this.pluginService = servicesContainer.pluginService; this.properties = properties; this.hostResponseTimeService = hostResponseTimeService ?? - new HostResponseTimeServiceImpl(pluginService, properties, WrapperProperties.RESPONSE_MEASUREMENT_INTERVAL_MILLIS.get(this.properties)); + new HostResponseTimeServiceImpl(servicesContainer, properties, WrapperProperties.RESPONSE_MEASUREMENT_INTERVAL_MILLIS.get(this.properties)); this.cacheExpirationNanos = BigInt(WrapperProperties.RESPONSE_MEASUREMENT_INTERVAL_MILLIS.get(this.properties) * 1_000_000); } @@ -91,18 +92,13 @@ export class FastestResponseStrategyPlugin extends AbstractConnectionPlugin { if (!this.acceptsStrategy(role, strategy)) { logAndThrowError(Messages.get("FastestResponseStrategyPlugin.unsupportedHostSelectorStrategy", strategy)); } - // The cache holds a host with the fastest response time. - // If the cache doesn't have a host for a role, it's necessary to find the fastest host in the topology. const fastestResponseHost: HostInfo = FastestResponseStrategyPlugin.cachedFastestResponseHostByRole.get(role); if (fastestResponseHost) { - // Found the fastest host. Find the host in the latest topology. const foundHost = this.pluginService.getHosts().find((host) => host === fastestResponseHost); if (foundHost) { - // Found a host in the topology. return foundHost; } } - // Cached result isn't available. Need to find the fastest response time host. const calculatedFastestResponseHost: ResponseTimeTuple[] = this.pluginService .getHosts() .filter((host) => role === host.role) @@ -113,8 +109,6 @@ export class FastestResponseStrategyPlugin extends AbstractConnectionPlugin { const calculatedHost = calculatedFastestResponseHost.length === 0 ? null : calculatedFastestResponseHost[0]; if (!calculatedHost) { - // Unable to identify the fastest response host. - // As a last resort, let's use a random host selector. return this.randomHostSelector.getHost(hosts, role, this.properties); } FastestResponseStrategyPlugin.cachedFastestResponseHostByRole.put(role, calculatedHost.hostInfo, Number(this.cacheExpirationNanos)); diff --git a/common/lib/plugins/strategy/fastest_response/fastest_respose_strategy_plugin_factory.ts b/common/lib/plugins/strategy/fastest_response/fastest_respose_strategy_plugin_factory.ts index 9518fdfd9..340e6f720 100644 --- a/common/lib/plugins/strategy/fastest_response/fastest_respose_strategy_plugin_factory.ts +++ b/common/lib/plugins/strategy/fastest_response/fastest_respose_strategy_plugin_factory.ts @@ -28,10 +28,7 @@ export class FastestResponseStrategyPluginFactory extends ConnectionPluginFactor if (!FastestResponseStrategyPluginFactory.fastestResponseStrategyPlugin) { FastestResponseStrategyPluginFactory.fastestResponseStrategyPlugin = await import("./fastest_response_strategy_plugin"); } - return new FastestResponseStrategyPluginFactory.fastestResponseStrategyPlugin.FastestResponseStrategyPlugin( - servicesContainer.pluginService, - properties - ); + return new FastestResponseStrategyPluginFactory.fastestResponseStrategyPlugin.FastestResponseStrategyPlugin(servicesContainer, properties); } catch (error: any) { throw new AwsWrapperError( Messages.get("ConnectionPluginChainBuilder.errorImportingPlugin", error.message, "FastestResponseStrategyPluginFactory") diff --git a/common/lib/plugins/strategy/fastest_response/host_response_time_monitor.ts b/common/lib/plugins/strategy/fastest_response/host_response_time_monitor.ts index 5e0a8b809..050747423 100644 --- a/common/lib/plugins/strategy/fastest_response/host_response_time_monitor.ts +++ b/common/lib/plugins/strategy/fastest_response/host_response_time_monitor.ts @@ -17,66 +17,95 @@ import { HostInfo } from "../../../host_info"; import { PluginService } from "../../../plugin_service"; import { TelemetryFactory } from "../../../utils/telemetry/telemetry_factory"; -import { sleep } from "../../../utils/utils"; +import { sleepWithAbort } from "../../../utils/utils"; import { logger } from "../../../../logutils"; import { Messages } from "../../../utils/messages"; import { TelemetryTraceLevel } from "../../../utils/telemetry/telemetry_trace_level"; import { ClientWrapper } from "../../../client_wrapper"; -import { TelemetryContext } from "../../../utils/telemetry/telemetry_context"; import { WrapperProperties } from "../../../wrapper_property"; +import { AbstractMonitor } from "../../../utils/monitoring/monitor"; +import { FullServicesContainer } from "../../../utils/full_services_container"; -export class HostResponseTimeMonitor { - static readonly MONITORING_PROPERTY_PREFIX = "frt_"; +export class ResponseTimeHolder { + private readonly url: string; + private readonly responseTime: number; + + constructor(url: string, responseTime: number) { + this.url = url; + this.responseTime = responseTime; + } + + getUrl(): string { + return this.url; + } + + getResponseTime(): number { + return this.responseTime; + } +} + +export class HostResponseTimeMonitor extends AbstractMonitor { static readonly NUM_OF_MEASURES = 5; + private static readonly TERMINATION_TIMEOUT_SEC = 5; private readonly intervalMs: number; private readonly hostInfo: HostInfo; - private stopped = false; private responseTimeMs = Number.MAX_SAFE_INTEGER; - private checkTimestamp = Date.now(); private readonly properties: Map; - private pluginService: PluginService; - private telemetryFactory: TelemetryFactory; + private readonly servicesContainer: FullServicesContainer; + private readonly pluginService: PluginService; + private readonly telemetryFactory: TelemetryFactory; protected monitoringClient: ClientWrapper | null = null; + private abortSleep?: () => void; - constructor(pluginService: PluginService, hostInfo: HostInfo, properties: Map, intervalMs: number) { - this.pluginService = pluginService; + constructor(servicesContainer: FullServicesContainer, hostInfo: HostInfo, properties: Map, intervalMs: number) { + super(HostResponseTimeMonitor.TERMINATION_TIMEOUT_SEC); + this.servicesContainer = servicesContainer; + this.pluginService = servicesContainer.pluginService; this.hostInfo = hostInfo; this.properties = properties; this.intervalMs = intervalMs; this.telemetryFactory = this.pluginService.getTelemetryFactory(); - const hostId: string = this.hostInfo.hostId ?? this.getHostInfo().host; - /** - * Report current response time (in milliseconds) to telemetry engine. - * Report -1 if response time couldn't be measured. - */ - this.telemetryFactory.createGauge(`frt.response.time.${hostId}`, () => this.getResponseTime() == Number.MAX_SAFE_INTEGER); - this.run(); + const hostId: string = this.hostInfo.hostId ?? this.hostInfo.host; + this.telemetryFactory.createGauge(`frt.response.time.${hostId}`, () => + this.responseTimeMs === Number.MAX_SAFE_INTEGER ? -1 : this.responseTimeMs + ); } - getResponseTime() { + getResponseTime(): number { return this.responseTimeMs; } - getCheckTimeStamp() { - return this.checkTimestamp; - } - - getHostInfo() { + getHostInfo(): HostInfo { return this.hostInfo; } async close(): Promise { - this.stopped = true; - await sleep(500); - logger.debug(Messages.get("HostResponseTimeMonitor.stopped", this.hostInfo.host)); + if (this.abortSleep) { + try { + this.abortSleep(); + } catch (error) { + // ignore + } + this.abortSleep = undefined; + } + if (this.monitoringClient) { + try { + await this.monitoringClient.abort(); + } catch (error) { + // ignore + } + this.monitoringClient = null; + } } - async run(): Promise { - const telemetryContext: TelemetryContext = this.telemetryFactory.openTelemetryContext("host response time task", TelemetryTraceLevel.TOP_LEVEL); + async monitor(): Promise { + const telemetryContext = this.telemetryFactory.openTelemetryContext("host response time task", TelemetryTraceLevel.TOP_LEVEL); telemetryContext.setAttribute("url", this.hostInfo.host); - while (!this.stopped) { + + while (!this._stop) { + this.lastActivityTimestampNanos = BigInt(Date.now() * 1_000_000); await telemetryContext.start(async () => { try { await this.openConnection(); @@ -84,33 +113,29 @@ export class HostResponseTimeMonitor { let responseTimeSum = 0; let count = 0; for (let i = 0; i < HostResponseTimeMonitor.NUM_OF_MEASURES; i++) { - if (this.stopped) { + if (this._stop) { break; } const startTime = Date.now(); if (await this.pluginService.isClientValid(this.monitoringClient)) { - const responseTime = Date.now() - startTime; - responseTimeSum += responseTime; + responseTimeSum += Date.now() - startTime; count++; } } if (count > 0) { this.responseTimeMs = responseTimeSum / count; + this.servicesContainer.storageService.set(this.hostInfo.url, new ResponseTimeHolder(this.hostInfo.url, this.responseTimeMs)); } else { this.responseTimeMs = Number.MAX_SAFE_INTEGER; + this.servicesContainer.storageService.remove(ResponseTimeHolder, this.hostInfo.url); } - this.checkTimestamp = Date.now(); logger.debug(Messages.get("HostResponseTimeMonitor.responseTime", this.hostInfo.host, this.responseTimeMs.toString())); } - await sleep(this.intervalMs); + const [sleepPromise, abortFn] = sleepWithAbort(this.intervalMs); + this.abortSleep = abortFn as () => void; + await sleepPromise; } catch (error) { logger.debug(Messages.get("HostResponseTimeMonitor.interruptedErrorDuringMonitoring", this.hostInfo.host, error.message)); - } finally { - this.stopped = true; - if (this.monitoringClient) { - await this.monitoringClient.abort(); - } - this.monitoringClient = null; } }); } @@ -119,8 +144,7 @@ export class HostResponseTimeMonitor { async openConnection(): Promise { try { if (this.monitoringClient) { - const clientIsValid = await this.pluginService.isClientValid(this.monitoringClient); - if (clientIsValid) { + if (await this.pluginService.isClientValid(this.monitoringClient)) { return; } } diff --git a/common/lib/plugins/strategy/fastest_response/host_response_time_service.ts b/common/lib/plugins/strategy/fastest_response/host_response_time_service.ts index 627a69b88..be6025e14 100644 --- a/common/lib/plugins/strategy/fastest_response/host_response_time_service.ts +++ b/common/lib/plugins/strategy/fastest_response/host_response_time_service.ts @@ -15,80 +15,65 @@ */ import { HostInfo } from "../../../host_info"; -import { PluginService } from "../../../plugin_service"; -import { TelemetryFactory } from "../../../utils/telemetry/telemetry_factory"; -import { SlidingExpirationCache } from "../../../utils/sliding_expiration_cache"; -import { HostResponseTimeMonitor } from "./host_response_time_monitor"; +import { HostResponseTimeMonitor, ResponseTimeHolder } from "./host_response_time_monitor"; +import { FullServicesContainer } from "../../../utils/full_services_container"; +import { MonitorErrorResponse } from "../../../utils/monitoring/monitor"; export interface HostResponseTimeService { - /** - * Return a response time in milliseconds to the host. - * Return Number.MAX_SAFE_INTEGER if response time is not available. - * - * @param hostInfo the host details - * @return response time in milliseconds for a desired host. - */ getResponseTime(hostInfo: HostInfo): number; - - /** - * Provides an updated host list to a service. - */ setHosts(hosts: HostInfo[]): void; } export class HostResponseTimeServiceImpl implements HostResponseTimeService { - static readonly CACHE_EXPIRATION_NANOS: bigint = BigInt(10 * 60_000_000_000); // 10 minutes - static readonly CACHE_CLEANUP_NANOS: bigint = BigInt(60_000_000_000); // 1 minute + private static readonly MONITOR_DISPOSAL_TIME_NANOS: bigint = BigInt(10 * 60_000_000_000); // 10 minutes + private static readonly INACTIVE_TIMEOUT_NANOS: bigint = BigInt(3 * 60_000_000_000); // 3 minutes - private readonly pluginService: PluginService; - readonly properties: Map; - readonly intervalMs: number; - protected hosts: HostInfo[]; - private readonly telemetryFactory: TelemetryFactory; - protected static monitoringHosts: SlidingExpirationCache = new SlidingExpirationCache( - HostResponseTimeServiceImpl.CACHE_CLEANUP_NANOS, - undefined, - async (monitor: HostResponseTimeMonitor) => { - { - try { - await monitor.close(); - } catch (error) { - // ignore - } - } - } - ); + private readonly servicesContainer: FullServicesContainer; + private readonly properties: Map; + private readonly intervalMs: number; + private hosts: HostInfo[] = []; - constructor(pluginService: PluginService, properties: Map, intervalMs: number) { - this.pluginService = pluginService; + constructor(servicesContainer: FullServicesContainer, properties: Map, intervalMs: number) { + this.servicesContainer = servicesContainer; this.properties = properties; this.intervalMs = intervalMs; - this.telemetryFactory = this.pluginService.getTelemetryFactory(); - HostResponseTimeServiceImpl.monitoringHosts.cleanupIntervalNs = BigInt(intervalMs) ?? HostResponseTimeServiceImpl.CACHE_CLEANUP_NANOS; - this.telemetryFactory.createGauge("frt.hosts.count", () => HostResponseTimeServiceImpl.monitoringHosts.size); + + this.servicesContainer.storageService.registerItemClassIfAbsent( + ResponseTimeHolder, + true, + HostResponseTimeServiceImpl.MONITOR_DISPOSAL_TIME_NANOS, + null, + null + ); + + this.servicesContainer.monitorService.registerMonitorTypeIfAbsent( + HostResponseTimeMonitor, + HostResponseTimeServiceImpl.MONITOR_DISPOSAL_TIME_NANOS, + HostResponseTimeServiceImpl.INACTIVE_TIMEOUT_NANOS, + new Set([MonitorErrorResponse.RECREATE]), + ResponseTimeHolder + ); } getResponseTime(hostInfo: HostInfo): number { - const monitor: HostResponseTimeMonitor = HostResponseTimeServiceImpl.monitoringHosts.get( - hostInfo.url, - HostResponseTimeServiceImpl.CACHE_EXPIRATION_NANOS - ); - if (!monitor) { - return Number.MAX_SAFE_INTEGER; - } - return monitor.getResponseTime(); + const holder: ResponseTimeHolder | null = this.servicesContainer.storageService.get(ResponseTimeHolder, hostInfo.url); + return holder === null ? Number.MAX_SAFE_INTEGER : holder.getResponseTime(); } setHosts(hosts: HostInfo[]): void { - const oldHostMap: string[] = hosts.flatMap((host) => host.url); + const oldHostUrls: Set = new Set(this.hosts.map((host) => host.url)); + this.hosts = hosts; + + const servicesContainer = this.servicesContainer; + const properties = this.properties; + const intervalMs = this.intervalMs; + hosts - .filter((hostInfo: HostInfo) => !(hostInfo.url in oldHostMap)) - .forEach((hostInfo: HostInfo) => { - HostResponseTimeServiceImpl.monitoringHosts.computeIfAbsent( - hostInfo.url, - (key) => new HostResponseTimeMonitor(this.pluginService, hostInfo, this.properties, this.intervalMs), - HostResponseTimeServiceImpl.CACHE_EXPIRATION_NANOS - ); + .filter((hostInfo) => !oldHostUrls.has(hostInfo.url)) + .forEach((hostInfo) => { + servicesContainer.monitorService.runIfAbsent(HostResponseTimeMonitor, hostInfo.url, servicesContainer, properties, { + createMonitor: (sc: FullServicesContainer) => new HostResponseTimeMonitor(sc, hostInfo, properties, intervalMs) + }); }); } } diff --git a/common/lib/utils/errors.ts b/common/lib/utils/errors.ts index 247bbba0a..58e37aec1 100644 --- a/common/lib/utils/errors.ts +++ b/common/lib/utils/errors.ts @@ -48,6 +48,8 @@ export class FailoverFailedError extends FailoverError {} export class TransactionResolutionUnknownError extends FailoverError {} +export class ReadWriteSplittingError extends AwsWrapperError {} + export class LoginError extends AwsWrapperError {} export class AwsTimeoutError extends AwsWrapperError {} diff --git a/common/lib/utils/messages.ts b/common/lib/utils/messages.ts index 353c00ded..32d9b9df7 100644 --- a/common/lib/utils/messages.ts +++ b/common/lib/utils/messages.ts @@ -421,6 +421,11 @@ const MESSAGES: Record = { "GlobalDbFailoverPlugin.unableToFindCandidateWithMatchingRole": "Unable to find a candidate host with the expected role (%s) based on the given host selection strategy: %s", "GlobalDbFailoverPlugin.unableToConnect": "Unable to establish a connection during Global DB failover.", + "GdbReadWriteSplittingPlugin.missingHomeRegion": + "Unable to parse home region from endpoint '%s'. Please ensure you have set the 'gdbRwHomeRegion' connection parameter.", + "GdbReadWriteSplittingPlugin.cantConnectWriterOutOfHomeRegion": "Writer connection to '%s' is not allowed since it is out of home region '%s'.", + "GdbReadWriteSplittingPlugin.noAvailableReadersInHomeRegion": "No available reader nodes in home region '%s'.", + "GdbReadWriteSplittingPlugin.parameterValue": "%s=%s", "BatchingEventPublisher.errorDeliveringImmediateEvent": "Error delivering immediate event: %s", "WrapperProperty.invalidValue": "Invalid value '%s' for property '%s'. Allowed values: %s" }; diff --git a/common/lib/wrapper_property.ts b/common/lib/wrapper_property.ts index 2a997577e..a3a5a84f4 100644 --- a/common/lib/wrapper_property.ts +++ b/common/lib/wrapper_property.ts @@ -561,6 +561,20 @@ export class WrapperProperties { ["writer", "none"] ); + static readonly GDB_RW_HOME_REGION = new WrapperProperty("gdbRwHomeRegion", "Specifies the home region for read/write splitting.", null); + + static readonly GDB_RW_RESTRICT_WRITER_TO_HOME_REGION = new WrapperProperty( + "gdbRwRestrictWriterToHomeRegion", + "Prevents connections to a writer node outside of the defined home region.", + true + ); + + static readonly GDB_RW_RESTRICT_READER_TO_HOME_REGION = new WrapperProperty( + "gdbRwRestrictReaderToHomeRegion", + "Prevents connections to a reader node outside of the defined home region.", + true + ); + private static readonly PREFIXES = [ WrapperProperties.MONITORING_PROPERTY_PREFIX, WrapperProperties.TOPOLOGY_MONITORING_PROPERTY_PREFIX, diff --git a/index.ts b/index.ts index eb152e55c..6292539cd 100644 --- a/index.ts +++ b/index.ts @@ -44,7 +44,7 @@ export type { ConnectionProvider } from "./common/lib/connection_provider"; export type { HostSelector } from "./common/lib/host_selector"; export type { DatabaseDialect } from "./common/lib/database_dialect/database_dialect"; export type { PluginService } from "./common/lib/plugin_service"; -export type { HostListProvider, BlockingHostListProvider } from "./common/lib/host_list_provider/host_list_provider"; +export type { HostListProvider } from "./common/lib/host_list_provider/host_list_provider"; export type { ErrorHandler } from "./common/lib/error_handler"; export type { SessionStateService } from "./common/lib/session_state_service"; export type { DriverDialect } from "./common/lib/driver_dialect/driver_dialect"; diff --git a/tests/integration/container/tests/config.ts b/tests/integration/container/tests/config.ts index b50208cfa..221b397dc 100644 --- a/tests/integration/container/tests/config.ts +++ b/tests/integration/container/tests/config.ts @@ -15,6 +15,7 @@ */ import { CustomConsole, LogMessage, LogType } from "@jest/console"; +import { TestEnvironment } from "./utils/test_environment"; function simpleFormatter(type: LogType, message: LogMessage): string { return message @@ -34,3 +35,7 @@ const testInfo = JSON.parse(infoJson); const request = testInfo.request; export const features = request.features; export const instanceCount = request.numOfInstances; + +afterAll(async () => { + await TestEnvironment.shutdownTelemetry(); +}); diff --git a/tests/integration/container/tests/failover/gdb_failover.test.ts b/tests/integration/container/tests/failover/gdb_failover.test.ts index 261d98c0b..2d5348a39 100644 --- a/tests/integration/container/tests/failover/gdb_failover.test.ts +++ b/tests/integration/container/tests/failover/gdb_failover.test.ts @@ -166,6 +166,7 @@ describe("gdb failover", () => { const config = await initDefaultConfig(initialWriterHost, initialWriterPort, true); config["activeHomeFailoverMode"] = "home-reader-or-writer"; config["inactiveHomeFailoverMode"] = "home-reader-or-writer"; + config["wrapperQueryTimeout"] = 2000; client = initClientFunc(config); await client.connect(); diff --git a/tests/integration/container/tests/iam_authentication.test.ts b/tests/integration/container/tests/iam_authentication.test.ts index 5b516c527..a38766630 100644 --- a/tests/integration/container/tests/iam_authentication.test.ts +++ b/tests/integration/container/tests/iam_authentication.test.ts @@ -23,7 +23,6 @@ import { readFileSync } from "fs"; import { logger } from "../../../../common/logutils"; import { TestEnvironmentFeatures } from "./utils/test_environment_features"; import { features } from "./config"; -import { jest } from "@jest/globals"; const itIf = !features.includes(TestEnvironmentFeatures.PERFORMANCE) && @@ -72,9 +71,6 @@ async function validateConnection() { describe("iam authentication", () => { beforeEach(async () => { logger.info(`Test started: ${expect.getState().currentTestName}`); - jest.useFakeTimers({ - doNotFake: ["nextTick"] - }); client = null; env = await TestEnvironment.getCurrent(); driver = DriverHelper.getDriverForDatabaseEngine(env.engine); @@ -93,11 +89,6 @@ describe("iam authentication", () => { logger.info(`Test finished: ${expect.getState().currentTestName}`); }, 1320000); - afterAll(async () => { - jest.runOnlyPendingTimers(); - jest.useRealTimers(); - }); - itIf( "iam wrong database username", async () => { diff --git a/tests/integration/container/tests/parameterized_queries.test.ts b/tests/integration/container/tests/parameterized_queries.test.ts index a8ceac145..0be0f2a81 100644 --- a/tests/integration/container/tests/parameterized_queries.test.ts +++ b/tests/integration/container/tests/parameterized_queries.test.ts @@ -141,7 +141,7 @@ beforeEach(async () => { logger.info(`Test started: ${expect.getState().currentTestName}`); await TestEnvironment.verifyClusterStatus(); client = null; -}, 60000); +}, 1320000); afterEach(async () => { if (client !== null) { diff --git a/tests/integration/container/tests/pg_pool.test.ts b/tests/integration/container/tests/pg_pool.test.ts index 4b09f224f..2b4744d2d 100644 --- a/tests/integration/container/tests/pg_pool.test.ts +++ b/tests/integration/container/tests/pg_pool.test.ts @@ -156,6 +156,12 @@ describe("pg pool integration tests", () => { itIfPGTwoInstances("failover writer during multi-statement transaction", async () => { client = await factory(); + try { + await client.query("DROP TABLE IF EXISTS test_table"); + } catch { + // Ignore + } + const initialWriterId = await auroraTestUtility.queryInstanceId(client); expect(await auroraTestUtility.isDbInstanceWriter(initialWriterId)).toStrictEqual(true); diff --git a/tests/integration/container/tests/utils/test_environment.ts b/tests/integration/container/tests/utils/test_environment.ts index 459229fa0..a03caa1ca 100644 --- a/tests/integration/container/tests/utils/test_environment.ts +++ b/tests/integration/container/tests/utils/test_environment.ts @@ -44,6 +44,7 @@ import { readFileSync } from "fs"; export class TestEnvironment { private static env?: TestEnvironment; + private static sdk?: NodeSDK; private readonly _info: TestEnvironmentInfo; private proxies?: { [s: string]: ProxyInfo }; @@ -281,6 +282,7 @@ export class TestEnvironment { // this enables the API to record telemetry sdk.start(); + TestEnvironment.sdk = sdk; // gracefully shut down the SDK on process exit process.on("SIGTERM", () => { sdk @@ -293,6 +295,17 @@ export class TestEnvironment { return env; } + static async shutdownTelemetry(): Promise { + if (TestEnvironment.sdk) { + try { + await TestEnvironment.sdk.shutdown(); + } catch (error) { + // ignore + } + TestEnvironment.sdk = undefined; + } + } + static async initProxies(environment: TestEnvironment) { if (environment.features.includes(TestEnvironmentFeatures.NETWORK_OUTAGES_ENABLED)) { environment.proxies = {}; diff --git a/tests/integration/host/src/test/java/integration/host/TestEnvironmentConfig.java b/tests/integration/host/src/test/java/integration/host/TestEnvironmentConfig.java index f4c8fb2e5..bfe94c656 100644 --- a/tests/integration/host/src/test/java/integration/host/TestEnvironmentConfig.java +++ b/tests/integration/host/src/test/java/integration/host/TestEnvironmentConfig.java @@ -13,6 +13,7 @@ import software.amazon.awssdk.services.rds.model.DBCluster; import software.amazon.awssdk.services.rds.model.DBInstance; +import java.io.File; import java.io.IOException; import java.net.URISyntaxException; import java.net.UnknownHostException; @@ -1073,6 +1074,9 @@ private static void createTestContainer(TestEnvironmentConfig env) { .withEnv("AWS_SESSION_TOKEN", env.awsSessionToken); } + // Ensure the reports directory exists before mounting it into the container. + new File("../../../tests/integration/container/reports").mkdirs(); + env.testContainer.start(); } diff --git a/tests/integration/host/src/test/java/integration/host/util/ContainerHelper.java b/tests/integration/host/src/test/java/integration/host/util/ContainerHelper.java index 88fc25c19..777b9a299 100644 --- a/tests/integration/host/src/test/java/integration/host/util/ContainerHelper.java +++ b/tests/integration/host/src/test/java/integration/host/util/ContainerHelper.java @@ -221,7 +221,6 @@ protected Long execInContainer( .execCreateCmd(containerId) .withAttachStdout(true) .withAttachStderr(true) - .withEnv(Arrays.asList("JEST_HTML_REPORTER_OUTPUT_PATH", "./tests/integration/container/reports/hello.html")) .withCmd(command); if (!StringUtils.isNullOrEmpty(workingDir)) { diff --git a/tests/unit/failover_plugin.test.ts b/tests/unit/failover_plugin.test.ts index 627500902..42e510251 100644 --- a/tests/unit/failover_plugin.test.ts +++ b/tests/unit/failover_plugin.test.ts @@ -44,6 +44,7 @@ import { NullTelemetryFactory } from "../../common/lib/utils/telemetry/null_tele import { HostChangeOptions } from "../../common/lib/host_change_options"; import { Messages } from "../../common/lib/utils/messages"; import { RdsHostListProvider } from "../../common/lib/host_list_provider/rds_host_list_provider"; +import { FullServicesContainer } from "../../common/lib/utils/full_services_container"; const builder = new HostInfoBuilder({ hostAvailabilityStrategy: new SimpleHostAvailabilityStrategy() }); @@ -61,6 +62,9 @@ const mockWriterResult: WriterFailoverResult = mock(WriterFailoverResult); const mockClientWrapper = new MySQLClientWrapper(undefined, mockHostInfo, new Map(), new MySQL2DriverDialect()); +const mockServicesContainer: FullServicesContainer = mock(); +const mockServicesContainerInstance = instance(mockServicesContainer); + const properties: Map = new Map(); let plugin: FailoverPlugin; @@ -76,10 +80,11 @@ function initializePlugin( readerFailoverHandler?: ClusterAwareReaderFailoverHandler, writerFailoverHandler?: ClusterAwareWriterFailoverHandler ): void { + when(mockServicesContainer.pluginService).thenReturn(mockPluginServiceInstance); plugin = readerFailoverHandler && writerFailoverHandler - ? new FailoverPlugin(mockPluginServiceInstance, properties, new RdsUtils(), readerFailoverHandler, writerFailoverHandler) - : new FailoverPlugin(mockPluginServiceInstance, properties, new RdsUtils()); + ? new FailoverPlugin(mockServicesContainerInstance, properties, new RdsUtils(), readerFailoverHandler, writerFailoverHandler) + : new FailoverPlugin(mockServicesContainerInstance, properties, new RdsUtils()); } describe("reader failover handler", () => { diff --git a/tests/unit/read_write_splitting.test.ts b/tests/unit/read_write_splitting.test.ts index 612435e19..775ca5836 100644 --- a/tests/unit/read_write_splitting.test.ts +++ b/tests/unit/read_write_splitting.test.ts @@ -405,6 +405,7 @@ describe("reader write splitting test", () => { when(mockPluginService.getCurrentHostInfo()).thenReturn(writerHostUnknownRole); when(mockPluginService.acceptsStrategy(anything(), anything())).thenReturn(true); + when(mockPluginService.getHostRole(anything())).thenResolve(HostRole.UNKNOWN); when(mockHostListProviderService.isDynamicHostListProvider()).thenReturn(true); const target = new TestReadWriteSplitting( diff --git a/tests/unit/writer_failover_handler.test.ts b/tests/unit/writer_failover_handler.test.ts index bb07d9d92..6daf39869 100644 --- a/tests/unit/writer_failover_handler.test.ts +++ b/tests/unit/writer_failover_handler.test.ts @@ -28,6 +28,10 @@ import { PgDatabaseDialect } from "../../pg/lib/dialect/pg_database_dialect"; import { MySQLClientWrapper } from "../../common/lib/mysql_client_wrapper"; import { DriverDialect } from "../../common/lib/driver_dialect/driver_dialect"; import { MySQL2DriverDialect } from "../../mysql/lib/dialect/mysql2_driver_dialect"; +import { FullServicesContainer } from "../../common/lib/utils/full_services_container"; +import { PluginManager } from "../../common/lib/plugin_manager"; +import { ServiceUtils } from "../../common/lib/utils/service_utils"; +import { HostListProviderService } from "../../common/lib/host_list_provider_service"; const builder = new HostInfoBuilder({ hostAvailabilityStrategy: new SimpleHostAvailabilityStrategy() }); @@ -45,6 +49,7 @@ const mockClientInstance = instance(mockClient); const mockPluginService = mock(PluginServiceImpl); const mockReaderFailover = mock(ClusterAwareReaderFailoverHandler); const mockDriverDialect: DriverDialect = mock(MySQL2DriverDialect); +const mockPluginManager = mock(PluginManager); const mockTargetClient = { client: 123 }; const mockClientWrapper: ClientWrapper = new MySQLClientWrapper( @@ -62,7 +67,23 @@ const mockClientWrapperB: ClientWrapper = new MySQLClientWrapper( mockDriverDialect ); +const mockServicesContainer = { + pluginService: null as any, + storageService: null as any, + monitorService: null as any, + eventPublisher: null as any, + defaultConnectionProvider: null as any, + telemetryFactory: null as any, + pluginManager: null as any, + hostListProviderService: null as any, + importantEventService: null as any +} as FullServicesContainer; + describe("writer failover handler", () => { + const originalServiceUtils = ServiceUtils.instance; + const mockServiceUtils = mock(ServiceUtils); + const mockHostListProviderService = mock(); + beforeEach(() => { writer.addAlias("writer-host"); newWriterHost.addAlias("new-writer-host"); @@ -70,11 +91,30 @@ describe("writer failover handler", () => { readerB.addAlias("reader-b-host"); when(mockPluginService.getDialect()).thenReturn(new PgDatabaseDialect()); + + // Mock ServiceUtils.createMinimalServiceContainerFrom to return a container + // that uses the same mock plugin service for both TaskA and TaskB. + const mockPluginManagerInstance = instance(mockPluginManager); + when(mockPluginManager.init()).thenResolve(); + when(mockServiceUtils.createMinimalServiceContainerFrom(anything(), anything())).thenReturn({ + pluginService: instance(mockPluginService), + pluginManager: mockPluginManagerInstance, + hostListProviderService: instance(mockHostListProviderService) + } as unknown as FullServicesContainer); + + // Replace the singleton instance with the mock. + Object.defineProperty(ServiceUtils, "instance", { get: () => instance(mockServiceUtils) }); }); afterEach(() => { reset(mockPluginService); reset(mockReaderFailover); + reset(mockPluginManager); + reset(mockServiceUtils); + reset(mockHostListProviderService); + + // Restore the original singleton instance. + Object.defineProperty(ServiceUtils, "instance", { get: () => originalServiceUtils }); }); it("test reconnect to writer - task B reader error", async () => { @@ -86,7 +126,15 @@ describe("writer failover handler", () => { const mockReaderFailoverInstance = instance(mockReaderFailover); const mockPluginServiceInstance = instance(mockPluginService); - const target = new ClusterAwareWriterFailoverHandler(mockPluginServiceInstance, mockReaderFailoverInstance, properties, 5000, 2000, 2000); + const target = new ClusterAwareWriterFailoverHandler( + mockPluginServiceInstance, + mockServicesContainer, + mockReaderFailoverInstance, + properties, + 5000, + 2000, + 2000 + ); const result = await target.failover(topology); expect(result.isConnected).toBe(true); @@ -111,7 +159,15 @@ describe("writer failover handler", () => { const mockReaderFailoverInstance = instance(mockReaderFailover); const mockPluginServiceInstance = instance(mockPluginService); - const target = new ClusterAwareWriterFailoverHandler(mockPluginServiceInstance, mockReaderFailoverInstance, properties, 60000, 5000, 5000); + const target = new ClusterAwareWriterFailoverHandler( + mockPluginServiceInstance, + mockServicesContainer, + mockReaderFailoverInstance, + properties, + 60000, + 5000, + 5000 + ); const result = await target.failover(topology); expect(result.isConnected).toBe(true); @@ -138,7 +194,15 @@ describe("writer failover handler", () => { const mockReaderFailoverInstance = instance(mockReaderFailover); const mockPluginServiceInstance = instance(mockPluginService); - const target = new ClusterAwareWriterFailoverHandler(mockPluginServiceInstance, mockReaderFailoverInstance, properties, 60000, 2000, 2000); + const target = new ClusterAwareWriterFailoverHandler( + mockPluginServiceInstance, + mockServicesContainer, + mockReaderFailoverInstance, + properties, + 60000, + 2000, + 2000 + ); const result: WriterFailoverResult = await target.failover(topology); expect(result.isConnected).toBe(true); @@ -168,6 +232,7 @@ describe("writer failover handler", () => { const target: ClusterAwareWriterFailoverHandler = new ClusterAwareWriterFailoverHandler( mockPluginServiceInstance, + mockServicesContainer, mockReaderFailoverInstance, properties, 60000, @@ -202,7 +267,15 @@ describe("writer failover handler", () => { const mockReaderFailoverInstance = instance(mockReaderFailover); const mockPluginServiceInstance = instance(mockPluginService); - const target = new ClusterAwareWriterFailoverHandler(mockPluginServiceInstance, mockReaderFailoverInstance, properties, 60000, 5000, 2000); + const target = new ClusterAwareWriterFailoverHandler( + mockPluginServiceInstance, + mockServicesContainer, + mockReaderFailoverInstance, + properties, + 60000, + 5000, + 2000 + ); const result: WriterFailoverResult = await target.failover(topology); expect(result.isConnected).toBe(true); @@ -238,7 +311,15 @@ describe("writer failover handler", () => { const mockReaderFailoverInstance = instance(mockReaderFailover); const mockPluginServiceInstance = instance(mockPluginService); - const target = new ClusterAwareWriterFailoverHandler(mockPluginServiceInstance, mockReaderFailoverInstance, properties, 5000, 2000, 2000); + const target = new ClusterAwareWriterFailoverHandler( + mockPluginServiceInstance, + mockServicesContainer, + mockReaderFailoverInstance, + properties, + 5000, + 2000, + 2000 + ); const startTime = Date.now(); const result = await target.failover(topology); @@ -264,7 +345,15 @@ describe("writer failover handler", () => { const mockReaderFailoverInstance = instance(mockReaderFailover); const mockPluginServiceInstance = instance(mockPluginService); - const target = new ClusterAwareWriterFailoverHandler(mockPluginServiceInstance, mockReaderFailoverInstance, properties, 5000, 2000, 2000); + const target = new ClusterAwareWriterFailoverHandler( + mockPluginServiceInstance, + mockServicesContainer, + mockReaderFailoverInstance, + properties, + 5000, + 2000, + 2000 + ); const result = await target.failover(topology); expect(result.isConnected).toBe(false);