Skip to content

Commit 4d5e22e

Browse files
JeffreyCArajeshkamal5050
authored andcommitted
Address PR review feedback
1 parent 2649bab commit 4d5e22e

7 files changed

Lines changed: 261 additions & 32 deletions

File tree

cli/azd/extensions/azure.ai.agents/internal/cmd/init.go

Lines changed: 32 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -58,15 +58,40 @@ type InitAction struct {
5858
//azureClient *azure.AzureClient
5959
azureContext *azdext.AzureContext
6060
//composedResources []*azdext.ComposedResource
61-
console input.Console
62-
credential azcore.TokenCredential
61+
console input.Console
62+
credential azcore.TokenCredential
63+
projectConfig *azdext.ProjectConfig
64+
environment *azdext.Environment
65+
flags *initFlags
66+
models *modelSelector
67+
68+
deploymentDetails []project.Deployment
69+
httpClient *http.Client
70+
}
71+
72+
// modelSelector encapsulates the dependencies needed for model selection and
73+
// deployment resolution during init. It avoids constructing partial InitAction
74+
// structs when only the model-selection call chain is needed.
75+
type modelSelector struct {
76+
azdClient *azdext.AzdClient
77+
azureContext *azdext.AzureContext
78+
environment *azdext.Environment
79+
flags *initFlags
80+
6381
modelCatalog map[string]*azdext.AiModel
6482
locationWarningShown bool
65-
projectConfig *azdext.ProjectConfig
66-
environment *azdext.Environment
67-
flags *initFlags
68-
deploymentDetails []project.Deployment
69-
httpClient *http.Client
83+
}
84+
85+
func (a *InitAction) getModelSelector() *modelSelector {
86+
if a.models == nil {
87+
a.models = &modelSelector{
88+
azdClient: a.azdClient,
89+
azureContext: a.azureContext,
90+
environment: a.environment,
91+
flags: a.flags,
92+
}
93+
}
94+
return a.models
7095
}
7196

7297
// GitHubUrlInfo holds parsed information from a GitHub URL

cli/azd/extensions/azure.ai.agents/internal/cmd/init_foundry_resources_helpers.go

Lines changed: 5 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -779,13 +779,9 @@ func locationAllowed(location string, allowedLocations []string) bool {
779779
}
780780

781781
normalized := normalizeLocationName(location)
782-
for _, allowed := range allowedLocations {
783-
if normalized == normalizeLocationName(allowed) {
784-
return true
785-
}
786-
}
787-
788-
return false
782+
return slices.ContainsFunc(allowedLocations, func(allowed string) bool {
783+
return normalized == normalizeLocationName(allowed)
784+
})
789785
}
790786

791787
func promptLocationForInit(
@@ -859,9 +855,8 @@ func selectNewModel(
859855
return modelResp.Model, nil
860856
}
861857

862-
// resolveModelDeployment resolves a model deployment without prompting, selecting the best
863-
// candidate based on default versions, SKU priority, and available quota. Both init flows
864-
// use this for the "deploy new model" path in non-interactive mode.
858+
// resolveModelDeployments resolves model deployments without prompting, returning all candidates
859+
// filtered by location and quota. Both init flows use this for deployment resolution.
865860
func resolveModelDeployments(
866861
ctx context.Context,
867862
azdClient *azdext.AzdClient,

cli/azd/extensions/azure.ai.agents/internal/cmd/init_foundry_resources_helpers_test.go

Lines changed: 124 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -298,11 +298,131 @@ func TestFoundryProjectInfoFromResource(t *testing.T) {
298298
func TestAgentModelFilter(t *testing.T) {
299299
t.Parallel()
300300

301-
filter := agentModelFilter([]string{"eastus2"}, []string{"gpt-4.1-mini"})
301+
tests := []struct {
302+
name string
303+
locations []string
304+
excludeModelNames []string
305+
wantLocations []string
306+
wantExclude []string
307+
}{
308+
{
309+
name: "AllPopulated",
310+
locations: []string{"eastus2"},
311+
excludeModelNames: []string{"gpt-4.1-mini"},
312+
wantLocations: []string{"eastus2"},
313+
wantExclude: []string{"gpt-4.1-mini"},
314+
},
315+
{
316+
name: "BothNil",
317+
locations: nil,
318+
excludeModelNames: nil,
319+
wantLocations: nil,
320+
wantExclude: nil,
321+
},
322+
{
323+
name: "EmptySlices",
324+
locations: []string{},
325+
excludeModelNames: []string{},
326+
wantLocations: nil,
327+
wantExclude: nil,
328+
},
329+
{
330+
name: "OnlyLocations",
331+
locations: []string{"westus", "eastus"},
332+
excludeModelNames: nil,
333+
wantLocations: []string{"westus", "eastus"},
334+
wantExclude: nil,
335+
},
336+
}
337+
338+
for _, tt := range tests {
339+
t.Run(tt.name, func(t *testing.T) {
340+
t.Parallel()
341+
342+
filter := agentModelFilter(tt.locations, tt.excludeModelNames)
343+
344+
require.Equal(t, []string{agentsV2ModelCapability}, filter.Capabilities)
345+
require.Equal(t, tt.wantLocations, filter.Locations)
346+
require.Equal(t, tt.wantExclude, filter.ExcludeModelNames)
347+
})
348+
}
349+
}
350+
351+
func TestLocationAllowed(t *testing.T) {
352+
t.Parallel()
302353

303-
require.Equal(t, []string{agentsV2ModelCapability}, filter.Capabilities)
304-
require.Equal(t, []string{"eastus2"}, filter.Locations)
305-
require.Equal(t, []string{"gpt-4.1-mini"}, filter.ExcludeModelNames)
354+
tests := []struct {
355+
name string
356+
location string
357+
allowedLocations []string
358+
want bool
359+
}{
360+
{
361+
name: "EmptyAllowedMeansAllowAll",
362+
location: "anyregion",
363+
allowedLocations: nil,
364+
want: true,
365+
},
366+
{
367+
name: "EmptySliceAllowsAll",
368+
location: "anyregion",
369+
allowedLocations: []string{},
370+
want: true,
371+
},
372+
{
373+
name: "ExactMatch",
374+
location: "eastus",
375+
allowedLocations: []string{"eastus", "westus"},
376+
want: true,
377+
},
378+
{
379+
name: "CaseInsensitiveMatch",
380+
location: "EastUS",
381+
allowedLocations: []string{"eastus", "westus"},
382+
want: true,
383+
},
384+
{
385+
name: "WhitespaceHandled",
386+
location: " eastus ",
387+
allowedLocations: []string{"eastus"},
388+
want: true,
389+
},
390+
{
391+
name: "NotInList",
392+
location: "northeurope",
393+
allowedLocations: []string{"eastus", "westus"},
394+
want: false,
395+
},
396+
}
397+
398+
for _, tt := range tests {
399+
t.Run(tt.name, func(t *testing.T) {
400+
t.Parallel()
401+
require.Equal(t, tt.want, locationAllowed(tt.location, tt.allowedLocations))
402+
})
403+
}
404+
}
405+
406+
func TestNormalizeLocationName(t *testing.T) {
407+
t.Parallel()
408+
409+
tests := []struct {
410+
input string
411+
want string
412+
}{
413+
{"EastUS", "eastus"},
414+
{" westus ", "westus"},
415+
{"NORTHEUROPE", "northeurope"},
416+
{"eastus2", "eastus2"},
417+
{"", ""},
418+
}
419+
420+
for _, tt := range tests {
421+
t.Run(tt.input, func(t *testing.T) {
422+
t.Parallel()
423+
require.Equal(t, tt.want, normalizeLocationName(tt.input))
424+
})
425+
}
306426
}
307427

308428
func TestUpdateFoundryProjectInfo(t *testing.T) {

cli/azd/extensions/azure.ai.agents/internal/cmd/init_from_code.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -640,7 +640,7 @@ func (a *InitFromCodeAction) resolveSelectedModelDeployment(
640640
return nil, exterrors.FromAiService(err, exterrors.CodeModelResolutionFailed)
641641
}
642642

643-
selector := &InitAction{
643+
selector := &modelSelector{
644644
azdClient: a.azdClient,
645645
azureContext: a.azureContext,
646646
environment: a.environment,

cli/azd/extensions/azure.ai.agents/internal/cmd/init_locations.go

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,3 +39,12 @@ var supportedHostedAgentRegions = []string{
3939
func supportedRegionsForInit() []string {
4040
return slices.Clone(supportedHostedAgentRegions)
4141
}
42+
43+
// supportedModelLocations returns the intersection of a model's available locations
44+
// with the supported hosted agent regions.
45+
func supportedModelLocations(modelLocations []string) []string {
46+
supported := supportedRegionsForInit()
47+
return slices.DeleteFunc(slices.Clone(modelLocations), func(loc string) bool {
48+
return !locationAllowed(loc, supported)
49+
})
50+
}
Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
// Copyright (c) Microsoft Corporation. All rights reserved.
2+
// Licensed under the MIT License.
3+
4+
package cmd
5+
6+
import (
7+
"testing"
8+
9+
"github.com/stretchr/testify/require"
10+
)
11+
12+
func TestSupportedModelLocations(t *testing.T) {
13+
t.Parallel()
14+
15+
tests := []struct {
16+
name string
17+
modelLocations []string
18+
wantSubset bool
19+
wantLen int
20+
}{
21+
{
22+
name: "AllSupported",
23+
modelLocations: []string{"eastus", "westus"},
24+
wantSubset: true,
25+
wantLen: 2,
26+
},
27+
{
28+
name: "SomeUnsupported",
29+
modelLocations: []string{"eastus", "unsupportedregion"},
30+
wantSubset: true,
31+
wantLen: 1,
32+
},
33+
{
34+
name: "NoneSupported",
35+
modelLocations: []string{"unsupportedregion1", "unsupportedregion2"},
36+
wantSubset: true,
37+
wantLen: 0,
38+
},
39+
{
40+
name: "EmptyInput",
41+
modelLocations: []string{},
42+
wantSubset: true,
43+
wantLen: 0,
44+
},
45+
{
46+
name: "NilInput",
47+
modelLocations: nil,
48+
wantSubset: true,
49+
wantLen: 0,
50+
},
51+
}
52+
53+
supported := supportedRegionsForInit()
54+
55+
for _, tt := range tests {
56+
t.Run(tt.name, func(t *testing.T) {
57+
t.Parallel()
58+
59+
result := supportedModelLocations(tt.modelLocations)
60+
require.Len(t, result, tt.wantLen)
61+
62+
// Every returned location must be in the supported regions list
63+
for _, loc := range result {
64+
require.True(t, locationAllowed(loc, supported),
65+
"returned location %q should be in supported regions", loc)
66+
}
67+
})
68+
}
69+
}
70+
71+
func TestSupportedModelLocationsDoesNotMutateInput(t *testing.T) {
72+
t.Parallel()
73+
74+
input := []string{"eastus", "unsupportedregion", "westus"}
75+
original := make([]string, len(input))
76+
copy(original, input)
77+
78+
_ = supportedModelLocations(input)
79+
80+
require.Equal(t, original, input, "input slice should not be mutated")
81+
}

cli/azd/extensions/azure.ai.agents/internal/cmd/init_models.go

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ import (
2626

2727
var defaultSkuPriority = []string{"GlobalStandard", "DataZoneStandard", "Standard"}
2828

29-
func (a *InitAction) loadAiCatalog(ctx context.Context) error {
29+
func (a *modelSelector) loadAiCatalog(ctx context.Context) error {
3030
if a.modelCatalog != nil {
3131
return nil
3232
}
@@ -69,9 +69,8 @@ func mapModelsByName(models []*azdext.AiModel) map[string]*azdext.AiModel {
6969
return modelMap
7070
}
7171

72-
func (a *InitAction) updateEnvLocation(ctx context.Context, selectedLocation string) error {
72+
func (a *modelSelector) updateEnvLocation(ctx context.Context, selectedLocation string) error {
7373
envName := ""
74-
var err error
7574
if a.environment != nil {
7675
envName = a.environment.Name
7776
} else {
@@ -82,7 +81,7 @@ func (a *InitAction) updateEnvLocation(ctx context.Context, selectedLocation str
8281
envName = envResponse.Environment.Name
8382
}
8483

85-
_, err = a.azdClient.Environment().SetValue(ctx, &azdext.SetEnvRequest{
84+
_, err := a.azdClient.Environment().SetValue(ctx, &azdext.SetEnvRequest{
8685
EnvName: envName,
8786
Key: "AZURE_LOCATION",
8887
Value: selectedLocation,
@@ -288,7 +287,7 @@ func (a *InitAction) getModelDeploymentDetails(ctx context.Context, model agent_
288287
}
289288
}
290289

291-
modelDetails, err := a.getModelDetails(ctx, model.Id)
290+
modelDetails, err := a.getModelSelector().getModelDetails(ctx, model.Id)
292291
if err != nil {
293292
return nil, fmt.Errorf("failed to get model details: %w", err)
294293
}
@@ -322,7 +321,7 @@ func (a *InitAction) getModelDeploymentDetails(ctx context.Context, model agent_
322321
}, nil
323322
}
324323

325-
func (a *InitAction) getModelDetails(ctx context.Context, modelName string) (*azdext.AiModelDeployment, error) {
324+
func (a *modelSelector) getModelDetails(ctx context.Context, modelName string) (*azdext.AiModelDeployment, error) {
326325
if err := a.loadAiCatalog(ctx); err != nil {
327326
return nil, err
328327
}
@@ -501,7 +500,7 @@ const (
501500
"2) Create a new Foundry project after changing regions."
502501
)
503502

504-
func (a *InitAction) promptForAlternativeModel(
503+
func (a *modelSelector) promptForAlternativeModel(
505504
ctx context.Context,
506505
originalModelName string,
507506
) (*azdext.AiModel, error) {
@@ -565,7 +564,7 @@ func (a *InitAction) promptForAlternativeModel(
565564
return modelResp.Model, nil
566565
}
567566

568-
func (a *InitAction) promptForModelLocationMismatch(
567+
func (a *modelSelector) promptForModelLocationMismatch(
569568
ctx context.Context,
570569
model *azdext.AiModel,
571570
currentLocation string,
@@ -626,7 +625,7 @@ func (a *InitAction) promptForModelLocationMismatch(
626625
&azdext.PromptAiModelLocationWithQuotaRequest{
627626
AzureContext: a.azureContext,
628627
ModelName: currentModel.Name,
629-
AllowedLocations: currentModel.Locations,
628+
AllowedLocations: supportedModelLocations(currentModel.Locations),
630629
Quota: &azdext.QuotaCheckOptions{
631630
MinRemainingCapacity: 1,
632631
},
@@ -677,7 +676,7 @@ func (a *InitAction) promptForModelLocationMismatch(
677676
&azdext.PromptAiModelLocationWithQuotaRequest{
678677
AzureContext: a.azureContext,
679678
ModelName: selectedModel.Name,
680-
AllowedLocations: selectedModel.Locations,
679+
AllowedLocations: supportedModelLocations(selectedModel.Locations),
681680
Quota: &azdext.QuotaCheckOptions{
682681
MinRemainingCapacity: 1,
683682
},

0 commit comments

Comments
 (0)