feat(agent): classify model families#2525
Conversation
There was a problem hiding this comment.
Code Review
This pull request introduces a ModelFamily enum and a model_family helper function to classify model IDs into their respective families (such as OpenAI, Google, Meta, and Mistral), along with comprehensive unit tests. The review feedback suggests several improvements to the classification logic: removing a redundant check for "/openai/" since "openai/" already covers it, adding a check for "gemma" to correctly classify Google's open-weights models, and adding checks for "mixtral" and "codestral" to ensure Mistral AI's other prominent models are not incorrectly classified as Inferencer.
| if normalized.starts_with("gpt-") | ||
| || normalized.contains("/gpt-") | ||
| || normalized.contains("openai/") | ||
| || normalized.contains("/openai/") | ||
| { |
There was a problem hiding this comment.
| if normalized.contains("gemini") || normalized.contains("google/") { | ||
| return ModelFamily::Google; | ||
| } |
There was a problem hiding this comment.
Google's open-weights model family Gemma (e.g., gemma-2-9b) is widely used, especially in local setups (like Ollama) where the google/ prefix might be omitted. Adding a check for "gemma" ensures these models are correctly classified under the Google family.
if normalized.contains("gemini")
|| normalized.contains("gemma")
|| normalized.contains("google/")
{
return ModelFamily::Google;
}| if normalized.contains("mistral") { | ||
| return ModelFamily::Mistral; | ||
| } |
There was a problem hiding this comment.
Mistral AI's model family includes prominent models like Mixtral (e.g., mixtral-8x7b) and Codestral (e.g., codestral-22b). Since these names do not contain the substring "mistral", they will currently fall back to Inferencer. Adding checks for "mixtral" and "codestral" ensures they are correctly classified.
| if normalized.contains("mistral") { | |
| return ModelFamily::Mistral; | |
| } | |
| if normalized.contains("mistral") | |
| || normalized.contains("mixtral") | |
| || normalized.contains("codestral") | |
| { | |
| return ModelFamily::Mistral; | |
| } |
874f22a to
fff82bc
Compare
fff82bc to
bfcbe9d
Compare
| #[must_use] | ||
| /// Classify a model identifier by its underlying model family. | ||
| pub fn model_family(model_id: &str) -> ModelFamily { |
There was a problem hiding this comment.
#[must_use] placed before the doc comment
In Rust, the conventional (and rustfmt-enforced) order is doc comment first, then outer attributes. Placing #[must_use] before /// can cause rustfmt --edition 2024 --check to report a diff on some toolchain versions and diverges from the pattern used by every other documented item in this file.
Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!
| if normalized.starts_with("gpt-") | ||
| || normalized.contains("/gpt-") | ||
| || normalized.contains("openai/") | ||
| { | ||
| return ModelFamily::OpenAI; | ||
| } |
There was a problem hiding this comment.
OpenAI o-series models not matched with bare IDs
o1-mini, o1-preview, o3, o3-mini, o4-mini and related IDs don't start with gpt-, don't contain /gpt-, and don't contain openai/, so model_family("o1-mini") currently returns Inferencer. They're reachable via the openai/o1-mini routed form, but direct/bare invocations silently fall through. Adding normalized.starts_with("o1") || normalized.starts_with("o3") || normalized.starts_with("o4") (or similar) alongside the existing gpt- prefix check would close this gap.
|
Cherry-picked to codex/v0.8.53 in 4556282. |
Problem
Model-family identity needs a shared agent-crate primitive before TUI, desktop, and runtime API surfaces can render consistent model affordances.
Change
ModelFamilytocodewhale-agent.model_family(model_id)classification for common first-party, routed, and self-hosted model ids.Verification
cargo test -p codewhale-agent model_family --lockedrustfmt crates/agent/src/lib.rs --edition 2024 --checkcargo test -p codewhale-agent --lockedgit diff --checkPartially addresses #2081.
Greptile Summary
This PR introduces a
ModelFamilyenum and amodel_family()classification function tocodewhale-agent, providing a shared primitive for identifying model families across TUI, desktop, and runtime API surfaces. Classification is done via substring/prefix matching on a normalized (trimmed, lowercased) model ID string.ModelFamilyvariants covering major first-party and self-hosted model providers, with anInferencerfallback for unknown or custom gateway IDs.model_family()applies ordered substring checks, correctly prioritisingGptOssbefore the broadergpt-/openai/OpenAI check.Inferencerfallback, but bare OpenAI o-series IDs (o1-mini,o3,o4-mini) are not handled and silently land inInferencer.Confidence Score: 3/5
Mostly safe to merge, but bare OpenAI o-series model IDs are silently misclassified as Inferencer rather than OpenAI, which will produce wrong affordances in any client that consumes this function.
The o1/o3/o4 classification gap is a real behavioral defect: any caller passing a bare
o1-mini,o3, oro4-miniID will receiveInferencerinstead ofOpenAI. Since this function is being introduced precisely to drive UI affordances across multiple surfaces, misclassifying a prominent and widely-used model family will have visible downstream impact the moment a client consumes it.crates/agent/src/lib.rs — specifically the OpenAI matching block in
model_family()and the missing o-series test coverage.Important Files Changed
ModelFamilyenum andmodel_family()classifier with substring/prefix matching; OpenAI o-series bare IDs (o1, o3, o4) fall through toInferencer, and#[must_use]is placed before the doc comment rather than after it.Flowchart
%%{init: {'theme': 'neutral'}}%% flowchart TD A([model_id: &str]) --> B[normalize: trim + lowercase] B --> C{is_empty?} C -- yes --> Z([Inferencer]) C -- no --> D{contains 'deepseek'?} D -- yes --> E([DeepSeek]) D -- no --> F{contains 'claude' or 'anthropic'?} F -- yes --> G([Anthropic]) F -- no --> H{contains 'gpt-oss' or 'gpt_oss'?} H -- yes --> I([GptOss]) H -- no --> J{starts_with 'gpt-' OR contains '/gpt-' OR 'openai/'?} J -- yes --> K([OpenAI]) J -- no --> L{contains 'gemini', 'gemma', or 'google/'?} L -- yes --> M([Google]) L -- no --> N{contains 'llama', 'meta-', or 'meta/'?} N -- yes --> O([Meta]) N -- no --> P{contains 'mistral', 'mixtral', or 'codestral'?} P -- yes --> Q([Mistral]) P -- no --> R{contains 'qwen'?} R -- yes --> S([Qwen]) R -- no --> T{contains 'grok'?} T -- yes --> U([Grok]) T -- no --> V{contains 'cohere' or 'command-r'?} V -- yes --> W([Cohere]) V -- no --> ZReviews (1): Last reviewed commit: "feat(agent): classify model families" | Re-trigger Greptile