Skip to content

Commit d59d32a

Browse files
authored
Implement on_kill() trigger hook for Databricks triggers (#65672)
1 parent 381a0d3 commit d59d32a

2 files changed

Lines changed: 28 additions & 0 deletions

File tree

providers/databricks/src/airflow/providers/databricks/triggers/databricks.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,14 @@ def serialize(self) -> tuple[str, dict[str, Any]]:
8484
},
8585
)
8686

87+
async def on_kill(self) -> None:
88+
"""Cancel the Databricks run when the trigger is cancelled by a user action."""
89+
if self.run_id:
90+
from asgiref.sync import sync_to_async
91+
92+
self.log.info("Cancelling Databricks run %s.", self.run_id)
93+
await sync_to_async(self.hook.cancel_run)(self.run_id)
94+
8795
async def run(self):
8896
async with self.hook:
8997
while True:
@@ -167,6 +175,14 @@ def serialize(self) -> tuple[str, dict[str, Any]]:
167175
},
168176
)
169177

178+
async def on_kill(self) -> None:
179+
"""Cancel the Databricks SQL statement when the trigger is cancelled by a user action."""
180+
if self.statement_id:
181+
from asgiref.sync import sync_to_async
182+
183+
self.log.info("Cancelling Databricks SQL statement %s.", self.statement_id)
184+
await sync_to_async(self.hook.cancel_sql_statement)(self.statement_id)
185+
170186
async def run(self):
171187
async with self.hook:
172188
while self.end_time > time.time():

providers/databricks/tests/unit/databricks/triggers/test_databricks.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -259,6 +259,12 @@ async def test_sleep_between_retries(
259259
mock_sleep.assert_called_once()
260260
mock_sleep.assert_called_with(POLLING_INTERVAL_SECONDS)
261261

262+
@pytest.mark.asyncio
263+
@mock.patch("airflow.providers.databricks.hooks.databricks.DatabricksHook.cancel_run")
264+
async def test_on_kill_cancels_run(self, mock_cancel_run):
265+
await self.trigger.on_kill()
266+
mock_cancel_run.assert_called_once_with(RUN_ID)
267+
262268

263269
class TestDatabricksSQLStatementExecutionTrigger:
264270
@pytest.fixture(autouse=True)
@@ -361,3 +367,9 @@ async def test_sleep_between_retries(self, mock_a_get_sql_statement_state, mock_
361367
)
362368
mock_sleep.assert_called_once()
363369
mock_sleep.assert_called_with(POLLING_INTERVAL_SECONDS)
370+
371+
@pytest.mark.asyncio
372+
@mock.patch("airflow.providers.databricks.hooks.databricks.DatabricksHook.cancel_sql_statement")
373+
async def test_on_kill_cancels_statement(self, mock_cancel_sql_statement):
374+
await self.trigger.on_kill()
375+
mock_cancel_sql_statement.assert_called_once_with(STATEMENT_ID)

0 commit comments

Comments
 (0)