Turi Create  4.0
turi::neural_net::tf_compute_context Class Reference

#include <ml/neural_net/tf_compute_context.hpp>

Public Types

using factory = std::function< std::unique_ptr< compute_context >()>
 

Public Member Functions

 tf_compute_context ()
 
void print_training_device_info () const override
 
size_t memory_budget () const override
 
std::unique_ptr< model_backendcreate_object_detector (int n, int c_in, int h_in, int w_in, int c_out, int h_out, int w_out, const float_array_map &config, const float_array_map &weights) override
 
std::unique_ptr< model_backendcreate_activity_classifier (const ac_parameters &ac_params) override
 
std::unique_ptr< model_backendcreate_style_transfer (const float_array_map &config, const float_array_map &weights) override
 
std::unique_ptr< model_backendcreate_drawing_classifier (const float_array_map &weights, size_t batch_size, size_t num_classes) override
 
std::unique_ptr< image_augmentercreate_image_augmenter (const image_augmenter::options &opts) override
 
virtual std::unique_ptr< turi::neural_net::model_backendcreate_multilayer_perceptron_classifier (int n, int c_in, int c_out, const std::vector< size_t > &layer_sizes, const turi::neural_net::float_array_map &config)
 

Static Public Member Functions

static std::unique_ptr< compute_contextcreate ()
 

Detailed Description

A compute_context implementation backed by TensorFlow for neural network computation and for data augmentation.

Definition at line 23 of file tf_compute_context.hpp.

Member Typedef Documentation

◆ factory

using turi::neural_net::compute_context::factory = std::function<std::unique_ptr<compute_context>()>
inherited

Function that yields a compute context.

Definition at line 58 of file compute_context.hpp.

Constructor & Destructor Documentation

◆ tf_compute_context()

turi::neural_net::tf_compute_context::tf_compute_context ( )

Constructs a context wrapping devices.

Member Function Documentation

◆ create()

static std::unique_ptr<compute_context> turi::neural_net::compute_context::create ( )
staticinherited

Requests a compute_context from each registered compute_context::factory, in ascending order by "priority", until one returns non-nil. Factories should be registered so that this function yields a backend appropriate to the current platform and hardware.

◆ create_activity_classifier()

std::unique_ptr<model_backend> turi::neural_net::tf_compute_context::create_activity_classifier ( const ac_parameters ac_params)
overridevirtual

Creates an activity classification network.

Reimplemented from turi::neural_net::compute_context.

◆ create_drawing_classifier()

std::unique_ptr<model_backend> turi::neural_net::tf_compute_context::create_drawing_classifier ( const float_array_map &  weights,
size_t  batch_size,
size_t  num_classes 
)
overridevirtual

Creates a drawing classification network.

Reimplemented from turi::neural_net::compute_context.

◆ create_image_augmenter()

std::unique_ptr<image_augmenter> turi::neural_net::tf_compute_context::create_image_augmenter ( const image_augmenter::options opts)
overridevirtual

Creates an image augmenter.

Reimplemented from turi::neural_net::compute_context.

◆ create_multilayer_perceptron_classifier()

virtual std::unique_ptr<turi::neural_net::model_backend> turi::neural_net::compute_context::create_multilayer_perceptron_classifier ( int  n,
int  c_in,
int  c_out,
const std::vector< size_t > &  layer_sizes,
const turi::neural_net::float_array_map &  config 
)
inlinevirtualinherited

Creates a multilevel perceptron classifier.

Definition at line 191 of file compute_context.hpp.

◆ create_object_detector()

std::unique_ptr<model_backend> turi::neural_net::tf_compute_context::create_object_detector ( int  n,
int  c_in,
int  h_in,
int  w_in,
int  c_out,
int  h_out,
int  w_out,
const float_array_map &  config,
const float_array_map &  weights 
)
overridevirtual

Creates an object detection network.

Reimplemented from turi::neural_net::compute_context.

◆ create_style_transfer()

std::unique_ptr<model_backend> turi::neural_net::tf_compute_context::create_style_transfer ( const float_array_map &  config,
const float_array_map &  weights 
)
overridevirtual

Creates a style transfer network

Reimplemented from turi::neural_net::compute_context.

◆ memory_budget()

size_t turi::neural_net::tf_compute_context::memory_budget ( ) const
overridevirtual

Provides a measure of the memory resources available.

Returns the maximum memory size in bytes that neural networks should allocate, typically used to determine batch sizes (often heuristically).

Implements turi::neural_net::compute_context.

◆ print_training_device_info()

void turi::neural_net::tf_compute_context::print_training_device_info ( ) const
overridevirtual

Prints (human readable) device information.

Implements turi::neural_net::compute_context.


The documentation for this class was generated from the following file: