diff --git a/benchmarks/src/replace.cpp b/benchmarks/src/replace.cpp index 7934b1eb91..29a61936c4 100644 --- a/benchmarks/src/replace.cpp +++ b/benchmarks/src/replace.cpp @@ -35,7 +35,8 @@ void rc(benchmark::State& state) { } } -// replace() is vectorized for 4 and 8 bytes only. +BENCHMARK(r); +BENCHMARK(r); BENCHMARK(r); BENCHMARK(r); diff --git a/stl/inc/algorithm b/stl/inc/algorithm index 73e3b0d138..9714d100cd 100644 --- a/stl/inc/algorithm +++ b/stl/inc/algorithm @@ -121,6 +121,13 @@ __declspec(noalias) bool __stdcall __std_includes_less_8u( #endif // ^^^ _VECTORIZED_INCLUDES ^^^ #if _VECTORIZED_REPLACE +#if _VECTORIZED_REPLACE_1_2 +__declspec(noalias) void __stdcall __std_replace_1( + void* _First, void* _Last, uint8_t _Old_val, uint8_t _New_val) noexcept; +__declspec(noalias) void __stdcall __std_replace_2( + void* _First, void* _Last, uint16_t _Old_val, uint16_t _New_val) noexcept; +#endif // ^^^ _VECTORIZED_REPLACE_1_2 ^^^ + // TRANSITION, DevCom-10610477 __declspec(noalias) void __stdcall __std_replace_4( void* _First, void* _Last, uint32_t _Old_val, uint32_t _New_val) noexcept; @@ -383,14 +390,25 @@ bool _Includes_vectorized( template __declspec(noalias) void _Replace_vectorized( _Ty* const _First, _Ty* const _Last, const _TVal1 _Old_val, const _TVal2 _New_val) noexcept { - if constexpr (sizeof(_Ty) == 4) { - ::__std_replace_4( - _First, _Last, _STD _Find_arg_cast(_Old_val), _STD _Find_arg_cast(_New_val)); - } else if constexpr (sizeof(_Ty) == 8) { - ::__std_replace_8( - _First, _Last, _STD _Find_arg_cast(_Old_val), _STD _Find_arg_cast(_New_val)); - } else { - static_assert(false, "unexpected size"); +#if _VECTORIZED_REPLACE_1_2 + if constexpr (sizeof(_Ty) == 1) { + ::__std_replace_1( + _First, _Last, _STD _Find_arg_cast(_Old_val), _STD _Find_arg_cast(_New_val)); + } else if constexpr (sizeof(_Ty) == 2) { + ::__std_replace_2( + _First, _Last, _STD _Find_arg_cast(_Old_val), _STD _Find_arg_cast(_New_val)); + } else +#endif // ^^^ _VECTORIZED_REPLACE_1_2 ^^^ + { + if constexpr (sizeof(_Ty) == 4) { + ::__std_replace_4( + _First, _Last, _STD _Find_arg_cast(_Old_val), _STD _Find_arg_cast(_New_val)); + } else if constexpr (sizeof(_Ty) == 8) { + ::__std_replace_8( + _First, _Last, _STD _Find_arg_cast(_Old_val), _STD _Find_arg_cast(_New_val)); + } else { + static_assert(false, "unexpected size"); + } } } #endif // ^^^ _VECTORIZED_REPLACE ^^^ @@ -491,10 +509,18 @@ _Ty* _Unique_copy_vectorized(const _Ty* const _First, const _Ty* const _Last, _T #endif // ^^^ _VECTORIZED_UNIQUE_COPY ^^^ #if _VECTORIZED_REPLACE +#if _VECTORIZED_REPLACE_1_2 +template +constexpr bool _Have_masked_op_for_iter = true; +#else // ^^^ _VECTORIZED_REPLACE_1_2 / !_VECTORIZED_REPLACE_1_2 vvv +template +constexpr bool _Have_masked_op_for_iter = sizeof(_Iter_value_t<_Iter>) >= 4; // avx masked op compatible size +#endif // ^^^ !_VECTORIZED_REPLACE_1_2 ^^^ + // Can we activate the vector algorithms for replace? template constexpr bool _Vector_alg_in_replace_is_safe = _Vector_alg_in_find_is_safe<_Iter, _Ty1> // can search for the value - && sizeof(_Iter_value_t<_Iter>) >= 4; // avx masked op compatible size + && _Have_masked_op_for_iter<_Iter>; // Can we activate the vector algorithms for ranges::replace? template diff --git a/stl/inc/xutility b/stl/inc/xutility index 488bc52d04..5a0ae369a1 100644 --- a/stl/inc/xutility +++ b/stl/inc/xutility @@ -89,7 +89,7 @@ _STL_DISABLE_CLANG_WARNINGS #define _VECTORIZED_MISMATCH _VECTORIZED_FOR_X64_X86_ARM64_ARM64EC #define _VECTORIZED_REMOVE _VECTORIZED_FOR_X64_X86_ARM64_ARM64EC #define _VECTORIZED_REMOVE_COPY _VECTORIZED_FOR_X64_X86_ARM64_ARM64EC -#define _VECTORIZED_REPLACE _VECTORIZED_FOR_X64_X86 +#define _VECTORIZED_REPLACE _VECTORIZED_FOR_X64_X86_ARM64_ARM64EC #define _VECTORIZED_REPLACE_COPY _VECTORIZED_FOR_X64_X86_ARM64_ARM64EC #define _VECTORIZED_REVERSE _VECTORIZED_FOR_X64_X86_ARM64_ARM64EC #define _VECTORIZED_REVERSE_COPY _VECTORIZED_FOR_X64_X86_ARM64_ARM64EC @@ -104,6 +104,12 @@ _STL_DISABLE_CLANG_WARNINGS // as this does not improve performance over the scalar code. #define _VECTORIZED_MINMAX_ELEMENT_64BIT_INT _VECTORIZED_FOR_X64_X86 +#if defined(_M_ARM64) || defined(_M_ARM64EC) +#define _VECTORIZED_REPLACE_1_2 1 +#else +#define _VECTORIZED_REPLACE_1_2 0 +#endif + #ifndef _USE_STD_VECTOR_FLOATING_ALGORITHMS #if _USE_STD_VECTOR_ALGORITHMS && !defined(_M_FP_EXCEPT) #define _USE_STD_VECTOR_FLOATING_ALGORITHMS 1 diff --git a/stl/src/vector_algorithms.cpp b/stl/src/vector_algorithms.cpp index 86456daf61..207d8f8a11 100644 --- a/stl/src/vector_algorithms.cpp +++ b/stl/src/vector_algorithms.cpp @@ -13,6 +13,7 @@ #if defined(_M_ARM64) || defined(_M_ARM64EC) #include +#include #include #else // ^^^ defined(_M_ARM64) || defined(_M_ARM64EC) / !defined(_M_ARM64) && !defined(_M_ARM64EC) vvv @@ -9612,6 +9613,119 @@ __declspec(noalias) size_t __stdcall __std_mismatch_8( namespace { namespace _Replacing { #if defined(_M_ARM64) || defined(_M_ARM64EC) + struct _Traits_1_sve { + static svuint8_t _Load(const svbool_t _Pred, const void* const _Ptr) noexcept { + return svld1(_Pred, static_cast(_Ptr)); + } + + static svuint8_t _Set(const uint8_t _Val) noexcept { + return svdup_n_u8(_Val); + } + + static svbool_t _Cmp(const svbool_t _Pred, const svuint8_t _Lhs, const svuint8_t _Rhs) noexcept { + return svcmpeq(_Pred, _Lhs, _Rhs); + } + + static void _Store(const svbool_t _Pred, void* const _Ptr, const svuint8_t _Val) noexcept { + svst1(_Pred, static_cast(_Ptr), _Val); + } + }; + + struct _Traits_2_sve { + static svuint16_t _Load(const svbool_t _Pred, const void* const _Ptr) noexcept { + return svld1(_Pred, static_cast(_Ptr)); + } + + static svuint16_t _Set(const uint16_t _Val) noexcept { + return svdup_n_u16(_Val); + } + + static svbool_t _Cmp(const svbool_t _Pred, const svuint16_t _Lhs, const svuint16_t _Rhs) noexcept { + return svcmpeq(_Pred, _Lhs, _Rhs); + } + + static void _Store(const svbool_t _Pred, void* const _Ptr, const svuint16_t _Val) noexcept { + svst1(_Pred, static_cast(_Ptr), _Val); + } + }; + + struct _Traits_4_sve { + static svuint32_t _Load(const svbool_t _Pred, const void* const _Ptr) noexcept { + return svld1(_Pred, static_cast(_Ptr)); + } + + static svuint32_t _Set(const uint32_t _Val) noexcept { + return svdup_n_u32(_Val); + } + + static svbool_t _Cmp(const svbool_t _Pred, const svuint32_t _Lhs, const svuint32_t _Rhs) noexcept { + return svcmpeq(_Pred, _Lhs, _Rhs); + } + + static void _Store(const svbool_t _Pred, void* const _Ptr, const svuint32_t _Val) noexcept { + svst1(_Pred, static_cast(_Ptr), _Val); + } + }; + + struct _Traits_8_sve { + static svuint64_t _Load(const svbool_t _Pred, const void* const _Ptr) noexcept { + return svld1(_Pred, static_cast(_Ptr)); + } + + static svuint64_t _Set(const uint64_t _Val) noexcept { + return svdup_n_u64(_Val); + } + + static svbool_t _Cmp(const svbool_t _Pred, const svuint64_t _Lhs, const svuint64_t _Rhs) noexcept { + return svcmpeq(_Pred, _Lhs, _Rhs); + } + + static void _Store(const svbool_t _Pred, void* const _Ptr, const svuint64_t _Val) noexcept { + svst1(_Pred, static_cast(_Ptr), _Val); + } + }; + + template + __declspec(noalias) void __stdcall _Replace_impl( + void* _First, void* const _Last, const _Ty _Old_val, const _Ty _New_val) noexcept { + + if (_Use_FEAT_SVE()) { + // Arm Architecture Reference Manual for A-profile architecture, + // B1.4.2 "Configurable SVE vector lengths": + // "The architecturally defined SVL set is all powers of two from 128 to 2048 bits inclusive." + const size_t _Sve_vl = svcntb(); + const size_t _Size_bytes = _Byte_length(_First, _Last); + const size_t _Full_vl_bytes = _Size_bytes & ~size_t{_Sve_vl - 1}; + + const void* _Stop_at = _First; + _Advance_bytes(_Stop_at, _Full_vl_bytes); + + const auto _Comparand = _Traits::_Set(_Old_val); + const auto _Replacement = _Traits::_Set(_New_val); + + const auto _True = svptrue_b8(); + while (_First != _Stop_at) { + const auto _Data = _Traits::_Load(_True, _First); + const auto _Mask = _Traits::_Cmp(_True, _Data, _Comparand); + _Traits::_Store(_Mask, _First, _Replacement); + _Advance_bytes(_First, _Sve_vl); + } + + if (const size_t _Tail_length = _Size_bytes & size_t{_Sve_vl - 1}; _Tail_length != 0) { + const auto _Tail_mask = svwhilelt_b8(size_t{0}, _Tail_length); + const auto _Data = _Traits::_Load(_Tail_mask, _First); + const auto _Mask = _Traits::_Cmp(_Tail_mask, _Data, _Comparand); + _Traits::_Store(_Mask, _First, _Replacement); + } + } else { + for (auto _Cur = static_cast<_Ty*>(_First); _Cur != _Last; ++_Cur) { + if (*_Cur == _Old_val) { + *_Cur = _New_val; + } + } + } + } + template __declspec(noalias) void __stdcall _Replace_copy_impl( const void* _First, const void* const _Last, void* _Dest, const _Ty _Old_val, const _Ty _New_val) noexcept { @@ -9750,10 +9864,29 @@ namespace { extern "C" { -#ifndef _M_ARM64 +#if defined(_M_ARM64) || defined(_M_ARM64EC) +__declspec(noalias) void __stdcall __std_replace_1( + void* const _First, void* const _Last, const uint8_t _Old_val, const uint8_t _New_val) noexcept { + _Replacing::_Replace_impl<_Replacing::_Traits_1_sve>(_First, _Last, _Old_val, _New_val); +} + +__declspec(noalias) void __stdcall __std_replace_2( + void* const _First, void* const _Last, const uint16_t _Old_val, const uint16_t _New_val) noexcept { + _Replacing::_Replace_impl<_Replacing::_Traits_2_sve>(_First, _Last, _Old_val, _New_val); +} + +__declspec(noalias) void __stdcall __std_replace_4( + void* const _First, void* const _Last, const uint32_t _Old_val, const uint32_t _New_val) noexcept { + _Replacing::_Replace_impl<_Replacing::_Traits_4_sve>(_First, _Last, _Old_val, _New_val); +} + +__declspec(noalias) void __stdcall __std_replace_8( + void* const _First, void* const _Last, const uint64_t _Old_val, const uint64_t _New_val) noexcept { + _Replacing::_Replace_impl<_Replacing::_Traits_8_sve>(_First, _Last, _Old_val, _New_val); +} +#else // ^^^ defined(_M_ARM64) || defined(_M_ARM64EC) / !defined(_M_ARM64) && !defined(_M_ARM64EC) vvv __declspec(noalias) void __stdcall __std_replace_4( void* _First, void* const _Last, const uint32_t _Old_val, const uint32_t _New_val) noexcept { -#ifndef _M_ARM64EC if (_Use_avx2()) { const __m256i _Comparand = _mm256_broadcastd_epi32(_mm_cvtsi32_si128(_Old_val)); const __m256i _Replacement = _mm256_broadcastd_epi32(_mm_cvtsi32_si128(_New_val)); @@ -9778,9 +9911,7 @@ __declspec(noalias) void __stdcall __std_replace_4( } _mm256_zeroupper(); // TRANSITION, DevCom-10331414 - } else -#endif // ^^^ !defined(_M_ARM64EC) ^^^ - { + } else { for (auto _Cur = reinterpret_cast(_First); _Cur != _Last; ++_Cur) { if (*_Cur == _Old_val) { *_Cur = _New_val; @@ -9791,7 +9922,6 @@ __declspec(noalias) void __stdcall __std_replace_4( __declspec(noalias) void __stdcall __std_replace_8( void* _First, void* const _Last, const uint64_t _Old_val, const uint64_t _New_val) noexcept { -#ifndef _M_ARM64EC if (_Use_avx2()) { #ifdef _WIN64 const __m256i _Comparand = _mm256_broadcastq_epi64(_mm_cvtsi64_si128(_Old_val)); @@ -9821,9 +9951,7 @@ __declspec(noalias) void __stdcall __std_replace_8( } _mm256_zeroupper(); // TRANSITION, DevCom-10331414 - } else -#endif // ^^^ !defined(_M_ARM64EC) ^^^ - { + } else { for (auto _Cur = reinterpret_cast(_First); _Cur != _Last; ++_Cur) { if (*_Cur == _Old_val) { *_Cur = _New_val; @@ -9831,7 +9959,7 @@ __declspec(noalias) void __stdcall __std_replace_8( } } } -#endif // ^^^ !defined(_M_ARM64) ^^^ +#endif // ^^^ !defined(_M_ARM64) && !defined(_M_ARM64EC) ^^^ __declspec(noalias) void __stdcall __std_replace_copy_1(const void* const _First, const void* const _Last, void* const _Dest, const uint8_t _Old_val, const uint8_t _New_val) noexcept { diff --git a/tests/std/tests/VSO_0000000_vector_algorithms/test.cpp b/tests/std/tests/VSO_0000000_vector_algorithms/test.cpp index 264b0c1926..77f9abe13b 100644 --- a/tests/std/tests/VSO_0000000_vector_algorithms/test.cpp +++ b/tests/std/tests/VSO_0000000_vector_algorithms/test.cpp @@ -767,8 +767,13 @@ void test_case_replace_copy(const vector& input, vector& out_expected, vec template void test_replace(mt19937_64& gen) { - // replace() is vectorized for 4 and 8 bytes only. +#if defined(_M_ARM64) || defined(_M_ARM64EC) + // For ARM64/ARM64EC, replace() is always vectorized. + constexpr bool replace_is_vectorized = true; +#else + // For x64/x86, replace() is vectorized for 4 and 8 bytes only. constexpr bool replace_is_vectorized = sizeof(T) >= 4; +#endif using TD = conditional_t; uniform_int_distribution dis(0, 9); @@ -1961,11 +1966,6 @@ int main() { test_min_max_element(gen); test_min_max_element_pointers(gen); - - test_replace(gen); - test_replace(gen); - test_replace(gen); - test_replace(gen); #else // ^^^ defined(_CALL_ALL_X64_VECTOR_ALGORITHMS_ON_ARM64EC) / normal test coverage vvv test_vector_algorithms(gen); test_various_containers();