Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 42 additions & 4 deletions fe/fe-core/src/main/java/org/apache/doris/qe/ConnectContext.java
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@
import org.apache.doris.mysql.MysqlHandshakePacket;
import org.apache.doris.mysql.MysqlSslContext;
import org.apache.doris.mysql.ProxyMysqlChannel;
import org.apache.doris.mysql.privilege.Auth;
import org.apache.doris.mysql.privilege.PrivPredicate;
import org.apache.doris.nereids.StatementContext;
import org.apache.doris.nereids.stats.StatsErrorEstimator;
Expand Down Expand Up @@ -369,6 +370,46 @@ public void clearLastDBOfCatalog() {
lastDBOfCatalog.clear();
}

public void resetConnection() {
closeTxn();
if (!dbToTempTableNamesMap.isEmpty()) {
deleteTempTable();
}
dbToTempTableNamesMap.clear();
resetSessionVariable();
userVars = new HashMap<>();
preparedQuerys.clear();
preparedStatementContextMap.clear();
preparedStmtId = Integer.MIN_VALUE;
runningQuery = null;
changeDefaultCatalog(InternalCatalog.INTERNAL_CATALOG_NAME);
clearLastDBOfCatalog();
command = MysqlCommand.COM_SLEEP;
returnRows = 0;
}

private void resetSessionVariable() {
sessionVariable = VariableMgr.newSessionVariable();
applyUserSessionVariableDefaults();
if (Config.use_fuzzy_session_variable) {
sessionVariable.initFuzzyModeVariables();
}
}

private void applyUserSessionVariableDefaults() {
String qualifiedUser = getQualifiedUser();
if (Strings.isNullOrEmpty(qualifiedUser)) {
return;
}
Env currentEnv = env == null ? Env.getCurrentEnv() : env;
Auth auth = currentEnv == null ? null : currentEnv.getAuth();
if (auth == null) {
return;
}
setUserQueryTimeout(auth.getQueryTimeout(qualifiedUser));
setUserInsertTimeout(auth.getInsertTimeout(qualifiedUser));
}

public void setNotEvalNondeterministicFunction(boolean notEvalNondeterministicFunction) {
this.notEvalNondeterministicFunction = notEvalNondeterministicFunction;
}
Expand All @@ -385,12 +426,9 @@ public void init() {
state = new QueryState();
returnRows = 0;
isKilled = false;
sessionVariable = VariableMgr.newSessionVariable();
resetSessionVariable();
userVars = new HashMap<>();
command = MysqlCommand.COM_SLEEP;
if (Config.use_fuzzy_session_variable) {
sessionVariable.initFuzzyModeVariables();
}

sessionId = UUID.randomUUID().toString();
if (!FeConstants.runningUnitTest) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,6 @@
import org.apache.doris.common.util.SqlUtils;
import org.apache.doris.common.util.Util;
import org.apache.doris.datasource.CatalogIf;
import org.apache.doris.datasource.InternalCatalog;
import org.apache.doris.metric.MetricRepo;
import org.apache.doris.mysql.MysqlChannel;
import org.apache.doris.mysql.MysqlCommand;
Expand Down Expand Up @@ -155,8 +154,7 @@ protected void handleDebug() {
}

protected void handleResetConnection() {
ctx.changeDefaultCatalog(InternalCatalog.INTERNAL_CATALOG_NAME);
ctx.clearLastDBOfCatalog();
ctx.resetConnection();
ctx.getState().setOk();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
package org.apache.doris.qe;

import org.apache.doris.analysis.ResourceTypeEnum;
import org.apache.doris.analysis.SetVar;
import org.apache.doris.analysis.StringLiteral;
import org.apache.doris.analysis.UserIdentity;
import org.apache.doris.catalog.Env;
import org.apache.doris.cloud.qe.ComputeGroupException;
Expand Down Expand Up @@ -73,6 +75,50 @@ public void setUp() throws Exception {
Mockito.when(catalogMgr.getCatalog(Mockito.anyString())).thenReturn(internalCatalog);
}

@Test
public void testResetConnectionClearsSessionState() throws Exception {
ConnectContext ctx = new ConnectContext();
ctx.setEnv(env);
ctx.setCurrentUserIdentity(UserIdentity.createAnalyzedUserIdentWithIp("testUser", "%"));
Mockito.when(env.getAuth()).thenReturn(auth);
Mockito.when(auth.getQueryTimeout("testUser")).thenReturn(123);
Mockito.when(auth.getInsertTimeout("testUser")).thenReturn(456);
ctx.setUserQueryTimeout(123);
ctx.setUserInsertTimeout(456);
VariableMgr.setVar(ctx.getSessionVariable(),
new SetVar(SessionVariable.SQL_SELECT_LIMIT, new StringLiteral("0")));
ctx.getSessionVariable().setQueryTimeoutS(1);
ctx.getSessionVariable().setInsertTimeoutS(2);
ctx.setUserVar("user_var", new StringLiteral("value"));
ctx.changeDefaultCatalog("external_catalog");
ctx.currentDb = "test_db";
ctx.currentDbId = 10;
ctx.addLastDBOfCatalog("external_catalog", "test_db");
ctx.addPreparedQuery("1", "select 1");
ctx.setRunningQuery("select 1");
ctx.setCommand(MysqlCommand.COM_QUERY);
ctx.updateReturnRows(10);

Assert.assertEquals(0, ctx.getSessionVariable().getSqlSelectLimit());
Assert.assertEquals(1, ctx.getSessionVariable().getQueryTimeoutS());
Assert.assertEquals(2, ctx.getSessionVariable().getInsertTimeoutS());
Assert.assertFalse(ctx.getUserVars().isEmpty());

ctx.resetConnection();

Assert.assertEquals(-1, ctx.getSessionVariable().getSqlSelectLimit());
Assert.assertEquals(123, ctx.getSessionVariable().getQueryTimeoutS());
Assert.assertEquals(456, ctx.getSessionVariable().getInsertTimeoutS());
Assert.assertTrue(ctx.getUserVars().isEmpty());
Assert.assertEquals(InternalCatalog.INTERNAL_CATALOG_NAME, ctx.getDefaultCatalog());
Assert.assertEquals("", ctx.getDatabase());
Assert.assertNull(ctx.getLastDBOfCatalog("external_catalog"));
Assert.assertNull(ctx.getPreparedQuery("1"));
Assert.assertNull(ctx.getRunningQuery());
Assert.assertEquals(MysqlCommand.COM_SLEEP, ctx.getCommand());
Assert.assertEquals(0, ctx.getReturnRows());
}

@Test
public void testNormal() {
try (MockedStatic<Env> mockedEnv = Mockito.mockStatic(Env.class)) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,10 @@
package org.apache.doris.regression.suite

import com.google.common.collect.Maps
import com.mysql.cj.NativeSession
import com.mysql.cj.jdbc.JdbcConnection
import com.mysql.cj.protocol.a.NativeConstants
import com.mysql.cj.protocol.a.NativePacketPayload
import groovy.transform.CompileStatic
import org.apache.doris.regression.Config
import org.apache.doris.regression.util.OutputUtils
Expand Down Expand Up @@ -437,6 +441,18 @@ class SuiteContext implements Closeable {
connectTo(connInfo.conn.getMetaData().getURL(), connInfo.username, connInfo.password);
}

public void resetConnection() {
ConnectionInfo connInfo = threadLocalConn.get()
if (connInfo == null) {
return
}
NativeSession session = (NativeSession) connInfo.conn.unwrap(JdbcConnection.class).getSession()
// COM_RESET_CONNECTION has no payload besides the command byte.
NativePacketPayload packet = new NativePacketPayload(1)
packet.writeInteger(NativeConstants.IntegerDataType.INT1, 0x1f)
session.sendCommand(packet, false, 0)
}

public void connectTo(String url, String username, String password) {
ConnectionInfo oldConn = threadLocalConn.get()
if (oldConn != null) {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

suite("test_reset_connection_session_variable", "p0") {
sql "set sql_select_limit = 0"

def limitedResult = sql "select 1 union all select 2"
assertEquals(0, limitedResult.size())

resetConnection()

def resetResult = sql "select 1 union all select 2"
assertEquals(2, resetResult.size())
}
Loading