Turi Create  4.0
All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Modules Pages
mps_style_transfer_backend.hpp
1 /* Copyright © 2019 Apple Inc. All rights reserved.
2  *
3  * Use of this source code is governed by a BSD-3-clause license that can
4  * be found in the LICENSE.txt file or at
5  * https://opensource.org/licenses/BSD-3-Clause
6  */
7 
8 #pragma once
9 
10 #include <functional>
11 #include <map>
12 #include <memory>
13 
14 #include <core/export.hpp>
15 #include <ml/neural_net/float_array.hpp>
16 #include <ml/neural_net/model_backend.hpp>
17 #include <ml/neural_net/mps_command_queue.hpp>
18 
19 #ifdef HAS_MACOS_10_15
20 
21 namespace turi {
22 namespace style_transfer {
23 
24 class EXPORT mps_style_transfer : public turi::neural_net::model_backend {
25 public:
26  mps_style_transfer(const turi::neural_net::float_array_map &config,
27  const turi::neural_net::float_array_map &weights);
28 
29  mps_style_transfer(const turi::neural_net::float_array_map &config,
30  const turi::neural_net::float_array_map &weights,
31  const turi::neural_net::mps_command_queue& command_queue);
32 
33  ~mps_style_transfer();
34 
35  turi::neural_net::float_array_map export_weights() const override;
36  turi::neural_net::float_array_map predict(const turi::neural_net::float_array_map& inputs) const override;
37  void set_learning_rate(float lr) override;
38  turi::neural_net::float_array_map train(const turi::neural_net::float_array_map& inputs) override;
39 private:
40  struct impl;
41  std::unique_ptr<impl> m_impl;
42 
43  void init(const turi::neural_net::float_array_map &config,
44  const turi::neural_net::float_array_map &weights,
45  const turi::neural_net::mps_command_queue& command_queue);
46 };
47 
48 } // namespace style_transfer
49 } // namespace turi
50 
51 #endif // #ifdef HAS_MACOS_10_15