7 #ifndef TURI_NEURAL_NET_CORE_IMAGE_AUGMENTATION_HPP_ 8 #define TURI_NEURAL_NET_CORE_IMAGE_AUGMENTATION_HPP_ 10 #include <ml/neural_net/image_augmentation.hpp> 12 #import <CoreImage/CoreImage.h> 13 #import <ml/neural_net/TCMPSImageAugmenting.h> 16 namespace neural_net {
22 class API_AVAILABLE(macos(10.13)) mps_image_augmenter:
public image_augmenter {
25 explicit mps_image_augmenter(
const options& opts);
31 std::function<
float(
float lower_bound,
float upper_bound)> rng);
33 const options& get_options()
const override {
return opts_; }
35 result prepare_images(std::vector<labeled_image> source_batch)
override;
39 mps_image_augmenter(
const options& opts,
40 NSArray<TCMPSUniformRandomNumberGenerator> *rng_batch);
43 CIContext *context_ = nil;
44 NSArray<id <TCMPSImageAugmenting>> *augmentations_ = nil;
45 NSArray<TCMPSUniformRandomNumberGenerator> *rng_batch_ = nil;
51 #endif // TURI_NEURAL_NET_CORE_IMAGE_AUGMENTATION_HPP_