Skip to content

Commit 1041f32

Browse files
authored
fix(policies): handle 401/403 from policy provider as unauthorized error (#2962)
Signed-off-by: Miguel Martinez Trivino <miguel@chainloop.dev>
1 parent 81d3701 commit 1041f32

3 files changed

Lines changed: 282 additions & 9 deletions

File tree

app/controlplane/pkg/biz/workflowcontract.go

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
//
2-
// Copyright 2024-2025 The Chainloop Authors.
2+
// Copyright 2024-2026 The Chainloop Authors.
33
//
44
// Licensed under the Apache License, Version 2.0 (the "License");
55
// you may not use this file except in compliance with the License.
@@ -500,6 +500,9 @@ func (uc *WorkflowContractUseCase) ValidatePolicyAttachment(providerName string,
500500
}
501501

502502
if err = provider.ValidateAttachment(att, token); err != nil {
503+
if errors.Is(err, policies.ErrUnauthorized) {
504+
return NewErrUnauthorized(fmt.Errorf("invalid attachment: %w", err))
505+
}
503506
return fmt.Errorf("invalid attachment: %w", err)
504507
}
505508

@@ -695,6 +698,9 @@ func (uc *WorkflowContractUseCase) GetPolicy(providerName, policyName, policyOrg
695698
if errors.Is(err, policies.ErrNotFound) {
696699
return nil, NewErrNotFound(fmt.Sprintf("policy %q", policyName))
697700
}
701+
if errors.Is(err, policies.ErrUnauthorized) {
702+
return nil, NewErrUnauthorized(fmt.Errorf("failed to resolve policy %q: %w", policyName, err))
703+
}
698704

699705
return nil, fmt.Errorf("failed to resolve policy: %w. Available providers: %s", err, uc.policyRegistry.GetProviderNames())
700706
}
@@ -716,6 +722,9 @@ func (uc *WorkflowContractUseCase) GetPolicyGroup(providerName, groupName, group
716722
if errors.Is(err, policies.ErrNotFound) {
717723
return nil, NewErrNotFound(fmt.Sprintf("policy group %q", groupName))
718724
}
725+
if errors.Is(err, policies.ErrUnauthorized) {
726+
return nil, NewErrUnauthorized(fmt.Errorf("failed to resolve policy group %q: %w", groupName, err))
727+
}
719728

720729
return nil, fmt.Errorf("failed to resolve policy: %w. Available providers: %s", err, uc.policyRegistry.GetProviderNames())
721730
}

app/controlplane/pkg/policies/policyprovider.go

Lines changed: 37 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
//
2-
// Copyright 2024-2025 The Chainloop Authors.
2+
// Copyright 2024-2026 The Chainloop Authors.
33
//
44
// Licensed under the Apache License, Version 2.0 (the "License");
55
// you may not use this file except in compliance with the License.
@@ -78,7 +78,10 @@ type ProviderAuthOpts struct {
7878
OrgName string
7979
}
8080

81-
var ErrNotFound = fmt.Errorf("policy not found")
81+
var (
82+
ErrNotFound = fmt.Errorf("policy not found")
83+
ErrUnauthorized = fmt.Errorf("unauthorized request to policy provider")
84+
)
8285

8386
// Resolve calls the remote provider for retrieving a policy
8487
func (p *PolicyProvider) Resolve(policyName, policyOrgName string, authOpts ProviderAuthOpts) (*schemaapi.Policy, *PolicyReference, error) {
@@ -147,12 +150,15 @@ func (p *PolicyProvider) ValidateAttachment(att *schemaapi.PolicyAttachment, tok
147150
}
148151

149152
if resp.StatusCode != http.StatusOK {
150-
if resp.StatusCode == http.StatusNotFound || resp.StatusCode == http.StatusMethodNotAllowed {
153+
switch resp.StatusCode {
154+
case http.StatusNotFound, http.StatusMethodNotAllowed:
151155
// Ignore endpoint not found as it might not be implemented by the provider
152156
return nil
157+
case http.StatusUnauthorized, http.StatusForbidden:
158+
return fmt.Errorf("%w: %s", ErrUnauthorized, readBodyMsg(resp))
159+
default:
160+
return fmt.Errorf("expected status code 200 but got %d", resp.StatusCode)
153161
}
154-
155-
return fmt.Errorf("expected status code 200 but got %d", resp.StatusCode)
156162
}
157163

158164
resBytes, err := io.ReadAll(resp.Body)
@@ -233,11 +239,14 @@ func (p *PolicyProvider) queryProvider(url *url.URL, digest, orgName string, aut
233239
}
234240

235241
if resp.StatusCode != http.StatusOK {
236-
if resp.StatusCode == http.StatusNotFound {
242+
switch resp.StatusCode {
243+
case http.StatusNotFound:
237244
return "", "", ErrNotFound
245+
case http.StatusUnauthorized, http.StatusForbidden:
246+
return "", "", fmt.Errorf("%w: %s", ErrUnauthorized, readBodyMsg(resp))
247+
default:
248+
return "", "", fmt.Errorf("expected status code 200 but got %d", resp.StatusCode)
238249
}
239-
240-
return "", "", fmt.Errorf("expected status code 200 but got %d", resp.StatusCode)
241250
}
242251

243252
resBytes, err := io.ReadAll(resp.Body)
@@ -278,6 +287,26 @@ func (p *PolicyProvider) queryProvider(url *url.URL, digest, orgName string, aut
278287
return response.Digest, orgName, nil
279288
}
280289

290+
// readBodyMsg reads the response body and extracts a human-readable error message.
291+
// It tries to parse the body as JSON and extract the "reason" field (common in
292+
// Connect/gRPC error responses). Falls back to the raw body, or the HTTP status text.
293+
func readBodyMsg(resp *http.Response) string {
294+
defer resp.Body.Close()
295+
body, err := io.ReadAll(io.LimitReader(resp.Body, 1024))
296+
if err != nil || len(body) == 0 {
297+
return resp.Status
298+
}
299+
300+
var structured struct {
301+
Reason string `json:"reason"`
302+
}
303+
if json.Unmarshal(body, &structured) == nil && structured.Reason != "" {
304+
return structured.Reason
305+
}
306+
307+
return string(body)
308+
}
309+
281310
func unmarshalFromRaw(raw *RawMessage, out proto.Message) error {
282311
var format unmarshal.RawFormat
283312
switch raw.Format {
Lines changed: 235 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,235 @@
1+
//
2+
// Copyright 2026 The Chainloop Authors.
3+
//
4+
// Licensed under the Apache License, Version 2.0 (the "License");
5+
// you may not use this file except in compliance with the License.
6+
// You may obtain a copy of the License at
7+
//
8+
// http://www.apache.org/licenses/LICENSE-2.0
9+
//
10+
// Unless required by applicable law or agreed to in writing, software
11+
// distributed under the License is distributed on an "AS IS" BASIS,
12+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
// See the License for the specific language governing permissions and
14+
// limitations under the License.
15+
16+
package policies
17+
18+
import (
19+
"fmt"
20+
"net/http"
21+
"net/http/httptest"
22+
"testing"
23+
24+
"github.com/stretchr/testify/assert"
25+
"github.com/stretchr/testify/require"
26+
)
27+
28+
func TestResolveHTTPStatusHandling(t *testing.T) {
29+
testCases := []struct {
30+
name string
31+
statusCode int
32+
body string
33+
wantErr error
34+
wantMsg string // substring expected in error message
35+
}{
36+
{
37+
name: "401 returns ErrUnauthorized with upstream message",
38+
statusCode: http.StatusUnauthorized,
39+
body: "token expired",
40+
wantErr: ErrUnauthorized,
41+
wantMsg: "token expired",
42+
},
43+
{
44+
name: "403 returns ErrUnauthorized with upstream message",
45+
statusCode: http.StatusForbidden,
46+
body: "insufficient permissions",
47+
wantErr: ErrUnauthorized,
48+
wantMsg: "insufficient permissions",
49+
},
50+
{
51+
name: "401 with JSON reason extracts reason field",
52+
statusCode: http.StatusUnauthorized,
53+
body: `{"level":"info","code":"unauthenticated","reason":"repository has no linked projects"}`,
54+
wantErr: ErrUnauthorized,
55+
wantMsg: "repository has no linked projects",
56+
},
57+
{
58+
name: "401 with empty body falls back to status text",
59+
statusCode: http.StatusUnauthorized,
60+
body: "",
61+
wantErr: ErrUnauthorized,
62+
wantMsg: "401 Unauthorized",
63+
},
64+
{
65+
name: "404 returns ErrNotFound",
66+
statusCode: http.StatusNotFound,
67+
wantErr: ErrNotFound,
68+
},
69+
{
70+
name: "500 returns generic error",
71+
statusCode: http.StatusInternalServerError,
72+
wantErr: nil, // generic error, not a sentinel
73+
},
74+
}
75+
76+
for _, tc := range testCases {
77+
t.Run(tc.name, func(t *testing.T) {
78+
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
79+
w.WriteHeader(tc.statusCode)
80+
if tc.body != "" {
81+
_, _ = w.Write([]byte(tc.body))
82+
}
83+
}))
84+
defer server.Close()
85+
86+
provider := &PolicyProvider{
87+
name: "test",
88+
url: server.URL,
89+
}
90+
91+
_, _, err := provider.Resolve("test-policy", "", ProviderAuthOpts{Token: "test-token"})
92+
require.Error(t, err)
93+
94+
if tc.wantErr != nil {
95+
assert.ErrorIs(t, err, tc.wantErr)
96+
} else {
97+
assert.NotErrorIs(t, err, ErrUnauthorized)
98+
assert.NotErrorIs(t, err, ErrNotFound)
99+
assert.Contains(t, err.Error(), fmt.Sprintf("got %d", tc.statusCode))
100+
}
101+
if tc.wantMsg != "" {
102+
assert.Contains(t, err.Error(), tc.wantMsg)
103+
}
104+
})
105+
}
106+
}
107+
108+
func TestResolveGroupHTTPStatusHandling(t *testing.T) {
109+
testCases := []struct {
110+
name string
111+
statusCode int
112+
body string
113+
wantErr error
114+
wantMsg string
115+
}{
116+
{
117+
name: "401 returns ErrUnauthorized with upstream message",
118+
statusCode: http.StatusUnauthorized,
119+
body: "invalid token",
120+
wantErr: ErrUnauthorized,
121+
wantMsg: "invalid token",
122+
},
123+
{
124+
name: "403 returns ErrUnauthorized with upstream message",
125+
statusCode: http.StatusForbidden,
126+
body: "access denied",
127+
wantErr: ErrUnauthorized,
128+
wantMsg: "access denied",
129+
},
130+
{
131+
name: "404 returns ErrNotFound",
132+
statusCode: http.StatusNotFound,
133+
wantErr: ErrNotFound,
134+
},
135+
}
136+
137+
for _, tc := range testCases {
138+
t.Run(tc.name, func(t *testing.T) {
139+
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
140+
w.WriteHeader(tc.statusCode)
141+
if tc.body != "" {
142+
_, _ = w.Write([]byte(tc.body))
143+
}
144+
}))
145+
defer server.Close()
146+
147+
provider := &PolicyProvider{
148+
name: "test",
149+
url: server.URL,
150+
}
151+
152+
_, _, err := provider.ResolveGroup("test-group", "", ProviderAuthOpts{Token: "test-token"})
153+
require.Error(t, err)
154+
assert.ErrorIs(t, err, tc.wantErr)
155+
if tc.wantMsg != "" {
156+
assert.Contains(t, err.Error(), tc.wantMsg)
157+
}
158+
})
159+
}
160+
}
161+
162+
func TestValidateAttachmentHTTPStatusHandling(t *testing.T) {
163+
testCases := []struct {
164+
name string
165+
statusCode int
166+
body string
167+
wantErr error
168+
wantMsg string
169+
errNil bool
170+
}{
171+
{
172+
name: "401 returns ErrUnauthorized with upstream message",
173+
statusCode: http.StatusUnauthorized,
174+
body: "token revoked",
175+
wantErr: ErrUnauthorized,
176+
wantMsg: "token revoked",
177+
},
178+
{
179+
name: "403 returns ErrUnauthorized with upstream message",
180+
statusCode: http.StatusForbidden,
181+
body: "org mismatch",
182+
wantErr: ErrUnauthorized,
183+
wantMsg: "org mismatch",
184+
},
185+
{
186+
name: "404 is ignored",
187+
statusCode: http.StatusNotFound,
188+
errNil: true,
189+
},
190+
{
191+
name: "405 is ignored",
192+
statusCode: http.StatusMethodNotAllowed,
193+
errNil: true,
194+
},
195+
{
196+
name: "500 returns generic error",
197+
statusCode: http.StatusInternalServerError,
198+
wantErr: nil,
199+
},
200+
}
201+
202+
for _, tc := range testCases {
203+
t.Run(tc.name, func(t *testing.T) {
204+
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
205+
w.WriteHeader(tc.statusCode)
206+
if tc.body != "" {
207+
_, _ = w.Write([]byte(tc.body))
208+
}
209+
}))
210+
defer server.Close()
211+
212+
provider := &PolicyProvider{
213+
name: "test",
214+
url: server.URL,
215+
}
216+
217+
err := provider.ValidateAttachment(nil, "test-token")
218+
if tc.errNil {
219+
assert.NoError(t, err)
220+
return
221+
}
222+
223+
require.Error(t, err)
224+
if tc.wantErr != nil {
225+
assert.ErrorIs(t, err, tc.wantErr)
226+
} else {
227+
assert.NotErrorIs(t, err, ErrUnauthorized)
228+
assert.Contains(t, err.Error(), fmt.Sprintf("got %d", tc.statusCode))
229+
}
230+
if tc.wantMsg != "" {
231+
assert.Contains(t, err.Error(), tc.wantMsg)
232+
}
233+
})
234+
}
235+
}

0 commit comments

Comments
 (0)