Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
308 changes: 216 additions & 92 deletions src/Builders/PromptBuilder.php
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
use WordPress\AiClient\Providers\Http\DTO\RequestOptions;
use WordPress\AiClient\Providers\Models\Contracts\ModelInterface;
use WordPress\AiClient\Providers\Models\DTO\ModelConfig;
use WordPress\AiClient\Providers\Models\DTO\ModelMetadata;
use WordPress\AiClient\Providers\Models\DTO\ModelRequirements;
use WordPress\AiClient\Providers\Models\Enums\CapabilityEnum;
use WordPress\AiClient\Providers\Models\ImageGeneration\Contracts\ImageGenerationModelInterface;
Expand Down Expand Up @@ -76,6 +75,16 @@ class PromptBuilder
*/
protected ?string $providerIdOrClassName = null;

/**
* @var list<string> Ordered provider IDs to prioritize during model discovery.
*/
protected array $providerPriority = [];

/**
* @var bool Whether to retry with the next discovered candidate model after runtime failures.
*/
protected bool $providerFallbackEnabled = true;

/**
* @var ModelConfig The model configuration.
*/
Expand Down Expand Up @@ -378,6 +387,45 @@ public function usingProvider(string $providerIdOrClassName): self
return $this;
}

/**
* Sets provider priority for cross-provider model discovery.
*
* @since n.e.x.t
*
* @param string ...$providerIds Provider IDs in preferred order.
* @return self The prompt builder instance.
*
* @throws InvalidArgumentException If no provider IDs are provided or an identifier is invalid.
*/
public function usingProviderPriority(string ...$providerIds): self
{
if ($providerIds === []) {
throw new InvalidArgumentException('At least one provider priority identifier must be provided.');
}

$normalizedProviderIds = [];
foreach ($providerIds as $providerId) {
$normalizedProviderIds[] = $this->normalizeProviderIdentifier($providerId);
}

$this->providerPriority = array_values(array_unique($normalizedProviderIds));
return $this;
}

/**
* Enables or disables fallback to subsequent discovered models after runtime failures.
*
* @since n.e.x.t
*
* @param bool $enabled Whether fallback should be enabled.
* @return self The prompt builder instance.
*/
public function usingProviderFallback(bool $enabled = true): self
{
$this->providerFallbackEnabled = $enabled;
return $this;
}

/**
* Sets the system instruction.
*
Expand Down Expand Up @@ -930,22 +978,57 @@ public function generateResult(?CapabilityEnum $capability = null): GenerativeAi
}
}

$model = $this->getConfiguredModel($capability);
if ($this->model !== null || !$this->providerFallbackEnabled) {
$model = $this->getConfiguredModel($capability);

// Dispatch BeforeGenerateResultEvent
$this->dispatchEvent(
new BeforeGenerateResultEvent($this->messages, $model, $capability)
);
// Dispatch BeforeGenerateResultEvent
$this->dispatchEvent(
new BeforeGenerateResultEvent($this->messages, $model, $capability)
);

// Route to the appropriate generation method based on capability
$result = $this->executeModelGeneration($model, $capability, $this->messages);
// Route to the appropriate generation method based on capability
$result = $this->executeModelGeneration($model, $capability, $this->messages);

// Dispatch AfterGenerateResultEvent
$this->dispatchEvent(
new AfterGenerateResultEvent($this->messages, $model, $capability, $result)
);
// Dispatch AfterGenerateResultEvent
$this->dispatchEvent(
new AfterGenerateResultEvent($this->messages, $model, $capability, $result)
);

return $result;
}

$requirements = ModelRequirements::fromPromptData($capability, $this->messages, $this->modelConfig);
$candidates = $this->getOrderedCandidateModels($requirements, $capability);

$lastException = null;

foreach ($candidates as [$providerId, $modelId]) {
$model = $this->registry->getProviderModel($providerId, $modelId, $this->modelConfig);
$this->bindModelRequestOptions($model);

$this->dispatchEvent(
new BeforeGenerateResultEvent($this->messages, $model, $capability)
);

try {
$result = $this->executeModelGeneration($model, $capability, $this->messages);
} catch (RuntimeException $e) {
$lastException = $e;
continue;
}

return $result;
$this->dispatchEvent(
new AfterGenerateResultEvent($this->messages, $model, $capability, $result)
);

return $result;
}

if ($lastException !== null) {
throw $lastException;
}

throw new RuntimeException('No candidate models were available for generation.');
}

/**
Expand Down Expand Up @@ -1335,47 +1418,8 @@ private function getConfiguredModel(CapabilityEnum $capability): ModelInterface
return $model;
}

// Retrieve the candidate models map which satisfies the requirements.
$candidateMap = $this->getCandidateModelsMap($requirements);

if (empty($candidateMap)) {
$message = sprintf(
'No models found that support %s for this prompt.',
$capability->value
);

if ($this->providerIdOrClassName !== null) {
$message = sprintf(
'No models found for provider "%s" that support %s for this prompt.',
$this->providerIdOrClassName,
$capability->value
);
}

throw new InvalidArgumentException($message);
}

// Check if any preferred models match the candidates, in priority order.
if (!empty($this->modelPreferenceKeys)) {
// Find preferences that match available candidates, preserving preference order.
$matchingPreferences = array_intersect_key(
array_flip($this->modelPreferenceKeys),
$candidateMap
);

if (!empty($matchingPreferences)) {
// Get the first matching preference key
$firstMatchKey = key($matchingPreferences);
[$providerId, $modelId] = $candidateMap[$firstMatchKey];

$model = $this->registry->getProviderModel($providerId, $modelId, $this->modelConfig);
$this->bindModelRequestOptions($model);
return $model;
}
}

// No preference matched; fall back to the first candidate discovered.
[$providerId, $modelId] = reset($candidateMap);
$candidates = $this->getOrderedCandidateModels($requirements, $capability);
[$providerId, $modelId] = $candidates[0];

$model = $this->registry->getProviderModel($providerId, $modelId, $this->modelConfig);
$this->bindModelRequestOptions($model);
Expand All @@ -1400,69 +1444,130 @@ private function bindModelRequestOptions(ModelInterface $model): void
}

/**
* Builds a map of candidate models that satisfy the requirements for efficient lookup.
*
* @since 0.2.0
* Retrieves candidate models ordered by provider discovery and model preferences.
*
* @since n.e.x.t
*
* @param ModelRequirements $requirements The requirements derived from the prompt.
* @return array<string, array{0:string,1:string}> Map of preference keys to [providerId, modelId] tuples.
* @param CapabilityEnum $capability The capability being requested.
* @return list<array{0:string,1:string}> Ordered [providerId, modelId] tuples.
* @throws InvalidArgumentException If no candidate models satisfy the requirements.
*/
private function getCandidateModelsMap(ModelRequirements $requirements): array
private function getOrderedCandidateModels(ModelRequirements $requirements, CapabilityEnum $capability): array
{
$candidates = [];

if ($this->providerIdOrClassName === null) {
// No provider locked in, gather all models across providers that meet requirements.
$providerModelsMetadata = $this->registry->findModelsMetadataForSupport($requirements);
$providerModelsMetadata = $this->providerPriority === []
? $this->registry->findModelsMetadataForSupport($requirements)
: $this->registry->findModelsMetadataForSupport($requirements, $this->providerPriority);

$candidateMap = [];
foreach ($providerModelsMetadata as $providerModels) {
$providerId = $providerModels->getProvider()->getId();
$providerMap = $this->generateMapFromCandidates($providerId, $providerModels->getModels());

// Use + operator to merge, preserving keys from $candidateMap (first provider wins for model-only keys)
$candidateMap = $candidateMap + $providerMap;
foreach ($providerModels->getModels() as $modelMetadata) {
$candidates[] = [$providerId, $modelMetadata->getId()];
}
}
} else {
$modelsMetadata = $this->registry->findProviderModelsMetadataForSupport(
$this->providerIdOrClassName,
$requirements
);

return $candidateMap;
$providerId = $this->registry->getProviderId($this->providerIdOrClassName);
foreach ($modelsMetadata as $modelMetadata) {
$candidates[] = [$providerId, $modelMetadata->getId()];
}
}

// Provider set, only consider models from that provider.
$modelsMetadata = $this->registry->findProviderModelsMetadataForSupport(
$this->providerIdOrClassName,
$requirements
);

// Ensure we pass the provider ID, not the class name
$providerId = $this->registry->getProviderId($this->providerIdOrClassName);
if ($candidates === []) {
throw new InvalidArgumentException($this->createNoModelsFoundMessage($capability));
}

return $this->generateMapFromCandidates($providerId, $modelsMetadata);
return $this->applyModelPreferencesToCandidates($candidates);
}

/**
* Generates a candidate map from model metadata with both provider-specific and model-only keys.
* Reorders candidate tuples to honor model preferences while preserving discovery order for remaining entries.
*
* @since 0.2.0
*
* @param string $providerId The provider ID.
* @param list<ModelMetadata> $modelsMetadata The models metadata to map.
* @return array<string, array{0:string,1:string}> Map of preference keys to [providerId, modelId] tuples.
* @since n.e.x.t
*
* @param list<array{0:string,1:string}> $candidates Candidate tuples in discovery order.
* @return list<array{0:string,1:string}> Reordered candidate tuples.
*/
private function generateMapFromCandidates(string $providerId, array $modelsMetadata): array
private function applyModelPreferencesToCandidates(array $candidates): array
{
$map = [];

foreach ($modelsMetadata as $modelMetadata) {
$modelId = $modelMetadata->getId();
if ($this->modelPreferenceKeys === []) {
return $candidates;
}

// Add provider-specific key
$candidateMap = [];
foreach ($candidates as [$providerId, $modelId]) {
$providerModelKey = $this->createProviderModelPreferenceKey($providerId, $modelId);
$map[$providerModelKey] = [$providerId, $modelId];
if (!isset($candidateMap[$providerModelKey])) {
$candidateMap[$providerModelKey] = [$providerId, $modelId];
}

// Add model-only key
$modelKey = $this->createModelPreferenceKey($modelId);
$map[$modelKey] = [$providerId, $modelId];
if (!isset($candidateMap[$modelKey])) {
$candidateMap[$modelKey] = [$providerId, $modelId];
}
}

$orderedCandidates = [];
$seen = [];

foreach ($this->modelPreferenceKeys as $preferenceKey) {
if (!isset($candidateMap[$preferenceKey])) {
continue;
}

[$providerId, $modelId] = $candidateMap[$preferenceKey];
$candidateKey = $providerId . '::' . $modelId;

if (isset($seen[$candidateKey])) {
continue;
}

$seen[$candidateKey] = true;
$orderedCandidates[] = [$providerId, $modelId];
}

foreach ($candidates as [$providerId, $modelId]) {
$candidateKey = $providerId . '::' . $modelId;
if (isset($seen[$candidateKey])) {
continue;
}

$seen[$candidateKey] = true;
$orderedCandidates[] = [$providerId, $modelId];
}

return $orderedCandidates;
}

/**
* Builds the no-models-found message for capability/provider context.
*
* @since n.e.x.t
*
* @param CapabilityEnum $capability The required capability.
* @return string The no-models-found message.
*/
private function createNoModelsFoundMessage(CapabilityEnum $capability): string
{
if ($this->providerIdOrClassName !== null) {
return sprintf(
'No models found for provider "%s" that support %s for this prompt.',
$this->providerIdOrClassName,
$capability->value
);
}

return $map;
return sprintf(
'No models found that support %s for this prompt.',
$capability->value
);
}

/**
Expand Down Expand Up @@ -1492,6 +1597,25 @@ private function normalizePreferenceIdentifier(
return $trimmed;
}

/**
* Normalizes and validates a provider identifier string.
*
* @since n.e.x.t
*
* @param string $providerId The provider identifier.
* @return string The normalized provider identifier.
* @throws InvalidArgumentException If the identifier is empty.
*/
private function normalizeProviderIdentifier(string $providerId): string
{
$providerId = trim($providerId);
if ($providerId === '') {
throw new InvalidArgumentException('Provider priority identifiers cannot be empty.');
}

return $providerId;
}

/**
* Creates a preference key for a provider/model combination.
*
Expand Down
Loading
Loading