Turi Create  4.0
combine_iterator.hpp
1 /* Copyright © 2020 Apple Inc. All rights reserved.
2  *
3  * Use of this source code is governed by a BSD-3-clause license that can
4  * be found in the LICENSE.txt file or at
5  * https://opensource.org/licenses/BSD-3-Clause
6  */
7 
8 #ifndef ML_NEURAL_NET_COMBINE_ITERATOR_HPP_
9 #define ML_NEURAL_NET_COMBINE_ITERATOR_HPP_
10 
11 #include <exception>
12 #include <memory>
13 #include <type_traits>
14 
16 
17 namespace turi {
18 namespace neural_net {
19 
20 template <typename T>
22 
23 /**
24  * Interface for objects that produce a sequence of values, using the
25  * conventional iterator interface.
26  *
27  * This type facilitates wrapping traditional iterator-style code as a
28  * Publisher.
29  */
30 template <typename T>
31 class Iterator : public std::enable_shared_from_this<Iterator<T>> {
32  public:
33  using Output = T;
34 
35  virtual ~Iterator() = default;
36 
37  /**
38  * Returns true as long as the underlying sequence contains more values.
39  *
40  * \todo If we have Optional<T>, we can remove this method and have Next()
41  * instead return Optional<Output>.
42  */
43  virtual bool HasNext() const = 0;
44 
45  /** Returns the next value in the sequence. May throw on error. */
46  virtual Output Next() = 0;
47 
48  /** Returns a Publisher wrapping this Iterator. */
49  std::shared_ptr<IteratorPublisher<T>> AsPublisher() {
50  return std::make_shared<IteratorPublisher<T>>(this->shared_from_this());
51  }
52 };
53 
54 /**
55  * Templated implementation of Iterator that wraps an arbitrary callable type.
56  */
57 template <typename Callable>
59  : public Iterator<typename std::result_of<Callable()>::type> {
60  public:
61  using Output = typename std::result_of<Callable()>::type;
62 
63  CallableIterator(Callable impl) : impl_(std::move(impl)) {}
64 
65  bool HasNext() const override { return true; }
66 
67  Output Next() override { return impl_(); }
68 
69  private:
70  Callable impl_;
71 };
72 
73 template <typename Callable>
74 std::shared_ptr<IteratorPublisher<typename std::result_of<Callable()>::type>>
75 CreatePublisherFromCallable(Callable impl) {
76  return std::make_shared<CallableIterator<Callable>>(std::move(impl))
77  ->AsPublisher();
78 }
79 
80 /**
81  * Concrete Publisher that wraps an Iterator.
82  *
83  * The resulting Publisher is unicast: each iterated value will go only to
84  * whichever Subscriber triggered the iteration.
85  */
86 template <typename T>
87 class IteratorPublisher : public Publisher<T> {
88  public:
89  using Output = T;
90 
91  explicit IteratorPublisher(std::shared_ptr<Iterator<Output>> iterator)
92  : iterator_(std::move(iterator)) {}
93 
94  void Receive(std::shared_ptr<Subscriber<Output>> subscriber) override {
95  auto subscription =
96  std::make_shared<IteratorSubscription>(subscriber, iterator_);
97  subscriber->Receive(std::move(subscription));
98  }
99 
100  private:
101  // All of the logic lives in the implementation of Subscription, which relies
102  // on the assumption that only one Subscription at a time will access the
103  // shared Iterator.
104  class IteratorSubscription : public Subscription {
105  public:
106  IteratorSubscription(std::shared_ptr<Subscriber<Output>> subscriber,
107  std::shared_ptr<Iterator<Output>> iterator)
108  : subscriber_(std::move(subscriber)), iterator_(std::move(iterator)) {}
109 
110  bool IsActive() const { return subscriber_ != nullptr; }
111 
112  void Cancel() override { subscriber_.reset(); }
113 
114  void Request(Demand demand) override {
115  // Keep sending signals to the Subscriber until we're cancelled or we
116  // exhaust the demand.
117  while (IsActive() && !demand.IsNone()) {
118  // Invoke the iterator to determine what signal we'll send.
119 
120  // Don't assume that Output has a (cheap) default constructor.
121  // TODO: Use an Optional type instead.
122  using OutputStorage =
123  typename std::aligned_storage<sizeof(Output),
124  alignof(Output)>::type;
125  OutputStorage value_storage;
126  Output* value = nullptr; // Track whether we have a value.
127  std::exception_ptr failure; // Track whether we have a failure.
128  try {
129  if (iterator_->HasNext()) {
130  // Use placement new to initialize our value from the iterator.
131  value = reinterpret_cast<Output*>(&value_storage);
132  new (value) Output(iterator_->Next());
133  }
134  } catch (...) {
135  // On any exception, *value was not initialized, since it was the last
136  // statement in the try block.
137  value = nullptr;
138  failure = std::current_exception();
139  }
140 
141  // Send the appropriate signal.
142  if (failure) {
143  // Signal failure and ensure we don't send any more signals.
144  subscriber_->Receive(Completion::Failure(failure));
145  Cancel();
146  } else if (!value) {
147  // Signal finished and ensure we don't send any more signals.
148  subscriber_->Receive(Completion::Finished());
149  Cancel();
150  } else {
151  // Pass the value to the Subscriber, adding any new demand.
152  demand.Decrement();
153  Demand new_demand = subscriber_->Receive(std::move(*value));
154  demand.Add(new_demand);
155 
156  // We must manually destroy the (moved-from) value.
157  value->~Output();
158  }
159  }
160  }
161 
162  private:
163  std::shared_ptr<Subscriber<Output>> subscriber_;
164  std::shared_ptr<Iterator<Output>> iterator_;
165  };
166 
167  std::shared_ptr<Iterator<Output>> iterator_;
168 };
169 
170 } // namespace neural_net
171 } // namespace turi
172 
173 #endif // ML_NEURAL_NET_COMBINE_ITERATOR_HPP_
virtual bool HasNext() const =0
static Completion Finished()
virtual Output Next()=0
void Receive(std::shared_ptr< Subscriber< Output >> subscriber) override
Demand & Add(Demand other)
std::shared_ptr< IteratorPublisher< T > > AsPublisher()
static Completion Failure(std::exception_ptr e)