/home/runner/work/HiCR/HiCR/include/hicr/frontends/channel/variableSize/spsc/consumer.hpp Source File

HiCR: /home/runner/work/HiCR/HiCR/include/hicr/frontends/channel/variableSize/spsc/consumer.hpp Source File
HiCR
consumer.hpp
Go to the documentation of this file.
1/*
2 * Copyright 2025 Huawei Technologies Co., Ltd.
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
24#pragma once
25
26#include <array>
27#include <numeric>
28#include <cassert>
30#include <utility>
31
32namespace HiCR::channel::variableSize::SPSC
33{
34
41class Consumer final : public variableSize::Base
42{
43 public:
44
69 Consumer(CommunicationManager &coordinationCommunicationManager,
70 CommunicationManager &payloadCommunicationManager,
71 std::shared_ptr<GlobalMemorySlot> payloadBuffer,
72 std::shared_ptr<GlobalMemorySlot> tokenBuffer,
73 const std::shared_ptr<LocalMemorySlot> &internalCoordinationBufferForCounts,
74 const std::shared_ptr<LocalMemorySlot> &internalCoordinationBufferForPayloads,
75 const std::shared_ptr<GlobalMemorySlot> &producerCoordinationBufferForCounts,
76 std::shared_ptr<GlobalMemorySlot> producerCoordinationBufferForPayloads,
77 const size_t payloadCapacity,
78 const size_t capacity)
79 : variableSize::Base(coordinationCommunicationManager,
80 payloadCommunicationManager,
81 internalCoordinationBufferForCounts,
82 internalCoordinationBufferForPayloads,
83 capacity,
84 payloadCapacity),
85
86 _payloadBuffer(std::move(payloadBuffer)),
87 _tokenSizeBuffer(std::move(tokenBuffer)),
88 _producerCoordinationBufferForCounts(producerCoordinationBufferForCounts),
89 _producerCoordinationBufferForPayloads(std::move(producerCoordinationBufferForPayloads))
90 {
91 assert(internalCoordinationBufferForCounts != nullptr);
92 assert(internalCoordinationBufferForPayloads != nullptr);
93 assert(producerCoordinationBufferForCounts != nullptr);
94 assert(producerCoordinationBufferForCounts != nullptr);
95 }
96
117 __INLINE__ size_t basePeek(const size_t pos = 0)
118 {
119 // Check if the requested position exceeds the capacity of the channel
120 if (pos >= getCircularBufferForCounts()->getCapacity())
121 HICR_THROW_LOGIC("Attempting to peek for a token with position (%lu), which is beyond than the channel capacity (%lu)", pos, getCircularBufferForCounts()->getCapacity());
122
123 // Updating channel depth
124 updateDepth();
125
126 // Check if there are enough tokens in the buffer to satisfy the request
127 if (pos >= getCircularBufferForCounts()->getDepth())
128 HICR_THROW_RUNTIME("Attempting to peek position (%lu) but not enough tokens (%lu) are in the buffer", pos, getCircularBufferForCounts()->getDepth());
129
130 // Calculating buffer position
131 const size_t bufferPos = (getCircularBufferForCounts()->getTailPosition() + pos) % getCircularBufferForCounts()->getCapacity();
132
133 // Succeeded in pushing the token(s)
134 return bufferPos;
135 }
136
145 __INLINE__ static size_t getPayloadBufferSize(const size_t payloadSize) noexcept { return payloadSize * 2; }
146
157 __INLINE__ std::array<size_t, 2> peek(const size_t pos = 0)
158 {
159 if (pos != 0) { HICR_THROW_FATAL("peek only implemented for n = 0 at the moment!"); }
160 updateDepth();
161
162 if (pos >= getCircularBufferForCounts()->getDepth())
163 {
164 HICR_THROW_RUNTIME("Attempting to peek position (%lu) but not enough tokens (%lu) are in the buffer", pos, getCircularBufferForCounts()->getDepth());
165 }
166
167 std::array<size_t, 2> result{};
168 result[0] = getCircularBufferForPayloads()->getTailPosition() % getCircularBufferForPayloads()->getCapacity();
169 size_t *tokenBufferPtr = static_cast<size_t *>(_tokenSizeBuffer->getSourceLocalMemorySlot()->getPointer());
170 auto tokenPos = basePeek(pos);
171 result[1] = tokenBufferPtr[tokenPos];
172 return result;
173 }
174
180 size_t getOldPayloadBytes(size_t n)
181 {
182 if (n == 0) return 0;
183 size_t *tokenBufferPtr = static_cast<size_t *>(_tokenSizeBuffer->getSourceLocalMemorySlot()->getPointer());
184
185 size_t payloadBytes = 0;
186 for (size_t i = 0; i < n; i++)
187 {
188 assert(i >= 0);
189 size_t pos = basePeek(i);
190 auto payloadSize = tokenBufferPtr[pos];
191 payloadBytes += payloadSize;
192 }
193 return payloadBytes;
194 }
195
206 __INLINE__ void pop(const size_t n = 1)
207 {
208 if (n > getCircularBufferForCounts()->getCapacity())
209 HICR_THROW_LOGIC("Attempting to pop (%lu) tokens, which is larger than the channel capacity (%lu)", n, getCircularBufferForCounts()->getCapacity());
210
211 // Updating channel depth
212 updateDepth();
213
214 // If the exchange buffer does not have n tokens pushed, reject operation
216 HICR_THROW_RUNTIME("Attempting to pop (%lu) tokens, which is more than the number of current tokens in the channel (%lu)", n, getCircularBufferForCounts()->getDepth());
217 auto payloadBytes = getOldPayloadBytes(n);
218 getCircularBufferForCounts()->advanceTail(n);
219 getCircularBufferForPayloads()->advanceTail(payloadBytes);
220
221 auto coordinationCommunicationManager = getCoordinationCommunicationManager();
222
223 const auto coordBuffElemSize = sizeof(_HICR_CHANNEL_COORDINATION_BUFFER_ELEMENT_TYPE);
224 // Notifying producer(s) of buffer liberation
225 coordinationCommunicationManager->memcpy(_producerCoordinationBufferForCounts, /* destination */
226 _HICR_CHANNEL_TAIL_ADVANCE_COUNT_IDX * coordBuffElemSize,
228 _HICR_CHANNEL_TAIL_ADVANCE_COUNT_IDX * coordBuffElemSize,
229 coordBuffElemSize);
230
231 coordinationCommunicationManager->memcpy(_producerCoordinationBufferForPayloads, /* destination */
232 _HICR_CHANNEL_TAIL_ADVANCE_COUNT_IDX * coordBuffElemSize,
234 _HICR_CHANNEL_TAIL_ADVANCE_COUNT_IDX * coordBuffElemSize,
235 coordBuffElemSize);
236
237 coordinationCommunicationManager->fence(getCoordinationBufferForCounts(), 1, 0);
238 coordinationCommunicationManager->fence(getCoordinationBufferForPayloads(), 1, 0);
239 }
240
247 __INLINE__ void updateDepth() {}
248
265 size_t getCoordinationDepth() { return getCircularBufferForCounts()->getDepth(); }
266
279 size_t getPayloadDepth() { return getCircularBufferForPayloads()->getDepth(); }
280
289 bool isEmpty() { return getCoordinationDepth() == 0; }
290
302 bool isFull(size_t requiredBufferSize)
303 {
304 auto coordinationCircularBuffer = getCircularBufferForCounts();
305 if (coordinationCircularBuffer->getDepth() == coordinationCircularBuffer->getCapacity()) return true;
306 auto payloadCircularBuffer = getCircularBufferForPayloads();
307 if (payloadCircularBuffer->getDepth() + requiredBufferSize > payloadCircularBuffer->getCapacity()) return true;
308
309 return false;
310 }
311
317 [[nodiscard]] std::shared_ptr<GlobalMemorySlot> getPayloadBufferMemorySlot() const { return _payloadBuffer; }
318
319 private:
320
324 std::shared_ptr<GlobalMemorySlot> _payloadBuffer;
325
331 const std::shared_ptr<GlobalMemorySlot> _tokenSizeBuffer;
332
337 const std::shared_ptr<GlobalMemorySlot> _producerCoordinationBufferForCounts;
338
343 const std::shared_ptr<GlobalMemorySlot> _producerCoordinationBufferForPayloads;
344};
345
346} // namespace HiCR::channel::variableSize::SPSC
Definition communicationManager.hpp:54
__INLINE__ size_t getDepth() const noexcept
Definition base.hpp:141
__INLINE__ CommunicationManager * getCoordinationCommunicationManager() const
Definition base.hpp:229
Definition base.hpp:41
__INLINE__ auto getCircularBufferForPayloads() const
Definition base.hpp:101
__INLINE__ auto getCoordinationBufferForPayloads() const
Definition base.hpp:113
__INLINE__ auto getCoordinationBufferForCounts() const
Definition base.hpp:107
__INLINE__ auto getCircularBufferForCounts() const
Definition base.hpp:95
__INLINE__ void pop(const size_t n=1)
Definition consumer.hpp:206
bool isEmpty()
Definition consumer.hpp:289
Consumer(CommunicationManager &coordinationCommunicationManager, CommunicationManager &payloadCommunicationManager, std::shared_ptr< GlobalMemorySlot > payloadBuffer, std::shared_ptr< GlobalMemorySlot > tokenBuffer, const std::shared_ptr< LocalMemorySlot > &internalCoordinationBufferForCounts, const std::shared_ptr< LocalMemorySlot > &internalCoordinationBufferForPayloads, const std::shared_ptr< GlobalMemorySlot > &producerCoordinationBufferForCounts, std::shared_ptr< GlobalMemorySlot > producerCoordinationBufferForPayloads, const size_t payloadCapacity, const size_t capacity)
Definition consumer.hpp:69
bool isFull(size_t requiredBufferSize)
Definition consumer.hpp:302
std::shared_ptr< GlobalMemorySlot > getPayloadBufferMemorySlot() const
Definition consumer.hpp:317
size_t getOldPayloadBytes(size_t n)
Definition consumer.hpp:180
__INLINE__ void updateDepth()
Definition consumer.hpp:247
__INLINE__ size_t basePeek(const size_t pos=0)
Definition consumer.hpp:117
size_t getCoordinationDepth()
Definition consumer.hpp:265
static __INLINE__ size_t getPayloadBufferSize(const size_t payloadSize) noexcept
Definition consumer.hpp:145
size_t getPayloadDepth()
Definition consumer.hpp:279
__INLINE__ std::array< size_t, 2 > peek(const size_t pos=0)
Definition consumer.hpp:157
#define HICR_THROW_RUNTIME(...)
Definition exceptions.hpp:74
#define HICR_THROW_LOGIC(...)
Definition exceptions.hpp:67
#define HICR_THROW_FATAL(...)
Definition exceptions.hpp:81
extends channel::Base into a base enabling var-size messages