Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 37 additions & 3 deletions crates/coglet-python/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -188,11 +188,25 @@ fn detect_version(py: Python<'_>, build: &BuildInfo) -> VersionInfo {

fn read_max_concurrency() -> usize {
match std::env::var("COG_MAX_CONCURRENCY") {
Ok(val) => val.parse::<usize>().unwrap_or(1),
Ok(val) => match parse_max_concurrency(&val) {
Some(n) => n,
None => {
warn!(value = %val, "Invalid COG_MAX_CONCURRENCY value, defaulting to 1");
1
}
},
Err(_) => 1,
}
}

fn parse_max_concurrency(val: &str) -> Option<usize> {
match val.parse::<usize>() {
Ok(0) => None,
Ok(n) => Some(n),
Err(_) => None,
}
}

fn read_setup_timeout() -> Option<std::time::Duration> {
match std::env::var("COG_SETUP_TIMEOUT") {
Ok(val) => match val.parse::<u64>() {
Expand Down Expand Up @@ -542,10 +556,10 @@ async fn run_worker_with_init() -> Result<(), String> {
info!(predictor_ref = %predictor_ref, num_slots, is_train, "Init received, connecting to transport");

let handler = Arc::new(if is_train {
worker_bridge::PythonPredictHandler::new_train(predictor_ref)
worker_bridge::PythonPredictHandler::new_train(predictor_ref, num_slots)
.map_err(|e| format!("Failed to create handler: {}", e))?
} else {
worker_bridge::PythonPredictHandler::new(predictor_ref)
worker_bridge::PythonPredictHandler::new(predictor_ref, num_slots)
.map_err(|e| format!("Failed to create handler: {}", e))?
});

Expand All @@ -568,6 +582,26 @@ async fn run_worker_with_init() -> Result<(), String> {
.map_err(|e| format!("Worker error: {}", e))
}

#[cfg(test)]
mod tests {
use super::*;

#[test]
fn parse_max_concurrency_reads_valid_value() {
assert_eq!(parse_max_concurrency("4"), Some(4));
}

#[test]
fn parse_max_concurrency_rejects_invalid_value() {
assert_eq!(parse_max_concurrency("wat"), None);
}

#[test]
fn parse_max_concurrency_rejects_zero() {
assert_eq!(parse_max_concurrency("0"), None);
}
}

// =============================================================================
// Module init
// =============================================================================
Expand Down
12 changes: 10 additions & 2 deletions crates/coglet-python/src/worker_bridge.rs
Original file line number Diff line number Diff line change
Expand Up @@ -101,11 +101,12 @@ pub struct PythonPredictHandler {
async_loop: Mutex<Option<Py<PyAny>>>,
/// Handle to the asyncio loop thread for joining on shutdown.
async_thread: Mutex<Option<JoinHandle<()>>>,
max_concurrency: usize,
}

impl PythonPredictHandler {
/// Create a handler in prediction mode.
pub fn new(predictor_ref: String) -> Result<Self, SetupError> {
pub fn new(predictor_ref: String, max_concurrency: usize) -> Result<Self, SetupError> {
let (loop_obj, thread) = Self::init_async_loop()?;
Ok(Self {
predictor_ref,
Expand All @@ -114,6 +115,7 @@ impl PythonPredictHandler {
mode: HandlerMode::Predict,
async_loop: Mutex::new(Some(loop_obj)),
async_thread: Mutex::new(Some(thread)),
max_concurrency,
})
}

Expand All @@ -122,7 +124,7 @@ impl PythonPredictHandler {
/// NOTE: For bug-for-bug compatibility with cog mainline, use new() instead.
/// Cog mainline's training routes incorrectly use a predict-mode worker.
#[allow(dead_code)]
pub fn new_train(predictor_ref: String) -> Result<Self, SetupError> {
pub fn new_train(predictor_ref: String, max_concurrency: usize) -> Result<Self, SetupError> {
let (loop_obj, thread) = Self::init_async_loop()?;
Ok(Self {
predictor_ref,
Expand All @@ -131,6 +133,7 @@ impl PythonPredictHandler {
mode: HandlerMode::Train,
async_loop: Mutex::new(Some(loop_obj)),
async_thread: Mutex::new(Some(thread)),
max_concurrency,
})
}

Expand Down Expand Up @@ -281,6 +284,11 @@ impl PredictHandler for PythonPredictHandler {

let pred = PythonPredictor::load(py, &self.predictor_ref)
.map_err(|e| SetupError::load(e.to_string()))?;
if self.max_concurrency > 1 && !pred.is_async() {
return Err(SetupError::setup(
"COG_MAX_CONCURRENCY > 1 requires an async run() or predict() method",
));
}

// Detect SDK implementation
let sdk_impl = match py.import("cog") {
Expand Down
34 changes: 19 additions & 15 deletions docs/deploy.md
Original file line number Diff line number Diff line change
Expand Up @@ -76,11 +76,11 @@ curl http://localhost:5001/predictions -X POST \

```json
{
"status": "succeeded",
"output": "data:image/png;base64,...",
"metrics": {
"predict_time": 4.52
}
"status": "succeeded",
"output": "data:image/png;base64,...",
"metrics": {
"predict_time": 4.52
}
}
```

Expand Down Expand Up @@ -144,8 +144,8 @@ the response contains a base64-encoded data URL by default:

```json
{
"status": "succeeded",
"output": "data:image/png;base64,iVBORw0KGgo..."
"status": "succeeded",
"output": "data:image/png;base64,iVBORw0KGgo..."
}
```

Expand All @@ -172,8 +172,8 @@ contains the uploaded URL instead of a data URL:

```json
{
"status": "succeeded",
"output": "https://example.com/upload/image.png"
"status": "succeeded",
"output": "https://example.com/upload/image.png"
}
```

Expand Down Expand Up @@ -215,21 +215,25 @@ to stop.)

## Concurrency

By default, the server processes one run at a time. To enable concurrent runs, set the `concurrency.max` option in `cog.yaml`:
By default, the server processes one run at a time. To enable concurrent runs, make your `run()` method async and decorate it with `@cog.concurrent(max=N)`:

```yaml
concurrency:
max: 4
```py
import cog

class Runner(cog.BaseRunner):
@cog.concurrent(max=4)
async def run(self) -> str:
return "hello world"
```

See the [`cog.yaml` reference](yaml.md#concurrency) for more details.
The deprecated [`concurrency.max`](yaml.md#concurrency) field in `cog.yaml` is still supported and takes precedence over the decorator by baking `COG_MAX_CONCURRENCY` into the image.

## Environment variables

You can configure runtime behavior with environment variables:

- `COG_SETUP_TIMEOUT`: Maximum time in seconds for the `setup()` method (default: no timeout).
- `COG_MAX_CONCURRENCY`: Number of concurrent prediction slots (default: 1).
- `COG_MAX_CONCURRENCY`: Number of concurrent prediction slots (default: 1). Overrides both `@cog.concurrent` and deprecated `cog.yaml` concurrency.

See the [environment variables reference](environment.md) for the full list.

Expand Down
13 changes: 11 additions & 2 deletions docs/environment.md
Original file line number Diff line number Diff line change
Expand Up @@ -173,9 +173,18 @@ These variables affect a running model server. Set them in `cog.yaml` under `env

### `COG_MAX_CONCURRENCY`

Controls how many predictions the model server can run concurrently.
Controls how many predictions the model server can run concurrently. This overrides both `@cog.concurrent(max=N)` and the deprecated `concurrency.max` field in `cog.yaml`.

By default, Cog runs one prediction at a time. Invalid values are ignored and the default of `1` is used.
By default, Cog runs one prediction at a time unless the model uses `@cog.concurrent(max=N)`. Invalid values are ignored and the default of `1` is used.

Values greater than `1` require an async `run()` method. This applies even when `COG_MAX_CONCURRENCY` is set as a runtime operator override.

Concurrency is resolved in this order, from highest to lowest precedence:

1. `COG_MAX_CONCURRENCY` set at runtime
2. Deprecated `concurrency.max` in `cog.yaml`, which is baked into the image as `COG_MAX_CONCURRENCY`
3. `@cog.concurrent(max=N)` on the async `run()` method
4. Default: `1`

```console
$ COG_MAX_CONCURRENCY=4 docker run -p 5000:5000 my-model
Expand Down
2 changes: 1 addition & 1 deletion docs/examples.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ If you want a working project to copy from, start with one of these:
- [`hello-context`](https://github.com/replicate/cog/tree/main/examples/hello-context):
shows how to read prediction context
- [`hello-concurrency`](https://github.com/replicate/cog/tree/main/examples/hello-concurrency):
demonstrates the `concurrency` setting in `cog.yaml`
demonstrates the `@cog.concurrent` decorator
- [`hello-train`](https://github.com/replicate/cog/tree/main/examples/hello-train):
defines a training interface
- [`hello-replicate`](https://github.com/replicate/cog/tree/main/examples/hello-replicate):
Expand Down
2 changes: 1 addition & 1 deletion docs/http.md
Original file line number Diff line number Diff line change
Expand Up @@ -255,7 +255,7 @@ Prefer: respond-async
Endpoints for creating and canceling a prediction idempotently
accept a `prediction_id` parameter in their path.
By default, the server runs one prediction at a time,
but this can be increased with the [`concurrency.max`](yaml.md#concurrency) setting.
but this can be increased with [`@cog.concurrent(max=N)`](python.md#async-runners-and-concurrency).
When all prediction slots are in use, the server returns `409 Conflict`.
The client should ensure prediction slots are available
before creating a new prediction with a different ID.
Expand Down
69 changes: 48 additions & 21 deletions docs/llms.txt

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

15 changes: 14 additions & 1 deletion docs/python.md
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,20 @@ class Runner(BaseRunner):
return "hello world";
```

Models that have an async `run()` function can run concurrently, up to the limit specified by [`concurrency.max`](yaml.md#max) in cog.yaml. Attempting to exceed this limit will return a 409 Conflict response.
Models that have an async `run()` function can run concurrently. Use `@cog.concurrent(max=N)` to configure the default maximum concurrency for the function:

```py
import cog

class Runner(cog.BaseRunner):
@cog.concurrent(max=4)
async def run(self) -> str:
return "hello world"
```

Attempting to exceed this limit will return a 409 Conflict response. `max` values greater than `1` require an async `run()` method. The `COG_MAX_CONCURRENCY` environment variable can override the decorator at runtime.

The `max` value must be an integer literal so Cog can configure the model at build time. Use `COG_MAX_CONCURRENCY` when you need to set concurrency dynamically at runtime.

## `Input(**kwargs)`

Expand Down
Loading