7 #ifdef YAKL_ENABLE_STREAMS
14 #if defined(YAKL_ARCH_CUDA)
21 cudaStream_t my_stream;
24 void nullify() { my_stream = 0; refCount =
nullptr; }
28 Stream() { nullify(); }
29 Stream(cudaStream_t cuda_stream) { nullify(); my_stream = cuda_stream; }
32 Stream(Stream
const &rhs) {
33 my_stream = rhs.my_stream;
34 refCount = rhs.refCount;
35 if (refCount !=
nullptr) (*refCount)++;
37 Stream(Stream &&rhs) {
38 my_stream = rhs.my_stream;
39 refCount = rhs.refCount;
42 Stream & operator=(Stream
const &rhs) {
45 my_stream = rhs.my_stream;
46 refCount = rhs.refCount;
47 if (refCount !=
nullptr) (*refCount)++;
51 Stream & operator=(Stream &&rhs) {
54 my_stream = rhs.my_stream;
55 refCount = rhs.refCount;
62 if (refCount ==
nullptr) {
70 if (refCount !=
nullptr) {
72 if ( (*refCount) == 0 ) {
80 cudaStream_t get_real_stream() {
return my_stream; }
81 bool operator==(Stream stream)
const {
return my_stream == stream.get_real_stream(); }
84 bool completed() {
return cudaStreamQuery( my_stream ) == cudaSuccess; }
85 void fence() {
if(!
completed()) cudaStreamSynchronize(my_stream); }
94 void nullify() { my_event = 0; refCount =
nullptr; }
98 Event() { nullify(); }
101 Event(Event
const &rhs) {
102 my_event = rhs.my_event;
103 refCount = rhs.refCount;
104 if (refCount !=
nullptr) (*refCount)++;
107 my_event = rhs.my_event;
108 refCount = rhs.refCount;
111 Event & operator=(Event
const &rhs) {
114 my_event = rhs.my_event;
115 refCount = rhs.refCount;
116 if (refCount !=
nullptr) (*refCount)++;
120 Event & operator=(Event &&rhs) {
123 my_event = rhs.my_event;
124 refCount = rhs.refCount;
131 if (refCount ==
nullptr) {
134 cudaEventCreateWithFlags( &my_event, cudaEventDisableTiming );
139 if (refCount !=
nullptr) {
141 if ( (*refCount) == 0 ) { cudaEventDestroy( my_event );
delete refCount; nullify(); }
145 inline void record(Stream stream);
146 cudaEvent_t get_real_event() {
return my_event; }
147 bool operator==(Event event)
const {
return my_event ==
event.get_real_event(); }
148 bool completed() {
return cudaEventQuery( my_event ) == cudaSuccess; }
155 cudaEventRecord( my_event , stream.get_real_stream() );
160 cudaStreamWaitEvent( my_stream , event.get_real_event() , 0 );
163 #elif defined(YAKL_ARCH_HIP)
170 hipStream_t my_stream;
173 void nullify() { my_stream = 0; refCount =
nullptr; }
177 Stream() { nullify(); }
178 Stream(hipStream_t hip_stream) { nullify(); my_stream = hip_stream; }
181 Stream(Stream
const &rhs) {
182 my_stream = rhs.my_stream;
183 refCount = rhs.refCount;
184 if (refCount !=
nullptr) (*refCount)++;
186 Stream(Stream &&rhs) {
187 my_stream = rhs.my_stream;
188 refCount = rhs.refCount;
191 Stream & operator=(Stream
const &rhs) {
194 my_stream = rhs.my_stream;
195 refCount = rhs.refCount;
196 if (refCount !=
nullptr) (*refCount)++;
200 Stream & operator=(Stream &&rhs) {
203 my_stream = rhs.my_stream;
204 refCount = rhs.refCount;
211 if (refCount ==
nullptr) {
214 if constexpr (
streams_enabled) hipStreamCreateWithFlags( &my_stream, hipStreamNonBlocking );
219 if (refCount !=
nullptr) {
221 if ( (*refCount) == 0 ) {
229 hipStream_t get_real_stream() {
return my_stream; }
230 bool operator==(Stream stream)
const {
return my_stream == stream.get_real_stream(); }
233 bool completed() {
return hipStreamQuery( my_stream ) == hipSuccess; }
234 void fence() {
if(!
completed()) hipStreamSynchronize(my_stream); }
243 void nullify() { my_event = 0; refCount =
nullptr; }
247 Event() { nullify(); }
250 Event(Event
const &rhs) {
251 my_event = rhs.my_event;
252 refCount = rhs.refCount;
253 if (refCount !=
nullptr) (*refCount)++;
256 my_event = rhs.my_event;
257 refCount = rhs.refCount;
260 Event & operator=(Event
const &rhs) {
263 my_event = rhs.my_event;
264 refCount = rhs.refCount;
265 if (refCount !=
nullptr) (*refCount)++;
269 Event & operator=(Event &&rhs) {
272 my_event = rhs.my_event;
273 refCount = rhs.refCount;
280 if (refCount ==
nullptr) {
283 hipEventCreateWithFlags( &my_event, hipEventDisableTiming );
288 if (refCount !=
nullptr) {
290 if ( (*refCount) == 0 ) { hipEventDestroy( my_event );
delete refCount; nullify(); }
294 inline void record(Stream stream);
295 hipEvent_t get_real_event() {
return my_event; }
296 bool operator==(Event event)
const {
return my_event ==
event.get_real_event(); }
297 bool completed() {
return hipEventQuery( my_event ) == hipSuccess; }
304 hipEventRecord( my_event , stream.get_real_stream() );
309 hipStreamWaitEvent( my_stream , event.get_real_event() , 0 );
312 #elif defined(YAKL_ARCH_SYCL)
319 std::shared_ptr<sycl::queue> my_stream{
nullptr};
324 Stream(sycl::queue &sycl_queue) { my_stream = std::make_shared<sycl::queue>(sycl_queue); }
326 Stream(
const Stream& ) =
default;
327 Stream( Stream&& ) noexcept = default;
328 Stream& operator=( const Stream& ) = default;
329 Stream& operator=( Stream&& ) noexcept = default;
334 my_stream = std::make_shared<sycl::queue>( sycl_default_stream().get_context() ,
335 sycl_default_stream().get_device() ,
337 sycl::property_list{sycl::property::queue::in_order{}} );
341 sycl::queue & get_real_stream()
const {
return (my_stream !=
nullptr) ? *my_stream : sycl_default_stream(); }
342 bool operator==(Stream stream)
const {
return get_real_stream() == stream.get_real_stream(); }
344 bool is_default_stream()
const {
return get_real_stream() == sycl_default_stream(); }
347 #if defined(SYCL_EXT_ONEAPI_QUEUE_EMPTY)
348 return get_real_stream().ext_oneapi_empty();
359 sycl::event my_event;
366 Event(Event
const &rhs) { my_event = rhs.my_event; }
367 Event(Event &&rhs) { my_event = rhs.my_event; }
368 Event & operator=(Event
const &rhs) {
if (
this != &rhs) { my_event = rhs.my_event; };
return *
this; }
369 Event & operator=(Event &&rhs) {
if (
this != &rhs) { my_event = rhs.my_event; };
return *
this; }
374 inline void record(Stream stream);
375 sycl::event & get_real_event() {
return my_event; }
376 bool operator==(Event event)
const {
return my_event ==
event.get_real_event(); }
377 bool completed() {
return my_event.get_info<sycl::info::event::command_execution_status>() == sycl::info::event_command_status::complete; }
382 inline void Event::record(Stream stream) { my_event = stream.get_real_stream().ext_oneapi_submit_barrier(); }
384 inline void Stream::wait_on_event(Event event) { this->get_real_stream().ext_oneapi_submit_barrier({
event.get_real_event()}); }
448 std::vector<Stream> *list;
460 list->push_back(stream);
464 int size()
const {
return list->size(); }
466 bool empty()
const {
return list->empty(); }