diff --git a/app/controlplane/pkg/biz/workflowcontract.go b/app/controlplane/pkg/biz/workflowcontract.go index e538cf37c..0615678a9 100644 --- a/app/controlplane/pkg/biz/workflowcontract.go +++ b/app/controlplane/pkg/biz/workflowcontract.go @@ -1,5 +1,5 @@ // -// Copyright 2024-2025 The Chainloop Authors. +// Copyright 2024-2026 The Chainloop Authors. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -500,6 +500,9 @@ func (uc *WorkflowContractUseCase) ValidatePolicyAttachment(providerName string, } if err = provider.ValidateAttachment(att, token); err != nil { + if errors.Is(err, policies.ErrUnauthorized) { + return NewErrUnauthorized(fmt.Errorf("invalid attachment: %w", err)) + } return fmt.Errorf("invalid attachment: %w", err) } @@ -695,6 +698,9 @@ func (uc *WorkflowContractUseCase) GetPolicy(providerName, policyName, policyOrg if errors.Is(err, policies.ErrNotFound) { return nil, NewErrNotFound(fmt.Sprintf("policy %q", policyName)) } + if errors.Is(err, policies.ErrUnauthorized) { + return nil, NewErrUnauthorized(fmt.Errorf("failed to resolve policy %q: %w", policyName, err)) + } return nil, fmt.Errorf("failed to resolve policy: %w. Available providers: %s", err, uc.policyRegistry.GetProviderNames()) } @@ -716,6 +722,9 @@ func (uc *WorkflowContractUseCase) GetPolicyGroup(providerName, groupName, group if errors.Is(err, policies.ErrNotFound) { return nil, NewErrNotFound(fmt.Sprintf("policy group %q", groupName)) } + if errors.Is(err, policies.ErrUnauthorized) { + return nil, NewErrUnauthorized(fmt.Errorf("failed to resolve policy group %q: %w", groupName, err)) + } return nil, fmt.Errorf("failed to resolve policy: %w. Available providers: %s", err, uc.policyRegistry.GetProviderNames()) } diff --git a/app/controlplane/pkg/policies/policyprovider.go b/app/controlplane/pkg/policies/policyprovider.go index 88a8fc57e..b76d0cbb8 100644 --- a/app/controlplane/pkg/policies/policyprovider.go +++ b/app/controlplane/pkg/policies/policyprovider.go @@ -1,5 +1,5 @@ // -// Copyright 2024-2025 The Chainloop Authors. +// Copyright 2024-2026 The Chainloop Authors. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -78,7 +78,10 @@ type ProviderAuthOpts struct { OrgName string } -var ErrNotFound = fmt.Errorf("policy not found") +var ( + ErrNotFound = fmt.Errorf("policy not found") + ErrUnauthorized = fmt.Errorf("unauthorized request to policy provider") +) // Resolve calls the remote provider for retrieving a policy func (p *PolicyProvider) Resolve(policyName, policyOrgName string, authOpts ProviderAuthOpts) (*schemaapi.Policy, *PolicyReference, error) { @@ -147,12 +150,15 @@ func (p *PolicyProvider) ValidateAttachment(att *schemaapi.PolicyAttachment, tok } if resp.StatusCode != http.StatusOK { - if resp.StatusCode == http.StatusNotFound || resp.StatusCode == http.StatusMethodNotAllowed { + switch resp.StatusCode { + case http.StatusNotFound, http.StatusMethodNotAllowed: // Ignore endpoint not found as it might not be implemented by the provider return nil + case http.StatusUnauthorized, http.StatusForbidden: + return fmt.Errorf("%w: %s", ErrUnauthorized, readBodyMsg(resp)) + default: + return fmt.Errorf("expected status code 200 but got %d", resp.StatusCode) } - - return fmt.Errorf("expected status code 200 but got %d", resp.StatusCode) } resBytes, err := io.ReadAll(resp.Body) @@ -233,11 +239,14 @@ func (p *PolicyProvider) queryProvider(url *url.URL, digest, orgName string, aut } if resp.StatusCode != http.StatusOK { - if resp.StatusCode == http.StatusNotFound { + switch resp.StatusCode { + case http.StatusNotFound: return "", "", ErrNotFound + case http.StatusUnauthorized, http.StatusForbidden: + return "", "", fmt.Errorf("%w: %s", ErrUnauthorized, readBodyMsg(resp)) + default: + return "", "", fmt.Errorf("expected status code 200 but got %d", resp.StatusCode) } - - return "", "", fmt.Errorf("expected status code 200 but got %d", resp.StatusCode) } resBytes, err := io.ReadAll(resp.Body) @@ -278,6 +287,26 @@ func (p *PolicyProvider) queryProvider(url *url.URL, digest, orgName string, aut return response.Digest, orgName, nil } +// readBodyMsg reads the response body and extracts a human-readable error message. +// It tries to parse the body as JSON and extract the "reason" field (common in +// Connect/gRPC error responses). Falls back to the raw body, or the HTTP status text. +func readBodyMsg(resp *http.Response) string { + defer resp.Body.Close() + body, err := io.ReadAll(io.LimitReader(resp.Body, 1024)) + if err != nil || len(body) == 0 { + return resp.Status + } + + var structured struct { + Reason string `json:"reason"` + } + if json.Unmarshal(body, &structured) == nil && structured.Reason != "" { + return structured.Reason + } + + return string(body) +} + func unmarshalFromRaw(raw *RawMessage, out proto.Message) error { var format unmarshal.RawFormat switch raw.Format { diff --git a/app/controlplane/pkg/policies/policyprovider_http_test.go b/app/controlplane/pkg/policies/policyprovider_http_test.go new file mode 100644 index 000000000..7e7327f3f --- /dev/null +++ b/app/controlplane/pkg/policies/policyprovider_http_test.go @@ -0,0 +1,235 @@ +// +// Copyright 2026 The Chainloop Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package policies + +import ( + "fmt" + "net/http" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestResolveHTTPStatusHandling(t *testing.T) { + testCases := []struct { + name string + statusCode int + body string + wantErr error + wantMsg string // substring expected in error message + }{ + { + name: "401 returns ErrUnauthorized with upstream message", + statusCode: http.StatusUnauthorized, + body: "token expired", + wantErr: ErrUnauthorized, + wantMsg: "token expired", + }, + { + name: "403 returns ErrUnauthorized with upstream message", + statusCode: http.StatusForbidden, + body: "insufficient permissions", + wantErr: ErrUnauthorized, + wantMsg: "insufficient permissions", + }, + { + name: "401 with JSON reason extracts reason field", + statusCode: http.StatusUnauthorized, + body: `{"level":"info","code":"unauthenticated","reason":"repository has no linked projects"}`, + wantErr: ErrUnauthorized, + wantMsg: "repository has no linked projects", + }, + { + name: "401 with empty body falls back to status text", + statusCode: http.StatusUnauthorized, + body: "", + wantErr: ErrUnauthorized, + wantMsg: "401 Unauthorized", + }, + { + name: "404 returns ErrNotFound", + statusCode: http.StatusNotFound, + wantErr: ErrNotFound, + }, + { + name: "500 returns generic error", + statusCode: http.StatusInternalServerError, + wantErr: nil, // generic error, not a sentinel + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(tc.statusCode) + if tc.body != "" { + _, _ = w.Write([]byte(tc.body)) + } + })) + defer server.Close() + + provider := &PolicyProvider{ + name: "test", + url: server.URL, + } + + _, _, err := provider.Resolve("test-policy", "", ProviderAuthOpts{Token: "test-token"}) + require.Error(t, err) + + if tc.wantErr != nil { + assert.ErrorIs(t, err, tc.wantErr) + } else { + assert.NotErrorIs(t, err, ErrUnauthorized) + assert.NotErrorIs(t, err, ErrNotFound) + assert.Contains(t, err.Error(), fmt.Sprintf("got %d", tc.statusCode)) + } + if tc.wantMsg != "" { + assert.Contains(t, err.Error(), tc.wantMsg) + } + }) + } +} + +func TestResolveGroupHTTPStatusHandling(t *testing.T) { + testCases := []struct { + name string + statusCode int + body string + wantErr error + wantMsg string + }{ + { + name: "401 returns ErrUnauthorized with upstream message", + statusCode: http.StatusUnauthorized, + body: "invalid token", + wantErr: ErrUnauthorized, + wantMsg: "invalid token", + }, + { + name: "403 returns ErrUnauthorized with upstream message", + statusCode: http.StatusForbidden, + body: "access denied", + wantErr: ErrUnauthorized, + wantMsg: "access denied", + }, + { + name: "404 returns ErrNotFound", + statusCode: http.StatusNotFound, + wantErr: ErrNotFound, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(tc.statusCode) + if tc.body != "" { + _, _ = w.Write([]byte(tc.body)) + } + })) + defer server.Close() + + provider := &PolicyProvider{ + name: "test", + url: server.URL, + } + + _, _, err := provider.ResolveGroup("test-group", "", ProviderAuthOpts{Token: "test-token"}) + require.Error(t, err) + assert.ErrorIs(t, err, tc.wantErr) + if tc.wantMsg != "" { + assert.Contains(t, err.Error(), tc.wantMsg) + } + }) + } +} + +func TestValidateAttachmentHTTPStatusHandling(t *testing.T) { + testCases := []struct { + name string + statusCode int + body string + wantErr error + wantMsg string + errNil bool + }{ + { + name: "401 returns ErrUnauthorized with upstream message", + statusCode: http.StatusUnauthorized, + body: "token revoked", + wantErr: ErrUnauthorized, + wantMsg: "token revoked", + }, + { + name: "403 returns ErrUnauthorized with upstream message", + statusCode: http.StatusForbidden, + body: "org mismatch", + wantErr: ErrUnauthorized, + wantMsg: "org mismatch", + }, + { + name: "404 is ignored", + statusCode: http.StatusNotFound, + errNil: true, + }, + { + name: "405 is ignored", + statusCode: http.StatusMethodNotAllowed, + errNil: true, + }, + { + name: "500 returns generic error", + statusCode: http.StatusInternalServerError, + wantErr: nil, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(tc.statusCode) + if tc.body != "" { + _, _ = w.Write([]byte(tc.body)) + } + })) + defer server.Close() + + provider := &PolicyProvider{ + name: "test", + url: server.URL, + } + + err := provider.ValidateAttachment(nil, "test-token") + if tc.errNil { + assert.NoError(t, err) + return + } + + require.Error(t, err) + if tc.wantErr != nil { + assert.ErrorIs(t, err, tc.wantErr) + } else { + assert.NotErrorIs(t, err, ErrUnauthorized) + assert.Contains(t, err.Error(), fmt.Sprintf("got %d", tc.statusCode)) + } + if tc.wantMsg != "" { + assert.Contains(t, err.Error(), tc.wantMsg) + } + }) + } +}