-
Notifications
You must be signed in to change notification settings - Fork 5
Expand file tree
/
Copy pathclient_manager.go
More file actions
382 lines (316 loc) · 11.3 KB
/
client_manager.go
File metadata and controls
382 lines (316 loc) · 11.3 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT License.
package azureappconfiguration
import (
"context"
"fmt"
"log"
"math"
"math/rand"
"net"
"net/url"
"strconv"
"strings"
"time"
"github.com/Azure/azure-sdk-for-go/sdk/azcore"
"github.com/Azure/azure-sdk-for-go/sdk/azcore/policy"
"github.com/Azure/azure-sdk-for-go/sdk/data/azappconfig"
)
// configurationClientManager handles creation and management of app configuration clients
type configurationClientManager struct {
replicaDiscoveryEnabled bool
clientOptions *azappconfig.ClientOptions
staticClient *configurationClientWrapper
dynamicClients []*configurationClientWrapper
endpoint string
validDomain string
credential azcore.TokenCredential
secret string
id string
lastFallbackClientAttempt time.Time
lastFallbackClientRefresh time.Time
}
// configurationClientWrapper wraps an Azure App Configuration client with additional metadata
type configurationClientWrapper struct {
endpoint string
client *azappconfig.Client
backOffEndTime time.Time
failedAttempts int
}
type clientManager interface {
getClients(ctx context.Context) ([]*configurationClientWrapper, error)
refreshClients(ctx context.Context)
}
// newConfigurationClientManager creates a new configuration client manager
func newConfigurationClientManager(authOptions AuthenticationOptions, options *Options) (*configurationClientManager, error) {
manager := &configurationClientManager{
clientOptions: setTelemetry(options.ClientOptions),
}
if options.ReplicaDiscoveryEnabled == nil || *options.ReplicaDiscoveryEnabled {
manager.replicaDiscoveryEnabled = true
}
// Create client based on authentication options
if err := manager.initializeClient(authOptions); err != nil {
return nil, fmt.Errorf("failed to initialize configuration client: %w", err)
}
return manager, nil
}
// initializeClient sets up the Azure App Configuration client based on the provided authentication options
func (manager *configurationClientManager) initializeClient(authOptions AuthenticationOptions) error {
var err error
var staticClient *azappconfig.Client
if authOptions.ConnectionString != "" {
// Initialize using connection string
connectionString := authOptions.ConnectionString
if manager.endpoint, err = parseConnectionString(connectionString, endpointKey); err != nil {
return err
}
if manager.secret, err = parseConnectionString(connectionString, secretKey); err != nil {
return err
}
if manager.id, err = parseConnectionString(connectionString, idKey); err != nil {
return err
}
if staticClient, err = azappconfig.NewClientFromConnectionString(connectionString, manager.clientOptions); err != nil {
return err
}
} else {
// Initialize using explicit endpoint and credential
if staticClient, err = azappconfig.NewClient(authOptions.Endpoint, authOptions.Credential, manager.clientOptions); err != nil {
return err
}
manager.endpoint = authOptions.Endpoint
manager.credential = authOptions.Credential
}
manager.validDomain = getValidDomain(manager.endpoint)
manager.staticClient = &configurationClientWrapper{
endpoint: manager.endpoint,
client: staticClient,
}
return nil
}
func (manager *configurationClientManager) getClients(ctx context.Context) ([]*configurationClientWrapper, error) {
currentTime := time.Now()
clients := make([]*configurationClientWrapper, 0, 1+len(manager.dynamicClients))
// Add the static client if it is not in backoff
if currentTime.After(manager.staticClient.backOffEndTime) {
clients = append(clients, manager.staticClient)
}
if !manager.replicaDiscoveryEnabled {
return clients, nil
}
if currentTime.After(manager.lastFallbackClientAttempt.Add(minimalClientRefreshInterval)) &&
(manager.dynamicClients == nil ||
currentTime.After(manager.lastFallbackClientRefresh.Add(fallbackClientRefreshExpireInterval))) {
manager.lastFallbackClientAttempt = currentTime
url, _ := url.Parse(manager.endpoint)
manager.discoverFallbackClients(url.Host)
}
for _, clientWrapper := range manager.dynamicClients {
if currentTime.After(clientWrapper.backOffEndTime) {
clients = append(clients, clientWrapper)
}
}
return clients, nil
}
func (manager *configurationClientManager) refreshClients(ctx context.Context) {
currentTime := time.Now()
if manager.replicaDiscoveryEnabled &&
currentTime.After(manager.lastFallbackClientAttempt.Add(minimalClientRefreshInterval)) {
manager.lastFallbackClientAttempt = currentTime
url, _ := url.Parse(manager.endpoint)
manager.discoverFallbackClients(url.Host)
}
}
func (manager *configurationClientManager) discoverFallbackClients(host string) {
go func() {
defer func() {
if r := recover(); r != nil {
log.Printf("panic in replica discovery: %v", r)
}
}()
discoveryCtx, cancel := context.WithTimeout(context.Background(), failoverTimeout)
defer cancel()
srvTargetHosts, err := querySrvTargetHost(discoveryCtx, host)
if err != nil {
log.Printf("failed to discover fallback clients for %s: %v", host, err)
return
}
manager.processSrvTargetHosts(srvTargetHosts)
}()
}
func (manager *configurationClientManager) processSrvTargetHosts(srvTargetHosts []string) {
// Shuffle the list of SRV target hosts for load balancing
rand.Shuffle(len(srvTargetHosts), func(i, j int) {
srvTargetHosts[i], srvTargetHosts[j] = srvTargetHosts[j], srvTargetHosts[i]
})
newDynamicClients := make([]*configurationClientWrapper, 0, len(srvTargetHosts))
for _, host := range srvTargetHosts {
if isValidEndpoint(host, manager.validDomain) {
targetEndpoint := "https://" + host
if strings.EqualFold(targetEndpoint, manager.endpoint) {
continue // Skip primary endpoint
}
client, err := manager.newConfigurationClient(targetEndpoint)
if err != nil {
log.Printf("failed to create client for replica %s: %v", targetEndpoint, err)
continue // Continue with other replicas instead of returning
}
newDynamicClients = append(newDynamicClients, &configurationClientWrapper{
endpoint: targetEndpoint,
client: client,
})
}
}
manager.dynamicClients = newDynamicClients
manager.lastFallbackClientRefresh = time.Now()
}
func querySrvTargetHost(ctx context.Context, host string) ([]string, error) {
results := make([]string, 0)
_, originRecords, err := net.DefaultResolver.LookupSRV(ctx, originKey, tcpKey, host)
if err != nil {
// If the host does not have SRV records => no replicas
if dnsErr, ok := err.(*net.DNSError); ok && dnsErr.IsNotFound {
return results, nil
} else {
return results, err
}
}
if len(originRecords) == 0 {
return results, nil
}
originHost := strings.TrimSuffix(originRecords[0].Target, ".")
results = append(results, originHost)
index := 0
for {
currentAlt := altKey + strconv.Itoa(index)
_, altRecords, err := net.DefaultResolver.LookupSRV(ctx, currentAlt, tcpKey, originHost)
if err != nil {
// If the host does not have SRV records => no more replicas
if dnsErr, ok := err.(*net.DNSError); ok && dnsErr.IsNotFound {
break
} else {
return results, err
}
}
for _, record := range altRecords {
altHost := strings.TrimSuffix(record.Target, ".")
if altHost != "" {
results = append(results, altHost)
}
}
index = index + 1
}
return results, nil
}
func (manager *configurationClientManager) newConfigurationClient(endpoint string) (*azappconfig.Client, error) {
if manager.credential != nil {
return azappconfig.NewClient(endpoint, manager.credential, manager.clientOptions)
}
connectionStr := buildConnectionString(endpoint, manager.secret, manager.id)
if connectionStr == "" {
return nil, fmt.Errorf("failed to build connection string for fallback client")
}
return azappconfig.NewClientFromConnectionString(connectionStr, manager.clientOptions)
}
func buildConnectionString(endpoint string, secret string, id string) string {
if secret == "" || id == "" {
return ""
}
return fmt.Sprintf("%s=%s;%s=%s;%s=%s",
endpointKey, endpoint,
idKey, id,
secretKey, secret)
}
// parseConnectionString extracts a named value from a connection string
func parseConnectionString(connectionString string, token string) (string, error) {
if connectionString == "" {
return "", fmt.Errorf("connectionString cannot be empty")
}
parseToken := token + "="
startIndex := strings.Index(connectionString, parseToken)
if startIndex < 0 {
return "", fmt.Errorf("missing %s in connection string", token)
}
// Move past the token=
startIndex += len(parseToken)
// Find the end of this value (either ; or end of string)
endIndex := strings.Index(connectionString[startIndex:], ";")
if endIndex < 0 {
// No semicolon found, use the rest of the string
return connectionString[startIndex:], nil
}
// Adjust endIndex to be relative to the original string
endIndex += startIndex
return connectionString[startIndex:endIndex], nil
}
func setTelemetry(options *azappconfig.ClientOptions) *azappconfig.ClientOptions {
if options == nil {
options = &azappconfig.ClientOptions{}
}
if !options.Telemetry.Disabled && options.Telemetry.ApplicationID == "" {
options.Telemetry = policy.TelemetryOptions{
ApplicationID: fmt.Sprintf("%s/%s", moduleName, moduleVersion),
}
}
return options
}
func getValidDomain(endpoint string) string {
url, _ := url.Parse(endpoint)
TrustedDomainLabels := []string{azConfigDomainLabel, appConfigDomainLabel}
for _, label := range TrustedDomainLabels {
index := strings.LastIndex(strings.ToLower(url.Host), strings.ToLower(label))
if index != -1 {
return url.Host[index:]
}
}
return ""
}
func isValidEndpoint(host string, validDomain string) bool {
if validDomain == "" {
return false
}
return strings.HasSuffix(strings.ToLower(host), strings.ToLower(validDomain))
}
func (client *configurationClientWrapper) updateBackoffStatus(success bool) {
if success {
client.failedAttempts = 0
client.backOffEndTime = time.Time{}
} else {
client.failedAttempts++
client.backOffEndTime = time.Now().Add(calculateBackoffDuration(client.failedAttempts))
}
}
func calculateBackoffDuration(failedAttempts int) time.Duration {
if failedAttempts <= 1 {
return minBackoffDuration
}
// Cap the exponent to prevent overflow
exponent := math.Min(float64(failedAttempts-1), float64(safeShiftLimit))
calculatedMilliseconds := float64(minBackoffDuration.Milliseconds()) * math.Pow(2, exponent)
if calculatedMilliseconds > float64(maxBackoffDuration.Milliseconds()) || calculatedMilliseconds <= 0 {
calculatedMilliseconds = float64(maxBackoffDuration.Milliseconds())
}
calculatedDuration := time.Duration(calculatedMilliseconds) * time.Millisecond
return jitter(calculatedDuration)
}
func getFixedBackoffDuration(timeElapsed time.Duration) time.Duration {
if timeElapsed < time.Second*100 {
return time.Second * 5
}
if timeElapsed < time.Second*200 {
return time.Second * 10
}
if timeElapsed < time.Second*600 {
return minBackoffDuration
}
return 0
}
func jitter(duration time.Duration) time.Duration {
// Calculate the amount of jitter to add to the duration
jitter := float64(duration) * jitterRatio
// Generate a random number between -jitter and +jitter
randomJitter := rand.Float64()*(2*jitter) - jitter
// Apply the random jitter to the original duration
return duration + time.Duration(randomJitter)
}