6 #ifndef TURI_LSH_FAMILY_H_ 7 #define TURI_LSH_FAMILY_H_ 10 #include <toolkits/nearest_neighbors/nearest_neighbors.hpp> 11 #include <toolkits/nearest_neighbors/hash_map_container.hpp> 14 namespace nearest_neighbors {
17 class EXPORT lsh_family {
20 lsh_family() =
default;
21 virtual ~lsh_family() {}
24 inline static std::shared_ptr<lsh_family> create_lsh_family(
const std::string& dist_name);
27 virtual bool is_asymmetric()
const = 0;
30 virtual std::string distance_type_name()
const = 0;
33 virtual void init_options(
const std::map<std::string, flexible_type>& _opts);
38 virtual void pre_lsh(
const v2::ml_data& mld_ref,
bool is_sparse) {}
41 virtual void init_model(
size_t num_dimensions) = 0;
46 void add_reference_data(
size_t ref_id,
const T& t);
51 std::vector<size_t> query(
const T& t)
const;
59 virtual std::vector<int> hash_vector_to_codes(
const DenseVector& vec,
60 bool is_reference_data)
const;
61 virtual std::vector<int> hash_vector_to_codes(
const SparseVector& vec,
62 bool is_reference_data)
const;
65 size_t num_input_dimensions;
67 size_t num_projections_per_table;
68 size_t num_projections;
69 std::vector<hash_map_container<size_t, std::vector<size_t>>> lookup_table;
79 virtual bool is_asymmetric()
const {
return false; }
81 virtual std::string distance_type_name()
const {
87 virtual void pre_lsh(
const v2::ml_data& mld_ref,
bool is_sparse);
89 virtual void init_model(
size_t num_dimensions);
95 virtual std::vector<int> hash_vector_to_codes(
const DenseVector& vec,
96 bool is_reference_data)
const;
97 virtual std::vector<int> hash_vector_to_codes(
const SparseVector& vec,
98 bool is_reference_data)
const;
102 DenseMatrix rand_mat;
103 DenseVector rand_vec;
113 std::string distance_type_name()
const {
114 return "squared_euclidean";
125 std::string distance_type_name()
const {
131 void pre_lsh(
const v2::ml_data& mld_ref,
bool is_sparse);
133 void init_model(
size_t num_dimensions);
142 bool is_asymmetric()
const {
return false; }
144 std::string distance_type_name()
const {
148 void init_model(
size_t num_dimensions);
154 std::vector<int> hash_vector_to_codes(
const DenseVector& vec,
155 bool is_reference_data)
const;
156 std::vector<int> hash_vector_to_codes(
const SparseVector& vec,
157 bool is_reference_data)
const;
160 DenseMatrix rand_mat;
169 bool is_asymmetric()
const {
return false; }
171 std::string distance_type_name()
const {
175 void init_model(
size_t num_dimensions);
181 std::vector<int> hash_vector_to_codes(
const DenseVector& vec,
182 bool is_reference_data)
const;
183 std::vector<int> hash_vector_to_codes(
const SparseVector& vec,
184 bool is_reference_data)
const;
187 void fill_empty_bins(std::vector<int>& vec)
const;
191 std::vector<size_t> rand_permutation;
192 std::vector<size_t> rand_sign;
201 bool is_asymmetric()
const {
return true; }
203 virtual std::string distance_type_name()
const {
204 return "dot_product";
207 void init_model(
size_t num_dimensions);
209 void pre_lsh(
const v2::ml_data& mld_ref,
bool is_sparse);
215 std::vector<int> hash_vector_to_codes(
const DenseVector& vec,
216 bool is_reference_data)
const;
217 std::vector<int> hash_vector_to_codes(
const SparseVector& vec,
218 bool is_reference_data)
const;
222 DenseMatrix rand_mat;
223 DenseVector rand_vec;
227 std::string distance_type_name()
const {
228 return "transformed_dot_product";
232 std::shared_ptr<lsh_family> lsh_family::create_lsh_family(
const std::string& dist_name) {
233 if (dist_name ==
"euclidean") {
235 }
else if (dist_name ==
"squared_euclidean") {
237 }
else if (dist_name ==
"manhattan") {
239 }
else if (dist_name ==
"cosine") {
240 return std::shared_ptr<lsh_family>(
new lsh_cosine);
241 }
else if (dist_name ==
"jaccard") {
242 return std::shared_ptr<lsh_family>(
new lsh_jaccard);
243 }
else if (dist_name ==
"dot_product") {
245 }
else if (dist_name ==
"transformed_dot_product") {
246 return std::shared_ptr<lsh_family>(
new lsh_transformed_dot_product);
248 log_and_throw(dist_name + std::string(
" is not supported by LSH! Try another distance or method!"));
253 template <
typename T>
254 void lsh_family::add_reference_data(
size_t ref_id,
const T& vec) {
256 ASSERT_MSG(
size_t(vec.size()) == num_input_dimensions,
257 "The input dimension does not match the previous ones!");
259 auto hash_vec = hash_vector_to_codes(vec,
true);
263 auto hash_bucket_id = boost::hash_range(
264 hash_vec.begin() + table_idx * num_projections_per_table,
265 hash_vec.begin() + std::min((table_idx + 1) * num_projections_per_table, num_projections));
267 lookup_table[table_idx].update(hash_bucket_id, [ref_id](std::vector<size_t>& v){
273 template <
typename T>
274 std::vector<size_t> lsh_family::query(
const T& vec)
const {
276 ASSERT_MSG(
size_t(vec.size()) == num_input_dimensions,
277 "The input num_dimensions does not match the reference data!");
279 std::unordered_set<size_t> ret;
280 auto hash_vec = hash_vector_to_codes(vec,
false);
282 for (
size_t table_idx = 0; table_idx < num_tables; ++table_idx) {
283 auto hash_bucket_id = boost::hash_range(
284 hash_vec.begin() + table_idx * num_projections_per_table,
285 hash_vec.begin() + std::min((table_idx + 1) * num_projections_per_table, num_projections));
287 const auto& candidates = lookup_table[table_idx].get(hash_bucket_id);
288 ret.insert(candidates.cbegin(), candidates.cend());
291 std::vector<size_t> ret_vec(ret.begin(), ret.end());
void parallel_for(size_t begin, size_t end, const FunctionType &fn)
The serialization input archive object which, provided with a reference to an istream, will read from the istream, providing deserialization capabilities.
The serialization output archive object which, provided with a reference to an ostream, will write to the ostream, providing serialization capabilities.
#define DASSERT_TRUE(cond)