Skip to content

Commit 0d08d86

Browse files
committed
adding hVaultErr
- custom error struct hVaultErr added. Vault errors processed and stored in this type - string comparison replaced with errors.Is - unit test added and dockerfiles updated
1 parent 2fa5117 commit 0d08d86

5 files changed

Lines changed: 394 additions & 6 deletions

File tree

Containerfile-kleidi-kms-hsm

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@ RUN go mod download
1212

1313
RUN CGO_ENABLED=1 GO111MODULE=on go build -ldflags "-X main.kleidiVersion=$VERSION" -a -installsuffix cgo cmd/kleidi/main.go
1414

15+
RUN go test -v ./...
16+
1517
FROM quay.io/centos/centos:stream9
1618
RUN dnf -y install jq opensc softhsm; dnf clean all;
1719

Containerfile-kleidi-kms-vault

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@ RUN go mod download
1212

1313
RUN CGO_ENABLED=1 GO111MODULE=on go build -ldflags "-X main.kleidiVersion=$VERSION" -a -installsuffix cgo cmd/kleidi/main.go
1414

15+
RUN go test -v ./...
16+
1517
FROM quay.io/centos/centos:stream9
1618

1719
LABEL org.opencontainers.image.source=https://github.com/beezy-dev/kleidi

internal/providers/hvault.go

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@ import (
1313
"fmt"
1414
"os"
1515
"strconv"
16-
"strings"
1716
"time"
1817

1918
"github.com/hashicorp/vault/api"
@@ -44,9 +43,10 @@ type hvaultRemoteService struct {
4443
func fatalOrErr(err error) error {
4544
// it can happen that token gets ivalidated - shutdown in these cases
4645
// for others it just "flows through"
47-
if strings.Contains(err.Error(), "invalid token") {
48-
zap.L().Fatal("EXIT:token: invalid token, restarting: " + err.Error())
49-
return err
46+
wrappedErr := WrapVaultError(err.Error())
47+
if errors.Is(wrappedErr, ErrInvalidToken) {
48+
zap.L().Fatal("EXIT:token: invalid token, restarting.")
49+
return err
5050
}
5151
return err
5252
}
@@ -360,13 +360,17 @@ func retryVaultOp[T any](s *hvaultRemoteService, ctx context.Context, amount int
360360
default:
361361
result, err = f()
362362
if err != nil {
363-
if strings.Contains(err.Error(), "invalid token") {
363+
zap.L().Error("Got error: " + err.Error())
364+
wrappedErr := WrapVaultError(err.Error())
365+
if errors.Is(wrappedErr, ErrInvalidToken) {
364366
// re-login
365367
_, err := s.Client.Auth().Login(ctx, s.ClientAuthMethod)
366368
if err != nil {
367369
zap.L().Error("Error: Could not relogin: " + err.Error())
370+
} else {
371+
// relogin OK
372+
zap.L().Debug("Relogin succesful.")
368373
}
369-
// relogin OK
370374
} // other error that cannot be solved by relogin: try calling f() again
371375
} else {
372376
// no error, no need to retry

internal/providers/hvaulterr.go

Lines changed: 143 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,143 @@
1+
package providers
2+
3+
import (
4+
"fmt"
5+
"regexp"
6+
"strings"
7+
"strconv"
8+
)
9+
10+
// The hVaultErr struct wraps the Vault string-based error,
11+
// providing structured fields for reliable error matching and inspection.
12+
// It implements the `error` interface.
13+
type hVaultErr struct {
14+
// originalError holds the full, original error string for context.
15+
originalError string
16+
17+
// Extracted fields from the error string.
18+
StatusCode int
19+
Method string
20+
URL string
21+
Namespace string
22+
Messages []string
23+
}
24+
25+
// Error implements the `error` interface, returning the original error string.
26+
func (e *hVaultErr) Error() string {
27+
return e.originalError
28+
}
29+
30+
// Unwrap returns a new error with the original wrapped error.
31+
func (e *hVaultErr) Unwrap() error {
32+
return fmt.Errorf("%s", e.originalError)
33+
}
34+
35+
// Is allows errors.Is to work directly with the hVaultErr type.
36+
// It checks if the wrapped error matches the target sentinel error based on
37+
// its status code and message.
38+
func (e *hVaultErr) Is(target error) bool {
39+
// Check if the target is a *hVaultErr.
40+
targetErr, ok := target.(*hVaultErr)
41+
if !ok {
42+
return false
43+
}
44+
45+
// Check if the status codes match.
46+
if e.StatusCode != targetErr.StatusCode {
47+
return false
48+
}
49+
50+
// If the target error doesn't have an original message,
51+
// a status code match is sufficient.
52+
if targetErr.originalError == "" {
53+
return true
54+
}
55+
56+
// If the target error has an original message,
57+
// check if any of the parsed messages in the receiver error
58+
// contain the target's original message.
59+
for _, msg := range e.Messages {
60+
if strings.Contains(strings.ToLower(msg), strings.ToLower(targetErr.originalError)) {
61+
return true
62+
}
63+
}
64+
65+
return false
66+
}
67+
68+
// Pre-defined constant errors for specific conditions.
69+
var (
70+
ErrInvalidToken = &hVaultErr{StatusCode: 403, originalError: "invalid token"}
71+
ErrVaultSealed = &hVaultErr{StatusCode: 503, originalError: "Vault is sealed"}
72+
)
73+
74+
// WrapVaultError parses a raw Vault error string and wraps it
75+
// in a structured hVaultErr. It uses regular expressions to
76+
// extract key information.
77+
func WrapVaultError(errString string) error {
78+
re := regexp.MustCompile(
79+
`URL: (\S+) (\S+)\s*` +
80+
`Code: (\d+)\. .*?:\s*` +
81+
`(?s)(.*)`)
82+
83+
match := re.FindStringSubmatch(errString)
84+
if len(match) < 4 {
85+
// If parsing fails, just return a generic wrapped error.
86+
return fmt.Errorf("failed to parse Vault error string: %w", fmt.Errorf("%s", errString))
87+
}
88+
89+
// Extract the components from the regex match.
90+
method := match[1]
91+
url := match[2]
92+
statusCode, err := strconv.Atoi(match[3])
93+
if err != nil {
94+
// If parsing fails, default the status code to 0.
95+
// This is a safer alternative to a potential Sscanf panic or unexpected behavior.
96+
statusCode = 0
97+
}
98+
99+
errorBody := match[4]
100+
101+
// Extract the namespace, which is optional.
102+
namespaceRe := regexp.MustCompile(`Namespace: (.+)\n`)
103+
namespaceMatch := namespaceRe.FindStringSubmatch(errString)
104+
namespace := ""
105+
if len(namespaceMatch) > 1 {
106+
namespace = strings.TrimSpace(namespaceMatch[1])
107+
}
108+
109+
// Parse the individual error messages from the error body.
110+
var messages []string
111+
if strings.Contains(errorBody, "* ") {
112+
// Handles the case with multiple errors
113+
messageLines := strings.Split(strings.TrimSpace(errorBody), "\n")
114+
// Filter out the "* n error(s) occurred:" line and trim "* "
115+
for _, line := range messageLines {
116+
trimmedLine := strings.TrimSpace(line)
117+
if trimmedLine == "" {
118+
continue
119+
}
120+
if strings.HasPrefix(trimmedLine, "* ") {
121+
// Check for the "n errors occurred" or "1 error occurred" line and skip it
122+
if strings.HasSuffix(trimmedLine, " errors occurred:") || strings.HasSuffix(trimmedLine, " error occurred:") {
123+
continue
124+
}
125+
messages = append(messages, strings.TrimPrefix(trimmedLine, "* "))
126+
} else {
127+
messages = append(messages, trimmedLine)
128+
}
129+
}
130+
} else {
131+
// Handles the case with a single raw message
132+
messages = append(messages, strings.TrimSpace(errorBody))
133+
}
134+
135+
return &hVaultErr{
136+
originalError: errString,
137+
StatusCode: statusCode,
138+
Method: method,
139+
URL: url,
140+
Namespace: namespace,
141+
Messages: messages,
142+
}
143+
}

0 commit comments

Comments
 (0)