Skip to content

Commit f05a995

Browse files
[BED-5972] Add cancellation tokens to timeout logic (#213)
* WIP new Timeout replacement * Trying to mitigate task failures from timeout expiring before task makes it into the thread pool * feat: Timeout with cancellation token * Adding Cancellation Tokens where appropriate * Put timeout on LookupPrincipalBySid; correct tests; mark WindowsOnlyFacts * Add new ExecuteWithTimeout tests, fix DCLdapProcessor_CheckScan_Timeout * chore: add tests for early exit of ExecuteWithTimeout tasks; formatting * Add test failure message to 'Timeout_Cancel' tests for clarity of intent * Add comments to describe the double-ExecuteWithTimeout decision * Add LongRunning hint to synchronous ExecuteWithTimeout * Add LongRunning hint to async ExecuteWithTimeout as well * Reverting last commit: Too many potential pitfalls * Decided to add cancellation token to Task.Factory.StartNew after all * I think technically this will more comprehensively capture completed work tasks
1 parent 56c4aca commit f05a995

27 files changed

Lines changed: 448 additions & 266 deletions

Dockerfile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
FROM mono:6.12.0
22

33
# Install .NET SDK
4-
ENV DOTNET_VERSION=7.0
4+
ENV DOTNET_VERSION=8.0
55

66
RUN curl -sSL https://dot.net/v1/dotnet-install.sh \
77
| bash -s -- -Channel $DOTNET_VERSION -InstallDir /usr/share/dotnet \

src/CommonLib/Helpers.cs

Lines changed: 122 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,10 @@
1111
using System.Security;
1212
using SharpHoundCommonLib.Processors;
1313
using Microsoft.Win32;
14+
using System.Threading.Tasks;
15+
using System.Threading;
16+
using SharpHoundRPC.NetAPINative;
17+
using SharpHoundRPC.Shared;
1418

1519
namespace SharpHoundCommonLib {
1620
public static class Helpers {
@@ -135,7 +139,8 @@ public static string DistinguishedNameToDomain(string distinguishedName) {
135139
int idx;
136140
if (distinguishedName.ToUpper().Contains("DELETED OBJECTS")) {
137141
idx = distinguishedName.IndexOf("DC=", 3, StringComparison.Ordinal);
138-
} else {
142+
}
143+
else {
139144
idx = distinguishedName.IndexOf("DC=",
140145
StringComparison.CurrentCultureIgnoreCase);
141146
}
@@ -193,7 +198,8 @@ public static long ConvertFileTimeToUnixEpoch(string ldapTime) {
193198

194199
try {
195200
toReturn = (long)Math.Floor(DateTime.FromFileTimeUtc(time).Subtract(EpochDiff).TotalSeconds);
196-
} catch {
201+
}
202+
catch {
197203
toReturn = -1;
198204
}
199205

@@ -209,7 +215,8 @@ public static long ConvertTimestampToUnixEpoch(string ldapTime) {
209215
try {
210216
var dt = DateTime.ParseExact(ldapTime, "yyyyMMddHHmmss.0K", CultureInfo.CurrentCulture).ToUniversalTime();
211217
return (long)dt.Subtract(EpochDiff).TotalSeconds;
212-
} catch {
218+
}
219+
catch {
213220
return 0;
214221
}
215222
}
@@ -263,19 +270,23 @@ public static RegistryResult GetRegistryKeyData(string target, string subkey, st
263270
data.Value = value;
264271

265272
data.Collected = true;
266-
} catch (IOException e) {
273+
}
274+
catch (IOException e) {
267275
log.LogDebug(e, "Error getting data from registry for {Target}: {RegSubKey}:{RegValue}",
268276
target, subkey, subvalue);
269277
data.FailureReason = "Target machine was not found or not connectable";
270-
} catch (SecurityException e) {
278+
}
279+
catch (SecurityException e) {
271280
log.LogDebug(e, "Error getting data from registry for {Target}: {RegSubKey}:{RegValue}",
272281
target, subkey, subvalue);
273282
data.FailureReason = "User does not have the proper permissions to perform this operation";
274-
} catch (UnauthorizedAccessException e) {
283+
}
284+
catch (UnauthorizedAccessException e) {
275285
log.LogDebug(e, "Error getting data from registry for {Target}: {RegSubKey}:{RegValue}",
276286
target, subkey, subvalue);
277287
data.FailureReason = "User does not have the necessary registry rights";
278-
} catch (Exception e) {
288+
}
289+
catch (Exception e) {
279290
log.LogDebug(e, "Error getting data from registry for {Target}: {RegSubKey}:{RegValue}",
280291
target, subkey, subvalue);
281292
data.FailureReason = e.Message;
@@ -300,7 +311,7 @@ public static IRegistryKey OpenRemoteRegistry(string target) {
300311
CommonOids.ClientAuthentication,
301312
CommonOids.AnyPurpose
302313
};
303-
314+
304315
public static string DumpDirectoryObject(this IDirectoryObject directoryObject) {
305316
var builder = new StringBuilder();
306317
builder.AppendLine("PropertyName : PropertyValue");
@@ -310,6 +321,109 @@ public static string DumpDirectoryObject(this IDirectoryObject directoryObject)
310321

311322
return builder.ToString();
312323
}
324+
325+
/// <summary>
326+
/// Returns a Fail result if a task runs longer than its budgeted time.
327+
/// A cancellation token is passed to the executing function so it may exit cleanly if timeout is reached.
328+
/// </summary>
329+
/// <typeparam name="T"></typeparam>
330+
/// <param name="timeout"></param>
331+
/// <param name="func"></param>
332+
/// <returns></returns>
333+
public static async Task<Result<T>> ExecuteWithTimeout<T>(TimeSpan timeout, Func<CancellationToken, T> func) {
334+
var cts = new CancellationTokenSource();
335+
var task = Task.Factory.StartNew(() => func(cts.Token), cts.Token, TaskCreationOptions.LongRunning, TaskScheduler.Default);
336+
await Task.WhenAny(task, Task.Delay(timeout, cts.Token));
337+
cts.Cancel();
338+
339+
if (task.IsCompleted) {
340+
try {
341+
return Result<T>.Ok(await task);
342+
}
343+
catch (OperationCanceledException) { }
344+
}
345+
346+
return Result<T>.Fail("Timeout");
347+
}
348+
349+
// These two ExecuteWithTimeout functions should perform equivalently -
350+
// they both create a new task from a function arg
351+
// But where the one below can invoke an async function directly to spawn the Task
352+
// The one above spawns a Task from a synchronous function.
353+
// The caller shouldn't have to worry about which they're using however,
354+
// the compiler should figure it out intrinsically
355+
356+
/// <summary>
357+
/// Returns a Fail result if a task runs longer than its budgeted time.
358+
/// A cancellation token is passed to the executing function so it may exit cleanly if timeout is reached.
359+
/// </summary>
360+
/// <typeparam name="T"></typeparam>
361+
/// <param name="timeout"></param>
362+
/// <param name="func"></param>
363+
/// <returns></returns>
364+
public static async Task<Result<T>> ExecuteWithTimeout<T>(TimeSpan timeout, Func<CancellationToken, Task<T>> func) {
365+
var cts = new CancellationTokenSource();
366+
var task = func.Invoke(cts.Token);
367+
await Task.WhenAny(task, Task.Delay(timeout, cts.Token));
368+
cts.Cancel();
369+
370+
if (task.IsCompleted) {
371+
try {
372+
return Result<T>.Ok(await task);
373+
}
374+
catch (OperationCanceledException) { }
375+
}
376+
377+
return Result<T>.Fail("Timeout");
378+
}
379+
380+
/// <summary>
381+
/// Returns a Fail result if a task runs longer than its budgeted time.
382+
/// A cancellation token is passed to the executing function so it may exit cleanly if timeout is reached.
383+
/// </summary>
384+
/// <typeparam name="T"></typeparam>
385+
/// <param name="timeout"></param>
386+
/// <param name="func"></param>
387+
/// <returns></returns>
388+
public static async Task<NetAPIResult<T>> ExecuteNetAPIWithTimeout<T>(TimeSpan timeout, Func<CancellationToken, NetAPIResult<T>> func) {
389+
var result = await ExecuteWithTimeout(timeout, func);
390+
if (result.IsSuccess)
391+
return result.Value;
392+
else
393+
return NetAPIResult<T>.Fail(result.Error);
394+
}
395+
396+
/// <summary>
397+
/// Returns a Fail result if a task runs longer than its budgeted time.
398+
/// A cancellation token is passed to the executing function so it may exit cleanly if timeout is reached.
399+
/// </summary>
400+
/// <typeparam name="T"></typeparam>
401+
/// <param name="timeout"></param>
402+
/// <param name="func"></param>
403+
/// <returns></returns>
404+
public static async Task<SharpHoundRPC.Result<T>> ExecuteRPCWithTimeout<T>(TimeSpan timeout, Func<CancellationToken, SharpHoundRPC.Result<T>> func) {
405+
var result = await ExecuteWithTimeout(timeout, func);
406+
if (result.IsSuccess)
407+
return result.Value;
408+
else
409+
return SharpHoundRPC.Result<T>.Fail(result.Error);
410+
}
411+
412+
/// <summary>
413+
/// Returns a Fail result if a task runs longer than its budgeted time.
414+
/// A cancellation token is passed to the executing function so it may exit cleanly if timeout is reached.
415+
/// </summary>
416+
/// <typeparam name="T"></typeparam>
417+
/// <param name="timeout"></param>
418+
/// <param name="func"></param>
419+
/// <returns></returns>
420+
public static async Task<SharpHoundRPC.Result<T>> ExecuteRPCWithTimeout<T>(TimeSpan timeout, Func<CancellationToken, Task<SharpHoundRPC.Result<T>>> func) {
421+
var result = await ExecuteWithTimeout(timeout, func);
422+
if (result.IsSuccess)
423+
return result.Value;
424+
else
425+
return SharpHoundRPC.Result<T>.Fail(result.Error);
426+
}
313427
}
314428

315429
public class ParsedGPLink {

src/CommonLib/Ntlm/NtlmAuthenticationHandler.cs

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,13 @@
11
using Microsoft.Extensions.Logging;
22
using SharpHoundCommonLib.Processors;
33
using SharpHoundCommonLib.ThirdParty.PSOpenAD;
4-
using System;
4+
using System.Threading;
55
using System.Threading.Tasks;
66

77
namespace SharpHoundCommonLib.Ntlm;
88

99
interface INtlmAuthenticationHandler {
10-
Task<object> PerformNtlmAuthenticationAsync(INtlmTransport transport);
10+
Task<object> PerformNtlmAuthenticationAsync(INtlmTransport transport, CancellationToken cancellationToken = default);
1111
}
1212

1313
/// <summary>
@@ -28,7 +28,7 @@ public NtlmAuthenticationHandler(string targetService, ILogger logger = null) {
2828
};
2929
}
3030

31-
public virtual async Task<object> PerformNtlmAuthenticationAsync(INtlmTransport transport) {
31+
public virtual async Task<object> PerformNtlmAuthenticationAsync(INtlmTransport transport, CancellationToken cancellationToken = default) {
3232
using var context = new SspiContext(
3333
null,
3434
null,
@@ -39,12 +39,16 @@ public virtual async Task<object> PerformNtlmAuthenticationAsync(INtlmTransport
3939
Options.Signing
4040
);
4141

42+
cancellationToken.ThrowIfCancellationRequested();
43+
4244
// NEGOTIATE
4345
var negotiateMsgBytes = context.Step();
4446

4547
// CHALLENGE
4648
var challengeMessageBytes = await transport.NegotiateAsync(negotiateMsgBytes);
4749

50+
cancellationToken.ThrowIfCancellationRequested();
51+
4852
// AUTHENTICATE
4953
var authenticateMsgBytes = context.Step(challengeMessageBytes);
5054

src/CommonLib/Processors/ComputerSessionProcessor.cs

Lines changed: 20 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
using Microsoft.Extensions.Logging;
99
using Microsoft.Win32;
1010
using SharpHoundCommonLib.OutputTypes;
11-
using SharpHoundRPC;
1211
using SharpHoundRPC.NetAPINative;
1312

1413
namespace SharpHoundCommonLib.Processors {
@@ -58,7 +57,7 @@ public async Task<SessionAPIResult> ReadUserSessions(string computerName, string
5857

5958
_log.LogDebug("Running NetSessionEnum for {ObjectName}", computerName);
6059

61-
var result = await Task.Run(() => {
60+
var result = await Helpers.ExecuteNetAPIWithTimeout(timeout, (timeoutToken) => {
6261
NetAPIResult<IEnumerable<NetSessionEnumResults>> result;
6362
if (_doLocalAdminSessionEnum) {
6463
// If we are authenticating using a local admin, we need to impersonate for this
@@ -67,19 +66,22 @@ public async Task<SessionAPIResult> ReadUserSessions(string computerName, string
6766
result = _nativeMethods.NetSessionEnum(computerName);
6867
}
6968

69+
timeoutToken.ThrowIfCancellationRequested();
70+
7071
if (result.IsFailed) {
7172
// Fall back to default User
7273
_log.LogDebug(
7374
"NetSessionEnum failed on {ComputerName} with local admin credentials: {Status}. Fallback to default user.",
7475
computerName, result.Status);
7576
result = _nativeMethods.NetSessionEnum(computerName);
7677
}
77-
} else {
78+
}
79+
else {
7880
result = _nativeMethods.NetSessionEnum(computerName);
7981
}
8082

8183
return result;
82-
}).TimeoutAfter(timeout);
84+
});
8385

8486
if (result.IsFailed) {
8587
await SendComputerStatus(new CSVComputerStatus {
@@ -121,7 +123,7 @@ await SendComputerStatus(new CSVComputerStatus {
121123
username.Equals("anonymous logon", StringComparison.CurrentCultureIgnoreCase)) {
122124
continue;
123125
}
124-
126+
125127
//Filter out domains that are "."
126128
if (computerDomain.Equals(".")) {
127129
continue;
@@ -147,7 +149,8 @@ await SendComputerStatus(new CSVComputerStatus {
147149
if (matchSuccess) {
148150
results.AddRange(
149151
sids.Select(s => new Session { ComputerSID = resolvedComputerSID, UserSID = s }));
150-
} else {
152+
}
153+
else {
151154
var res = await _utils.ResolveAccountName(username, computerDomain);
152155
if (res.Success)
153156
results.Add(new Session {
@@ -180,7 +183,7 @@ public async Task<SessionAPIResult> ReadUserSessionsPrivileged(string computerNa
180183

181184
_log.LogDebug("Running NetWkstaUserEnum for {ObjectName}", computerName);
182185

183-
var result = await Task.Run(() => {
186+
var result = await Helpers.ExecuteNetAPIWithTimeout(timeout, (timeoutToken) => {
184187
NetAPIResult<IEnumerable<NetWkstaUserEnumResults>>
185188
result;
186189
if (_doLocalAdminSessionEnum) {
@@ -190,19 +193,22 @@ public async Task<SessionAPIResult> ReadUserSessionsPrivileged(string computerNa
190193
result = _nativeMethods.NetWkstaUserEnum(computerName);
191194
}
192195

196+
timeoutToken.ThrowIfCancellationRequested();
197+
193198
if (result.IsFailed) {
194199
// Fall back to default User
195200
_log.LogDebug(
196201
"NetWkstaUserEnum failed on {ComputerName} with local admin credentials: {Status}. Fallback to default user.",
197202
computerName, result.Status);
198203
result = _nativeMethods.NetWkstaUserEnum(computerName);
199204
}
200-
} else {
205+
}
206+
else {
201207
result = _nativeMethods.NetWkstaUserEnum(computerName);
202208
}
203209

204210
return result;
205-
}).TimeoutAfter(timeout);
211+
});
206212

207213
if (result.IsFailed) {
208214
await SendComputerStatus(new CSVComputerStatus {
@@ -244,7 +250,7 @@ await SendComputerStatus(new CSVComputerStatus {
244250
if (string.IsNullOrWhiteSpace(username) || username.EndsWith("$", StringComparison.Ordinal)) {
245251
continue;
246252
}
247-
253+
248254
//Filter out domains that are "."
249255
if (domain.Equals(".")) {
250256
continue;
@@ -312,7 +318,8 @@ await SendComputerStatus(new CSVComputerStatus {
312318
ret.Results = results.ToArray();
313319

314320
return ret;
315-
} catch (Exception e) {
321+
}
322+
catch (Exception e) {
316323
_log.LogTrace("Registry session enum failed on {ComputerName}: {Status}", computerName, e.Message);
317324
await SendComputerStatus(new CSVComputerStatus {
318325
Status = e.Message,
@@ -322,7 +329,8 @@ await SendComputerStatus(new CSVComputerStatus {
322329
ret.Collected = false;
323330
ret.FailureReason = e.Message;
324331
return ret;
325-
} finally {
332+
}
333+
finally {
326334
key?.Dispose();
327335
}
328336
}

0 commit comments

Comments
 (0)