Turi Create  4.0
tf_compute_context.hpp
1 /* Copyright © 2019 Apple Inc. All rights reserved.
2  *
3  * Use of this source code is governed by a BSD-3-clause license that can
4  * be found in the LICENSE.txt file or at https://opensource.org/licenses/BSD-3-Clause
5  */
6 
7 #ifndef UNITY_TOOLKITS_NEURAL_NET_TF_COMPUTE_CONTEXT_HPP_
8 #define UNITY_TOOLKITS_NEURAL_NET_TF_COMPUTE_CONTEXT_HPP_
9 
10 #include <core/export.hpp>
11 #include <ml/neural_net/compute_context.hpp>
12 
13 namespace turi {
14 namespace neural_net {
15 
16 
17 
18 /**
19  * A compute_context implementation backed by TensorFlow
20  * for neural network computation and for data
21  * augmentation.
22  */
23 EXPORT class tf_compute_context: public compute_context {
24 public:
25 
26  /**
27  * Constructs a context wrapping devices.
28  */
30 
31  virtual ~tf_compute_context();
32 
33  void print_training_device_info() const override;
34  size_t memory_budget() const override;
35 
36  std::unique_ptr<model_backend> create_object_detector(
37  int n, int c_in, int h_in, int w_in, int c_out, int h_out, int w_out,
38  const float_array_map& config, const float_array_map& weights) override;
39 
40  std::unique_ptr<model_backend> create_activity_classifier(
41  const ac_parameters& ac_params) override;
42 
43  std::unique_ptr<model_backend> create_style_transfer(
44  const float_array_map& config, const float_array_map& weights) override;
45 
46  std::unique_ptr<model_backend> create_drawing_classifier(
47  /* TODO: const float_array_map& config if needed */
48  const float_array_map& weights, size_t batch_size,
49  size_t num_classes) override;
50 
51  std::unique_ptr<image_augmenter> create_image_augmenter(
52  const image_augmenter::options &opts) override;
53 
54 };
55 
56 } // namespace neural_net
57 } // namespace turi
58 
59 #endif // UNITY_TOOLKITS_NEURAL_NET_TF_COMPUTE_CONTEXT_HPP_
size_t memory_budget() const override
std::unique_ptr< model_backend > create_drawing_classifier(const float_array_map &weights, size_t batch_size, size_t num_classes) override
std::unique_ptr< image_augmenter > create_image_augmenter(const image_augmenter::options &opts) override
std::unique_ptr< model_backend > create_object_detector(int n, int c_in, int h_in, int w_in, int c_out, int h_out, int w_out, const float_array_map &config, const float_array_map &weights) override
std::unique_ptr< model_backend > create_style_transfer(const float_array_map &config, const float_array_map &weights) override
void print_training_device_info() const override
std::unique_ptr< model_backend > create_activity_classifier(const ac_parameters &ac_params) override