Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions examples/automated/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,9 @@
# @author Tyson Jones

add_all_local_examples()

include(CheckCXXCompilerFlag)
check_cxx_compiler_flag("-mbmi2" QUEST_COMPILER_SUPPORTS_MBMI2)
if (QUEST_COMPILER_SUPPORTS_MBMI2)
target_compile_options(benchmark_bmi2_bitwise_cpp PRIVATE -mbmi2)
endif()
113 changes: 113 additions & 0 deletions examples/automated/benchmark_bmi2_bitwise.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
/** @file
* Quick benchmark for BMI2-assisted bit-index helpers.
*
* @author tzh476
*/

#include "quest/src/core/bitwise.hpp"

#include <array>
#include <chrono>
#include <iomanip>
#include <iostream>
#include <limits>
#include <string>
#include <vector>

static volatile qindex sinkValue = 0;

template <size_t N>
qindex makeMask(const std::array<int, N>& indices, qindex pattern) {
qindex mask = 0;
for (size_t i=0; i<N; i++)
if ((pattern >> i) & 1)
mask |= QINDEX_ONE << indices[i];
return mask;
}

template <size_t N>
double benchGet(const std::string& name, const std::array<int, N>& indices, const std::vector<qindex>& inputs, qindex ampMask) {
constexpr qindex numIterations = 5000000;
constexpr int numReps = 5;

size_t inputMask = inputs.size() - 1;
double best = std::numeric_limits<double>::max();

for (int r=0; r<numReps; r++) {
qindex acc = static_cast<qindex>(0x13579BDF);
auto start = std::chrono::steady_clock::now();

for (qindex i=0; i<numIterations; i++) {
qindex n = (inputs[static_cast<size_t>(i) & inputMask] + acc) & ampMask;
acc ^= getValueOfBits(n, indices.data(), static_cast<int>(N)) + (i & 7);
}

auto end = std::chrono::steady_clock::now();
sinkValue ^= acc;

double nsPerCall = std::chrono::duration<double, std::nano>(end - start).count() / static_cast<double>(numIterations);
best = std::min(best, nsPerCall);
}

std::cout << std::left << std::setw(30) << name << " " << std::fixed << std::setprecision(3) << best << " ns/call\n";
return best;
}

template <size_t N>
double benchInsert(const std::string& name, const std::array<int, N>& indices, const std::vector<qindex>& inputs, qindex valueMask, qindex insertedMask) {
constexpr qindex numIterations = 5000000;
constexpr int numReps = 5;

size_t inputMask = inputs.size() - 1;
double best = std::numeric_limits<double>::max();

for (int r=0; r<numReps; r++) {
qindex acc = static_cast<qindex>(0x2468ACE0);
auto start = std::chrono::steady_clock::now();

for (qindex i=0; i<numIterations; i++) {
qindex n = (inputs[static_cast<size_t>(i) & inputMask] + acc) & valueMask;
acc ^= insertBitsWithMaskedValues(n, indices.data(), static_cast<int>(N), insertedMask) + (i & 15);
}

auto end = std::chrono::steady_clock::now();
sinkValue ^= acc;

double nsPerCall = std::chrono::duration<double, std::nano>(end - start).count() / static_cast<double>(numIterations);
best = std::min(best, nsPerCall);
}

std::cout << std::left << std::setw(30) << name << " " << std::fixed << std::setprecision(3) << best << " ns/call\n";
return best;
}

int main() {
#if defined(QUEST_USE_BMI2_INTRINSICS)
std::cout << "BMI2 intrinsics: enabled\n";
#else
std::cout << "BMI2 intrinsics: disabled\n";
#endif

std::vector<qindex> inputs(1 << 15);
qindex state = static_cast<qindex>(0x123456789ABCDEFULL);
for (qindex& input : inputs) {
state = state * static_cast<qindex>(0x5851F42D4C957F2DULL) + static_cast<qindex>(0x14057B7EF767814FULL);
input = state;
}

qindex nineQubitMask = (QINDEX_ONE << 9) - QINDEX_ONE;
const std::array<int, 2> inds2 = {2, 7};
const std::array<int, 5> inds5 = {0, 2, 4, 6, 8};
const std::array<int, 6> inds6 = {0, 1, 3, 5, 7, 8};

benchGet("getValueOfBits 2 bits", inds2, inputs, nineQubitMask);
benchGet("getValueOfBits 5 bits", inds5, inputs, nineQubitMask);
benchGet("getValueOfBits 6 bits", inds6, inputs, nineQubitMask);

benchInsert("insertBitsWithMask 2 bits", inds2, inputs, (QINDEX_ONE << 7) - QINDEX_ONE, makeMask(inds2, 0b01));
benchInsert("insertBitsWithMask 5 bits", inds5, inputs, (QINDEX_ONE << 4) - QINDEX_ONE, makeMask(inds5, 0b10101));
benchInsert("insertBitsWithMask 6 bits", inds6, inputs, (QINDEX_ONE << 3) - QINDEX_ONE, makeMask(inds6, 0b101011));

std::cout << "sink: " << sinkValue << "\n";
return 0;
}
54 changes: 53 additions & 1 deletion quest/src/core/bitwise.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,11 @@
#include <intrin.h>
#endif

#if defined(__BMI2__) && (defined(__x86_64__) || defined(__i386__) || defined(_M_X64) || defined(_M_IX86)) && !defined(__CUDA_ARCH__) && !defined(__HIP_DEVICE_COMPILE__)
#include <immintrin.h>
#define QUEST_USE_BMI2_INTRINSICS
#endif

#include "quest/include/types.h"

#include "quest/src/core/inliner.hpp"
Expand Down Expand Up @@ -116,6 +121,35 @@ INLINE qindex setBit(qindex number, int bitIndex, int bitValue) {
}


INLINE bool getBitMaskAndCheckIsIncreasing(qindex* maskPtr, const int* bitIndices, int numIndices) {

// bitIndices can be arbitrarily ordered, though PEXT requires increasing order
qindex mask = 0;
bool isIncreasing = true;

for (int i=0; i<numIndices; i++) {
mask |= QINDEX_ONE << bitIndices[i];

if (i > 0)
isIncreasing = isIncreasing && bitIndices[i-1] < bitIndices[i];
}

*maskPtr = mask;
return isIncreasing;
}


INLINE qindex getBitMaskOfIndices(const int* bitIndices, int numIndices) {

qindex mask = 0;

for (int i=0; i<numIndices; i++)
mask |= QINDEX_ONE << bitIndices[i];

return mask;
}


INLINE int getBitMaskParity(qindex mask) {

// Try a builtin if on GCC/Clang and it is available
Expand Down Expand Up @@ -164,6 +198,12 @@ INLINE int getBitMaskParity(qindex mask) {


INLINE qindex insertBits(qindex number, const int* bitIndices, int numIndices, int bitValue) {

#if defined(QUEST_USE_BMI2_INTRINSICS)
qindex mask = getBitMaskOfIndices(bitIndices, numIndices);
qindex result = static_cast<qindex>(_pdep_u64(static_cast<unsigned long long>(number), ~static_cast<unsigned long long>(mask)));
return bitValue? result | mask : result;
#endif

// bitIndices must be strictly increasing
for (int i=0; i<numIndices; i++)
Expand All @@ -190,6 +230,14 @@ INLINE qindex getValueOfBits(qindex number, const int* bitIndices, int numIndice
// bits are arbitrarily ordered, which affects value
qindex value = 0;

#if defined(QUEST_USE_BMI2_INTRINSICS)
qindex mask;
bool isIncreasing = getBitMaskAndCheckIsIncreasing(&mask, bitIndices, numIndices);

if (isIncreasing)
return static_cast<qindex>(_pext_u64(static_cast<unsigned long long>(number), static_cast<unsigned long long>(mask)));
#endif

for (int i=0; i<numIndices; i++)
value |= getBit(number, bitIndices[i]) << i;

Expand All @@ -208,6 +256,10 @@ INLINE qindex getValueOfBits(qindex number, const int* bitIndices, int numIndice
INLINE qindex insertBitsWithMaskedValues(qindex number, const int* bitInds, int numBits, qindex mask) {

// bitInds must be sorted (increasing), and mask must be zero everywhere except bitInds
#if defined(QUEST_USE_BMI2_INTRINSICS)
return mask | static_cast<qindex>(_pdep_u64(static_cast<unsigned long long>(number), ~static_cast<unsigned long long>(mask)));
#endif

return mask | insertBits(number, bitInds, numBits, 0);
}

Expand Down Expand Up @@ -379,4 +431,4 @@ INLINE void setToBitsOfInteger(int* bits, qindex number, int numBits) {



#endif // BITWISE_HPP
#endif // BITWISE_HPP
3 changes: 2 additions & 1 deletion tests/unit/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

target_sources(tests
PUBLIC
bitwise.cpp
calculations.cpp
channels.cpp
debug.cpp
Expand All @@ -16,4 +17,4 @@ target_sources(tests
qureg.cpp
trotterisation.cpp
types.cpp
)
)
126 changes: 126 additions & 0 deletions tests/unit/bitwise.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
/** @file
* Unit tests of internal bitwise helpers.
*
* @defgroup unitbitwise Bitwise
* @ingroup unittests
*/

#include "quest/src/core/bitwise.hpp"

#include <catch2/catch_test_macros.hpp>

#include "tests/utils/macros.hpp"



/*
* UTILITIES
*/

#define TEST_CATEGORY \
LABEL_UNIT_TAG "[bitwise]"


static qindex getReferenceInsertBits(qindex number, const int* bitIndices, int numIndices, int bitValue) {

for (int i=0; i<numIndices; i++)
number = insertBit(number, bitIndices[i], bitValue);

return number;
}


static qindex getReferenceValueOfBits(qindex number, const int* bitIndices, int numIndices) {

qindex value = 0;

for (int i=0; i<numIndices; i++)
value |= getBit(number, bitIndices[i]) << i;

return value;
}



/**
* TESTS
*
* @ingroup unitbitwise
* @{
*/


TEST_CASE( "insertBits", TEST_CATEGORY ) {

SECTION( LABEL_CORRECTNESS ) {

int bitInds[] = {1, 3, 6, 9, 12, 20};
int numInds = 6;
qindex number = 0b101101001011;

REQUIRE( insertBits(number, bitInds, numInds, 0) == getReferenceInsertBits(number, bitInds, numInds, 0) );
REQUIRE( insertBits(number, bitInds, numInds, 1) == getReferenceInsertBits(number, bitInds, numInds, 1) );
}

SECTION( LABEL_VALIDATION ) {

// no validation!
SUCCEED( );
}
}


TEST_CASE( "getValueOfBits", TEST_CATEGORY ) {

SECTION( LABEL_CORRECTNESS ) {

qindex number = 0b101101101001011;

SECTION( "increasing indices" ) {

int bitInds[] = {0, 2, 5, 8, 11, 14};
int numInds = 6;
REQUIRE( getValueOfBits(number, bitInds, numInds) == getReferenceValueOfBits(number, bitInds, numInds) );
}

SECTION( "arbitrarily ordered indices" ) {

int bitInds[] = {14, 0, 8, 2, 11, 5};
int numInds = 6;
REQUIRE( getValueOfBits(number, bitInds, numInds) == getReferenceValueOfBits(number, bitInds, numInds) );
}
}

SECTION( LABEL_VALIDATION ) {

// no validation!
SUCCEED( );
}
}


TEST_CASE( "insertBitsWithMaskedValues", TEST_CATEGORY ) {

SECTION( LABEL_CORRECTNESS ) {

int bitInds[] = {2, 4, 7, 10, 13, 21};
int numInds = 6;
qindex number = 0b1101001011;
qindex mask = 0;

for (int i=0; i<numInds; i++)
mask |= QINDEX_ONE << bitInds[i];

qindex expected = mask | getReferenceInsertBits(number, bitInds, numInds, 0);
REQUIRE( insertBitsWithMaskedValues(number, bitInds, numInds, mask) == expected );
}

SECTION( LABEL_VALIDATION ) {

// no validation!
SUCCEED( );
}
}


/** @} (end defgroup) */
Loading