!!! note This document is a work in progress. It is not yet complete and may contain errors or omissions. Please check back later for updates, and ask in the Discord for help.
This guide will walk you through the steps to set up and manage a TPU cluster using Google Cloud Platform (GCP) and Ray. Currently, this is the only way to use multinode TPU (i.e. more than a v4-8) with Marin.
Before you begin, ensure you have:
- Completed the basic installation
- Access to a Google Cloud project with TPU quota
- Google Cloud SDK installed
- Ray installed
A typical TPU cluster consists of:
- Head Node: A persistent GCP VM that coordinates the cluster
- Worker Nodes: Autoscaling TPU VMs that execute the actual computation
The cluster configuration is defined in a YAML template file. This template allows you to customize various aspects of your cluster:
- Machine types for head and worker nodes
- Number of CPUs and amount of RAM
- Disk size and type
- TPU types and versions
- Minimum and maximum worker counts
- Docker images and initialization commands
- Network configuration
- Environment variables and secrets. We use GCP secrets manager to store secrets.
For our clusters, each config is in a separate file in the infra directory. These files are automatically generated by the
scripts/ray/cluster.py script, which reads the infra/marin-cluster-template.yaml file.
More information can be found in the infra/README.md.
We use docker images to run the jobs inside the ray cluster. We provide a convieninet make target make cluster_docker to build the docker image and upload it to GCP artifact registry for various different regions. This make target performs the following steps:
- Build the docker image inside [Dockerfile.cluster]((https://github.com/marin-community/tree/main/docker/marin/Dockerfile.cluster)
- Tag the docker image with the different region names
- Push the docker images to respective GCP artifact registries
-
Install Marin with TPU support:
uv sync --extra=tpu
-
Set up your GCP environment:
# Configure Docker for GCP artifact registry gcloud auth configure-docker <region>-docker.pkg.dev
-
Start the cluster (replace
$REGIONwith your desired region, e.g.,us-central2):ray up -y infra/marin-$REGION.yaml
For running any meaningful job on the cluster, you will have to make a GCS bucket and also set the MARIN_PREFIX environment variable to the bucket name, e.g:- gs://marin-us-central2. We use this bucket to store all the checkpoints, datasets, and other artifacts.
If you haven't setup your SSH keys yet, use
make dev_setup
to setup the required keys for accessing the Marin Ray clusters.
-
Connect to the cluster using ray dashboard:
ray dashboard infra/marin-$REGION.yaml -
In another terminal, submit a job using ray submit:
export RAY_ADDRESS=http://localhost:8265 ray job submit --working-dir . -- python experiments/tutorials/hello_world.py
Note that
RAY_ADDRESSneeds to use thehttp://URL specifier when connecting via the dashboard.
- Connect to the cluster using ray dashboard:
ray dashboard infra/marin-$REGION.yaml - We have made a thin wrapper on top of
ray submitcalled ray-run which can be used to easily specify the additional pip requirements and environment variables (Apart from those specified in the Dockerfile or cluster config). Ray run ensures following stuff:- All the dependencies in pyproject.toml are installed
- Some of the necessary environment variables are set from file vars.py
srcdirectories of submodules are added toPYTHONPATH. This is useful for co-developing
You can use --env_vars and --pip_deps to specify additional environment variables and pip dependencies. For example, to run a job with additional pip dependencies `jax==0.4.35 and async-lru, along with the environment variable WANDB_API_KEY, you can use:
python marin/run/ray_run.py --env_vars WANDB_API_KEY ${WANDB_API_KEY} --pip_deps jax==0.4.35,async-lru -- python experiments/tutorials/hello_world.pyMost cluster management for Ray & GCP is handled by scripts/ray/cluster.py.
You can specify a path to either a config file or a region to operate on e.g. eu-west4.
# Basic cluster operations
uv run scripts/ray/cluster.py --config infra/marin-$REGION.yaml start-cluster
uv run scripts/ray/cluster.py --config infra/marin-$REGION.yaml stop-cluster
uv run scripts/ray/cluster.py --config infra/marin-$REGION.yaml restart-cluster
uv run scripts/ray/cluster.py --config infra/marin-$REGION.yaml get-status
# List and update configurations
uv run scripts/ray/cluster.py list-configs
uv run scripts/ray/cluster.py update-configs
# SSH operations
uv run scripts/ray/cluster.py --config infra/marin-$REGION.yaml ssh-head
uv run scripts/ray/cluster.py ssh-tpu 10.130.1.40
# Maintenance operations
uv run scripts/ray/cluster.py clean-tpu-processes --dry-run
uv run scripts/ray/cluster.py clean-tpu-processes --tpu-type v4-8
uv run scripts/ray/cluster.py --config infra/marin-$REGION.yaml clean-preempted-tpus
# Job operations
uv run scripts/ray/cluster.py --config infra/marin-$REGION.yaml list-jobs
uv run scripts/ray/cluster.py --config infra/marin-$REGION.yaml submit-job "python experiments/tutorials/hello_world.py"
# Monitoring and logs
uv run scripts/ray/cluster.py --config infra/marin-$REGION.yaml monitor-cluster --wandb
uv run scripts/ray/cluster.py --config infra/marin-$REGION.yaml show-logs --tail 100
uv run scripts/ray/cluster.py --config infra/marin-$REGION.yaml open-dashboardThe original Ray commands still work but are less convenient:
# Monitor cluster status
ray status infra/marin-$REGION.yaml
# SSH into head node
ray attach infra/marin-$REGION.yaml
# View autoscaler logs
ray exec infra/marin-$REGION.yaml "tail -n 100 -f /tmp/ray/session_latest/logs/monitor*"
# Stop the cluster
ray down -y infra/marin-$REGION.yaml-
Data Location: Use a bucket in the same region as your cluster to minimize data transfer costs.
-
Preemption Handling: If your compute is preemptible, meaning VMs can be shut down at any time. Design your jobs to:
- Be idempotent (can run multiple times safely)
- Handle interruptions gracefully
- Save checkpoints regularly
-
Resource Management:
- Avoid running heavy computation on the head node
- Break jobs into tasks that complete within 1-10 minutes
- Use appropriate TPU types for your workload (v4-8 for v4 clusters, v5e-1 for v5e clusters)
If you encounter issues:
-
Cluster Not Scaling: Check autoscaler logs for errors:
ray exec infra/marin-$REGION.yaml "tail -n 100 -f /tmp/ray/session_latest/logs/monitor*"
-
Worker Connection Issues: If workers aren't connecting, try:
- Checking GCP Console for VM/TPU status
- Reviewing Ray dashboard on the head node
- Restarting the cluster if needed
-
Cluster State Issues: If the cluster state becomes inconsistent:
- Bring down the cluster completely
- Delete any orphaned resources in GCP Console
- Bring the cluster back up
- Try running a basic training job to verify your setup
- Explore the Language Modeling Pipeline to understand how Marin uses TPUs for training
- Read the Ray documentation for more details on cluster management
- Levanter Ray autoscaling setup