turicreate.drawing_classifier.create¶
-
turicreate.drawing_classifier.
create
(input_dataset, target, feature=None, validation_set='auto', warm_start='auto', batch_size=256, max_iterations=500, verbose=True, random_seed=None)¶ Create a
DrawingClassifier
model.Parameters: - dataset : SFrame
Input data. The columns named by the
feature
andtarget
parameters will be extracted for training the drawing classifier.- target : string
Name of the column containing the target variable. The values in this column must be of string or integer type.
- feature : string optional
Name of the column containing the input drawings. The feature column can contain either bitmap-based drawings or stroke-based drawings. Bitmap-based drawing input can be a grayscale tc.Image of any size. Stroke-based drawing input must be in the following format: Every drawing must be represented by a list of strokes, where each stroke must be a list of points in the order in which they were drawn on the canvas. Each point must be a dictionary with two keys, “x” and “y”, and their respective values must be numerical, i.e. either integer or float.
- validation_set : SFrame optional
A dataset for monitoring the model’s generalization performance. The format of this SFrame must be the same as the training set. By default this argument is set to ‘auto’ and a validation set is automatically sampled and used for progress printing. If validation_set is set to None, then no additional metrics are computed. The default value is ‘auto’.
- warm_start : string optional
A string to denote which pretrained model to use. Set to “auto” by default which uses a model trained on 245 of the 345 classes in the Quick, Draw! dataset. To disable warm start, pass in None to this argument. Here is a list of all the pretrained models that can be passed in as this argument: “auto”: Uses quickdraw_245_v0 “quickdraw_245_v0”: Uses a model trained on 245 of the 345 classes in the
Quick, Draw! dataset.
None: No Warm Start
- batch_size: int optional
The number of drawings per training step. If not set, a default value of 256 will be used. If you are getting memory errors, try decreasing this value. If you have a powerful computer, increasing this value may improve performance.
- max_iterations : int optional
The maximum number of allowed passes through the data. More passes over the data can result in a more accurately trained model.
- verbose : bool optional
If True, print progress updates and model details.
- random_seed : int, optional
The results can be reproduced when given the same seed.
Returns: - out : DrawingClassifier
A trained
DrawingClassifier
model.
See also
Examples
# Train a drawing classifier model >>> model = turicreate.drawing_classifier.create(data) # Make predictions on the training set and as column to the SFrame >>> data['predictions'] = model.predict(data)