Turi Create  4.0
combine_queue.hpp
1 /* Copyright © 2020 Apple Inc. All rights reserved.
2  *
3  * Use of this source code is governed by a BSD-3-clause license that can
4  * be found in the LICENSE.txt file or at
5  * https://opensource.org/licenses/BSD-3-Clause
6  */
7 
8 #pragma once
9 
10 #include <core/util/Verify.hpp>
11 #include <ml/neural_net/TaskQueue.hpp>
13 
14 namespace turi {
15 namespace neural_net {
16 
17 /**
18  * Publisher that implements the Publisher::SubscribeOn operator.
19  *
20  * The resulting Publisher simply dispatches subscription requests, demands, and
21  * cancellations to a specified task queue. It inherits the semantics of the
22  * upstream Publisher that it wraps, with regard to the behavior with multiple
23  * downstream subscribers.
24  */
25 template <typename T>
26 class SubscribeOnQueuePublisher : public Publisher<T> {
27  public:
28  using Output = T;
29 
30  SubscribeOnQueuePublisher(std::shared_ptr<Publisher<Output>> upstream,
31  std::shared_ptr<TaskQueue> queue)
32  : upstream_(std::move(upstream)), queue_(std::move(queue)) {}
33  ~SubscribeOnQueuePublisher() override = default;
34 
35  // Any attempt to copy or move instances of this class is likely an error. All
36  // instances should be allocated with std::make_shared.
37  SubscribeOnQueuePublisher(const SubscribeOnQueuePublisher&) = delete;
38  SubscribeOnQueuePublisher(SubscribeOnQueuePublisher&&) = delete;
39  SubscribeOnQueuePublisher& operator=(const SubscribeOnQueuePublisher&) =
40  delete;
41  SubscribeOnQueuePublisher& operator=(SubscribeOnQueuePublisher&&) = delete;
42 
43  void Receive(std::shared_ptr<Subscriber<Output>> subscriber) override {
44  // Pass a proxy for this subscriber to the upstream publisher, but do so on
45  // the requested task queue.
46  std::shared_ptr<Publisher<Output>> upstream = upstream_;
47  auto impl = std::make_shared<Proxy>(std::move(subscriber), queue_);
48  queue_->DispatchAsync([upstream, impl] { upstream->Subscribe(impl); });
49  }
50 
51  private:
52  // This class serves as an intermediary between the upstream publisher and the
53  // downstream subscriber.
54  class Proxy : public std::enable_shared_from_this<Proxy>,
55  public Subscriber<Output>,
56  public Subscription {
57  public:
58  Proxy(std::shared_ptr<Subscriber<Output>> downstream,
59  std::shared_ptr<TaskQueue> queue)
60  : downstream_(std::move(downstream)), queue_(std::move(queue)) {}
61  ~Proxy() override = default;
62 
63  // Any attempt to copy or move instances of this class is likely an error.
64  // All instances should be allocated with std::make_shared.
65  Proxy(const Proxy&) = delete;
66  Proxy(Proxy&&) = delete;
67  Proxy& operator=(const Proxy&) = delete;
68  Proxy& operator=(Proxy&&) = delete;
69 
70  // Subscriber<Output> interface, for upstream publisher
71 
72  void Receive(std::shared_ptr<Subscription> subscription) override {
73  VerifyIsTrue(downstream_, TuriErrorCode::LogicError); // We cannot have been canceled yet.
74 
75  // Intercept (and store) the subscription we receive from the upstream
76  // publisher.
77  subscription_ = std::move(subscription);
78 
79  // Pass ourselves to the downstream subscriber. We will serve as a
80  // proxy. From here on out, we can be canceled at any time.
81  downstream_->Receive(this->shared_from_this());
82  }
83 
84  Demand Receive(Output element) override {
85  // Do nothing if we are already cancelled.
86  if (!downstream_) return Demand::None();
87 
88  return downstream_->Receive(std::move(element));
89  }
90 
91  void Receive(Completion completion) override {
92  // Do nothing if we are already cancelled.
93  if (!downstream_) return;
94 
95  downstream_->Receive(std::move(completion));
96  }
97 
98  // Subscription interface, for downstream subscriber
99 
100  void Cancel() override {
101  // Do nothing if we are already cancelled.
102  if (!downstream_) return;
103 
104  // Ensure that we send no further signals to the downstream subscriber.
105  downstream_ = nullptr;
106 
107  // Forward the cancel request to the upstream publisher, but do so on the
108  // requested task queue.
109  std::shared_ptr<Subscription> subscription = subscription_;
110  queue_->DispatchAsync([subscription] { subscription->Cancel(); });
111  }
112 
113  void Request(Demand demand) override {
114  // Do nothing if we are already cancelled.
115  if (!downstream_) return;
116 
117  // Forward the request to the upstream publisher, but do so on the
118  // requested task queue.
119  std::shared_ptr<Subscription> subscription = subscription_;
120  queue_->DispatchAsync(
121  [subscription, demand] { subscription->Request(demand); });
122  }
123 
124  private:
125  std::shared_ptr<Subscriber<Output>> downstream_;
126  std::shared_ptr<TaskQueue> queue_;
127  std::shared_ptr<Subscription> subscription_;
128  };
129 
130  std::shared_ptr<Publisher<Output>> upstream_;
131  std::shared_ptr<TaskQueue> queue_;
132 };
133 
134 /**
135  * Publisher that implements the Publisher::ReceiveOn operator.
136  *
137  * The resulting Publisher simply dispatches subscriptions, values, and
138  * completions to a specified task queue. It inherits the semantics of the
139  * upstream Publisher that it wraps, with regard to the behavior with multiple
140  * downstream subscribers.
141  */
142 template <typename T>
143 class ReceiveOnQueuePublisher : public Publisher<T> {
144  public:
145  using Output = T;
146 
147  ReceiveOnQueuePublisher(std::shared_ptr<Publisher<Output>> upstream,
148  std::shared_ptr<TaskQueue> queue)
149  : upstream_(std::move(upstream)), queue_(std::move(queue)) {}
150  ~ReceiveOnQueuePublisher() override = default;
151 
152  // Any attempt to copy or move instances of this class is likely an error. All
153  // instances should be allocated with std::make_shared.
156  ReceiveOnQueuePublisher& operator=(const ReceiveOnQueuePublisher&) = delete;
157  ReceiveOnQueuePublisher& operator=(ReceiveOnQueuePublisher&&) = delete;
158 
159  void Receive(std::shared_ptr<Subscriber<Output>> subscriber) override {
160  auto proxy = std::make_shared<Proxy>(std::move(subscriber), queue_);
161  upstream_->Subscribe(std::move(proxy));
162  }
163 
164  private:
165  // This class serves as an intermediary between the upstream publisher and the
166  // downstream subscriber.
167  class Proxy : public Subscriber<Output> {
168  public:
169  Proxy(std::shared_ptr<Subscriber<Output>> downstream,
170  std::shared_ptr<TaskQueue> queue)
171  : downstream_(std::move(downstream)), queue_(std::move(queue)) {}
172  ~Proxy() override = default;
173 
174  // Any attempt to copy or move instances of this class is likely an error.
175  // All instances should be allocated with std::make_shared.
176  Proxy(const Proxy&) = delete;
177  Proxy(Proxy&&) = delete;
178  Proxy& operator=(const Proxy&) = delete;
179  Proxy& operator=(Proxy&&) = delete;
180 
181  void Receive(std::shared_ptr<Subscription> subscription) override {
182  // Store a reference to the subscription so we can request incremental
183  // demands resulting from async delivery of values.
184  subscription_ = subscription;
185 
186  // Send the subscription to the downstream subscriber on the requested
187  // task queue.
188  std::shared_ptr<Subscriber<Output>> downstream = downstream_;
189  queue_->DispatchAsync(
190  [downstream, subscription] { downstream->Receive(subscription); });
191  }
192 
193  Demand Receive(Output element) override {
194  // Send the element to the downstream subscriber on the requested task
195  // queue.
196  std::shared_ptr<Subscriber<Output>> downstream = downstream_;
197  std::shared_ptr<Output> shared_element =
198  std::make_shared<Output>(std::move(element));
199  std::shared_ptr<Subscription> subscription = subscription_;
200  queue_->DispatchAsync([downstream, subscription, shared_element] {
201  Demand demand = downstream->Receive(std::move(*shared_element));
202 
203  // If the subscriber immediately demands more, dispatch a new request.
204  if (!demand.IsNone()) {
205  subscription->Request(demand);
206  }
207  });
208 
209  // Don't wait for the subscriber to respond.
210  return Demand::None();
211  }
212 
213  void Receive(Completion completion) override {
214  // Send the completion to the downstream subscriber on the requested task
215  // queue.
216  std::shared_ptr<Subscriber<Output>> downstream = downstream_;
217  queue_->DispatchAsync(
218  [downstream, completion] { downstream->Receive(completion); });
219  }
220 
221  private:
222  std::shared_ptr<Subscriber<Output>> downstream_;
223  std::shared_ptr<TaskQueue> queue_;
224  std::shared_ptr<Subscription> subscription_;
225  };
226 
227  std::shared_ptr<Publisher<Output>> upstream_;
228  std::shared_ptr<TaskQueue> queue_;
229 };
230 
231 } // namespace neural_net
232 } // namespace turi
void Receive(std::shared_ptr< Subscriber< Output >> subscriber) override
STL namespace.
void Receive(std::shared_ptr< Subscriber< Output >> subscriber) override