-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy path__init__.py
More file actions
176 lines (144 loc) · 4.99 KB
/
__init__.py
File metadata and controls
176 lines (144 loc) · 4.99 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
"""WaveSpeedAI Serverless SDK.
This module provides a serverless worker implementation compatible with
RunPod serverless infrastructure (Waverless).
Example usage:
import wavespeed.serverless as serverless
def handler(job):
return {"output": job["input"]["data"]}
serverless.start({"handler": handler})
"""
import argparse
import signal
import sys
from typing import Any, Dict
from wavespeed.config import _detect_serverless_env
from .modules.handler import is_generator
from .modules.local import run_local
from .modules.logger import log
from .worker import run_worker
__all__ = ["start", "log"]
def _parse_args() -> argparse.Namespace:
"""Parse command-line arguments for the serverless worker."""
parser = argparse.ArgumentParser(description="WaveSpeedAI Serverless Worker")
parser.add_argument(
"--waverless_log_level",
"--rp_log_level",
type=str,
default=None,
help="Log level (NOTSET, TRACE, DEBUG, INFO, WARN, ERROR)",
)
parser.add_argument(
"--waverless_debugger",
"--rp_debugger",
action="store_true",
default=False,
help="Enable debugger/profiler",
)
parser.add_argument(
"--waverless_serve_api",
"--rp_serve_api",
action="store_true",
default=False,
help="Serve API server for local development",
)
parser.add_argument(
"--waverless_api_host",
"--rp_api_host",
type=str,
default="0.0.0.0",
help="API server host (default: 0.0.0.0)",
)
parser.add_argument(
"--waverless_api_port",
"--rp_api_port",
type=int,
default=8000,
help="API server port (default: 8000)",
)
parser.add_argument(
"--test_input",
type=str,
default=None,
help="JSON test input for local testing",
)
return parser.parse_args()
def _setup_signal_handlers(worker_config: Dict[str, Any]) -> None:
"""Set up signal handlers for graceful shutdown."""
def signal_handler(sig: int, frame: Any) -> None:
log.info("Received shutdown signal, stopping worker...")
worker_config["_shutdown"] = True
sys.exit(0)
signal.signal(signal.SIGINT, signal_handler)
signal.signal(signal.SIGTERM, signal_handler)
def start(config: Dict[str, Any]) -> None:
"""Start the serverless worker.
This is the main entry point for the serverless worker. It handles:
- Parsing command-line arguments
- Detecting serverless environment (RunPod/Waverless) and loading config
- Setting up signal handlers
- Running in local test mode or worker mode
Args:
config: Configuration dictionary containing:
- handler: The handler function to process jobs
- return_aggregate_stream (optional): Whether to aggregate
streaming results
- concurrency_modifier (optional): Function to adjust concurrency
Example:
def handler(job):
return {"output": "processed"}
serverless.start({"handler": handler})
"""
if "handler" not in config:
raise ValueError("config must contain a 'handler' function")
# Parse CLI arguments
args = _parse_args()
# Set log level from CLI or environment
if args.waverless_log_level:
log.set_level(args.waverless_log_level)
# Config is auto-loaded at import time in wavespeed.config
# Just detect environment for logging and storing in config
serverless_env = _detect_serverless_env()
if serverless_env == "runpod":
log.debug("Detected RunPod environment")
elif serverless_env == "waverless":
log.debug("Detected native Waverless environment")
# Store parsed args in config
config["_args"] = args
config["_shutdown"] = False
config["_serverless_env"] = serverless_env
# Setup signal handlers
_setup_signal_handlers(config)
# Check handler type
handler = config["handler"]
if is_generator(handler):
log.debug("Handler is a generator function")
config["_is_generator"] = True
else:
config["_is_generator"] = False
# Determine run mode
if args.test_input is not None:
# Local test mode with CLI input
log.info("Running in local test mode (CLI input)")
run_local(config)
elif args.waverless_serve_api:
# API server mode
log.info("Running in API server mode")
from .modules.fastapi import WorkerAPI
api = WorkerAPI(config)
api.start(
host=args.waverless_api_host,
port=args.waverless_api_port,
)
else:
# Standard worker mode with health check server
log.info("Starting serverless worker...")
from .modules.health import HealthServer
health_server = HealthServer(
host=args.waverless_api_host,
port=args.waverless_api_port,
)
health_server.start()
try:
run_worker(config)
finally:
health_server.stop()