From 696e88dc2a09ed9658d543176cd32e746567a199 Mon Sep 17 00:00:00 2001 From: asahoo Date: Wed, 1 Jul 2026 18:12:34 -0500 Subject: [PATCH 1/7] docs: document --use-cog-base-image and prebuilt base images Closes #1820 --- docs/base-images.md | 70 +++++++++++++++++++++++++++++++ docs/getting-started-own-model.md | 2 + docs/getting-started.md | 5 +++ docs/llms.txt | 7 ++++ mkdocs.yml | 1 + 5 files changed, 85 insertions(+) create mode 100644 docs/base-images.md diff --git a/docs/base-images.md b/docs/base-images.md new file mode 100644 index 0000000000..557cc91339 --- /dev/null +++ b/docs/base-images.md @@ -0,0 +1,70 @@ +# Base images + +Cog builds your model into a Docker image. To speed up builds and reduce cold boot times, Cog uses **prebuilt base images** by default. These images contain the common dependencies that most models need, so Cog doesn't have to install them from scratch every time you build or push. + +## What is a Cog base image? + +A Cog base image is a Docker image maintained by Replicate that includes: + +- **Python runtime** — the Python version specified in your `cog.yaml`. +- **System libraries** — common dependencies for machine learning and media processing, including `ffmpeg`, `git`, `curl`, `build-essential`, `cmake`, OpenCV libraries, audio libraries (`sox`, `libsndfile1`), and graphics libraries (`libgl1`, `libsm6`, `libxext6`). +- **CUDA and PyTorch** (GPU images only) — the appropriate CUDA toolkit and PyTorch stack for your configuration. +- **Cog runtime** — the Cog SDK and coglet prediction server. + +Base images are tagged by their configuration, for example: + +``` +r8.im/cog-base:cuda12-python3.13-torch2.6 +r8.im/cog-base:python3.13 +``` + +When you run `cog build` or `cog push`, Cog selects the base image that matches your Python version, CUDA version, and PyTorch version. Because these images are pre-pulled on Replicate's infrastructure, models built on top of them start faster. + +## Using the Cog base image + +The `--use-cog-base-image` flag controls whether Cog uses a prebuilt base image. It is **enabled by default** on the following commands: + +- [`cog build`](cli.md#cog-build) +- [`cog push`](cli.md#cog-push) +- [`cog run`](cli.md#cog-run) +- [`cog serve`](cli.md#cog-serve) +- [`cog exec`](cli.md#cog-exec) + +Since it's on by default, you don't need to pass any flags: + +```bash +cog push r8.im/your-username/my-model +``` + +This builds and pushes your model using a prebuilt Cog base image for faster cold boots. + +## Disabling the Cog base image + +If you encounter build issues or need a custom base layer, you can disable the Cog base image: + +```bash +cog build --use-cog-base-image=false +``` + +When disabled, Cog generates a Dockerfile from scratch using either an NVIDIA CUDA base image or a slim Python base image, depending on the `--use-cuda-base-image` flag. + +## Relationship to `--use-cuda-base-image` + +The `--use-cuda-base-image` flag controls which underlying base image Cog uses **when the Cog base image is disabled**. It has no effect when `--use-cog-base-image` is enabled (the default), because the Cog base image already includes the appropriate CUDA and Python stack. + +When you disable the Cog base image with `--use-cog-base-image=false`, Cog chooses the base image automatically: + +- **GPU models** (`gpu: true` in `cog.yaml`): uses an NVIDIA CUDA base image. +- **CPU models**: uses a slim Python base image. + +These flags — along with `--dockerfile` — are **mutually exclusive**: you can only set one of them explicitly on a given command. To customize the base image, disable the Cog base image and let Cog choose between CUDA and Python automatically. + +## Troubleshooting + +If `cog build` or `cog push` fails with the Cog base image enabled, try disabling it: + +```bash +cog push --use-cog-base-image=false r8.im/your-username/my-model +``` + +This falls back to building from a standard CUDA or Python base image, which can help diagnose whether the issue is with the base image or your model's configuration. diff --git a/docs/getting-started-own-model.md b/docs/getting-started-own-model.md index 4bc858b58c..20a830295f 100644 --- a/docs/getting-started-own-model.md +++ b/docs/getting-started-own-model.md @@ -58,6 +58,8 @@ Type "help", "copyright", "credits" or "license" for more information. This is handy for ensuring a consistent environment for development or training. +By default, Cog builds on top of a [prebuilt base image](base-images.md) that includes Python, common system libraries, and the Cog runtime. This speeds up builds and reduces cold boot times when you deploy with `cog push`. + With `cog.yaml`, you can also install system packages and other things. [Take a look at the full reference to see what else you can do.](yaml.md) ## Define how to run your model diff --git a/docs/getting-started.md b/docs/getting-started.md index e11587fd14..ae32b53cc2 100644 --- a/docs/getting-started.md +++ b/docs/getting-started.md @@ -151,6 +151,8 @@ Note: The first time you run `cog run`, the build process will be triggered to g We can bake your model's code, the trained weights, and the Docker environment into a Docker image. This image serves an HTTP server, and can be deployed to anywhere that Docker runs to serve real-time inference. +By default, Cog builds on top of a [prebuilt base image](base-images.md) that includes Python, common system libraries, and the Cog runtime. This significantly reduces cold boot times when deploying your model. + ```bash cog build -t resnet # Building Docker image... @@ -198,6 +200,9 @@ cog push The Docker image is now accessible to anyone or any system that has access to this Docker registry. +> [!TIP] +> `cog push` uses a [prebuilt Cog base image](base-images.md) by default for faster cold boots. If you run into build issues, try disabling it with `--use-cog-base-image=false`. + ## Next steps Those are the basics! Next, you might want to take a look at: diff --git a/docs/llms.txt b/docs/llms.txt index 11cb30b033..7639b4aef2 100644 --- a/docs/llms.txt +++ b/docs/llms.txt @@ -1187,6 +1187,8 @@ Type "help", "copyright", "credits" or "license" for more information. This is handy for ensuring a consistent environment for development or training. +By default, Cog builds on top of a [prebuilt base image](base-images.md) that includes Python, common system libraries, and the Cog runtime. This speeds up builds and reduces cold boot times when you deploy with `cog push`. + With `cog.yaml`, you can also install system packages and other things. [Take a look at the full reference to see what else you can do.](yaml.md) ## Define how to run your model @@ -1443,6 +1445,8 @@ Note: The first time you run `cog run`, the build process will be triggered to g We can bake your model's code, the trained weights, and the Docker environment into a Docker image. This image serves an HTTP server, and can be deployed to anywhere that Docker runs to serve real-time inference. +By default, Cog builds on top of a [prebuilt base image](base-images.md) that includes Python, common system libraries, and the Cog runtime. This significantly reduces cold boot times when deploying your model. + ```bash cog build -t resnet # Building Docker image... @@ -1490,6 +1494,9 @@ cog push The Docker image is now accessible to anyone or any system that has access to this Docker registry. +> [!TIP] +> `cog push` uses a [prebuilt Cog base image](base-images.md) by default for faster cold boots. If you run into build issues, try disabling it with `--use-cog-base-image=false`. + ## Next steps Those are the basics! Next, you might want to take a look at: diff --git a/mkdocs.yml b/mkdocs.yml index 6c394d3c31..dfc0390ad5 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -12,6 +12,7 @@ nav: - Training API: training.md - HTTP API: http.md - CLI: cli.md + - Base images: base-images.md - Environment variables: environment.md - Private registry: private-package-registry.md - Notebooks: notebooks.md From d71f7445de3da906fd379c5f8cd64fb6a63f4368 Mon Sep 17 00:00:00 2001 From: asahoo Date: Wed, 1 Jul 2026 18:25:19 -0500 Subject: [PATCH 2/7] docs: add link to available base image list Closes #1820 --- docs/base-images.md | 2 ++ docs/llms.txt | 76 +++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 78 insertions(+) diff --git a/docs/base-images.md b/docs/base-images.md index 557cc91339..843ea44fa7 100644 --- a/docs/base-images.md +++ b/docs/base-images.md @@ -20,6 +20,8 @@ r8.im/cog-base:python3.13 When you run `cog build` or `cog push`, Cog selects the base image that matches your Python version, CUDA version, and PyTorch version. Because these images are pre-pulled on Replicate's infrastructure, models built on top of them start faster. +The entire list of available base images can be viewed at [https://r8.im/v2/cog-base/tags/list](https://r8.im/v2/cog-base/tags/list). + ## Using the Cog base image The `--use-cog-base-image` flag controls whether Cog uses a prebuilt base image. It is **enabled by default** on the following commands: diff --git a/docs/llms.txt b/docs/llms.txt index 7639b4aef2..dbae30e7d3 100644 --- a/docs/llms.txt +++ b/docs/llms.txt @@ -242,6 +242,82 @@ See [CONTRIBUTING.md](CONTRIBUTING.md) for how to set up a development environme +--- + +# Base images + +Cog builds your model into a Docker image. To speed up builds and reduce cold boot times, Cog uses **prebuilt base images** by default. These images contain the common dependencies that most models need, so Cog doesn't have to install them from scratch every time you build or push. + +## What is a Cog base image? + +A Cog base image is a Docker image maintained by Replicate that includes: + +- **Python runtime** — the Python version specified in your `cog.yaml`. +- **System libraries** — common dependencies for machine learning and media processing, including `ffmpeg`, `git`, `curl`, `build-essential`, `cmake`, OpenCV libraries, audio libraries (`sox`, `libsndfile1`), and graphics libraries (`libgl1`, `libsm6`, `libxext6`). +- **CUDA and PyTorch** (GPU images only) — the appropriate CUDA toolkit and PyTorch stack for your configuration. +- **Cog runtime** — the Cog SDK and coglet prediction server. + +Base images are tagged by their configuration, for example: + +``` +r8.im/cog-base:cuda12-python3.13-torch2.6 +r8.im/cog-base:python3.13 +``` + +When you run `cog build` or `cog push`, Cog selects the base image that matches your Python version, CUDA version, and PyTorch version. Because these images are pre-pulled on Replicate's infrastructure, models built on top of them start faster. + +The entire list of available base images can be viewed at [https://r8.im/v2/cog-base/tags/list](https://r8.im/v2/cog-base/tags/list). + +## Using the Cog base image + +The `--use-cog-base-image` flag controls whether Cog uses a prebuilt base image. It is **enabled by default** on the following commands: + +- [`cog build`](cli.md#cog-build) +- [`cog push`](cli.md#cog-push) +- [`cog run`](cli.md#cog-run) +- [`cog serve`](cli.md#cog-serve) +- [`cog exec`](cli.md#cog-exec) + +Since it's on by default, you don't need to pass any flags: + +```bash +cog push r8.im/your-username/my-model +``` + +This builds and pushes your model using a prebuilt Cog base image for faster cold boots. + +## Disabling the Cog base image + +If you encounter build issues or need a custom base layer, you can disable the Cog base image: + +```bash +cog build --use-cog-base-image=false +``` + +When disabled, Cog generates a Dockerfile from scratch using either an NVIDIA CUDA base image or a slim Python base image, depending on the `--use-cuda-base-image` flag. + +## Relationship to `--use-cuda-base-image` + +The `--use-cuda-base-image` flag controls which underlying base image Cog uses **when the Cog base image is disabled**. It has no effect when `--use-cog-base-image` is enabled (the default), because the Cog base image already includes the appropriate CUDA and Python stack. + +When you disable the Cog base image with `--use-cog-base-image=false`, Cog chooses the base image automatically: + +- **GPU models** (`gpu: true` in `cog.yaml`): uses an NVIDIA CUDA base image. +- **CPU models**: uses a slim Python base image. + +These flags — along with `--dockerfile` — are **mutually exclusive**: you can only set one of them explicitly on a given command. To customize the base image, disable the Cog base image and let Cog choose between CUDA and Python automatically. + +## Troubleshooting + +If `cog build` or `cog push` fails with the Cog base image enabled, try disabling it: + +```bash +cog push --use-cog-base-image=false r8.im/your-username/my-model +``` + +This falls back to building from a standard CUDA or Python base image, which can help diagnose whether the issue is with the base image or your model's configuration. + + --- # CLI reference From e433b85fc376eb08f006b733c6495eeaadbfcf1f Mon Sep 17 00:00:00 2001 From: asahoo Date: Wed, 1 Jul 2026 18:32:51 -0500 Subject: [PATCH 3/7] feat: add concurrent decorator --- crates/coglet-python/src/lib.rs | 40 ++++- crates/coglet-python/src/worker_bridge.rs | 12 +- docs/deploy.md | 34 ++-- docs/environment.md | 13 +- docs/examples.md | 2 +- docs/http.md | 2 +- docs/llms.txt | 69 +++++--- docs/python.md | 15 +- docs/yaml.md | 3 +- examples/hello-concurrency/README.md | 15 +- examples/hello-concurrency/cog.yaml | 2 - examples/hello-concurrency/run.py | 2 + .../concurrent/concurrent_test.go | 10 +- pkg/config/validate.go | 8 + pkg/config/validate_test.go | 19 +++ pkg/image/build.go | 96 ++++++++++- pkg/image/build_test.go | 100 +++++++++++ pkg/schema/generator.go | 11 +- pkg/schema/python/annotations.go | 105 ++++++++++++ pkg/schema/python/parser.go | 52 ++++++ pkg/schema/python/parser_test.go | 157 ++++++++++++++++++ pkg/schema/python/state.go | 2 + pkg/schema/types.go | 2 + python/cog/__init__.py | 37 +++++ python/tests/test_concurrent.py | 74 +++++++++ 25 files changed, 819 insertions(+), 63 deletions(-) create mode 100644 python/tests/test_concurrent.py diff --git a/crates/coglet-python/src/lib.rs b/crates/coglet-python/src/lib.rs index 35130783ba..980b6c9cfd 100644 --- a/crates/coglet-python/src/lib.rs +++ b/crates/coglet-python/src/lib.rs @@ -188,11 +188,25 @@ fn detect_version(py: Python<'_>, build: &BuildInfo) -> VersionInfo { fn read_max_concurrency() -> usize { match std::env::var("COG_MAX_CONCURRENCY") { - Ok(val) => val.parse::().unwrap_or(1), + Ok(val) => match parse_max_concurrency(&val) { + Some(n) => n, + None => { + warn!(value = %val, "Invalid COG_MAX_CONCURRENCY value, defaulting to 1"); + 1 + } + }, Err(_) => 1, } } +fn parse_max_concurrency(val: &str) -> Option { + match val.parse::() { + Ok(0) => None, + Ok(n) => Some(n), + Err(_) => None, + } +} + fn read_setup_timeout() -> Option { match std::env::var("COG_SETUP_TIMEOUT") { Ok(val) => match val.parse::() { @@ -542,10 +556,10 @@ async fn run_worker_with_init() -> Result<(), String> { info!(predictor_ref = %predictor_ref, num_slots, is_train, "Init received, connecting to transport"); let handler = Arc::new(if is_train { - worker_bridge::PythonPredictHandler::new_train(predictor_ref) + worker_bridge::PythonPredictHandler::new_train(predictor_ref, num_slots) .map_err(|e| format!("Failed to create handler: {}", e))? } else { - worker_bridge::PythonPredictHandler::new(predictor_ref) + worker_bridge::PythonPredictHandler::new(predictor_ref, num_slots) .map_err(|e| format!("Failed to create handler: {}", e))? }); @@ -568,6 +582,26 @@ async fn run_worker_with_init() -> Result<(), String> { .map_err(|e| format!("Worker error: {}", e)) } +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn parse_max_concurrency_reads_valid_value() { + assert_eq!(parse_max_concurrency("4"), Some(4)); + } + + #[test] + fn parse_max_concurrency_rejects_invalid_value() { + assert_eq!(parse_max_concurrency("wat"), None); + } + + #[test] + fn parse_max_concurrency_rejects_zero() { + assert_eq!(parse_max_concurrency("0"), None); + } +} + // ============================================================================= // Module init // ============================================================================= diff --git a/crates/coglet-python/src/worker_bridge.rs b/crates/coglet-python/src/worker_bridge.rs index 4378e729f8..3e697b364b 100644 --- a/crates/coglet-python/src/worker_bridge.rs +++ b/crates/coglet-python/src/worker_bridge.rs @@ -101,11 +101,12 @@ pub struct PythonPredictHandler { async_loop: Mutex>>, /// Handle to the asyncio loop thread for joining on shutdown. async_thread: Mutex>>, + max_concurrency: usize, } impl PythonPredictHandler { /// Create a handler in prediction mode. - pub fn new(predictor_ref: String) -> Result { + pub fn new(predictor_ref: String, max_concurrency: usize) -> Result { let (loop_obj, thread) = Self::init_async_loop()?; Ok(Self { predictor_ref, @@ -114,6 +115,7 @@ impl PythonPredictHandler { mode: HandlerMode::Predict, async_loop: Mutex::new(Some(loop_obj)), async_thread: Mutex::new(Some(thread)), + max_concurrency, }) } @@ -122,7 +124,7 @@ impl PythonPredictHandler { /// NOTE: For bug-for-bug compatibility with cog mainline, use new() instead. /// Cog mainline's training routes incorrectly use a predict-mode worker. #[allow(dead_code)] - pub fn new_train(predictor_ref: String) -> Result { + pub fn new_train(predictor_ref: String, max_concurrency: usize) -> Result { let (loop_obj, thread) = Self::init_async_loop()?; Ok(Self { predictor_ref, @@ -131,6 +133,7 @@ impl PythonPredictHandler { mode: HandlerMode::Train, async_loop: Mutex::new(Some(loop_obj)), async_thread: Mutex::new(Some(thread)), + max_concurrency, }) } @@ -281,6 +284,11 @@ impl PredictHandler for PythonPredictHandler { let pred = PythonPredictor::load(py, &self.predictor_ref) .map_err(|e| SetupError::load(e.to_string()))?; + if self.max_concurrency > 1 && !pred.is_async() { + return Err(SetupError::setup( + "COG_MAX_CONCURRENCY > 1 requires an async run() or predict() method", + )); + } // Detect SDK implementation let sdk_impl = match py.import("cog") { diff --git a/docs/deploy.md b/docs/deploy.md index 63a5fa83c7..a1e833786f 100644 --- a/docs/deploy.md +++ b/docs/deploy.md @@ -76,11 +76,11 @@ curl http://localhost:5001/predictions -X POST \ ```json { - "status": "succeeded", - "output": "data:image/png;base64,...", - "metrics": { - "predict_time": 4.52 - } + "status": "succeeded", + "output": "data:image/png;base64,...", + "metrics": { + "predict_time": 4.52 + } } ``` @@ -144,8 +144,8 @@ the response contains a base64-encoded data URL by default: ```json { - "status": "succeeded", - "output": "data:image/png;base64,iVBORw0KGgo..." + "status": "succeeded", + "output": "data:image/png;base64,iVBORw0KGgo..." } ``` @@ -172,8 +172,8 @@ contains the uploaded URL instead of a data URL: ```json { - "status": "succeeded", - "output": "https://example.com/upload/image.png" + "status": "succeeded", + "output": "https://example.com/upload/image.png" } ``` @@ -215,21 +215,25 @@ to stop.) ## Concurrency -By default, the server processes one run at a time. To enable concurrent runs, set the `concurrency.max` option in `cog.yaml`: +By default, the server processes one run at a time. To enable concurrent runs, make your `run()` method async and decorate it with `@cog.concurrent(max=N)`: -```yaml -concurrency: - max: 4 +```py +import cog + +class Runner(cog.BaseRunner): + @cog.concurrent(max=4) + async def run(self) -> str: + return "hello world" ``` -See the [`cog.yaml` reference](yaml.md#concurrency) for more details. +The deprecated [`concurrency.max`](yaml.md#concurrency) field in `cog.yaml` is still supported and takes precedence over the decorator by baking `COG_MAX_CONCURRENCY` into the image. ## Environment variables You can configure runtime behavior with environment variables: - `COG_SETUP_TIMEOUT`: Maximum time in seconds for the `setup()` method (default: no timeout). -- `COG_MAX_CONCURRENCY`: Number of concurrent prediction slots (default: 1). +- `COG_MAX_CONCURRENCY`: Number of concurrent prediction slots (default: 1). Overrides both `@cog.concurrent` and deprecated `cog.yaml` concurrency. See the [environment variables reference](environment.md) for the full list. diff --git a/docs/environment.md b/docs/environment.md index bbf02e2a90..ee9b3afc6b 100644 --- a/docs/environment.md +++ b/docs/environment.md @@ -173,9 +173,18 @@ These variables affect a running model server. Set them in `cog.yaml` under `env ### `COG_MAX_CONCURRENCY` -Controls how many predictions the model server can run concurrently. +Controls how many predictions the model server can run concurrently. This overrides both `@cog.concurrent(max=N)` and the deprecated `concurrency.max` field in `cog.yaml`. -By default, Cog runs one prediction at a time. Invalid values are ignored and the default of `1` is used. +By default, Cog runs one prediction at a time unless the model uses `@cog.concurrent(max=N)`. Invalid values are ignored and the default of `1` is used. + +Values greater than `1` require an async `run()` method. This applies even when `COG_MAX_CONCURRENCY` is set as a runtime operator override. + +Concurrency is resolved in this order, from highest to lowest precedence: + +1. `COG_MAX_CONCURRENCY` set at runtime +2. Deprecated `concurrency.max` in `cog.yaml`, which is baked into the image as `COG_MAX_CONCURRENCY` +3. `@cog.concurrent(max=N)` on the async `run()` method +4. Default: `1` ```console $ COG_MAX_CONCURRENCY=4 docker run -p 5000:5000 my-model diff --git a/docs/examples.md b/docs/examples.md index 973c09093b..31a930f35a 100644 --- a/docs/examples.md +++ b/docs/examples.md @@ -12,7 +12,7 @@ If you want a working project to copy from, start with one of these: - [`hello-context`](https://github.com/replicate/cog/tree/main/examples/hello-context): shows how to read prediction context - [`hello-concurrency`](https://github.com/replicate/cog/tree/main/examples/hello-concurrency): - demonstrates the `concurrency` setting in `cog.yaml` + demonstrates the `@cog.concurrent` decorator - [`hello-train`](https://github.com/replicate/cog/tree/main/examples/hello-train): defines a training interface - [`hello-replicate`](https://github.com/replicate/cog/tree/main/examples/hello-replicate): diff --git a/docs/http.md b/docs/http.md index 592c4cee59..de2a36dac4 100644 --- a/docs/http.md +++ b/docs/http.md @@ -255,7 +255,7 @@ Prefer: respond-async Endpoints for creating and canceling a prediction idempotently accept a `prediction_id` parameter in their path. By default, the server runs one prediction at a time, -but this can be increased with the [`concurrency.max`](yaml.md#concurrency) setting. +but this can be increased with [`@cog.concurrent(max=N)`](python.md#async-runners-and-concurrency). When all prediction slots are in use, the server returns `409 Conflict`. The client should ensure prediction slots are available before creating a new prediction with a different ID. diff --git a/docs/llms.txt b/docs/llms.txt index 11cb30b033..1c61c7c8cf 100644 --- a/docs/llms.txt +++ b/docs/llms.txt @@ -633,11 +633,11 @@ curl http://localhost:5001/predictions -X POST \ ```json { - "status": "succeeded", - "output": "data:image/png;base64,...", - "metrics": { - "predict_time": 4.52 - } + "status": "succeeded", + "output": "data:image/png;base64,...", + "metrics": { + "predict_time": 4.52 + } } ``` @@ -701,8 +701,8 @@ the response contains a base64-encoded data URL by default: ```json { - "status": "succeeded", - "output": "data:image/png;base64,iVBORw0KGgo..." + "status": "succeeded", + "output": "data:image/png;base64,iVBORw0KGgo..." } ``` @@ -729,8 +729,8 @@ contains the uploaded URL instead of a data URL: ```json { - "status": "succeeded", - "output": "https://example.com/upload/image.png" + "status": "succeeded", + "output": "https://example.com/upload/image.png" } ``` @@ -772,21 +772,25 @@ to stop.) ## Concurrency -By default, the server processes one run at a time. To enable concurrent runs, set the `concurrency.max` option in `cog.yaml`: +By default, the server processes one run at a time. To enable concurrent runs, make your `run()` method async and decorate it with `@cog.concurrent(max=N)`: -```yaml -concurrency: - max: 4 +```py +import cog + +class Runner(cog.BaseRunner): + @cog.concurrent(max=4) + async def run(self) -> str: + return "hello world" ``` -See the [`cog.yaml` reference](yaml.md#concurrency) for more details. +The deprecated [`concurrency.max`](yaml.md#concurrency) field in `cog.yaml` is still supported and takes precedence over the decorator by baking `COG_MAX_CONCURRENCY` into the image. ## Environment variables You can configure runtime behavior with environment variables: - `COG_SETUP_TIMEOUT`: Maximum time in seconds for the `setup()` method (default: no timeout). -- `COG_MAX_CONCURRENCY`: Number of concurrent prediction slots (default: 1). +- `COG_MAX_CONCURRENCY`: Number of concurrent prediction slots (default: 1). Overrides both `@cog.concurrent` and deprecated `cog.yaml` concurrency. See the [environment variables reference](environment.md) for the full list. @@ -974,9 +978,18 @@ These variables affect a running model server. Set them in `cog.yaml` under `env ### `COG_MAX_CONCURRENCY` -Controls how many predictions the model server can run concurrently. +Controls how many predictions the model server can run concurrently. This overrides both `@cog.concurrent(max=N)` and the deprecated `concurrency.max` field in `cog.yaml`. + +By default, Cog runs one prediction at a time unless the model uses `@cog.concurrent(max=N)`. Invalid values are ignored and the default of `1` is used. + +Values greater than `1` require an async `run()` method. This applies even when `COG_MAX_CONCURRENCY` is set as a runtime operator override. + +Concurrency is resolved in this order, from highest to lowest precedence: -By default, Cog runs one prediction at a time. Invalid values are ignored and the default of `1` is used. +1. `COG_MAX_CONCURRENCY` set at runtime +2. Deprecated `concurrency.max` in `cog.yaml`, which is baked into the image as `COG_MAX_CONCURRENCY` +3. `@cog.concurrent(max=N)` on the async `run()` method +4. Default: `1` ```console $ COG_MAX_CONCURRENCY=4 docker run -p 5000:5000 my-model @@ -1096,7 +1109,7 @@ If you want a working project to copy from, start with one of these: - [`hello-context`](https://github.com/replicate/cog/tree/main/examples/hello-context): shows how to read prediction context - [`hello-concurrency`](https://github.com/replicate/cog/tree/main/examples/hello-concurrency): - demonstrates the `concurrency` setting in `cog.yaml` + demonstrates the `@cog.concurrent` decorator - [`hello-train`](https://github.com/replicate/cog/tree/main/examples/hello-train): defines a training interface - [`hello-replicate`](https://github.com/replicate/cog/tree/main/examples/hello-replicate): @@ -1759,7 +1772,7 @@ Prefer: respond-async Endpoints for creating and canceling a prediction idempotently accept a `prediction_id` parameter in their path. By default, the server runs one prediction at a time, -but this can be increased with the [`concurrency.max`](yaml.md#concurrency) setting. +but this can be increased with [`@cog.concurrent(max=N)`](python.md#async-runners-and-concurrency). When all prediction slots are in use, the server returns `409 Conflict`. The client should ensure prediction slots are available before creating a new prediction with a different ID. @@ -2349,7 +2362,20 @@ class Runner(BaseRunner): return "hello world"; ``` -Models that have an async `run()` function can run concurrently, up to the limit specified by [`concurrency.max`](yaml.md#max) in cog.yaml. Attempting to exceed this limit will return a 409 Conflict response. +Models that have an async `run()` function can run concurrently. Use `@cog.concurrent(max=N)` to configure the default maximum concurrency for the function: + +```py +import cog + +class Runner(cog.BaseRunner): + @cog.concurrent(max=4) + async def run(self) -> str: + return "hello world" +``` + +Attempting to exceed this limit will return a 409 Conflict response. `max` values greater than `1` require an async `run()` method. The `COG_MAX_CONCURRENCY` environment variable can override the decorator at runtime. + +The `max` value must be an integer literal so Cog can configure the model at build time. Use `COG_MAX_CONCURRENCY` when you need to set concurrency dynamically at runtime. ## `Input(**kwargs)` @@ -3574,8 +3600,9 @@ build: ## `concurrency` > Added in cog 0.14.0. +> Deprecated: use [`@cog.concurrent(max=N)`](python.md#async-runners-and-concurrency) on your async `run()` method instead. -This stanza describes the concurrency capabilities of the model. It has one option: +This stanza describes the concurrency capabilities of the model. It is still supported for backwards compatibility, but new models should use `@cog.concurrent(max=N)`. It has one option: ### `max` diff --git a/docs/python.md b/docs/python.md index fe9307aa3b..218af5f33f 100644 --- a/docs/python.md +++ b/docs/python.md @@ -128,7 +128,20 @@ class Runner(BaseRunner): return "hello world"; ``` -Models that have an async `run()` function can run concurrently, up to the limit specified by [`concurrency.max`](yaml.md#max) in cog.yaml. Attempting to exceed this limit will return a 409 Conflict response. +Models that have an async `run()` function can run concurrently. Use `@cog.concurrent(max=N)` to configure the default maximum concurrency for the function: + +```py +import cog + +class Runner(cog.BaseRunner): + @cog.concurrent(max=4) + async def run(self) -> str: + return "hello world" +``` + +Attempting to exceed this limit will return a 409 Conflict response. `max` values greater than `1` require an async `run()` method. The `COG_MAX_CONCURRENCY` environment variable can override the decorator at runtime. + +The `max` value must be an integer literal so Cog can configure the model at build time. Use `COG_MAX_CONCURRENCY` when you need to set concurrency dynamically at runtime. ## `Input(**kwargs)` diff --git a/docs/yaml.md b/docs/yaml.md index 535b9aafa4..4c0000f8bb 100644 --- a/docs/yaml.md +++ b/docs/yaml.md @@ -189,8 +189,9 @@ build: ## `concurrency` > Added in cog 0.14.0. +> Deprecated: use [`@cog.concurrent(max=N)`](python.md#async-runners-and-concurrency) on your async `run()` method instead. -This stanza describes the concurrency capabilities of the model. It has one option: +This stanza describes the concurrency capabilities of the model. It is still supported for backwards compatibility, but new models should use `@cog.concurrent(max=N)`. It has one option: ### `max` diff --git a/examples/hello-concurrency/README.md b/examples/hello-concurrency/README.md index a3e85c8552..68c8cfed3e 100644 --- a/examples/hello-concurrency/README.md +++ b/examples/hello-concurrency/README.md @@ -1,13 +1,16 @@ # hello-concurrency -This is an example Cog project that demonstrates the newly added concurrency support within -cog >= 0.14.0. +This is an example Cog project that demonstrates concurrency support within Cog. -The key piece is the new `concurrency` field in the cog.yaml. +The key piece is the `@concurrent(max=4)` decorator on the async `run()` method. -```yaml -concurrency: - max: 4 +```py +from cog import BaseRunner, concurrent + +class Runner(BaseRunner): + @concurrent(max=4) + async def run(self) -> str: + return "hello" ``` This combined with the async setup and run methods in `run.py` allows Cog to run up to diff --git a/examples/hello-concurrency/cog.yaml b/examples/hello-concurrency/cog.yaml index 9be8201142..83e0e89397 100644 --- a/examples/hello-concurrency/cog.yaml +++ b/examples/hello-concurrency/cog.yaml @@ -5,5 +5,3 @@ build: python_version: "3.12" python_requirements: requirements.txt run: "run.py:Runner" -concurrency: - max: 4 diff --git a/examples/hello-concurrency/run.py b/examples/hello-concurrency/run.py index 090f5d562e..f5616bbabc 100644 --- a/examples/hello-concurrency/run.py +++ b/examples/hello-concurrency/run.py @@ -19,6 +19,7 @@ BaseRunner, Input, __version__, + concurrent, current_scope, ) @@ -72,6 +73,7 @@ async def setup(self) -> None: logging.info(f"completed setup in {duration} seconds") span.set_attribute("model.setup_time_seconds", duration) + @concurrent(max=4) async def run( # pyright: ignore self, total: int = Input(default=5), diff --git a/integration-tests/concurrent/concurrent_test.go b/integration-tests/concurrent/concurrent_test.go index bb1cc2ece6..b5ce645ce7 100644 --- a/integration-tests/concurrent/concurrent_test.go +++ b/integration-tests/concurrent/concurrent_test.go @@ -259,15 +259,14 @@ func allocatePort() (int, error) { const cogYAML = `build: python_version: "3.11" predict: "predict.py:Predictor" -concurrency: - max: 5 ` const predictPy = `import asyncio -from cog import BasePredictor +from cog import BasePredictor, concurrent class Predictor(BasePredictor): + @concurrent(max=5) async def predict(self, s: str, sleep: float) -> str: await asyncio.sleep(sleep) return f"wake up {s}" @@ -399,14 +398,13 @@ func TestConcurrentAsyncMetrics(t *testing.T) { metricsCogYAML := `build: python_version: "3.12" predict: "predict.py:Predictor" -concurrency: - max: 5 ` metricsPredictPy := `import asyncio -from cog import BasePredictor, current_scope +from cog import BasePredictor, concurrent, current_scope class Predictor(BasePredictor): + @concurrent(max=5) async def predict(self, idx: int = 0, sleep: float = 0.5) -> str: scope = current_scope() scope.record_metric("prediction_index", idx) diff --git a/pkg/config/validate.go b/pkg/config/validate.go index 71aaaa2437..cd16752af0 100644 --- a/pkg/config/validate.go +++ b/pkg/config/validate.go @@ -688,6 +688,14 @@ func checkDeprecatedFields(cfg *configFile, result *ValidationResult) { }) } + if cfg.Concurrency != nil && cfg.Concurrency.Max != nil { + result.AddWarning(DeprecationWarning{ + Field: "concurrency.max", + Replacement: "@cog.concurrent(max=...)", + Message: "configure prediction concurrency with @cog.concurrent on your async run() method instead", + }) + } + if cfg.Build == nil { return } diff --git a/pkg/config/validate_test.go b/pkg/config/validate_test.go index 25ea320695..316698ecf0 100644 --- a/pkg/config/validate_test.go +++ b/pkg/config/validate_test.go @@ -159,6 +159,25 @@ func TestValidateConfigFileConcurrencyType(t *testing.T) { result := ValidateConfigFile(cfg) require.False(t, result.HasErrors(), "expected no errors, got: %v", result.Errors) + require.Contains(t, result.Warnings, DeprecationWarning{ + Field: "concurrency.max", + Replacement: "@cog.concurrent(max=...)", + Message: "configure prediction concurrency with @cog.concurrent on your async run() method instead", + }) +} + +func TestValidateConfigFileConcurrencyDeprecationWithoutBuild(t *testing.T) { + cfg := &configFile{ + Run: new("run.py:Runner"), + Concurrency: &concurrencyFile{ + Max: new(2), + }, + } + + result := ValidateConfigFile(cfg) + require.False(t, result.HasErrors()) + require.Len(t, result.Warnings, 1) + require.Equal(t, "concurrency.max", result.Warnings[0].Field) } func TestValidateConfigFileDeprecatedPythonPackages(t *testing.T) { diff --git a/pkg/image/build.go b/pkg/image/build.go index eb8a98477e..57362f4423 100644 --- a/pkg/image/build.go +++ b/pkg/image/build.go @@ -13,6 +13,7 @@ import ( "os/exec" "path/filepath" "slices" + "strconv" "strings" "time" @@ -124,6 +125,10 @@ func Build( // Generate schema before the Docker build so schema errors fail fast and the // schema file is available in the build context. var schemaJSON []byte + predictorInfo, err := generatePredictorMetadata(cfg, dir) + if err != nil { + return "", fmt.Errorf("image build failed: %w", err) + } switch { case needsSchema: if err := validateStaticSchemaSDKVersion(cfg); err != nil { @@ -159,6 +164,11 @@ func Build( } } + dockerfileCfg, err := configWithDecoratorConcurrency(cfg, predictorInfo) + if err != nil { + return "", err + } + // --- Runtime weights manifest (/.cog/weights.json) --- // When managed weights are configured and a lockfile exists, project the // lockfile to the minimal runtime manifest (spec §3.3) and write it into @@ -206,8 +216,11 @@ func Build( if _, err := dockerCommand.ImageBuild(ctx, buildOpts); err != nil { return "", fmt.Errorf("Failed to build Docker image: %w", err) } + if err := addConcurrencyToCustomDockerfileImage(ctx, dockerCommand, tmpImageId, dockerfileCfg.Concurrency, progressOutput, bp.buildDir); err != nil { + return "", err + } } else { - generator, err := dockerfile.NewStandardGenerator(cfg, dir, bp.buildDir, configFilename, dockerCommand, client, true) + generator, err := dockerfile.NewStandardGenerator(dockerfileCfg, dir, bp.buildDir, configFilename, dockerCommand, client, true) if err != nil { return "", fmt.Errorf("Error creating Dockerfile generator: %w", err) } @@ -446,6 +459,83 @@ func generateStaticSchema(cfg *config.Config, dir string) ([]byte, error) { return nil, fmt.Errorf("no predict or train reference found in cog.yaml") } return schema.GenerateCombined(dir, cfg.Predict, cfg.Train, schema.PathAwareParser(python.ParsePredictorWithSourcePath)) +} + +func generatePredictorMetadata(cfg *config.Config, dir string) (*schema.PredictorInfo, error) { + if cfg.Predict == "" && cfg.Train == "" { + return nil, nil + } + if cfg.Predict != "" { + return schema.GenerateInfo(cfg.Predict, dir, schema.ModePredict, schema.PathAwareParser(python.ParsePredictorMetadataWithSourcePath)) + } + return schema.GenerateInfo(cfg.Train, dir, schema.ModeTrain, schema.PathAwareParser(python.ParsePredictorMetadataWithSourcePath)) +} + +func configWithDecoratorConcurrency(cfg *config.Config, info *schema.PredictorInfo) (*config.Config, error) { + if cfg.Concurrency != nil && cfg.Concurrency.Max > 1 && info != nil && !info.IsAsync { + return nil, fmt.Errorf("concurrency.max > 1 requires an async run() or predict() method") + } + if cfg.Concurrency != nil || info == nil || info.ConcurrencyMax == nil { + return cfg, validateEffectiveConcurrency(cfg, info) + } + if *info.ConcurrencyMax > 1 && !info.IsAsync { + return nil, fmt.Errorf("@cog.concurrent(max > 1) requires an async run() or predict() method") + } + cfgCopy := *cfg + cfgCopy.Concurrency = &config.Concurrency{Max: *info.ConcurrencyMax} + return &cfgCopy, validateEffectiveConcurrency(&cfgCopy, info) +} + +func addConcurrencyToCustomDockerfileImage(ctx context.Context, dockerCommand command.Command, imageName string, concurrency *config.Concurrency, progressOutput string, buildCacheDir string) error { + if concurrency == nil || concurrency.Max <= 0 { + return nil + } + buildOpts := command.ImageBuildOptions{ + DockerfileContents: concurrencyDockerfile(imageName, concurrency.Max), + ImageName: imageName, + ProgressOutput: progressOutput, + BuildCacheDir: buildCacheDir, + } + if _, err := dockerCommand.ImageBuild(ctx, buildOpts); err != nil { + return fmt.Errorf("Failed to add concurrency configuration to Docker image: %w", err) + } + return nil +} + +func validateEffectiveConcurrency(cfg *config.Config, info *schema.PredictorInfo) error { + if cfg.Concurrency == nil || cfg.Concurrency.Max <= 1 { + return nil + } + if info != nil && !info.IsAsync { + return fmt.Errorf("concurrency.max > 1 requires an async run() or predict() method") + } + if cfg.Build == nil || cfg.Build.PythonVersion == "" { + return nil + } + major, minor, err := splitPythonVersionForBuild(cfg.Build.PythonVersion) + if err != nil { + return nil + } + if major == config.MinimumMajorPythonVersion && minor < config.MinimumMinorPythonVersionForConcurrency { + return fmt.Errorf("concurrency requires Python %d.%d or higher", config.MinimumMajorPythonVersion, config.MinimumMinorPythonVersionForConcurrency) + } + return nil +} + +func splitPythonVersionForBuild(version string) (major int, minor int, err error) { + parts := strings.Split(version, ".") + if len(parts) < 2 { + return 0, 0, fmt.Errorf("invalid Python version %q", version) + } + major, err = strconv.Atoi(parts[0]) + if err != nil { + return 0, 0, err + } + minor, err = strconv.Atoi(parts[1]) + if err != nil { + return 0, 0, err + } + return major, minor, nil } @@ -546,6 +636,10 @@ func bundleDockerfile(baseImage string, files []string) string { return b.String() } +func concurrencyDockerfile(baseImage string, maxConcurrency int) string { + return fmt.Sprintf("FROM %s\nENV COG_MAX_CONCURRENCY=%d\n", baseImage, maxConcurrency) +} + func isGitWorkTree(ctx context.Context, dir string) bool { ctx, cancel := context.WithTimeout(ctx, 3*time.Second) defer cancel() diff --git a/pkg/image/build_test.go b/pkg/image/build_test.go index bf2af1e229..1089e22e51 100644 --- a/pkg/image/build_test.go +++ b/pkg/image/build_test.go @@ -14,7 +14,10 @@ import ( "github.com/stretchr/testify/require" "github.com/replicate/cog/pkg/config" + "github.com/replicate/cog/pkg/docker/command" + "github.com/replicate/cog/pkg/docker/dockertest" "github.com/replicate/cog/pkg/dotcog" + "github.com/replicate/cog/pkg/schema" "github.com/replicate/cog/pkg/weights/lockfile" ) @@ -62,6 +65,103 @@ func TestIsGitWorkTree(t *testing.T) { r.True(isGitWorkTree(ctx, setupGitWorkTree(t))) } +func TestConfigWithDecoratorConcurrencyAppliesWhenYAMLAbsent(t *testing.T) { + concurrencyMax := 4 + cfg := &config.Config{} + info := &schema.PredictorInfo{ConcurrencyMax: &concurrencyMax, IsAsync: true} + + got, err := configWithDecoratorConcurrency(cfg, info) + + require.NoError(t, err) + require.NotSame(t, cfg, got) + require.Nil(t, cfg.Concurrency) + require.NotNil(t, got.Concurrency) + require.Equal(t, 4, got.Concurrency.Max) +} + +func TestConfigWithDecoratorConcurrencyPreservesYAMLPrecedence(t *testing.T) { + concurrencyMax := 4 + cfg := &config.Config{Concurrency: &config.Concurrency{Max: 2}} + info := &schema.PredictorInfo{ConcurrencyMax: &concurrencyMax, IsAsync: true} + + got, err := configWithDecoratorConcurrency(cfg, info) + + require.NoError(t, err) + require.Same(t, cfg, got) + require.Equal(t, 2, got.Concurrency.Max) +} + +func TestConfigWithDecoratorConcurrencyRejectsSyncYAMLConcurrency(t *testing.T) { + cfg := &config.Config{Concurrency: &config.Concurrency{Max: 2}} + info := &schema.PredictorInfo{IsAsync: false} + + _, err := configWithDecoratorConcurrency(cfg, info) + + require.Error(t, err) + require.Contains(t, err.Error(), "requires an async") +} + +func TestConfigWithDecoratorConcurrencyRejectsOldPythonVersion(t *testing.T) { + concurrencyMax := 2 + cfg := &config.Config{Build: &config.Build{PythonVersion: "3.10"}} + info := &schema.PredictorInfo{ConcurrencyMax: &concurrencyMax, IsAsync: true} + + _, err := configWithDecoratorConcurrency(cfg, info) + + require.Error(t, err) + require.Contains(t, err.Error(), "Python 3.11 or higher") +} + +func TestConcurrencyDockerfileSetsEnv(t *testing.T) { + dockerfile := concurrencyDockerfile("my-image", 4) + + require.Equal(t, "FROM my-image\nENV COG_MAX_CONCURRENCY=4\n", dockerfile) +} + +type recordingCommand struct { + *dockertest.MockCommand + builds []command.ImageBuildOptions +} + +func (c *recordingCommand) ImageBuild(ctx context.Context, options command.ImageBuildOptions) (string, error) { + c.builds = append(c.builds, options) + return "sha256:test", nil +} + +func TestAddConcurrencyToCustomDockerfileImageBuildsWrapperLayer(t *testing.T) { + dockerCommand := &recordingCommand{MockCommand: dockertest.NewMockCommand()} + concurrency := &config.Concurrency{Max: 4} + + err := addConcurrencyToCustomDockerfileImage(t.Context(), dockerCommand, "my-image", concurrency, "plain", "/tmp/build-cache") + + require.NoError(t, err) + require.Len(t, dockerCommand.builds, 1) + require.Equal(t, "FROM my-image\nENV COG_MAX_CONCURRENCY=4\n", dockerCommand.builds[0].DockerfileContents) + require.Equal(t, "my-image", dockerCommand.builds[0].ImageName) + require.Equal(t, "plain", dockerCommand.builds[0].ProgressOutput) + require.Equal(t, "/tmp/build-cache", dockerCommand.builds[0].BuildCacheDir) +} + +func TestGeneratePredictorMetadataDoesNotRequireValidOutputSchema(t *testing.T) { + dir := t.TempDir() + require.NoError(t, os.WriteFile(filepath.Join(dir, "predict.py"), []byte(` +from cog import BasePredictor, concurrent + +class Predictor(BasePredictor): + @concurrent(max=3) + async def predict(self) -> NotARealType: + return "hello" +`), 0o644)) + cfg := &config.Config{Predict: "predict.py:Predictor"} + + info, err := generatePredictorMetadata(cfg, dir) + + require.NoError(t, err) + require.NotNil(t, info.ConcurrencyMax) + require.Equal(t, 3, *info.ConcurrencyMax) + require.True(t, info.IsAsync) +} + func TestGitHead(t *testing.T) { t.Run("via github env", func(t *testing.T) { t.Setenv("GITHUB_SHA", "fafafaf") diff --git a/pkg/schema/generator.go b/pkg/schema/generator.go index efdddebfd0..790873beb2 100644 --- a/pkg/schema/generator.go +++ b/pkg/schema/generator.go @@ -40,6 +40,15 @@ func Generate(predictRef string, sourceDir string, mode Mode, parse any) ([]byte return data, nil } + info, err := GenerateInfo(predictRef, sourceDir, mode, parse) + if err != nil { + return nil, err + } + return GenerateOpenAPISchema(info) +} + +// GenerateInfo parses predictor source and returns the extracted predictor info. +func GenerateInfo(predictRef string, sourceDir string, mode Mode, parse any) (*PredictorInfo, error) { filePath, className, err := parsePredictRef(predictRef) if err != nil { return nil, err @@ -59,7 +68,7 @@ func Generate(predictRef string, sourceDir string, mode Mode, parse any) ([]byte return nil, fmt.Errorf("failed to read predictor source %s: %w", fullPath, err) } - return generateFromSource(source, className, mode, parse, sourceDir, filePath) + return parsePredictorInfo(parse, source, className, mode, sourceDir, filePath) } func cleanSourceFilePath(filePath string) (string, error) { diff --git a/pkg/schema/python/annotations.go b/pkg/schema/python/annotations.go index 5310da4c21..05ca7148a6 100644 --- a/pkg/schema/python/annotations.go +++ b/pkg/schema/python/annotations.go @@ -243,6 +243,69 @@ func functionSupportsStreaming(node *sitter.Node, source []byte, imports *schema return false } +func functionIsAsync(node *sitter.Node, source []byte) bool { + return strings.HasPrefix(strings.TrimSpace(Content(node, source)), "async def ") +} + +func functionConcurrencyMax(node *sitter.Node, source []byte, imports *schema.ImportContext) (*int, error) { + decorated := decoratedFunctionNode(node) + if decorated == nil { + return nil, nil + } + + for _, child := range NamedChildren(decorated) { + if child.Type() != "decorator" { + continue + } + concurrencyMax, ok, err := decoratorCogConcurrentMax(child, source, imports) + if err != nil { + return nil, err + } + if ok { + return concurrencyMax, nil + } + } + return nil, nil +} + +func decoratorCogConcurrentMax(node *sitter.Node, source []byte, imports *schema.ImportContext) (*int, bool, error) { + expr := decoratorExpression(node) + if expr == nil || !expressionIsCogConcurrent(expr, source, imports) { + return nil, false, nil + } + + call := decoratorCall(node) + if call == nil { + concurrencyMax := 1 + return &concurrencyMax, true, nil + } + + argList := callArgumentList(call) + if argList == nil || len(NamedChildren(argList)) == 0 { + concurrencyMax := 1 + return &concurrencyMax, true, nil + } + + args := NamedChildren(argList) + if len(args) != 1 || args[0].Type() != "keyword_argument" { + return nil, true, schema.WrapError(schema.ErrUnsupportedType, "@concurrent only supports literal integer max=... arguments", nil) + } + nameNode := args[0].ChildByFieldName("name") + valueNode := args[0].ChildByFieldName("value") + if nameNode == nil || valueNode == nil || Content(nameNode, source) != "max" { + return nil, true, schema.WrapError(schema.ErrUnsupportedType, "@concurrent only supports literal integer max=... arguments", nil) + } + val, ok := parseDefaultValue(valueNode, source) + if !ok || val.Kind != schema.DefaultInt { + return nil, true, schema.WrapError(schema.ErrUnsupportedType, "@concurrent max must be an integer literal", nil) + } + if val.Int < 1 { + return nil, true, schema.WrapError(schema.ErrUnsupportedType, "@concurrent max must be at least 1", nil) + } + concurrencyMax := int(val.Int) + return &concurrencyMax, true, nil +} + func decoratedFunctionNode(node *sitter.Node) *sitter.Node { if node.Type() == "decorated_definition" { return node @@ -277,6 +340,23 @@ func decoratorExpression(node *sitter.Node) *sitter.Node { return child } +func decoratorCall(node *sitter.Node) *sitter.Node { + children := NamedChildren(node) + if len(children) == 0 || children[0].Type() != "call" { + return nil + } + return children[0] +} + +func callArgumentList(node *sitter.Node) *sitter.Node { + for _, child := range NamedChildren(node) { + if child.Type() == "argument_list" { + return child + } + } + return nil +} + func expressionIsCogStreaming(node *sitter.Node, source []byte, imports *schema.ImportContext) bool { switch node.Type() { case "attribute": @@ -288,6 +368,17 @@ func expressionIsCogStreaming(node *sitter.Node, source []byte, imports *schema. } } +func expressionIsCogConcurrent(node *sitter.Node, source []byte, imports *schema.ImportContext) bool { + switch node.Type() { + case "attribute": + return attributeIsCogConcurrent(node, source, imports) + case "identifier": + return identifierIsCogConcurrent(node, source, imports) + default: + return false + } +} + func attributeIsCogStreaming(node *sitter.Node, source []byte, imports *schema.ImportContext) bool { name, attr, ok := strings.Cut(Content(node, source), ".") if !ok || attr != "streaming" { @@ -297,11 +388,25 @@ func attributeIsCogStreaming(node *sitter.Node, source []byte, imports *schema.I return ok && entry.Module == "cog" && entry.Original == "cog" } +func attributeIsCogConcurrent(node *sitter.Node, source []byte, imports *schema.ImportContext) bool { + name, attr, ok := strings.Cut(Content(node, source), ".") + if !ok || attr != "concurrent" { + return false + } + entry, ok := imports.Names.Get(name) + return ok && entry.Module == "cog" && entry.Original == "cog" +} + func identifierIsCogStreaming(node *sitter.Node, source []byte, imports *schema.ImportContext) bool { entry, ok := imports.Names.Get(Content(node, source)) return ok && entry.Module == "cog" && entry.Original == "streaming" } +func identifierIsCogConcurrent(node *sitter.Node, source []byte, imports *schema.ImportContext) bool { + entry, ok := imports.Names.Get(Content(node, source)) + return ok && entry.Module == "cog" && entry.Original == "concurrent" +} + func supportsStreamingOutput(output schema.SchemaType) bool { return output.Kind == schema.SchemaIterator || output.Kind == schema.SchemaConcatIterator } diff --git a/pkg/schema/python/parser.go b/pkg/schema/python/parser.go index 4ee6557d33..093a1e6e85 100644 --- a/pkg/schema/python/parser.go +++ b/pkg/schema/python/parser.go @@ -35,6 +35,47 @@ func ParsePredictorWithSourcePath(source []byte, targetRef string, mode schema.M return ParseWithOptions(opts) } +func ParsePredictorMetadataWithSourcePath(source []byte, targetRef string, mode schema.Mode, sourceDir string, sourcePath string) (*schema.PredictorInfo, error) { + opts := defaultParserOptions(source, targetRef, mode, sourceDir) + opts.SourcePath = sourcePath + return ParseMetadataWithOptions(opts) +} + +func ParseMetadataWithOptions(opts ParserOptions) (*schema.PredictorInfo, error) { + sourcePath, err := normalizeSourcePath(opts.SourceDir, opts.SourcePath) + if err != nil { + return nil, err + } + opts.SourcePath = sourcePath + state := newParseState(opts) + phases := []parserPhase{ + {Name: "parse module", From: phaseInitial, To: phaseModuleParsed, Run: parseModulePhase}, + {Name: "collect imports", From: phaseModuleParsed, To: phaseImportsCollected, Run: collectImportsPhase}, + {Name: "collect module scope", From: phaseImportsCollected, To: phaseModuleScopeCollected, Run: collectModuleScopePhase}, + {Name: "collect local models", From: phaseModuleScopeCollected, To: phaseLocalModelsCollected, Run: collectLocalModelsPhase}, + {Name: "resolve imported models", From: phaseLocalModelsCollected, To: phaseImportedModelsResolved, Run: resolveImportedModelsPhase}, + {Name: "collect input registry", From: phaseImportedModelsResolved, To: phaseInputRegistryCollected, Run: collectInputRegistryPhase}, + {Name: "find target callable", From: phaseInputRegistryCollected, To: phaseTargetFound, Run: findTargetCallablePhase}, + } + if err := runPhases(state, phases); err != nil { + return nil, err + } + targetSource := state.TargetFunc.file.source + isAsync := functionIsAsync(state.TargetFunc.node, targetSource) + concurrencyMax, err := functionConcurrencyMax(state.TargetFunc.node, targetSource, state.TargetFunc.file.imports) + if err != nil { + return nil, err + } + if concurrencyMax != nil && *concurrencyMax > 1 && !isAsync { + return nil, schema.WrapError(schema.ErrUnsupportedType, "@concurrent(max > 1) requires an async run() or predict() method", nil) + } + return &schema.PredictorInfo{ + Mode: opts.Mode, + ConcurrencyMax: concurrencyMax, + IsAsync: isAsync, + }, nil +} + func ParseWithOptions(opts ParserOptions) (*schema.PredictorInfo, error) { sourcePath, err := normalizeSourcePath(opts.SourceDir, opts.SourcePath) if err != nil { @@ -62,6 +103,8 @@ func ParseWithOptions(opts ParserOptions) (*schema.PredictorInfo, error) { Output: state.Output, Mode: opts.Mode, SupportsStreaming: state.SupportsStreaming, + ConcurrencyMax: state.ConcurrencyMax, + IsAsync: state.IsAsync, }, nil } @@ -208,6 +251,7 @@ func extractInputsPhase(state *ParseState) error { } state.Target = &TargetCallable{MethodName: actualMethodName, Node: funcNode, IsMethod: isMethod} state.Inputs = inputs + state.IsAsync = functionIsAsync(funcNode, targetSource) return nil } @@ -231,8 +275,16 @@ func resolveOutputPhase(state *ParseState) error { if supportsStreaming && !supportsStreamingOutput(output) { return schema.WrapError(schema.ErrUnsupportedType, "@streaming requires Iterator[...], AsyncIterator[...], ConcatenateIterator[...] or AsyncConcatenateIterator[...] return type", nil) } + concurrencyMax, err := functionConcurrencyMax(state.TargetFunc.node, state.TargetFunc.file.source, state.TargetFunc.file.imports) + if err != nil { + return err + } + if concurrencyMax != nil && *concurrencyMax > 1 && !state.IsAsync { + return schema.WrapError(schema.ErrUnsupportedType, "@concurrent(max > 1) requires an async run() or predict() method", nil) + } state.Output = output state.SupportsStreaming = supportsStreaming + state.ConcurrencyMax = concurrencyMax state.OutputSet = true return nil } diff --git a/pkg/schema/python/parser_test.go b/pkg/schema/python/parser_test.go index de0a71f267..93182a3acf 100644 --- a/pkg/schema/python/parser_test.go +++ b/pkg/schema/python/parser_test.go @@ -963,6 +963,163 @@ class Predictor(cog.BasePredictor): require.False(t, info.SupportsStreaming) } +func TestConcurrentDecoratorQualifiedMax(t *testing.T) { + source := ` +import cog + +class Predictor(cog.BasePredictor): + @cog.concurrent(max=4) + async def predict(self) -> str: + return "hello" +` + info := parse(t, source, "Predictor") + require.NotNil(t, info.ConcurrencyMax) + require.Equal(t, 4, *info.ConcurrencyMax) + require.True(t, info.IsAsync) +} + +func TestConcurrentDecoratorImportedMax(t *testing.T) { + source := ` +from cog import BasePredictor, concurrent + +class Predictor(BasePredictor): + @concurrent(max=3) + async def predict(self) -> str: + return "hello" +` + info := parse(t, source, "Predictor") + require.NotNil(t, info.ConcurrencyMax) + require.Equal(t, 3, *info.ConcurrencyMax) +} + +func TestConcurrentDecoratorBareDefaultsToOne(t *testing.T) { + source := ` +from cog import BasePredictor, concurrent + +class Predictor(BasePredictor): + @concurrent + def predict(self) -> str: + return "hello" +` + info := parse(t, source, "Predictor") + require.NotNil(t, info.ConcurrencyMax) + require.Equal(t, 1, *info.ConcurrencyMax) + require.False(t, info.IsAsync) +} + +func TestConcurrentDecoratorCallDefaultsToOne(t *testing.T) { + source := ` +import cog + +class Predictor(cog.BasePredictor): + @cog.concurrent() + def predict(self) -> str: + return "hello" +` + info := parse(t, source, "Predictor") + require.NotNil(t, info.ConcurrencyMax) + require.Equal(t, 1, *info.ConcurrencyMax) +} + +func TestConcurrentDecoratorImportedAliasMax(t *testing.T) { + source := ` +from cog import BasePredictor, concurrent as concurrent_predictions + +class Predictor(BasePredictor): + @concurrent_predictions(max=2) + async def predict(self) -> str: + return "hello" +` + info := parse(t, source, "Predictor") + require.NotNil(t, info.ConcurrencyMax) + require.Equal(t, 2, *info.ConcurrencyMax) +} + +func TestConcurrentDecoratorIgnoredWhenNotFromCog(t *testing.T) { + source := ` +from other import concurrent +from cog import BasePredictor + +class Predictor(BasePredictor): + @concurrent(max=4) + async def predict(self) -> str: + return "hello" +` + info := parse(t, source, "Predictor") + require.Nil(t, info.ConcurrencyMax) +} + +func TestConcurrentDecoratorRequiresAsyncForMaxGreaterThanOne(t *testing.T) { + source := ` +from cog import BasePredictor, concurrent + +class Predictor(BasePredictor): + @concurrent(max=2) + def predict(self) -> str: + return "hello" +` + se := parseErr(t, source, "Predictor", schema.ModePredict) + require.Equal(t, schema.ErrUnsupportedType, se.Kind) + require.Contains(t, se.Message, "requires an async") +} + +func TestConcurrentDecoratorRejectsNonLiteralMax(t *testing.T) { + source := ` +from cog import BasePredictor, concurrent + +MAX_CONCURRENCY = 2 + +class Predictor(BasePredictor): + @concurrent(max=MAX_CONCURRENCY) + async def predict(self) -> str: + return "hello" +` + se := parseErr(t, source, "Predictor", schema.ModePredict) + require.Equal(t, schema.ErrUnsupportedType, se.Kind) + require.Contains(t, se.Message, "integer literal") +} + +func TestConcurrentDecoratorRejectsPositionalArgument(t *testing.T) { + source := ` +from cog import BasePredictor, concurrent + +class Predictor(BasePredictor): + @concurrent(2) + async def predict(self) -> str: + return "hello" +` + se := parseErr(t, source, "Predictor", schema.ModePredict) + require.Equal(t, schema.ErrUnsupportedType, se.Kind) + require.Contains(t, se.Message, "max=...") +} + +func TestConcurrentDecoratorRejectsStringMax(t *testing.T) { + source := ` +from cog import BasePredictor, concurrent + +class Predictor(BasePredictor): + @concurrent(max="2") + async def predict(self) -> str: + return "hello" +` + se := parseErr(t, source, "Predictor", schema.ModePredict) + require.Equal(t, schema.ErrUnsupportedType, se.Kind) + require.Contains(t, se.Message, "integer literal") +} + +func TestConcurrentDecoratorUndecoratedHasNoMax(t *testing.T) { + source := ` +from cog import BasePredictor + +class Predictor(BasePredictor): + async def predict(self) -> str: + return "hello" +` + info := parse(t, source, "Predictor") + require.Nil(t, info.ConcurrencyMax) + require.True(t, info.IsAsync) +} + func TestListOutput(t *testing.T) { source := ` from cog import BasePredictor, Path diff --git a/pkg/schema/python/state.go b/pkg/schema/python/state.go index 17533067ec..dc0421ea16 100644 --- a/pkg/schema/python/state.go +++ b/pkg/schema/python/state.go @@ -104,6 +104,8 @@ type ParseState struct { Inputs *schema.OrderedMap[string, schema.InputField] Output schema.SchemaType SupportsStreaming bool + ConcurrencyMax *int + IsAsync bool OutputSet bool } diff --git a/pkg/schema/types.go b/pkg/schema/types.go index 0d86535801..38d7878f5d 100644 --- a/pkg/schema/types.go +++ b/pkg/schema/types.go @@ -254,6 +254,8 @@ type PredictorInfo struct { Output SchemaType Mode Mode SupportsStreaming bool + ConcurrencyMax *int + IsAsync bool } // TypeAnnotation is a parsed Python type annotation (intermediate, before resolution). diff --git a/python/cog/__init__.py b/python/cog/__init__.py index 775e2ba0e9..ba55c6902f 100644 --- a/python/cog/__init__.py +++ b/python/cog/__init__.py @@ -19,6 +19,7 @@ def run( return self.model.generate(prompt, image) """ +import inspect as _inspect import sys as _sys from collections.abc import Callable from typing import TypeVar, overload @@ -64,6 +65,41 @@ def decorate(inner: _F) -> _F: return decorate(fn) +@overload +def concurrent(fn: _F) -> _F: + pass + + +@overload +def concurrent(fn: None = None, *, max: int = 1) -> Callable[[_F], _F]: # noqa: A002 + pass + + +def concurrent( + fn: _F | None = None, + *, + max: int = 1, # noqa: A002 +) -> _F | Callable[[_F], _F]: + """Configure the maximum concurrency for an async predict handler.""" + + if isinstance(max, bool) or not isinstance(max, int): + raise TypeError("concurrent max must be an integer") + if max < 1: + raise ValueError("concurrent max must be at least 1") + + def decorate(inner: _F) -> _F: + if max > 1 and not ( + _inspect.iscoroutinefunction(inner) or _inspect.isasyncgenfunction(inner) + ): + raise TypeError("concurrent max greater than 1 requires an async function") + inner.__cog_concurrent_max__ = max # type: ignore[attr-defined] + return inner + + if fn is None: + return decorate + return decorate(fn) + + # --------------------------------------------------------------------------- # Backwards-compatibility shim: ExperimentalFeatureWarning # @@ -161,6 +197,7 @@ def current_scope() -> object: "current_scope", # Decorators "streaming", + "concurrent", # Deprecated compat shims "ExperimentalFeatureWarning", "emit_metric", diff --git a/python/tests/test_concurrent.py b/python/tests/test_concurrent.py new file mode 100644 index 0000000000..399d47a241 --- /dev/null +++ b/python/tests/test_concurrent.py @@ -0,0 +1,74 @@ +"""Tests for the cog.concurrent decorator.""" + +import sys +import types +from collections.abc import AsyncIterator + +import pytest + +if "coglet" not in sys.modules: + coglet = types.ModuleType("coglet") + coglet.CancelationException = type("CancelationException", (BaseException,), {}) + sys.modules["coglet"] = coglet + +from cog import concurrent + + +def test_bare_concurrent_sets_default_max_on_sync_function() -> None: + @concurrent + def predict() -> str: + return "ok" + + assert predict.__cog_concurrent_max__ == 1 + + +def test_concurrent_call_form_sets_default_max_on_sync_function() -> None: + @concurrent() + def predict() -> str: + return "ok" + + assert predict.__cog_concurrent_max__ == 1 + + +def test_concurrent_max_one_allows_sync_function() -> None: + @concurrent(max=1) + def predict() -> str: + return "ok" + + assert predict.__cog_concurrent_max__ == 1 + + +def test_concurrent_max_allows_async_function() -> None: + @concurrent(max=3) + async def predict() -> str: + return "ok" + + assert predict.__cog_concurrent_max__ == 3 + + +def test_concurrent_max_allows_async_generator() -> None: + @concurrent(max=3) + async def predict() -> AsyncIterator[str]: + yield "ok" + + assert predict.__cog_concurrent_max__ == 3 + + +def test_concurrent_max_rejects_sync_function() -> None: + with pytest.raises(TypeError, match="requires an async function"): + + @concurrent(max=2) + def predict() -> str: + return "ok" + + +@pytest.mark.parametrize("value", [0, -1]) +def test_concurrent_rejects_max_less_than_one(value: int) -> None: + with pytest.raises(ValueError, match="at least 1"): + concurrent(max=value) + + +@pytest.mark.parametrize("value", [True, 1.5, "2"]) +def test_concurrent_rejects_non_integer_max(value: object) -> None: + with pytest.raises(TypeError, match="must be an integer"): + concurrent(max=value) # type: ignore[arg-type] From 0b3f79249107362fca09bd10b2adc7c6472ac4d1 Mon Sep 17 00:00:00 2001 From: asahoo Date: Wed, 1 Jul 2026 18:41:43 -0500 Subject: [PATCH 4/7] fix: bound concurrent max conversion --- pkg/schema/python/annotations.go | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/pkg/schema/python/annotations.go b/pkg/schema/python/annotations.go index 05ca7148a6..90f940ff01 100644 --- a/pkg/schema/python/annotations.go +++ b/pkg/schema/python/annotations.go @@ -2,6 +2,7 @@ package python import ( "fmt" + "math" "strings" sitter "github.com/smacker/go-tree-sitter" @@ -302,6 +303,9 @@ func decoratorCogConcurrentMax(node *sitter.Node, source []byte, imports *schema if val.Int < 1 { return nil, true, schema.WrapError(schema.ErrUnsupportedType, "@concurrent max must be at least 1", nil) } + if val.Int > int64(math.MaxInt) { + return nil, true, schema.WrapError(schema.ErrUnsupportedType, "@concurrent max is too large", nil) + } concurrencyMax := int(val.Int) return &concurrencyMax, true, nil } From bc66ab7c9814af12515e782522961f556e458d2d Mon Sep 17 00:00:00 2001 From: Anish Sahoo Date: Wed, 1 Jul 2026 19:03:19 -0500 Subject: [PATCH 5/7] Update pkg/schema/python/parser_test.go Co-authored-by: ask-bonk[bot] <249159057+ask-bonk[bot]@users.noreply.github.com> Signed-off-by: Anish Sahoo --- pkg/schema/python/parser_test.go | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/pkg/schema/python/parser_test.go b/pkg/schema/python/parser_test.go index 93182a3acf..2fd1a1eee9 100644 --- a/pkg/schema/python/parser_test.go +++ b/pkg/schema/python/parser_test.go @@ -1035,6 +1035,23 @@ class Predictor(BasePredictor): require.Equal(t, 2, *info.ConcurrencyMax) } +func TestConcurrentDecoratorQualifiedAliasMax(t *testing.T) { + source := ` +import cog as c + +class Predictor(c.BasePredictor): + @c.concurrent(max=4) + async def predict(self) -> str: + return "hello" +` + info := parse(t, source, "Predictor") + require.NotNil(t, info.ConcurrencyMax) + require.Equal(t, 4, *info.ConcurrencyMax) + require.True(t, info.IsAsync) +} + +func TestConcurrentDecoratorIgnoredWhenNotFromCog(t *testing.T) { + func TestConcurrentDecoratorIgnoredWhenNotFromCog(t *testing.T) { source := ` from other import concurrent From e0e1b47f4aae743dd2f74656b254618d84e845d1 Mon Sep 17 00:00:00 2001 From: asahoo Date: Wed, 1 Jul 2026 19:05:09 -0500 Subject: [PATCH 6/7] fix: remove duplicate function declaration in parser_test.go Merge from #3089 introduced a duplicate TestConcurrentDecoratorIgnoredWhenNotFromCog function signature that caused a syntax error. --- README.md | 88 ++++++++++++++++---------------- pkg/schema/python/parser_test.go | 2 - 2 files changed, 44 insertions(+), 46 deletions(-) diff --git a/README.md b/README.md index 9e753f46d2..21fc5d87a3 100644 --- a/README.md +++ b/README.md @@ -127,83 +127,83 @@ Choose your platform for installation instructions.
macOS - The easiest way to install Cog on macOS is with Homebrew: +The easiest way to install Cog on macOS is with Homebrew: - ```console - brew install replicate/tap/cog - ``` +```console +brew install replicate/tap/cog +``` - You can also use the install script: +You can also use the install script: - ```sh - # bash, zsh, and other shells - sh <(curl -fsSL https://cog.run/install.sh) +```sh +# bash, zsh, and other shells +sh <(curl -fsSL https://cog.run/install.sh) - # fish shell - sh (curl -fsSL https://cog.run/install.sh | psub) +# fish shell +sh (curl -fsSL https://cog.run/install.sh | psub) - # download with wget and run in a separate command - wget -qO- https://cog.run/install.sh - sh ./install.sh - ``` +# download with wget and run in a separate command +wget -qO- https://cog.run/install.sh +sh ./install.sh +``` - Or install manually: +Or install manually: - ```console - sudo curl -o /usr/local/bin/cog -L "https://github.com/replicate/cog/releases/latest/download/cog_$(uname -s)_$(uname -m | sed 's/aarch64/arm64/')" - sudo chmod +x /usr/local/bin/cog - sudo xattr -d com.apple.quarantine /usr/local/bin/cog 2>/dev/null || true - ``` +```console +sudo curl -o /usr/local/bin/cog -L "https://github.com/replicate/cog/releases/latest/download/cog_$(uname -s)_$(uname -m | sed 's/aarch64/arm64/')" +sudo chmod +x /usr/local/bin/cog +sudo xattr -d com.apple.quarantine /usr/local/bin/cog 2>/dev/null || true +``` - If you see a Gatekeeper warning saying the binary "cannot be opened because the developer cannot be verified", run: +If you see a Gatekeeper warning saying the binary "cannot be opened because the developer cannot be verified", run: - ```console - sudo xattr -d com.apple.quarantine /usr/local/bin/cog - ``` +```console +sudo xattr -d com.apple.quarantine /usr/local/bin/cog +```
Linux - You can install Cog using the install script: +You can install Cog using the install script: - ```sh - # bash, zsh, and other shells - sh <(curl -fsSL https://cog.run/install.sh) +```sh +# bash, zsh, and other shells +sh <(curl -fsSL https://cog.run/install.sh) - # fish shell - sh (curl -fsSL https://cog.run/install.sh | psub) +# fish shell +sh (curl -fsSL https://cog.run/install.sh | psub) - # download with wget and run in a separate command - wget -qO- https://cog.run/install.sh - sh ./install.sh - ``` +# download with wget and run in a separate command +wget -qO- https://cog.run/install.sh +sh ./install.sh +``` - Or install manually: +Or install manually: - ```console - sudo curl -o /usr/local/bin/cog -L "https://github.com/replicate/cog/releases/latest/download/cog_$(uname -s)_$(uname -m | sed 's/aarch64/arm64/')" - sudo chmod +x /usr/local/bin/cog - ``` +```console +sudo curl -o /usr/local/bin/cog -L "https://github.com/replicate/cog/releases/latest/download/cog_$(uname -s)_$(uname -m | sed 's/aarch64/arm64/')" +sudo chmod +x /usr/local/bin/cog +```
Windows - Cog does not natively support Windows, but you can run it on Windows 11 using [WSL 2](docs/wsl2/wsl2.md). Once WSL 2 is set up, follow the Linux installation instructions above. +Cog does not natively support Windows, but you can run it on Windows 11 using [WSL 2](docs/wsl2/wsl2.md). Once WSL 2 is set up, follow the Linux installation instructions above.
Docker - To install Cog inside a Docker image: +To install Cog inside a Docker image: - ```dockerfile - RUN sh -c "INSTALL_DIR=\"/usr/local/bin\" SUDO=\"\" $(curl -fsSL https://cog.run/install.sh)" - ``` +```dockerfile +RUN sh -c "INSTALL_DIR=\"/usr/local/bin\" SUDO=\"\" $(curl -fsSL https://cog.run/install.sh)" +```
diff --git a/pkg/schema/python/parser_test.go b/pkg/schema/python/parser_test.go index 2fd1a1eee9..7f094cceda 100644 --- a/pkg/schema/python/parser_test.go +++ b/pkg/schema/python/parser_test.go @@ -1050,8 +1050,6 @@ class Predictor(c.BasePredictor): require.True(t, info.IsAsync) } -func TestConcurrentDecoratorIgnoredWhenNotFromCog(t *testing.T) { - func TestConcurrentDecoratorIgnoredWhenNotFromCog(t *testing.T) { source := ` from other import concurrent From 4af3a18c623a1d83b8ce0bc5c076c2d49ccada64 Mon Sep 17 00:00:00 2001 From: asahoo Date: Thu, 2 Jul 2026 10:04:18 -0500 Subject: [PATCH 7/7] revert: remove concurrency branch changes --- crates/coglet-python/src/lib.rs | 40 +--- crates/coglet-python/src/worker_bridge.rs | 12 +- docs/deploy.md | 16 +- docs/environment.md | 13 +- docs/examples.md | 2 +- docs/http.md | 2 +- docs/llms.txt | 139 ++++++-------- docs/python.md | 15 +- docs/yaml.md | 3 +- examples/hello-concurrency/README.md | 15 +- examples/hello-concurrency/cog.yaml | 2 + examples/hello-concurrency/run.py | 2 - .../concurrent/concurrent_test.go | 10 +- pkg/config/validate.go | 8 - pkg/config/validate_test.go | 19 -- pkg/image/build.go | 96 +--------- pkg/image/build_test.go | 100 ---------- pkg/schema/generator.go | 11 +- pkg/schema/python/annotations.go | 109 ----------- pkg/schema/python/parser.go | 52 ------ pkg/schema/python/parser_test.go | 172 ------------------ pkg/schema/python/state.go | 2 - pkg/schema/types.go | 2 - python/cog/__init__.py | 37 ---- python/tests/test_concurrent.py | 74 -------- 25 files changed, 89 insertions(+), 864 deletions(-) delete mode 100644 python/tests/test_concurrent.py diff --git a/crates/coglet-python/src/lib.rs b/crates/coglet-python/src/lib.rs index 980b6c9cfd..35130783ba 100644 --- a/crates/coglet-python/src/lib.rs +++ b/crates/coglet-python/src/lib.rs @@ -188,25 +188,11 @@ fn detect_version(py: Python<'_>, build: &BuildInfo) -> VersionInfo { fn read_max_concurrency() -> usize { match std::env::var("COG_MAX_CONCURRENCY") { - Ok(val) => match parse_max_concurrency(&val) { - Some(n) => n, - None => { - warn!(value = %val, "Invalid COG_MAX_CONCURRENCY value, defaulting to 1"); - 1 - } - }, + Ok(val) => val.parse::().unwrap_or(1), Err(_) => 1, } } -fn parse_max_concurrency(val: &str) -> Option { - match val.parse::() { - Ok(0) => None, - Ok(n) => Some(n), - Err(_) => None, - } -} - fn read_setup_timeout() -> Option { match std::env::var("COG_SETUP_TIMEOUT") { Ok(val) => match val.parse::() { @@ -556,10 +542,10 @@ async fn run_worker_with_init() -> Result<(), String> { info!(predictor_ref = %predictor_ref, num_slots, is_train, "Init received, connecting to transport"); let handler = Arc::new(if is_train { - worker_bridge::PythonPredictHandler::new_train(predictor_ref, num_slots) + worker_bridge::PythonPredictHandler::new_train(predictor_ref) .map_err(|e| format!("Failed to create handler: {}", e))? } else { - worker_bridge::PythonPredictHandler::new(predictor_ref, num_slots) + worker_bridge::PythonPredictHandler::new(predictor_ref) .map_err(|e| format!("Failed to create handler: {}", e))? }); @@ -582,26 +568,6 @@ async fn run_worker_with_init() -> Result<(), String> { .map_err(|e| format!("Worker error: {}", e)) } -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn parse_max_concurrency_reads_valid_value() { - assert_eq!(parse_max_concurrency("4"), Some(4)); - } - - #[test] - fn parse_max_concurrency_rejects_invalid_value() { - assert_eq!(parse_max_concurrency("wat"), None); - } - - #[test] - fn parse_max_concurrency_rejects_zero() { - assert_eq!(parse_max_concurrency("0"), None); - } -} - // ============================================================================= // Module init // ============================================================================= diff --git a/crates/coglet-python/src/worker_bridge.rs b/crates/coglet-python/src/worker_bridge.rs index 3e697b364b..4378e729f8 100644 --- a/crates/coglet-python/src/worker_bridge.rs +++ b/crates/coglet-python/src/worker_bridge.rs @@ -101,12 +101,11 @@ pub struct PythonPredictHandler { async_loop: Mutex>>, /// Handle to the asyncio loop thread for joining on shutdown. async_thread: Mutex>>, - max_concurrency: usize, } impl PythonPredictHandler { /// Create a handler in prediction mode. - pub fn new(predictor_ref: String, max_concurrency: usize) -> Result { + pub fn new(predictor_ref: String) -> Result { let (loop_obj, thread) = Self::init_async_loop()?; Ok(Self { predictor_ref, @@ -115,7 +114,6 @@ impl PythonPredictHandler { mode: HandlerMode::Predict, async_loop: Mutex::new(Some(loop_obj)), async_thread: Mutex::new(Some(thread)), - max_concurrency, }) } @@ -124,7 +122,7 @@ impl PythonPredictHandler { /// NOTE: For bug-for-bug compatibility with cog mainline, use new() instead. /// Cog mainline's training routes incorrectly use a predict-mode worker. #[allow(dead_code)] - pub fn new_train(predictor_ref: String, max_concurrency: usize) -> Result { + pub fn new_train(predictor_ref: String) -> Result { let (loop_obj, thread) = Self::init_async_loop()?; Ok(Self { predictor_ref, @@ -133,7 +131,6 @@ impl PythonPredictHandler { mode: HandlerMode::Train, async_loop: Mutex::new(Some(loop_obj)), async_thread: Mutex::new(Some(thread)), - max_concurrency, }) } @@ -284,11 +281,6 @@ impl PredictHandler for PythonPredictHandler { let pred = PythonPredictor::load(py, &self.predictor_ref) .map_err(|e| SetupError::load(e.to_string()))?; - if self.max_concurrency > 1 && !pred.is_async() { - return Err(SetupError::setup( - "COG_MAX_CONCURRENCY > 1 requires an async run() or predict() method", - )); - } // Detect SDK implementation let sdk_impl = match py.import("cog") { diff --git a/docs/deploy.md b/docs/deploy.md index a1e833786f..f6cfcf685d 100644 --- a/docs/deploy.md +++ b/docs/deploy.md @@ -215,25 +215,21 @@ to stop.) ## Concurrency -By default, the server processes one run at a time. To enable concurrent runs, make your `run()` method async and decorate it with `@cog.concurrent(max=N)`: +By default, the server processes one run at a time. To enable concurrent runs, set the `concurrency.max` option in `cog.yaml`: -```py -import cog - -class Runner(cog.BaseRunner): - @cog.concurrent(max=4) - async def run(self) -> str: - return "hello world" +```yaml +concurrency: + max: 4 ``` -The deprecated [`concurrency.max`](yaml.md#concurrency) field in `cog.yaml` is still supported and takes precedence over the decorator by baking `COG_MAX_CONCURRENCY` into the image. +See the [`cog.yaml` reference](yaml.md#concurrency) for more details. ## Environment variables You can configure runtime behavior with environment variables: - `COG_SETUP_TIMEOUT`: Maximum time in seconds for the `setup()` method (default: no timeout). -- `COG_MAX_CONCURRENCY`: Number of concurrent prediction slots (default: 1). Overrides both `@cog.concurrent` and deprecated `cog.yaml` concurrency. +- `COG_MAX_CONCURRENCY`: Number of concurrent prediction slots (default: 1). See the [environment variables reference](environment.md) for the full list. diff --git a/docs/environment.md b/docs/environment.md index ee9b3afc6b..bbf02e2a90 100644 --- a/docs/environment.md +++ b/docs/environment.md @@ -173,18 +173,9 @@ These variables affect a running model server. Set them in `cog.yaml` under `env ### `COG_MAX_CONCURRENCY` -Controls how many predictions the model server can run concurrently. This overrides both `@cog.concurrent(max=N)` and the deprecated `concurrency.max` field in `cog.yaml`. +Controls how many predictions the model server can run concurrently. -By default, Cog runs one prediction at a time unless the model uses `@cog.concurrent(max=N)`. Invalid values are ignored and the default of `1` is used. - -Values greater than `1` require an async `run()` method. This applies even when `COG_MAX_CONCURRENCY` is set as a runtime operator override. - -Concurrency is resolved in this order, from highest to lowest precedence: - -1. `COG_MAX_CONCURRENCY` set at runtime -2. Deprecated `concurrency.max` in `cog.yaml`, which is baked into the image as `COG_MAX_CONCURRENCY` -3. `@cog.concurrent(max=N)` on the async `run()` method -4. Default: `1` +By default, Cog runs one prediction at a time. Invalid values are ignored and the default of `1` is used. ```console $ COG_MAX_CONCURRENCY=4 docker run -p 5000:5000 my-model diff --git a/docs/examples.md b/docs/examples.md index 31a930f35a..973c09093b 100644 --- a/docs/examples.md +++ b/docs/examples.md @@ -12,7 +12,7 @@ If you want a working project to copy from, start with one of these: - [`hello-context`](https://github.com/replicate/cog/tree/main/examples/hello-context): shows how to read prediction context - [`hello-concurrency`](https://github.com/replicate/cog/tree/main/examples/hello-concurrency): - demonstrates the `@cog.concurrent` decorator + demonstrates the `concurrency` setting in `cog.yaml` - [`hello-train`](https://github.com/replicate/cog/tree/main/examples/hello-train): defines a training interface - [`hello-replicate`](https://github.com/replicate/cog/tree/main/examples/hello-replicate): diff --git a/docs/http.md b/docs/http.md index de2a36dac4..592c4cee59 100644 --- a/docs/http.md +++ b/docs/http.md @@ -255,7 +255,7 @@ Prefer: respond-async Endpoints for creating and canceling a prediction idempotently accept a `prediction_id` parameter in their path. By default, the server runs one prediction at a time, -but this can be increased with [`@cog.concurrent(max=N)`](python.md#async-runners-and-concurrency). +but this can be increased with the [`concurrency.max`](yaml.md#concurrency) setting. When all prediction slots are in use, the server returns `409 Conflict`. The client should ensure prediction slots are available before creating a new prediction with a different ID. diff --git a/docs/llms.txt b/docs/llms.txt index cd4941c586..e3d4ed462a 100644 --- a/docs/llms.txt +++ b/docs/llms.txt @@ -127,83 +127,83 @@ Choose your platform for installation instructions.
macOS - The easiest way to install Cog on macOS is with Homebrew: +The easiest way to install Cog on macOS is with Homebrew: - ```console - brew install replicate/tap/cog - ``` +```console +brew install replicate/tap/cog +``` - You can also use the install script: +You can also use the install script: - ```sh - # bash, zsh, and other shells - sh <(curl -fsSL https://cog.run/install.sh) +```sh +# bash, zsh, and other shells +sh <(curl -fsSL https://cog.run/install.sh) - # fish shell - sh (curl -fsSL https://cog.run/install.sh | psub) +# fish shell +sh (curl -fsSL https://cog.run/install.sh | psub) - # download with wget and run in a separate command - wget -qO- https://cog.run/install.sh - sh ./install.sh - ``` +# download with wget and run in a separate command +wget -qO- https://cog.run/install.sh +sh ./install.sh +``` - Or install manually: +Or install manually: - ```console - sudo curl -o /usr/local/bin/cog -L "https://github.com/replicate/cog/releases/latest/download/cog_$(uname -s)_$(uname -m | sed 's/aarch64/arm64/')" - sudo chmod +x /usr/local/bin/cog - sudo xattr -d com.apple.quarantine /usr/local/bin/cog 2>/dev/null || true - ``` +```console +sudo curl -o /usr/local/bin/cog -L "https://github.com/replicate/cog/releases/latest/download/cog_$(uname -s)_$(uname -m | sed 's/aarch64/arm64/')" +sudo chmod +x /usr/local/bin/cog +sudo xattr -d com.apple.quarantine /usr/local/bin/cog 2>/dev/null || true +``` - If you see a Gatekeeper warning saying the binary "cannot be opened because the developer cannot be verified", run: +If you see a Gatekeeper warning saying the binary "cannot be opened because the developer cannot be verified", run: - ```console - sudo xattr -d com.apple.quarantine /usr/local/bin/cog - ``` +```console +sudo xattr -d com.apple.quarantine /usr/local/bin/cog +```
Linux - You can install Cog using the install script: +You can install Cog using the install script: - ```sh - # bash, zsh, and other shells - sh <(curl -fsSL https://cog.run/install.sh) +```sh +# bash, zsh, and other shells +sh <(curl -fsSL https://cog.run/install.sh) - # fish shell - sh (curl -fsSL https://cog.run/install.sh | psub) +# fish shell +sh (curl -fsSL https://cog.run/install.sh | psub) - # download with wget and run in a separate command - wget -qO- https://cog.run/install.sh - sh ./install.sh - ``` +# download with wget and run in a separate command +wget -qO- https://cog.run/install.sh +sh ./install.sh +``` - Or install manually: +Or install manually: - ```console - sudo curl -o /usr/local/bin/cog -L "https://github.com/replicate/cog/releases/latest/download/cog_$(uname -s)_$(uname -m | sed 's/aarch64/arm64/')" - sudo chmod +x /usr/local/bin/cog - ``` +```console +sudo curl -o /usr/local/bin/cog -L "https://github.com/replicate/cog/releases/latest/download/cog_$(uname -s)_$(uname -m | sed 's/aarch64/arm64/')" +sudo chmod +x /usr/local/bin/cog +```
Windows - Cog does not natively support Windows, but you can run it on Windows 11 using [WSL 2](docs/wsl2/wsl2.md). Once WSL 2 is set up, follow the Linux installation instructions above. +Cog does not natively support Windows, but you can run it on Windows 11 using [WSL 2](docs/wsl2/wsl2.md). Once WSL 2 is set up, follow the Linux installation instructions above.
Docker - To install Cog inside a Docker image: +To install Cog inside a Docker image: - ```dockerfile - RUN sh -c "INSTALL_DIR=\"/usr/local/bin\" SUDO=\"\" $(curl -fsSL https://cog.run/install.sh)" - ``` +```dockerfile +RUN sh -c "INSTALL_DIR=\"/usr/local/bin\" SUDO=\"\" $(curl -fsSL https://cog.run/install.sh)" +```
@@ -848,25 +848,21 @@ to stop.) ## Concurrency -By default, the server processes one run at a time. To enable concurrent runs, make your `run()` method async and decorate it with `@cog.concurrent(max=N)`: - -```py -import cog +By default, the server processes one run at a time. To enable concurrent runs, set the `concurrency.max` option in `cog.yaml`: -class Runner(cog.BaseRunner): - @cog.concurrent(max=4) - async def run(self) -> str: - return "hello world" +```yaml +concurrency: + max: 4 ``` -The deprecated [`concurrency.max`](yaml.md#concurrency) field in `cog.yaml` is still supported and takes precedence over the decorator by baking `COG_MAX_CONCURRENCY` into the image. +See the [`cog.yaml` reference](yaml.md#concurrency) for more details. ## Environment variables You can configure runtime behavior with environment variables: - `COG_SETUP_TIMEOUT`: Maximum time in seconds for the `setup()` method (default: no timeout). -- `COG_MAX_CONCURRENCY`: Number of concurrent prediction slots (default: 1). Overrides both `@cog.concurrent` and deprecated `cog.yaml` concurrency. +- `COG_MAX_CONCURRENCY`: Number of concurrent prediction slots (default: 1). See the [environment variables reference](environment.md) for the full list. @@ -1054,18 +1050,9 @@ These variables affect a running model server. Set them in `cog.yaml` under `env ### `COG_MAX_CONCURRENCY` -Controls how many predictions the model server can run concurrently. This overrides both `@cog.concurrent(max=N)` and the deprecated `concurrency.max` field in `cog.yaml`. - -By default, Cog runs one prediction at a time unless the model uses `@cog.concurrent(max=N)`. Invalid values are ignored and the default of `1` is used. - -Values greater than `1` require an async `run()` method. This applies even when `COG_MAX_CONCURRENCY` is set as a runtime operator override. - -Concurrency is resolved in this order, from highest to lowest precedence: +Controls how many predictions the model server can run concurrently. -1. `COG_MAX_CONCURRENCY` set at runtime -2. Deprecated `concurrency.max` in `cog.yaml`, which is baked into the image as `COG_MAX_CONCURRENCY` -3. `@cog.concurrent(max=N)` on the async `run()` method -4. Default: `1` +By default, Cog runs one prediction at a time. Invalid values are ignored and the default of `1` is used. ```console $ COG_MAX_CONCURRENCY=4 docker run -p 5000:5000 my-model @@ -1185,7 +1172,7 @@ If you want a working project to copy from, start with one of these: - [`hello-context`](https://github.com/replicate/cog/tree/main/examples/hello-context): shows how to read prediction context - [`hello-concurrency`](https://github.com/replicate/cog/tree/main/examples/hello-concurrency): - demonstrates the `@cog.concurrent` decorator + demonstrates the `concurrency` setting in `cog.yaml` - [`hello-train`](https://github.com/replicate/cog/tree/main/examples/hello-train): defines a training interface - [`hello-replicate`](https://github.com/replicate/cog/tree/main/examples/hello-replicate): @@ -1855,7 +1842,7 @@ Prefer: respond-async Endpoints for creating and canceling a prediction idempotently accept a `prediction_id` parameter in their path. By default, the server runs one prediction at a time, -but this can be increased with [`@cog.concurrent(max=N)`](python.md#async-runners-and-concurrency). +but this can be increased with the [`concurrency.max`](yaml.md#concurrency) setting. When all prediction slots are in use, the server returns `409 Conflict`. The client should ensure prediction slots are available before creating a new prediction with a different ID. @@ -2445,20 +2432,7 @@ class Runner(BaseRunner): return "hello world"; ``` -Models that have an async `run()` function can run concurrently. Use `@cog.concurrent(max=N)` to configure the default maximum concurrency for the function: - -```py -import cog - -class Runner(cog.BaseRunner): - @cog.concurrent(max=4) - async def run(self) -> str: - return "hello world" -``` - -Attempting to exceed this limit will return a 409 Conflict response. `max` values greater than `1` require an async `run()` method. The `COG_MAX_CONCURRENCY` environment variable can override the decorator at runtime. - -The `max` value must be an integer literal so Cog can configure the model at build time. Use `COG_MAX_CONCURRENCY` when you need to set concurrency dynamically at runtime. +Models that have an async `run()` function can run concurrently, up to the limit specified by [`concurrency.max`](yaml.md#max) in cog.yaml. Attempting to exceed this limit will return a 409 Conflict response. ## `Input(**kwargs)` @@ -3683,9 +3657,8 @@ build: ## `concurrency` > Added in cog 0.14.0. -> Deprecated: use [`@cog.concurrent(max=N)`](python.md#async-runners-and-concurrency) on your async `run()` method instead. -This stanza describes the concurrency capabilities of the model. It is still supported for backwards compatibility, but new models should use `@cog.concurrent(max=N)`. It has one option: +This stanza describes the concurrency capabilities of the model. It has one option: ### `max` diff --git a/docs/python.md b/docs/python.md index 218af5f33f..fe9307aa3b 100644 --- a/docs/python.md +++ b/docs/python.md @@ -128,20 +128,7 @@ class Runner(BaseRunner): return "hello world"; ``` -Models that have an async `run()` function can run concurrently. Use `@cog.concurrent(max=N)` to configure the default maximum concurrency for the function: - -```py -import cog - -class Runner(cog.BaseRunner): - @cog.concurrent(max=4) - async def run(self) -> str: - return "hello world" -``` - -Attempting to exceed this limit will return a 409 Conflict response. `max` values greater than `1` require an async `run()` method. The `COG_MAX_CONCURRENCY` environment variable can override the decorator at runtime. - -The `max` value must be an integer literal so Cog can configure the model at build time. Use `COG_MAX_CONCURRENCY` when you need to set concurrency dynamically at runtime. +Models that have an async `run()` function can run concurrently, up to the limit specified by [`concurrency.max`](yaml.md#max) in cog.yaml. Attempting to exceed this limit will return a 409 Conflict response. ## `Input(**kwargs)` diff --git a/docs/yaml.md b/docs/yaml.md index 4c0000f8bb..535b9aafa4 100644 --- a/docs/yaml.md +++ b/docs/yaml.md @@ -189,9 +189,8 @@ build: ## `concurrency` > Added in cog 0.14.0. -> Deprecated: use [`@cog.concurrent(max=N)`](python.md#async-runners-and-concurrency) on your async `run()` method instead. -This stanza describes the concurrency capabilities of the model. It is still supported for backwards compatibility, but new models should use `@cog.concurrent(max=N)`. It has one option: +This stanza describes the concurrency capabilities of the model. It has one option: ### `max` diff --git a/examples/hello-concurrency/README.md b/examples/hello-concurrency/README.md index 68c8cfed3e..a3e85c8552 100644 --- a/examples/hello-concurrency/README.md +++ b/examples/hello-concurrency/README.md @@ -1,16 +1,13 @@ # hello-concurrency -This is an example Cog project that demonstrates concurrency support within Cog. +This is an example Cog project that demonstrates the newly added concurrency support within +cog >= 0.14.0. -The key piece is the `@concurrent(max=4)` decorator on the async `run()` method. +The key piece is the new `concurrency` field in the cog.yaml. -```py -from cog import BaseRunner, concurrent - -class Runner(BaseRunner): - @concurrent(max=4) - async def run(self) -> str: - return "hello" +```yaml +concurrency: + max: 4 ``` This combined with the async setup and run methods in `run.py` allows Cog to run up to diff --git a/examples/hello-concurrency/cog.yaml b/examples/hello-concurrency/cog.yaml index 83e0e89397..9be8201142 100644 --- a/examples/hello-concurrency/cog.yaml +++ b/examples/hello-concurrency/cog.yaml @@ -5,3 +5,5 @@ build: python_version: "3.12" python_requirements: requirements.txt run: "run.py:Runner" +concurrency: + max: 4 diff --git a/examples/hello-concurrency/run.py b/examples/hello-concurrency/run.py index f5616bbabc..090f5d562e 100644 --- a/examples/hello-concurrency/run.py +++ b/examples/hello-concurrency/run.py @@ -19,7 +19,6 @@ BaseRunner, Input, __version__, - concurrent, current_scope, ) @@ -73,7 +72,6 @@ async def setup(self) -> None: logging.info(f"completed setup in {duration} seconds") span.set_attribute("model.setup_time_seconds", duration) - @concurrent(max=4) async def run( # pyright: ignore self, total: int = Input(default=5), diff --git a/integration-tests/concurrent/concurrent_test.go b/integration-tests/concurrent/concurrent_test.go index b5ce645ce7..bb1cc2ece6 100644 --- a/integration-tests/concurrent/concurrent_test.go +++ b/integration-tests/concurrent/concurrent_test.go @@ -259,14 +259,15 @@ func allocatePort() (int, error) { const cogYAML = `build: python_version: "3.11" predict: "predict.py:Predictor" +concurrency: + max: 5 ` const predictPy = `import asyncio -from cog import BasePredictor, concurrent +from cog import BasePredictor class Predictor(BasePredictor): - @concurrent(max=5) async def predict(self, s: str, sleep: float) -> str: await asyncio.sleep(sleep) return f"wake up {s}" @@ -398,13 +399,14 @@ func TestConcurrentAsyncMetrics(t *testing.T) { metricsCogYAML := `build: python_version: "3.12" predict: "predict.py:Predictor" +concurrency: + max: 5 ` metricsPredictPy := `import asyncio -from cog import BasePredictor, concurrent, current_scope +from cog import BasePredictor, current_scope class Predictor(BasePredictor): - @concurrent(max=5) async def predict(self, idx: int = 0, sleep: float = 0.5) -> str: scope = current_scope() scope.record_metric("prediction_index", idx) diff --git a/pkg/config/validate.go b/pkg/config/validate.go index cd16752af0..71aaaa2437 100644 --- a/pkg/config/validate.go +++ b/pkg/config/validate.go @@ -688,14 +688,6 @@ func checkDeprecatedFields(cfg *configFile, result *ValidationResult) { }) } - if cfg.Concurrency != nil && cfg.Concurrency.Max != nil { - result.AddWarning(DeprecationWarning{ - Field: "concurrency.max", - Replacement: "@cog.concurrent(max=...)", - Message: "configure prediction concurrency with @cog.concurrent on your async run() method instead", - }) - } - if cfg.Build == nil { return } diff --git a/pkg/config/validate_test.go b/pkg/config/validate_test.go index 316698ecf0..25ea320695 100644 --- a/pkg/config/validate_test.go +++ b/pkg/config/validate_test.go @@ -159,25 +159,6 @@ func TestValidateConfigFileConcurrencyType(t *testing.T) { result := ValidateConfigFile(cfg) require.False(t, result.HasErrors(), "expected no errors, got: %v", result.Errors) - require.Contains(t, result.Warnings, DeprecationWarning{ - Field: "concurrency.max", - Replacement: "@cog.concurrent(max=...)", - Message: "configure prediction concurrency with @cog.concurrent on your async run() method instead", - }) -} - -func TestValidateConfigFileConcurrencyDeprecationWithoutBuild(t *testing.T) { - cfg := &configFile{ - Run: new("run.py:Runner"), - Concurrency: &concurrencyFile{ - Max: new(2), - }, - } - - result := ValidateConfigFile(cfg) - require.False(t, result.HasErrors()) - require.Len(t, result.Warnings, 1) - require.Equal(t, "concurrency.max", result.Warnings[0].Field) } func TestValidateConfigFileDeprecatedPythonPackages(t *testing.T) { diff --git a/pkg/image/build.go b/pkg/image/build.go index 57362f4423..eb8a98477e 100644 --- a/pkg/image/build.go +++ b/pkg/image/build.go @@ -13,7 +13,6 @@ import ( "os/exec" "path/filepath" "slices" - "strconv" "strings" "time" @@ -125,10 +124,6 @@ func Build( // Generate schema before the Docker build so schema errors fail fast and the // schema file is available in the build context. var schemaJSON []byte - predictorInfo, err := generatePredictorMetadata(cfg, dir) - if err != nil { - return "", fmt.Errorf("image build failed: %w", err) - } switch { case needsSchema: if err := validateStaticSchemaSDKVersion(cfg); err != nil { @@ -164,11 +159,6 @@ func Build( } } - dockerfileCfg, err := configWithDecoratorConcurrency(cfg, predictorInfo) - if err != nil { - return "", err - } - // --- Runtime weights manifest (/.cog/weights.json) --- // When managed weights are configured and a lockfile exists, project the // lockfile to the minimal runtime manifest (spec §3.3) and write it into @@ -216,11 +206,8 @@ func Build( if _, err := dockerCommand.ImageBuild(ctx, buildOpts); err != nil { return "", fmt.Errorf("Failed to build Docker image: %w", err) } - if err := addConcurrencyToCustomDockerfileImage(ctx, dockerCommand, tmpImageId, dockerfileCfg.Concurrency, progressOutput, bp.buildDir); err != nil { - return "", err - } } else { - generator, err := dockerfile.NewStandardGenerator(dockerfileCfg, dir, bp.buildDir, configFilename, dockerCommand, client, true) + generator, err := dockerfile.NewStandardGenerator(cfg, dir, bp.buildDir, configFilename, dockerCommand, client, true) if err != nil { return "", fmt.Errorf("Error creating Dockerfile generator: %w", err) } @@ -459,83 +446,6 @@ func generateStaticSchema(cfg *config.Config, dir string) ([]byte, error) { return nil, fmt.Errorf("no predict or train reference found in cog.yaml") } return schema.GenerateCombined(dir, cfg.Predict, cfg.Train, schema.PathAwareParser(python.ParsePredictorWithSourcePath)) -} - -func generatePredictorMetadata(cfg *config.Config, dir string) (*schema.PredictorInfo, error) { - if cfg.Predict == "" && cfg.Train == "" { - return nil, nil - } - if cfg.Predict != "" { - return schema.GenerateInfo(cfg.Predict, dir, schema.ModePredict, schema.PathAwareParser(python.ParsePredictorMetadataWithSourcePath)) - } - return schema.GenerateInfo(cfg.Train, dir, schema.ModeTrain, schema.PathAwareParser(python.ParsePredictorMetadataWithSourcePath)) -} - -func configWithDecoratorConcurrency(cfg *config.Config, info *schema.PredictorInfo) (*config.Config, error) { - if cfg.Concurrency != nil && cfg.Concurrency.Max > 1 && info != nil && !info.IsAsync { - return nil, fmt.Errorf("concurrency.max > 1 requires an async run() or predict() method") - } - if cfg.Concurrency != nil || info == nil || info.ConcurrencyMax == nil { - return cfg, validateEffectiveConcurrency(cfg, info) - } - if *info.ConcurrencyMax > 1 && !info.IsAsync { - return nil, fmt.Errorf("@cog.concurrent(max > 1) requires an async run() or predict() method") - } - cfgCopy := *cfg - cfgCopy.Concurrency = &config.Concurrency{Max: *info.ConcurrencyMax} - return &cfgCopy, validateEffectiveConcurrency(&cfgCopy, info) -} - -func addConcurrencyToCustomDockerfileImage(ctx context.Context, dockerCommand command.Command, imageName string, concurrency *config.Concurrency, progressOutput string, buildCacheDir string) error { - if concurrency == nil || concurrency.Max <= 0 { - return nil - } - buildOpts := command.ImageBuildOptions{ - DockerfileContents: concurrencyDockerfile(imageName, concurrency.Max), - ImageName: imageName, - ProgressOutput: progressOutput, - BuildCacheDir: buildCacheDir, - } - if _, err := dockerCommand.ImageBuild(ctx, buildOpts); err != nil { - return fmt.Errorf("Failed to add concurrency configuration to Docker image: %w", err) - } - return nil -} - -func validateEffectiveConcurrency(cfg *config.Config, info *schema.PredictorInfo) error { - if cfg.Concurrency == nil || cfg.Concurrency.Max <= 1 { - return nil - } - if info != nil && !info.IsAsync { - return fmt.Errorf("concurrency.max > 1 requires an async run() or predict() method") - } - if cfg.Build == nil || cfg.Build.PythonVersion == "" { - return nil - } - major, minor, err := splitPythonVersionForBuild(cfg.Build.PythonVersion) - if err != nil { - return nil - } - if major == config.MinimumMajorPythonVersion && minor < config.MinimumMinorPythonVersionForConcurrency { - return fmt.Errorf("concurrency requires Python %d.%d or higher", config.MinimumMajorPythonVersion, config.MinimumMinorPythonVersionForConcurrency) - } - return nil -} - -func splitPythonVersionForBuild(version string) (major int, minor int, err error) { - parts := strings.Split(version, ".") - if len(parts) < 2 { - return 0, 0, fmt.Errorf("invalid Python version %q", version) - } - major, err = strconv.Atoi(parts[0]) - if err != nil { - return 0, 0, err - } - minor, err = strconv.Atoi(parts[1]) - if err != nil { - return 0, 0, err - } - return major, minor, nil } @@ -636,10 +546,6 @@ func bundleDockerfile(baseImage string, files []string) string { return b.String() } -func concurrencyDockerfile(baseImage string, maxConcurrency int) string { - return fmt.Sprintf("FROM %s\nENV COG_MAX_CONCURRENCY=%d\n", baseImage, maxConcurrency) -} - func isGitWorkTree(ctx context.Context, dir string) bool { ctx, cancel := context.WithTimeout(ctx, 3*time.Second) defer cancel() diff --git a/pkg/image/build_test.go b/pkg/image/build_test.go index 1089e22e51..bf2af1e229 100644 --- a/pkg/image/build_test.go +++ b/pkg/image/build_test.go @@ -14,10 +14,7 @@ import ( "github.com/stretchr/testify/require" "github.com/replicate/cog/pkg/config" - "github.com/replicate/cog/pkg/docker/command" - "github.com/replicate/cog/pkg/docker/dockertest" "github.com/replicate/cog/pkg/dotcog" - "github.com/replicate/cog/pkg/schema" "github.com/replicate/cog/pkg/weights/lockfile" ) @@ -65,103 +62,6 @@ func TestIsGitWorkTree(t *testing.T) { r.True(isGitWorkTree(ctx, setupGitWorkTree(t))) } -func TestConfigWithDecoratorConcurrencyAppliesWhenYAMLAbsent(t *testing.T) { - concurrencyMax := 4 - cfg := &config.Config{} - info := &schema.PredictorInfo{ConcurrencyMax: &concurrencyMax, IsAsync: true} - - got, err := configWithDecoratorConcurrency(cfg, info) - - require.NoError(t, err) - require.NotSame(t, cfg, got) - require.Nil(t, cfg.Concurrency) - require.NotNil(t, got.Concurrency) - require.Equal(t, 4, got.Concurrency.Max) -} - -func TestConfigWithDecoratorConcurrencyPreservesYAMLPrecedence(t *testing.T) { - concurrencyMax := 4 - cfg := &config.Config{Concurrency: &config.Concurrency{Max: 2}} - info := &schema.PredictorInfo{ConcurrencyMax: &concurrencyMax, IsAsync: true} - - got, err := configWithDecoratorConcurrency(cfg, info) - - require.NoError(t, err) - require.Same(t, cfg, got) - require.Equal(t, 2, got.Concurrency.Max) -} - -func TestConfigWithDecoratorConcurrencyRejectsSyncYAMLConcurrency(t *testing.T) { - cfg := &config.Config{Concurrency: &config.Concurrency{Max: 2}} - info := &schema.PredictorInfo{IsAsync: false} - - _, err := configWithDecoratorConcurrency(cfg, info) - - require.Error(t, err) - require.Contains(t, err.Error(), "requires an async") -} - -func TestConfigWithDecoratorConcurrencyRejectsOldPythonVersion(t *testing.T) { - concurrencyMax := 2 - cfg := &config.Config{Build: &config.Build{PythonVersion: "3.10"}} - info := &schema.PredictorInfo{ConcurrencyMax: &concurrencyMax, IsAsync: true} - - _, err := configWithDecoratorConcurrency(cfg, info) - - require.Error(t, err) - require.Contains(t, err.Error(), "Python 3.11 or higher") -} - -func TestConcurrencyDockerfileSetsEnv(t *testing.T) { - dockerfile := concurrencyDockerfile("my-image", 4) - - require.Equal(t, "FROM my-image\nENV COG_MAX_CONCURRENCY=4\n", dockerfile) -} - -type recordingCommand struct { - *dockertest.MockCommand - builds []command.ImageBuildOptions -} - -func (c *recordingCommand) ImageBuild(ctx context.Context, options command.ImageBuildOptions) (string, error) { - c.builds = append(c.builds, options) - return "sha256:test", nil -} - -func TestAddConcurrencyToCustomDockerfileImageBuildsWrapperLayer(t *testing.T) { - dockerCommand := &recordingCommand{MockCommand: dockertest.NewMockCommand()} - concurrency := &config.Concurrency{Max: 4} - - err := addConcurrencyToCustomDockerfileImage(t.Context(), dockerCommand, "my-image", concurrency, "plain", "/tmp/build-cache") - - require.NoError(t, err) - require.Len(t, dockerCommand.builds, 1) - require.Equal(t, "FROM my-image\nENV COG_MAX_CONCURRENCY=4\n", dockerCommand.builds[0].DockerfileContents) - require.Equal(t, "my-image", dockerCommand.builds[0].ImageName) - require.Equal(t, "plain", dockerCommand.builds[0].ProgressOutput) - require.Equal(t, "/tmp/build-cache", dockerCommand.builds[0].BuildCacheDir) -} - -func TestGeneratePredictorMetadataDoesNotRequireValidOutputSchema(t *testing.T) { - dir := t.TempDir() - require.NoError(t, os.WriteFile(filepath.Join(dir, "predict.py"), []byte(` -from cog import BasePredictor, concurrent - -class Predictor(BasePredictor): - @concurrent(max=3) - async def predict(self) -> NotARealType: - return "hello" -`), 0o644)) - cfg := &config.Config{Predict: "predict.py:Predictor"} - - info, err := generatePredictorMetadata(cfg, dir) - - require.NoError(t, err) - require.NotNil(t, info.ConcurrencyMax) - require.Equal(t, 3, *info.ConcurrencyMax) - require.True(t, info.IsAsync) -} - func TestGitHead(t *testing.T) { t.Run("via github env", func(t *testing.T) { t.Setenv("GITHUB_SHA", "fafafaf") diff --git a/pkg/schema/generator.go b/pkg/schema/generator.go index 790873beb2..efdddebfd0 100644 --- a/pkg/schema/generator.go +++ b/pkg/schema/generator.go @@ -40,15 +40,6 @@ func Generate(predictRef string, sourceDir string, mode Mode, parse any) ([]byte return data, nil } - info, err := GenerateInfo(predictRef, sourceDir, mode, parse) - if err != nil { - return nil, err - } - return GenerateOpenAPISchema(info) -} - -// GenerateInfo parses predictor source and returns the extracted predictor info. -func GenerateInfo(predictRef string, sourceDir string, mode Mode, parse any) (*PredictorInfo, error) { filePath, className, err := parsePredictRef(predictRef) if err != nil { return nil, err @@ -68,7 +59,7 @@ func GenerateInfo(predictRef string, sourceDir string, mode Mode, parse any) (*P return nil, fmt.Errorf("failed to read predictor source %s: %w", fullPath, err) } - return parsePredictorInfo(parse, source, className, mode, sourceDir, filePath) + return generateFromSource(source, className, mode, parse, sourceDir, filePath) } func cleanSourceFilePath(filePath string) (string, error) { diff --git a/pkg/schema/python/annotations.go b/pkg/schema/python/annotations.go index 90f940ff01..5310da4c21 100644 --- a/pkg/schema/python/annotations.go +++ b/pkg/schema/python/annotations.go @@ -2,7 +2,6 @@ package python import ( "fmt" - "math" "strings" sitter "github.com/smacker/go-tree-sitter" @@ -244,72 +243,6 @@ func functionSupportsStreaming(node *sitter.Node, source []byte, imports *schema return false } -func functionIsAsync(node *sitter.Node, source []byte) bool { - return strings.HasPrefix(strings.TrimSpace(Content(node, source)), "async def ") -} - -func functionConcurrencyMax(node *sitter.Node, source []byte, imports *schema.ImportContext) (*int, error) { - decorated := decoratedFunctionNode(node) - if decorated == nil { - return nil, nil - } - - for _, child := range NamedChildren(decorated) { - if child.Type() != "decorator" { - continue - } - concurrencyMax, ok, err := decoratorCogConcurrentMax(child, source, imports) - if err != nil { - return nil, err - } - if ok { - return concurrencyMax, nil - } - } - return nil, nil -} - -func decoratorCogConcurrentMax(node *sitter.Node, source []byte, imports *schema.ImportContext) (*int, bool, error) { - expr := decoratorExpression(node) - if expr == nil || !expressionIsCogConcurrent(expr, source, imports) { - return nil, false, nil - } - - call := decoratorCall(node) - if call == nil { - concurrencyMax := 1 - return &concurrencyMax, true, nil - } - - argList := callArgumentList(call) - if argList == nil || len(NamedChildren(argList)) == 0 { - concurrencyMax := 1 - return &concurrencyMax, true, nil - } - - args := NamedChildren(argList) - if len(args) != 1 || args[0].Type() != "keyword_argument" { - return nil, true, schema.WrapError(schema.ErrUnsupportedType, "@concurrent only supports literal integer max=... arguments", nil) - } - nameNode := args[0].ChildByFieldName("name") - valueNode := args[0].ChildByFieldName("value") - if nameNode == nil || valueNode == nil || Content(nameNode, source) != "max" { - return nil, true, schema.WrapError(schema.ErrUnsupportedType, "@concurrent only supports literal integer max=... arguments", nil) - } - val, ok := parseDefaultValue(valueNode, source) - if !ok || val.Kind != schema.DefaultInt { - return nil, true, schema.WrapError(schema.ErrUnsupportedType, "@concurrent max must be an integer literal", nil) - } - if val.Int < 1 { - return nil, true, schema.WrapError(schema.ErrUnsupportedType, "@concurrent max must be at least 1", nil) - } - if val.Int > int64(math.MaxInt) { - return nil, true, schema.WrapError(schema.ErrUnsupportedType, "@concurrent max is too large", nil) - } - concurrencyMax := int(val.Int) - return &concurrencyMax, true, nil -} - func decoratedFunctionNode(node *sitter.Node) *sitter.Node { if node.Type() == "decorated_definition" { return node @@ -344,23 +277,6 @@ func decoratorExpression(node *sitter.Node) *sitter.Node { return child } -func decoratorCall(node *sitter.Node) *sitter.Node { - children := NamedChildren(node) - if len(children) == 0 || children[0].Type() != "call" { - return nil - } - return children[0] -} - -func callArgumentList(node *sitter.Node) *sitter.Node { - for _, child := range NamedChildren(node) { - if child.Type() == "argument_list" { - return child - } - } - return nil -} - func expressionIsCogStreaming(node *sitter.Node, source []byte, imports *schema.ImportContext) bool { switch node.Type() { case "attribute": @@ -372,17 +288,6 @@ func expressionIsCogStreaming(node *sitter.Node, source []byte, imports *schema. } } -func expressionIsCogConcurrent(node *sitter.Node, source []byte, imports *schema.ImportContext) bool { - switch node.Type() { - case "attribute": - return attributeIsCogConcurrent(node, source, imports) - case "identifier": - return identifierIsCogConcurrent(node, source, imports) - default: - return false - } -} - func attributeIsCogStreaming(node *sitter.Node, source []byte, imports *schema.ImportContext) bool { name, attr, ok := strings.Cut(Content(node, source), ".") if !ok || attr != "streaming" { @@ -392,25 +297,11 @@ func attributeIsCogStreaming(node *sitter.Node, source []byte, imports *schema.I return ok && entry.Module == "cog" && entry.Original == "cog" } -func attributeIsCogConcurrent(node *sitter.Node, source []byte, imports *schema.ImportContext) bool { - name, attr, ok := strings.Cut(Content(node, source), ".") - if !ok || attr != "concurrent" { - return false - } - entry, ok := imports.Names.Get(name) - return ok && entry.Module == "cog" && entry.Original == "cog" -} - func identifierIsCogStreaming(node *sitter.Node, source []byte, imports *schema.ImportContext) bool { entry, ok := imports.Names.Get(Content(node, source)) return ok && entry.Module == "cog" && entry.Original == "streaming" } -func identifierIsCogConcurrent(node *sitter.Node, source []byte, imports *schema.ImportContext) bool { - entry, ok := imports.Names.Get(Content(node, source)) - return ok && entry.Module == "cog" && entry.Original == "concurrent" -} - func supportsStreamingOutput(output schema.SchemaType) bool { return output.Kind == schema.SchemaIterator || output.Kind == schema.SchemaConcatIterator } diff --git a/pkg/schema/python/parser.go b/pkg/schema/python/parser.go index 093a1e6e85..4ee6557d33 100644 --- a/pkg/schema/python/parser.go +++ b/pkg/schema/python/parser.go @@ -35,47 +35,6 @@ func ParsePredictorWithSourcePath(source []byte, targetRef string, mode schema.M return ParseWithOptions(opts) } -func ParsePredictorMetadataWithSourcePath(source []byte, targetRef string, mode schema.Mode, sourceDir string, sourcePath string) (*schema.PredictorInfo, error) { - opts := defaultParserOptions(source, targetRef, mode, sourceDir) - opts.SourcePath = sourcePath - return ParseMetadataWithOptions(opts) -} - -func ParseMetadataWithOptions(opts ParserOptions) (*schema.PredictorInfo, error) { - sourcePath, err := normalizeSourcePath(opts.SourceDir, opts.SourcePath) - if err != nil { - return nil, err - } - opts.SourcePath = sourcePath - state := newParseState(opts) - phases := []parserPhase{ - {Name: "parse module", From: phaseInitial, To: phaseModuleParsed, Run: parseModulePhase}, - {Name: "collect imports", From: phaseModuleParsed, To: phaseImportsCollected, Run: collectImportsPhase}, - {Name: "collect module scope", From: phaseImportsCollected, To: phaseModuleScopeCollected, Run: collectModuleScopePhase}, - {Name: "collect local models", From: phaseModuleScopeCollected, To: phaseLocalModelsCollected, Run: collectLocalModelsPhase}, - {Name: "resolve imported models", From: phaseLocalModelsCollected, To: phaseImportedModelsResolved, Run: resolveImportedModelsPhase}, - {Name: "collect input registry", From: phaseImportedModelsResolved, To: phaseInputRegistryCollected, Run: collectInputRegistryPhase}, - {Name: "find target callable", From: phaseInputRegistryCollected, To: phaseTargetFound, Run: findTargetCallablePhase}, - } - if err := runPhases(state, phases); err != nil { - return nil, err - } - targetSource := state.TargetFunc.file.source - isAsync := functionIsAsync(state.TargetFunc.node, targetSource) - concurrencyMax, err := functionConcurrencyMax(state.TargetFunc.node, targetSource, state.TargetFunc.file.imports) - if err != nil { - return nil, err - } - if concurrencyMax != nil && *concurrencyMax > 1 && !isAsync { - return nil, schema.WrapError(schema.ErrUnsupportedType, "@concurrent(max > 1) requires an async run() or predict() method", nil) - } - return &schema.PredictorInfo{ - Mode: opts.Mode, - ConcurrencyMax: concurrencyMax, - IsAsync: isAsync, - }, nil -} - func ParseWithOptions(opts ParserOptions) (*schema.PredictorInfo, error) { sourcePath, err := normalizeSourcePath(opts.SourceDir, opts.SourcePath) if err != nil { @@ -103,8 +62,6 @@ func ParseWithOptions(opts ParserOptions) (*schema.PredictorInfo, error) { Output: state.Output, Mode: opts.Mode, SupportsStreaming: state.SupportsStreaming, - ConcurrencyMax: state.ConcurrencyMax, - IsAsync: state.IsAsync, }, nil } @@ -251,7 +208,6 @@ func extractInputsPhase(state *ParseState) error { } state.Target = &TargetCallable{MethodName: actualMethodName, Node: funcNode, IsMethod: isMethod} state.Inputs = inputs - state.IsAsync = functionIsAsync(funcNode, targetSource) return nil } @@ -275,16 +231,8 @@ func resolveOutputPhase(state *ParseState) error { if supportsStreaming && !supportsStreamingOutput(output) { return schema.WrapError(schema.ErrUnsupportedType, "@streaming requires Iterator[...], AsyncIterator[...], ConcatenateIterator[...] or AsyncConcatenateIterator[...] return type", nil) } - concurrencyMax, err := functionConcurrencyMax(state.TargetFunc.node, state.TargetFunc.file.source, state.TargetFunc.file.imports) - if err != nil { - return err - } - if concurrencyMax != nil && *concurrencyMax > 1 && !state.IsAsync { - return schema.WrapError(schema.ErrUnsupportedType, "@concurrent(max > 1) requires an async run() or predict() method", nil) - } state.Output = output state.SupportsStreaming = supportsStreaming - state.ConcurrencyMax = concurrencyMax state.OutputSet = true return nil } diff --git a/pkg/schema/python/parser_test.go b/pkg/schema/python/parser_test.go index 7f094cceda..de0a71f267 100644 --- a/pkg/schema/python/parser_test.go +++ b/pkg/schema/python/parser_test.go @@ -963,178 +963,6 @@ class Predictor(cog.BasePredictor): require.False(t, info.SupportsStreaming) } -func TestConcurrentDecoratorQualifiedMax(t *testing.T) { - source := ` -import cog - -class Predictor(cog.BasePredictor): - @cog.concurrent(max=4) - async def predict(self) -> str: - return "hello" -` - info := parse(t, source, "Predictor") - require.NotNil(t, info.ConcurrencyMax) - require.Equal(t, 4, *info.ConcurrencyMax) - require.True(t, info.IsAsync) -} - -func TestConcurrentDecoratorImportedMax(t *testing.T) { - source := ` -from cog import BasePredictor, concurrent - -class Predictor(BasePredictor): - @concurrent(max=3) - async def predict(self) -> str: - return "hello" -` - info := parse(t, source, "Predictor") - require.NotNil(t, info.ConcurrencyMax) - require.Equal(t, 3, *info.ConcurrencyMax) -} - -func TestConcurrentDecoratorBareDefaultsToOne(t *testing.T) { - source := ` -from cog import BasePredictor, concurrent - -class Predictor(BasePredictor): - @concurrent - def predict(self) -> str: - return "hello" -` - info := parse(t, source, "Predictor") - require.NotNil(t, info.ConcurrencyMax) - require.Equal(t, 1, *info.ConcurrencyMax) - require.False(t, info.IsAsync) -} - -func TestConcurrentDecoratorCallDefaultsToOne(t *testing.T) { - source := ` -import cog - -class Predictor(cog.BasePredictor): - @cog.concurrent() - def predict(self) -> str: - return "hello" -` - info := parse(t, source, "Predictor") - require.NotNil(t, info.ConcurrencyMax) - require.Equal(t, 1, *info.ConcurrencyMax) -} - -func TestConcurrentDecoratorImportedAliasMax(t *testing.T) { - source := ` -from cog import BasePredictor, concurrent as concurrent_predictions - -class Predictor(BasePredictor): - @concurrent_predictions(max=2) - async def predict(self) -> str: - return "hello" -` - info := parse(t, source, "Predictor") - require.NotNil(t, info.ConcurrencyMax) - require.Equal(t, 2, *info.ConcurrencyMax) -} - -func TestConcurrentDecoratorQualifiedAliasMax(t *testing.T) { - source := ` -import cog as c - -class Predictor(c.BasePredictor): - @c.concurrent(max=4) - async def predict(self) -> str: - return "hello" -` - info := parse(t, source, "Predictor") - require.NotNil(t, info.ConcurrencyMax) - require.Equal(t, 4, *info.ConcurrencyMax) - require.True(t, info.IsAsync) -} - -func TestConcurrentDecoratorIgnoredWhenNotFromCog(t *testing.T) { - source := ` -from other import concurrent -from cog import BasePredictor - -class Predictor(BasePredictor): - @concurrent(max=4) - async def predict(self) -> str: - return "hello" -` - info := parse(t, source, "Predictor") - require.Nil(t, info.ConcurrencyMax) -} - -func TestConcurrentDecoratorRequiresAsyncForMaxGreaterThanOne(t *testing.T) { - source := ` -from cog import BasePredictor, concurrent - -class Predictor(BasePredictor): - @concurrent(max=2) - def predict(self) -> str: - return "hello" -` - se := parseErr(t, source, "Predictor", schema.ModePredict) - require.Equal(t, schema.ErrUnsupportedType, se.Kind) - require.Contains(t, se.Message, "requires an async") -} - -func TestConcurrentDecoratorRejectsNonLiteralMax(t *testing.T) { - source := ` -from cog import BasePredictor, concurrent - -MAX_CONCURRENCY = 2 - -class Predictor(BasePredictor): - @concurrent(max=MAX_CONCURRENCY) - async def predict(self) -> str: - return "hello" -` - se := parseErr(t, source, "Predictor", schema.ModePredict) - require.Equal(t, schema.ErrUnsupportedType, se.Kind) - require.Contains(t, se.Message, "integer literal") -} - -func TestConcurrentDecoratorRejectsPositionalArgument(t *testing.T) { - source := ` -from cog import BasePredictor, concurrent - -class Predictor(BasePredictor): - @concurrent(2) - async def predict(self) -> str: - return "hello" -` - se := parseErr(t, source, "Predictor", schema.ModePredict) - require.Equal(t, schema.ErrUnsupportedType, se.Kind) - require.Contains(t, se.Message, "max=...") -} - -func TestConcurrentDecoratorRejectsStringMax(t *testing.T) { - source := ` -from cog import BasePredictor, concurrent - -class Predictor(BasePredictor): - @concurrent(max="2") - async def predict(self) -> str: - return "hello" -` - se := parseErr(t, source, "Predictor", schema.ModePredict) - require.Equal(t, schema.ErrUnsupportedType, se.Kind) - require.Contains(t, se.Message, "integer literal") -} - -func TestConcurrentDecoratorUndecoratedHasNoMax(t *testing.T) { - source := ` -from cog import BasePredictor - -class Predictor(BasePredictor): - async def predict(self) -> str: - return "hello" -` - info := parse(t, source, "Predictor") - require.Nil(t, info.ConcurrencyMax) - require.True(t, info.IsAsync) -} - func TestListOutput(t *testing.T) { source := ` from cog import BasePredictor, Path diff --git a/pkg/schema/python/state.go b/pkg/schema/python/state.go index dc0421ea16..17533067ec 100644 --- a/pkg/schema/python/state.go +++ b/pkg/schema/python/state.go @@ -104,8 +104,6 @@ type ParseState struct { Inputs *schema.OrderedMap[string, schema.InputField] Output schema.SchemaType SupportsStreaming bool - ConcurrencyMax *int - IsAsync bool OutputSet bool } diff --git a/pkg/schema/types.go b/pkg/schema/types.go index 38d7878f5d..0d86535801 100644 --- a/pkg/schema/types.go +++ b/pkg/schema/types.go @@ -254,8 +254,6 @@ type PredictorInfo struct { Output SchemaType Mode Mode SupportsStreaming bool - ConcurrencyMax *int - IsAsync bool } // TypeAnnotation is a parsed Python type annotation (intermediate, before resolution). diff --git a/python/cog/__init__.py b/python/cog/__init__.py index ba55c6902f..775e2ba0e9 100644 --- a/python/cog/__init__.py +++ b/python/cog/__init__.py @@ -19,7 +19,6 @@ def run( return self.model.generate(prompt, image) """ -import inspect as _inspect import sys as _sys from collections.abc import Callable from typing import TypeVar, overload @@ -65,41 +64,6 @@ def decorate(inner: _F) -> _F: return decorate(fn) -@overload -def concurrent(fn: _F) -> _F: - pass - - -@overload -def concurrent(fn: None = None, *, max: int = 1) -> Callable[[_F], _F]: # noqa: A002 - pass - - -def concurrent( - fn: _F | None = None, - *, - max: int = 1, # noqa: A002 -) -> _F | Callable[[_F], _F]: - """Configure the maximum concurrency for an async predict handler.""" - - if isinstance(max, bool) or not isinstance(max, int): - raise TypeError("concurrent max must be an integer") - if max < 1: - raise ValueError("concurrent max must be at least 1") - - def decorate(inner: _F) -> _F: - if max > 1 and not ( - _inspect.iscoroutinefunction(inner) or _inspect.isasyncgenfunction(inner) - ): - raise TypeError("concurrent max greater than 1 requires an async function") - inner.__cog_concurrent_max__ = max # type: ignore[attr-defined] - return inner - - if fn is None: - return decorate - return decorate(fn) - - # --------------------------------------------------------------------------- # Backwards-compatibility shim: ExperimentalFeatureWarning # @@ -197,7 +161,6 @@ def current_scope() -> object: "current_scope", # Decorators "streaming", - "concurrent", # Deprecated compat shims "ExperimentalFeatureWarning", "emit_metric", diff --git a/python/tests/test_concurrent.py b/python/tests/test_concurrent.py deleted file mode 100644 index 399d47a241..0000000000 --- a/python/tests/test_concurrent.py +++ /dev/null @@ -1,74 +0,0 @@ -"""Tests for the cog.concurrent decorator.""" - -import sys -import types -from collections.abc import AsyncIterator - -import pytest - -if "coglet" not in sys.modules: - coglet = types.ModuleType("coglet") - coglet.CancelationException = type("CancelationException", (BaseException,), {}) - sys.modules["coglet"] = coglet - -from cog import concurrent - - -def test_bare_concurrent_sets_default_max_on_sync_function() -> None: - @concurrent - def predict() -> str: - return "ok" - - assert predict.__cog_concurrent_max__ == 1 - - -def test_concurrent_call_form_sets_default_max_on_sync_function() -> None: - @concurrent() - def predict() -> str: - return "ok" - - assert predict.__cog_concurrent_max__ == 1 - - -def test_concurrent_max_one_allows_sync_function() -> None: - @concurrent(max=1) - def predict() -> str: - return "ok" - - assert predict.__cog_concurrent_max__ == 1 - - -def test_concurrent_max_allows_async_function() -> None: - @concurrent(max=3) - async def predict() -> str: - return "ok" - - assert predict.__cog_concurrent_max__ == 3 - - -def test_concurrent_max_allows_async_generator() -> None: - @concurrent(max=3) - async def predict() -> AsyncIterator[str]: - yield "ok" - - assert predict.__cog_concurrent_max__ == 3 - - -def test_concurrent_max_rejects_sync_function() -> None: - with pytest.raises(TypeError, match="requires an async function"): - - @concurrent(max=2) - def predict() -> str: - return "ok" - - -@pytest.mark.parametrize("value", [0, -1]) -def test_concurrent_rejects_max_less_than_one(value: int) -> None: - with pytest.raises(ValueError, match="at least 1"): - concurrent(max=value) - - -@pytest.mark.parametrize("value", [True, 1.5, "2"]) -def test_concurrent_rejects_non_integer_max(value: object) -> None: - with pytest.raises(TypeError, match="must be an integer"): - concurrent(max=value) # type: ignore[arg-type]