Skip to content

Commit a2456d9

Browse files
[BED-5988] Update MergeFilter to accomodate more than 1 mandatory filter (#211)
* GetFilterList changes to manage multiple mandatory filters * Update MergeFilter test * Let's also check that exactly 2 filters are returned by our MergeFilter test
1 parent 801f2c5 commit a2456d9

2 files changed

Lines changed: 27 additions & 15 deletions

File tree

src/CommonLib/LdapQueries/LdapFilter.cs

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -255,17 +255,19 @@ public string GetFilter() {
255255
return filterPartsDistinct;
256256
}
257257

258-
private string MergeFilter(string filterA, string filterB) {
259-
return $"(&{filterA}{filterB})";
258+
private string MergeFilters(params string[] filters) {
259+
return $"(&{string.Join("", filters)})";
260260
}
261261

262262
public IEnumerable<string> GetFilterList() {
263263
foreach (var filter in _filterParts.Distinct())
264264
{
265265
if (_mandatory.Count > 0) {
266-
foreach (var mandatory in _mandatory) {
267-
yield return MergeFilter(filter, mandatory);
268-
}
266+
var filters = new List<string>(_mandatory)
267+
{
268+
filter
269+
};
270+
yield return MergeFilters(filters.ToArray());
269271
} else {
270272
yield return filter;
271273
}

test/unit/LDAPFilterTest.cs

Lines changed: 20 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
using System;
22
using System.Collections.Generic;
3+
using System.Linq;
34
using SharpHoundCommonLib.LDAPQueries;
45
using Xunit;
56
using Xunit.Abstractions;
@@ -73,27 +74,36 @@ public void LDAPFilter_GetFilterList()
7374
i++;
7475
}
7576
}
76-
77+
7778
[Fact]
7879
public void LDAPFilter_GetFilterList_MergeFilter()
7980
{
8081
var test = new LdapFilter();
8182
test.AddUsers();
8283
test.AddComputers();
83-
string mergeFilter = "(objectclass=*)";
84-
test.AddFilter(mergeFilter, true);
85-
84+
string mandatoryFilter1 = "(objectclass=*)";
85+
string mandatoryFilter2 = "(iamamandatoryfilter=1)";
86+
test.AddFilter(mandatoryFilter1, true);
87+
test.AddFilter(mandatoryFilter2, true);
88+
8689
IEnumerable<string> filters = test.GetFilterList();
87-
88-
int i = 0;
90+
8991
string computerFilter = "(samaccounttype=805306369)";
9092
string userFilter = "(|(samaccounttype=805306368)(samaccounttype=805306370))";
91-
string[] expected = {$"(&{userFilter}{mergeFilter})", $"(&{computerFilter}{mergeFilter})"};
9293

93-
foreach (var filter in filters) {
94-
Assert.Equal(expected[i], filter);
95-
i++;
94+
// Check that each filter includes all mandatory filters
95+
foreach (var filter in filters)
96+
{
97+
Assert.StartsWith("(&", filter);
98+
Assert.Contains(mandatoryFilter1, filter);
99+
Assert.Contains(mandatoryFilter2, filter);
96100
}
101+
102+
// Check that each of userFilter and computerFilter are accounted for
103+
Assert.Single(filters.Where(f => f.Contains(userFilter)));
104+
Assert.Single(filters.Where(f => f.Contains(computerFilter)));
105+
106+
Assert.Equal(2, filters.Count());
97107
}
98108

99109
#endregion

0 commit comments

Comments
 (0)