diff --git a/src/shammodels/gsph/include/shammodels/gsph/modules/GSPHUtilities.hpp b/src/shammodels/gsph/include/shammodels/gsph/modules/GSPHUtilities.hpp index f96ae01d2..9ea0f3e07 100644 --- a/src/shammodels/gsph/include/shammodels/gsph/modules/GSPHUtilities.hpp +++ b/src/shammodels/gsph/include/shammodels/gsph/modules/GSPHUtilities.hpp @@ -77,7 +77,7 @@ namespace shammodels::gsph { PatchtreeField interactR_mpi_tree = sptree.make_patch_tree_field( sched, - shamsys::instance::get_compute_queue(), + shamsys::instance::get_compute_scheduler_ptr(), interactR_patch, [](Tscal h0, Tscal h1, Tscal h2, Tscal h3, Tscal h4, Tscal h5, Tscal h6, Tscal h7) { return sham::max_8points(h0, h1, h2, h3, h4, h5, h6, h7); diff --git a/src/shammodels/gsph/src/modules/GSPHGhostHandler.cpp b/src/shammodels/gsph/src/modules/GSPHGhostHandler.cpp index 4200bef3c..644b933e9 100644 --- a/src/shammodels/gsph/src/modules/GSPHGhostHandler.cpp +++ b/src/shammodels/gsph/src/modules/GSPHGhostHandler.cpp @@ -133,8 +133,7 @@ auto GSPHGhostHandler::find_interfaces( using BCShearingPeriodic = typename CfgClass::ShearingPeriodic; if (BCPeriodic *cfg = std::get_if(&ghost_config)) { - sycl::host_accessor acc_tf{ - shambase::get_check_ref(int_range_max_tree.internal_buf), sycl::read_only}; + auto acc_tf = int_range_max_tree.internal_buf->template mirror_to(); for (i32 xoff = -repetition_x; xoff <= repetition_x; xoff++) { for (i32 yoff = -repetition_y; yoff <= repetition_y; yoff++) { @@ -188,8 +187,7 @@ auto GSPHGhostHandler::find_interfaces( } } } else if (BCShearingPeriodic *cfg = std::get_if(&ghost_config)) { - sycl::host_accessor acc_tf{ - shambase::get_check_ref(int_range_max_tree.internal_buf), sycl::read_only}; + auto acc_tf = int_range_max_tree.internal_buf->template mirror_to(); for_each_patch_shift(*cfg, bsize, [&](i32_3 ioff, ShiftInfo shift) { i32 xoff = ioff.x(); @@ -242,8 +240,7 @@ auto GSPHGhostHandler::find_interfaces( }); } else { - sycl::host_accessor acc_tf{ - shambase::get_check_ref(int_range_max_tree.internal_buf), sycl::read_only}; + auto acc_tf = int_range_max_tree.internal_buf->template mirror_to(); // sender translation vec periodic_offset = vec{0, 0, 0}; diff --git a/src/shammodels/sph/include/shammodels/sph/SPHUtilities.hpp b/src/shammodels/sph/include/shammodels/sph/SPHUtilities.hpp index af5f17a9b..2b8d0c835 100644 --- a/src/shammodels/sph/include/shammodels/sph/SPHUtilities.hpp +++ b/src/shammodels/sph/include/shammodels/sph/SPHUtilities.hpp @@ -94,7 +94,7 @@ namespace shammodels::sph { PatchtreeField interactR_mpi_tree = sptree.make_patch_tree_field( sched, - shamsys::instance::get_compute_queue(), + shamsys::instance::get_compute_scheduler_ptr(), interactR_patch, [](flt h0, flt h1, flt h2, flt h3, flt h4, flt h5, flt h6, flt h7) { return sham::max_8points(h0, h1, h2, h3, h4, h5, h6, h7); diff --git a/src/shammodels/sph/src/BasicSPHGhosts.cpp b/src/shammodels/sph/src/BasicSPHGhosts.cpp index 515d6a454..17a6755be 100644 --- a/src/shammodels/sph/src/BasicSPHGhosts.cpp +++ b/src/shammodels/sph/src/BasicSPHGhosts.cpp @@ -290,8 +290,7 @@ auto BasicSPHGhostHandler::find_interfaces( base_timer.start(); if (BCPeriodic *cfg = std::get_if(&ghost_config)) { - sycl::host_accessor acc_tf{ - shambase::get_check_ref(int_range_max_tree.internal_buf), sycl::read_only}; + auto acc_tf = int_range_max_tree.internal_buf->template mirror_to(); for (i32 xoff = -repetition_x; xoff <= repetition_x; xoff++) { for (i32 yoff = -repetition_y; yoff <= repetition_y; yoff++) { @@ -300,10 +299,8 @@ auto BasicSPHGhostHandler::find_interfaces( // sender translation vec periodic_offset = vec{xoff * bsize.x(), yoff * bsize.y(), zoff * bsize.z()}; - sycl::host_accessor tree{ - shambase::get_check_ref(sptree.serial_tree_buf), sycl::read_only}; - sycl::host_accessor lpid{ - shambase::get_check_ref(sptree.linked_patch_ids_buf), sycl::read_only}; + auto tree = sptree.serial_tree_buf->template mirror_to(); + auto lpid = sptree.linked_patch_ids_buf->template mirror_to(); #pragma omp parallel for for (u32 i = 0; i < sched.patch_list.local.size(); i++) { @@ -359,8 +356,7 @@ auto BasicSPHGhostHandler::find_interfaces( } } } else if (BCShearingPeriodic *cfg = std::get_if(&ghost_config)) { - sycl::host_accessor acc_tf{ - shambase::get_check_ref(int_range_max_tree.internal_buf), sycl::read_only}; + auto acc_tf = int_range_max_tree.internal_buf->template mirror_to(); for_each_patch_shift(*cfg, bsize, [&](i32_3 ioff, ShiftInfo shift) { i32 xoff = ioff.x(); @@ -369,10 +365,8 @@ auto BasicSPHGhostHandler::find_interfaces( vec offset = shift.shift; - sycl::host_accessor tree{ - shambase::get_check_ref(sptree.serial_tree_buf), sycl::read_only}; - sycl::host_accessor lpid{ - shambase::get_check_ref(sptree.linked_patch_ids_buf), sycl::read_only}; + auto tree = sptree.serial_tree_buf->template mirror_to(); + auto lpid = sptree.linked_patch_ids_buf->template mirror_to(); #pragma omp parallel for for (u32 i = 0; i < sched.patch_list.local.size(); i++) { @@ -429,14 +423,12 @@ auto BasicSPHGhostHandler::find_interfaces( }); } else { - sycl::host_accessor acc_tf{ - shambase::get_check_ref(int_range_max_tree.internal_buf), sycl::read_only}; + auto acc_tf = int_range_max_tree.internal_buf->template mirror_to(); // sender translation vec periodic_offset = vec{0, 0, 0}; - sycl::host_accessor tree{shambase::get_check_ref(sptree.serial_tree_buf), sycl::read_only}; - sycl::host_accessor lpid{ - shambase::get_check_ref(sptree.linked_patch_ids_buf), sycl::read_only}; + auto tree = sptree.serial_tree_buf->template mirror_to(); + auto lpid = sptree.linked_patch_ids_buf->template mirror_to(); #pragma omp parallel for for (u32 i = 0; i < sched.patch_list.local.size(); i++) { diff --git a/src/shammodels/sph/src/modules/SPHSetup.cpp b/src/shammodels/sph/src/modules/SPHSetup.cpp index 051fb9042..2ec4b7599 100644 --- a/src/shammodels/sph/src/modules/SPHSetup.cpp +++ b/src/shammodels/sph/src/modules/SPHSetup.cpp @@ -652,7 +652,7 @@ void shammodels::sph::modules::SPHSetup::apply_setup_new( shambase::throw_unimplemented(); } - sycl::buffer new_id_buf = sptree.compute_patch_owner( + sham::DeviceBuffer new_id_buf = sptree.compute_patch_owner( shamsys::instance::get_compute_scheduler_ptr(), pos_field.get_buf(), pos_field.get_obj_cnt()); @@ -660,7 +660,7 @@ void shammodels::sph::modules::SPHSetup::apply_setup_new( std::unordered_map> index_per_ranks; bool err_id_in_newid = false; { - sycl::host_accessor nid{new_id_buf, sycl::read_only}; + std::vector nid = new_id_buf.copy_to_stdvec(); for (u32 i = 0; i < pos_field.get_obj_cnt(); i++) { u64 patch_id = nid[i]; bool err = patch_id == u64_max; diff --git a/src/shamrock/include/shamrock/patch/PatchField.hpp b/src/shamrock/include/shamrock/patch/PatchField.hpp index ce010143e..f68b3edcb 100644 --- a/src/shamrock/include/shamrock/patch/PatchField.hpp +++ b/src/shamrock/include/shamrock/patch/PatchField.hpp @@ -16,8 +16,11 @@ */ #include "shambase/DistributedData.hpp" +#include "shambackends/DeviceBuffer.hpp" +#include "shambackends/DeviceScheduler.hpp" #include "shambackends/sycl.hpp" #include +#include namespace shamrock::patch { template @@ -33,10 +36,12 @@ namespace shamrock::patch { template class PatchtreeField { public: - std::unique_ptr> internal_buf; + std::optional> internal_buf; inline void reset() { internal_buf.reset(); } - inline void allocate(u32 size) { internal_buf = std::make_unique>(size); } + inline void allocate(u32 size, sham::DeviceScheduler_ptr sched) { + internal_buf.emplace(size, sched); + } }; } // namespace shamrock::patch diff --git a/src/shamrock/include/shamrock/scheduler/ReattributeDataUtility.hpp b/src/shamrock/include/shamrock/scheduler/ReattributeDataUtility.hpp index 18d9b396b..4aa1f0c44 100644 --- a/src/shamrock/include/shamrock/scheduler/ReattributeDataUtility.hpp +++ b/src/shamrock/include/shamrock/scheduler/ReattributeDataUtility.hpp @@ -59,12 +59,12 @@ namespace shamrock { * bound). */ template - shambase::DistributedData> compute_new_pid( + shambase::DistributedData> compute_new_pid( SerialPatchTree &sptree, u32 ipos) { StackEntry stack_loc{}; - shambase::DistributedData> newid_buf_map; + shambase::DistributedData> newid_buf_map; sched.patch_data.for_each_patchdata([&](u64 id, shamrock::patch::PatchDataLayer &pdat) { if (!pdat.is_empty()) { @@ -84,7 +84,7 @@ namespace shamrock { bool err_id_in_newid = false; { - sycl::host_accessor nid{newid_buf_map.get(id), sycl::read_only}; + std::vector nid = newid_buf_map.get(id).copy_to_stdvec(); for (u32 i = 0; i < pdat.get_obj_cnt(); i++) { bool err = nid[i] == u64_max; err_id_in_newid = err_id_in_newid || (err); @@ -114,7 +114,7 @@ namespace shamrock { * @return A shared distributed data object containing the extracted patch data. */ inline shambase::DistributedDataShared extract_elements( - shambase::DistributedData> new_pid) { + shambase::DistributedData> &&new_pid) { shambase::DistributedDataShared part_exchange; StackEntry stack_loc{}; @@ -128,7 +128,7 @@ namespace shamrock { histogram_extract[current_pid] = 0; if (!pdat.is_empty()) { - sycl::host_accessor nid{new_pid.get(current_pid), sycl::read_only}; + std::vector nid = new_pid.get(current_pid).copy_to_stdvec(); if (false) { @@ -222,9 +222,10 @@ namespace shamrock { u32 ipos = sched.pdl_old().get_field_idx(position_field); - DistributedData> new_pid = compute_new_pid(sptree, ipos); + DistributedData> new_pid = compute_new_pid(sptree, ipos); - DistributedDataShared part_exchange = extract_elements(new_pid); + DistributedDataShared part_exchange + = extract_elements(std::move(new_pid)); part_exchange.for_each([](u64 sender, u64 receiver, PatchDataLayer &pdat) { shamlog_debug_ln("ReattributeDataUtility", sender, receiver, pdat.get_obj_cnt()); diff --git a/src/shamrock/include/shamrock/scheduler/SerialPatchTree.hpp b/src/shamrock/include/shamrock/scheduler/SerialPatchTree.hpp index 19f186be5..d9fa86142 100644 --- a/src/shamrock/include/shamrock/scheduler/SerialPatchTree.hpp +++ b/src/shamrock/include/shamrock/scheduler/SerialPatchTree.hpp @@ -21,13 +21,16 @@ #include "shambase/memory.hpp" #include "shambase/stacktrace.hpp" -#include "shamrock/legacy/patch/utility/patch_field.hpp" +#include "shambackends/BufferMirror.hpp" +#include "shambackends/DeviceBuffer.hpp" +#include "shambackends/kernel_call.hpp" #include "shamrock/patch/PatchField.hpp" #include "shamrock/scheduler/PatchScheduler.hpp" #include "shamrock/scheduler/PatchTree.hpp" #include "shamsys/legacy/log.hpp" #include "shamsys/legacy/sycl_handler.hpp" #include +#include #include #include @@ -38,30 +41,32 @@ class SerialPatchTree { using PatchTree = shamrock::scheduler::PatchTree; - // TODO use unique pointer instead u32 root_count = 0; - std::unique_ptr> serial_tree_buf; - std::unique_ptr> linked_patch_ids_buf; + std::optional> serial_tree_buf; + std::optional> linked_patch_ids_buf; inline void attach_buf() { - if (bool(serial_tree_buf)) + if (serial_tree_buf.has_value()) throw shambase::make_except_with_loc( "serial_tree_buf is already allocated"); - if (bool(linked_patch_ids_buf)) + if (linked_patch_ids_buf.has_value()) throw shambase::make_except_with_loc( "linked_patch_ids_buf is already allocated"); - serial_tree_buf - = std::make_unique>(serial_tree.data(), serial_tree.size()); - linked_patch_ids_buf - = std::make_unique>(linked_patch_ids.data(), linked_patch_ids.size()); + auto dev_sched = shamsys::instance::get_compute_scheduler_ptr(); + + serial_tree_buf.emplace(serial_tree.size(), dev_sched); + serial_tree_buf->copy_from_stdvec(serial_tree); + + linked_patch_ids_buf.emplace(linked_patch_ids.size(), dev_sched); + linked_patch_ids_buf->copy_from_stdvec(linked_patch_ids); } inline void detach_buf() { - if (!bool(serial_tree_buf)) + if (!serial_tree_buf.has_value()) throw shambase::make_except_with_loc( "serial_tree_buf wasn't allocated"); - if (!bool(linked_patch_ids_buf)) + if (!linked_patch_ids_buf.has_value()) throw shambase::make_except_with_loc( "linked_patch_ids_buf wasn't allocated"); @@ -149,8 +154,8 @@ class SerialPatchTree { std::function found_case) { StackEntry stack_loc{false}; - sycl::host_accessor tree{shambase::get_check_ref(serial_tree_buf), sycl::read_only}; - sycl::host_accessor lpid{shambase::get_check_ref(linked_patch_ids_buf), sycl::read_only}; + auto tree = shambase::get_check_ref(serial_tree_buf).template mirror_to(); + auto lpid = shambase::get_check_ref(linked_patch_ids_buf).template mirror_to(); host_for_each_leafs_internal(interact_cd, found_case, tree, lpid); } @@ -177,17 +182,17 @@ class SerialPatchTree { template inline shamrock::patch::PatchtreeField make_patch_tree_field( PatchScheduler &sched, - sycl::queue &queue, + sham::DeviceScheduler_ptr dev_sched, shamrock::patch::PatchField pfield, Func &&reducer) { shamrock::patch::PatchtreeField ptfield; - ptfield.allocate(get_element_count()); + ptfield.allocate(get_element_count(), dev_sched); { - sycl::host_accessor lpid{ - shambase::get_check_ref(linked_patch_ids_buf), sycl::read_only}; - sycl::host_accessor tree_field{ - shambase::get_check_ref(ptfield.internal_buf), sycl::write_only, sycl::no_init}; + auto lpid + = shambase::get_check_ref(linked_patch_ids_buf).template mirror_to(); + auto tree_field + = shambase::get_check_ref(ptfield.internal_buf).template mirror_to(); // init reduction std::unordered_map &idp_to_gid = sched.patch_list.id_patch_to_global_idx; @@ -196,18 +201,17 @@ class SerialPatchTree { } } - sycl::range<1> range{get_element_count()}; + auto &q = shambase::get_check_ref(dev_sched).get_queue(); u32 end_loop = get_level_count(); + u32 elem_cnt = get_element_count(); for (u32 level = 0; level < end_loop; level++) { - queue.submit([&](sycl::handler &cgh) { - sycl::accessor tree{shambase::get_check_ref(serial_tree_buf), cgh, sycl::read_only}; - sycl::accessor f{ - shambase::get_check_ref(ptfield.internal_buf), cgh, sycl::read_write}; - - cgh.parallel_for(range, [=](sycl::item<1> item) { - u64 i = (u64) item.get_id(0); - + sham::kernel_call( + q, + sham::MultiRef{shambase::get_check_ref(serial_tree_buf)}, + sham::MultiRef{shambase::get_check_ref(ptfield.internal_buf)}, + elem_cnt, + [reducer](u32 i, const PtNode *tree, T *f) { std::array n = tree[i].childs_id; if (n[0] != u64_max) { @@ -215,7 +219,6 @@ class SerialPatchTree { f[n[0]], f[n[1]], f[n[2]], f[n[3]], f[n[4]], f[n[5]], f[n[6]], f[n[7]]); } }); - }); } return ptfield; } @@ -243,41 +246,48 @@ class SerialPatchTree { } } - sycl::buffer compute_patch_owner( + sham::DeviceBuffer compute_patch_owner( sham::DeviceScheduler_ptr dev_sched, sham::DeviceBuffer &position_buffer, u32 len); }; template -sycl::buffer SerialPatchTree::compute_patch_owner( +sham::DeviceBuffer SerialPatchTree::compute_patch_owner( sham::DeviceScheduler_ptr dev_sched, sham::DeviceBuffer &position_buffer, u32 len) { - sycl::buffer new_owned_id(len); - - using namespace shamrock::patch; - - sycl::buffer roots = shamalgs::vec_to_buf(roots_ids); + sham::DeviceBuffer new_owned_id(len, dev_sched); - auto &q = dev_sched->get_queue(); - - sham::EventList depends_list; - auto pos = position_buffer.get_read_access(depends_list); - - auto e = q.submit(depends_list, [&](sycl::handler &cgh) { - sycl::accessor tnode{shambase::get_check_ref(serial_tree_buf), cgh, sycl::read_only}; - sycl::accessor linked_node_id{ - shambase::get_check_ref(linked_patch_ids_buf), cgh, sycl::read_only}; - sycl::accessor roots_id{roots, cgh, sycl::read_only}; - sycl::accessor new_id{new_owned_id, cgh, sycl::write_only, sycl::no_init}; - - u32 root_cnt = roots_id.size(); - auto max_lev = get_level_count(); - - using PtNode = shamrock::scheduler::SerialPatchNode; + if (len == 0) { + return new_owned_id; + } - cgh.parallel_for(sycl::range(len), [=](sycl::item<1> item) { - u32 i = (u32) item.get_id(0); + using namespace shamrock::patch; + sham::DeviceBuffer roots(roots_ids.size(), dev_sched); + roots.copy_from_stdvec(roots_ids); + + auto &q = shambase::get_check_ref(dev_sched).get_queue(); + u32 root_cnt = roots_ids.size(); + auto max_lev = get_level_count(); + + using PtNode = shamrock::scheduler::SerialPatchNode; + + sham::kernel_call( + q, + sham::MultiRef{ + position_buffer, + shambase::get_check_ref(serial_tree_buf), + shambase::get_check_ref(linked_patch_ids_buf), + roots}, + sham::MultiRef{new_owned_id}, + len, + [root_cnt, max_lev]( + u32 i, + const vec *pos, + const PtNode *tnode, + const u64 *linked_node_id, + const u64 *roots_id, + u64 *new_id) { auto xyz = pos[i]; u64 current_node = 0; @@ -358,9 +368,6 @@ sycl::buffer SerialPatchTree::compute_patch_owner( new_id[i] = result_node; }); - }); - - position_buffer.complete_event_state(e); return new_owned_id; }