Recognize Digits Using ML in Elixir

Here's Philip Brown, founder of Prise, with a tutorial on building a fullstack machine learning application using Nx, Axon, and LiveView. Coincidentally, Fly.io is the perfect place to run your LiveView apps. Get started.

Machine learning allows you to solve problems that were once totally unimaginable. The ability for a computer to take an image and tell you what it sees was once only possible in science fiction.

Now, it's possible to build machine learning models that can do amazing things. However, part of the challenge of machine learning is that there are a lot of moving parts to learn. This means that solving a problem with machine learning can be a difficult task for an individual engineer.

One of the big advantages Elixir has over similar programming languages is the integrated nature of what you have available to you. You can do a lot in Elixir without ever leaving the comfort of the language you love.

What Are We Going to Build?

In this tutorial we're going to look at building out an end-to-end machine learning project using only Elixir. Boom! As if that wasn't enough, we're going to build a machine learning model that can recognize a handwritten digit. We'll train the model so that it will predict the digit from an image. We'll also build an application that can accept new handwritten digits from the user, and then display the prediction.

Here's a preview of what it looks like:

Let's get started!

Setting Up the Project

We're going to build this project using Phoenix, so the first thing we need is to create a new Phoenix project.

If you don't already have Elixir installed on your computer, you can find instructions for your operating system on the Elixir Website.

Once you have Phoenix installed, you can run the following command in terminal:

$ mix archive.install hex phx_new

With Elixir and Phoenix installed, we can create a new Phoenix project:

$ mix phx.new digits --no-ecto

I'm including the --no-ecto flag because we don't need a database for this project. This command should prompt you to install the project's dependencies. Hit Y on that prompt and wait for the dependencies to be installed.

Once the dependencies are installed, follow the onscreen instructions to run your new Phoenix application and verify that everything was set up correctly.

I'm also going to add the Tailwind package for styling the application. If you want to add Tailwind to your project add the following to the list of dependencies in your mix.exs file:

{:tailwind, "~> 0.1", runtime: Mix.env() == :dev}

Then follow the configuration instructions listed here.

Where Will Get Our Training Data?

One of the most important aspects of machine learning is having good, quality data to train on. When working on real life machine learning projects, expect to spend the majority of your time on the data.

Fortunately for us, there is already a ready made dataset we can use. The MNIST Database is a large dataset of handwritten digits that have already been prepared and labeled. This dataset is commonly used for training image recognition machine learning models. The dataset consists of images of handwritten digits from 0 - 9 that are already labeled.

Prepare the Project for Machine Learning

Next, we need to set up the machine learning model. The Elixir ecosystem has a number of exciting packages that can be used for training machine learning models.

The Nx package is the foundation of machine learning in Elixir. Nx allows us to manipulate our data using tensors, which are essentially efficient multi-dimensional arrays. When we say "tensor" below, just think "multi-dimensional array".

Next, we have EXLA, which provides hardware acceleration for training our models. Crunching the numbers of machine learning is a very intensive process, but EXLA makes that much faster.

Axon builds on top of Nx and makes it possible for us to create neural networks in Elixir.

Finally we have the Scidata package, which provides conveniences for working with machine learning datasets, including MNIST.

So, the first thing we need to do is to add those dependencies to our mix.exs file:

{:axon, "~> 0.1.0-dev", github: "elixir-nx/axon"},
{:exla, "~> 0.1.0-dev", github: "elixir-nx/nx", sparse: "exla"},
{:nx, "~> 0.1.0-dev", github: "elixir-nx/nx", sparse: "nx", override: true},
{:scidata, "~> 0.1.5"}

Then we can install our new dependencies from a terminal:

$ mix deps.get

Working With Our Training Data

There's a couple of steps required for getting and transforming the training data, so we'll start building out a module that can encapsulate everything that we're building:

defmodule Digits.Model do
  @moduledoc """
  The Digits Machine Learning model
  """
end

First up we'll add a download/0 function that downloads the training data for us. We're just delegating to the Scidata package for that.

def download do
  Scidata.MNIST.download()
end

This function returns a tuple of {images, labels}. However, we want to transform the images and labels so we can use them in our model.

First, we'll use the following function to transform the images:

def transform_images({binary, type, shape}) do
  binary
  |> Nx.from_binary(type)
  |> Nx.reshape(shape)
  |> Nx.divide(255)
end

The image data from the download includes the following:

  • Binary data - This is the image data as a binary.
  • The type of the data - In this example the type is {:u, 8} unsigned integer.
  • The shape of the data - In this example the shape is {60000, 1, 28, 28}. This means there are 60000 images, which all have 1 channel (ie they're black and white) and have a dimension of 28x28.

We can convert the binary into a tensor using Nx.

If we open up iex we can visualize the image data. Run the following command in a terminal to open up iex with our project loaded:

$ iex -S mix

Next, we run the following code:

{images, labels} = Digits.Model.download()

images
|> Digits.Model.transform_images()
|> Nx.slice_axis(0, 1, 0)
|> Nx.reshape({1, 1, 28, 28})
|> Nx.to_heatmap()

You should see the first handwritten digit of the dataset. This is what it looks like:

Sample image of number 5 digit heatmap

We can see the corresponding label for the image too. Let's see how to do that.

First, we pattern match the binary data and type from the downloaded label data.

{binary, type, _} = labels

Then we convert the binary to a tensor and "slice" off the first item as our example.

binary
|> Nx.from_binary(type)
|> Nx.new_axis(-1)
|> Nx.slice_axis(0, 1, 0)

The first label should be a 5. We'll refactor that code in our transform function to get the labels.

def transform_labels({binary, type, _}) do
  binary
  |> Nx.from_binary(type)
  |> Nx.new_axis(-1)
  |> Nx.equal(Nx.tensor(Enum.to_list(0..9)))
end

The labels of the training data are used as targets for the model's predictions. For each image, we know how it was labelled. During training, the model uses the labels to compare it's predictions with the actual correct result. The guessing is adjusted to give better results in the future.

Currently, the labels are integers from 0 - 9. You can think of them as 10 different categories. In our case, the categories are integers, but when training a machine learning model, you might have categories such as colors, sizes, types of animals, etc.

So we need to convert our categories into something that the machine learning model can understand. The way we do this is to convert the label into a tensor of size {1, 10}, where 10 is the number of categories you have.

For example:

#Nx.Tensor<
  u8[1][10]
  [
    [1, 0, 0, 0, 0, 0, 0, 0, 0, 0]
  ]
>

In this example, the long list of numbers has a 1 is in the first position. This represents the first category. In our case, that is the number "0", but it could also be the color "red", the size "small", or the type of animal "dog".

The second category would be:

#Nx.Tensor<
  u8[1][10]
  [
    [0, 1, 0, 0, 0, 0, 0, 0, 0, 0]
  ]
>

And so on.

This process is called one-hot encoding.

You can see what the first label of the training data is when it's been one-hot encoded using the following chunk of code. (Still in iex):

labels
|> Digits.Model.transform_labels()
|> Nx.slice_axis(0, 1, 0)

This should output the following tensor:

#Nx.Tensor<
  u8[1][10]
  [
    [0, 0, 0, 0, 0, 1, 0, 0, 0, 0]
  ]
>

Remember, we're working with the number 5 right now. This tensor is an array of zeros with a 1 in the index for the 5. Counting from 0, the 1 is in the 5th spot.

Next, we convert the images and labels into batches. During training, we feed the data into the model in batches rather all at once. In this example we're using a batch size of 32. This means each batch will include 32 examples.

batch_size = 32

images =
  images
  |> Digits.Model.transform_images()
  |> Nx.to_batched_list(batch_size)

labels =
  labels
  |> Digits.Model.transform_labels()
  |> Nx.to_batched_list(batch_size)

Next, we zip the images and labels together using Enum.zip. Then we split the dataset into training, testing, and validation datasets. We need to use the majority of the data for training, and then a portion of the data to use to test the accuracy of the model. In this example we're using 80% of the data for training and validation, and the remaining 20% unseen data will be used for testing.

data = Enum.zip(images, labels)

training_count = floor(0.8 * Enum.count(data))
validation_count = floor(0.2 * training_count)

{training_data, test_data} = Enum.split(data, training_count)
{validation_data, training_data} = Enum.split(train, validation_count)

Phew! That may seem pretty heavy but we've already achieved a lot! We've downloaded our training data, preprocessed it, and got it ready for building the model. During a real-life machine learning project you will likely spend a lot of time at acquiring, cleaning, and manipulating the data. We're now in a great position to build and train the model!

Building the Model

Next up we'll use Axon to build the machine learning model. Add a new function to the Digits.Model module with the following code:

def new({channels, height, width}) do
  Axon.input({nil, channels, height, width})
  |> Axon.flatten()
  |> Axon.dense(128, activation: :relu)
  |> Axon.dense(10, activation: :softmax)
end

First we need to set the input shape of the model to fit our training data. Next we flatten the previous layer and add a dense layer that uses relu as the activation function. Finally the output layer returns one of 10 labels (because our labels are 0 - 9).

You can experiment with different model configurations to get different results.

Training the Model

Now that we have the data and the model, we can start training. Add another function to Digits.Model to train the model:

def train(model, training_data, validation_data) do
  model
  |> Axon.Loop.trainer(:categorical_cross_entropy, Axon.Optimizers.adam(0.01))
  |> Axon.Loop.metric(:accuracy, "Accuracy")
  |> Axon.Loop.validate(model, validation_data)
  |> Axon.Loop.run(training_data, compiler: EXLA, epochs: 10)
end

We're using categorical cross entropy because we're matching multiple labels and the "adam" optimizer because it gives fairly good results. We'll track a single accuracy metric, and we'll also validate the model with our validation data from earlier to ensure the model is not over-fitting on the training data.

Finally we'll use EXLA as the compiler and we'll train for 10 epochs. An epoch is one cycle through the data, so this means we'll cycle through the data 10 times during training.

Testing Our Model

We can also test our model after training to get an idea of how well it performs. Add the following function to Digits.Model:

def test(model, state, test_data) do
  model
  |> Axon.Loop.evaluator(state)
  |> Axon.Loop.metric(:accuracy, "Accuracy")
  |> Axon.Loop.run(test_data)
end

This tests the model using previously unseen data to check the accuracy of the predictions.

Saving and Loading Our Model

The final thing to do is to add the ability to save and load the model. Our model is just an Elixir struct, so saving and loading it is simply a case of using Erlang's binary_to_term/1 and term_to_binary/1 functions:

def save!(model, state) do
  contents = :erlang.term_to_binary({model, state})

  File.write!(path(), contents)
end

def load! do
  path()
  |> File.read!()
  |> :erlang.binary_to_term()
end

def path do
  Path.join(Application.app_dir(:digits, "priv"), "model.axon")
end

Running the Model

Now that we've written all the code to transform the data, train, and test our machine learning model, we'll write a mix command to put it all together:

defmodule Mix.Tasks.Train do
  use Mix.Task

  @requirements ["app.start"]

  alias Digits

  def run(_) do
    EXLA.set_as_nx_default([:tpu, :cuda, :rocm, :host])

    {images, labels} = Digits.Model.download()

    images =
      images
      |> Digits.Model.transform_images()
      |> Nx.to_batched_list(32)

    labels =
      labels
      |> Digits.Model.transform_labels()
      |> Nx.to_batched_list(32)

    data = Enum.zip(images, labels)

    training_count = floor(0.8 * Enum.count(data))
    validation_count = floor(0.2 * training_count)

    {training_data, test_data} = Enum.split(data, training_count)
    {validation_data, training_data} = Enum.split(training_data, validation_count)

    model = Digits.Model.new({1, 28, 28})

    Mix.Shell.IO.info("training...")

    state = Digits.Model.train(model, training_data, validation_data)

    Mix.Shell.IO.info("testing...")

    Digits.Model.test(model, state, test_data)

    Digits.Model.save!(model, state)

    :ok
  end
end

We can run the training with the following command:

mix train

Setting Up the LiveView

Now that we have a trained machine learning model, we can set up a LiveView to accept new handwritten digits, and then display the predicted results.

First, we add a new live route to our router file in lib/digits_web/router.ex :

scope "/", DigitsWeb do
  pipe_through :browser

  live "/", PageLive, :index
end

Next, we create a new file under lib/digits_web/live called page_live.ex. This is our LiveView module where all the interactivity happens:

defmodule DigitsWeb.PageLive do
  @moduledoc """
  PageLive LiveView
  """

  use DigitsWeb, :live_view
end

When a user submits a new handwritten digit, the machine learning model makes a prediction on what digit was written and then the LiveView displays the prediction to the user. However, when the LiveView is first loaded, there isn't a prediction to display. So, first, we need to initiate the prediction assign value as nil inside the mount/3 callback:

def mount(_params, _session, socket) do
  {:ok, assign(socket, %{prediction: nil})}
end

Next, the render/1 function is responsible for rendering the LiveView:

def render(assigns) do
  ~H"""
  <div id="wrapper" phx-update="ignore">
    <div id="canvas" phx-hook="Draw"></div>
  </div>

  <div>
    <button phx-click="reset">Reset</button>
    <button phx-click="predict">Predict</button>
  </div>

  <%= if @prediction do %>
  <div>
    <div>
      Prediction:
    </div>
    <div>
      <%= @prediction %>
    </div>
  </div>
  <% end %>
  """
end

Notice above that we have a div with the id of "canvas". This will have an HTML canvas attached. The phx-hook uses Javascript to let us interact with the canvas. The canvas div is wrapped in another div with the phx-update="ignore" because we don't want Phoenix to update it.

Next are two buttons, one to reset the canvas and one to make a prediction from what the user drew. Each of these buttons are wired up to phx-click triggers.

Finally, if we have a prediction, it is displayed.

Fly.io ❤️ Elixir

Fly.io is a great way to run your Phoenix LiveView app close to your users. It's really easy to get started. You can be running in minutes.

Deploy a Phoenix app today!  

Adding the Canvas

Next, we need some input from the user. We could let the user upload images using Phoenix's LiveView upload functionality, but a better (and way cooler) experience is to let the user draw new examples directly into the LiveView.

There's a handy NPM package called draw-on-canvas that make this part easy.

To install it, cd into the assets directory and run the following command in a terminal:

$ npm i draw-on-canvas

This installs the draw-on-canvas into the project.

Now we connect the draw-on-canvas package to our LiveView via a hook. Open up assets/js/app.js and import the draw-on-canvas package:

import Draw from 'draw-on-canvas'

Let's create a new Hooks object:

let Hooks = {}

Remember to register the hook object in the LiveSocket:

let liveSocket = new LiveSocket("/live", Socket, {
  params: {_csrf_token: csrfToken},
  hooks: Hooks
})

Next we add a new Draw hook:

Hooks.Draw = {}

We need to implement the mounted function, which is called when the hook is mounted. This is where we set up the canvas:

Hooks.Draw = {
  mounted() {
    this.draw = new Draw(this.el, 384, 384, {
      backgroundColor: "black",
      strokeColor: "white",
      strokeWeight: 10
    })
  }
}

When we open the app in a browser, we should see a black square canvas that we can draw on!

Interacting With the Canvas

Remember back in our PageLive module, we added two buttons for interacting with the canvas.

The first button is used to reset the canvas. When the button is pressed we send a message to the client to reset the canvas. The push_event function makes this easy.

Our new "reset" event handler in PageLive looks like this:

def handle_event("reset", _params, socket) do
  {:noreply,
    socket
    |> assign(prediction: nil)
    |> push_event("reset", %{})
  }
end

When the reset button is clicked, the phx-click trigger sends the reset event to the server. We then push an event called reset to the client. We also set the prediction to nil in the socket assigns.

On the Javascript side, we add a handleEvent, that listens for the reset event, and resets the canvas:

this.handleEvent("reset", () => {
  this.draw.reset()
})

Next, let's make our "predict" button work. We want to grab the contents of the canvas as an image. Again, we send a message to the client from the PageLive LiveView module:

def handle_event("predict", _params, socket) do
  {:noreply, push_event(socket, "predict", %{})}
end

In the mounted callback, we add another handleEvent. This grabs the contents of the canvas as a data URL and sends it to the server using pushEvent:

this.handleEvent("predict", () => {
  this.pushEvent("image", this.draw.canvas.toDataURL('image/png'))
})

Making Predictions

Now that we hooked up the buttons to reset the canvas and send up the canvas contents to make a prediction, we will use the image from the canvas as a new input to our machine learning model.

We can accept the image data URL from the client using another handle_event/3 callback function:

def handle_event("image", "data:image/png;base64," <> raw, socket) do
  name = Base.url_encode64(:crypto.strong_rand_bytes(10), padding: false)
  path = Path.join(System.tmp_dir!(), "#{name}.png")

  File.write!(path, Base.decode64!(raw))

  prediction = Digits.Model.predict(path)

  File.rm!(path)

  {:noreply, assign(socket, prediction: prediction)}
end

In this function, we use a binary pattern matching on the params to get the image data. Next, we generate a random file name and create a path to a temporary directory for storing the image. Then we decode the image data and write it to the path.

Next we pass the path into the Digits.Model.predict/1 function and return a prediction. The prediction result is a number between 0 and 9. We'll write that function next.

Finally, we delete the image file and assign the prediction to the socket for display in our LiveView.

Before we can use the user's drawing with our model, we need to prepare the image. We need to:

  • Convert it to grayscale to reduce the number channels from 3 to 1
  • Resize it to 28 x 28

The Evision library can do these changes for us. Let's add it as a dependency in our mix.exs file now:

{:evision, "~> 0.1.0-dev", github: "cocoa-xu/evision", branch: "main"}

Install the dependency using:

mix deps.get

In the Digits.Model module, let's add a new function for making a prediction.

def predict(path) do
  {:ok, mat} = Evision.imread(path, flags: Evision.cv_IMREAD_GRAYSCALE)
  {:ok, mat} = Evision.resize(mat, [28, 28])

  data =
    Evision.Nx.to_nx(mat)
    |> Nx.reshape({1, 28, 28})
    |> List.wrap()
    |> Nx.stack()

  {model, state} = load!()

  model
  |> Axon.predict(state, data, compiler: EXLA)
  |> Nx.argmax()
  |> Nx.to_number()
end

First, we read the image path and convert it to grayscale. This reduces the number of channels from 3 to 1. Then we resize the image to 28 x 28.

We also need to convert the image data to an Nx tensor and reshape it to an expected correct shape. Our machine learning model expects a "batch" of inputs, and so we'll wrap the tensor using List.wrap/1 and then stack it using Nx.stack/1.

Next, we load the model and the state, using the load!/0 function from earlier. Ideally you wouldn't be loading the model and state for each prediction, but it's fine for our basic example.

We pass the model, state and data into the Axon.predict/4 function. One thing to note is, you will need to add require Axon to the Digits.Model module because Axon.predict/4 is actually a macro.

The Axon.predict/4 function returns a prediction in the form of a one-hot encoded tensor. We use the Nx.argmax/1 function to convert it to a tensor that contains a single scalar value between 0 and 9, and then we use Nx.to_number/1 to return the value as a number.

Our predicted number is set as the prediction is the LiveView assigns, displaying it to the user.

We Built an End-to-end Machine Learning Application in Elixir!

Wow! Check out what we just did!

We built an end-to-end machine learning application using Elixir! We trained a model from scratch. We used LiveView for interactive, real-time application input from the user. We ran predictions and displayed the results interactively.

One of the most amazing things here was that we did it all using Elixir and didn't need external machine learning tools or languages. Machine learning in Elixir is still maturing, but I hope this inspires you to try something new in your own project.

Full code for this tutorial is at philipbrown/handwritten-digits-elixir.