Turi Create  4.0
loss_model_profiles.hpp
1 /* Copyright © 2017 Apple Inc. All rights reserved.
2  *
3  * Use of this source code is governed by a BSD-3-clause license that can
4  * be found in the LICENSE.txt file or at https://opensource.org/licenses/BSD-3-Clause
5  */
6 #ifndef TURI_FACTORIZATION_LOSS_MODEL_PROFILES_H_
7 #define TURI_FACTORIZATION_LOSS_MODEL_PROFILES_H_
8 
9 #include <core/logging/assertions.hpp>
10 #include <core/util/logit_math.hpp>
11 #include <core/util/code_optimization.hpp>
12 #include <cmath>
13 #include <memory>
14 #include <string>
15 #include <utility>
16 
17 namespace turi { namespace factorization {
18 
19 
20 /** The base class for the generative models. These models
21  * encapsulate the part of the problem surrounding the translation of
22  * the underlying linear model to the target/response variable. Thus
23  * it encapsulates the (1) translation function from linear model to
24  * response and (2) the loss function used to fit the coefficients of
25  * the linear model to predict the response.
26  *
27  * To make reporting easier, report_loss translates a cumulative loss
28  * value -- the sum of loss(...) over all data points -- to a
29  * standard error measure. It's name is returned along with the
30  * reportable value.
31  */
33  public:
34  virtual double loss(double fx, double y) const = 0;
35  virtual double loss_grad(double fx, double y) const = 0;
36  virtual double translate_fx_to_prediction(double f_x) const = 0;
37  virtual bool prediction_is_translated() const = 0;
38  virtual std::string reported_loss_name() const = 0;
39  virtual double reported_loss_value(double cumulative_loss_value) const = 0;
40 };
41 
42 ////////////////////////////////////////////////////////////////////////////////
43 
44 /** Implements squared error loss for the linear models.
45  */
46 class loss_squared_error final : public loss_model_profile {
47  public:
48 
49  static std::string name() { return "squared_error"; }
50 
51  double loss(double fx, double y) const GL_HOT_INLINE_FLATTEN {
52  return sq(fx - y);
53  }
54 
55  double loss_grad(double fx, double y) const GL_HOT_INLINE_FLATTEN {
56  return 2*(fx - y);
57  }
58 
59  double translate_fx_to_prediction(double f_x) const GL_HOT_INLINE_FLATTEN {
60  return f_x;
61  }
62 
63  bool prediction_is_translated() const { return false; }
64 
65  std::string reported_loss_name() const { return "RMSE"; }
66 
67  double reported_loss_value(double cumulative_loss_value) const {
68  return std::sqrt(cumulative_loss_value);
69  }
70 };
71 
72 ////////////////////////////////////////////////////////////////////////////////
73 
74 void _logistic_loss_value_is_bad(double) __attribute__((noinline, cold));
75 
76 /** Implements logistic loss for the linear models.
77  */
78 class loss_logistic final : public loss_model_profile {
79  public:
80 
81  static std::string name() { return "logistic"; }
82 
83  double loss(double fx, double y) const GL_HOT_INLINE_FLATTEN {
84  if(y < 0 || y > 1.0)
85  _logistic_loss_value_is_bad(y);
86 
87  return (1 - y) * fx + log1pen(fx);
88  }
89 
90  double loss_grad(double fx, double y) const GL_HOT_INLINE_FLATTEN {
91  return (1 - y) + log1pen_deriviative(fx);
92  }
93 
94  double translate_fx_to_prediction(double fx) const GL_HOT_INLINE_FLATTEN {
95  return sigmoid(fx);
96  }
97 
98  bool prediction_is_translated() const { return true; }
99 
100  std::string reported_loss_name() const { return "Predictive Error"; }
101 
102  double reported_loss_value(double cumulative_loss_value) const {
103  return cumulative_loss_value;
104  }
105 };
106 
107 ////////////////////////////////////////////////////////////////////////////////
108 
109 /** Implements ranking loss for the model
110  */
112  public:
113 
114  static std::string name() { return "hinge_ranking"; }
115 
116  double loss(double fx_diff, double) const GL_HOT_INLINE_FLATTEN {
117  return std::max(0.0, 1 - fx_diff);
118  }
119 
120  double loss_grad(double fx_diff, double) const GL_HOT_INLINE_FLATTEN {
121  return (fx_diff < 1) ? -1 : 0;
122  }
123 
124  double translate_fx_to_prediction(double fx) const GL_HOT_INLINE_FLATTEN {
125  return fx;
126  }
127 
128  bool prediction_is_translated() const { return false; }
129 
130  std::string reported_loss_name() const { return "Hinge Loss"; }
131 
132  double reported_loss_value(double cumulative_loss_value) const {
133  return cumulative_loss_value;
134  }
135 };
136 
137 /** Implements ranking loss for the model
138  */
140  public:
141 
142  static std::string name() { return "logit rank"; }
143 
144  double loss(double fx_diff, double) const GL_HOT_INLINE_FLATTEN {
145  return log1pen(fx_diff);
146  }
147 
148  double loss_grad(double fx_diff, double) const GL_HOT_INLINE_FLATTEN {
149  return log1pen_deriviative(fx_diff);
150  }
151 
152  double translate_fx_to_prediction(double fx) const GL_HOT_INLINE_FLATTEN {
153  return sigmoid(fx);
154  }
155 
156  bool prediction_is_translated() const { return true; }
157 
158  std::string reported_loss_name() const { return "Logistic Rank Loss"; }
159 
160  double reported_loss_value(double cumulative_loss_value) const {
161  return cumulative_loss_value;
162  }
163 };
164 
165 
166 /// A quick helper function to retrieve the correct profile by name.
167 std::shared_ptr<loss_model_profile> get_loss_model_profile(
168  const std::string& name);
169 
170 }}
171 
172 #endif
static GL_HOT_INLINE_FLATTEN double log1pen_deriviative(double x)
Definition: logit_math.hpp:88
static GL_HOT_INLINE_FLATTEN double sigmoid(double x)
Definition: logit_math.hpp:31
static GL_HOT_INLINE_FLATTEN double log1pen(double x)
Definition: logit_math.hpp:49
#define GL_HOT_INLINE_FLATTEN