turicreate.drawing_classifier.create

turicreate.drawing_classifier.create(input_dataset, target, feature=None, validation_set='auto', warm_start='auto', batch_size=256, max_iterations=500, verbose=True, random_seed=None)

Create a DrawingClassifier model.

Parameters:
dataset : SFrame

Input data. The columns named by the feature and target parameters will be extracted for training the drawing classifier.

target : string

Name of the column containing the target variable. The values in this column must be of string or integer type.

feature : string optional

Name of the column containing the input drawings. The feature column can contain either bitmap-based drawings or stroke-based drawings. Bitmap-based drawing input can be a grayscale tc.Image of any size. Stroke-based drawing input must be in the following format: Every drawing must be represented by a list of strokes, where each stroke must be a list of points in the order in which they were drawn on the canvas. Each point must be a dictionary with two keys, “x” and “y”, and their respective values must be numerical, i.e. either integer or float.

validation_set : SFrame optional

A dataset for monitoring the model’s generalization performance. The format of this SFrame must be the same as the training set. By default this argument is set to ‘auto’ and a validation set is automatically sampled and used for progress printing. If validation_set is set to None, then no additional metrics are computed. The default value is ‘auto’.

warm_start : string optional

A string to denote which pretrained model to use. Set to “auto” by default which uses a model trained on 245 of the 345 classes in the Quick, Draw! dataset. To disable warm start, pass in None to this argument. Here is a list of all the pretrained models that can be passed in as this argument: “auto”: Uses quickdraw_245_v0 “quickdraw_245_v0”: Uses a model trained on 245 of the 345 classes in the

Quick, Draw! dataset.

None: No Warm Start

batch_size: int optional

The number of drawings per training step. If not set, a default value of 256 will be used. If you are getting memory errors, try decreasing this value. If you have a powerful computer, increasing this value may improve performance.

max_iterations : int optional

The maximum number of allowed passes through the data. More passes over the data can result in a more accurately trained model.

verbose : bool optional

If True, print progress updates and model details.

random_seed : int, optional

The results can be reproduced when given the same seed.

Returns:
out : DrawingClassifier

A trained DrawingClassifier model.

Examples

# Train a drawing classifier model
>>> model = turicreate.drawing_classifier.create(data)

# Make predictions on the training set and as column to the SFrame
>>> data['predictions'] = model.predict(data)