- Project Description
- Project Structure
- Setup Instructions
- API Specification
- Deployment to AWS
- Profiling
- Running Benchmarks
This project implements a producer / consumer model to run tasks in a pool of workers. Each worker is a process that listens for tasks from a dispatcher on a queue, processes them, and can return results to the dispatcher if requested.
Each worker can be associated with a limited set of labels at any given time, and each task is assigned with a particular label. Motivation for this will be given below. It is assumed that assigning a new label to a worker is expensive. Ideally, the dispatcher would send tasks to workers that have the required label already assigned to them.
This arose from a system where we had several workers in a pool for running tasks that consisted of inference with ML models. The workers had access to a common cache of these models, but loading them into memory was very expensive and could result in significant delays for processing tasks. Therefore, the idea arose to have a system that could automatically send tasks directly to workers that already had the required model loaded instead of a common queue, and thus reduce the loading overhead.
In this project, the ML models have been abstracted to "labels" that can be assumed / unassumed by workers when requested. The dispatcher keeps track of which workers have which labels assigned to them and whether they are available or are currently processing a task. It then tries to send a task to a worker that has the required label already assigned and is available.
The architecture of the project is based on a producer / consumer model with the following components:
- Dispatcher: This acts as the central hub to receive tasks and run them on the workers. It is an HTTP server implemented in Go that receives tasks from a producer and dispatches them to workers based on the labels assigned to them.
- Worker: Workers are in charge of performing ML inference on the given inputs. They are connected to the dispatcher via persistent sockets. The messages between the worker and dispatcher are encoded using protocol buffers. Each worker consists of two components:
- A Python manager process that communicates with the dispatcher and manages model loading
- A Triton Inference Server that handles the actual ML inference workloads
- Log Collector: This is a Go process that collects logs from the workers for later debugging purposes.
graph TD;
A[Producer] --> B(Dispatcher);
C[Producer] --> B;
B -->|Socket| Worker1;
subgraph Worker1
D[Worker] <--> D2[Triton Server]
end
B -->|Socket| Worker2;
subgraph Worker2
E[Worker] <--> E2[Triton Server]
end
Worker1 --> G[Log Collector];
Worker2 --> G;
The project has components that are implemented both in the Go and Python programming languages. The source code for the components is in the packages folder and is organized as follows:
packages/
├── dispatcher # (GoLang) Component that dispatches tasks to workers
├── log-collector # (GoLang) Component to collect logs from workers
├── worker # (Python) Component that processes tasks and manages Triton
├── uploader # (Python) Tools to convert models to ONNX format for Triton
└── benchmark # (GoLang) Components to run benchmarks with simulated traffic
kubernetes/
├── smart-routing-inference/ # Helm chart for deploying the system
│ ├── templates/ # Kubernetes resource templates
│ └── values.yaml # Default configuration values
└── kind-cfg/ # Configuration for local Kubernetes cluster
terraform/ # Infrastructure as Code (IaC) for AWS deployment
├── cluster/ # EKS cluster setup
└── helm/ # Helm deployment to EKS with Terraform
The dispatcher, log collector, and worker have Dockerfiles for containerization. The system can be deployed using Kubernetes with the provided Helm chart. The root directory contains a Taskfile.yml file to run tasks using the Task tool.
In order to run the project, you should have the following installed on your machine:
- Docker: For building container images
- Python 3.10+: For the worker component
- Go 1.20+: For the dispatcher and log collector components
- kubectl: For interacting with Kubernetes clusters
- kind: For running local Kubernetes clusters (install from https://kind.sigs.k8s.io/)
- Helm 3: For deploying the application (install from https://helm.sh/)
- Task: For running project tasks (install from https://taskfile.dev/)
Python packages are managed with the uv tool, with pytest for testing, and ruff for formatting.
The project now uses Kubernetes for local orchestration. To deploy locally:
-
Start the services: This will automatically create a kind cluster, build Docker images, load them, and deploy via Helm:
task start
-
Access the dispatcher: The dispatcher is exposed on port 30080:
curl http://localhost:30080/health
-
Stop the services: This will uninstall the Helm release and delete the kind cluster:
task stop
The Taskfile.yml contains several useful tasks:
task unit-tests: Run unit tests for all componentstask start: Start all services in a local Kubernetes cluster (simple one-command setup)task stop: Stop services and clean up Kubernetes resourcestask build-producer: Build the benchmark producer binarytask prepare-local-k8s VERSION=<version>: Advanced: manually set up local Kubernetes clustertask load-imgs-kind: Advanced: build and load Docker images into kind cluster
The dispatcher service exposes the following REST API endpoints:
- Endpoint:
GET /health - Description: Check the health status of the dispatcher service
- Response:
{ "message": "OK" }
- Endpoint:
GET /workers - Description: Get the list of all registered workers
- Response:
{ "workers": ["worker-id-1", "worker-id-2"] }
- Endpoint:
POST /run-task - Description: Submit a task to be executed by a worker
- Request Body:
{ "task_id": "unique-task-identifier", "task_type": "embed_texts", "label": "model-name-or-label", "parameters_json": "{\"texts\": [\"text to process\"]}" } - Response: Task execution result (format depends on task type)
- Endpoint:
POST /upload-model - Description: Forward a request to convert a SentenceTransformer model to ONNX format for use with Triton Server
- Request Body:
{ "model_id": "sentence-transformers/all-MiniLM-L6-v2", "max_batch_size": 32, "overwrite": false }model_id: HuggingFace model ID (required)max_batch_size: Maximum batch size for Triton (default: 32)overwrite: Whether to overwrite existing model (default: false)
- Response:
{ "status": "success", "model_name": "sentence-transformers--all-MiniLM-L6-v2", "message": "Model sentence-transformers/all-MiniLM-L6-v2 successfully converted to ONNX format" } - Note: Requires uploader service to be configured and running
- Endpoint:
GET /labels - Description: Get a summary of which labels are loaded on which workers
- Response:
{ "label-1": ["worker-id-1", "worker-id-2"], "label-2": ["worker-id-3"] }
- Endpoint:
/worker-listen - Description: WebSocket endpoint for workers to connect and receive tasks
- Headers Required:
Worker-Idheader must be set with unique worker identifier - Protocol: WebSocket with Protocol Buffer encoding
The project can be deployed to an AWS kubernetes cluster on EKS using Terraform and the Helm chart. The setup to do this is in the terraform directory, along with documentation on how to carry out the deployment here.
You can use GoLang's pprof package to generate CPU profiles of the dispatcher service. To do this, first, build the dispatcher binary with task build-dispatcher. Then, start the binary with ./bin/dispatcher --cpu-prof=path/to/output/file.prof to initialize the server. You can then run a few instances of the packages/worker/dummy_client in separate terminals to add some workers, and then use a script like the following to run some requests to the dispatcher:
# Script to run profiling requests
import random
import requests
labels = [f"label-{i+1}" for i in range(5)]
task_types = ["sample_task"]
total_requests = 8000
for i in range(total_requests):
rsp = requests.post(
"http://localhost:8080/run-task",
json={
"task_type": random.choice(task_types),
"task_id": f"task-{i + 1}",
"label": random.choice(labels),
"parameters_json": """{"data": "sample"}"""
}
)
if i % 250 == 0:
wrsp = requests.get("http://localhost:8080/workers")
print("WORKERS", wrsp.content)
if rsp.status_code // 100 != 2:
print("FAIL", rsp, rsp.content)Once this is done, you can stop the dispatcher service and dummy workers. Finally, you can access the profiling report with go tool pprof path/to/output/file.prof.
In order to run benchmarks on the service, you can use the task run-benchmark command. This will build and spin up the services, then run a number of requests with a given number of concurrent workers. Finally, it will parse the log files produced to show some summary statistics, and shut down the services.
Important: Before running benchmarks, ensure that the models to be used are ready by calling the upload models endpoint on the dispatcher. This will prepare the required models for the benchmark tests.
You can run the benchmarks with the following command (showing some of the possible options):
# Options:
# 'NUM_WORKERS': Number of workers to run in the pool: 2
# 'RANDOM_DISPATCH': Set to true to send all tasks to the common queue
# 'requests': Number of requests to send: 10
# 'producers': Number of clients that will send requests concurrently: 2
# 'wait-avg': Avg wait time each client waits between requests (seconds): 1.0
# 'wait-stddev': Std deviation of the wait time between requests: 0.5
task run-benchmark NUM_WORKERS=2 RANDOM_DISPATCH=false -- \
--requests 10 \
--producers 2 \
--wait-avg 1.0 \
--wait-stddev 0.5Here are some benchmark results from running 2 workers with capacity for 2 models each locally. There were 5 possible labels to ensure that some loading / unloading of models is needed.
Call
task run-benchmark NUM_WORKERS=2 RANDOM_DISPATCH=true -- \
--requests 2400 \
--producers 2 \
--wait-avg 2.5 \
--wait-stddev 1.0Results
Worker Log Summary:
Total Label Misses: 1448
Tasks Completed: 2400
Average Task Runtime: 0.3231 seconds
Producer Log Summary:
Total Requests: 2400
Total Failed Requests: 0
Total Synchronous Requests: 2400
Total Dispatched Jobs: 2400
Call
task run-benchmark NUM_WORKERS=2 RANDOM_DISPATCH=false -- \
--requests 2400 \
--producers 2 \
--wait-avg 2.5 \
--wait-stddev 1.0Results
Worker Log Summary:
Total Label Misses: 563
Tasks Completed: 2400
Average Task Runtime: 0.1791 seconds
Producer Log Summary:
Total Requests: 2400
Total Failed Requests: 0
Total Synchronous Requests: 2400
Total Dispatched Jobs: 2400
Here we can see that the number of label misses, that is, of models having to be loaded into memory has dropped by more than half using the routing scheme.
Now here are some results from running a benchmark with 4 different models, that is, the full capacity of the workers. Here the time between requests was lowered substantially as well.
Call
task run-benchmark NUM_WORKERS=2 RANDOM_DISPATCH=true -- \
--requests 2400 \
--producers 2 \
--wait-avg 0.75 \
--wait-stddev 1.0Results
Worker Log Summary:
Total Label Misses: 1193
Tasks Completed: 2400
Average Task Runtime: 0.2483 seconds
Producer Log Summary:
Total Requests: 2400
Total Failed Requests: 0
Total Synchronous Requests: 2400
Total Dispatched Jobs: 2400
Call
task run-benchmark NUM_WORKERS=2 RANDOM_DISPATCH=false -- \
--requests 2400 \
--producers 2 \
--wait-avg 0.75 \
--wait-stddev 1.0Results
Worker Log Summary:
Total Label Misses: 138
Tasks Completed: 2400
Average Task Runtime: 0.0639 seconds
Producer Log Summary:
Total Requests: 2400
Total Failed Requests: 0
Total Synchronous Requests: 2400
Total Dispatched Jobs: 2400
Here the difference is even more pronounced, and we have not only a tenfold reduction in label misses, but a 5x speed increase as well!