Turi Create  4.0
combine_mock.hpp
1 /* Copyright © 2020 Apple Inc. All rights reserved.
2  *
3  * Use of this source code is governed by a BSD-3-clause license that can
4  * be found in the LICENSE.txt file or at
5  * https://opensource.org/licenses/BSD-3-Clause
6  */
7 
8 #ifndef ML_NEURAL_NET_COMBINE_MOCK_HPP_
9 #define ML_NEURAL_NET_COMBINE_MOCK_HPP_
10 
12 
13 #include <functional>
14 #include <queue>
15 
16 #include <boost/test/unit_test.hpp>
17 #include <core/util/test_macros.hpp>
18 
19 namespace turi {
20 namespace neural_net {
21 
22 /**
23  * Helper function to reduce verbosity of writing mocks.
24  *
25  * Pops the first callback from the given queue and invokes it with the provided
26  * arguments.
27  */
28 template <typename R, typename... Args>
29 R Call(std::queue<std::function<R(Args...)>> *callbacks, Args &&... args) {
30  TS_ASSERT(!callbacks->empty());
31  auto callback = std::move(callbacks->front());
32  callbacks->pop();
33  return callback(std::forward<Args>(args)...);
34 }
35 
36 class MockSubscription : public Subscription {
37  public:
38  ~MockSubscription() override {
39  TS_ASSERT(cancel_callbacks.empty());
40  TS_ASSERT(demand_callbacks.empty());
41  }
42 
43  void Cancel() override { return Call(&cancel_callbacks); }
44  std::queue<std::function<void()>> cancel_callbacks;
45 
46  void Request(Demand demand) override {
47  return Call(&demand_callbacks, std::move(demand));
48  }
49  std::queue<std::function<void(Demand)>> demand_callbacks;
50 };
51 
52 template <typename T>
53 class MockSubscriber : public Subscriber<T> {
54  public:
55  using Input = T;
56 
57  ~MockSubscriber() override {
58  TS_ASSERT(subscription_callbacks.empty());
59  TS_ASSERT(input_callbacks.empty());
60  TS_ASSERT(completion_callbacks.empty());
61  }
62 
63  void Receive(std::shared_ptr<Subscription> subscription) override {
64  return Call(&subscription_callbacks, std::move(subscription));
65  }
66  std::queue<std::function<void(std::shared_ptr<Subscription>)>>
67  subscription_callbacks;
68 
69  Demand Receive(Input element) override {
70  return Call(&input_callbacks, std::move(element));
71  }
72  std::queue<std::function<Demand(Input)>> input_callbacks;
73 
74  void Receive(Completion completion) override {
75  return Call(&completion_callbacks, std::move(completion));
76  }
77  std::queue<std::function<void(Completion)>> completion_callbacks;
78 };
79 
80 template <typename T>
81 class MockPublisher : public Publisher<T> {
82  public:
83  using Output = T;
84 
85  ~MockPublisher() override { TS_ASSERT(subscriber_callbacks.empty()); }
86 
87  void Receive(std::shared_ptr<Subscriber<Output>> subscriber) override {
88  return Call(&subscriber_callbacks, std::move(subscriber));
89  }
90  std::queue<std::function<void(std::shared_ptr<Subscriber<Output>>)>>
91  subscriber_callbacks;
92 };
93 
94 } // namespace neural_net
95 } // namespace turi
96 
97 #endif // ML_NEURAL_NET_COMBINE_MOCK_HPP_