Bernardo de Lemos

A Keras-like Deep Learning Crate in Rust Powered by Candle

While candle provides a low-level API for building neural network models, some users (including myself) may prefer a more intuitive and user-friendly way to define, compile, and train their models. That’s why I propose the addition of a high-level Keras-like API to candle. This API would allow users to define models in a sequential manner by adding layers one after the other. It would also provide methods for model compilation, training, and evaluation.


Introducing candle: A Deep Learning Rust Crate

Rust is a powerful and fast programming language that offers many benefits for developing high-performance applications. However, Rust is not very popular among data scientists and machine learning practitioners, who often prefer Python or other languages with more mature and user-friendly libraries for deep learning.

candle is a deep learning Rust crate from Hugging Face that aims to provide a low-level API for building neural network models. candle leverages the speed and safety of Rust, while also exposing the flexibility and expressiveness of Tensor operations.

There some other very promising and interesting crates that provide fast and easy inference for text embeddings and text generation, which are using candle under the hood. These crates are extremely fast and useful and can enable serverless model inference.

text-embeddings-inference: This crate allows you to compute text embeddings from various pre-trained models, such as BERT, RoBERTa, DistilBERT, and more. You can use it as a library or through the command-line interface (CLI). text-generation-inference: This crate allows you to generate text from various pre-trained models, such as GPT-2, GPT-3, T5, and more. You can use it as a library or through the CLI.

What is candle?

candle is a deep learning Rust crate that provides a low-level API for building neural network models. candle is based on candle_core, a crate that implements Tensor operations and automatic differentiation. candle also depends on candle_nn, a crate that implements common neural network layers and modules.

With candle, you can define your own custom neural network models by composing different layers and modules. You can also use candle to perform various operations on Tensors, such as slicing, reshaping, broadcasting, indexing, and arithmetic. candle supports both CPU and GPU devices, and allows you to switch between them easily.

How to use candle?

To use candle, you need to add it as a dependency in your Cargo.toml file:

[dependencies]
candle-core = { git = "https://github.com/huggingface/candle.git"}
candle-nn = { git = "https://github.com/huggingface/candle.git"}
candle-datasets  = { git = "https://github.com/huggingface/candle.git"}

Creating and Training Neural Network using candle`

Defining the Neural Network

To define a neural network model with candle, you need to implement the Module trait for your custom structure. The Module trait requires you to implement a forward method, which takes a Tensor as input and returns a Tensor as output. For example, here is how you can define a simple linear regression model with candle:

use candle_core::{Device, Result, Tensor, Var, DType};
use candle_nn::{Linear, Module, SGD, Optimizer, VarBuilder, VarMap, linear};

struct LinearRegression {
    linear: Linear,
}

impl LinearRegression {
    fn new(input_dim: usize, output_dim: usize, vs: VarBuilder) -> Result<Self> {
        // Create a Linear layer with the weight and bias tensors
        let linear = candle_nn::linear(input_dim, output_dim, vs.pp("linear"))?;
        Ok(LinearRegression{ linear })
    }
}

impl Module for LinearRegression {
    fn forward(&self, x: &Tensor) -> Result<Tensor> {
        self.linear.forward(x)
    }
}

Code Breakdown

The code uses two external crates: candle_core and candle_nn, which are libraries for tensor computation and neural network models.

The code defines a struct called LinearRegression, which represents a linear regression model with a single linear layer. The struct has one field: linear, which is an instance of the Linear struct from the candle_nn crate. The Linear struct implements the Module trait, which defines a common interface for neural network modules.

The LinearRegression struct also implements the Module trait, which requires it to define a forward method that takes a Tensor as input and returns a Tensor as output. The forward method simply calls the forward method of the linear field, which performs a linear transformation on the input tensor using the weight and bias tensors of the Linear struct.

The LinearRegression struct has a new method that takes three arguments: input_dim, output_dim, and vs. The input_dim and output_dim are the sizes of the input and output tensors, respectively. And vs is an instance of the VarBuilder struct from the candle_core crate, which is a helper for creating variables (tensors that can be updated by an optimizer).

The new method uses the candle_nn::linear function to create a Linear struct with the given dimensions and the given variable builder. The function also takes a Path argument, which is a way of naming and organizing variables in a hierarchical structure. The vs.pp("linear") expression creates a new Path by appending the string "linear" to the current Path of the variable builder.

The new method returns a Result<Self>, which is a type that can either hold a value of the Self type (in this case, LinearRegression) or an error. The ? operator after the candle_nn::linear function call is a shorthand for handling errors: if the function returns an error, the ? operator will return the error from the new method; otherwise, it will unwrap the value and assign it to the linear variable.

Training the Model

To train your model, you need to create an optimizer, such as SGD or Adam, and pass it the model parameters. Then, you can use a loop to iterate over your data, compute the loss, and update the model parameters. For example, here is how you can train the linear regression model on some dummy data:

fn main() -> Result<()> {
    let device = Device::Cpu;

    // Generate some linear data, y = 3.x1 + x2 - 2.
    let w_gen = Tensor::new(&[[3f32, 1.]], &Device::Cpu)?;
    let b_gen = Tensor::new(-2f32, &Device::Cpu)?;
    let gen = Linear::new(w_gen, Some(b_gen));
    let sample_xs = Tensor::new(&[[2f32, 1.], [7., 4.], [-4., 12.], [5., 8.]], &Device::Cpu)?;
    let sample_ys = gen.forward(&sample_xs)?;

    // Create variables; this is necessary to run the optimizer
    let varmap = VarMap::new();
    let vs = VarBuilder::from_varmap(&varmap, DType::F32, &device);
    let model = LinearRegression::new(2, 1, vs.clone())?;
    
    // Create SGD optimizer
    let mut sgd = SGD::new(varmap.all_vars(), 0.004)?;

    // Train
    for epoch in 1..=10 {
        let y_pred = model.forward(&sample_xs)?;
        let loss = candle_nn::loss::mse(&sample_ys, &y_pred)?;
        println!("Epoch: {}, Loss: {}", epoch, loss);

        sgd.backward_step(&loss)?;

    }

    Ok(())
}

Code Breakdown

The code creates a device variable that holds the value Device::Cpu, which is an enum variant from the candle_core crate that represents the CPU device for tensor computation.

It generates some linear data using the Linear struct from the candle_nn crate. It does that by creating tensors using the Tensor::new method from the candle_core crate, and passes them to the Linear::new method to create a Linear struct. Finally it creates a tensor of input samples using the Tensor::new method, and passes it to the forward method of the Linear struct to get the corresponding output tensor.

We create a varmap variable that holds an instance of the VarMap struct from the candle_core, which is a container for variables. The vs variable holds an instance of the VarBuilder struct, which is a helper for creating variables, is has the data type (DType::F32), and the device to the VarBuilder::from_varmap method to create a VarBuilder struct.

LinearRegression::new is used to create a LinearRegression struct with the given dimensions and variable builder.

The sgd variable holds a mutable instance of the SGD struct from candle_nn, which is an implementation of the stochastic gradient descent optimizer. It’s instantiated using varmap and the learning rate (0.004) by calling the SGD::new method.

Finnaly, we train our model for 10 epochs. In each epoch, we perform the following steps:

  1. Do a forward pass by calling the forward method of the LinearRegression struct to get the predicted output tensor for the input samples.
  2. Compute the loss using candle_nn::loss::mse function to compute the mean squared error loss between the predicted output and the true output.
  3. Apply backpropagation via the backward_step method of the SGD struct to update the variables using the gradient of the loss.

Running the Model

To train the model we execute cargo run, which produces the following output:

Epoch: 1, Loss: [32.9661]
Epoch: 2, Loss: [22.2148]
Epoch: 3, Loss: [15.1691]
Epoch: 4, Loss: [10.4938]
Epoch: 5, Loss: [7.3743]
Epoch: 6, Loss: [5.2876]
Epoch: 7, Loss: [3.8900]
Epoch: 8, Loss: [2.9533]
Epoch: 9, Loss: [2.3252]
Epoch: 10, Loss: [1.9036]

We can see that the training loss is decreasing very rapiiidly.

A Keras-Like API for candle

While candle provides a low-level API for building neural network models, some users (including myself) may prefer a more intuitive and user-friendly way to define, compile, and train their models. That’s why I propose the addition of a high-level Keras-like API to candle. This API would allow users to define models in a sequential manner by adding layers one after the other. It would also provide methods for model compilation, training, and evaluation.

Here is an example of how the Keras-like API would look like:

use candle_core::{Device, Result, Tensor};
use candle_nn::{Linear, Module};

struct Sequential {
    layers: Vec<Linear>,
}

impl Sequential {
    fn new() -> Self {
        Sequential {
            layers: Vec::new(),
        }
    }

    fn add(&mut self, layer: Linear) {
        self.layers.push(layer);
    }

    fn compile(&self) -> Model {
        Model::new( &self.layers)
    }
}

struct Model {
    layers: Vec<Linear>,
}

impl Model {
    fn new(layers: &Vec<Linear>) -> Model {
        Model { layers: layers.clone() }
    }

    fn forward(&self, image: &Tensor) -> Result<Tensor> {
        let mut x = image.clone();
        for layer in &self.layers {
            x = layer.forward(&x)?;
            x = x.relu()?;
        }
        Ok(x)
    }
}

fn main() -> Result<()> {
    let device = Device::Cpu;

    let mut model = Sequential::new();
    model.add(Linear::new(
        Tensor::randn(0f32, 1.0, (100, 784), &device)?,
        Some(Tensor::randn(0f32, 1.0, (100,), &device)?),
    ));
    model.add(Linear::new(
        Tensor::randn(0f32, 1.0, (10, 100), &device)?,
        Some(Tensor::randn(0f32, 1.0, (10,), &device)?),
    ));

    let compiled_model = model.compile();

    let dummy_image = Tensor::randn(0f32, 1.0, (1, 784), &device)?;

    let digit = compiled_model.forward(&dummy_image)?;
    println!("Digit: {:?}", digit);
    Ok(())
}

As you can see, the Keras-like API is much simpler and more concise than the low-level API. It also abstracts away some of the details, such as creating the Module trait, and provides user-friendly methods for adding layers and compiling the model.

I believe that implementing a Keras-like API would greatly enhance the usability and appeal of candle. This feature can make it easier for users to define and train neural network models, making our crate more accessible and user-friendly.

Conclusion

candle is a deep learning Rust crate that provides a low-level API for building neural network models. candle leverages the speed and safety of Rust, while also exposing the flexibility and expressiveness of Tensor operations. I propose to add a high-level Keras-like API to candle, which would provide a more intuitive and user-friendly way to define, compile, and train neural network models. I hope that this feature will make candle more attractive and useful for data scientists and machine learning practitioners who want to use Rust for deep learning.

References