Turi Create  4.0
combine_map.hpp
1 /* Copyright © 2020 Apple Inc. All rights reserved.
2  *
3  * Use of this source code is governed by a BSD-3-clause license that can
4  * be found in the LICENSE.txt file or at
5  * https://opensource.org/licenses/BSD-3-Clause
6  */
7 
8 #ifndef ML_NEURAL_NET_COMBINE_MAP_HPP_
9 #define ML_NEURAL_NET_COMBINE_MAP_HPP_
10 
11 #include <exception>
12 #include <functional>
13 #include <memory>
14 #include <type_traits>
15 
17 
18 namespace turi {
19 namespace neural_net {
20 
21 template <typename T>
22 class IteratorPublisher;
23 
24 /**
25  * Interface for objects that apply a transform to a value.
26  */
27 template <typename T, typename U>
28 class Transform {
29  public:
30  using Input = T;
31  using Output = U;
32 
33  virtual ~Transform() = default;
34 
35  /** Returns the next value in the sequence. May throw on error. */
36  virtual Output Invoke(Input value) = 0;
37 };
38 
39 /**
40  * Templated implementation of Transform that wraps an arbitrary callable type.
41  */
42 template <typename T, typename Callable>
43 class CallableTransform
44  : public Transform<T, typename std::result_of<Callable(T)>::type> {
45  public:
46  using Input = T;
47  using Output = typename std::result_of<Callable(T)>::type;
48 
49  CallableTransform(Callable impl) : impl_(std::move(impl)) {}
50 
51  Output Invoke(Input input) override { return impl_(std::move(input)); }
52 
53  private:
54  Callable impl_;
55 };
56 
57 /**
58  * Concrete operator Publisher that wraps a Transform.
59  *
60  * The resulting Publisher inherits the semantics of the upstream Publisher that
61  * it subscribes to, with regard to the semantics of multiple downstream
62  * subscribers. It simply applies the Transform to each value from the upstream.
63  */
64 template <typename T, typename U>
65 class MapPublisher : public Publisher<U> {
66  public:
67  using Input = T;
68  using Output = U;
69 
70  MapPublisher(std::shared_ptr<Publisher<T>> upstream,
71  std::shared_ptr<Transform<T, U>> transform)
72  : upstream_(std::move(upstream)), transform_(std::move(transform)) {}
73 
74  void Receive(std::shared_ptr<Subscriber<Output>> subscriber) override {
75  auto impl =
76  std::make_shared<MapSubscriber>(transform_, std::move(subscriber));
77  upstream_->Subscribe(std::move(impl));
78  }
79 
80  private:
81  class MapSubscriber : public Subscriber<Input> {
82  public:
83  MapSubscriber(std::shared_ptr<Transform<Input, Output>> transform,
84  std::shared_ptr<Subscriber<Output>> downstream)
85  : transform_(std::move(transform)),
86  downstream_(std::move(downstream)) {}
87 
88  void Receive(std::shared_ptr<Subscription> subscription) override {
89  if (downstream_) {
90  downstream_->Receive(std::move(subscription));
91  }
92  }
93 
94  Demand Receive(Input element) override {
95  // Do nothing if we are already cancelled.
96  if (!downstream_) return Demand::None();
97 
98  // TODO: Define Optional<T>. (Or require C++17 for std::optional?)
99  using OutputStorage =
100  typename std::aligned_storage<sizeof(Output), alignof(Output)>::type;
101  OutputStorage value_storage;
102  Output* value = nullptr; // Track whether we have a value.
103  std::exception_ptr failure; // Track whether we have a failure.
104  try {
105  // Use placement new to initialize the value from the transform output.
106  // Placement new must be the last statement of the try block, so that
107  // the catch block can assume the value is not initialized.
108  value = reinterpret_cast<Output*>(&value_storage);
109  new (value) Output(transform_->Invoke(std::move(element)));
110  } catch (...) {
111  value = nullptr;
112  failure = std::current_exception();
113  }
114 
115  Demand demand = Demand::None();
116  if (failure) {
117  // Leave downstream_ as nullptr to avoid sending any further signals.
118  auto downstream = std::move(downstream_);
119  downstream->Receive(Completion::Failure(failure));
120  } else {
121  demand = downstream_->Receive(std::move(*value));
122  value->~Output();
123  }
124 
125  return demand;
126  }
127 
128  void Receive(Completion completion) override {
129  if (downstream_) {
130  downstream_->Receive(std::move(completion));
131  }
132  }
133 
134  private:
135  std::shared_ptr<Transform<Input, Output>> transform_;
136  std::shared_ptr<Subscriber<Output>> downstream_;
137  std::exception_ptr failure_;
138  };
139 
140  std::shared_ptr<Publisher<Input>> upstream_;
141  std::shared_ptr<Transform<Input, Output>> transform_;
142 };
143 
144 } // namespace neural_net
145 } // namespace turi
146 
147 #endif // ML_NEURAL_NET_COMBINE_MAP_HPP_
STL namespace.
virtual Output Invoke(Input value)=0
void Receive(std::shared_ptr< Subscriber< Output >> subscriber) override
Definition: combine_map.hpp:74
static Completion Failure(std::exception_ptr e)
Output Invoke(Input input) override
Definition: combine_map.hpp:51
void transform(S &&input, T &&output, TransformFn transformfn, std::set< size_t > constraint_segments=std::set< size_t >())
Definition: algorithm.hpp:64