/home/runner/work/HiCR/HiCR/include/hicr/frontends/channel/variableSize/spsc/consumer.hpp Source File

HiCR: /home/runner/work/HiCR/HiCR/include/hicr/frontends/channel/variableSize/spsc/consumer.hpp Source File
HiCR
consumer.hpp
Go to the documentation of this file.
1/*
2 * Copyright 2025 Huawei Technologies Co., Ltd.
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
24#pragma once
25
26#include <array>
27#include <numeric>
28#include <cassert>
30#include <utility>
31
32namespace HiCR::channel::variableSize::SPSC
33{
34
41class Consumer final : public variableSize::Base
42{
43 public:
44
68 Consumer(CommunicationManager &communicationManager,
69 std::shared_ptr<GlobalMemorySlot> payloadBuffer,
70 std::shared_ptr<GlobalMemorySlot> tokenBuffer,
71 const std::shared_ptr<LocalMemorySlot> &internalCoordinationBufferForCounts,
72 const std::shared_ptr<LocalMemorySlot> &internalCoordinationBufferForPayloads,
73 const std::shared_ptr<GlobalMemorySlot> &producerCoordinationBufferForCounts,
74 std::shared_ptr<GlobalMemorySlot> producerCoordinationBufferForPayloads,
75 const size_t payloadCapacity,
76 const size_t capacity)
77 : variableSize::Base(communicationManager, internalCoordinationBufferForCounts, internalCoordinationBufferForPayloads, capacity, payloadCapacity),
78
79 _payloadBuffer(std::move(payloadBuffer)),
80 _tokenSizeBuffer(std::move(tokenBuffer)),
81 _producerCoordinationBufferForCounts(producerCoordinationBufferForCounts),
82 _producerCoordinationBufferForPayloads(std::move(producerCoordinationBufferForPayloads))
83 {
84 assert(internalCoordinationBufferForCounts != nullptr);
85 assert(internalCoordinationBufferForPayloads != nullptr);
86 assert(producerCoordinationBufferForCounts != nullptr);
87 assert(producerCoordinationBufferForCounts != nullptr);
88 }
89
110 __INLINE__ size_t basePeek(const size_t pos = 0)
111 {
112 // Check if the requested position exceeds the capacity of the channel
113 if (pos >= getCircularBufferForCounts()->getCapacity())
114 HICR_THROW_LOGIC("Attempting to peek for a token with position (%lu), which is beyond than the channel capacity (%lu)", pos, getCircularBufferForCounts()->getCapacity());
115
116 // Updating channel depth
117 updateDepth();
118
119 // Check if there are enough tokens in the buffer to satisfy the request
120 if (pos >= getCircularBufferForCounts()->getDepth())
121 HICR_THROW_RUNTIME("Attempting to peek position (%lu) but not enough tokens (%lu) are in the buffer", pos, getCircularBufferForCounts()->getDepth());
122
123 // Calculating buffer position
124 const size_t bufferPos = (getCircularBufferForCounts()->getTailPosition() + pos) % getCircularBufferForCounts()->getCapacity();
125
126 // Succeeded in pushing the token(s)
127 return bufferPos;
128 }
129
140 __INLINE__ std::array<size_t, 2> peek(const size_t pos = 0)
141 {
142 if (pos != 0) { HICR_THROW_FATAL("peek only implemented for n = 0 at the moment!"); }
143 updateDepth();
144
145 if (pos >= getCircularBufferForCounts()->getDepth())
146 {
147 HICR_THROW_RUNTIME("Attempting to peek position (%lu) but not enough tokens (%lu) are in the buffer", pos, getCircularBufferForCounts()->getDepth());
148 }
149
150 std::array<size_t, 2> result{};
151 result[0] = getCircularBufferForPayloads()->getTailPosition() % getCircularBufferForPayloads()->getCapacity();
152 size_t *tokenBufferPtr = static_cast<size_t *>(_tokenSizeBuffer->getSourceLocalMemorySlot()->getPointer());
153 auto tokenPos = basePeek(pos);
154 result[1] = tokenBufferPtr[tokenPos];
155 return result;
156 }
157
163 size_t getOldPayloadBytes(size_t n)
164 {
165 if (n == 0) return 0;
166 size_t *tokenBufferPtr = static_cast<size_t *>(_tokenSizeBuffer->getSourceLocalMemorySlot()->getPointer());
167
168 size_t payloadBytes = 0;
169 for (size_t i = 0; i < n; i++)
170 {
171 assert(i >= 0);
172 size_t pos = basePeek(i);
173 auto payloadSize = tokenBufferPtr[pos];
174 payloadBytes += payloadSize;
175 }
176 return payloadBytes;
177 }
178
184 size_t getNewPayloadBytes(size_t n)
185 {
186 if (n == 0) return 0;
187 size_t *tokenBufferPtr = static_cast<size_t *>(_tokenSizeBuffer->getSourceLocalMemorySlot()->getPointer());
188 size_t payloadBytes = 0;
189 for (size_t i = 0; i < n; i++)
190 {
191 size_t ind = getCircularBufferForCounts()->getDepth() - 1 - i;
192 assert(ind >= 0);
193 size_t pos = basePeek(ind);
194 auto payloadSize = tokenBufferPtr[pos];
195 payloadBytes += payloadSize;
196 }
197
198 return payloadBytes;
199 }
200
211 __INLINE__ void pop(const size_t n = 1)
212 {
213 if (n > getCircularBufferForCounts()->getCapacity())
214 HICR_THROW_LOGIC("Attempting to pop (%lu) tokens, which is larger than the channel capacity (%lu)", n, getCircularBufferForCounts()->getCapacity());
215
216 // Updating channel depth
217 updateDepth();
218
219 // If the exchange buffer does not have n tokens pushed, reject operation
221 HICR_THROW_RUNTIME("Attempting to pop (%lu) tokens, which is more than the number of current tokens in the channel (%lu)", n, getCircularBufferForCounts()->getDepth());
222 auto payloadBytes = getOldPayloadBytes(n);
223 getCircularBufferForCounts()->advanceTail(n);
224 getCircularBufferForPayloads()->advanceTail(payloadBytes);
225
226 const auto coordBuffElemSize = sizeof(_HICR_CHANNEL_COORDINATION_BUFFER_ELEMENT_TYPE);
227 // Notifying producer(s) of buffer liberation
228 getCommunicationManager()->memcpy(_producerCoordinationBufferForCounts, /* destination */
229 _HICR_CHANNEL_TAIL_ADVANCE_COUNT_IDX * coordBuffElemSize,
231 _HICR_CHANNEL_TAIL_ADVANCE_COUNT_IDX * coordBuffElemSize,
232 coordBuffElemSize);
233
234 getCommunicationManager()->memcpy(_producerCoordinationBufferForPayloads, /* destination */
235 _HICR_CHANNEL_TAIL_ADVANCE_COUNT_IDX * coordBuffElemSize,
237 _HICR_CHANNEL_TAIL_ADVANCE_COUNT_IDX * coordBuffElemSize,
238 coordBuffElemSize);
239
242 }
243
250 __INLINE__ void updateDepth() {}
251
265 size_t getDepth() { return getCircularBufferForCounts()->getDepth(); }
266
277 size_t getPayloadDepth() { return getCircularBufferForPayloads()->getDepth(); }
278
287 bool isEmpty() { return (getDepth() == 0); }
288
294 [[nodiscard]] std::shared_ptr<GlobalMemorySlot> getPayloadBufferMemorySlot() const { return _payloadBuffer; }
295
296 private:
297
301 std::shared_ptr<GlobalMemorySlot> _payloadBuffer;
302
308 const std::shared_ptr<GlobalMemorySlot> _tokenSizeBuffer;
309
314 const std::shared_ptr<GlobalMemorySlot> _producerCoordinationBufferForCounts;
315
320 const std::shared_ptr<GlobalMemorySlot> _producerCoordinationBufferForPayloads;
321};
322
323} // namespace HiCR::channel::variableSize::SPSC
Definition communicationManager.hpp:54
__INLINE__ void memcpy(const std::shared_ptr< LocalMemorySlot > &destination, size_t dst_offset, const std::shared_ptr< LocalMemorySlot > &source, size_t src_offset, size_t size)
Definition communicationManager.hpp:267
__INLINE__ void fence(GlobalMemorySlot::tag_t tag)
Definition communicationManager.hpp:377
__INLINE__ CommunicationManager * getCommunicationManager() const
Definition base.hpp:217
Definition base.hpp:41
__INLINE__ auto getCircularBufferForPayloads() const
Definition base.hpp:99
__INLINE__ auto getCoordinationBufferForPayloads() const
Definition base.hpp:111
__INLINE__ auto getCoordinationBufferForCounts() const
Definition base.hpp:105
__INLINE__ auto getCircularBufferForCounts() const
Definition base.hpp:93
__INLINE__ void pop(const size_t n=1)
Definition consumer.hpp:211
bool isEmpty()
Definition consumer.hpp:287
std::shared_ptr< GlobalMemorySlot > getPayloadBufferMemorySlot() const
Definition consumer.hpp:294
size_t getOldPayloadBytes(size_t n)
Definition consumer.hpp:163
__INLINE__ void updateDepth()
Definition consumer.hpp:250
__INLINE__ size_t basePeek(const size_t pos=0)
Definition consumer.hpp:110
size_t getDepth()
Definition consumer.hpp:265
Consumer(CommunicationManager &communicationManager, std::shared_ptr< GlobalMemorySlot > payloadBuffer, std::shared_ptr< GlobalMemorySlot > tokenBuffer, const std::shared_ptr< LocalMemorySlot > &internalCoordinationBufferForCounts, const std::shared_ptr< LocalMemorySlot > &internalCoordinationBufferForPayloads, const std::shared_ptr< GlobalMemorySlot > &producerCoordinationBufferForCounts, std::shared_ptr< GlobalMemorySlot > producerCoordinationBufferForPayloads, const size_t payloadCapacity, const size_t capacity)
Definition consumer.hpp:68
size_t getNewPayloadBytes(size_t n)
Definition consumer.hpp:184
size_t getPayloadDepth()
Definition consumer.hpp:277
__INLINE__ std::array< size_t, 2 > peek(const size_t pos=0)
Definition consumer.hpp:140
#define HICR_THROW_RUNTIME(...)
Definition exceptions.hpp:74
#define HICR_THROW_LOGIC(...)
Definition exceptions.hpp:67
#define HICR_THROW_FATAL(...)
Definition exceptions.hpp:81
extends channel::Base into a base enabling var-size messages