Skip to content

Latest commit

 

History

History
204 lines (147 loc) · 8.49 KB

File metadata and controls

204 lines (147 loc) · 8.49 KB

Setting up a TPU Cluster

!!! note This document is a work in progress. It is not yet complete and may contain errors or omissions. Please check back later for updates, and ask in the Discord for help.

This guide will walk you through the steps to set up and manage a TPU cluster using Google Cloud Platform (GCP) and Ray. Currently, this is the only way to use multinode TPU (i.e. more than a v4-8) with Marin.

Prerequisites

Before you begin, ensure you have:

TPU Cluster Architecture

A typical TPU cluster consists of:

  • Head Node: A persistent GCP VM that coordinates the cluster
  • Worker Nodes: Autoscaling TPU VMs that execute the actual computation

Cluster Configuration

The cluster configuration is defined in a YAML template file. This template allows you to customize various aspects of your cluster:

  • Machine types for head and worker nodes
  • Number of CPUs and amount of RAM
  • Disk size and type
  • TPU types and versions
  • Minimum and maximum worker counts
  • Docker images and initialization commands
  • Network configuration
  • Environment variables and secrets. We use GCP secrets manager to store secrets.

For our clusters, each config is in a separate file in the infra directory. These files are automatically generated by the scripts/ray/cluster.py script, which reads the infra/marin-cluster-template.yaml file. More information can be found in the infra/README.md.

Building and Uploading the Docker Image

We use docker images to run the jobs inside the ray cluster. We provide a convieninet make target make cluster_docker to build the docker image and upload it to GCP artifact registry for various different regions. This make target performs the following steps:

  1. Build the docker image inside [Dockerfile.cluster]((https://github.com/marin-community/tree/main/docker/marin/Dockerfile.cluster)
  2. Tag the docker image with the different region names
  3. Push the docker images to respective GCP artifact registries

Setting Up the Cluster

  1. Install Marin with TPU support:

    uv sync --extra=tpu
  2. Set up your GCP environment:

    # Configure Docker for GCP artifact registry
    gcloud auth configure-docker <region>-docker.pkg.dev
  3. Start the cluster (replace $REGION with your desired region, e.g., us-central2):

    ray up -y infra/marin-$REGION.yaml

Executing a Job on the cluster

For running any meaningful job on the cluster, you will have to make a GCS bucket and also set the MARIN_PREFIX environment variable to the bucket name, e.g:- gs://marin-us-central2. We use this bucket to store all the checkpoints, datasets, and other artifacts.

If you haven't setup your SSH keys yet, use

make dev_setup

to setup the required keys for accessing the Marin Ray clusters.

Using Ray Submit

  1. Connect to the cluster using ray dashboard:

    ray dashboard infra/marin-$REGION.yaml
  2. In another terminal, submit a job using ray submit:

    export RAY_ADDRESS=http://localhost:8265
    ray job submit --working-dir . -- python experiments/tutorials/hello_world.py

    Note that RAY_ADDRESS needs to use the http:// URL specifier when connecting via the dashboard.

Using Ray Run

  1. Connect to the cluster using ray dashboard:
    ray dashboard infra/marin-$REGION.yaml
  2. We have made a thin wrapper on top of ray submit called ray-run which can be used to easily specify the additional pip requirements and environment variables (Apart from those specified in the Dockerfile or cluster config). Ray run ensures following stuff:
    • All the dependencies in pyproject.toml are installed
    • Some of the necessary environment variables are set from file vars.py
    • src directories of submodules are added to PYTHONPATH. This is useful for co-developing

You can use --env_vars and --pip_deps to specify additional environment variables and pip dependencies. For example, to run a job with additional pip dependencies `jax==0.4.35 and async-lru, along with the environment variable WANDB_API_KEY, you can use:

python marin/run/ray_run.py --env_vars WANDB_API_KEY ${WANDB_API_KEY}  --pip_deps jax==0.4.35,async-lru --  python experiments/tutorials/hello_world.py

Managing the Cluster

Using the cluster management script

Most cluster management for Ray & GCP is handled by scripts/ray/cluster.py. You can specify a path to either a config file or a region to operate on e.g. eu-west4.

# Basic cluster operations
uv run scripts/ray/cluster.py --config infra/marin-$REGION.yaml start-cluster
uv run scripts/ray/cluster.py --config infra/marin-$REGION.yaml stop-cluster
uv run scripts/ray/cluster.py --config infra/marin-$REGION.yaml restart-cluster
uv run scripts/ray/cluster.py --config infra/marin-$REGION.yaml get-status

# List and update configurations
uv run scripts/ray/cluster.py list-configs
uv run scripts/ray/cluster.py update-configs

# SSH operations
uv run scripts/ray/cluster.py --config infra/marin-$REGION.yaml ssh-head
uv run scripts/ray/cluster.py ssh-tpu 10.130.1.40

# Maintenance operations
uv run scripts/ray/cluster.py clean-tpu-processes --dry-run
uv run scripts/ray/cluster.py clean-tpu-processes --tpu-type v4-8
uv run scripts/ray/cluster.py --config infra/marin-$REGION.yaml clean-preempted-tpus

# Job operations
uv run scripts/ray/cluster.py --config infra/marin-$REGION.yaml list-jobs
uv run scripts/ray/cluster.py --config infra/marin-$REGION.yaml submit-job "python experiments/tutorials/hello_world.py"

# Monitoring and logs
uv run scripts/ray/cluster.py --config infra/marin-$REGION.yaml monitor-cluster --wandb
uv run scripts/ray/cluster.py --config infra/marin-$REGION.yaml show-logs --tail 100
uv run scripts/ray/cluster.py --config infra/marin-$REGION.yaml open-dashboard

Legacy Commands (still supported)

The original Ray commands still work but are less convenient:

# Monitor cluster status
ray status infra/marin-$REGION.yaml

# SSH into head node
ray attach infra/marin-$REGION.yaml

# View autoscaler logs
ray exec infra/marin-$REGION.yaml "tail -n 100 -f /tmp/ray/session_latest/logs/monitor*"

# Stop the cluster
ray down -y infra/marin-$REGION.yaml

Best Practices

  1. Data Location: Use a bucket in the same region as your cluster to minimize data transfer costs.

  2. Preemption Handling: If your compute is preemptible, meaning VMs can be shut down at any time. Design your jobs to:

    • Be idempotent (can run multiple times safely)
    • Handle interruptions gracefully
    • Save checkpoints regularly
  3. Resource Management:

    • Avoid running heavy computation on the head node
    • Break jobs into tasks that complete within 1-10 minutes
    • Use appropriate TPU types for your workload (v4-8 for v4 clusters, v5e-1 for v5e clusters)

Troubleshooting

If you encounter issues:

  1. Cluster Not Scaling: Check autoscaler logs for errors:

    ray exec infra/marin-$REGION.yaml "tail -n 100 -f /tmp/ray/session_latest/logs/monitor*"
  2. Worker Connection Issues: If workers aren't connecting, try:

    • Checking GCP Console for VM/TPU status
    • Reviewing Ray dashboard on the head node
    • Restarting the cluster if needed
  3. Cluster State Issues: If the cluster state becomes inconsistent:

    • Bring down the cluster completely
    • Delete any orphaned resources in GCP Console
    • Bring the cluster back up

Next Steps

  1. Try running a basic training job to verify your setup
  2. Explore the Language Modeling Pipeline to understand how Marin uses TPUs for training
  3. Read the Ray documentation for more details on cluster management
  4. Levanter Ray autoscaling setup