/home/runner/work/HiCR/HiCR/include/hicr/frontends/tasking/worker.hpp Source File

HiCR: /home/runner/work/HiCR/HiCR/include/hicr/frontends/tasking/worker.hpp Source File
HiCR
worker.hpp
Go to the documentation of this file.
1/*
2 * Copyright 2025 Huawei Technologies Co., Ltd.
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
24#pragma once
25
26#include <thread>
27#include <memory>
28#include <utility>
29#include <vector>
30#include <set>
31#include <unistd.h>
32#include <hicr/core/definitions.hpp>
36#include "task.hpp"
37
38namespace HiCR::tasking
39{
40
44constexpr size_t _DEFAULT_SUSPEND_INTERVAL_MS = 1000;
45
49constexpr size_t _MILISECONDS_PER_SECOND = 1000;
50
54using pullFunction_t = std::function<HiCR::tasking::Task *()>;
55
63class Worker
64{
65 public:
66
97
102
148
157 Worker(HiCR::ComputeManager *executionStateComputeManager,
158 HiCR::ComputeManager *processingUnitComputeManager,
159 pullFunction_t pullFunction,
160 workerCallbackMap_t *callbackMap = nullptr)
161 : _executionStateComputeManager(executionStateComputeManager),
162 _processingUnitComputeManager(processingUnitComputeManager),
163 _pullFunction(std::move(pullFunction)),
164 _callbackMap(callbackMap)
165 {
166 _state = state_t::uninitialized;
167 }
168
169 virtual ~Worker() = default;
170
176 __INLINE__ const state_t getState() { return _state; }
177
183 __INLINE__ void setCallbackMap(workerCallbackMap_t *callbackMap) { _callbackMap = callbackMap; }
184
190 __INLINE__ workerCallbackMap_t *getCallbackMap() { return _callbackMap; }
191
197 __INLINE__ HiCR::tasking::Task *getCurrentTask() { return _currentTask; }
198
202 __INLINE__ void initialize()
203 {
204 // Grabbing state value
205 auto prevState = _state.load();
206
207 // Checking we have at least one assigned resource
208 if (_processingUnits.empty()) HICR_THROW_LOGIC("Attempting to initialize worker without any assigned resources");
209
210 // Checking state
211 if (prevState != state_t::uninitialized && prevState != state_t::terminated) HICR_THROW_RUNTIME("Attempting to initialize already initialized worker");
212
213 // Initializing all resources
214 for (auto &r : _processingUnits) _processingUnitComputeManager->initialize(r);
215
216 // Transitioning state
217 _state = state_t::ready;
218 }
219
223 __INLINE__ void start()
224 {
225 // Grabbing state value
226 auto prevState = _state.load();
227
228 // Checking state
229 if (prevState != state_t::ready) HICR_THROW_RUNTIME("Attempting to start worker that is not in the 'initialized' state");
230
231 // Setting state
232 _state = state_t::running;
233
234 // Creating new execution unit (the processing unit must support an execution unit of 'host' type)
235 auto executionUnit = _executionStateComputeManager->createExecutionUnit([](void *worker) { static_cast<HiCR::tasking::Worker *>(worker)->mainLoop(); });
236
237 // Creating worker's execution state
238 auto executionState = _executionStateComputeManager->createExecutionState(executionUnit, this);
239
240 // Launching worker in the lead resource (first one to be added)
241 _processingUnitComputeManager->start(_processingUnits[0], executionState);
242 }
243
250 __INLINE__ bool suspend()
251 {
252 // Doing an atomic exchange
253 state_t expected = state_t::running;
254 bool succeeded = _state.compare_exchange_weak(expected, state_t::suspending);
255
256 // Checking exchange
257 return succeeded;
258 }
259
265 __INLINE__ bool resume()
266 {
267 // Doing an atomic exchange
268 state_t expected = state_t::suspended;
269 bool succeeded = _state.compare_exchange_weak(expected, state_t::resuming);
270
271 // Checking exchange
272 return succeeded;
273 }
274
278 __INLINE__ void terminate()
279 {
280 // Transitioning state
281 auto prevState = _state.exchange(state_t::terminating);
282
283 // Checking state
284 if (prevState != state_t::running && prevState != state_t::suspending) HICR_THROW_RUNTIME("Attempting to stop worker that is not in a terminate-able state");
285 }
286
290 __INLINE__ void await()
291 {
292 // Getting state
293 auto prevState = _state.load();
294
295 if (prevState != state_t::terminating && prevState != state_t::running && prevState != state_t::suspended && prevState != state_t::suspending && prevState != state_t::resuming)
296 HICR_THROW_RUNTIME("Attempting to wait for a worker that has not yet started or has already terminated");
297
298 // Wait for the resources to free up
299 for (auto &p : _processingUnits) _processingUnitComputeManager->await(p);
300
301 // Transitioning state
302 _state = state_t::terminated;
303 }
304
310 __INLINE__ void addProcessingUnit(std::unique_ptr<HiCR::ProcessingUnit> pu) { _processingUnits.push_back(std::move(pu)); }
311
317 __INLINE__ std::vector<std::unique_ptr<HiCR::ProcessingUnit>> &getProcessingUnits() { return _processingUnits; }
318
324 __INLINE__ void setSuspendInterval(size_t suspendIntervalMs) { _suspendIntervalMs = suspendIntervalMs; }
325
326 protected:
327
333 __INLINE__ virtual bool checkResumeConditions() { return _state == state_t::resuming; }
334
335 private:
336
340 HiCR::ComputeManager *const _executionStateComputeManager;
341
345 HiCR::ComputeManager *const _processingUnitComputeManager;
346
350 HiCR::tasking::Task *_currentTask = nullptr;
351
355 const pullFunction_t _pullFunction;
356
360 size_t _suspendIntervalMs = _DEFAULT_SUSPEND_INTERVAL_MS;
361
365 std::atomic<state_t> _state;
366
370 std::vector<std::unique_ptr<HiCR::ProcessingUnit>> _processingUnits;
371
375 workerCallbackMap_t *_callbackMap = nullptr;
376
380 __INLINE__ void mainLoop()
381 {
382 // Calling appropriate callback
383 if (_callbackMap != nullptr) _callbackMap->trigger(this, callback_t::onWorkerStart);
384
385 // Start main worker loop (run until terminated)
386 while (true)
387 {
388 // Attempt to get a task by executing the pull function
389 _currentTask = _pullFunction();
390
391 // Calling appropriate callback
392 if (_callbackMap != nullptr) _callbackMap->trigger(this, callback_t::onWorkerTaskPulled);
393
394 // If a task was returned, then start or execute it
395 if (_currentTask != nullptr)
396 {
397 // If the task hasn't been initialized yet, we need to do it now
399 {
400 // First, create new execution state for the processing unit
401 auto executionState = _executionStateComputeManager->createExecutionState(_currentTask->getExecutionUnit(), _currentTask);
402
403 // Then initialize the task with the new execution state
404 _currentTask->initialize(std::move(executionState));
405 }
406
407 // Now actually run the task
408 _currentTask->run();
409 }
410
411 // Requesting processing units to terminate as soon as possible
412 if (_state == state_t::suspending)
413 {
414 // Setting state as suspended
415 _state = state_t::suspended;
416
417 // Calling appropriate callback
418 if (_callbackMap != nullptr) _callbackMap->trigger(this, callback_t::onWorkerSuspend);
419
420 // Suspending other processing units
421 for (size_t i = 1; i < _processingUnits.size(); i++) _processingUnitComputeManager->suspend(_processingUnits[i]);
422
423 // Putting current processing unit to check every so often
424 while (checkResumeConditions() == false) usleep(_suspendIntervalMs * _MILISECONDS_PER_SECOND);
425
426 // Calling appropriate callback
427 if (_callbackMap != nullptr) _callbackMap->trigger(this, callback_t::onWorkerResume);
428
429 // Resuming other processing units
430 for (size_t i = 1; i < _processingUnits.size(); i++) _processingUnitComputeManager->resume(_processingUnits[i]);
431
432 // Setting worker as running
433 _state = state_t::running;
434 }
435
436 // Requesting processing units to terminate as soon as possible
437 if (_state == state_t::terminating)
438 {
439 // Calling appropriate callback
440 if (_callbackMap != nullptr) _callbackMap->trigger(this, callback_t::onWorkerTerminate);
441
442 // Terminate secondary processing units first
443 for (size_t i = 1; i < _processingUnits.size(); i++) _processingUnitComputeManager->terminate(_processingUnits[i]);
444
445 // Then terminate current processing unit
446 _processingUnitComputeManager->terminate(_processingUnits[0]);
447
448 // Return immediately
449 return;
450 }
451 }
452 }
453}; // class Worker
454
455} // namespace HiCR::tasking
Definition computeManager.hpp:48
__INLINE__ void initialize(std::unique_ptr< HiCR::ProcessingUnit > &processingUnit)
Definition computeManager.hpp:93
__INLINE__ void suspend(std::unique_ptr< HiCR::ProcessingUnit > &processingUnit)
Definition computeManager.hpp:135
__INLINE__ void start(std::unique_ptr< HiCR::ProcessingUnit > &processingUnit, std::unique_ptr< HiCR::ExecutionState > &executionState)
Definition computeManager.hpp:115
virtual std::unique_ptr< HiCR::ExecutionState > createExecutionState(std::shared_ptr< HiCR::ExecutionUnit > executionUnit, void *const argument=nullptr) const =0
__INLINE__ void terminate(std::unique_ptr< HiCR::ProcessingUnit > &processingUnit)
Definition computeManager.hpp:175
virtual __INLINE__ std::shared_ptr< HiCR::ExecutionUnit > createExecutionUnit(const replicableFc_t &function)
Definition computeManager.hpp:63
__INLINE__ void await(std::unique_ptr< HiCR::ProcessingUnit > &processingUnit)
Definition computeManager.hpp:189
__INLINE__ void resume(std::unique_ptr< HiCR::ProcessingUnit > &processingUnit)
Definition computeManager.hpp:155
@ uninitialized
Definition executionState.hpp:49
Definition callbackMap.hpp:40
Definition task.hpp:57
__INLINE__ std::shared_ptr< HiCR::ExecutionUnit > getExecutionUnit() const
Definition task.hpp:152
__INLINE__ void run()
Definition task.hpp:173
__INLINE__ void initialize(std::unique_ptr< HiCR::ExecutionState > executionState)
Definition task.hpp:159
__INLINE__ const HiCR::ExecutionState::state_t getState()
Definition task.hpp:131
Definition worker.hpp:64
HiCR::tasking::CallbackMap< Worker *, callback_t > workerCallbackMap_t
Definition worker.hpp:101
__INLINE__ bool suspend()
Definition worker.hpp:250
__INLINE__ void setSuspendInterval(size_t suspendIntervalMs)
Definition worker.hpp:324
__INLINE__ std::vector< std::unique_ptr< HiCR::ProcessingUnit > > & getProcessingUnits()
Definition worker.hpp:317
Worker(HiCR::ComputeManager *executionStateComputeManager, HiCR::ComputeManager *processingUnitComputeManager, pullFunction_t pullFunction, workerCallbackMap_t *callbackMap=nullptr)
Definition worker.hpp:157
__INLINE__ HiCR::tasking::Task * getCurrentTask()
Definition worker.hpp:197
callback_t
Definition worker.hpp:71
@ onWorkerStart
Definition worker.hpp:75
@ onWorkerTaskPulled
Definition worker.hpp:80
@ onWorkerSuspend
Definition worker.hpp:85
@ onWorkerResume
Definition worker.hpp:90
@ onWorkerTerminate
Definition worker.hpp:95
__INLINE__ workerCallbackMap_t * getCallbackMap()
Definition worker.hpp:190
__INLINE__ void start()
Definition worker.hpp:223
__INLINE__ void initialize()
Definition worker.hpp:202
__INLINE__ void addProcessingUnit(std::unique_ptr< HiCR::ProcessingUnit > pu)
Definition worker.hpp:310
__INLINE__ void terminate()
Definition worker.hpp:278
state_t
Definition worker.hpp:107
@ suspending
Definition worker.hpp:126
@ running
Definition worker.hpp:121
@ uninitialized
Definition worker.hpp:111
@ suspended
Definition worker.hpp:131
@ terminated
Definition worker.hpp:146
@ resuming
Definition worker.hpp:136
@ ready
Definition worker.hpp:116
@ terminating
Definition worker.hpp:141
__INLINE__ void await()
Definition worker.hpp:290
__INLINE__ const state_t getState()
Definition worker.hpp:176
virtual __INLINE__ bool checkResumeConditions()
Definition worker.hpp:333
__INLINE__ void setCallbackMap(workerCallbackMap_t *callbackMap)
Definition worker.hpp:183
__INLINE__ bool resume()
Definition worker.hpp:265
Provides a definition for the abstract compute manager class.
Provides a definition for a HiCR ProcessingUnit class.
Provides a failure model and corresponding exception classes.
#define HICR_THROW_RUNTIME(...)
Definition exceptions.hpp:74
#define HICR_THROW_LOGIC(...)
Definition exceptions.hpp:67
This file implements the HiCR task class.
constexpr size_t _MILISECONDS_PER_SECOND
Definition worker.hpp:49
std::function< HiCR::tasking::Task *()> pullFunction_t
Definition worker.hpp:54
constexpr size_t _DEFAULT_SUSPEND_INTERVAL_MS
Definition worker.hpp:44