10 #include <core/util/Verify.hpp> 11 #include <ml/neural_net/TaskQueue.hpp> 15 namespace neural_net {
26 class SubscribeOnQueuePublisher :
public Publisher<T> {
30 SubscribeOnQueuePublisher(std::shared_ptr<Publisher<Output>> upstream,
31 std::shared_ptr<TaskQueue> queue)
32 : upstream_(
std::move(upstream)), queue_(
std::move(queue)) {}
33 ~SubscribeOnQueuePublisher()
override =
default;
37 SubscribeOnQueuePublisher(
const SubscribeOnQueuePublisher&) =
delete;
38 SubscribeOnQueuePublisher(SubscribeOnQueuePublisher&&) =
delete;
39 SubscribeOnQueuePublisher& operator=(
const SubscribeOnQueuePublisher&) =
41 SubscribeOnQueuePublisher& operator=(SubscribeOnQueuePublisher&&) =
delete;
46 std::shared_ptr<Publisher<Output>> upstream = upstream_;
47 auto impl = std::make_shared<Proxy>(std::move(subscriber), queue_);
48 queue_->DispatchAsync([upstream, impl] { upstream->Subscribe(impl); });
54 class Proxy :
public std::enable_shared_from_this<Proxy>,
59 std::shared_ptr<TaskQueue> queue)
60 : downstream_(std::move(downstream)), queue_(std::move(queue)) {}
61 ~Proxy()
override =
default;
65 Proxy(
const Proxy&) =
delete;
66 Proxy(Proxy&&) =
delete;
67 Proxy& operator=(
const Proxy&) =
delete;
68 Proxy& operator=(Proxy&&) =
delete;
72 void Receive(std::shared_ptr<Subscription> subscription)
override {
73 VerifyIsTrue(downstream_, TuriErrorCode::LogicError);
77 subscription_ = std::move(subscription);
81 downstream_->Receive(this->shared_from_this());
86 if (!downstream_)
return Demand::None();
88 return downstream_->Receive(std::move(element));
93 if (!downstream_)
return;
95 downstream_->Receive(std::move(completion));
100 void Cancel()
override {
102 if (!downstream_)
return;
105 downstream_ =
nullptr;
109 std::shared_ptr<Subscription> subscription = subscription_;
110 queue_->DispatchAsync([subscription] { subscription->Cancel(); });
113 void Request(
Demand demand)
override {
115 if (!downstream_)
return;
119 std::shared_ptr<Subscription> subscription = subscription_;
120 queue_->DispatchAsync(
121 [subscription, demand] { subscription->Request(demand); });
125 std::shared_ptr<Subscriber<Output>> downstream_;
126 std::shared_ptr<TaskQueue> queue_;
127 std::shared_ptr<Subscription> subscription_;
130 std::shared_ptr<Publisher<Output>> upstream_;
131 std::shared_ptr<TaskQueue> queue_;
142 template <
typename T>
148 std::shared_ptr<TaskQueue> queue)
149 : upstream_(std::move(upstream)), queue_(std::move(queue)) {}
160 auto proxy = std::make_shared<Proxy>(std::move(subscriber), queue_);
161 upstream_->Subscribe(std::move(proxy));
170 std::shared_ptr<TaskQueue> queue)
171 : downstream_(std::move(downstream)), queue_(std::move(queue)) {}
172 ~Proxy()
override =
default;
176 Proxy(
const Proxy&) =
delete;
177 Proxy(Proxy&&) =
delete;
178 Proxy& operator=(
const Proxy&) =
delete;
179 Proxy& operator=(Proxy&&) =
delete;
181 void Receive(std::shared_ptr<Subscription> subscription)
override {
184 subscription_ = subscription;
188 std::shared_ptr<Subscriber<Output>> downstream = downstream_;
189 queue_->DispatchAsync(
190 [downstream, subscription] { downstream->Receive(subscription); });
196 std::shared_ptr<Subscriber<Output>> downstream = downstream_;
197 std::shared_ptr<Output> shared_element =
198 std::make_shared<Output>(std::move(element));
199 std::shared_ptr<Subscription> subscription = subscription_;
200 queue_->DispatchAsync([downstream, subscription, shared_element] {
201 Demand demand = downstream->Receive(std::move(*shared_element));
204 if (!demand.IsNone()) {
205 subscription->Request(demand);
210 return Demand::None();
216 std::shared_ptr<Subscriber<Output>> downstream = downstream_;
217 queue_->DispatchAsync(
218 [downstream, completion] { downstream->Receive(completion); });
222 std::shared_ptr<Subscriber<Output>> downstream_;
223 std::shared_ptr<TaskQueue> queue_;
224 std::shared_ptr<Subscription> subscription_;
227 std::shared_ptr<Publisher<Output>> upstream_;
228 std::shared_ptr<TaskQueue> queue_;
void Receive(std::shared_ptr< Subscriber< Output >> subscriber) override
void Receive(std::shared_ptr< Subscriber< Output >> subscriber) override