Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 0 additions & 35 deletions src/CommonLib/Ntlm/HttpClientFactory.cs

This file was deleted.

25 changes: 7 additions & 18 deletions src/CommonLib/Ntlm/HttpNtlmAuthenticationService.cs
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,14 @@ namespace SharpHoundCommonLib.Ntlm;
/// </summary>
public class HttpNtlmAuthenticationService {
private readonly ILogger _logger;
private readonly IHttpClientFactory _httpClientFactory;
private readonly INtlmHttpClientFactory _ntlmHttpClientFactory;
private readonly AdaptiveTimeout _getSupportedNTLMAuthSchemesAdaptiveTimeout;
private readonly AdaptiveTimeout _ntlmAuthAdaptiveTimeout;
private readonly AdaptiveTimeout _authWithChannelBindingAdaptiveTimeout;

public HttpNtlmAuthenticationService(IHttpClientFactory httpClientFactory, ILogger logger = null) {
public HttpNtlmAuthenticationService(INtlmHttpClientFactory ntlmHttpClientFactory, ILogger logger = null) {
_logger = logger ?? Logging.LogProvider.CreateLogger(nameof(HttpNtlmAuthenticationService));
_httpClientFactory = httpClientFactory;
_ntlmHttpClientFactory = ntlmHttpClientFactory ?? throw new ArgumentNullException(nameof(ntlmHttpClientFactory));
_getSupportedNTLMAuthSchemesAdaptiveTimeout = new AdaptiveTimeout(maxTimeout: TimeSpan.FromMinutes(2), Logging.LogProvider.CreateLogger(nameof(GetSupportedNtlmAuthSchemesAsync)));
_ntlmAuthAdaptiveTimeout = new AdaptiveTimeout(maxTimeout: TimeSpan.FromMinutes(2), Logging.LogProvider.CreateLogger(nameof(NtlmAuthenticationHandler.PerformNtlmAuthenticationAsync)));
_authWithChannelBindingAdaptiveTimeout = new AdaptiveTimeout(maxTimeout: TimeSpan.FromMinutes(2), Logging.LogProvider.CreateLogger(nameof(AuthWithBadChannelBindingsAsync)));
Expand Down Expand Up @@ -57,7 +57,7 @@ public async Task EnsureRequiresAuth(Uri url, bool? useBadChannelBindings) {
}

private async Task<string[]> GetSupportedNtlmAuthSchemesAsync(Uri url) {
var httpClient = _httpClientFactory.CreateUnauthenticatedClient();
var httpClient = _ntlmHttpClientFactory.CreateUnauthenticatedClient();
using var getRequest = new HttpRequestMessage(HttpMethod.Get, url);

var result = await _getSupportedNTLMAuthSchemesAdaptiveTimeout.ExecuteWithTimeout(async (timeoutToken) => {
Expand Down Expand Up @@ -105,7 +105,7 @@ internal string[] ExtractAuthSchemes(HttpResponseMessage response) {
}

private async Task AuthWithBadChannelBindingsAsync(Uri url, string authScheme, NtlmAuthenticationHandler ntlmAuth = null) {
var httpClient = _httpClientFactory.CreateUnauthenticatedClient();
var httpClient = _ntlmHttpClientFactory.CreateUnauthenticatedClient();
var transport = new HttpTransport(httpClient, url, authScheme, _logger);
var ntlmAuthHandler = ntlmAuth ?? new NtlmAuthenticationHandler($"HTTP/{url.Host}");

Expand Down Expand Up @@ -143,18 +143,7 @@ private async Task AuthWithBadChannelBindingsAsync(Uri url, string authScheme, N
}

private async Task<bool> AuthWithChannelBindingAsync(Uri url, string authScheme) {
var handler = new HttpClientHandler {
ServerCertificateCustomValidationCallback = (httpRequestMessage, cert, cetChain, policyErrors) => true,
};

var credentialCache = new CredentialCache {
{ url, authScheme, CredentialCache.DefaultNetworkCredentials }
};

handler.Credentials = credentialCache;
handler.PreAuthenticate = true;

using var client = new HttpClient(handler);
using var client = _ntlmHttpClientFactory.CreateAuthenticatedHttpClient(url, authScheme);

var result = await _authWithChannelBindingAdaptiveTimeout.ExecuteWithTimeout(async (timeoutToken) => {
try {
Expand Down Expand Up @@ -217,4 +206,4 @@ public AuthNotRequiredException() {

public AuthNotRequiredException(string message) : base(message) {
}
}
}
62 changes: 62 additions & 0 deletions src/CommonLib/Ntlm/NtlmHttpClientFactory.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
using System;
using System.Net;
using System.Net.Http;
using System.Security.Authentication;

namespace SharpHoundCommonLib.Ntlm;

public interface INtlmHttpClientFactory {
HttpClient CreateUnauthenticatedClient();
HttpClient CreateAuthenticatedHttpClient(Uri Url, string authPackage = "Kerberos");
}

public class NtlmHttpClientFactory : INtlmHttpClientFactory {
private readonly SslProtocols _sslProtocols;

/// <summary>
/// Creates an HttpClientFactory whose handlers will negotiate TLS using OS/framework defaults.
/// </summary>
public NtlmHttpClientFactory() : this(SslProtocols.None) { }

/// <summary>
/// Creates an HttpClientFactory whose handlers will restrict TLS negotiation to the specified protocols.
/// Use this overload when a specific set of legacy protocols must be supported for a target service,
/// rather than setting <see cref="System.Net.ServicePointManager.SecurityProtocol"/> process-wide.
/// </summary>
/// <param name="sslProtocols">
/// The SSL/TLS protocols to allow. Pass <see cref="SslProtocols.None"/> to defer to OS/framework defaults.
/// </param>
public NtlmHttpClientFactory(SslProtocols sslProtocols) {
_sslProtocols = sslProtocols;
}

public HttpClient CreateUnauthenticatedClient() {
var handler = new HttpClientHandler {
ServerCertificateCustomValidationCallback =
(httpRequestMessage, cert, cetChain, policyErrors) => true,
UseDefaultCredentials = false
};

if (_sslProtocols != SslProtocols.None)
handler.SslProtocols = _sslProtocols;

return new HttpClient(handler);
}

public HttpClient CreateAuthenticatedHttpClient(Uri Url, string authPackage = "Kerberos") {
var handler = new HttpClientHandler {
Credentials = new CredentialCache() {
{ Url, authPackage, CredentialCache.DefaultNetworkCredentials }
},

PreAuthenticate = true,
ServerCertificateCustomValidationCallback =
(httpRequestMessage, cert, cetChain, policyErrors) => true,
};

if (_sslProtocols != SslProtocols.None)
handler.SslProtocols = _sslProtocols;

return new HttpClient(handler);
}
}
19 changes: 9 additions & 10 deletions src/CommonLib/Processors/CAEnrollmentProcessor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
using System.Net;
using System.Net.Http;
using System.Net.Sockets;
using System.Security.Authentication;
using System.Threading.Tasks;

namespace SharpHoundCommonLib.Processors {
Expand All @@ -18,13 +19,11 @@ public class CAEnrollmentProcessor {
private readonly string _caName;
private readonly ILogger _logger;

public CAEnrollmentProcessor(string caDnsHostname, string caName, ILogger log = null) {
ServicePointManager.SecurityProtocol |=
SecurityProtocolType.Ssl3
| SecurityProtocolType.Tls12
| SecurityProtocolType.Tls11
| SecurityProtocolType.Tls;
// TLS1.3 is not available in .Net Framework 4.7.2, but the enum can still be assigned.
private const SslProtocols CaEnrollmentSslProtocols =
SslProtocols.Ssl3 | SslProtocols.Tls | SslProtocols.Tls11 | SslProtocols.Tls12 | (SslProtocols)12288;

public CAEnrollmentProcessor(string caDnsHostname, string caName, ILogger log = null) {
_caDnsHostname = caDnsHostname;
_caName = caName;
_logger = log ?? Logging.LogProvider.CreateLogger("CAEnrollmentProcessor");
Expand All @@ -48,7 +47,7 @@ await Task.WhenAll(
} catch (Exception ex) {
_logger.LogError(ex, "An error occurred while scanning enrollment endpoints");
}

endpoints = TagEndpoints(endpoints).ToList();

return endpoints;
Expand All @@ -59,7 +58,7 @@ private IEnumerable<APIResult<CAEnrollmentEndpoint>> TagEndpoints(IEnumerable<AP
foreach (var endpoint in tagEndpoints) {
if (!endpoint.Collected)
continue;

var enrollmentEndpoint = endpoint.Result;
if (enrollmentEndpoint.Url.Scheme != Uri.UriSchemeHttps) {
switch (enrollmentEndpoint.Status) {
Expand Down Expand Up @@ -126,7 +125,7 @@ private async Task<IEnumerable<APIResult<CAEnrollmentEndpoint>>>
private async Task<APIResult<CAEnrollmentEndpoint>> GetNtlmEndpoint(Uri url, bool? useBadChannelBinding,
CAEnrollmentEndpointType type, CAEnrollmentEndpointScanResult scanResult) {
var authService = new HttpNtlmAuthenticationService(
new HttpClientFactory()
new NtlmHttpClientFactory(CaEnrollmentSslProtocols)
);

var output = new CAEnrollmentEndpoint(url, type, scanResult);
Expand Down Expand Up @@ -228,4 +227,4 @@ private async Task<APIResult<CAEnrollmentEndpoint>> GetNtlmEndpoint(Uri url, boo
}
}
}
}
}
8 changes: 4 additions & 4 deletions test/unit/HttpNtlmAuthenticationServiceTest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ public void Dispose() {

[Fact]
public void HttpNtlmAuthenticationService_ExtractAuthSchemes_AuthNotRequiredException() {
var service = new HttpNtlmAuthenticationService(new HttpClientFactory(), null);
var service = new HttpNtlmAuthenticationService(new NtlmHttpClientFactory(), null);
var httpResponseMessage = new HttpResponseMessage {
StatusCode = HttpStatusCode.OK,
};
Expand All @@ -37,7 +37,7 @@ public void HttpNtlmAuthenticationService_ExtractAuthSchemes_AuthNotRequiredExce

[Fact]
public void HttpNtlmAuthenticationService_ExtractAuthSchemes_HttpForbiddenException() {
var service = new HttpNtlmAuthenticationService(new HttpClientFactory(), null);
var service = new HttpNtlmAuthenticationService(new NtlmHttpClientFactory(), null);
var httpResponseMessage = new HttpResponseMessage {
StatusCode = HttpStatusCode.Forbidden,
};
Expand All @@ -49,7 +49,7 @@ public void HttpNtlmAuthenticationService_ExtractAuthSchemes_HttpForbiddenExcept

[Fact]
public void HttpNtlmAuthenticationService_ExtractAuthSchemes_HttpServerErrorException() {
var service = new HttpNtlmAuthenticationService(new HttpClientFactory(), null);
var service = new HttpNtlmAuthenticationService(new NtlmHttpClientFactory(), null);
var httpResponseMessage = new HttpResponseMessage {
StatusCode = HttpStatusCode.InternalServerError,
};
Expand All @@ -61,7 +61,7 @@ public void HttpNtlmAuthenticationService_ExtractAuthSchemes_HttpServerErrorExce

[Fact]
public void HttpNtlmAuthenticationService_ExtractAuthSchemes_Success() {
var service = new HttpNtlmAuthenticationService(new HttpClientFactory(), null);
var service = new HttpNtlmAuthenticationService(new NtlmHttpClientFactory(), null);
var httpResponseMessage = new HttpResponseMessage();
httpResponseMessage.StatusCode = HttpStatusCode.Accepted;
httpResponseMessage.Headers.WwwAuthenticate.Add(
Expand Down
Loading