diff --git a/google/cloud/dataproc_magics/magics.py b/google/cloud/dataproc_magics/magics.py index 278cc81..8614e19 100644 --- a/google/cloud/dataproc_magics/magics.py +++ b/google/cloud/dataproc_magics/magics.py @@ -14,6 +14,9 @@ """Dataproc magic implementations.""" +import pyarrow as pa +import pyspark.sql.connect.proto as pb2 +import re import shlex from IPython.core.magic import (Magics, magics_class, line_magic) from google.cloud.dataproc_spark_connect import DataprocSparkSession @@ -22,6 +25,12 @@ @magics_class class DataprocMagics(Magics): + PIP_INSTALL_FAILURE_MSG = "Pip install failed with non-zero exit code" + DATAPROC_COMMAND_RUNNER = ( + "org.apache.spark.sql.artifact.DataprocCommandRunner" + ) + PIP_INSTALL_COMMAND = "PipInstallPackages" + def __init__( self, shell, @@ -37,40 +46,70 @@ def dpip(self, line): """ try: args = shlex.split(line) + packages, session = self._check_preconditions(args) + + print(f"Installing packages: {packages}") + output = self._run_command(packages, session) - if not args or args[0] != "install": - raise RuntimeError( - "Usage: %dpip install ..." - ) + failure_match = re.search(self.PIP_INSTALL_FAILURE_MSG, output) + if failure_match: + raise RuntimeError(output) - packages = args[1:] # remove `install` + print(output) + print("Finished installing packages.") + except Exception as e: + raise RuntimeError(f"Failed to install packages: {e}") from e - if not packages: - raise RuntimeError("Error: No packages specified.") + def _check_preconditions(self, args): + if not args or args[0] != "install": + raise RuntimeError("Usage: %dpip install ...") - if any(pkg.startswith("-") for pkg in packages): - raise RuntimeError("Error: Flags are not currently supported.") + packages = args[1:] # remove `install` - sessions = [ - (key, value) - for key, value in self.shell.user_ns.items() - if isinstance(value, DataprocSparkSession) - ] + if not packages: + raise RuntimeError("Error: No packages specified.") - if not sessions: - raise RuntimeError( - "Error: No active Dataproc Spark Session found. Please create one first." - ) - if len(sessions) > 1: - raise RuntimeError( - "Error: Found more than one active Dataproc Spark Sessions." - ) + if any(pkg.startswith("-") for pkg in packages): + raise RuntimeError("Error: Flags are not currently supported.") - ((name, session),) = sessions - print(f"Active session found: {name}") - print(f"Installing packages: {packages}") - session.addArtifacts(*packages, pypi=True) + sessions = [ + (key, value) + for key, value in self.shell.user_ns.items() + if isinstance(value, DataprocSparkSession) + ] - print("Finished installing packages.") + if not sessions: + raise RuntimeError( + "Error: No active Dataproc Spark Session found. Please create one first." + ) + if len(sessions) > 1: + raise RuntimeError( + "Error: Found more than one active Dataproc Spark Sessions." + ) + + ((name, session),) = sessions + print(f"Active session found: {name}") + return packages, session + + def _run_command(self, packages, session): + command = pb2.Command() + command.execute_external_command.runner = self.DATAPROC_COMMAND_RUNNER + command.execute_external_command.command = self.PIP_INSTALL_COMMAND + + for index, package in enumerate(packages): + command.execute_external_command.options[str(index)] = package + + _, properties, _ = session.client.execute_command(command) + + try: + binary_data = properties["sql_command_result"].local_relation.data + + # decode the Arrow stream and return the output + table = pa.ipc.RecordBatchStreamReader(binary_data).read_all() + return "\n".join(table.column(0).to_pylist()) + except (KeyError, AttributeError) as e: + raise RuntimeError( + "Unexpected response structure: missing binary data." + ) from e except Exception as e: - raise RuntimeError(f"Failed to install packages: {e}") from e + raise RuntimeError(f"Error decoding Arrow data: {e}") from e diff --git a/tests/unit/dataproc_magics/test_magics.py b/tests/unit/dataproc_magics/test_magics.py index 83d0b3e..86e750d 100644 --- a/tests/unit/dataproc_magics/test_magics.py +++ b/tests/unit/dataproc_magics/test_magics.py @@ -17,6 +17,8 @@ from contextlib import redirect_stdout from unittest import mock +import pyarrow as pa +import pyspark.sql.connect.proto as pb2 from google.cloud.dataproc_spark_connect import DataprocSparkSession from google.cloud.dataproc_magics import DataprocMagics from IPython.core.interactiveshell import InteractiveShell @@ -31,6 +33,14 @@ def setUp(self): self.shell.config = Config() self.magics = DataprocMagics(shell=self.shell) + def _create_mock_arrow_binary(self, lines: list[str]) -> bytes: + schema = pa.schema([pa.field("output", pa.string())]) + table = pa.Table.from_arrays([lines], schema=schema) + sink = pa.BufferOutputStream() + with pa.ipc.RecordBatchStreamWriter(sink, table.schema) as writer: + writer.write_table(table) + return sink.getvalue() + def test_dpip_with_flags(self): with self.assertRaisesRegex( RuntimeError, "Error: Flags are not currently supported." @@ -74,29 +84,89 @@ def test_dpip_no_packages_specified(self): def test_dpip_install_packages_success(self): mock_session = mock.Mock(spec=DataprocSparkSession) + mock_session.client = mock.Mock() + + # Create a mock for the properties object + properties = mock.Mock() + + # Create a pyarrow table and serialize it + binary_data = self._create_mock_arrow_binary( + ["Collecting pandas", "Successfully installed pandas"] + ) + + # Set up the mock response structure + properties.sql_command_result.local_relation.data = binary_data + mock_session.client.execute_command.return_value = ( + None, + {"sql_command_result": properties.sql_command_result}, + None, + ) + self.shell.user_ns["spark"] = mock_session f = io.StringIO() with redirect_stdout(f): self.magics.dpip("install pandas numpy") - mock_session.addArtifacts.assert_called_once_with( - "pandas", "numpy", pypi=True + # Check that execute_command was called + mock_session.client.execute_command.assert_called_once() + call_args = mock_session.client.execute_command.call_args[0][0] + self.assertIsInstance(call_args, pb2.Command) + self.assertEqual( + call_args.execute_external_command.command, + DataprocMagics.PIP_INSTALL_COMMAND, + ) + self.assertEqual( + call_args.execute_external_command.options["0"], "pandas" + ) + self.assertEqual( + call_args.execute_external_command.options["1"], "numpy" ) - self.assertEqual(mock_session.addArtifacts.call_count, 1) - self.assertIn("Finished installing packages.", f.getvalue()) - def test_dpip_add_artifacts_fails(self): + output = f.getvalue() + self.assertIn("Installing packages: ['pandas', 'numpy']", output) + self.assertIn("Collecting pandas", output) + self.assertIn("Successfully installed pandas", output) + self.assertIn("Finished installing packages.", output) + + def test_dpip_install_failure(self): mock_session = mock.Mock(spec=DataprocSparkSession) - mock_session.addArtifacts.side_effect = Exception("Failed") + mock_session.client = mock.Mock() + + # Create a mock for the properties object with failure message + properties = mock.Mock() + binary_data = self._create_mock_arrow_binary( + [ + DataprocMagics.PIP_INSTALL_FAILURE_MSG, + "ERROR: some pip error", + ] + ) + + properties.sql_command_result.local_relation.data = binary_data + mock_session.client.execute_command.return_value = ( + None, + {"sql_command_result": properties.sql_command_result}, + None, + ) + self.shell.user_ns["spark"] = mock_session with self.assertRaisesRegex( - RuntimeError, "Failed to install packages: Failed" + RuntimeError, "Failed to install packages: Pip install failed" ): - self.magics.dpip("install pandas") + self.magics.dpip("install non-existent-package") + + def test_dpip_unexpected_response(self): + mock_session = mock.Mock(spec=DataprocSparkSession) + mock_session.client = mock.Mock() + # Return response without 'sql_command_result' + mock_session.client.execute_command.return_value = (None, {}, None) + self.shell.user_ns["spark"] = mock_session - mock_session.addArtifacts.assert_called_once_with("pandas", pypi=True) + with self.assertRaisesRegex( + RuntimeError, "Unexpected response structure: missing binary data" + ): + self.magics.dpip("install pandas") if __name__ == "__main__":