MATLAB Machine Learning

The Basics and a Quick Tutorial


MATLAB is a high-level language and interactive environment designed for numerical computation, visualization, and programming. Developed by MathWorks, MATLAB allows matrix manipulations, plotting of functions and data, implementation of algorithms, creation of user interfaces, and interfacing with programs written in other languages.

MATLAB stands for Matrix Laboratory, reflecting its strength in matrix operations, which are critical in scientific and engineering tasks. It's used widely in academia and industry for a range of applications, including video and image processing, control systems, signal processing and communication, testing and measurements, computational biology, and computational finance.

MATLAB provides a simple syntax and desktop environment tuned for iterative analysis and design processes, leading to a much shorter learning curve compared to other programming languages. Moreover, its extensive library of pre-built functions allows users to create sophisticated programs without having to be an expert programmer.

This is part of a series of articles about machine learning engineer.

In this article:

Why Use MATLAB for Machine Learning?

Machine learning is a method of data analysis that automates analytical model building. It's a branch of artificial intelligence based on the idea that systems can learn from data, identify patterns and make decisions with minimal human intervention. MATLAB has several benefits for machine learning applications:

  • Rich library of pre-built functions and toolboxes: These functions and toolboxes simplify the process of developing machine learning models, enabling users to focus on the bigger picture rather than getting bogged down in the intricacies of coding.
  • Robust visualization capabilities: Make it possible to understand data and model behavior. Through visualization, users can explore data, understand patterns and trends, validate models, and present results in a meaningful way.
  • Interface with other programming languages and hardware: MATLAB machine learning models can be incorporated with machine learning frameworks and specialized hardware like graphical processing units (GPUs).

Key MATLAB Functions for Machine Learning Models


One of the key functions for MATLAB machine learning is fitcsvm. This function is used to train a Support Vector Machine (SVM) classifier for binary classification. SVMs are powerful models that can handle both linear and non-linear classification tasks. They work by finding the hyperplane that best separates the classes in the feature space.


Another essential function is fitcnb, which is used to train a Naive Bayes classifier. Naive Bayes is a classification technique based on applying Bayes' theorem using strong independence assumptions between the features. It's simple yet remarkably effective, particularly for text classification and sentiment analysis tasks.


The fitctree function is used to train a decision tree for classification or regression. Decision trees are intuitive and easy to interpret models that can handle both categorical and numerical data. They work by creating a tree-like model of decisions based on the features.


The fitrgam function is used to fit generalized additive models (GAMs) for regression. GAMs are flexible models that can model complex non-linear relationships between the response and predictor variables. They work by fitting smooth functions to each predictor variable while allowing for interactions between variables.


The fitglm function is used to fit a generalized linear model (GLM). GLMs are a flexible extension of ordinary linear regression that allows for response variables to have error distribution models other than a normal distribution. They are widely used in statistics to model different types of data.


Finally, the trainNetwork function is used for training a neural network for classification, regression, or feature learning. Neural networks are powerful models inspired by the human brain that can model complex relationships and patterns in data.

Related content: Read our guide to machine learning infrastructure

Tutorial: Training a Neural Network with trainNetwork

This tutorial is based on the official MATLAB documentation.

Step 1: Load Data

Load the relevant data as an object (ImageDatastore):

digitDatasetPath = fullfile(matlabroot,'toolbox','nnet', ...


imds = imageDatastore(digitDatasetPath, ...

   'IncludeSubfolders',true, ...


The datastore contains 10,000 synthetic images of digits from 0 to 9. The images are generated by applying random transformations to digit images created with different fonts. The datastore has an equal number of images per category.

Step 2: Display Images

Use this code to display some of the images:


numImages = 10000;

perm = randperm(numImages,20);

for i = 1:20





The output should look something like this:

Source: MATLAB

Step 3: Create Training and Testing Set

Divide the datastore so that each category in the training set has 750 images and the testing set has the remaining images from each label.

numTrainingFiles = 750;

[imdsTrain,imdsTest] = splitEachLabel(imds,numTrainingFiles,'randomize');

splitEachLabel splits the image files in digitData into two new datastores, imdsTrain and imdsTest.

Step 4: Define Convolutional Neural Network Architecture

Define the convolutional neural network architecture:

layers = [ ...

   imageInputLayer([28 28 1])







Step 5: Train the Neural Network

We’ll set the options to the default settings for the stochastic gradient descent with momentum. Set the maximum number of epochs at 20, and start the training with an initial learning rate of 0.0001.

options = trainingOptions('sgdm', ...


   'InitialLearnRate',1e-4, ...

   'Verbose',false, ...


To train the network, run this command:

net = trainNetwork(imdsTrain,layers,options);

Step 6: Predict Image Labels

We can now run the trained network on our test set. We didn’t use this set for training the network. It should be able to predict the image labels (in digits).

YPred = classify(net,imdsTest);

YTest = imdsTest.Labels;

Let’s calculate the accuracy. The accuracy is represented by the ratio between two numbers. The first number is of the images in our test data, while the second number is of the true labels in our test data, which match the classifications from the classify object.

accuracy = sum(YPred == YTest)/numel(YTest)

MATLAB will show the accuracy of the model, which should be around 94%.

Optimizing Machine Learning Training with Run:ai

When running machine learning at scale, you’ll need to manage a large number of computing resources and GPUs. Run:ai automates resource management and orchestration for machine learning infrastructure. With Run:ai, you can automatically run as many compute intensive experiments as needed.

Here are some of the capabilities you gain when using Run:ai:

  • Advanced visibility—create an efficient pipeline of resource sharing by pooling GPU compute resources.
  • No more bottlenecks—you can set up guaranteed quotas of GPU resources, to avoid bottlenecks and optimize billing.
  • A higher level of control—Run:AI enables you to dynamically change resource allocation, ensuring each job gets the resources it needs at any given time.

Run:ai simplifies machine learning infrastructure pipelines, helping data scientists accelerate their productivity and the quality of their models.

Learn more about the Run:ai GPU virtualization platform.