Turi Create  4.0
mps_image_augmentation.hpp
1 /* Copyright © 2018 Apple Inc. All rights reserved.
2  *
3  * Use of this source code is governed by a BSD-3-clause license that can
4  * be found in the LICENSE.txt file or at https://opensource.org/licenses/BSD-3-Clause
5  */
6 
7 #ifndef TURI_NEURAL_NET_CORE_IMAGE_AUGMENTATION_HPP_
8 #define TURI_NEURAL_NET_CORE_IMAGE_AUGMENTATION_HPP_
9 
10 #include <ml/neural_net/image_augmentation.hpp>
11 
12 #import <CoreImage/CoreImage.h>
13 #import <ml/neural_net/TCMPSImageAugmenting.h>
14 
15 namespace turi {
16 namespace neural_net {
17 
18 /**
19  * Implementation of image_augmentation that uses Core Image, for use with the
20  * MPS-based neural-net backend.
21  */
22 class API_AVAILABLE(macos(10.13)) mps_image_augmenter: public image_augmenter {
23 public:
24 
25  explicit mps_image_augmenter(const options& opts);
26 
27  // Variant constructor allowing injection of the random number generator,
28  // largely for testing.
29  mps_image_augmenter(
30  const options& opts,
31  std::function<float(float lower_bound, float upper_bound)> rng);
32 
33  const options& get_options() const override { return opts_; }
34 
35  result prepare_images(std::vector<labeled_image> source_batch) override;
36 
37 private:
38 
39  mps_image_augmenter(const options& opts,
40  NSArray<TCMPSUniformRandomNumberGenerator> *rng_batch);
41 
42  options opts_;
43  CIContext *context_ = nil;
44  NSArray<id <TCMPSImageAugmenting>> *augmentations_ = nil;
45  NSArray<TCMPSUniformRandomNumberGenerator> *rng_batch_ = nil;
46 };
47 
48 } // neural_net
49 } // turi
50 
51 #endif // TURI_NEURAL_NET_CORE_IMAGE_AUGMENTATION_HPP_