Turi Create  4.0
combine_futures_subscriber.hpp
1 /* Copyright © 2020 Apple Inc. All rights reserved.
2  *
3  * Use of this source code is governed by a BSD-3-clause license that can
4  * be found in the LICENSE.txt file or at
5  * https://opensource.org/licenses/BSD-3-Clause
6  */
7 
8 #pragma once
9 
10 #include <future>
11 #include <memory>
12 #include <queue>
13 
14 #include <core/util/std/make_unique.hpp>
15 #include <core/util/Verify.hpp>
17 
18 namespace turi {
19 namespace neural_net {
20 
21 /**
22  * Subscriber that synchronously produces futures for promises to be fulfilled
23  * by its publisher.
24  *
25  * This type is useful for integrating Publishers into existing code bases that
26  * rely on synchronous behavior or futures.
27  *
28  * Client code MUST call FuturesSubscriber::Cancel() to tear down a
29  * FuturesSubscriber instance. This requirement can be handled automatically
30  * using the FuturesStream wrapper class below.
31  */
32 template <typename T>
33 class FuturesSubscriber : public Subscriber<T> {
34  public:
35  using Input = T;
36 
37  /**
38  * Submits a request for a value to the Publisher but immediately returns a
39  * future for that value.
40  *
41  * If the publisher returned a failure for this request or any previous
42  * request from this subscriber, then the future will store that exception. If
43  * the publisher returned Completion::IsFinished() for this request or any
44  * previous request, or if Cancel is called, then the future will store null.
45  *
46  * \todo Use an Optional<T> type instead of unique_ptr to avoid allocation.
47  */
48  std::future<std::unique_ptr<T>> Request() {
49  std::promise<std::unique_ptr<T>> promise;
50  auto future = promise.get_future();
51  if (failure_) {
52  // We've already observed an exception. Set it now.
53  promise.set_exception(failure_);
54  } else if (completed_) {
55  // We've already observed the end of the sequence. Signal completion now.
56  promise.set_value(nullptr);
57  } else {
58  // Enqueue this promise and submit a request to the Publisher.
59  promises_.push(std::move(promise));
60  if (subscription_) {
61  subscription_->Request(Demand(1));
62  }
63  }
64  return future;
65  }
66 
67  void Cancel() {
68  if (completed_) return;
69 
70  completed_ = true;
71 
72  // Cancel the subscription if active.
73  if (subscription_) {
74  subscription_->Cancel();
75  subscription_.reset();
76  }
77 
78  // Fulfill any outstanding promises.
79  while (!promises_.empty()) {
80  promises_.front().set_value(nullptr);
81  promises_.pop();
82  }
83  }
84 
85  void Receive(std::shared_ptr<Subscription> subscription) override {
86  // It is a programmer error to attach the same Subscriber to more than one
87  // Publisher.
88  VerifyIsTrue(subscription_ == nullptr, TuriErrorCode::LogicError);
89 
90  // Reject any subscriptions after the first. Reject the first subscription
91  // if we cancelled before it could start.
92  if (subscription_ || completed_) {
93  subscription->Cancel();
94  return;
95  }
96 
97  subscription_ = std::move(subscription);
98 
99  // If we already have promises queued, request their values now.
100  if (!promises_.empty()) {
101  Demand demand(static_cast<int>(promises_.size()));
102  subscription_->Request(demand);
103  }
104  }
105 
106  Demand Receive(Input element) override {
107  // Do nothing if we were cancelled.
108  if (completed_) return Demand::None();
109 
110  // Wrap the value in a unique_ptr and fulfill the promise.
111  auto input = std::make_unique<Input>(std::move(element));
112  promises_.front().set_value(std::move(input));
113  promises_.pop();
114  return Demand::None();
115  }
116 
117  void Receive(Completion completion) override {
118  completed_ = true;
119  if (!completion.IsFinished()) {
120  failure_ = completion.failure();
121  }
122 
123  // Fulfill any outstanding promises.
124  while (!promises_.empty()) {
125  auto promise = std::move(promises_.front());
126  promises_.pop();
127  if (failure_) {
128  promise.set_exception(completion.failure());
129  } else {
130  promise.set_value(nullptr);
131  }
132  }
133  }
134 
135  private:
136  std::shared_ptr<Subscription> subscription_;
137 
138  std::queue<std::promise<std::unique_ptr<T>>> promises_;
139 
140  bool completed_ = false;
141  std::exception_ptr failure_;
142 };
143 
144 /**
145  * Simple wrapper class around FuturesSubscriber that calls Cancel() on
146  * destruction of the wrapper.
147  */
148 template <typename T>
149 class FuturesStream {
150  public:
151  explicit FuturesStream(std::shared_ptr<FuturesSubscriber<T>> subscriber)
152  : subscriber_(std::move(subscriber)) {}
153 
154  ~FuturesStream() { subscriber_->Cancel(); }
155 
156  std::future<std::unique_ptr<T>> Next() { return subscriber_->Request(); }
157 
158  private:
159  std::shared_ptr<FuturesSubscriber<T>> subscriber_;
160 };
161 
162 } // namespace neural_net
163 } // namespace turi
void Receive(Completion completion) override
std::exception_ptr failure() const
std::future< std::unique_ptr< T > > Request()
void Receive(std::shared_ptr< Subscription > subscription) override