8 #ifndef ML_NEURAL_NET_COMBINE_MAP_HPP_ 9 #define ML_NEURAL_NET_COMBINE_MAP_HPP_ 14 #include <type_traits> 19 namespace neural_net {
22 class IteratorPublisher;
27 template <
typename T,
typename U>
33 virtual ~Transform() =
default;
36 virtual Output
Invoke(Input value) = 0;
42 template <
typename T,
typename Callable>
43 class CallableTransform
44 :
public Transform<T, typename std::result_of<Callable(T)>::type> {
47 using Output =
typename std::result_of<Callable(T)>::type;
49 CallableTransform(Callable impl) : impl_(
std::move(impl)) {}
51 Output
Invoke(Input input)
override {
return impl_(std::move(input)); }
64 template <
typename T,
typename U>
72 : upstream_(std::move(upstream)), transform_(std::move(
transform)) {}
76 std::make_shared<MapSubscriber>(transform_, std::move(subscriber));
77 upstream_->Subscribe(std::move(impl));
81 class MapSubscriber :
public Subscriber<Input> {
86 downstream_(std::move(downstream)) {}
88 void Receive(std::shared_ptr<Subscription> subscription)
override {
90 downstream_->Receive(std::move(subscription));
94 Demand Receive(Input element)
override {
96 if (!downstream_)
return Demand::None();
100 typename std::aligned_storage<sizeof(Output), alignof(Output)>::type;
101 OutputStorage value_storage;
102 Output* value =
nullptr;
103 std::exception_ptr failure;
108 value =
reinterpret_cast<Output*
>(&value_storage);
109 new (value) Output(transform_->Invoke(std::move(element)));
112 failure = std::current_exception();
115 Demand demand = Demand::None();
118 auto downstream = std::move(downstream_);
121 demand = downstream_->Receive(std::move(*value));
128 void Receive(
Completion completion)
override {
130 downstream_->Receive(std::move(completion));
135 std::shared_ptr<Transform<Input, Output>> transform_;
136 std::shared_ptr<Subscriber<Output>> downstream_;
137 std::exception_ptr failure_;
140 std::shared_ptr<Publisher<Input>> upstream_;
141 std::shared_ptr<Transform<Input, Output>> transform_;
147 #endif // ML_NEURAL_NET_COMBINE_MAP_HPP_
void Receive(std::shared_ptr< Subscriber< Output >> subscriber) override
static Completion Failure(std::exception_ptr e)
void transform(S &&input, T &&output, TransformFn transformfn, std::set< size_t > constraint_segments=std::set< size_t >())