diff --git a/src/sparkctl/compute_interface_factory.py b/src/sparkctl/compute_interface_factory.py index 886e48a..90a2986 100644 --- a/src/sparkctl/compute_interface_factory.py +++ b/src/sparkctl/compute_interface_factory.py @@ -1,4 +1,5 @@ from sparkctl.compute_interface import ComputeInterface +from sparkctl.fake_compute import FakeCompute from sparkctl.models import ComputeEnvironment, SparkConfig from sparkctl.native_compute import NativeCompute from sparkctl.slurm_compute import SlurmCompute @@ -11,6 +12,8 @@ def make_compute_interface(config: SparkConfig) -> ComputeInterface: return SlurmCompute(config) case ComputeEnvironment.NATIVE: return NativeCompute(config) + case ComputeEnvironment.FAKE: + return FakeCompute(config) case _: msg = f"{config.environment=} is not supported" raise NotImplementedError(msg) diff --git a/src/sparkctl/fake_compute.py b/src/sparkctl/fake_compute.py new file mode 100644 index 0000000..d9b6513 --- /dev/null +++ b/src/sparkctl/fake_compute.py @@ -0,0 +1,41 @@ +import tempfile +from pathlib import Path +from socket import gethostname + +from sparkctl.compute_interface import ComputeInterface + + +class FakeCompute(ComputeInterface): + """Provides interface to Slurm.""" + + def get_node_memory_overhead_gb( + self, driver_memory_gb: int, node_memory_overhead_gb: int + ) -> int: + return driver_memory_gb + self._config.runtime.node_memory_overhead_gb + + def get_num_workers(self) -> int: + return 1 + + def get_node_names(self) -> list[str]: + return [gethostname()] + + def get_scratch_dir(self) -> Path: + return Path(tempfile.gettempdir()) + + def get_worker_node_names(self) -> list[str]: + return self.get_node_names() + + def get_worker_memory_gb(self) -> int: + return 32 + + def get_worker_num_cpus(self) -> int: + return 12 + + def is_heterogeneous_slurm_job(self) -> bool: + return False + + def is_worker_node(self, node_name: str) -> bool: + return True + + def run_checks(self) -> None: + pass diff --git a/src/sparkctl/models.py b/src/sparkctl/models.py index 7e0d90a..59e21c1 100644 --- a/src/sparkctl/models.py +++ b/src/sparkctl/models.py @@ -197,6 +197,7 @@ class ComputeEnvironment(StrEnum): NATIVE = "native" # sparkctl detects workers through Slurm environment variables. SLURM = "slurm" + FAKE = "fake" class PostgresScripts(SparkctlBaseModel): diff --git a/tests/conftest.py b/tests/conftest.py index 162129f..fd303fa 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -139,7 +139,7 @@ def extract_tarball(src_file: Path, extract_dir: Path) -> None: @pytest.fixture def setup_local_env(tmp_path): - config = _create_default_config(SPARK_DIR, JAVA_DIR, tmp_path, ComputeEnvironment.NATIVE) + config = _create_default_config(SPARK_DIR, JAVA_DIR, tmp_path, ComputeEnvironment.FAKE) config.binaries.spark_path = SPARK_DIR match sys.platform: case "linux":