diff --git a/crates/coglet-python/src/lib.rs b/crates/coglet-python/src/lib.rs index 35130783ba..980b6c9cfd 100644 --- a/crates/coglet-python/src/lib.rs +++ b/crates/coglet-python/src/lib.rs @@ -188,11 +188,25 @@ fn detect_version(py: Python<'_>, build: &BuildInfo) -> VersionInfo { fn read_max_concurrency() -> usize { match std::env::var("COG_MAX_CONCURRENCY") { - Ok(val) => val.parse::().unwrap_or(1), + Ok(val) => match parse_max_concurrency(&val) { + Some(n) => n, + None => { + warn!(value = %val, "Invalid COG_MAX_CONCURRENCY value, defaulting to 1"); + 1 + } + }, Err(_) => 1, } } +fn parse_max_concurrency(val: &str) -> Option { + match val.parse::() { + Ok(0) => None, + Ok(n) => Some(n), + Err(_) => None, + } +} + fn read_setup_timeout() -> Option { match std::env::var("COG_SETUP_TIMEOUT") { Ok(val) => match val.parse::() { @@ -542,10 +556,10 @@ async fn run_worker_with_init() -> Result<(), String> { info!(predictor_ref = %predictor_ref, num_slots, is_train, "Init received, connecting to transport"); let handler = Arc::new(if is_train { - worker_bridge::PythonPredictHandler::new_train(predictor_ref) + worker_bridge::PythonPredictHandler::new_train(predictor_ref, num_slots) .map_err(|e| format!("Failed to create handler: {}", e))? } else { - worker_bridge::PythonPredictHandler::new(predictor_ref) + worker_bridge::PythonPredictHandler::new(predictor_ref, num_slots) .map_err(|e| format!("Failed to create handler: {}", e))? }); @@ -568,6 +582,26 @@ async fn run_worker_with_init() -> Result<(), String> { .map_err(|e| format!("Worker error: {}", e)) } +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn parse_max_concurrency_reads_valid_value() { + assert_eq!(parse_max_concurrency("4"), Some(4)); + } + + #[test] + fn parse_max_concurrency_rejects_invalid_value() { + assert_eq!(parse_max_concurrency("wat"), None); + } + + #[test] + fn parse_max_concurrency_rejects_zero() { + assert_eq!(parse_max_concurrency("0"), None); + } +} + // ============================================================================= // Module init // ============================================================================= diff --git a/crates/coglet-python/src/worker_bridge.rs b/crates/coglet-python/src/worker_bridge.rs index 4378e729f8..3e697b364b 100644 --- a/crates/coglet-python/src/worker_bridge.rs +++ b/crates/coglet-python/src/worker_bridge.rs @@ -101,11 +101,12 @@ pub struct PythonPredictHandler { async_loop: Mutex>>, /// Handle to the asyncio loop thread for joining on shutdown. async_thread: Mutex>>, + max_concurrency: usize, } impl PythonPredictHandler { /// Create a handler in prediction mode. - pub fn new(predictor_ref: String) -> Result { + pub fn new(predictor_ref: String, max_concurrency: usize) -> Result { let (loop_obj, thread) = Self::init_async_loop()?; Ok(Self { predictor_ref, @@ -114,6 +115,7 @@ impl PythonPredictHandler { mode: HandlerMode::Predict, async_loop: Mutex::new(Some(loop_obj)), async_thread: Mutex::new(Some(thread)), + max_concurrency, }) } @@ -122,7 +124,7 @@ impl PythonPredictHandler { /// NOTE: For bug-for-bug compatibility with cog mainline, use new() instead. /// Cog mainline's training routes incorrectly use a predict-mode worker. #[allow(dead_code)] - pub fn new_train(predictor_ref: String) -> Result { + pub fn new_train(predictor_ref: String, max_concurrency: usize) -> Result { let (loop_obj, thread) = Self::init_async_loop()?; Ok(Self { predictor_ref, @@ -131,6 +133,7 @@ impl PythonPredictHandler { mode: HandlerMode::Train, async_loop: Mutex::new(Some(loop_obj)), async_thread: Mutex::new(Some(thread)), + max_concurrency, }) } @@ -281,6 +284,11 @@ impl PredictHandler for PythonPredictHandler { let pred = PythonPredictor::load(py, &self.predictor_ref) .map_err(|e| SetupError::load(e.to_string()))?; + if self.max_concurrency > 1 && !pred.is_async() { + return Err(SetupError::setup( + "COG_MAX_CONCURRENCY > 1 requires an async run() or predict() method", + )); + } // Detect SDK implementation let sdk_impl = match py.import("cog") { diff --git a/docs/deploy.md b/docs/deploy.md index 63a5fa83c7..a1e833786f 100644 --- a/docs/deploy.md +++ b/docs/deploy.md @@ -76,11 +76,11 @@ curl http://localhost:5001/predictions -X POST \ ```json { - "status": "succeeded", - "output": "data:image/png;base64,...", - "metrics": { - "predict_time": 4.52 - } + "status": "succeeded", + "output": "data:image/png;base64,...", + "metrics": { + "predict_time": 4.52 + } } ``` @@ -144,8 +144,8 @@ the response contains a base64-encoded data URL by default: ```json { - "status": "succeeded", - "output": "data:image/png;base64,iVBORw0KGgo..." + "status": "succeeded", + "output": "data:image/png;base64,iVBORw0KGgo..." } ``` @@ -172,8 +172,8 @@ contains the uploaded URL instead of a data URL: ```json { - "status": "succeeded", - "output": "https://example.com/upload/image.png" + "status": "succeeded", + "output": "https://example.com/upload/image.png" } ``` @@ -215,21 +215,25 @@ to stop.) ## Concurrency -By default, the server processes one run at a time. To enable concurrent runs, set the `concurrency.max` option in `cog.yaml`: +By default, the server processes one run at a time. To enable concurrent runs, make your `run()` method async and decorate it with `@cog.concurrent(max=N)`: -```yaml -concurrency: - max: 4 +```py +import cog + +class Runner(cog.BaseRunner): + @cog.concurrent(max=4) + async def run(self) -> str: + return "hello world" ``` -See the [`cog.yaml` reference](yaml.md#concurrency) for more details. +The deprecated [`concurrency.max`](yaml.md#concurrency) field in `cog.yaml` is still supported and takes precedence over the decorator by baking `COG_MAX_CONCURRENCY` into the image. ## Environment variables You can configure runtime behavior with environment variables: - `COG_SETUP_TIMEOUT`: Maximum time in seconds for the `setup()` method (default: no timeout). -- `COG_MAX_CONCURRENCY`: Number of concurrent prediction slots (default: 1). +- `COG_MAX_CONCURRENCY`: Number of concurrent prediction slots (default: 1). Overrides both `@cog.concurrent` and deprecated `cog.yaml` concurrency. See the [environment variables reference](environment.md) for the full list. diff --git a/docs/environment.md b/docs/environment.md index bbf02e2a90..ee9b3afc6b 100644 --- a/docs/environment.md +++ b/docs/environment.md @@ -173,9 +173,18 @@ These variables affect a running model server. Set them in `cog.yaml` under `env ### `COG_MAX_CONCURRENCY` -Controls how many predictions the model server can run concurrently. +Controls how many predictions the model server can run concurrently. This overrides both `@cog.concurrent(max=N)` and the deprecated `concurrency.max` field in `cog.yaml`. -By default, Cog runs one prediction at a time. Invalid values are ignored and the default of `1` is used. +By default, Cog runs one prediction at a time unless the model uses `@cog.concurrent(max=N)`. Invalid values are ignored and the default of `1` is used. + +Values greater than `1` require an async `run()` method. This applies even when `COG_MAX_CONCURRENCY` is set as a runtime operator override. + +Concurrency is resolved in this order, from highest to lowest precedence: + +1. `COG_MAX_CONCURRENCY` set at runtime +2. Deprecated `concurrency.max` in `cog.yaml`, which is baked into the image as `COG_MAX_CONCURRENCY` +3. `@cog.concurrent(max=N)` on the async `run()` method +4. Default: `1` ```console $ COG_MAX_CONCURRENCY=4 docker run -p 5000:5000 my-model diff --git a/docs/examples.md b/docs/examples.md index 973c09093b..31a930f35a 100644 --- a/docs/examples.md +++ b/docs/examples.md @@ -12,7 +12,7 @@ If you want a working project to copy from, start with one of these: - [`hello-context`](https://github.com/replicate/cog/tree/main/examples/hello-context): shows how to read prediction context - [`hello-concurrency`](https://github.com/replicate/cog/tree/main/examples/hello-concurrency): - demonstrates the `concurrency` setting in `cog.yaml` + demonstrates the `@cog.concurrent` decorator - [`hello-train`](https://github.com/replicate/cog/tree/main/examples/hello-train): defines a training interface - [`hello-replicate`](https://github.com/replicate/cog/tree/main/examples/hello-replicate): diff --git a/docs/http.md b/docs/http.md index 592c4cee59..de2a36dac4 100644 --- a/docs/http.md +++ b/docs/http.md @@ -255,7 +255,7 @@ Prefer: respond-async Endpoints for creating and canceling a prediction idempotently accept a `prediction_id` parameter in their path. By default, the server runs one prediction at a time, -but this can be increased with the [`concurrency.max`](yaml.md#concurrency) setting. +but this can be increased with [`@cog.concurrent(max=N)`](python.md#async-runners-and-concurrency). When all prediction slots are in use, the server returns `409 Conflict`. The client should ensure prediction slots are available before creating a new prediction with a different ID. diff --git a/docs/llms.txt b/docs/llms.txt index 11cb30b033..1c61c7c8cf 100644 --- a/docs/llms.txt +++ b/docs/llms.txt @@ -633,11 +633,11 @@ curl http://localhost:5001/predictions -X POST \ ```json { - "status": "succeeded", - "output": "data:image/png;base64,...", - "metrics": { - "predict_time": 4.52 - } + "status": "succeeded", + "output": "data:image/png;base64,...", + "metrics": { + "predict_time": 4.52 + } } ``` @@ -701,8 +701,8 @@ the response contains a base64-encoded data URL by default: ```json { - "status": "succeeded", - "output": "data:image/png;base64,iVBORw0KGgo..." + "status": "succeeded", + "output": "data:image/png;base64,iVBORw0KGgo..." } ``` @@ -729,8 +729,8 @@ contains the uploaded URL instead of a data URL: ```json { - "status": "succeeded", - "output": "https://example.com/upload/image.png" + "status": "succeeded", + "output": "https://example.com/upload/image.png" } ``` @@ -772,21 +772,25 @@ to stop.) ## Concurrency -By default, the server processes one run at a time. To enable concurrent runs, set the `concurrency.max` option in `cog.yaml`: +By default, the server processes one run at a time. To enable concurrent runs, make your `run()` method async and decorate it with `@cog.concurrent(max=N)`: -```yaml -concurrency: - max: 4 +```py +import cog + +class Runner(cog.BaseRunner): + @cog.concurrent(max=4) + async def run(self) -> str: + return "hello world" ``` -See the [`cog.yaml` reference](yaml.md#concurrency) for more details. +The deprecated [`concurrency.max`](yaml.md#concurrency) field in `cog.yaml` is still supported and takes precedence over the decorator by baking `COG_MAX_CONCURRENCY` into the image. ## Environment variables You can configure runtime behavior with environment variables: - `COG_SETUP_TIMEOUT`: Maximum time in seconds for the `setup()` method (default: no timeout). -- `COG_MAX_CONCURRENCY`: Number of concurrent prediction slots (default: 1). +- `COG_MAX_CONCURRENCY`: Number of concurrent prediction slots (default: 1). Overrides both `@cog.concurrent` and deprecated `cog.yaml` concurrency. See the [environment variables reference](environment.md) for the full list. @@ -974,9 +978,18 @@ These variables affect a running model server. Set them in `cog.yaml` under `env ### `COG_MAX_CONCURRENCY` -Controls how many predictions the model server can run concurrently. +Controls how many predictions the model server can run concurrently. This overrides both `@cog.concurrent(max=N)` and the deprecated `concurrency.max` field in `cog.yaml`. + +By default, Cog runs one prediction at a time unless the model uses `@cog.concurrent(max=N)`. Invalid values are ignored and the default of `1` is used. + +Values greater than `1` require an async `run()` method. This applies even when `COG_MAX_CONCURRENCY` is set as a runtime operator override. + +Concurrency is resolved in this order, from highest to lowest precedence: -By default, Cog runs one prediction at a time. Invalid values are ignored and the default of `1` is used. +1. `COG_MAX_CONCURRENCY` set at runtime +2. Deprecated `concurrency.max` in `cog.yaml`, which is baked into the image as `COG_MAX_CONCURRENCY` +3. `@cog.concurrent(max=N)` on the async `run()` method +4. Default: `1` ```console $ COG_MAX_CONCURRENCY=4 docker run -p 5000:5000 my-model @@ -1096,7 +1109,7 @@ If you want a working project to copy from, start with one of these: - [`hello-context`](https://github.com/replicate/cog/tree/main/examples/hello-context): shows how to read prediction context - [`hello-concurrency`](https://github.com/replicate/cog/tree/main/examples/hello-concurrency): - demonstrates the `concurrency` setting in `cog.yaml` + demonstrates the `@cog.concurrent` decorator - [`hello-train`](https://github.com/replicate/cog/tree/main/examples/hello-train): defines a training interface - [`hello-replicate`](https://github.com/replicate/cog/tree/main/examples/hello-replicate): @@ -1759,7 +1772,7 @@ Prefer: respond-async Endpoints for creating and canceling a prediction idempotently accept a `prediction_id` parameter in their path. By default, the server runs one prediction at a time, -but this can be increased with the [`concurrency.max`](yaml.md#concurrency) setting. +but this can be increased with [`@cog.concurrent(max=N)`](python.md#async-runners-and-concurrency). When all prediction slots are in use, the server returns `409 Conflict`. The client should ensure prediction slots are available before creating a new prediction with a different ID. @@ -2349,7 +2362,20 @@ class Runner(BaseRunner): return "hello world"; ``` -Models that have an async `run()` function can run concurrently, up to the limit specified by [`concurrency.max`](yaml.md#max) in cog.yaml. Attempting to exceed this limit will return a 409 Conflict response. +Models that have an async `run()` function can run concurrently. Use `@cog.concurrent(max=N)` to configure the default maximum concurrency for the function: + +```py +import cog + +class Runner(cog.BaseRunner): + @cog.concurrent(max=4) + async def run(self) -> str: + return "hello world" +``` + +Attempting to exceed this limit will return a 409 Conflict response. `max` values greater than `1` require an async `run()` method. The `COG_MAX_CONCURRENCY` environment variable can override the decorator at runtime. + +The `max` value must be an integer literal so Cog can configure the model at build time. Use `COG_MAX_CONCURRENCY` when you need to set concurrency dynamically at runtime. ## `Input(**kwargs)` @@ -3574,8 +3600,9 @@ build: ## `concurrency` > Added in cog 0.14.0. +> Deprecated: use [`@cog.concurrent(max=N)`](python.md#async-runners-and-concurrency) on your async `run()` method instead. -This stanza describes the concurrency capabilities of the model. It has one option: +This stanza describes the concurrency capabilities of the model. It is still supported for backwards compatibility, but new models should use `@cog.concurrent(max=N)`. It has one option: ### `max` diff --git a/docs/python.md b/docs/python.md index fe9307aa3b..218af5f33f 100644 --- a/docs/python.md +++ b/docs/python.md @@ -128,7 +128,20 @@ class Runner(BaseRunner): return "hello world"; ``` -Models that have an async `run()` function can run concurrently, up to the limit specified by [`concurrency.max`](yaml.md#max) in cog.yaml. Attempting to exceed this limit will return a 409 Conflict response. +Models that have an async `run()` function can run concurrently. Use `@cog.concurrent(max=N)` to configure the default maximum concurrency for the function: + +```py +import cog + +class Runner(cog.BaseRunner): + @cog.concurrent(max=4) + async def run(self) -> str: + return "hello world" +``` + +Attempting to exceed this limit will return a 409 Conflict response. `max` values greater than `1` require an async `run()` method. The `COG_MAX_CONCURRENCY` environment variable can override the decorator at runtime. + +The `max` value must be an integer literal so Cog can configure the model at build time. Use `COG_MAX_CONCURRENCY` when you need to set concurrency dynamically at runtime. ## `Input(**kwargs)` diff --git a/docs/yaml.md b/docs/yaml.md index 535b9aafa4..4c0000f8bb 100644 --- a/docs/yaml.md +++ b/docs/yaml.md @@ -189,8 +189,9 @@ build: ## `concurrency` > Added in cog 0.14.0. +> Deprecated: use [`@cog.concurrent(max=N)`](python.md#async-runners-and-concurrency) on your async `run()` method instead. -This stanza describes the concurrency capabilities of the model. It has one option: +This stanza describes the concurrency capabilities of the model. It is still supported for backwards compatibility, but new models should use `@cog.concurrent(max=N)`. It has one option: ### `max` diff --git a/examples/hello-concurrency/README.md b/examples/hello-concurrency/README.md index a3e85c8552..68c8cfed3e 100644 --- a/examples/hello-concurrency/README.md +++ b/examples/hello-concurrency/README.md @@ -1,13 +1,16 @@ # hello-concurrency -This is an example Cog project that demonstrates the newly added concurrency support within -cog >= 0.14.0. +This is an example Cog project that demonstrates concurrency support within Cog. -The key piece is the new `concurrency` field in the cog.yaml. +The key piece is the `@concurrent(max=4)` decorator on the async `run()` method. -```yaml -concurrency: - max: 4 +```py +from cog import BaseRunner, concurrent + +class Runner(BaseRunner): + @concurrent(max=4) + async def run(self) -> str: + return "hello" ``` This combined with the async setup and run methods in `run.py` allows Cog to run up to diff --git a/examples/hello-concurrency/cog.yaml b/examples/hello-concurrency/cog.yaml index 9be8201142..83e0e89397 100644 --- a/examples/hello-concurrency/cog.yaml +++ b/examples/hello-concurrency/cog.yaml @@ -5,5 +5,3 @@ build: python_version: "3.12" python_requirements: requirements.txt run: "run.py:Runner" -concurrency: - max: 4 diff --git a/examples/hello-concurrency/run.py b/examples/hello-concurrency/run.py index 090f5d562e..f5616bbabc 100644 --- a/examples/hello-concurrency/run.py +++ b/examples/hello-concurrency/run.py @@ -19,6 +19,7 @@ BaseRunner, Input, __version__, + concurrent, current_scope, ) @@ -72,6 +73,7 @@ async def setup(self) -> None: logging.info(f"completed setup in {duration} seconds") span.set_attribute("model.setup_time_seconds", duration) + @concurrent(max=4) async def run( # pyright: ignore self, total: int = Input(default=5), diff --git a/integration-tests/concurrent/concurrent_test.go b/integration-tests/concurrent/concurrent_test.go index bb1cc2ece6..b5ce645ce7 100644 --- a/integration-tests/concurrent/concurrent_test.go +++ b/integration-tests/concurrent/concurrent_test.go @@ -259,15 +259,14 @@ func allocatePort() (int, error) { const cogYAML = `build: python_version: "3.11" predict: "predict.py:Predictor" -concurrency: - max: 5 ` const predictPy = `import asyncio -from cog import BasePredictor +from cog import BasePredictor, concurrent class Predictor(BasePredictor): + @concurrent(max=5) async def predict(self, s: str, sleep: float) -> str: await asyncio.sleep(sleep) return f"wake up {s}" @@ -399,14 +398,13 @@ func TestConcurrentAsyncMetrics(t *testing.T) { metricsCogYAML := `build: python_version: "3.12" predict: "predict.py:Predictor" -concurrency: - max: 5 ` metricsPredictPy := `import asyncio -from cog import BasePredictor, current_scope +from cog import BasePredictor, concurrent, current_scope class Predictor(BasePredictor): + @concurrent(max=5) async def predict(self, idx: int = 0, sleep: float = 0.5) -> str: scope = current_scope() scope.record_metric("prediction_index", idx) diff --git a/pkg/config/validate.go b/pkg/config/validate.go index 71aaaa2437..cd16752af0 100644 --- a/pkg/config/validate.go +++ b/pkg/config/validate.go @@ -688,6 +688,14 @@ func checkDeprecatedFields(cfg *configFile, result *ValidationResult) { }) } + if cfg.Concurrency != nil && cfg.Concurrency.Max != nil { + result.AddWarning(DeprecationWarning{ + Field: "concurrency.max", + Replacement: "@cog.concurrent(max=...)", + Message: "configure prediction concurrency with @cog.concurrent on your async run() method instead", + }) + } + if cfg.Build == nil { return } diff --git a/pkg/config/validate_test.go b/pkg/config/validate_test.go index 25ea320695..316698ecf0 100644 --- a/pkg/config/validate_test.go +++ b/pkg/config/validate_test.go @@ -159,6 +159,25 @@ func TestValidateConfigFileConcurrencyType(t *testing.T) { result := ValidateConfigFile(cfg) require.False(t, result.HasErrors(), "expected no errors, got: %v", result.Errors) + require.Contains(t, result.Warnings, DeprecationWarning{ + Field: "concurrency.max", + Replacement: "@cog.concurrent(max=...)", + Message: "configure prediction concurrency with @cog.concurrent on your async run() method instead", + }) +} + +func TestValidateConfigFileConcurrencyDeprecationWithoutBuild(t *testing.T) { + cfg := &configFile{ + Run: new("run.py:Runner"), + Concurrency: &concurrencyFile{ + Max: new(2), + }, + } + + result := ValidateConfigFile(cfg) + require.False(t, result.HasErrors()) + require.Len(t, result.Warnings, 1) + require.Equal(t, "concurrency.max", result.Warnings[0].Field) } func TestValidateConfigFileDeprecatedPythonPackages(t *testing.T) { diff --git a/pkg/image/build.go b/pkg/image/build.go index eb8a98477e..57362f4423 100644 --- a/pkg/image/build.go +++ b/pkg/image/build.go @@ -13,6 +13,7 @@ import ( "os/exec" "path/filepath" "slices" + "strconv" "strings" "time" @@ -124,6 +125,10 @@ func Build( // Generate schema before the Docker build so schema errors fail fast and the // schema file is available in the build context. var schemaJSON []byte + predictorInfo, err := generatePredictorMetadata(cfg, dir) + if err != nil { + return "", fmt.Errorf("image build failed: %w", err) + } switch { case needsSchema: if err := validateStaticSchemaSDKVersion(cfg); err != nil { @@ -159,6 +164,11 @@ func Build( } } + dockerfileCfg, err := configWithDecoratorConcurrency(cfg, predictorInfo) + if err != nil { + return "", err + } + // --- Runtime weights manifest (/.cog/weights.json) --- // When managed weights are configured and a lockfile exists, project the // lockfile to the minimal runtime manifest (spec ยง3.3) and write it into @@ -206,8 +216,11 @@ func Build( if _, err := dockerCommand.ImageBuild(ctx, buildOpts); err != nil { return "", fmt.Errorf("Failed to build Docker image: %w", err) } + if err := addConcurrencyToCustomDockerfileImage(ctx, dockerCommand, tmpImageId, dockerfileCfg.Concurrency, progressOutput, bp.buildDir); err != nil { + return "", err + } } else { - generator, err := dockerfile.NewStandardGenerator(cfg, dir, bp.buildDir, configFilename, dockerCommand, client, true) + generator, err := dockerfile.NewStandardGenerator(dockerfileCfg, dir, bp.buildDir, configFilename, dockerCommand, client, true) if err != nil { return "", fmt.Errorf("Error creating Dockerfile generator: %w", err) } @@ -446,6 +459,83 @@ func generateStaticSchema(cfg *config.Config, dir string) ([]byte, error) { return nil, fmt.Errorf("no predict or train reference found in cog.yaml") } return schema.GenerateCombined(dir, cfg.Predict, cfg.Train, schema.PathAwareParser(python.ParsePredictorWithSourcePath)) +} + +func generatePredictorMetadata(cfg *config.Config, dir string) (*schema.PredictorInfo, error) { + if cfg.Predict == "" && cfg.Train == "" { + return nil, nil + } + if cfg.Predict != "" { + return schema.GenerateInfo(cfg.Predict, dir, schema.ModePredict, schema.PathAwareParser(python.ParsePredictorMetadataWithSourcePath)) + } + return schema.GenerateInfo(cfg.Train, dir, schema.ModeTrain, schema.PathAwareParser(python.ParsePredictorMetadataWithSourcePath)) +} + +func configWithDecoratorConcurrency(cfg *config.Config, info *schema.PredictorInfo) (*config.Config, error) { + if cfg.Concurrency != nil && cfg.Concurrency.Max > 1 && info != nil && !info.IsAsync { + return nil, fmt.Errorf("concurrency.max > 1 requires an async run() or predict() method") + } + if cfg.Concurrency != nil || info == nil || info.ConcurrencyMax == nil { + return cfg, validateEffectiveConcurrency(cfg, info) + } + if *info.ConcurrencyMax > 1 && !info.IsAsync { + return nil, fmt.Errorf("@cog.concurrent(max > 1) requires an async run() or predict() method") + } + cfgCopy := *cfg + cfgCopy.Concurrency = &config.Concurrency{Max: *info.ConcurrencyMax} + return &cfgCopy, validateEffectiveConcurrency(&cfgCopy, info) +} + +func addConcurrencyToCustomDockerfileImage(ctx context.Context, dockerCommand command.Command, imageName string, concurrency *config.Concurrency, progressOutput string, buildCacheDir string) error { + if concurrency == nil || concurrency.Max <= 0 { + return nil + } + buildOpts := command.ImageBuildOptions{ + DockerfileContents: concurrencyDockerfile(imageName, concurrency.Max), + ImageName: imageName, + ProgressOutput: progressOutput, + BuildCacheDir: buildCacheDir, + } + if _, err := dockerCommand.ImageBuild(ctx, buildOpts); err != nil { + return fmt.Errorf("Failed to add concurrency configuration to Docker image: %w", err) + } + return nil +} + +func validateEffectiveConcurrency(cfg *config.Config, info *schema.PredictorInfo) error { + if cfg.Concurrency == nil || cfg.Concurrency.Max <= 1 { + return nil + } + if info != nil && !info.IsAsync { + return fmt.Errorf("concurrency.max > 1 requires an async run() or predict() method") + } + if cfg.Build == nil || cfg.Build.PythonVersion == "" { + return nil + } + major, minor, err := splitPythonVersionForBuild(cfg.Build.PythonVersion) + if err != nil { + return nil + } + if major == config.MinimumMajorPythonVersion && minor < config.MinimumMinorPythonVersionForConcurrency { + return fmt.Errorf("concurrency requires Python %d.%d or higher", config.MinimumMajorPythonVersion, config.MinimumMinorPythonVersionForConcurrency) + } + return nil +} + +func splitPythonVersionForBuild(version string) (major int, minor int, err error) { + parts := strings.Split(version, ".") + if len(parts) < 2 { + return 0, 0, fmt.Errorf("invalid Python version %q", version) + } + major, err = strconv.Atoi(parts[0]) + if err != nil { + return 0, 0, err + } + minor, err = strconv.Atoi(parts[1]) + if err != nil { + return 0, 0, err + } + return major, minor, nil } @@ -546,6 +636,10 @@ func bundleDockerfile(baseImage string, files []string) string { return b.String() } +func concurrencyDockerfile(baseImage string, maxConcurrency int) string { + return fmt.Sprintf("FROM %s\nENV COG_MAX_CONCURRENCY=%d\n", baseImage, maxConcurrency) +} + func isGitWorkTree(ctx context.Context, dir string) bool { ctx, cancel := context.WithTimeout(ctx, 3*time.Second) defer cancel() diff --git a/pkg/image/build_test.go b/pkg/image/build_test.go index bf2af1e229..1089e22e51 100644 --- a/pkg/image/build_test.go +++ b/pkg/image/build_test.go @@ -14,7 +14,10 @@ import ( "github.com/stretchr/testify/require" "github.com/replicate/cog/pkg/config" + "github.com/replicate/cog/pkg/docker/command" + "github.com/replicate/cog/pkg/docker/dockertest" "github.com/replicate/cog/pkg/dotcog" + "github.com/replicate/cog/pkg/schema" "github.com/replicate/cog/pkg/weights/lockfile" ) @@ -62,6 +65,103 @@ func TestIsGitWorkTree(t *testing.T) { r.True(isGitWorkTree(ctx, setupGitWorkTree(t))) } +func TestConfigWithDecoratorConcurrencyAppliesWhenYAMLAbsent(t *testing.T) { + concurrencyMax := 4 + cfg := &config.Config{} + info := &schema.PredictorInfo{ConcurrencyMax: &concurrencyMax, IsAsync: true} + + got, err := configWithDecoratorConcurrency(cfg, info) + + require.NoError(t, err) + require.NotSame(t, cfg, got) + require.Nil(t, cfg.Concurrency) + require.NotNil(t, got.Concurrency) + require.Equal(t, 4, got.Concurrency.Max) +} + +func TestConfigWithDecoratorConcurrencyPreservesYAMLPrecedence(t *testing.T) { + concurrencyMax := 4 + cfg := &config.Config{Concurrency: &config.Concurrency{Max: 2}} + info := &schema.PredictorInfo{ConcurrencyMax: &concurrencyMax, IsAsync: true} + + got, err := configWithDecoratorConcurrency(cfg, info) + + require.NoError(t, err) + require.Same(t, cfg, got) + require.Equal(t, 2, got.Concurrency.Max) +} + +func TestConfigWithDecoratorConcurrencyRejectsSyncYAMLConcurrency(t *testing.T) { + cfg := &config.Config{Concurrency: &config.Concurrency{Max: 2}} + info := &schema.PredictorInfo{IsAsync: false} + + _, err := configWithDecoratorConcurrency(cfg, info) + + require.Error(t, err) + require.Contains(t, err.Error(), "requires an async") +} + +func TestConfigWithDecoratorConcurrencyRejectsOldPythonVersion(t *testing.T) { + concurrencyMax := 2 + cfg := &config.Config{Build: &config.Build{PythonVersion: "3.10"}} + info := &schema.PredictorInfo{ConcurrencyMax: &concurrencyMax, IsAsync: true} + + _, err := configWithDecoratorConcurrency(cfg, info) + + require.Error(t, err) + require.Contains(t, err.Error(), "Python 3.11 or higher") +} + +func TestConcurrencyDockerfileSetsEnv(t *testing.T) { + dockerfile := concurrencyDockerfile("my-image", 4) + + require.Equal(t, "FROM my-image\nENV COG_MAX_CONCURRENCY=4\n", dockerfile) +} + +type recordingCommand struct { + *dockertest.MockCommand + builds []command.ImageBuildOptions +} + +func (c *recordingCommand) ImageBuild(ctx context.Context, options command.ImageBuildOptions) (string, error) { + c.builds = append(c.builds, options) + return "sha256:test", nil +} + +func TestAddConcurrencyToCustomDockerfileImageBuildsWrapperLayer(t *testing.T) { + dockerCommand := &recordingCommand{MockCommand: dockertest.NewMockCommand()} + concurrency := &config.Concurrency{Max: 4} + + err := addConcurrencyToCustomDockerfileImage(t.Context(), dockerCommand, "my-image", concurrency, "plain", "/tmp/build-cache") + + require.NoError(t, err) + require.Len(t, dockerCommand.builds, 1) + require.Equal(t, "FROM my-image\nENV COG_MAX_CONCURRENCY=4\n", dockerCommand.builds[0].DockerfileContents) + require.Equal(t, "my-image", dockerCommand.builds[0].ImageName) + require.Equal(t, "plain", dockerCommand.builds[0].ProgressOutput) + require.Equal(t, "/tmp/build-cache", dockerCommand.builds[0].BuildCacheDir) +} + +func TestGeneratePredictorMetadataDoesNotRequireValidOutputSchema(t *testing.T) { + dir := t.TempDir() + require.NoError(t, os.WriteFile(filepath.Join(dir, "predict.py"), []byte(` +from cog import BasePredictor, concurrent + +class Predictor(BasePredictor): + @concurrent(max=3) + async def predict(self) -> NotARealType: + return "hello" +`), 0o644)) + cfg := &config.Config{Predict: "predict.py:Predictor"} + + info, err := generatePredictorMetadata(cfg, dir) + + require.NoError(t, err) + require.NotNil(t, info.ConcurrencyMax) + require.Equal(t, 3, *info.ConcurrencyMax) + require.True(t, info.IsAsync) +} + func TestGitHead(t *testing.T) { t.Run("via github env", func(t *testing.T) { t.Setenv("GITHUB_SHA", "fafafaf") diff --git a/pkg/schema/generator.go b/pkg/schema/generator.go index efdddebfd0..790873beb2 100644 --- a/pkg/schema/generator.go +++ b/pkg/schema/generator.go @@ -40,6 +40,15 @@ func Generate(predictRef string, sourceDir string, mode Mode, parse any) ([]byte return data, nil } + info, err := GenerateInfo(predictRef, sourceDir, mode, parse) + if err != nil { + return nil, err + } + return GenerateOpenAPISchema(info) +} + +// GenerateInfo parses predictor source and returns the extracted predictor info. +func GenerateInfo(predictRef string, sourceDir string, mode Mode, parse any) (*PredictorInfo, error) { filePath, className, err := parsePredictRef(predictRef) if err != nil { return nil, err @@ -59,7 +68,7 @@ func Generate(predictRef string, sourceDir string, mode Mode, parse any) ([]byte return nil, fmt.Errorf("failed to read predictor source %s: %w", fullPath, err) } - return generateFromSource(source, className, mode, parse, sourceDir, filePath) + return parsePredictorInfo(parse, source, className, mode, sourceDir, filePath) } func cleanSourceFilePath(filePath string) (string, error) { diff --git a/pkg/schema/python/annotations.go b/pkg/schema/python/annotations.go index 5310da4c21..af791c899a 100644 --- a/pkg/schema/python/annotations.go +++ b/pkg/schema/python/annotations.go @@ -2,6 +2,7 @@ package python import ( "fmt" + "math" "strings" sitter "github.com/smacker/go-tree-sitter" @@ -243,6 +244,73 @@ func functionSupportsStreaming(node *sitter.Node, source []byte, imports *schema return false } +func functionIsAsync(node *sitter.Node, source []byte) bool { + return strings.HasPrefix(strings.TrimSpace(Content(node, source)), "async def ") +} + +func functionConcurrencyMax(node *sitter.Node, source []byte, imports *schema.ImportContext) (*int, error) { + decorated := decoratedFunctionNode(node) + if decorated == nil { + return nil, nil + } + + for _, child := range NamedChildren(decorated) { + if child.Type() != "decorator" { + continue + } + concurrencyMax, ok, err := decoratorCogConcurrentMax(child, source, imports) + if err != nil { + return nil, err + } + if ok { + return concurrencyMax, nil + } + } + return nil, nil +} + +func decoratorCogConcurrentMax(node *sitter.Node, source []byte, imports *schema.ImportContext) (*int, bool, error) { + expr := decoratorExpression(node) + if expr == nil || !expressionIsCogConcurrent(expr, source, imports) { + return nil, false, nil + } + + call := decoratorCall(node) + if call == nil { + concurrencyMax := 1 + return &concurrencyMax, true, nil + } + + argList := callArgumentList(call) + if argList == nil || len(NamedChildren(argList)) == 0 { + concurrencyMax := 1 + return &concurrencyMax, true, nil + } + + args := NamedChildren(argList) + if len(args) != 1 || args[0].Type() != "keyword_argument" { + return nil, true, schema.WrapError(schema.ErrUnsupportedType, "@concurrent only supports literal integer max=... arguments", nil) + } + nameNode := args[0].ChildByFieldName("name") + valueNode := args[0].ChildByFieldName("value") + if nameNode == nil || valueNode == nil || Content(nameNode, source) != "max" { + return nil, true, schema.WrapError(schema.ErrUnsupportedType, "@concurrent only supports literal integer max=... arguments", nil) + } + val, ok := parseDefaultValue(valueNode, source) + if !ok || val.Kind != schema.DefaultInt { + return nil, true, schema.WrapError(schema.ErrUnsupportedType, "@concurrent max must be an integer literal", nil) + } + if val.Int < 1 { + return nil, true, schema.WrapError(schema.ErrUnsupportedType, "@concurrent max must be at least 1", nil) + } + maxSupportedInt := int64(math.MaxInt) + if val.Int > maxSupportedInt { + return nil, true, schema.WrapError(schema.ErrUnsupportedType, "@concurrent max is too large", nil) + } + concurrencyMax := int(val.Int) + return &concurrencyMax, true, nil +} + func decoratedFunctionNode(node *sitter.Node) *sitter.Node { if node.Type() == "decorated_definition" { return node @@ -277,6 +345,23 @@ func decoratorExpression(node *sitter.Node) *sitter.Node { return child } +func decoratorCall(node *sitter.Node) *sitter.Node { + children := NamedChildren(node) + if len(children) == 0 || children[0].Type() != "call" { + return nil + } + return children[0] +} + +func callArgumentList(node *sitter.Node) *sitter.Node { + for _, child := range NamedChildren(node) { + if child.Type() == "argument_list" { + return child + } + } + return nil +} + func expressionIsCogStreaming(node *sitter.Node, source []byte, imports *schema.ImportContext) bool { switch node.Type() { case "attribute": @@ -288,6 +373,17 @@ func expressionIsCogStreaming(node *sitter.Node, source []byte, imports *schema. } } +func expressionIsCogConcurrent(node *sitter.Node, source []byte, imports *schema.ImportContext) bool { + switch node.Type() { + case "attribute": + return attributeIsCogConcurrent(node, source, imports) + case "identifier": + return identifierIsCogConcurrent(node, source, imports) + default: + return false + } +} + func attributeIsCogStreaming(node *sitter.Node, source []byte, imports *schema.ImportContext) bool { name, attr, ok := strings.Cut(Content(node, source), ".") if !ok || attr != "streaming" { @@ -297,11 +393,25 @@ func attributeIsCogStreaming(node *sitter.Node, source []byte, imports *schema.I return ok && entry.Module == "cog" && entry.Original == "cog" } +func attributeIsCogConcurrent(node *sitter.Node, source []byte, imports *schema.ImportContext) bool { + name, attr, ok := strings.Cut(Content(node, source), ".") + if !ok || attr != "concurrent" { + return false + } + entry, ok := imports.Names.Get(name) + return ok && entry.Module == "cog" && entry.Original == "cog" +} + func identifierIsCogStreaming(node *sitter.Node, source []byte, imports *schema.ImportContext) bool { entry, ok := imports.Names.Get(Content(node, source)) return ok && entry.Module == "cog" && entry.Original == "streaming" } +func identifierIsCogConcurrent(node *sitter.Node, source []byte, imports *schema.ImportContext) bool { + entry, ok := imports.Names.Get(Content(node, source)) + return ok && entry.Module == "cog" && entry.Original == "concurrent" +} + func supportsStreamingOutput(output schema.SchemaType) bool { return output.Kind == schema.SchemaIterator || output.Kind == schema.SchemaConcatIterator } diff --git a/pkg/schema/python/parser.go b/pkg/schema/python/parser.go index 4ee6557d33..093a1e6e85 100644 --- a/pkg/schema/python/parser.go +++ b/pkg/schema/python/parser.go @@ -35,6 +35,47 @@ func ParsePredictorWithSourcePath(source []byte, targetRef string, mode schema.M return ParseWithOptions(opts) } +func ParsePredictorMetadataWithSourcePath(source []byte, targetRef string, mode schema.Mode, sourceDir string, sourcePath string) (*schema.PredictorInfo, error) { + opts := defaultParserOptions(source, targetRef, mode, sourceDir) + opts.SourcePath = sourcePath + return ParseMetadataWithOptions(opts) +} + +func ParseMetadataWithOptions(opts ParserOptions) (*schema.PredictorInfo, error) { + sourcePath, err := normalizeSourcePath(opts.SourceDir, opts.SourcePath) + if err != nil { + return nil, err + } + opts.SourcePath = sourcePath + state := newParseState(opts) + phases := []parserPhase{ + {Name: "parse module", From: phaseInitial, To: phaseModuleParsed, Run: parseModulePhase}, + {Name: "collect imports", From: phaseModuleParsed, To: phaseImportsCollected, Run: collectImportsPhase}, + {Name: "collect module scope", From: phaseImportsCollected, To: phaseModuleScopeCollected, Run: collectModuleScopePhase}, + {Name: "collect local models", From: phaseModuleScopeCollected, To: phaseLocalModelsCollected, Run: collectLocalModelsPhase}, + {Name: "resolve imported models", From: phaseLocalModelsCollected, To: phaseImportedModelsResolved, Run: resolveImportedModelsPhase}, + {Name: "collect input registry", From: phaseImportedModelsResolved, To: phaseInputRegistryCollected, Run: collectInputRegistryPhase}, + {Name: "find target callable", From: phaseInputRegistryCollected, To: phaseTargetFound, Run: findTargetCallablePhase}, + } + if err := runPhases(state, phases); err != nil { + return nil, err + } + targetSource := state.TargetFunc.file.source + isAsync := functionIsAsync(state.TargetFunc.node, targetSource) + concurrencyMax, err := functionConcurrencyMax(state.TargetFunc.node, targetSource, state.TargetFunc.file.imports) + if err != nil { + return nil, err + } + if concurrencyMax != nil && *concurrencyMax > 1 && !isAsync { + return nil, schema.WrapError(schema.ErrUnsupportedType, "@concurrent(max > 1) requires an async run() or predict() method", nil) + } + return &schema.PredictorInfo{ + Mode: opts.Mode, + ConcurrencyMax: concurrencyMax, + IsAsync: isAsync, + }, nil +} + func ParseWithOptions(opts ParserOptions) (*schema.PredictorInfo, error) { sourcePath, err := normalizeSourcePath(opts.SourceDir, opts.SourcePath) if err != nil { @@ -62,6 +103,8 @@ func ParseWithOptions(opts ParserOptions) (*schema.PredictorInfo, error) { Output: state.Output, Mode: opts.Mode, SupportsStreaming: state.SupportsStreaming, + ConcurrencyMax: state.ConcurrencyMax, + IsAsync: state.IsAsync, }, nil } @@ -208,6 +251,7 @@ func extractInputsPhase(state *ParseState) error { } state.Target = &TargetCallable{MethodName: actualMethodName, Node: funcNode, IsMethod: isMethod} state.Inputs = inputs + state.IsAsync = functionIsAsync(funcNode, targetSource) return nil } @@ -231,8 +275,16 @@ func resolveOutputPhase(state *ParseState) error { if supportsStreaming && !supportsStreamingOutput(output) { return schema.WrapError(schema.ErrUnsupportedType, "@streaming requires Iterator[...], AsyncIterator[...], ConcatenateIterator[...] or AsyncConcatenateIterator[...] return type", nil) } + concurrencyMax, err := functionConcurrencyMax(state.TargetFunc.node, state.TargetFunc.file.source, state.TargetFunc.file.imports) + if err != nil { + return err + } + if concurrencyMax != nil && *concurrencyMax > 1 && !state.IsAsync { + return schema.WrapError(schema.ErrUnsupportedType, "@concurrent(max > 1) requires an async run() or predict() method", nil) + } state.Output = output state.SupportsStreaming = supportsStreaming + state.ConcurrencyMax = concurrencyMax state.OutputSet = true return nil } diff --git a/pkg/schema/python/parser_test.go b/pkg/schema/python/parser_test.go index de0a71f267..7f094cceda 100644 --- a/pkg/schema/python/parser_test.go +++ b/pkg/schema/python/parser_test.go @@ -963,6 +963,178 @@ class Predictor(cog.BasePredictor): require.False(t, info.SupportsStreaming) } +func TestConcurrentDecoratorQualifiedMax(t *testing.T) { + source := ` +import cog + +class Predictor(cog.BasePredictor): + @cog.concurrent(max=4) + async def predict(self) -> str: + return "hello" +` + info := parse(t, source, "Predictor") + require.NotNil(t, info.ConcurrencyMax) + require.Equal(t, 4, *info.ConcurrencyMax) + require.True(t, info.IsAsync) +} + +func TestConcurrentDecoratorImportedMax(t *testing.T) { + source := ` +from cog import BasePredictor, concurrent + +class Predictor(BasePredictor): + @concurrent(max=3) + async def predict(self) -> str: + return "hello" +` + info := parse(t, source, "Predictor") + require.NotNil(t, info.ConcurrencyMax) + require.Equal(t, 3, *info.ConcurrencyMax) +} + +func TestConcurrentDecoratorBareDefaultsToOne(t *testing.T) { + source := ` +from cog import BasePredictor, concurrent + +class Predictor(BasePredictor): + @concurrent + def predict(self) -> str: + return "hello" +` + info := parse(t, source, "Predictor") + require.NotNil(t, info.ConcurrencyMax) + require.Equal(t, 1, *info.ConcurrencyMax) + require.False(t, info.IsAsync) +} + +func TestConcurrentDecoratorCallDefaultsToOne(t *testing.T) { + source := ` +import cog + +class Predictor(cog.BasePredictor): + @cog.concurrent() + def predict(self) -> str: + return "hello" +` + info := parse(t, source, "Predictor") + require.NotNil(t, info.ConcurrencyMax) + require.Equal(t, 1, *info.ConcurrencyMax) +} + +func TestConcurrentDecoratorImportedAliasMax(t *testing.T) { + source := ` +from cog import BasePredictor, concurrent as concurrent_predictions + +class Predictor(BasePredictor): + @concurrent_predictions(max=2) + async def predict(self) -> str: + return "hello" +` + info := parse(t, source, "Predictor") + require.NotNil(t, info.ConcurrencyMax) + require.Equal(t, 2, *info.ConcurrencyMax) +} + +func TestConcurrentDecoratorQualifiedAliasMax(t *testing.T) { + source := ` +import cog as c + +class Predictor(c.BasePredictor): + @c.concurrent(max=4) + async def predict(self) -> str: + return "hello" +` + info := parse(t, source, "Predictor") + require.NotNil(t, info.ConcurrencyMax) + require.Equal(t, 4, *info.ConcurrencyMax) + require.True(t, info.IsAsync) +} + +func TestConcurrentDecoratorIgnoredWhenNotFromCog(t *testing.T) { + source := ` +from other import concurrent +from cog import BasePredictor + +class Predictor(BasePredictor): + @concurrent(max=4) + async def predict(self) -> str: + return "hello" +` + info := parse(t, source, "Predictor") + require.Nil(t, info.ConcurrencyMax) +} + +func TestConcurrentDecoratorRequiresAsyncForMaxGreaterThanOne(t *testing.T) { + source := ` +from cog import BasePredictor, concurrent + +class Predictor(BasePredictor): + @concurrent(max=2) + def predict(self) -> str: + return "hello" +` + se := parseErr(t, source, "Predictor", schema.ModePredict) + require.Equal(t, schema.ErrUnsupportedType, se.Kind) + require.Contains(t, se.Message, "requires an async") +} + +func TestConcurrentDecoratorRejectsNonLiteralMax(t *testing.T) { + source := ` +from cog import BasePredictor, concurrent + +MAX_CONCURRENCY = 2 + +class Predictor(BasePredictor): + @concurrent(max=MAX_CONCURRENCY) + async def predict(self) -> str: + return "hello" +` + se := parseErr(t, source, "Predictor", schema.ModePredict) + require.Equal(t, schema.ErrUnsupportedType, se.Kind) + require.Contains(t, se.Message, "integer literal") +} + +func TestConcurrentDecoratorRejectsPositionalArgument(t *testing.T) { + source := ` +from cog import BasePredictor, concurrent + +class Predictor(BasePredictor): + @concurrent(2) + async def predict(self) -> str: + return "hello" +` + se := parseErr(t, source, "Predictor", schema.ModePredict) + require.Equal(t, schema.ErrUnsupportedType, se.Kind) + require.Contains(t, se.Message, "max=...") +} + +func TestConcurrentDecoratorRejectsStringMax(t *testing.T) { + source := ` +from cog import BasePredictor, concurrent + +class Predictor(BasePredictor): + @concurrent(max="2") + async def predict(self) -> str: + return "hello" +` + se := parseErr(t, source, "Predictor", schema.ModePredict) + require.Equal(t, schema.ErrUnsupportedType, se.Kind) + require.Contains(t, se.Message, "integer literal") +} + +func TestConcurrentDecoratorUndecoratedHasNoMax(t *testing.T) { + source := ` +from cog import BasePredictor + +class Predictor(BasePredictor): + async def predict(self) -> str: + return "hello" +` + info := parse(t, source, "Predictor") + require.Nil(t, info.ConcurrencyMax) + require.True(t, info.IsAsync) +} + func TestListOutput(t *testing.T) { source := ` from cog import BasePredictor, Path diff --git a/pkg/schema/python/state.go b/pkg/schema/python/state.go index 17533067ec..dc0421ea16 100644 --- a/pkg/schema/python/state.go +++ b/pkg/schema/python/state.go @@ -104,6 +104,8 @@ type ParseState struct { Inputs *schema.OrderedMap[string, schema.InputField] Output schema.SchemaType SupportsStreaming bool + ConcurrencyMax *int + IsAsync bool OutputSet bool } diff --git a/pkg/schema/types.go b/pkg/schema/types.go index 0d86535801..38d7878f5d 100644 --- a/pkg/schema/types.go +++ b/pkg/schema/types.go @@ -254,6 +254,8 @@ type PredictorInfo struct { Output SchemaType Mode Mode SupportsStreaming bool + ConcurrencyMax *int + IsAsync bool } // TypeAnnotation is a parsed Python type annotation (intermediate, before resolution). diff --git a/python/cog/__init__.py b/python/cog/__init__.py index 775e2ba0e9..ba55c6902f 100644 --- a/python/cog/__init__.py +++ b/python/cog/__init__.py @@ -19,6 +19,7 @@ def run( return self.model.generate(prompt, image) """ +import inspect as _inspect import sys as _sys from collections.abc import Callable from typing import TypeVar, overload @@ -64,6 +65,41 @@ def decorate(inner: _F) -> _F: return decorate(fn) +@overload +def concurrent(fn: _F) -> _F: + pass + + +@overload +def concurrent(fn: None = None, *, max: int = 1) -> Callable[[_F], _F]: # noqa: A002 + pass + + +def concurrent( + fn: _F | None = None, + *, + max: int = 1, # noqa: A002 +) -> _F | Callable[[_F], _F]: + """Configure the maximum concurrency for an async predict handler.""" + + if isinstance(max, bool) or not isinstance(max, int): + raise TypeError("concurrent max must be an integer") + if max < 1: + raise ValueError("concurrent max must be at least 1") + + def decorate(inner: _F) -> _F: + if max > 1 and not ( + _inspect.iscoroutinefunction(inner) or _inspect.isasyncgenfunction(inner) + ): + raise TypeError("concurrent max greater than 1 requires an async function") + inner.__cog_concurrent_max__ = max # type: ignore[attr-defined] + return inner + + if fn is None: + return decorate + return decorate(fn) + + # --------------------------------------------------------------------------- # Backwards-compatibility shim: ExperimentalFeatureWarning # @@ -161,6 +197,7 @@ def current_scope() -> object: "current_scope", # Decorators "streaming", + "concurrent", # Deprecated compat shims "ExperimentalFeatureWarning", "emit_metric", diff --git a/python/tests/test_concurrent.py b/python/tests/test_concurrent.py new file mode 100644 index 0000000000..399d47a241 --- /dev/null +++ b/python/tests/test_concurrent.py @@ -0,0 +1,74 @@ +"""Tests for the cog.concurrent decorator.""" + +import sys +import types +from collections.abc import AsyncIterator + +import pytest + +if "coglet" not in sys.modules: + coglet = types.ModuleType("coglet") + coglet.CancelationException = type("CancelationException", (BaseException,), {}) + sys.modules["coglet"] = coglet + +from cog import concurrent + + +def test_bare_concurrent_sets_default_max_on_sync_function() -> None: + @concurrent + def predict() -> str: + return "ok" + + assert predict.__cog_concurrent_max__ == 1 + + +def test_concurrent_call_form_sets_default_max_on_sync_function() -> None: + @concurrent() + def predict() -> str: + return "ok" + + assert predict.__cog_concurrent_max__ == 1 + + +def test_concurrent_max_one_allows_sync_function() -> None: + @concurrent(max=1) + def predict() -> str: + return "ok" + + assert predict.__cog_concurrent_max__ == 1 + + +def test_concurrent_max_allows_async_function() -> None: + @concurrent(max=3) + async def predict() -> str: + return "ok" + + assert predict.__cog_concurrent_max__ == 3 + + +def test_concurrent_max_allows_async_generator() -> None: + @concurrent(max=3) + async def predict() -> AsyncIterator[str]: + yield "ok" + + assert predict.__cog_concurrent_max__ == 3 + + +def test_concurrent_max_rejects_sync_function() -> None: + with pytest.raises(TypeError, match="requires an async function"): + + @concurrent(max=2) + def predict() -> str: + return "ok" + + +@pytest.mark.parametrize("value", [0, -1]) +def test_concurrent_rejects_max_less_than_one(value: int) -> None: + with pytest.raises(ValueError, match="at least 1"): + concurrent(max=value) + + +@pytest.mark.parametrize("value", [True, 1.5, "2"]) +def test_concurrent_rejects_non_integer_max(value: object) -> None: + with pytest.raises(TypeError, match="must be an integer"): + concurrent(max=value) # type: ignore[arg-type]