Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions src/sparkctl/compute_interface_factory.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from sparkctl.compute_interface import ComputeInterface
from sparkctl.fake_compute import FakeCompute
from sparkctl.models import ComputeEnvironment, SparkConfig
from sparkctl.native_compute import NativeCompute
from sparkctl.slurm_compute import SlurmCompute
Expand All @@ -11,6 +12,8 @@ def make_compute_interface(config: SparkConfig) -> ComputeInterface:
return SlurmCompute(config)
case ComputeEnvironment.NATIVE:
return NativeCompute(config)
case ComputeEnvironment.FAKE:
return FakeCompute(config)
case _:
msg = f"{config.environment=} is not supported"
raise NotImplementedError(msg)
41 changes: 41 additions & 0 deletions src/sparkctl/fake_compute.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
import tempfile
from pathlib import Path
from socket import gethostname

from sparkctl.compute_interface import ComputeInterface


class FakeCompute(ComputeInterface):
"""Provides interface to Slurm."""

def get_node_memory_overhead_gb(
self, driver_memory_gb: int, node_memory_overhead_gb: int
) -> int:
return driver_memory_gb + self._config.runtime.node_memory_overhead_gb

def get_num_workers(self) -> int:
return 1

def get_node_names(self) -> list[str]:
return [gethostname()]

def get_scratch_dir(self) -> Path:
return Path(tempfile.gettempdir())

def get_worker_node_names(self) -> list[str]:
return self.get_node_names()

def get_worker_memory_gb(self) -> int:
return 32

def get_worker_num_cpus(self) -> int:
return 12

def is_heterogeneous_slurm_job(self) -> bool:
return False

def is_worker_node(self, node_name: str) -> bool:
return True

def run_checks(self) -> None:
pass
1 change: 1 addition & 0 deletions src/sparkctl/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,7 @@ class ComputeEnvironment(StrEnum):
NATIVE = "native"
# sparkctl detects workers through Slurm environment variables.
SLURM = "slurm"
FAKE = "fake"


class PostgresScripts(SparkctlBaseModel):
Expand Down
2 changes: 1 addition & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ def extract_tarball(src_file: Path, extract_dir: Path) -> None:

@pytest.fixture
def setup_local_env(tmp_path):
config = _create_default_config(SPARK_DIR, JAVA_DIR, tmp_path, ComputeEnvironment.NATIVE)
config = _create_default_config(SPARK_DIR, JAVA_DIR, tmp_path, ComputeEnvironment.FAKE)
config.binaries.spark_path = SPARK_DIR
match sys.platform:
case "linux":
Expand Down
Loading