Skip to content

Commit 498006e

Browse files
MaggieQixufyanQi Chen
authored
Parallel bkt build (#443)
* Add parallel BKT tree build support via level-order BFS with OpenMP Add BuildTreesParallel() method to BKTree that parallelizes tree construction using a level-order (BFS) approach instead of the existing depth-first recursive method. At each level of the tree, all sibling nodes are processed in parallel using OpenMP, with each thread running independent k-means clustering. The tree structure is then assembled sequentially to maintain correctness. This is controlled by a new ParallelBKTBuild parameter (default: false) that can be enabled in both BKT index and SPANN select-head configurations. Benchmark on SIFT 50M (128-dim, L2) with 32 threads on Azure L32s_v2: - Select Head (BKT build): 16.6 hours -> 1.2 hours (13.6x speedup) - Build Head graph (RefineGraph): unchanged (~10 hours, memory-bound) - Total end-to-end build: ~30 hours -> ~15 hours - Recall@1: 91% -> 94% (slight improvement) - Query latency: comparable (P50 ~40ms) * change omp to std:::thread --------- Co-authored-by: xufyan <nickyoung.fdu@gmail.com> Co-authored-by: Qi Chen <cheqi@microsoft.com>
1 parent b2748d9 commit 498006e

7 files changed

Lines changed: 346 additions & 6 deletions

File tree

AnnService/inc/Core/BKT/ParameterDefinitionList.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ DefineBKTParameter(m_pTrees.m_iBKTKmeansK, int, 32L, "BKTKmeansK")
1515
DefineBKTParameter(m_pTrees.m_iBKTLeafSize, int, 8L, "BKTLeafSize")
1616
DefineBKTParameter(m_pTrees.m_iSamples, int, 1000L, "Samples")
1717
DefineBKTParameter(m_pTrees.m_fBalanceFactor, float, 100.0F, "BKTLambdaFactor")
18+
DefineBKTParameter(m_pTrees.m_parallelBuild, bool, false, "ParallelBKTBuild")
1819

1920
DefineBKTParameter(m_pGraph.m_iTPTNumber, int, 32L, "TPTNumber")
2021
DefineBKTParameter(m_pGraph.m_iTPTLeafSize, int, 2000L, "TPTLeafSize")

AnnService/inc/Core/Common/BKTree.h

Lines changed: 185 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@
99
#include <vector>
1010
#include <mutex>
1111
#include <shared_mutex>
12+
#include <atomic>
13+
#include <omp.h>
1214
#include "inc/Core/VectorIndex.h"
1315

1416
#include "CommonUtils.h"
@@ -655,6 +657,188 @@ break;
655657
}
656658
}
657659

660+
// Parallel BKTree Build - processes sibling nodes in parallel
661+
template <typename T>
662+
void BuildTreesParallel(const Dataset<T>& data, DistCalcMethod distMethod, int numOfThreads,
663+
std::vector<SizeType>* indices = nullptr, std::vector<SizeType>* reverseIndices = nullptr,
664+
bool dynamicK = false, IAbortOperation* abort = nullptr)
665+
{
666+
SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Using PARALLEL BKTree build with %d threads.\n", numOfThreads);
667+
668+
// Helper struct for collecting parallel results
669+
struct ParallelNodeResult {
670+
SizeType parentIndex;
671+
SizeType first, last;
672+
std::vector<SizeType> childCenters;
673+
std::vector<SizeType> childCounts;
674+
bool isLeaf;
675+
bool singleCluster;
676+
SizeType singleClusterCenter;
677+
};
678+
679+
struct BKTStackItem {
680+
SizeType index, first, last;
681+
bool debug;
682+
BKTStackItem(SizeType index_ = -1, SizeType first_ = 0, SizeType last_ = 0, bool debug_ = false)
683+
: index(index_), first(first_), last(last_), debug(debug_) {}
684+
};
685+
686+
std::vector<SizeType> localindices;
687+
if (indices == nullptr) {
688+
localindices.resize(data.R());
689+
for (SizeType i = 0; i < (SizeType)localindices.size(); i++) localindices[i] = i;
690+
}
691+
else {
692+
localindices.assign(indices->begin(), indices->end());
693+
}
694+
695+
// Create a shared KmeansArgs for DynamicFactorSelect (uses all threads)
696+
KmeansArgs<T> sharedArgs(m_iBKTKmeansK, data.C(), (SizeType)localindices.size(), numOfThreads, distMethod, m_pQuantizer);
697+
698+
if (m_fBalanceFactor < 0) {
699+
m_fBalanceFactor = DynamicFactorSelect(data, localindices, 0, (SizeType)localindices.size(), sharedArgs, m_iSamples);
700+
}
701+
702+
std::mt19937 rg;
703+
m_pSampleCenterMap.clear();
704+
705+
for (char treeIdx = 0; treeIdx < m_iTreeNumber; treeIdx++)
706+
{
707+
std::shuffle(localindices.begin(), localindices.end(), rg);
708+
709+
m_pTreeStart.push_back((SizeType)m_pTreeRoots.size());
710+
m_pTreeRoots.emplace_back((SizeType)localindices.size());
711+
SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Start to build BKTree %d (parallel)\n", treeIdx + 1);
712+
713+
// Level-order processing
714+
std::vector<BKTStackItem> currentLevel, nextLevel;
715+
currentLevel.push_back(BKTStackItem(m_pTreeStart[treeIdx], 0, (SizeType)localindices.size(), true));
716+
717+
int level = 0;
718+
while (!currentLevel.empty()) {
719+
if (abort && abort->ShouldAbort()) {
720+
SPTAGLIB_LOG(Helper::LogLevel::LL_Warning, "Abort!!!\n");
721+
return;
722+
}
723+
724+
size_t levelSize = currentLevel.size();
725+
SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Processing level %d with %zu nodes...\n", level, levelSize);
726+
727+
std::vector<ParallelNodeResult> results(levelSize);
728+
729+
// Parallel phase: Run k-means for all nodes in this level
730+
std::atomic_int nextidx(0);
731+
auto func = [&]() {
732+
while (true) {
733+
int idx = nextidx.fetch_add(1);
734+
if (idx < (int)levelSize) {
735+
BKTStackItem& item = currentLevel[idx];
736+
ParallelNodeResult& result = results[idx];
737+
result.parentIndex = item.index;
738+
result.first = item.first;
739+
result.last = item.last;
740+
result.isLeaf = false;
741+
result.singleCluster = false;
742+
743+
if (item.last - item.first <= m_iBKTLeafSize) {
744+
// Leaf node
745+
result.isLeaf = true;
746+
for (SizeType j = item.first; j < item.last; j++) {
747+
SizeType cid = (reverseIndices == nullptr) ? localindices[j] : reverseIndices->at(localindices[j]);
748+
result.childCenters.push_back(cid);
749+
}
750+
} else {
751+
// K-means clustering - use thread-local args with 1 thread
752+
// (parallelism is at the node level, not within k-means)
753+
// IMPORTANT: Must use full dataset size because KmeansAssign uses absolute indices
754+
// (args.label[i] where i ranges from first to last, not 0 to rangeSize)
755+
KmeansArgs<T> localArgs(m_iBKTKmeansK, data.C(), (SizeType)localindices.size(), 1, distMethod, m_pQuantizer);
756+
757+
int dk = m_iBKTKmeansK;
758+
if (dynamicK) {
759+
dk = std::min<int>((item.last - item.first) / m_iBKTLeafSize + 1, m_iBKTKmeansK);
760+
dk = std::max<int>(dk, 2);
761+
localArgs._DK = dk;
762+
}
763+
764+
int numClusters = KmeansClustering(data, localindices, item.first, item.last, localArgs,
765+
m_iSamples, m_fBalanceFactor, false, abort);
766+
767+
if (numClusters <= 1) {
768+
result.singleCluster = true;
769+
SizeType end = min(item.last + 1, (SizeType)localindices.size());
770+
std::sort(localindices.begin() + item.first, localindices.begin() + end);
771+
result.singleClusterCenter = (reverseIndices == nullptr) ? localindices[item.first] : reverseIndices->at(localindices[item.first]);
772+
for (SizeType j = item.first + 1; j < end; j++) {
773+
SizeType cid = (reverseIndices == nullptr) ? localindices[j] : reverseIndices->at(localindices[j]);
774+
result.childCenters.push_back(cid);
775+
}
776+
} else {
777+
SizeType pos = item.first;
778+
for (int k = 0; k < m_iBKTKmeansK; k++) {
779+
if (localArgs.counts[k] == 0) continue;
780+
SizeType cid = (reverseIndices == nullptr) ? localindices[pos + localArgs.counts[k] - 1] : reverseIndices->at(localindices[pos + localArgs.counts[k] - 1]);
781+
result.childCenters.push_back(cid);
782+
result.childCounts.push_back(localArgs.counts[k]);
783+
pos += localArgs.counts[k];
784+
}
785+
}
786+
}
787+
} else {
788+
return;
789+
}
790+
}
791+
};
792+
793+
std::vector<std::thread> mythreads;
794+
mythreads.reserve(numOfThreads);
795+
for (int tid = 0; tid < numOfThreads; tid++)
796+
{
797+
mythreads.emplace_back(func);
798+
}
799+
for (auto& thread : mythreads) { thread.join(); }
800+
801+
// Sequential phase: Build tree structure and prepare next level
802+
nextLevel.clear();
803+
for (size_t idx = 0; idx < levelSize; idx++) {
804+
ParallelNodeResult& result = results[idx];
805+
m_pTreeRoots[result.parentIndex].childStart = (SizeType)m_pTreeRoots.size();
806+
807+
if (result.isLeaf) {
808+
for (SizeType cid : result.childCenters) {
809+
m_pTreeRoots.emplace_back(cid);
810+
}
811+
} else if (result.singleCluster) {
812+
m_pTreeRoots[result.parentIndex].centerid = result.singleClusterCenter;
813+
m_pTreeRoots[result.parentIndex].childStart = -m_pTreeRoots[result.parentIndex].childStart;
814+
for (SizeType cid : result.childCenters) {
815+
m_pTreeRoots.emplace_back(cid);
816+
m_pSampleCenterMap[cid] = result.singleClusterCenter;
817+
}
818+
m_pSampleCenterMap[-1 - result.singleClusterCenter] = result.parentIndex;
819+
} else {
820+
SizeType pos = result.first;
821+
for (size_t c = 0; c < result.childCenters.size(); c++) {
822+
SizeType nodeIdx = (SizeType)m_pTreeRoots.size();
823+
m_pTreeRoots.emplace_back(result.childCenters[c]);
824+
if (result.childCounts[c] > 1) {
825+
nextLevel.push_back(BKTStackItem(nodeIdx, pos, pos + result.childCounts[c] - 1, false));
826+
}
827+
pos += result.childCounts[c];
828+
}
829+
}
830+
m_pTreeRoots[result.parentIndex].childEnd = (SizeType)m_pTreeRoots.size();
831+
}
832+
833+
currentLevel.swap(nextLevel);
834+
level++;
835+
}
836+
837+
m_pTreeRoots.emplace_back(-1);
838+
SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "%d BKTree built (parallel), %zu %zu\n", treeIdx + 1, m_pTreeRoots.size() - m_pTreeStart[treeIdx], localindices.size());
839+
}
840+
}
841+
658842
inline std::uint64_t BufferSize() const
659843
{
660844
return sizeof(int) + sizeof(SizeType) * m_iTreeNumber +
@@ -863,6 +1047,7 @@ break;
8631047
int m_iTreeNumber, m_iBKTKmeansK, m_iBKTLeafSize, m_iSamples, m_bfs;
8641048
float m_fBalanceFactor;
8651049
std::shared_ptr<SPTAG::COMMON::IQuantizer> m_pQuantizer;
1050+
bool m_parallelBuild = false;
8661051
};
8671052
}
8681053
}

AnnService/inc/Core/SPANN/Options.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@ namespace SPTAG {
7070
bool m_recursiveCheckSmallCluster;
7171
bool m_printSizeCount;
7272
std::string m_selectType;
73+
bool m_parallelBKTBuild;
7374

7475
// Section 3: for build head
7576
bool m_buildHead;

AnnService/inc/Core/SPANN/ParameterDefinitionList.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@ DefineSelectHeadParameter(m_headVectorCount, int, 0, "Count")
6262
DefineSelectHeadParameter(m_recursiveCheckSmallCluster, bool, true, "RecursiveCheckSmallCluster")
6363
DefineSelectHeadParameter(m_printSizeCount, bool, true, "PrintSizeCount")
6464
DefineSelectHeadParameter(m_selectType, std::string, "BKT", "SelectHeadType")
65+
DefineSelectHeadParameter(m_parallelBKTBuild, bool, false, "ParallelBKTBuild")
6566
#endif
6667

6768
#ifdef DefineBuildHeadParameter

AnnService/src/Core/BKT/BKTIndex.cpp

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -840,7 +840,11 @@ ErrorCode Index<T>::BuildIndex(const void *p_data, SizeType p_vectorNum, Dimensi
840840
m_threadPool.init();
841841

842842
auto t1 = std::chrono::high_resolution_clock::now();
843-
m_pTrees.BuildTrees<T>(m_pSamples, m_iDistCalcMethod, m_iNumberOfThreads);
843+
if (m_pTrees.m_parallelBuild) {
844+
m_pTrees.BuildTreesParallel<T>(m_pSamples, m_iDistCalcMethod, m_iNumberOfThreads);
845+
} else {
846+
m_pTrees.BuildTrees<T>(m_pSamples, m_iDistCalcMethod, m_iNumberOfThreads);
847+
}
844848
auto t2 = std::chrono::high_resolution_clock::now();
845849
SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Build Tree time (s): %lld\n",
846850
std::chrono::duration_cast<std::chrono::seconds>(t2 - t1).count());

AnnService/src/Core/SPANN/SPANNIndex.cpp

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -938,16 +938,22 @@ bool Index<T>::SelectHeadInternal(std::shared_ptr<Helper::VectorSetReader> &p_re
938938
bkt->m_iSamples = m_options.m_iSamples;
939939
bkt->m_iTreeNumber = m_options.m_iTreeNumber;
940940
bkt->m_fBalanceFactor = m_options.m_fBalanceFactor;
941+
bkt->m_parallelBuild = m_options.m_parallelBKTBuild;
941942
bkt->m_pQuantizer = m_pQuantizer;
942943
SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Start invoking BuildTrees.\n");
943944
SPTAGLIB_LOG(
944945
Helper::LogLevel::LL_Info,
945-
"BKTKmeansK: %d, BKTLeafSize: %d, Samples: %d, BKTLambdaFactor:%f TreeNumber: %d, ThreadNum: %d.\n",
946+
"BKTKmeansK: %d, BKTLeafSize: %d, Samples: %d, BKTLambdaFactor:%f TreeNumber: %d, ThreadNum: %d, ParallelBuild: %s.\n",
946947
bkt->m_iBKTKmeansK, bkt->m_iBKTLeafSize, bkt->m_iSamples, bkt->m_fBalanceFactor, bkt->m_iTreeNumber,
947-
m_options.m_iSelectHeadNumberOfThreads);
948-
949-
bkt->BuildTrees<InternalDataType>(data, m_options.m_distCalcMethod, m_options.m_iSelectHeadNumberOfThreads,
950-
nullptr, nullptr, true);
948+
m_options.m_iSelectHeadNumberOfThreads, m_options.m_parallelBKTBuild ? "true" : "false");
949+
950+
if (bkt->m_parallelBuild) {
951+
bkt->BuildTreesParallel<InternalDataType>(data, m_options.m_distCalcMethod, m_options.m_iSelectHeadNumberOfThreads,
952+
nullptr, nullptr, true);
953+
} else {
954+
bkt->BuildTrees<InternalDataType>(data, m_options.m_distCalcMethod, m_options.m_iSelectHeadNumberOfThreads,
955+
nullptr, nullptr, true);
956+
}
951957
auto t2 = std::chrono::high_resolution_clock::now();
952958
double elapsedSeconds = std::chrono::duration_cast<std::chrono::seconds>(t2 - t1).count();
953959
SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "End invoking BuildTrees.\n");

0 commit comments

Comments
 (0)