Turi Create  4.0
st_resnet16_model_trainer.hpp
1 /* Copyright © 2020 Apple Inc. All rights reserved.
2  *
3  * Use of this source code is governed by a BSD-3-clause license that can
4  * be found in the LICENSE.txt file or at
5  * https://opensource.org/licenses/BSD-3-Clause
6  */
7 
8 #ifndef TOOLKITS_STYLE_TRANSFER_ST_RESNET16_MODEL_TRAINER_HPP_
9 #define TOOLKITS_STYLE_TRANSFER_ST_RESNET16_MODEL_TRAINER_HPP_
10 
12 
13 namespace turi {
14 namespace style_transfer {
15 
16 /**
17  * Subclass of Checkpoint that generates ResNet16ModelTrainer instances.
18  */
20  public:
21  /**
22  * Loads a pretrained model to use as a starting point.
23  */
24  ResNet16Checkpoint(Config config, const std::string& resnet_mlmodel_path);
25 
26  /** Loads weights saved from a ResNet16ModelTrainer. */
27  ResNet16Checkpoint(Config config, neural_net::float_array_map weights);
28 
29  std::unique_ptr<ModelTrainer> CreateModelTrainer() const override;
30 
31  neural_net::model_spec ExportToCoreML() const override;
32 };
33 
34 /** Subclass of ModelTrainer encapsulating the resnet-16 architecture. */
36  public:
37  /**
38  * Initializes a model from a checkpoint.
39  */
40  ResNet16ModelTrainer(Config config, neural_net::float_array_map weights);
41 
42  bool SupportsLossComponents() const override { return false; }
43 
44  /** Returns a publisher that can be used to request checkpoints. */
45  std::shared_ptr<neural_net::Publisher<std::unique_ptr<Checkpoint>>>
46  AsCheckpointPublisher() override;
47 
48  protected:
49  std::shared_ptr<neural_net::model_backend> CreateTrainingBackend(
50  const std::string& vgg_mlmodel_path,
51  neural_net::compute_context* context) override;
52 
53  std::shared_ptr<neural_net::model_backend> CreateInferenceBackend(
54  neural_net::compute_context* context) override;
55 
56  private:
57  struct ModelState {
58  // Non-null if a training backend has been created.
59  std::shared_ptr<neural_net::model_backend> training_backend;
60 
61  // Only used until a training backend is created.
62  neural_net::float_array_map weights;
63  };
64 
65  static neural_net::float_array_map GetWeights(const ModelState& state);
66 
67  // This state is shared with the publishers we create
68  std::shared_ptr<ModelState> state_;
69 };
70 
71 } // namespace style_transfer
72 } // namespace turi
73 
74 #endif // TOOLKITS_STYLE_TRANSFER_ST_RESNET16_MODEL_TRAINER_HPP_
neural_net::model_spec ExportToCoreML() const override
std::unique_ptr< ModelTrainer > CreateModelTrainer() const override
ResNet16Checkpoint(Config config, const std::string &resnet_mlmodel_path)