Skip to content

Commit dda497e

Browse files
authored
Merge pull request #420 from dotnet/bugfix/419-ping-blocked
Put ping check behind configuration
2 parents 958ef49 + 8a66990 commit dda497e

7 files changed

Lines changed: 258 additions & 58 deletions

File tree

Kerberos.NET/Client/Transport/ClientDomainService.cs

Lines changed: 79 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
using System.Collections.Generic;
44
using System.Globalization;
55
using System.Linq;
6+
using System.Net.NetworkInformation;
7+
using System.Threading;
68
using System.Threading.Tasks;
79
using Kerberos.NET.Configuration;
810
using Kerberos.NET.Dns;
@@ -13,6 +15,8 @@ namespace Kerberos.NET.Transport
1315
{
1416
public class ClientDomainService
1517
{
18+
private static readonly Random Random = new();
19+
1620
public ClientDomainService(ILoggerFactory logger)
1721
{
1822
this.logger = logger.CreateLoggerSafe<ClientDomainService>();
@@ -47,6 +51,13 @@ static ClientDomainService()
4751

4852
public Krb5Config Configuration { get; set; }
4953

54+
public TimeSpan ConnectTimeout { get; set; } = TimeSpan.FromSeconds(2);
55+
56+
public TimeSpan SendTimeout { get; set; } = TimeSpan.FromSeconds(10);
57+
58+
public TimeSpan ReceiveTimeout { get; set; } = TimeSpan.FromSeconds(10);
59+
60+
5061
public void ResetConnections()
5162
{
5263
DomainCache.Clear();
@@ -59,7 +70,37 @@ public virtual async Task<IEnumerable<DnsRecord>> LocateKdc(string domain, strin
5970
{
6071
var results = await this.Query(domain, servicePrefix, DefaultKerberosPort);
6172

62-
return ParseQuerySrvReply(results);
73+
results = ParseQuerySrvReply(results);
74+
75+
return await WeightResults(results);
76+
}
77+
78+
private async Task<IEnumerable<DnsRecord>> WeightResults(IEnumerable<DnsRecord> results)
79+
{
80+
SortedList<int, DnsRecord> fastest = new();
81+
82+
if (this.Configuration.Defaults.PrioritizeKdcByPing)
83+
{
84+
try
85+
{
86+
using var cts = new CancellationTokenSource(this.ConnectTimeout);
87+
88+
fastest = await results.GetFastestAsync(PingAsync, cts.Token);
89+
}
90+
catch (Exception ex)
91+
{
92+
this.logger.LogWarning(ex, "Ping failed for all found services");
93+
}
94+
}
95+
96+
foreach (var r in results)
97+
{
98+
var speed = fastest.FirstOrDefault(f => string.Equals(f.Value.Target, r.Target, StringComparison.OrdinalIgnoreCase));
99+
100+
r.PingResponseTime = speed.Value != null ? speed.Key : Random.Next(fastest.Count, int.MaxValue);
101+
}
102+
103+
return results;
63104
}
64105

65106
public virtual async Task<IEnumerable<DnsRecord>> LocateKpasswd(string domain, string servicePrefix)
@@ -153,6 +194,43 @@ protected virtual async Task<IEnumerable<DnsRecord>> Query(string domain, string
153194
return records;
154195
}
155196

197+
protected virtual async Task<DnsRecord> PingAsync(DnsRecord record, CancellationToken cancellationToken)
198+
{
199+
using var ping = new Ping();
200+
201+
cancellationToken.Register(() => ping.SendAsyncCancel());
202+
203+
var reply = await ping.SendPingAsync(record.Target, Convert.ToInt32(this.ConnectTimeout.TotalMilliseconds));
204+
205+
return reply.Status == IPStatus.Success ? record : throw new PingException($"Ping {record.Target} returned {reply.Status}");
206+
}
207+
208+
private class DnsRecordComparer : IEqualityComparer<DnsRecord>
209+
{
210+
public static readonly DnsRecordComparer Instance = new();
211+
212+
private DnsRecordComparer()
213+
{
214+
}
215+
216+
public bool Equals(DnsRecord x, DnsRecord y)
217+
{
218+
if (ReferenceEquals(x, y)) return true;
219+
if (x is null) return false;
220+
if (y is null) return false;
221+
if (x.GetType() != y.GetType()) return false;
222+
return x.Target == y.Target && x.Port == y.Port;
223+
}
224+
225+
public int GetHashCode(DnsRecord obj)
226+
{
227+
unchecked
228+
{
229+
return ((obj.Target != null ? obj.Target.GetHashCode() : 0) * 397) ^ obj.Port;
230+
}
231+
}
232+
}
233+
156234
private async Task QueryDns(string domain, string servicePrefix, List<DnsRecord> records)
157235
{
158236
var lookup = Invariant($"{servicePrefix}.{domain}");

Kerberos.NET/Client/Transport/HttpsKerberosTransport.cs

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,6 @@ namespace Kerberos.NET.Transport
1919
{
2020
public class HttpsKerberosTransport : KerberosTransportBase
2121
{
22-
private static readonly Random Random = new Random();
23-
2422
private readonly ILogger logger;
2523

2624
public HttpsKerberosTransport(ILoggerFactory logger = null)

Kerberos.NET/Client/Transport/KerberosTransportBase.cs

Lines changed: 26 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
using System;
77
using System.Collections.Generic;
88
using System.Linq;
9-
using System.Net.NetworkInformation;
109
using System.Threading;
1110
using System.Threading.Tasks;
1211
using Kerberos.NET.Asn1;
@@ -20,26 +19,41 @@ namespace Kerberos.NET.Transport
2019
{
2120
public abstract class KerberosTransportBase : IKerberosTransport2, IDisposable
2221
{
22+
protected static readonly Random Random = new();
23+
24+
private bool disposedValue;
25+
2326
protected KerberosTransportBase(ILoggerFactory logger)
2427
{
2528
this.ClientRealmService = new ClientDomainService(logger);
29+
this.Logger = logger.CreateLoggerSafe<KerberosTransportBase>();
2630
}
2731

28-
private bool disposedValue;
29-
30-
private DnsRecord fastest;
32+
protected ILogger Logger { get; }
3133

3234
public virtual bool TransportFailed { get; set; }
3335

3436
public virtual KerberosTransportException LastError { get; set; }
3537

3638
public bool Enabled { get; set; }
3739

38-
public TimeSpan ConnectTimeout { get; set; } = TimeSpan.FromSeconds(2);
40+
public TimeSpan ConnectTimeout
41+
{
42+
get => this.ClientRealmService.ConnectTimeout;
43+
set => this.ClientRealmService.ConnectTimeout = value;
44+
}
3945

40-
public TimeSpan SendTimeout { get; set; } = TimeSpan.FromSeconds(10);
46+
public TimeSpan SendTimeout
47+
{
48+
get => this.ClientRealmService.SendTimeout;
49+
set => this.ClientRealmService.SendTimeout = value;
50+
}
4151

42-
public TimeSpan ReceiveTimeout { get; set; } = TimeSpan.FromSeconds(10);
52+
public TimeSpan ReceiveTimeout
53+
{
54+
get => this.ClientRealmService.ReceiveTimeout;
55+
set => this.ClientRealmService.ReceiveTimeout = value;
56+
}
4357

4458
public int MaximumAttempts { get; set; } = 30;
4559

@@ -166,58 +180,20 @@ public void Dispose()
166180
protected virtual async Task<DnsRecord> LocatePreferredKdc(string domain, string servicePrefix)
167181
{
168182
var results = await this.LocateKdc(domain, servicePrefix);
169-
return await SelectedPreferredInstance(domain, servicePrefix, results, ClientDomainService.DefaultKerberosPort);
183+
return SelectedPreferredInstance(domain, servicePrefix, results, ClientDomainService.DefaultKerberosPort);
170184
}
171185

172186
protected virtual async Task<DnsRecord> LocatePreferredKpasswd(string domain, string servicePrefix)
173187
{
174188
var results = await this.LocateKpasswd(domain, servicePrefix);
175-
return await SelectedPreferredInstance(domain, servicePrefix, results, ClientDomainService.DefaultKpasswdPort);
176-
}
177-
178-
protected virtual async Task<DnsRecord> SelectedPreferredInstance(string domain, string servicePrefix, IEnumerable<DnsRecord> results, int defaultPort)
179-
{
180-
if (results.Contains(fastest, DnsRecordComparer.Instance))
181-
{
182-
return fastest;
183-
}
184-
185-
fastest = await results.Where(r => r.Name.StartsWith(servicePrefix)).GetFastestAsync(PingAsync);
186-
return fastest ?? throw new KerberosTransportException($"Cannot locate SRV record for {domain}");
187-
}
188-
189-
private async Task<DnsRecord> PingAsync(DnsRecord record, CancellationToken cancellationToken)
190-
{
191-
using var ping = new Ping();
192-
cancellationToken.Register(() => ping.SendAsyncCancel());
193-
var reply = await ping.SendPingAsync(record.Target, Convert.ToInt32(ConnectTimeout.TotalMilliseconds));
194-
return reply.Status == IPStatus.Success ? record : throw new PingException($"Ping {record.Target} returned {reply.Status}");
189+
return SelectedPreferredInstance(domain, servicePrefix, results, ClientDomainService.DefaultKpasswdPort);
195190
}
196191

197-
private class DnsRecordComparer : IEqualityComparer<DnsRecord>
192+
protected virtual DnsRecord SelectedPreferredInstance(string domain, string servicePrefix, IEnumerable<DnsRecord> results, int defaultPort)
198193
{
199-
public static readonly DnsRecordComparer Instance = new();
200-
201-
private DnsRecordComparer()
202-
{
203-
}
194+
results = results.Where(r => r.Name.StartsWith(servicePrefix)).OrderBy(r => r.PingResponseTime);
204195

205-
public bool Equals(DnsRecord x, DnsRecord y)
206-
{
207-
if (ReferenceEquals(x, y)) return true;
208-
if (x is null) return false;
209-
if (y is null) return false;
210-
if (x.GetType() != y.GetType()) return false;
211-
return x.Target == y.Target && x.Port == y.Port;
212-
}
213-
214-
public int GetHashCode(DnsRecord obj)
215-
{
216-
unchecked
217-
{
218-
return ((obj.Target != null ? obj.Target.GetHashCode() : 0) * 397) ^ obj.Port;
219-
}
220-
}
196+
return results.FirstOrDefault() ?? throw new KerberosTransportException($"Cannot locate SRV record for {domain}");
221197
}
222198
}
223199
}

Kerberos.NET/Configuration/Krb5ConfigDefaults.cs

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -346,5 +346,12 @@ public class Krb5ConfigDefaults : Krb5ConfigObject
346346
[DefaultValue(PrincipalNameType.NT_ENTERPRISE)]
347347
[DisplayName("default_name_type")]
348348
public PrincipalNameType DefaultNameType { get; set; }
349+
350+
/// <summary>
351+
/// Indicates whether the client should try to find and sort KDCs by how long it takes for them to respond by ping.
352+
/// </summary>
353+
[DefaultValue(true)]
354+
[DisplayName("prioritize_by_response_time")]
355+
public bool PrioritizeKdcByPing { get; set; }
349356
}
350357
}

Kerberos.NET/Dns/DnsRecord.cs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,5 +49,7 @@ public string Address
4949
return this.Target;
5050
}
5151
}
52+
53+
public int PingResponseTime { get; set; } = int.MaxValue;
5254
}
5355
}

Kerberos.NET/TaskExtensions.cs

Lines changed: 24 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// -----------------------------------------------------------------------
1+
// -----------------------------------------------------------------------
22
// Licensed to The .NET Foundation under one or more agreements.
33
// The .NET Foundation licenses this file to you under the MIT license.
44
// -----------------------------------------------------------------------
@@ -11,31 +11,50 @@
1111

1212
internal static class TaskExtensions
1313
{
14-
public static async Task<TResult> GetFastestAsync<TSource, TResult>(this IEnumerable<TSource> source, Func<TSource, CancellationToken, Task<TResult>> task, CancellationToken cancellationToken = default)
14+
public static async Task<SortedList<int, TResult>> GetFastestAsync<TSource, TResult>(
15+
this IEnumerable<TSource> source,
16+
Func<TSource, CancellationToken, Task<TResult>> task,
17+
CancellationToken cancellationToken = default
18+
)
1519
{
1620
using var cts = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken);
1721
var tasks = new HashSet<Task<TResult>>(source.Select(e => task(e, cts.Token)));
22+
1823
if (tasks.Count == 0)
1924
{
20-
return default;
25+
return new();
2126
}
2227

28+
int next = 0;
29+
SortedList<int, TResult> results = new();
30+
2331
var exceptions = new List<Exception>();
32+
2433
do
2534
{
2635
var completedTask = await Task.WhenAny(tasks);
36+
2737
if (completedTask.Status == TaskStatus.RanToCompletion)
2838
{
2939
cts.Cancel();
30-
return completedTask.Result;
40+
41+
results.Add(++next, completedTask.Result);
3142
}
3243

3344
if (completedTask.Exception != null)
3445
{
3546
exceptions.AddRange(completedTask.Exception.InnerExceptions);
3647
}
48+
3749
tasks.Remove(completedTask);
38-
} while (tasks.Count > 0);
50+
51+
}
52+
while (tasks.Count > 0);
53+
54+
if (results.Count > 0)
55+
{
56+
return results;
57+
}
3958

4059
throw new AggregateException(exceptions);
4160
}

0 commit comments

Comments
 (0)