/home/runner/work/HiCR/HiCR/include/hicr/frontends/channel/variableSize/mpsc/nonlocking/consumer.hpp Source File

HiCR: /home/runner/work/HiCR/HiCR/include/hicr/frontends/channel/variableSize/mpsc/nonlocking/consumer.hpp Source File
HiCR
consumer.hpp
1/*
2 * Copyright 2025 Huawei Technologies Co., Ltd.
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
24#pragma once
25
26#include <queue>
28
29namespace HiCR::channel::variableSize::MPSC::nonlocking
30{
31
37{
38 public:
39
65 Consumer(CommunicationManager &coordinationCommunicationManager,
66 CommunicationManager &payloadCommunicationManager,
67 const std::vector<std::shared_ptr<GlobalMemorySlot>> &payloadBuffers,
68 const std::vector<std::shared_ptr<GlobalMemorySlot>> &tokenBuffers,
69 const std::vector<std::shared_ptr<LocalMemorySlot>> &internalCoordinationBufferForCounts,
70 const std::vector<std::shared_ptr<LocalMemorySlot>> &internalCoordinationBufferForPayloads,
71 const std::vector<std::shared_ptr<GlobalMemorySlot>> &producerCoordinationBufferForCounts,
72 const std::vector<std::shared_ptr<GlobalMemorySlot>> &producerCoordinationBufferForPayloads,
73 const size_t payloadCapacity,
74 const size_t payloadSize,
75 const size_t capacity)
76 : _coordinationCommunicationManager(&coordinationCommunicationManager),
77 _payloadCommunicationManager(&payloadCommunicationManager)
78 {
79 // make sure producer and consumer sides have the same element size
80 // the size is hopefully the producer count
81 assert(!internalCoordinationBufferForCounts.empty());
82 auto producerCount = internalCoordinationBufferForCounts.size();
83 assert(producerCount == internalCoordinationBufferForPayloads.size());
84 assert(producerCount == producerCoordinationBufferForCounts.size());
85 assert(producerCount == producerCoordinationBufferForPayloads.size());
86
87 // create p (= number of producers) SPSC channels
88 for (size_t i = 0; i < producerCount; i++)
89 {
90 std::shared_ptr<variableSize::SPSC::Consumer> consumerPtr(new variableSize::SPSC::Consumer(coordinationCommunicationManager,
91 payloadCommunicationManager,
92 payloadBuffers[i],
93 tokenBuffers[i],
94 internalCoordinationBufferForCounts[i],
95 internalCoordinationBufferForPayloads[i],
96 producerCoordinationBufferForCounts[i],
97 producerCoordinationBufferForPayloads[i],
98 payloadCapacity,
99 capacity));
100 _spscList.push_back(consumerPtr);
101
102 /*
103 * Note that it is important to record messages that might already have been received
104 * immediately upon creation of the SPSC channel. Therefore we do not reset
105 * _depths to zero, and check for "early" received messages
106 */
107 _depths.push_back(consumerPtr->getCoordinationDepth());
108 for (size_t j = 0; j < _depths.back(); j++) { _channelPushes.push(i); }
109 }
110 }
111
112 ~Consumer() = default;
113
129 __INLINE__ std::array<size_t, 3> peek(const size_t pos = 0)
130 {
131 std::array<size_t, 3> ret = {0};
132 // @ToDo: to support pos > 0, we need to modify _channelPushes to
133 // be of type std::vector instead of std::queue
134 if (pos > 0) HICR_THROW_LOGIC("Nonblocking MPSC not yet implemented for peek with n!=0");
135
136 _coordinationCommunicationManager->flushReceived();
137 _payloadCommunicationManager->flushReceived();
138 updateDepth();
139 if (_channelPushes.empty()) HICR_THROW_RUNTIME("Attempting to peek position (%lu) but supporting queue has size (%lu)", pos, _channelPushes.size());
140
141 size_t channelId = _channelPushes.front(); // front() returns the first (i.e. oldest) element
142 if (channelId >= _spscList.size()) { HICR_THROW_LOGIC("channelId (%lu) >= _spscList.size() (%lu)", channelId, _spscList.size()); }
143 ret[0] = channelId;
144 ret[1] = _spscList[channelId]->peek()[0];
145 ret[2] = _spscList[channelId]->peek()[1];
146
147 return ret;
148 }
149
156 __INLINE__ size_t getDepth()
157 {
158 size_t totalDepth = 0;
159 for (auto d : _depths) { totalDepth += d; }
160
161 if (totalDepth != _channelPushes.size())
162 {
163 HICR_THROW_LOGIC("Helper FIFO and channels are out of sync, implemenation issue! getDepth (%lu) != _channelPushes.size() (%lu)", totalDepth, _channelPushes.size());
164 }
165 return totalDepth;
166 }
167
172 __INLINE__ bool isEmpty() { return (getDepth() == 0); }
173
179 __INLINE__ void pop(const size_t n = 1)
180 {
181 updateDepth();
182 // If the exchange buffer does not have enough tokens, reject operation
183 if (n > getDepth())
184 HICR_THROW_RUNTIME("Attempting to pop (%lu) tokens, which is more than the number of current tokens in the channel (%lu)", n, getDepth());
185 else
186 {
187 size_t channelFirstPushed = _channelPushes.front();
188 if (channelFirstPushed >= _spscList.size())
189 HICR_THROW_LOGIC("Index of latest push channel incorrect!");
190 else
191 {
192 // pop n elements from the SPSCs in the order recorded in the helper
193 // FIFO _channelPushes, and also update the FIFO itself
194 for (size_t i = 0; i < n; i++)
195 {
196 _spscList[channelFirstPushed]->pop();
197 _depths[channelFirstPushed]--;
198 _channelPushes.pop();
199 }
200 }
201 }
202 }
203
212 __INLINE__ void updateDepth()
213 {
214 std::vector<size_t> newDepths(_spscList.size());
215 /*
216 * Note that after calling updateDepth() on each SPSC channel,
217 * we must accept this state as a new temporary snapshot in newDepths.
218 * It is possible that during our iterating through newDepths, producers have
219 * sent more elements already, which will be handled in later updateDepth calls.
220 */
221
222 for (size_t i = 0; i < _spscList.size(); i++)
223 {
224 _spscList[i]->updateDepth();
225 newDepths[i] = _spscList[i]->getCoordinationDepth();
226 }
227
228 for (size_t i = 0; i < _spscList.size(); i++)
229 {
230 assert(_depths[i] <= newDepths[i]);
231 for (size_t j = _depths[i]; j < newDepths[i]; j++) { _channelPushes.push(i); }
232 }
233 std::swap(_depths, newDepths);
234 if (getDepth() != _channelPushes.size()) { HICR_THROW_LOGIC("getDepth (%lu) != _channelPushes.size() (%lu)", getDepth(), _channelPushes.size()); }
235 }
236
237 private:
238
242 std::vector<std::shared_ptr<channel::variableSize::SPSC::Consumer>> _spscList;
246 std::queue<size_t> _channelPushes;
247
251 std::vector<size_t> _depths;
252
256 CommunicationManager *const _coordinationCommunicationManager;
257
261 CommunicationManager *const _payloadCommunicationManager;
262};
263
264} // namespace HiCR::channel::variableSize::MPSC::nonlocking
Definition communicationManager.hpp:54
virtual __INLINE__ void flushReceived()
Definition communicationManager.hpp:469
__INLINE__ size_t getDepth()
Definition consumer.hpp:156
__INLINE__ std::array< size_t, 3 > peek(const size_t pos=0)
Definition consumer.hpp:129
__INLINE__ bool isEmpty()
Definition consumer.hpp:172
Consumer(CommunicationManager &coordinationCommunicationManager, CommunicationManager &payloadCommunicationManager, const std::vector< std::shared_ptr< GlobalMemorySlot > > &payloadBuffers, const std::vector< std::shared_ptr< GlobalMemorySlot > > &tokenBuffers, const std::vector< std::shared_ptr< LocalMemorySlot > > &internalCoordinationBufferForCounts, const std::vector< std::shared_ptr< LocalMemorySlot > > &internalCoordinationBufferForPayloads, const std::vector< std::shared_ptr< GlobalMemorySlot > > &producerCoordinationBufferForCounts, const std::vector< std::shared_ptr< GlobalMemorySlot > > &producerCoordinationBufferForPayloads, const size_t payloadCapacity, const size_t payloadSize, const size_t capacity)
Definition consumer.hpp:65
__INLINE__ void updateDepth()
Definition consumer.hpp:212
__INLINE__ void pop(const size_t n=1)
Definition consumer.hpp:179
#define HICR_THROW_RUNTIME(...)
Definition exceptions.hpp:74
#define HICR_THROW_LOGIC(...)
Definition exceptions.hpp:67
Provides functionality for a var-size SPSC consumer channel.