/home/runner/work/HiCR/HiCR/include/hicr/frontends/channel/variableSize/mpsc/locking/consumer.hpp Source File

HiCR: /home/runner/work/HiCR/HiCR/include/hicr/frontends/channel/variableSize/mpsc/locking/consumer.hpp Source File
HiCR
consumer.hpp
Go to the documentation of this file.
1/*
2 * Copyright 2025 Huawei Technologies Co., Ltd.
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
24#pragma once
25
27#include <utility>
28
29namespace HiCR::channel::variableSize::MPSC::locking
30{
31
38class Consumer final : public variableSize::Base
39{
40 public:
41
66 Consumer(CommunicationManager &coordinationCommunicationManager,
67 CommunicationManager &payloadCommunicationManager,
68 std::shared_ptr<GlobalMemorySlot> payloadBuffer,
69 std::shared_ptr<GlobalMemorySlot> tokenBuffer,
70 const std::shared_ptr<LocalMemorySlot> &internalCoordinationBufferForCounts,
71 const std::shared_ptr<LocalMemorySlot> &internalCoordinationBufferForPayloads,
72 const std::shared_ptr<GlobalMemorySlot> &consumerCoordinationBufferForCounts,
73 std::shared_ptr<GlobalMemorySlot> consumerCoordinationBufferForPayloads,
74 const size_t payloadCapacity,
75 const size_t capacity)
76 : variableSize::Base(coordinationCommunicationManager,
77 payloadCommunicationManager,
78 internalCoordinationBufferForCounts,
79 internalCoordinationBufferForPayloads,
80 capacity,
81 payloadCapacity),
82 _payloadBuffer(std::move(payloadBuffer)),
83 _tokenSizeBuffer(std::move(tokenBuffer)),
84 _consumerCoordinationBufferForCounts(consumerCoordinationBufferForCounts),
85 _consumerCoordinationBufferForPayloads(std::move(consumerCoordinationBufferForPayloads))
86 {
87 assert(internalCoordinationBufferForCounts != nullptr);
88 assert(internalCoordinationBufferForPayloads != nullptr);
89 assert(consumerCoordinationBufferForCounts != nullptr);
90 getCoordinationCommunicationManager()->queryMemorySlotUpdates(_tokenSizeBuffer->getSourceLocalMemorySlot());
91 getPayloadCommunicationManager()->queryMemorySlotUpdates(_payloadBuffer->getSourceLocalMemorySlot());
92 }
93
114 __INLINE__ size_t basePeek(const size_t pos = 0)
115 {
116 // Check if the requested position exceeds the capacity of the channel
117 if (pos >= getCircularBufferForCounts()->getCapacity())
118 HICR_THROW_LOGIC("Attempting to peek for a token with position (%lu), which is beyond than the channel capacity (%lu)", pos, getCircularBufferForCounts()->getCapacity());
119
120 // Check if there are enough tokens in the buffer to satisfy the request
121 if (pos >= getCircularBufferForCounts()->getDepth())
122 HICR_THROW_RUNTIME("Attempting to peek position (%lu) but not enough tokens (%lu) are in the buffer", pos, getCircularBufferForCounts()->getDepth());
123
124 // Calculating buffer position
125 const size_t bufferPos = (getCircularBufferForCounts()->getTailPosition() + pos) % getCircularBufferForCounts()->getCapacity();
126
127 // Succeeded in pushing the token(s)
128 return bufferPos;
129 }
130
144 __INLINE__ std::array<size_t, 2> peek(const size_t pos = 0)
145 {
148 std::array<size_t, 2> result{};
149 if (pos != 0) { HICR_THROW_FATAL("peek only implemented for n = 0 at the moment!"); }
150 if (pos >= getCircularBufferForCounts()->getDepth())
151 {
152 HICR_THROW_RUNTIME("Attempting to peek position (%lu) but not enough tokens (%lu) are in the buffer", pos, getCircularBufferForCounts()->getDepth());
153 }
154
155 result[0] = getCircularBufferForPayloads()->getTailPosition() % getCircularBufferForPayloads()->getCapacity();
156 size_t *tokenBufferPtr = static_cast<size_t *>(_tokenSizeBuffer->getSourceLocalMemorySlot()->getPointer());
157 auto tokenPos = basePeek(pos);
158 result[1] = tokenBufferPtr[tokenPos];
159 return result;
160 }
161
168 __INLINE__ size_t getOldPayloadBytes(size_t n)
169 {
170 if (n == 0) return 0;
171 size_t *tokenBufferPtr = static_cast<size_t *>(_tokenSizeBuffer->getSourceLocalMemorySlot()->getPointer());
172
173 size_t payloadBytes = 0;
174 for (size_t i = 0; i < n; i++)
175 {
176 assert(i >= 0);
177 size_t pos = basePeek(i);
178 auto payloadSize = tokenBufferPtr[pos];
179 payloadBytes += payloadSize;
180 }
181 return payloadBytes;
182 }
183
189 __INLINE__ size_t getNewPayloadBytes(size_t n)
190 {
191 if (n == 0) return 0;
192 size_t *tokenBufferPtr = static_cast<size_t *>(_tokenSizeBuffer->getSourceLocalMemorySlot()->getPointer());
193 size_t payloadBytes = 0;
194
195 for (size_t i = 0; i < n; i++)
196 {
197 size_t ind = getCircularBufferForCounts()->getDepth() - 1 - i;
198 assert(ind >= 0);
199 size_t pos = basePeek(ind);
200 auto payloadSize = tokenBufferPtr[pos];
201 payloadBytes += payloadSize;
202 }
203
204 return payloadBytes;
205 }
206
218 __INLINE__ bool pop(const size_t n = 1)
219 {
220 bool successFlag = false;
221
222 auto coordinationCommunicationManager = getCoordinationCommunicationManager();
223
224 // Locking remote coordination buffer slot
225 if (coordinationCommunicationManager->acquireGlobalLock(_consumerCoordinationBufferForCounts) == false) return successFlag;
226
227 if (n > getCircularBufferForCounts()->getCapacity())
228 HICR_THROW_LOGIC("Attempting to pop (%lu) tokens, which is larger than the channel capacity (%lu)", n, getCircularBufferForCounts()->getCapacity());
229 // If the exchange buffer does not have n tokens pushed, reject operation
231 HICR_THROW_RUNTIME("Attempting to pop (%lu) tokens, which is more than the number of current tokens in the channel (%lu)", n, getCircularBufferForCounts()->getDepth());
232
233 size_t *tokenBufferPtr = static_cast<size_t *>(_tokenSizeBuffer->getSourceLocalMemorySlot()->getPointer());
234
235 size_t bytesOldestEntry = tokenBufferPtr[getCircularBufferForCounts()->getTailPosition()];
236
237 getCircularBufferForCounts()->advanceTail(n);
238 getCircularBufferForPayloads()->advanceTail(bytesOldestEntry);
239
240 coordinationCommunicationManager->releaseGlobalLock(_consumerCoordinationBufferForCounts);
241 successFlag = true;
242 return successFlag;
243 }
244
261 size_t getDepth() { return getCircularBufferForCounts()->getDepth(); }
262
271 bool isEmpty() { return (getDepth() == 0); }
272
278 [[nodiscard]] std::shared_ptr<GlobalMemorySlot> getPayloadBufferMemorySlot() const { return _payloadBuffer; }
279
280 private:
281
285 std::shared_ptr<GlobalMemorySlot> _payloadBuffer;
286
292 const std::shared_ptr<GlobalMemorySlot> _tokenSizeBuffer;
293
298 const std::shared_ptr<GlobalMemorySlot> _consumerCoordinationBufferForCounts;
299
304 const std::shared_ptr<GlobalMemorySlot> _consumerCoordinationBufferForPayloads;
305};
306
307} // namespace HiCR::channel::variableSize::MPSC::locking
Definition communicationManager.hpp:54
virtual __INLINE__ void flushReceived()
Definition communicationManager.hpp:469
__INLINE__ void queryMemorySlotUpdates(std::shared_ptr< LocalMemorySlot > memorySlot)
Definition communicationManager.hpp:229
__INLINE__ CommunicationManager * getPayloadCommunicationManager() const
Definition base.hpp:223
__INLINE__ CommunicationManager * getCoordinationCommunicationManager() const
Definition base.hpp:229
Definition base.hpp:41
__INLINE__ auto getCircularBufferForPayloads() const
Definition base.hpp:101
__INLINE__ auto getCircularBufferForCounts() const
Definition base.hpp:95
__INLINE__ size_t basePeek(const size_t pos=0)
Definition consumer.hpp:114
std::shared_ptr< GlobalMemorySlot > getPayloadBufferMemorySlot() const
Definition consumer.hpp:278
__INLINE__ bool pop(const size_t n=1)
Definition consumer.hpp:218
Consumer(CommunicationManager &coordinationCommunicationManager, CommunicationManager &payloadCommunicationManager, std::shared_ptr< GlobalMemorySlot > payloadBuffer, std::shared_ptr< GlobalMemorySlot > tokenBuffer, const std::shared_ptr< LocalMemorySlot > &internalCoordinationBufferForCounts, const std::shared_ptr< LocalMemorySlot > &internalCoordinationBufferForPayloads, const std::shared_ptr< GlobalMemorySlot > &consumerCoordinationBufferForCounts, std::shared_ptr< GlobalMemorySlot > consumerCoordinationBufferForPayloads, const size_t payloadCapacity, const size_t capacity)
Definition consumer.hpp:66
size_t getDepth()
Definition consumer.hpp:261
__INLINE__ std::array< size_t, 2 > peek(const size_t pos=0)
Definition consumer.hpp:144
__INLINE__ size_t getNewPayloadBytes(size_t n)
Definition consumer.hpp:189
bool isEmpty()
Definition consumer.hpp:271
__INLINE__ size_t getOldPayloadBytes(size_t n)
Definition consumer.hpp:168
#define HICR_THROW_RUNTIME(...)
Definition exceptions.hpp:74
#define HICR_THROW_LOGIC(...)
Definition exceptions.hpp:67
#define HICR_THROW_FATAL(...)
Definition exceptions.hpp:81
extends channel::Base into a base enabling var-size messages