Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
95 changes: 67 additions & 28 deletions google/cloud/dataproc_magics/magics.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@

"""Dataproc magic implementations."""

import pyarrow as pa
import pyspark.sql.connect.proto as pb2
import re
import shlex
from IPython.core.magic import (Magics, magics_class, line_magic)
from google.cloud.dataproc_spark_connect import DataprocSparkSession
Expand All @@ -22,6 +25,12 @@
@magics_class
class DataprocMagics(Magics):

PIP_INSTALL_FAILURE_MSG = "Pip install failed with non-zero exit code"
DATAPROC_COMMAND_RUNNER = (
"org.apache.spark.sql.artifact.DataprocCommandRunner"
)
PIP_INSTALL_COMMAND = "PipInstallPackages"

def __init__(
self,
shell,
Expand All @@ -37,40 +46,70 @@ def dpip(self, line):
"""
try:
args = shlex.split(line)
packages, session = self._check_preconditions(args)

print(f"Installing packages: {packages}")
output = self._run_command(packages, session)

if not args or args[0] != "install":
raise RuntimeError(
"Usage: %dpip install <package1> <package2> ..."
)
failure_match = re.search(self.PIP_INSTALL_FAILURE_MSG, output)
if failure_match:
raise RuntimeError(output)

packages = args[1:] # remove `install`
print(output)
print("Finished installing packages.")
except Exception as e:
raise RuntimeError(f"Failed to install packages: {e}") from e

if not packages:
raise RuntimeError("Error: No packages specified.")
def _check_preconditions(self, args):
if not args or args[0] != "install":
raise RuntimeError("Usage: %dpip install <package1> <package2> ...")

if any(pkg.startswith("-") for pkg in packages):
raise RuntimeError("Error: Flags are not currently supported.")
packages = args[1:] # remove `install`

sessions = [
(key, value)
for key, value in self.shell.user_ns.items()
if isinstance(value, DataprocSparkSession)
]
if not packages:
raise RuntimeError("Error: No packages specified.")

if not sessions:
raise RuntimeError(
"Error: No active Dataproc Spark Session found. Please create one first."
)
if len(sessions) > 1:
raise RuntimeError(
"Error: Found more than one active Dataproc Spark Sessions."
)
if any(pkg.startswith("-") for pkg in packages):
raise RuntimeError("Error: Flags are not currently supported.")

((name, session),) = sessions
print(f"Active session found: {name}")
print(f"Installing packages: {packages}")
session.addArtifacts(*packages, pypi=True)
sessions = [
(key, value)
for key, value in self.shell.user_ns.items()
if isinstance(value, DataprocSparkSession)
]

print("Finished installing packages.")
if not sessions:
raise RuntimeError(
"Error: No active Dataproc Spark Session found. Please create one first."
)
if len(sessions) > 1:
raise RuntimeError(
"Error: Found more than one active Dataproc Spark Sessions."
)

((name, session),) = sessions
print(f"Active session found: {name}")
return packages, session

def _run_command(self, packages, session):
command = pb2.Command()
command.execute_external_command.runner = self.DATAPROC_COMMAND_RUNNER
command.execute_external_command.command = self.PIP_INSTALL_COMMAND

for index, package in enumerate(packages):
command.execute_external_command.options[str(index)] = package

_, properties, _ = session.client.execute_command(command)

try:
binary_data = properties["sql_command_result"].local_relation.data

# decode the Arrow stream and return the output
table = pa.ipc.RecordBatchStreamReader(binary_data).read_all()
return "\n".join(table.column(0).to_pylist())
except (KeyError, AttributeError) as e:
raise RuntimeError(
"Unexpected response structure: missing binary data."
) from e
except Exception as e:
raise RuntimeError(f"Failed to install packages: {e}") from e
raise RuntimeError(f"Error decoding Arrow data: {e}") from e
88 changes: 79 additions & 9 deletions tests/unit/dataproc_magics/test_magics.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
from contextlib import redirect_stdout
from unittest import mock

import pyarrow as pa
import pyspark.sql.connect.proto as pb2
from google.cloud.dataproc_spark_connect import DataprocSparkSession
from google.cloud.dataproc_magics import DataprocMagics
from IPython.core.interactiveshell import InteractiveShell
Expand All @@ -31,6 +33,14 @@ def setUp(self):
self.shell.config = Config()
self.magics = DataprocMagics(shell=self.shell)

def _create_mock_arrow_binary(self, lines: list[str]) -> bytes:
schema = pa.schema([pa.field("output", pa.string())])
table = pa.Table.from_arrays([lines], schema=schema)
sink = pa.BufferOutputStream()
with pa.ipc.RecordBatchStreamWriter(sink, table.schema) as writer:
writer.write_table(table)
return sink.getvalue()

def test_dpip_with_flags(self):
with self.assertRaisesRegex(
RuntimeError, "Error: Flags are not currently supported."
Expand Down Expand Up @@ -74,29 +84,89 @@ def test_dpip_no_packages_specified(self):

def test_dpip_install_packages_success(self):
mock_session = mock.Mock(spec=DataprocSparkSession)
mock_session.client = mock.Mock()

# Create a mock for the properties object
properties = mock.Mock()

# Create a pyarrow table and serialize it
binary_data = self._create_mock_arrow_binary(
["Collecting pandas", "Successfully installed pandas"]
)

# Set up the mock response structure
properties.sql_command_result.local_relation.data = binary_data
mock_session.client.execute_command.return_value = (
None,
{"sql_command_result": properties.sql_command_result},
None,
)

self.shell.user_ns["spark"] = mock_session

f = io.StringIO()
with redirect_stdout(f):
self.magics.dpip("install pandas numpy")

mock_session.addArtifacts.assert_called_once_with(
"pandas", "numpy", pypi=True
# Check that execute_command was called
mock_session.client.execute_command.assert_called_once()
call_args = mock_session.client.execute_command.call_args[0][0]
self.assertIsInstance(call_args, pb2.Command)
self.assertEqual(
call_args.execute_external_command.command,
DataprocMagics.PIP_INSTALL_COMMAND,
)
self.assertEqual(
call_args.execute_external_command.options["0"], "pandas"
)
self.assertEqual(
call_args.execute_external_command.options["1"], "numpy"
)
self.assertEqual(mock_session.addArtifacts.call_count, 1)
self.assertIn("Finished installing packages.", f.getvalue())

def test_dpip_add_artifacts_fails(self):
output = f.getvalue()
self.assertIn("Installing packages: ['pandas', 'numpy']", output)
self.assertIn("Collecting pandas", output)
self.assertIn("Successfully installed pandas", output)
self.assertIn("Finished installing packages.", output)

def test_dpip_install_failure(self):
mock_session = mock.Mock(spec=DataprocSparkSession)
mock_session.addArtifacts.side_effect = Exception("Failed")
mock_session.client = mock.Mock()

# Create a mock for the properties object with failure message
properties = mock.Mock()
binary_data = self._create_mock_arrow_binary(
[
DataprocMagics.PIP_INSTALL_FAILURE_MSG,
"ERROR: some pip error",
]
)

properties.sql_command_result.local_relation.data = binary_data
mock_session.client.execute_command.return_value = (
None,
{"sql_command_result": properties.sql_command_result},
None,
)

self.shell.user_ns["spark"] = mock_session

with self.assertRaisesRegex(
RuntimeError, "Failed to install packages: Failed"
RuntimeError, "Failed to install packages: Pip install failed"
):
self.magics.dpip("install pandas")
self.magics.dpip("install non-existent-package")

def test_dpip_unexpected_response(self):
mock_session = mock.Mock(spec=DataprocSparkSession)
mock_session.client = mock.Mock()
# Return response without 'sql_command_result'
mock_session.client.execute_command.return_value = (None, {}, None)
self.shell.user_ns["spark"] = mock_session

mock_session.addArtifacts.assert_called_once_with("pandas", pypi=True)
with self.assertRaisesRegex(
RuntimeError, "Unexpected response structure: missing binary data"
):
self.magics.dpip("install pandas")


if __name__ == "__main__":
Expand Down
Loading