diff --git a/src/Builders/PromptBuilder.php b/src/Builders/PromptBuilder.php index 130fc574..58529328 100644 --- a/src/Builders/PromptBuilder.php +++ b/src/Builders/PromptBuilder.php @@ -21,7 +21,6 @@ use WordPress\AiClient\Providers\Http\DTO\RequestOptions; use WordPress\AiClient\Providers\Models\Contracts\ModelInterface; use WordPress\AiClient\Providers\Models\DTO\ModelConfig; -use WordPress\AiClient\Providers\Models\DTO\ModelMetadata; use WordPress\AiClient\Providers\Models\DTO\ModelRequirements; use WordPress\AiClient\Providers\Models\Enums\CapabilityEnum; use WordPress\AiClient\Providers\Models\ImageGeneration\Contracts\ImageGenerationModelInterface; @@ -76,6 +75,16 @@ class PromptBuilder */ protected ?string $providerIdOrClassName = null; + /** + * @var list Ordered provider IDs to prioritize during model discovery. + */ + protected array $providerPriority = []; + + /** + * @var bool Whether to retry with the next discovered candidate model after runtime failures. + */ + protected bool $providerFallbackEnabled = true; + /** * @var ModelConfig The model configuration. */ @@ -378,6 +387,45 @@ public function usingProvider(string $providerIdOrClassName): self return $this; } + /** + * Sets provider priority for cross-provider model discovery. + * + * @since n.e.x.t + * + * @param string ...$providerIds Provider IDs in preferred order. + * @return self The prompt builder instance. + * + * @throws InvalidArgumentException If no provider IDs are provided or an identifier is invalid. + */ + public function usingProviderPriority(string ...$providerIds): self + { + if ($providerIds === []) { + throw new InvalidArgumentException('At least one provider priority identifier must be provided.'); + } + + $normalizedProviderIds = []; + foreach ($providerIds as $providerId) { + $normalizedProviderIds[] = $this->normalizeProviderIdentifier($providerId); + } + + $this->providerPriority = array_values(array_unique($normalizedProviderIds)); + return $this; + } + + /** + * Enables or disables fallback to subsequent discovered models after runtime failures. + * + * @since n.e.x.t + * + * @param bool $enabled Whether fallback should be enabled. + * @return self The prompt builder instance. + */ + public function usingProviderFallback(bool $enabled = true): self + { + $this->providerFallbackEnabled = $enabled; + return $this; + } + /** * Sets the system instruction. * @@ -930,22 +978,57 @@ public function generateResult(?CapabilityEnum $capability = null): GenerativeAi } } - $model = $this->getConfiguredModel($capability); + if ($this->model !== null || !$this->providerFallbackEnabled) { + $model = $this->getConfiguredModel($capability); - // Dispatch BeforeGenerateResultEvent - $this->dispatchEvent( - new BeforeGenerateResultEvent($this->messages, $model, $capability) - ); + // Dispatch BeforeGenerateResultEvent + $this->dispatchEvent( + new BeforeGenerateResultEvent($this->messages, $model, $capability) + ); - // Route to the appropriate generation method based on capability - $result = $this->executeModelGeneration($model, $capability, $this->messages); + // Route to the appropriate generation method based on capability + $result = $this->executeModelGeneration($model, $capability, $this->messages); - // Dispatch AfterGenerateResultEvent - $this->dispatchEvent( - new AfterGenerateResultEvent($this->messages, $model, $capability, $result) - ); + // Dispatch AfterGenerateResultEvent + $this->dispatchEvent( + new AfterGenerateResultEvent($this->messages, $model, $capability, $result) + ); + + return $result; + } + + $requirements = ModelRequirements::fromPromptData($capability, $this->messages, $this->modelConfig); + $candidates = $this->getOrderedCandidateModels($requirements, $capability); + + $lastException = null; + + foreach ($candidates as [$providerId, $modelId]) { + $model = $this->registry->getProviderModel($providerId, $modelId, $this->modelConfig); + $this->bindModelRequestOptions($model); + + $this->dispatchEvent( + new BeforeGenerateResultEvent($this->messages, $model, $capability) + ); + + try { + $result = $this->executeModelGeneration($model, $capability, $this->messages); + } catch (RuntimeException $e) { + $lastException = $e; + continue; + } - return $result; + $this->dispatchEvent( + new AfterGenerateResultEvent($this->messages, $model, $capability, $result) + ); + + return $result; + } + + if ($lastException !== null) { + throw $lastException; + } + + throw new RuntimeException('No candidate models were available for generation.'); } /** @@ -1335,47 +1418,8 @@ private function getConfiguredModel(CapabilityEnum $capability): ModelInterface return $model; } - // Retrieve the candidate models map which satisfies the requirements. - $candidateMap = $this->getCandidateModelsMap($requirements); - - if (empty($candidateMap)) { - $message = sprintf( - 'No models found that support %s for this prompt.', - $capability->value - ); - - if ($this->providerIdOrClassName !== null) { - $message = sprintf( - 'No models found for provider "%s" that support %s for this prompt.', - $this->providerIdOrClassName, - $capability->value - ); - } - - throw new InvalidArgumentException($message); - } - - // Check if any preferred models match the candidates, in priority order. - if (!empty($this->modelPreferenceKeys)) { - // Find preferences that match available candidates, preserving preference order. - $matchingPreferences = array_intersect_key( - array_flip($this->modelPreferenceKeys), - $candidateMap - ); - - if (!empty($matchingPreferences)) { - // Get the first matching preference key - $firstMatchKey = key($matchingPreferences); - [$providerId, $modelId] = $candidateMap[$firstMatchKey]; - - $model = $this->registry->getProviderModel($providerId, $modelId, $this->modelConfig); - $this->bindModelRequestOptions($model); - return $model; - } - } - - // No preference matched; fall back to the first candidate discovered. - [$providerId, $modelId] = reset($candidateMap); + $candidates = $this->getOrderedCandidateModels($requirements, $capability); + [$providerId, $modelId] = $candidates[0]; $model = $this->registry->getProviderModel($providerId, $modelId, $this->modelConfig); $this->bindModelRequestOptions($model); @@ -1400,69 +1444,130 @@ private function bindModelRequestOptions(ModelInterface $model): void } /** - * Builds a map of candidate models that satisfy the requirements for efficient lookup. - * - * @since 0.2.0 + * Retrieves candidate models ordered by provider discovery and model preferences. * + * @since n.e.x.t + * * @param ModelRequirements $requirements The requirements derived from the prompt. - * @return array Map of preference keys to [providerId, modelId] tuples. + * @param CapabilityEnum $capability The capability being requested. + * @return list Ordered [providerId, modelId] tuples. + * @throws InvalidArgumentException If no candidate models satisfy the requirements. */ - private function getCandidateModelsMap(ModelRequirements $requirements): array + private function getOrderedCandidateModels(ModelRequirements $requirements, CapabilityEnum $capability): array { + $candidates = []; + if ($this->providerIdOrClassName === null) { - // No provider locked in, gather all models across providers that meet requirements. - $providerModelsMetadata = $this->registry->findModelsMetadataForSupport($requirements); + $providerModelsMetadata = $this->providerPriority === [] + ? $this->registry->findModelsMetadataForSupport($requirements) + : $this->registry->findModelsMetadataForSupport($requirements, $this->providerPriority); - $candidateMap = []; foreach ($providerModelsMetadata as $providerModels) { $providerId = $providerModels->getProvider()->getId(); - $providerMap = $this->generateMapFromCandidates($providerId, $providerModels->getModels()); - - // Use + operator to merge, preserving keys from $candidateMap (first provider wins for model-only keys) - $candidateMap = $candidateMap + $providerMap; + foreach ($providerModels->getModels() as $modelMetadata) { + $candidates[] = [$providerId, $modelMetadata->getId()]; + } } + } else { + $modelsMetadata = $this->registry->findProviderModelsMetadataForSupport( + $this->providerIdOrClassName, + $requirements + ); - return $candidateMap; + $providerId = $this->registry->getProviderId($this->providerIdOrClassName); + foreach ($modelsMetadata as $modelMetadata) { + $candidates[] = [$providerId, $modelMetadata->getId()]; + } } - // Provider set, only consider models from that provider. - $modelsMetadata = $this->registry->findProviderModelsMetadataForSupport( - $this->providerIdOrClassName, - $requirements - ); - - // Ensure we pass the provider ID, not the class name - $providerId = $this->registry->getProviderId($this->providerIdOrClassName); + if ($candidates === []) { + throw new InvalidArgumentException($this->createNoModelsFoundMessage($capability)); + } - return $this->generateMapFromCandidates($providerId, $modelsMetadata); + return $this->applyModelPreferencesToCandidates($candidates); } /** - * Generates a candidate map from model metadata with both provider-specific and model-only keys. + * Reorders candidate tuples to honor model preferences while preserving discovery order for remaining entries. * - * @since 0.2.0 - * - * @param string $providerId The provider ID. - * @param list $modelsMetadata The models metadata to map. - * @return array Map of preference keys to [providerId, modelId] tuples. + * @since n.e.x.t + * + * @param list $candidates Candidate tuples in discovery order. + * @return list Reordered candidate tuples. */ - private function generateMapFromCandidates(string $providerId, array $modelsMetadata): array + private function applyModelPreferencesToCandidates(array $candidates): array { - $map = []; - - foreach ($modelsMetadata as $modelMetadata) { - $modelId = $modelMetadata->getId(); + if ($this->modelPreferenceKeys === []) { + return $candidates; + } - // Add provider-specific key + $candidateMap = []; + foreach ($candidates as [$providerId, $modelId]) { $providerModelKey = $this->createProviderModelPreferenceKey($providerId, $modelId); - $map[$providerModelKey] = [$providerId, $modelId]; + if (!isset($candidateMap[$providerModelKey])) { + $candidateMap[$providerModelKey] = [$providerId, $modelId]; + } - // Add model-only key $modelKey = $this->createModelPreferenceKey($modelId); - $map[$modelKey] = [$providerId, $modelId]; + if (!isset($candidateMap[$modelKey])) { + $candidateMap[$modelKey] = [$providerId, $modelId]; + } + } + + $orderedCandidates = []; + $seen = []; + + foreach ($this->modelPreferenceKeys as $preferenceKey) { + if (!isset($candidateMap[$preferenceKey])) { + continue; + } + + [$providerId, $modelId] = $candidateMap[$preferenceKey]; + $candidateKey = $providerId . '::' . $modelId; + + if (isset($seen[$candidateKey])) { + continue; + } + + $seen[$candidateKey] = true; + $orderedCandidates[] = [$providerId, $modelId]; + } + + foreach ($candidates as [$providerId, $modelId]) { + $candidateKey = $providerId . '::' . $modelId; + if (isset($seen[$candidateKey])) { + continue; + } + + $seen[$candidateKey] = true; + $orderedCandidates[] = [$providerId, $modelId]; + } + + return $orderedCandidates; + } + + /** + * Builds the no-models-found message for capability/provider context. + * + * @since n.e.x.t + * + * @param CapabilityEnum $capability The required capability. + * @return string The no-models-found message. + */ + private function createNoModelsFoundMessage(CapabilityEnum $capability): string + { + if ($this->providerIdOrClassName !== null) { + return sprintf( + 'No models found for provider "%s" that support %s for this prompt.', + $this->providerIdOrClassName, + $capability->value + ); } - return $map; + return sprintf( + 'No models found that support %s for this prompt.', + $capability->value + ); } /** @@ -1492,6 +1597,25 @@ private function normalizePreferenceIdentifier( return $trimmed; } + /** + * Normalizes and validates a provider identifier string. + * + * @since n.e.x.t + * + * @param string $providerId The provider identifier. + * @return string The normalized provider identifier. + * @throws InvalidArgumentException If the identifier is empty. + */ + private function normalizeProviderIdentifier(string $providerId): string + { + $providerId = trim($providerId); + if ($providerId === '') { + throw new InvalidArgumentException('Provider priority identifiers cannot be empty.'); + } + + return $providerId; + } + /** * Creates a preference key for a provider/model combination. * diff --git a/src/Providers/ProviderRegistry.php b/src/Providers/ProviderRegistry.php index d5aa6e47..545aa307 100644 --- a/src/Providers/ProviderRegistry.php +++ b/src/Providers/ProviderRegistry.php @@ -234,15 +234,29 @@ public function isProviderConfigured(string $idOrClassName): bool * Finds models across all available providers that support the given requirements. * * @since 0.1.0 + * @since n.e.x.t Added support for preferred provider ordering. * * @param ModelRequirements $modelRequirements The requirements to match against. + * @param list $preferredProviderOrder Optional ordered provider IDs to prioritize first. * @return list List of provider models metadata that match requirements. */ - public function findModelsMetadataForSupport(ModelRequirements $modelRequirements): array - { + public function findModelsMetadataForSupport( + ModelRequirements $modelRequirements, + array $preferredProviderOrder = [] + ): array { $results = []; + $providerIds = array_keys($this->registeredIdsToClassNames); + + if ($preferredProviderOrder !== []) { + $preferredProviderOrder = array_values(array_unique($preferredProviderOrder)); + + $orderedProviderIds = array_values(array_intersect($preferredProviderOrder, $providerIds)); + $remainingProviderIds = array_values(array_diff($providerIds, $orderedProviderIds)); + $providerIds = array_merge($orderedProviderIds, $remainingProviderIds); + } - foreach ($this->registeredIdsToClassNames as $providerId => $className) { + foreach ($providerIds as $providerId) { + $className = $this->registeredIdsToClassNames[$providerId]; $providerResults = $this->findProviderModelsMetadataForSupport($providerId, $modelRequirements); if (!empty($providerResults)) { // Use static method from ProviderInterface diff --git a/tests/mocks/MockSecondaryProvider.php b/tests/mocks/MockSecondaryProvider.php new file mode 100644 index 00000000..1973de35 --- /dev/null +++ b/tests/mocks/MockSecondaryProvider.php @@ -0,0 +1,117 @@ +getModelMetadata($modelId); + + $config = $modelConfig ?? new ModelConfig(); + + return new MockModel($modelMetadata, $config); + } + + /** + * {@inheritDoc} + */ + public static function availability(): ProviderAvailabilityInterface + { + if (static::$availability === null) { + static::$availability = new MockProviderAvailability(true); + } + + return static::$availability; + } + + /** + * {@inheritDoc} + */ + public static function modelMetadataDirectory(): ModelMetadataDirectoryInterface + { + if (static::$modelMetadataDirectory === null) { + $mockModels = [ + 'mock-secondary-text-model' => new ModelMetadata( + 'mock-secondary-text-model', + 'Mock Secondary Text Model', + [CapabilityEnum::textGeneration()], + [] + ) + ]; + + static::$modelMetadataDirectory = new MockModelMetadataDirectory($mockModels); + } + + return static::$modelMetadataDirectory; + } + + /** + * Sets the availability checker for testing. + * + * @param MockProviderAvailability $availability The availability checker. + */ + public static function setAvailability(MockProviderAvailability $availability): void + { + static::$availability = $availability; + } + + /** + * Sets the model metadata directory for testing. + * + * @param MockModelMetadataDirectory $directory The model metadata directory. + */ + public static function setModelMetadataDirectory(MockModelMetadataDirectory $directory): void + { + static::$modelMetadataDirectory = $directory; + } + + /** + * Resets static state for testing. + */ + public static function reset(): void + { + static::$availability = null; + static::$modelMetadataDirectory = null; + } +} diff --git a/tests/unit/Builders/PromptBuilderTest.php b/tests/unit/Builders/PromptBuilderTest.php index ce68a223..8bb1a6c4 100644 --- a/tests/unit/Builders/PromptBuilderTest.php +++ b/tests/unit/Builders/PromptBuilderTest.php @@ -8,6 +8,7 @@ use PHPUnit\Framework\TestCase; use RuntimeException; use WordPress\AiClient\Builders\PromptBuilder; +use WordPress\AiClient\Common\Exception\RuntimeException as AiClientRuntimeException; use WordPress\AiClient\Files\DTO\File; use WordPress\AiClient\Files\Enums\FileTypeEnum; use WordPress\AiClient\Files\Enums\MediaOrientationEnum; @@ -1155,6 +1156,75 @@ public function testUsingProvider(): void $this->assertEquals('test-provider', $actualProvider); } + /** + * Tests usingProviderPriority method. + * + * @return void + */ + public function testUsingProviderPriority(): void + { + $builder = new PromptBuilder($this->registry); + $result = $builder->usingProviderPriority('provider-a', 'provider-b', 'provider-a'); + + $this->assertSame($builder, $result); + + $reflection = new \ReflectionClass($builder); + $priorityProperty = $reflection->getProperty('providerPriority'); + $priorityProperty->setAccessible(true); + + $actualPriority = $priorityProperty->getValue($builder); + $this->assertSame(['provider-a', 'provider-b'], $actualPriority); + } + + /** + * Tests usingProviderPriority throws when no identifiers are provided. + * + * @return void + */ + public function testUsingProviderPriorityWithoutArgumentsThrowsException(): void + { + $builder = new PromptBuilder($this->registry); + + $this->expectException(InvalidArgumentException::class); + $this->expectExceptionMessage('At least one provider priority identifier must be provided.'); + + $builder->usingProviderPriority(); + } + + /** + * Tests usingProviderPriority throws for empty identifiers. + * + * @return void + */ + public function testUsingProviderPriorityWithEmptyIdentifierThrowsException(): void + { + $builder = new PromptBuilder($this->registry); + + $this->expectException(InvalidArgumentException::class); + $this->expectExceptionMessage('Provider priority identifiers cannot be empty.'); + + $builder->usingProviderPriority('provider-a', ' '); + } + + /** + * Tests usingProviderFallback updates the fallback-enabled flag. + * + * @return void + */ + public function testUsingProviderFallback(): void + { + $builder = new PromptBuilder($this->registry); + $result = $builder->usingProviderFallback(false); + + $this->assertSame($builder, $result); + + $fallbackEnabled = (function (): bool { + return $this->providerFallbackEnabled; + })->call($builder); + + $this->assertFalse($fallbackEnabled); + } + /** * Tests usingSystemInstruction method. * @@ -3504,6 +3574,406 @@ public function testGenerateResultWithProviderNoModelsThrowsException(): void $builder->generateResult(); } + /** + * Tests generateResult passes provider priority order to discovery. + * + * @return void + */ + public function testGenerateResultWithProviderPriority(): void + { + $result = $this->createMock(GenerativeAiResult::class); + $modelMetadata = $this->createMock(ModelMetadata::class); + $modelMetadata->method('getId')->willReturn('provider-priority-model'); + + $providerMetadata = new ProviderMetadata( + 'preferred-provider', + 'Preferred Provider', + ProviderTypeEnum::cloud() + ); + + $model = $this->createMockTextGenerationModel($result, $modelMetadata); + + $this->registry->expects($this->once()) + ->method('findModelsMetadataForSupport') + ->with( + $this->isInstanceOf(ModelRequirements::class), + ['preferred-provider', 'secondary-provider'] + ) + ->willReturn([new ProviderModelsMetadata($providerMetadata, [$modelMetadata])]); + + $this->registry->expects($this->once()) + ->method('getProviderModel') + ->with('preferred-provider', 'provider-priority-model', $this->isInstanceOf(ModelConfig::class)) + ->willReturn($model); + + $builder = new PromptBuilder($this->registry, 'Test prompt'); + $builder->usingProviderPriority('preferred-provider', 'secondary-provider'); + + $actualResult = $builder->generateResult(); + $this->assertSame($result, $actualResult); + } + + /** + * Tests generateResult falls back to the next discovered model after runtime failures. + * + * @return void + */ + public function testGenerateResultFallsBackToNextModelOnRuntimeException(): void + { + $successfulResult = $this->createTestResult('Fallback success'); + + $firstModelMetadata = $this->createMock(ModelMetadata::class); + $firstModelMetadata->method('getId')->willReturn('first-model'); + + $secondModelMetadata = $this->createMock(ModelMetadata::class); + $secondModelMetadata->method('getId')->willReturn('second-model'); + + $firstProviderMetadata = new ProviderMetadata( + 'first-provider', + 'First Provider', + ProviderTypeEnum::cloud() + ); + $secondProviderMetadata = new ProviderMetadata( + 'second-provider', + 'Second Provider', + ProviderTypeEnum::cloud() + ); + + $failingModel = new class ( + $firstModelMetadata, + $firstProviderMetadata + ) implements ModelInterface, TextGenerationModelInterface { + private ModelMetadata $metadata; + private ProviderMetadata $providerMetadata; + private ModelConfig $config; + + public function __construct( + ModelMetadata $metadata, + ProviderMetadata $providerMetadata + ) { + $this->metadata = $metadata; + $this->providerMetadata = $providerMetadata; + $this->config = new ModelConfig(); + } + + public function metadata(): ModelMetadata + { + return $this->metadata; + } + + public function providerMetadata(): ProviderMetadata + { + return $this->providerMetadata; + } + + public function setConfig(ModelConfig $config): void + { + $this->config = $config; + } + + public function getConfig(): ModelConfig + { + return $this->config; + } + + public function generateTextResult(array $prompt): GenerativeAiResult + { + throw new AiClientRuntimeException('First provider failed'); + } + }; + + $successfulModel = $this->createMockTextGenerationModel($successfulResult, $secondModelMetadata); + + $this->registry->expects($this->once()) + ->method('findModelsMetadataForSupport') + ->with($this->isInstanceOf(ModelRequirements::class)) + ->willReturn([ + new ProviderModelsMetadata( + $firstProviderMetadata, + [$firstModelMetadata] + ), + new ProviderModelsMetadata($secondProviderMetadata, [$secondModelMetadata]), + ]); + + $callCount = 0; + $this->registry->expects($this->exactly(2)) + ->method('getProviderModel') + ->willReturnCallback( + function ( + string $providerId, + string $modelId, + ModelConfig $modelConfig + ) use ( + &$callCount, + $failingModel, + $successfulModel + ): ModelInterface { + $this->assertInstanceOf(ModelConfig::class, $modelConfig); + + if ($callCount === 0) { + $this->assertSame('first-provider', $providerId); + $this->assertSame('first-model', $modelId); + $callCount++; + return $failingModel; + } + + $this->assertSame('second-provider', $providerId); + $this->assertSame('second-model', $modelId); + $callCount++; + return $successfulModel; + } + ); + + $builder = new PromptBuilder($this->registry, 'Test prompt'); + + $actualResult = $builder->generateResult(); + $this->assertSame($successfulResult, $actualResult); + } + + /** + * Tests generateResult does not retry another model when fallback is disabled. + * + * @return void + */ + public function testGenerateResultWithFallbackDisabledDoesNotRetry(): void + { + $firstModelMetadata = $this->createMock(ModelMetadata::class); + $firstModelMetadata->method('getId')->willReturn('first-model'); + + $secondModelMetadata = $this->createMock(ModelMetadata::class); + $secondModelMetadata->method('getId')->willReturn('second-model'); + + $firstProviderMetadata = new ProviderMetadata( + 'first-provider', + 'First Provider', + ProviderTypeEnum::cloud() + ); + $secondProviderMetadata = new ProviderMetadata( + 'second-provider', + 'Second Provider', + ProviderTypeEnum::cloud() + ); + + $failingModel = new class ( + $firstModelMetadata, + $firstProviderMetadata + ) implements ModelInterface, TextGenerationModelInterface { + private ModelMetadata $metadata; + private ProviderMetadata $providerMetadata; + private ModelConfig $config; + + public function __construct( + ModelMetadata $metadata, + ProviderMetadata $providerMetadata + ) { + $this->metadata = $metadata; + $this->providerMetadata = $providerMetadata; + $this->config = new ModelConfig(); + } + + public function metadata(): ModelMetadata + { + return $this->metadata; + } + + public function providerMetadata(): ProviderMetadata + { + return $this->providerMetadata; + } + + public function setConfig(ModelConfig $config): void + { + $this->config = $config; + } + + public function getConfig(): ModelConfig + { + return $this->config; + } + + public function generateTextResult(array $prompt): GenerativeAiResult + { + throw new AiClientRuntimeException('First provider failed'); + } + }; + + $this->registry->expects($this->once()) + ->method('findModelsMetadataForSupport') + ->with($this->isInstanceOf(ModelRequirements::class)) + ->willReturn([ + new ProviderModelsMetadata($firstProviderMetadata, [$firstModelMetadata]), + new ProviderModelsMetadata($secondProviderMetadata, [$secondModelMetadata]), + ]); + + $this->registry->expects($this->once()) + ->method('getProviderModel') + ->with('first-provider', 'first-model', $this->isInstanceOf(ModelConfig::class)) + ->willReturn($failingModel); + + $builder = new PromptBuilder($this->registry, 'Test prompt'); + $builder->usingProviderFallback(false); + + $this->expectException(AiClientRuntimeException::class); + $this->expectExceptionMessage('First provider failed'); + + $builder->generateResult(); + } + + /** + * Tests generateResult throws the final runtime exception when all fallback candidates fail. + * + * @return void + */ + public function testGenerateResultWithFallbackAllCandidatesFailThrowsLastException(): void + { + $firstModelMetadata = $this->createMock(ModelMetadata::class); + $firstModelMetadata->method('getId')->willReturn('first-model'); + + $secondModelMetadata = $this->createMock(ModelMetadata::class); + $secondModelMetadata->method('getId')->willReturn('second-model'); + + $firstProviderMetadata = new ProviderMetadata( + 'first-provider', + 'First Provider', + ProviderTypeEnum::cloud() + ); + $secondProviderMetadata = new ProviderMetadata( + 'second-provider', + 'Second Provider', + ProviderTypeEnum::cloud() + ); + + $firstFailingModel = new class ( + $firstModelMetadata, + $firstProviderMetadata + ) implements ModelInterface, TextGenerationModelInterface { + private ModelMetadata $metadata; + private ProviderMetadata $providerMetadata; + private ModelConfig $config; + + public function __construct( + ModelMetadata $metadata, + ProviderMetadata $providerMetadata + ) { + $this->metadata = $metadata; + $this->providerMetadata = $providerMetadata; + $this->config = new ModelConfig(); + } + + public function metadata(): ModelMetadata + { + return $this->metadata; + } + + public function providerMetadata(): ProviderMetadata + { + return $this->providerMetadata; + } + + public function setConfig(ModelConfig $config): void + { + $this->config = $config; + } + + public function getConfig(): ModelConfig + { + return $this->config; + } + + public function generateTextResult(array $prompt): GenerativeAiResult + { + throw new AiClientRuntimeException('First provider failed'); + } + }; + + $secondFailingModel = new class ( + $secondModelMetadata, + $secondProviderMetadata + ) implements ModelInterface, TextGenerationModelInterface { + private ModelMetadata $metadata; + private ProviderMetadata $providerMetadata; + private ModelConfig $config; + + public function __construct( + ModelMetadata $metadata, + ProviderMetadata $providerMetadata + ) { + $this->metadata = $metadata; + $this->providerMetadata = $providerMetadata; + $this->config = new ModelConfig(); + } + + public function metadata(): ModelMetadata + { + return $this->metadata; + } + + public function providerMetadata(): ProviderMetadata + { + return $this->providerMetadata; + } + + public function setConfig(ModelConfig $config): void + { + $this->config = $config; + } + + public function getConfig(): ModelConfig + { + return $this->config; + } + + public function generateTextResult(array $prompt): GenerativeAiResult + { + throw new AiClientRuntimeException('Second provider failed'); + } + }; + + $this->registry->expects($this->once()) + ->method('findModelsMetadataForSupport') + ->with($this->isInstanceOf(ModelRequirements::class)) + ->willReturn([ + new ProviderModelsMetadata($firstProviderMetadata, [$firstModelMetadata]), + new ProviderModelsMetadata($secondProviderMetadata, [$secondModelMetadata]), + ]); + + $callCount = 0; + $this->registry->expects($this->exactly(2)) + ->method('getProviderModel') + ->willReturnCallback( + function ( + string $providerId, + string $modelId, + ModelConfig $modelConfig + ) use ( + &$callCount, + $firstFailingModel, + $secondFailingModel + ): ModelInterface { + $this->assertInstanceOf(ModelConfig::class, $modelConfig); + + if ($callCount === 0) { + $this->assertSame('first-provider', $providerId); + $this->assertSame('first-model', $modelId); + $callCount++; + return $firstFailingModel; + } + + $this->assertSame('second-provider', $providerId); + $this->assertSame('second-model', $modelId); + $callCount++; + return $secondFailingModel; + } + ); + + $builder = new PromptBuilder($this->registry, 'Test prompt'); + + $this->expectException(AiClientRuntimeException::class); + $this->expectExceptionMessage('Second provider failed'); + + $builder->generateResult(); + } + /** * Tests that provider takes precedence when both provider and model are set. * diff --git a/tests/unit/Providers/ProviderRegistryTest.php b/tests/unit/Providers/ProviderRegistryTest.php index ff8197e6..6e95f8c9 100644 --- a/tests/unit/Providers/ProviderRegistryTest.php +++ b/tests/unit/Providers/ProviderRegistryTest.php @@ -20,6 +20,7 @@ use WordPress\AiClient\Tests\mocks\MockNoAuthProvider; use WordPress\AiClient\Tests\mocks\MockProvider; use WordPress\AiClient\Tests\mocks\MockProviderAvailability; +use WordPress\AiClient\Tests\mocks\MockSecondaryProvider; /** * @covers \WordPress\AiClient\Providers\ProviderRegistry @@ -33,11 +34,13 @@ protected function setUp(): void parent::setUp(); $this->registry = new ProviderRegistry(); MockProvider::reset(); // Reset static state of mock provider before each test. + MockSecondaryProvider::reset(); } protected function tearDown(): void { MockProvider::reset(); // Reset static state of mock provider after each test. + MockSecondaryProvider::reset(); parent::tearDown(); } @@ -180,6 +183,28 @@ public function testFindModelsMetadataForSupportWithRegisteredProvider(): void $this->assertCount(1, $results); } + /** + * Tests findModelsMetadataForSupport prioritizes preferred provider order. + * + * @return void + */ + public function testFindModelsMetadataForSupportWithPreferredProviderOrder(): void + { + $this->registry->registerProvider(MockProvider::class); + $this->registry->registerProvider(MockSecondaryProvider::class); + + $requirements = new ModelRequirements([CapabilityEnum::textGeneration()], []); + + $results = $this->registry->findModelsMetadataForSupport( + $requirements, + ['mock-secondary', 'unknown-provider'] + ); + + $this->assertCount(2, $results); + $this->assertSame('mock-secondary', $results[0]->getProvider()->getId()); + $this->assertSame('mock', $results[1]->getProvider()->getId()); + } + /** * Tests findProviderModelsMetadataForSupport with registered provider. *