Turi Create  4.0
combine_base.hpp
Go to the documentation of this file.
1 /* Copyright © 2020 Apple Inc. All rights reserved.
2  *
3  * Use of this source code is governed by a BSD-3-clause license that can
4  * be found in the LICENSE.txt file or at
5  * https://opensource.org/licenses/BSD-3-Clause
6  */
7 
8 #ifndef ML_NEURAL_NET_COMBINE_BASE_HPP_
9 #define ML_NEURAL_NET_COMBINE_BASE_HPP_
10 
11 /**
12  * \file combine_base.hpp
13  *
14  * Defines the core data types for a reactive-streams library inspired by the
15  * Swift Combine framework. Client code should generally import combine.hpp.
16  */
17 
18 #include <exception>
19 #include <memory>
20 
21 #include <ml/neural_net/TaskQueue.hpp>
22 
23 namespace turi {
24 namespace neural_net {
25 
26 // Forward declarations for types defined by other headers included by
27 // combine.hpp.
28 
29 template <typename T, typename Callable>
31 
32 template <typename T>
34 
35 template <typename T>
37 
38 template <typename T, typename U>
40 
41 template <typename T>
43 
44 template <typename T>
46 
47 template <typename T, typename U>
48 class Transform;
49 
50 /**
51  * Simple type expressing how many values a Subscriber is ready to receive
52  * from its Publisher.
53  */
54 class Demand {
55  public:
56  static Demand Unlimited() { return Demand(-1); }
57  static Demand None() { return Demand(0); }
58 
59  /** Any negative value is interpreted as "unlimited". */
60  explicit Demand(int max) : max_(max) {}
61 
62  bool IsUnlimited() const { return max_ < 0; }
63  bool IsNone() const { return max_ == 0; }
64 
65  /** Returns a negative number to indicate "unlimited." */
66  int max() const { return max_; }
67 
68  /** Additively combines another Demand value into this one. */
69  Demand& Add(Demand other) {
70  if (IsUnlimited() || other.IsUnlimited()) {
71  max_ = -1;
72  } else {
73  max_ += other.max_;
74  }
75  return *this;
76  }
77 
78  /** Decrease this demand by one if the current max is positive and finite. */
80  if (max_ > 0) {
81  --max_;
82  }
83  return *this;
84  }
85 
86  private:
87  int max_ = 0;
88 };
89 
90 /**
91  * Interface for objects that Publishers send to Subscribers to allow the
92  * Subscribers to (potentially asynchronously) control the flow of values that
93  * the Subscriber receives from the Publisher.
94  */
95 class Subscription {
96  public:
97  virtual ~Subscription() = default;
98 
99  /**
100  * Requests the Publisher to stop sending anything to the Subscriber.
101  *
102  * After receiving Cancel() from a Subscriber, a Publisher should thereafter
103  * ignore all future messages from that Subscriber, including future calls to
104  * Cancel.
105  *
106  * Publishers must support Subscribers calling Cancel() from inside
107  * Subscriber::Receive(...).
108  */
109  virtual void Cancel() = 0;
110 
111  /**
112  * Requests the Publisher to send the indicated number of values to the
113  * Subscriber.
114  *
115  * Publishers must support Subscribers calling Request(Demand) from inside
116  * Subscriber::Receive(Subscription), but Subscribers should avoid calling
117  * Request(Demand) inside Subscriber::Receive(Input). Instead, they should
118  * send additional Demand via the return value of Subscriber::Receive(Input)
119  * (to help prevent infinite recursion).
120  */
121  virtual void Request(Demand demand) = 0;
122 };
123 
124 /**
125  * Type representing a message from a Publisher to a Subscriber indicating that
126  * the Subscriber will no longer receive any further messages.
127  */
128 class Completion {
129  public:
130  /** Returns an instance that signals successful completion. */
131  static Completion Finished() { return Completion(); }
132 
133  /**
134  * Returns an instance that signals failure, described by the given
135  * exception.
136  */
137  static Completion Failure(std::exception_ptr e) { return Completion(e); }
138 
139  bool IsFinished() const { return failure_ == nullptr; }
140 
141  /** Returns the exception if a failure and a null pointer otherwise. */
142  std::exception_ptr failure() const { return failure_; }
143 
144  private:
145  explicit Completion(std::exception_ptr e = nullptr) : failure_(e) {}
146 
147  std::exception_ptr failure_;
148 };
149 
150 /**
151  * Interface for objects that consume values from a Publisher.
152  *
153  * Unless otherwise specified by the concrete implementation, external
154  * synchronization must be used to avoid concurrent calls the Subscriber
155  * interface from different threads.
156  */
157 template <typename T>
158 class Subscriber {
159  public:
160  /** The type of the values that this Subscriber consumes. */
161  using Input = T;
162 
163  virtual ~Subscriber() = default;
164 
165  /**
166  * The first signal that a Subscriber receives from a Publisher, passing the
167  * Subscription that the Subscriber can use to control the flow of values.
168  *
169  * A Subscriber may only have one Publisher. If it somehow receives more than
170  * one Subscription, it should call Subscription::Cancel() on any instances
171  * received after the first.
172  *
173  * A Subscriber is explictly allowed to demand values synchronously from
174  * within its implementation of this method.
175  */
176  virtual void Receive(std::shared_ptr<Subscription> subscription) = 0;
177 
178  /**
179  * Transmits a value from the Publisher to this Subscriber.
180  *
181  * A Subcriber should never receive more calls to this method than the total
182  * Demand it has requested from its publisher. Subscribers should only demand
183  * more elements from within this method via its return value.
184  */
185  virtual Demand Receive(Input element) = 0;
186 
187  /**
188  * Signals completion of the stream of values from the Publisher.
189  *
190  * A Subscriber should not receive any further signals of any kind after
191  * receiving a Completion.
192  */
193  virtual void Receive(Completion completion) = 0;
194 };
195 
196 /**
197  * Interface for objects that produce values on demand from its Subscribers.
198  *
199  * Unless otherwise specified by the concrete implementation, external
200  * synchronization must be used to avoid concurrent calls on multiple threads to
201  * a Publisher, including via the Subscriptions that it passes to its
202  * Subscribers.
203  *
204  * Each concrete implementation defines whether it is unicast or multicast:
205  * whether multiple Subscribers observe the same values or not. (An
206  * implementation might only support one Subscriber, by passing an immediate
207  * Completion to each Subscriber after the first.)
208  *
209  * Note: instances of this class are intended to be stored using shared_ptr.
210  * Many of the operators rely on generating strong references to the instance
211  * being augmented.
212  */
213 template <typename T>
214 class Publisher : public std::enable_shared_from_this<Publisher<T>> {
215  public:
216  /** The type of values that this Publisher produces. */
217  using Output = T;
218 
219  virtual ~Publisher() = default;
220 
221  /**
222  * Establishes a connection between this Publisher and the given Subcriber.
223  *
224  * The Publisher must eventually call Subscriber::Receive(Subscription) on the
225  * given Subscriber (and may do so synchronously). The Publisher must then
226  * conform to the protocol established by the Subscription.
227  */
228  virtual void Receive(std::shared_ptr<Subscriber<Output>> subscriber) = 0;
229 
230  // Convenienience methods, supporting the chaining together of operations.
231  // Many of these rely on the forward declarations above. Client code should
232  // include combine.hpp to ensure these are defined before they are used.
233 
234  void Subscribe(std::shared_ptr<Subscriber<Output>> subscriber) {
235  Receive(std::move(subscriber));
236  }
237 
238  std::shared_ptr<FuturesStream<Output>> AsFutures() {
239  auto subscriber = std::make_shared<FuturesSubscriber<Output>>();
240  Subscribe(subscriber);
241  return std::make_shared<FuturesStream<Output>>(std::move(subscriber));
242  }
243 
244  template <typename TransformType>
245  std::shared_ptr<Publisher<typename TransformType::Output>> Map(
246  std::shared_ptr<TransformType> transform) {
247  using TransformInput = typename TransformType::Input;
248  using TransformOutput = typename TransformType::Output;
249  return std::make_shared<MapPublisher<TransformInput, TransformOutput>>(
250  this->shared_from_this(), std::move(transform));
251  }
252 
253  template <typename Callable>
254  std::shared_ptr<Publisher<typename std::result_of<Callable(Output)>::type>>
255  Map(Callable fn) {
256  using TransformType = CallableTransform<Output, Callable>;
257  auto transform = std::make_shared<TransformType>(std::move(fn));
258  return Map(std::move(transform));
259  }
260 
261  std::shared_ptr<Publisher<Output>> SubscribeOn(
262  std::shared_ptr<TaskQueue> queue) {
263  return std::make_shared<SubscribeOnQueuePublisher<Output>>(
264  this->shared_from_this(), std::move(queue));
265  }
266 
267  std::shared_ptr<Publisher<Output>> ReceiveOn(
268  std::shared_ptr<TaskQueue> queue) {
269  return std::make_shared<ReceiveOnQueuePublisher<Output>>(
270  this->shared_from_this(), std::move(queue));
271  }
272 };
273 
274 } // namespace neural_net
275 } // namespace turi
276 
277 #endif // ML_NEURAL_NET_COMBINE_BASE_HPP_
static Completion Finished()
std::exception_ptr failure() const
Demand & Add(Demand other)
static Completion Failure(std::exception_ptr e)
void transform(S &&input, T &&output, TransformFn transformfn, std::set< size_t > constraint_segments=std::set< size_t >())
Definition: algorithm.hpp:64