|
9 | 9 | #include <vector> |
10 | 10 | #include <mutex> |
11 | 11 | #include <shared_mutex> |
| 12 | +#include <atomic> |
| 13 | +#include <omp.h> |
12 | 14 | #include "inc/Core/VectorIndex.h" |
13 | 15 |
|
14 | 16 | #include "CommonUtils.h" |
@@ -655,6 +657,188 @@ break; |
655 | 657 | } |
656 | 658 | } |
657 | 659 |
|
| 660 | + // Parallel BKTree Build - processes sibling nodes in parallel |
| 661 | + template <typename T> |
| 662 | + void BuildTreesParallel(const Dataset<T>& data, DistCalcMethod distMethod, int numOfThreads, |
| 663 | + std::vector<SizeType>* indices = nullptr, std::vector<SizeType>* reverseIndices = nullptr, |
| 664 | + bool dynamicK = false, IAbortOperation* abort = nullptr) |
| 665 | + { |
| 666 | + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Using PARALLEL BKTree build with %d threads.\n", numOfThreads); |
| 667 | + |
| 668 | + // Helper struct for collecting parallel results |
| 669 | + struct ParallelNodeResult { |
| 670 | + SizeType parentIndex; |
| 671 | + SizeType first, last; |
| 672 | + std::vector<SizeType> childCenters; |
| 673 | + std::vector<SizeType> childCounts; |
| 674 | + bool isLeaf; |
| 675 | + bool singleCluster; |
| 676 | + SizeType singleClusterCenter; |
| 677 | + }; |
| 678 | + |
| 679 | + struct BKTStackItem { |
| 680 | + SizeType index, first, last; |
| 681 | + bool debug; |
| 682 | + BKTStackItem(SizeType index_ = -1, SizeType first_ = 0, SizeType last_ = 0, bool debug_ = false) |
| 683 | + : index(index_), first(first_), last(last_), debug(debug_) {} |
| 684 | + }; |
| 685 | + |
| 686 | + std::vector<SizeType> localindices; |
| 687 | + if (indices == nullptr) { |
| 688 | + localindices.resize(data.R()); |
| 689 | + for (SizeType i = 0; i < (SizeType)localindices.size(); i++) localindices[i] = i; |
| 690 | + } |
| 691 | + else { |
| 692 | + localindices.assign(indices->begin(), indices->end()); |
| 693 | + } |
| 694 | + |
| 695 | + // Create a shared KmeansArgs for DynamicFactorSelect (uses all threads) |
| 696 | + KmeansArgs<T> sharedArgs(m_iBKTKmeansK, data.C(), (SizeType)localindices.size(), numOfThreads, distMethod, m_pQuantizer); |
| 697 | + |
| 698 | + if (m_fBalanceFactor < 0) { |
| 699 | + m_fBalanceFactor = DynamicFactorSelect(data, localindices, 0, (SizeType)localindices.size(), sharedArgs, m_iSamples); |
| 700 | + } |
| 701 | + |
| 702 | + std::mt19937 rg; |
| 703 | + m_pSampleCenterMap.clear(); |
| 704 | + |
| 705 | + for (char treeIdx = 0; treeIdx < m_iTreeNumber; treeIdx++) |
| 706 | + { |
| 707 | + std::shuffle(localindices.begin(), localindices.end(), rg); |
| 708 | + |
| 709 | + m_pTreeStart.push_back((SizeType)m_pTreeRoots.size()); |
| 710 | + m_pTreeRoots.emplace_back((SizeType)localindices.size()); |
| 711 | + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Start to build BKTree %d (parallel)\n", treeIdx + 1); |
| 712 | + |
| 713 | + // Level-order processing |
| 714 | + std::vector<BKTStackItem> currentLevel, nextLevel; |
| 715 | + currentLevel.push_back(BKTStackItem(m_pTreeStart[treeIdx], 0, (SizeType)localindices.size(), true)); |
| 716 | + |
| 717 | + int level = 0; |
| 718 | + while (!currentLevel.empty()) { |
| 719 | + if (abort && abort->ShouldAbort()) { |
| 720 | + SPTAGLIB_LOG(Helper::LogLevel::LL_Warning, "Abort!!!\n"); |
| 721 | + return; |
| 722 | + } |
| 723 | + |
| 724 | + size_t levelSize = currentLevel.size(); |
| 725 | + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Processing level %d with %zu nodes...\n", level, levelSize); |
| 726 | + |
| 727 | + std::vector<ParallelNodeResult> results(levelSize); |
| 728 | + |
| 729 | + // Parallel phase: Run k-means for all nodes in this level |
| 730 | + std::atomic_int nextidx(0); |
| 731 | + auto func = [&]() { |
| 732 | + while (true) { |
| 733 | + int idx = nextidx.fetch_add(1); |
| 734 | + if (idx < (int)levelSize) { |
| 735 | + BKTStackItem& item = currentLevel[idx]; |
| 736 | + ParallelNodeResult& result = results[idx]; |
| 737 | + result.parentIndex = item.index; |
| 738 | + result.first = item.first; |
| 739 | + result.last = item.last; |
| 740 | + result.isLeaf = false; |
| 741 | + result.singleCluster = false; |
| 742 | + |
| 743 | + if (item.last - item.first <= m_iBKTLeafSize) { |
| 744 | + // Leaf node |
| 745 | + result.isLeaf = true; |
| 746 | + for (SizeType j = item.first; j < item.last; j++) { |
| 747 | + SizeType cid = (reverseIndices == nullptr) ? localindices[j] : reverseIndices->at(localindices[j]); |
| 748 | + result.childCenters.push_back(cid); |
| 749 | + } |
| 750 | + } else { |
| 751 | + // K-means clustering - use thread-local args with 1 thread |
| 752 | + // (parallelism is at the node level, not within k-means) |
| 753 | + // IMPORTANT: Must use full dataset size because KmeansAssign uses absolute indices |
| 754 | + // (args.label[i] where i ranges from first to last, not 0 to rangeSize) |
| 755 | + KmeansArgs<T> localArgs(m_iBKTKmeansK, data.C(), (SizeType)localindices.size(), 1, distMethod, m_pQuantizer); |
| 756 | + |
| 757 | + int dk = m_iBKTKmeansK; |
| 758 | + if (dynamicK) { |
| 759 | + dk = std::min<int>((item.last - item.first) / m_iBKTLeafSize + 1, m_iBKTKmeansK); |
| 760 | + dk = std::max<int>(dk, 2); |
| 761 | + localArgs._DK = dk; |
| 762 | + } |
| 763 | + |
| 764 | + int numClusters = KmeansClustering(data, localindices, item.first, item.last, localArgs, |
| 765 | + m_iSamples, m_fBalanceFactor, false, abort); |
| 766 | + |
| 767 | + if (numClusters <= 1) { |
| 768 | + result.singleCluster = true; |
| 769 | + SizeType end = min(item.last + 1, (SizeType)localindices.size()); |
| 770 | + std::sort(localindices.begin() + item.first, localindices.begin() + end); |
| 771 | + result.singleClusterCenter = (reverseIndices == nullptr) ? localindices[item.first] : reverseIndices->at(localindices[item.first]); |
| 772 | + for (SizeType j = item.first + 1; j < end; j++) { |
| 773 | + SizeType cid = (reverseIndices == nullptr) ? localindices[j] : reverseIndices->at(localindices[j]); |
| 774 | + result.childCenters.push_back(cid); |
| 775 | + } |
| 776 | + } else { |
| 777 | + SizeType pos = item.first; |
| 778 | + for (int k = 0; k < m_iBKTKmeansK; k++) { |
| 779 | + if (localArgs.counts[k] == 0) continue; |
| 780 | + SizeType cid = (reverseIndices == nullptr) ? localindices[pos + localArgs.counts[k] - 1] : reverseIndices->at(localindices[pos + localArgs.counts[k] - 1]); |
| 781 | + result.childCenters.push_back(cid); |
| 782 | + result.childCounts.push_back(localArgs.counts[k]); |
| 783 | + pos += localArgs.counts[k]; |
| 784 | + } |
| 785 | + } |
| 786 | + } |
| 787 | + } else { |
| 788 | + return; |
| 789 | + } |
| 790 | + } |
| 791 | + }; |
| 792 | + |
| 793 | + std::vector<std::thread> mythreads; |
| 794 | + mythreads.reserve(numOfThreads); |
| 795 | + for (int tid = 0; tid < numOfThreads; tid++) |
| 796 | + { |
| 797 | + mythreads.emplace_back(func); |
| 798 | + } |
| 799 | + for (auto& thread : mythreads) { thread.join(); } |
| 800 | + |
| 801 | + // Sequential phase: Build tree structure and prepare next level |
| 802 | + nextLevel.clear(); |
| 803 | + for (size_t idx = 0; idx < levelSize; idx++) { |
| 804 | + ParallelNodeResult& result = results[idx]; |
| 805 | + m_pTreeRoots[result.parentIndex].childStart = (SizeType)m_pTreeRoots.size(); |
| 806 | + |
| 807 | + if (result.isLeaf) { |
| 808 | + for (SizeType cid : result.childCenters) { |
| 809 | + m_pTreeRoots.emplace_back(cid); |
| 810 | + } |
| 811 | + } else if (result.singleCluster) { |
| 812 | + m_pTreeRoots[result.parentIndex].centerid = result.singleClusterCenter; |
| 813 | + m_pTreeRoots[result.parentIndex].childStart = -m_pTreeRoots[result.parentIndex].childStart; |
| 814 | + for (SizeType cid : result.childCenters) { |
| 815 | + m_pTreeRoots.emplace_back(cid); |
| 816 | + m_pSampleCenterMap[cid] = result.singleClusterCenter; |
| 817 | + } |
| 818 | + m_pSampleCenterMap[-1 - result.singleClusterCenter] = result.parentIndex; |
| 819 | + } else { |
| 820 | + SizeType pos = result.first; |
| 821 | + for (size_t c = 0; c < result.childCenters.size(); c++) { |
| 822 | + SizeType nodeIdx = (SizeType)m_pTreeRoots.size(); |
| 823 | + m_pTreeRoots.emplace_back(result.childCenters[c]); |
| 824 | + if (result.childCounts[c] > 1) { |
| 825 | + nextLevel.push_back(BKTStackItem(nodeIdx, pos, pos + result.childCounts[c] - 1, false)); |
| 826 | + } |
| 827 | + pos += result.childCounts[c]; |
| 828 | + } |
| 829 | + } |
| 830 | + m_pTreeRoots[result.parentIndex].childEnd = (SizeType)m_pTreeRoots.size(); |
| 831 | + } |
| 832 | + |
| 833 | + currentLevel.swap(nextLevel); |
| 834 | + level++; |
| 835 | + } |
| 836 | + |
| 837 | + m_pTreeRoots.emplace_back(-1); |
| 838 | + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "%d BKTree built (parallel), %zu %zu\n", treeIdx + 1, m_pTreeRoots.size() - m_pTreeStart[treeIdx], localindices.size()); |
| 839 | + } |
| 840 | + } |
| 841 | + |
658 | 842 | inline std::uint64_t BufferSize() const |
659 | 843 | { |
660 | 844 | return sizeof(int) + sizeof(SizeType) * m_iTreeNumber + |
@@ -863,6 +1047,7 @@ break; |
863 | 1047 | int m_iTreeNumber, m_iBKTKmeansK, m_iBKTLeafSize, m_iSamples, m_bfs; |
864 | 1048 | float m_fBalanceFactor; |
865 | 1049 | std::shared_ptr<SPTAG::COMMON::IQuantizer> m_pQuantizer; |
| 1050 | + bool m_parallelBuild = false; |
866 | 1051 | }; |
867 | 1052 | } |
868 | 1053 | } |
|
0 commit comments