diff --git a/iotdb-client/session/src/main/java/org/apache/iotdb/session/pool/SessionPool.java b/iotdb-client/session/src/main/java/org/apache/iotdb/session/pool/SessionPool.java index 949fba1251c50..b902a4ccd284c 100644 --- a/iotdb-client/session/src/main/java/org/apache/iotdb/session/pool/SessionPool.java +++ b/iotdb-client/session/src/main/java/org/apache/iotdb/session/pool/SessionPool.java @@ -788,7 +788,7 @@ private void occupy(ISession session) { occupied.put(session, session); } - /** close all connections in the pool */ + /** Closes all connections in the pool and unblocks any waiting threads. */ @Override public synchronized void close() { for (ISession session : queue) { @@ -819,6 +819,8 @@ public synchronized void close() { this.closed = true; queue.clear(); occupied.clear(); + // Notify all waiting threads in getSession() so they wake up immediately + this.notifyAll(); } @Override diff --git a/iotdb-client/session/src/main/java/org/apache/iotdb/session/pool/TableSessionPool.java b/iotdb-client/session/src/main/java/org/apache/iotdb/session/pool/TableSessionPool.java index 4e08f202a14ad..718f5c34f19cf 100644 --- a/iotdb-client/session/src/main/java/org/apache/iotdb/session/pool/TableSessionPool.java +++ b/iotdb-client/session/src/main/java/org/apache/iotdb/session/pool/TableSessionPool.java @@ -36,6 +36,7 @@ public ITableSession getSession() throws IoTDBConnectionException { return sessionPool.getPooledTableSession(); } + /** Closes the underlying session pool and unblocks any waiting threads. */ @Override public void close() { this.sessionPool.close(); diff --git a/iotdb-client/session/src/test/java/org/apache/iotdb/session/pool/SessionPoolTest.java b/iotdb-client/session/src/test/java/org/apache/iotdb/session/pool/SessionPoolTest.java index ce48172af2b87..a12ebbf70c7b8 100644 --- a/iotdb-client/session/src/test/java/org/apache/iotdb/session/pool/SessionPoolTest.java +++ b/iotdb-client/session/src/test/java/org/apache/iotdb/session/pool/SessionPoolTest.java @@ -74,8 +74,11 @@ import java.util.List; import java.util.Map; import java.util.concurrent.ConcurrentLinkedDeque; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; import static org.junit.Assert.assertTrue; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.eq; @@ -1623,4 +1626,69 @@ private List FakedFirstFetchTsBlockResult() { return Collections.singletonList(tsBlock); } + + // Regression test for graceful shutdown + @Test(timeout = 5000) + public void testCloseNotifiesWaitingThreads() throws Exception { + SessionPool pool = + new SessionPool.Builder() + .host("localhost") + .port(6667) + .user("root") + .password("root") + .maxSize(1) + .waitToGetSessionTimeoutInMs(10000) + .build(); + + try { + Session mockSession = Mockito.mock(Session.class); + ConcurrentLinkedDeque queue = + (ConcurrentLinkedDeque) Whitebox.getInternalState(pool, "queue"); + queue.push(mockSession); + Whitebox.setInternalState(pool, "size", 1); + + ISession occupiedSession = (ISession) Whitebox.invokeMethod(pool, "getSession"); + assertEquals(mockSession, occupiedSession); + assertEquals(0, queue.size()); + + final Exception[] caughtException = {null}; + CountDownLatch latch = new CountDownLatch(1); + + Thread waiterThread = + new Thread( + () -> { + try { + latch.countDown(); + Whitebox.invokeMethod(pool, "getSession"); + } catch (Exception e) { + caughtException[0] = e; + } + }); + waiterThread.start(); + + assertTrue("Waiter thread should have started", latch.await(10, TimeUnit.SECONDS)); + // Give it a moment to enter the wait(1000) block in getSession() + Thread.sleep(200); + + pool.close(); + + waiterThread.join(500); + assertTrue("Waiter thread should be unblocked quickly", !waiterThread.isAlive()); + + assertNotNull("Waiter thread should have caught an exception", caughtException[0]); + assertTrue( + "Exception should be IoTDBConnectionException", + caughtException[0] instanceof IoTDBConnectionException); + assertTrue( + "Exception message should indicate pool is closed", + caughtException[0].getMessage().contains("closed")); + + } finally { + try { + pool.close(); + } catch (Exception e) { + // ignore + } + } + } }