@@ -29,18 +29,16 @@ Runtime *ContextImpl::getCurrentRuntime() {
2929 return current_runtime_;
3030}
3131
32- Runtime *ContextImpl::getCpuRuntime () {
33- return runtime_table_[int (Device::Type::CPU)][0 ].get ();
34- }
35-
3632void ContextImpl::setDevice (Device device) {
3733 if (device == getCurrentRuntime ()->device ()) {
3834 // Do nothing if the device is already set.
3935 return ;
4036 }
4137
42- if (getCurrentRuntime ()->isGraphRecording ()) {
38+ thread_local bool warn_switch_runtime = false ;
39+ if (getCurrentRuntime ()->isGraphRecording () && !warn_switch_runtime) {
4340 spdlog::warn (" Switching device runtime during graph recording may break the graph!" );
41+ warn_switch_runtime = true ;
4442 }
4543
4644 if (runtime_table_[int (device.getType ())][device.getIndex ()] == nullptr ) {
@@ -104,11 +102,8 @@ infinirtStream_t getStream() {
104102}
105103
106104infiniopHandle_t getInfiniopHandle (Device device) {
107- if (device.getType () == Device::Type::CPU) {
108- return ContextImpl::singleton ().getCpuRuntime ()->infiniopHandle ();
109- }
110105 if (device != getDevice ()) {
111- throw std::runtime_error ( " Requested device doesn't match current runtime. " );
106+ setDevice ( device);
112107 }
113108 return ContextImpl::singleton ().getCurrentRuntime ()->infiniopHandle ();
114109}
@@ -127,7 +122,7 @@ std::shared_ptr<Memory> allocateMemory(size_t size) {
127122
128123std::shared_ptr<Memory> allocateHostMemory (size_t size) {
129124 setDevice (Device::cpu ());
130- return ContextImpl::singleton (). getCpuRuntime ()-> allocateMemory (size);
125+ return allocateMemory (size);
131126}
132127
133128std::shared_ptr<Memory> allocatePinnedHostMemory (size_t size) {
@@ -147,7 +142,8 @@ void memcpyD2D(void *dst, const void *src, size_t size, bool async) {
147142}
148143
149144void memcpyH2H (void *dst, const void *src, size_t size) {
150- return ContextImpl::singleton ().getCpuRuntime ()->memcpyD2D (dst, src, size);
145+ setDevice (Device::cpu ());
146+ return ContextImpl::singleton ().getCurrentRuntime ()->memcpyD2D (dst, src, size);
151147}
152148
153149// Timing API implementations
0 commit comments