/home/runner/work/HiCR/HiCR/include/hicr/frontends/channel/variableSize/spsc/producer.hpp Source File

HiCR: /home/runner/work/HiCR/HiCR/include/hicr/frontends/channel/variableSize/spsc/producer.hpp Source File
HiCR
producer.hpp
Go to the documentation of this file.
1/*
2 * Copyright 2025 Huawei Technologies Co., Ltd.
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
24#pragma once
25
27#include <utility>
28
29namespace HiCR::channel::variableSize::SPSC
30{
31
39{
40 public:
41
62 Producer(CommunicationManager &communicationManager,
63 std::shared_ptr<LocalMemorySlot> sizeInfoBuffer,
64 std::shared_ptr<GlobalMemorySlot> payloadBuffer,
65 std::shared_ptr<GlobalMemorySlot> tokenBuffer,
66 const std::shared_ptr<LocalMemorySlot> &internalCoordinationBufferForCounts,
67 const std::shared_ptr<LocalMemorySlot> &internalCoordinationBufferForPayloads,
68 std::shared_ptr<GlobalMemorySlot> consumerCoordinationBufferForCounts,
69 std::shared_ptr<GlobalMemorySlot> consumerCoordinationBufferForPayloads,
70 const size_t payloadCapacity,
71 const size_t payloadSize,
72 const size_t capacity)
73 : variableSize::Base(communicationManager, internalCoordinationBufferForCounts, internalCoordinationBufferForPayloads, capacity, payloadCapacity),
74 _payloadBuffer(std::move(payloadBuffer)),
75 _sizeInfoBuffer(std::move(sizeInfoBuffer)),
76 _payloadSize(payloadSize),
77 _tokenBuffer(std::move(tokenBuffer)),
78 _consumerCoordinationBufferForCounts(std::move(consumerCoordinationBufferForCounts)),
79 _consumerCoordinationBufferForPayloads(std::move(consumerCoordinationBufferForPayloads))
80 {}
81
82 ~Producer() = default;
83
88 __INLINE__ void updateDepth() {}
89
94 [[nodiscard]] __INLINE__ size_t getPayloadHeadPosition() const noexcept { return getCircularBufferForPayloads()->getHeadPosition(); }
95
100 __INLINE__ size_t getPayloadSize() { return _payloadSize; }
101
106 __INLINE__ size_t getPayloadDepth() { return getCircularBufferForPayloads()->getDepth(); }
107
112 __INLINE__ size_t getPayloadCapacity() { return getCircularBufferForPayloads()->getCapacity(); }
113
137 __INLINE__ void push(const std::shared_ptr<LocalMemorySlot> &sourceSlot, const size_t n = 1)
138 {
139 if (n != 1) HICR_THROW_RUNTIME("HiCR currently has no implementation for n != 1 with push(sourceSlot, n) for variable size version.");
140
141 // Make sure source slot is beg enough to satisfy the operation
142 size_t requiredBufferSize = sourceSlot->getSize();
143 size_t providedBufferCapacity = getPayloadCapacity();
144
145 // Updating depth of token (message sizes) and payload buffers
146 updateDepth();
147 auto currentPayloadDepth = getCircularBufferForPayloads()->getDepth();
148 auto currentDepth = getDepth();
149
150 /*
151 * Part 1: Copy the payload data
152 */
153 if (currentPayloadDepth + requiredBufferSize > providedBufferCapacity)
154 HICR_THROW_RUNTIME("Attempting to push (%lu) bytes while the channel currently has payload depth (%lu). This would exceed capacity (%lu).\n",
155 requiredBufferSize,
156 currentPayloadDepth,
157 providedBufferCapacity);
158
159 /*
160 * Payload copy:
161 * - We have checked (requiredBufferSize <= depth)
162 * that the payload fits into available circular buffer,
163 * but it is possible it spills over the end into the
164 * beginning. Cover this corner case below
165 */
166 if (requiredBufferSize + getPayloadHeadPosition() > getPayloadCapacity())
167 {
168 size_t first_chunk = getPayloadCapacity() - getPayloadHeadPosition();
169 size_t second_chunk = requiredBufferSize - first_chunk;
170
171 // copy first part to end of buffer
172 getCommunicationManager()->memcpy(_payloadBuffer, /* destination */
173 getPayloadHeadPosition(), /* dst_offset */
174 sourceSlot, /* source */
175 0, /* src_offset */
176 first_chunk); /* size */
177 // copy second part to beginning of buffer
178 getCommunicationManager()->memcpy(_payloadBuffer, /* destination */
179 0, /* dst_offset */
180 sourceSlot, /* source */
181 first_chunk, /* src_offset */
182 second_chunk); /* size */
183 getCommunicationManager()->fence(sourceSlot, 2, 0);
184 }
185 else
186 {
187 getCommunicationManager()->memcpy(_payloadBuffer, getPayloadHeadPosition(), sourceSlot, 0, requiredBufferSize);
188 getCommunicationManager()->fence(sourceSlot, 1, 0);
189 }
190
191 getCircularBufferForPayloads()->advanceHead(requiredBufferSize);
192
193 // update the consumer coordination buffers (consumer does not update
194 // its own coordination head positions)
195 getCommunicationManager()->memcpy(_consumerCoordinationBufferForPayloads,
196 _HICR_CHANNEL_HEAD_ADVANCE_COUNT_IDX * sizeof(size_t),
198 _HICR_CHANNEL_HEAD_ADVANCE_COUNT_IDX * sizeof(size_t),
199 sizeof(size_t));
201
202 /*
203 * Part 2: Copy the message size
204 */
205
206 auto *sizeInfoBufferPtr = static_cast<size_t *>(_sizeInfoBuffer->getPointer());
207 sizeInfoBufferPtr[0] = requiredBufferSize;
208
209 // If the exchange buffer does not have n free slots, reject the operation
210 if (currentDepth + 1 > getCircularBufferForCounts()->getCapacity())
211 HICR_THROW_RUNTIME("Attempting to push with (%lu) tokens while the channel has (%lu) tokens and this would exceed capacity (%lu).\n",
212 1,
213 getDepth(),
214 getCircularBufferForCounts()->getCapacity());
215
216 getCommunicationManager()->memcpy(_tokenBuffer, /* destination */
217 getTokenSize() * getCircularBufferForCounts()->getHeadPosition(), /* dst_offset */
218 _sizeInfoBuffer, /* source */
219 0, /* src_offset */
220 getTokenSize()); /* size */
221 getCommunicationManager()->fence(_sizeInfoBuffer, 1, 0);
222 getCircularBufferForCounts()->advanceHead(1);
223
224 getCommunicationManager()->memcpy(_consumerCoordinationBufferForCounts,
225 _HICR_CHANNEL_HEAD_ADVANCE_COUNT_IDX * sizeof(size_t),
227 _HICR_CHANNEL_HEAD_ADVANCE_COUNT_IDX * sizeof(size_t),
228 sizeof(size_t));
230 }
231
240 size_t getDepth() { return getCircularBufferForCounts()->getDepth(); }
241
250 bool isEmpty() { return getDepth() == 0; }
251
252 private:
253
257 std::shared_ptr<GlobalMemorySlot> _payloadBuffer;
261 std::shared_ptr<LocalMemorySlot> _sizeInfoBuffer;
265 size_t _payloadSize;
266
270 const std::shared_ptr<GlobalMemorySlot> _tokenBuffer;
271
275 const std::shared_ptr<GlobalMemorySlot> _consumerCoordinationBufferForCounts;
276
280 const std::shared_ptr<GlobalMemorySlot> _consumerCoordinationBufferForPayloads;
281};
282
283} // namespace HiCR::channel::variableSize::SPSC
Definition communicationManager.hpp:54
__INLINE__ void memcpy(const std::shared_ptr< LocalMemorySlot > &destination, size_t dst_offset, const std::shared_ptr< LocalMemorySlot > &source, size_t src_offset, size_t size)
Definition communicationManager.hpp:250
__INLINE__ void fence(GlobalMemorySlot::tag_t tag)
Definition communicationManager.hpp:360
__INLINE__ CommunicationManager * getCommunicationManager() const
Definition base.hpp:217
__INLINE__ size_t getTokenSize() const noexcept
Definition base.hpp:84
Definition base.hpp:41
__INLINE__ auto getCircularBufferForPayloads() const
Definition base.hpp:99
__INLINE__ auto getCoordinationBufferForPayloads() const
Definition base.hpp:111
__INLINE__ auto getCoordinationBufferForCounts() const
Definition base.hpp:105
__INLINE__ auto getCircularBufferForCounts() const
Definition base.hpp:93
__INLINE__ size_t getPayloadDepth()
Definition producer.hpp:106
size_t getDepth()
Definition producer.hpp:240
__INLINE__ void push(const std::shared_ptr< LocalMemorySlot > &sourceSlot, const size_t n=1)
Definition producer.hpp:137
__INLINE__ void updateDepth()
Definition producer.hpp:88
__INLINE__ size_t getPayloadCapacity()
Definition producer.hpp:112
bool isEmpty()
Definition producer.hpp:250
__INLINE__ size_t getPayloadSize()
Definition producer.hpp:100
__INLINE__ size_t getPayloadHeadPosition() const noexcept
Definition producer.hpp:94
Producer(CommunicationManager &communicationManager, std::shared_ptr< LocalMemorySlot > sizeInfoBuffer, std::shared_ptr< GlobalMemorySlot > payloadBuffer, std::shared_ptr< GlobalMemorySlot > tokenBuffer, const std::shared_ptr< LocalMemorySlot > &internalCoordinationBufferForCounts, const std::shared_ptr< LocalMemorySlot > &internalCoordinationBufferForPayloads, std::shared_ptr< GlobalMemorySlot > consumerCoordinationBufferForCounts, std::shared_ptr< GlobalMemorySlot > consumerCoordinationBufferForPayloads, const size_t payloadCapacity, const size_t payloadSize, const size_t capacity)
Definition producer.hpp:62
#define HICR_THROW_RUNTIME(...)
Definition exceptions.hpp:74
extends channel::Base into a base enabling var-size messages