Skip to content

Commit 5092af8

Browse files
committed
fix(policies): handle 401/403 from policy provider as unauthorized error
Previously, HTTP 401/403 responses from policy providers were returned as generic errors, falling through to LogAndMaskErr which masked them from the client. Now they surface as proper unauthorized errors with descriptive messages. Signed-off-by: Miguel Martinez Trivino <miguel@chainloop.dev>
1 parent b263756 commit 5092af8

3 files changed

Lines changed: 212 additions & 9 deletions

File tree

app/controlplane/pkg/biz/workflowcontract.go

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
//
2-
// Copyright 2024-2025 The Chainloop Authors.
2+
// Copyright 2024-2026 The Chainloop Authors.
33
//
44
// Licensed under the Apache License, Version 2.0 (the "License");
55
// you may not use this file except in compliance with the License.
@@ -500,6 +500,9 @@ func (uc *WorkflowContractUseCase) ValidatePolicyAttachment(providerName string,
500500
}
501501

502502
if err = provider.ValidateAttachment(att, token); err != nil {
503+
if errors.Is(err, policies.ErrUnauthorized) {
504+
return NewErrUnauthorized(fmt.Errorf("invalid attachment: %w", err))
505+
}
503506
return fmt.Errorf("invalid attachment: %w", err)
504507
}
505508

@@ -695,6 +698,9 @@ func (uc *WorkflowContractUseCase) GetPolicy(providerName, policyName, policyOrg
695698
if errors.Is(err, policies.ErrNotFound) {
696699
return nil, NewErrNotFound(fmt.Sprintf("policy %q", policyName))
697700
}
701+
if errors.Is(err, policies.ErrUnauthorized) {
702+
return nil, NewErrUnauthorized(fmt.Errorf("failed to resolve policy %q: %w", policyName, err))
703+
}
698704

699705
return nil, fmt.Errorf("failed to resolve policy: %w. Available providers: %s", err, uc.policyRegistry.GetProviderNames())
700706
}
@@ -716,6 +722,9 @@ func (uc *WorkflowContractUseCase) GetPolicyGroup(providerName, groupName, group
716722
if errors.Is(err, policies.ErrNotFound) {
717723
return nil, NewErrNotFound(fmt.Sprintf("policy group %q", groupName))
718724
}
725+
if errors.Is(err, policies.ErrUnauthorized) {
726+
return nil, NewErrUnauthorized(fmt.Errorf("failed to resolve policy group %q: %w", groupName, err))
727+
}
719728

720729
return nil, fmt.Errorf("failed to resolve policy: %w. Available providers: %s", err, uc.policyRegistry.GetProviderNames())
721730
}

app/controlplane/pkg/policies/policyprovider.go

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
//
2-
// Copyright 2024-2025 The Chainloop Authors.
2+
// Copyright 2024-2026 The Chainloop Authors.
33
//
44
// Licensed under the Apache License, Version 2.0 (the "License");
55
// you may not use this file except in compliance with the License.
@@ -78,7 +78,10 @@ type ProviderAuthOpts struct {
7878
OrgName string
7979
}
8080

81-
var ErrNotFound = fmt.Errorf("policy not found")
81+
var (
82+
ErrNotFound = fmt.Errorf("policy not found")
83+
ErrUnauthorized = fmt.Errorf("unauthorized request to policy provider")
84+
)
8285

8386
// Resolve calls the remote provider for retrieving a policy
8487
func (p *PolicyProvider) Resolve(policyName, policyOrgName string, authOpts ProviderAuthOpts) (*schemaapi.Policy, *PolicyReference, error) {
@@ -147,12 +150,15 @@ func (p *PolicyProvider) ValidateAttachment(att *schemaapi.PolicyAttachment, tok
147150
}
148151

149152
if resp.StatusCode != http.StatusOK {
150-
if resp.StatusCode == http.StatusNotFound || resp.StatusCode == http.StatusMethodNotAllowed {
153+
switch resp.StatusCode {
154+
case http.StatusNotFound, http.StatusMethodNotAllowed:
151155
// Ignore endpoint not found as it might not be implemented by the provider
152156
return nil
157+
case http.StatusUnauthorized, http.StatusForbidden:
158+
return ErrUnauthorized
159+
default:
160+
return fmt.Errorf("expected status code 200 but got %d", resp.StatusCode)
153161
}
154-
155-
return fmt.Errorf("expected status code 200 but got %d", resp.StatusCode)
156162
}
157163

158164
resBytes, err := io.ReadAll(resp.Body)
@@ -233,11 +239,14 @@ func (p *PolicyProvider) queryProvider(url *url.URL, digest, orgName string, aut
233239
}
234240

235241
if resp.StatusCode != http.StatusOK {
236-
if resp.StatusCode == http.StatusNotFound {
242+
switch resp.StatusCode {
243+
case http.StatusNotFound:
237244
return "", "", ErrNotFound
245+
case http.StatusUnauthorized, http.StatusForbidden:
246+
return "", "", ErrUnauthorized
247+
default:
248+
return "", "", fmt.Errorf("expected status code 200 but got %d", resp.StatusCode)
238249
}
239-
240-
return "", "", fmt.Errorf("expected status code 200 but got %d", resp.StatusCode)
241250
}
242251

243252
resBytes, err := io.ReadAll(resp.Body)
Lines changed: 185 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,185 @@
1+
//
2+
// Copyright 2026 The Chainloop Authors.
3+
//
4+
// Licensed under the Apache License, Version 2.0 (the "License");
5+
// you may not use this file except in compliance with the License.
6+
// You may obtain a copy of the License at
7+
//
8+
// http://www.apache.org/licenses/LICENSE-2.0
9+
//
10+
// Unless required by applicable law or agreed to in writing, software
11+
// distributed under the License is distributed on an "AS IS" BASIS,
12+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
// See the License for the specific language governing permissions and
14+
// limitations under the License.
15+
16+
package policies
17+
18+
import (
19+
"fmt"
20+
"net/http"
21+
"net/http/httptest"
22+
"testing"
23+
24+
"github.com/stretchr/testify/assert"
25+
"github.com/stretchr/testify/require"
26+
)
27+
28+
func TestResolveHTTPStatusHandling(t *testing.T) {
29+
testCases := []struct {
30+
name string
31+
statusCode int
32+
wantErr error
33+
}{
34+
{
35+
name: "401 returns ErrUnauthorized",
36+
statusCode: http.StatusUnauthorized,
37+
wantErr: ErrUnauthorized,
38+
},
39+
{
40+
name: "403 returns ErrUnauthorized",
41+
statusCode: http.StatusForbidden,
42+
wantErr: ErrUnauthorized,
43+
},
44+
{
45+
name: "404 returns ErrNotFound",
46+
statusCode: http.StatusNotFound,
47+
wantErr: ErrNotFound,
48+
},
49+
{
50+
name: "500 returns generic error",
51+
statusCode: http.StatusInternalServerError,
52+
wantErr: nil, // generic error, not a sentinel
53+
},
54+
}
55+
56+
for _, tc := range testCases {
57+
t.Run(tc.name, func(t *testing.T) {
58+
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
59+
w.WriteHeader(tc.statusCode)
60+
}))
61+
defer server.Close()
62+
63+
provider := &PolicyProvider{
64+
name: "test",
65+
url: server.URL,
66+
}
67+
68+
_, _, err := provider.Resolve("test-policy", "", ProviderAuthOpts{Token: "test-token"})
69+
require.Error(t, err)
70+
71+
if tc.wantErr != nil {
72+
assert.ErrorIs(t, err, tc.wantErr)
73+
} else {
74+
assert.NotErrorIs(t, err, ErrUnauthorized)
75+
assert.NotErrorIs(t, err, ErrNotFound)
76+
assert.Contains(t, err.Error(), fmt.Sprintf("got %d", tc.statusCode))
77+
}
78+
})
79+
}
80+
}
81+
82+
func TestResolveGroupHTTPStatusHandling(t *testing.T) {
83+
testCases := []struct {
84+
name string
85+
statusCode int
86+
wantErr error
87+
}{
88+
{
89+
name: "401 returns ErrUnauthorized",
90+
statusCode: http.StatusUnauthorized,
91+
wantErr: ErrUnauthorized,
92+
},
93+
{
94+
name: "403 returns ErrUnauthorized",
95+
statusCode: http.StatusForbidden,
96+
wantErr: ErrUnauthorized,
97+
},
98+
{
99+
name: "404 returns ErrNotFound",
100+
statusCode: http.StatusNotFound,
101+
wantErr: ErrNotFound,
102+
},
103+
}
104+
105+
for _, tc := range testCases {
106+
t.Run(tc.name, func(t *testing.T) {
107+
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
108+
w.WriteHeader(tc.statusCode)
109+
}))
110+
defer server.Close()
111+
112+
provider := &PolicyProvider{
113+
name: "test",
114+
url: server.URL,
115+
}
116+
117+
_, _, err := provider.ResolveGroup("test-group", "", ProviderAuthOpts{Token: "test-token"})
118+
require.Error(t, err)
119+
assert.ErrorIs(t, err, tc.wantErr)
120+
})
121+
}
122+
}
123+
124+
func TestValidateAttachmentHTTPStatusHandling(t *testing.T) {
125+
testCases := []struct {
126+
name string
127+
statusCode int
128+
wantErr error
129+
errNil bool
130+
}{
131+
{
132+
name: "401 returns ErrUnauthorized",
133+
statusCode: http.StatusUnauthorized,
134+
wantErr: ErrUnauthorized,
135+
},
136+
{
137+
name: "403 returns ErrUnauthorized",
138+
statusCode: http.StatusForbidden,
139+
wantErr: ErrUnauthorized,
140+
},
141+
{
142+
name: "404 is ignored",
143+
statusCode: http.StatusNotFound,
144+
errNil: true,
145+
},
146+
{
147+
name: "405 is ignored",
148+
statusCode: http.StatusMethodNotAllowed,
149+
errNil: true,
150+
},
151+
{
152+
name: "500 returns generic error",
153+
statusCode: http.StatusInternalServerError,
154+
wantErr: nil,
155+
},
156+
}
157+
158+
for _, tc := range testCases {
159+
t.Run(tc.name, func(t *testing.T) {
160+
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
161+
w.WriteHeader(tc.statusCode)
162+
}))
163+
defer server.Close()
164+
165+
provider := &PolicyProvider{
166+
name: "test",
167+
url: server.URL,
168+
}
169+
170+
err := provider.ValidateAttachment(nil, "test-token")
171+
if tc.errNil {
172+
assert.NoError(t, err)
173+
return
174+
}
175+
176+
require.Error(t, err)
177+
if tc.wantErr != nil {
178+
assert.ErrorIs(t, err, tc.wantErr)
179+
} else {
180+
assert.NotErrorIs(t, err, ErrUnauthorized)
181+
assert.Contains(t, err.Error(), fmt.Sprintf("got %d", tc.statusCode))
182+
}
183+
})
184+
}
185+
}

0 commit comments

Comments
 (0)