Candle

Candle defines itself as a "minimalist ML framework for Rust with a focus on performance (including GPU support) and ease of use".

Let's see examples of using Candle for neural networks and LLMs in Rust.

Training simple dense neural networks with Candle

Let's start by defining our Cargo.toml file:

1[package]
2name = "candle-nn"
3version = "0.1.0"
4edition = "2021"
5
6[dependencies]
7candle = { version = "0.4", package = "candle-core" }
8candle-nn = { version = "0.4" }
9tqdm = "0.6"
10
11[profile.dev.package."*"]
12opt-level = 3
13
14[features]
15cuda = ["candle/cuda"]

and then our src/main.rs:

1use candle::{DType, Device, Tensor};
2use candle_nn::{
3 linear,
4 loss::mse,
5 optim::{AdamW, Optimizer, ParamsAdamW},
6 Linear, Module, VarBuilder, VarMap,
7};
8use std::error::Error;
9use tqdm::tqdm;
10
11struct DenseNeuralNetwork {
12 ln1: Linear,
13 ln2: Linear,
14}
15
16impl DenseNeuralNetwork {
17 fn new(vs: VarBuilder) -> Result<Self, Box<dyn Error>> {
18 let ln1 = linear(10, 100, vs.pp("ln1"))?;
19 let ln2 = linear(100, 1, vs.pp("ln2"))?;
20 Ok(Self { ln1, ln2 })
21 }
22}
23
24impl Module for DenseNeuralNetwork {
25 fn forward(&self, xs: &Tensor) -> candle::Result<Tensor> {
26 let xs = self.ln1.forward(xs)?;
27 let xs = xs.relu()?;
28 self.ln2.forward(&xs)
29 }
30}
31
32fn main() -> Result<(), Box<dyn Error>> {
33 let device = Device::cuda_if_available(0)?;
34 println!("Using device: {:?}", device);
35 let varmap = VarMap::new();
36 let vs = VarBuilder::from_varmap(&varmap, DType::F32, &device);
37
38 let dnn = DenseNeuralNetwork::new(vs)?;
39 let mut optimizer =
40 AdamW::new(varmap.all_vars(), ParamsAdamW::default())?;
41
42 let beta = Tensor::randn(0f32, 1., (10, 1), &device)?;
43
44 // Generate a training set
45 let xs_train = Tensor::randn(0f32, 1., (100, 10), &device)?;
46 let mu = xs_train.cos()?.matmul(&beta)?;
47 let eps = Tensor::randn(0f32, 1., (100, 1), &device)?;
48 let ys_train = (mu + eps)?;
49
50 // Generate a validation set
51 let xs_val = Tensor::randn(0f32, 1., (100, 10), &device)?;
52 let mu = xs_val.cos()?.matmul(&beta)?;
53 let eps = Tensor::randn(0f32, 1., (100, 1), &device)?;
54 let ys_val = (mu + eps)?;
55
56 let n_epochs = 10_000;
57 let mut losses_val = Vec::<f32>::with_capacity(n_epochs);
58
59 for epoch in tqdm(0..n_epochs) {
60 if epoch % (n_epochs / 10) == 0 || epoch == n_epochs - 1 {
61 losses_val
62 .push(mse(&dnn.forward(&xs_val)?, &ys_val)?.to_scalar()?);
63 }
64
65 let gradients =
66 mse(&dnn.forward(&xs_train)?, &ys_train)?.backward()?;
67 optimizer.step(&gradients)?;
68 }
69 println!("Losses on validation set: {:?}", losses_val);
70 Ok(())
71}

For maximum performance, it's crucial to compile it with native CPU flag, e.g.: RUSTFLAGS="-Ctarget-cpu=native" cargo run.

If you have a CUDA enabled GPU and the drivers are properly installed on your computer (check with nvidia-smi), you can run with cargo run --features cuda.

Mistral example with Candle

Let's start by defining our Cargo.toml file:

1[package]
2name = "candle-mistral"
3version = "0.1.0"
4edition = "2021"
5
6[dependencies]
7candle = { version = "0.4", package = "candle-core" }
8candle-nn = { version = "0.4" }
9candle-transformers = { version = "0.4" }
10hf-hub = "0.3.2"
11tokenizers = "0.15"
12
13# Optional dependencies
14accelerate-src = { version = "0.3.0", optional = true }
15intel-mkl-src = { version = "0.8", optional = true, features = [
16 "mkl-static-lp64-iomp",
17] }
18
19# Optional features
20# E.g.: compile with
21# cargo run
22[features]
23default = []
24accelerate = [
25 "dep:accelerate-src",
26 "candle/accelerate",
27 "candle-nn/accelerate",
28 "candle-transformers/accelerate",
29]
30mkl = [
31 "dep:intel-mkl-src",
32 "candle/mkl",
33 "candle-nn/mkl",
34 "candle-transformers/mkl",
35]
36
37[profile.dev.package."*"]
38opt-level = 3

and then our src/main.rs:

1// Adapted from https://github.com/huggingface/candle/blob/main/candle-examples/examples/quantized/main.rs
2// which have licenses
3// https://github.com/huggingface/candle/blob/main/LICENSE-APACHE
4// https://github.com/huggingface/candle/blob/main/LICENSE-MIT
5
6#[cfg(feature = "mkl")]
7extern crate intel_mkl_src;
8
9#[cfg(feature = "accelerate")]
10extern crate accelerate_src;
11
12use std::io::Write;
13use tokenizers::Tokenizer;
14
15use candle::quantized::gguf_file;
16use candle::{Device, Tensor};
17use candle_transformers::generation::LogitsProcessor;
18
19use candle_transformers::models::quantized_llama as model;
20use model::ModelWeights;
21
22fn main() -> Result<(), Box<dyn std::error::Error>> {
23 // The prompt. If None, then, will be an iteractive chat.
24 let prompt: Option<String> = None;
25
26 // The length of the sample to generate (in tokens).
27 let sample_len: usize = 100;
28
29 // The temperature used to generate samples, use 0 for greedy sampling.
30 let temperature: f64 = 0.8;
31
32 // Nucleus sampling probability cutoff.
33 let top_p: Option<f64> = None;
34
35 // The seed to use when generating random samples.
36 let seed: u64 = 299792458;
37
38 // Display the token for the specified prompt.
39 let verbose_prompt: bool = false;
40
41 // Penalty to be applied for repeating tokens, 1. means no penalty.
42 let repeat_penalty: f32 = 1.1;
43
44 // The context size to consider for the repeat penalty.
45 let repeat_last_n: usize = 64;
46
47 let temperature = if temperature == 0. {
48 None
49 } else {
50 Some(temperature)
51 };
52
53 println!(
54 "avx: {}, neon: {}, simd128: {}, f16c: {}",
55 candle::utils::with_avx(),
56 candle::utils::with_neon(),
57 candle::utils::with_simd128(),
58 candle::utils::with_f16c()
59 );
60
61 println!(
62 "temp: {:.2} repeat-penalty: {:.2} repeat-last-n: {}",
63 temperature.unwrap_or_default(),
64 repeat_penalty,
65 repeat_last_n
66 );
67
68 //let repo = "TheBloke/Mistral-7B-v0.1-GGUF";
69 let repo = "TheBloke/Mistral-7B-Instruct-v0.1-GGUF";
70
71 // let filename = h"mistral-7b-instruct-v0.1.Q4_K_S.gguf";
72 let filename = "mistral-7b-instruct-v0.1.Q2_K.gguf";
73
74 let api = hf_hub::api::sync::Api::new()?;
75 let api = api.model(repo.to_string());
76 let model_path = api.get(filename)?;
77
78 let mut file = std::fs::File::open(model_path)?;
79 let start = std::time::Instant::now();
80 let device = Device::Cpu;
81
82 let mut model = {
83 let model = gguf_file::Content::read(&mut file)?;
84 let mut total_size_in_bytes = 0;
85 for (_, tensor) in model.tensor_infos.iter() {
86 let elem_count = tensor.shape.elem_count();
87 total_size_in_bytes += elem_count
88 * tensor.ggml_dtype.type_size()
89 / tensor.ggml_dtype.block_size();
90 }
91 println!(
92 "loaded {:?} tensors ({}) in {:.2}s",
93 model.tensor_infos.len(),
94 &format_size(total_size_in_bytes),
95 start.elapsed().as_secs_f32(),
96 );
97 ModelWeights::from_gguf(model, &mut file, &device)?
98 };
99 println!("model built");
100
101 let api = hf_hub::api::sync::Api::new()?;
102 let repo = "mistralai/Mistral-7B-v0.1";
103 let api = api.model(repo.to_string());
104 let tokenizer_path = api.get("tokenizer.json")?;
105
106 let tokenizer = Tokenizer::from_file(tokenizer_path)
107 .map_err(|e| format!("Error loading tokenizer: {e}"))?;
108
109 let mut pre_prompt_tokens = vec![];
110 loop {
111 let prompt_str = {
112 let prompt = if let Some(ref prompt) = prompt {
113 prompt.to_owned()
114 } else {
115 print!("> ");
116 std::io::stdout().flush()?;
117 let mut prompt = String::new();
118 std::io::stdin().read_line(&mut prompt)?;
119 if prompt.ends_with('\n') {
120 prompt.pop();
121 if prompt.ends_with('\r') {
122 prompt.pop();
123 }
124 }
125 prompt
126 };
127
128 format!("[INST] {prompt} [/INST]")
129 };
130 print!("{}", &prompt_str);
131 let tokens = tokenizer
132 .encode(prompt_str, true)
133 .map_err(|e| format!("Error encoding tokenizer: {e}"))?;
134 if verbose_prompt {
135 for (token, id) in
136 tokens.get_tokens().iter().zip(tokens.get_ids().iter())
137 {
138 let token =
139 token.replace('', " ").replace("<0x0A>", "\n");
140 println!("{id:7} -> '{token}'");
141 }
142 }
143
144 let prompt_tokens =
145 [&pre_prompt_tokens, tokens.get_ids()].concat();
146 let to_sample = sample_len.saturating_sub(1);
147 let prompt_tokens = if prompt_tokens.len() + to_sample
148 > model::MAX_SEQ_LEN - 10
149 {
150 let to_remove =
151 prompt_tokens.len() + to_sample + 10 - model::MAX_SEQ_LEN;
152 prompt_tokens[prompt_tokens.len().saturating_sub(to_remove)..]
153 .to_vec()
154 } else {
155 prompt_tokens
156 };
157 let mut all_tokens = vec![];
158 let mut logits_processor =
159 LogitsProcessor::new(seed, temperature, top_p);
160
161 let start_prompt_processing = std::time::Instant::now();
162 let mut next_token = {
163 let input = Tensor::new(prompt_tokens.as_slice(), &device)?
164 .unsqueeze(0)?;
165 let logits = model.forward(&input, 0)?;
166 let logits = logits.squeeze(0)?;
167 logits_processor.sample(&logits)?
168 };
169 let prompt_dt = start_prompt_processing.elapsed();
170 all_tokens.push(next_token);
171 print_token(next_token, &tokenizer);
172
173 let eos_token = *tokenizer.get_vocab(true).get("</s>").unwrap();
174
175 let start_post_prompt = std::time::Instant::now();
176 for index in 0..to_sample {
177 let input =
178 Tensor::new(&[next_token], &device)?.unsqueeze(0)?;
179 let logits =
180 model.forward(&input, prompt_tokens.len() + index)?;
181 let logits = logits.squeeze(0)?;
182 let logits = if repeat_penalty == 1. {
183 logits
184 } else {
185 let start_at =
186 all_tokens.len().saturating_sub(repeat_last_n);
187 candle_transformers::utils::apply_repeat_penalty(
188 &logits,
189 repeat_penalty,
190 &all_tokens[start_at..],
191 )?
192 };
193 next_token = logits_processor.sample(&logits)?;
194 all_tokens.push(next_token);
195 print_token(next_token, &tokenizer);
196 if next_token == eos_token {
197 break;
198 };
199 }
200 let dt = start_post_prompt.elapsed();
201 println!(
202 "\n\n{:4} prompt tokens processed: {:.2} token/s",
203 prompt_tokens.len(),
204 prompt_tokens.len() as f64 / prompt_dt.as_secs_f64(),
205 );
206 println!(
207 "{:4} tokens generated: {:.2} token/s",
208 to_sample,
209 to_sample as f64 / dt.as_secs_f64(),
210 );
211
212 match prompt {
213 Some(_) => break,
214 None => {
215 pre_prompt_tokens =
216 [prompt_tokens.as_slice(), all_tokens.as_slice()]
217 .concat()
218 }
219 }
220 }
221
222 Ok(())
223}
224
225fn print_token(next_token: u32, tokenizer: &Tokenizer) {
226 // Extracting the last token as a string is complicated, here we just apply some simple
227 // heuristics as it seems to work well enough for this example. See the following for more
228 // details:
229 // https://github.com/huggingface/tokenizers/issues/1141#issuecomment-1562644141
230 if let Some(text) = tokenizer.id_to_token(next_token) {
231 let text = text.replace('', " ");
232 let ascii = text
233 .strip_prefix("<0x")
234 .and_then(|t| t.strip_suffix('>'))
235 .and_then(|t| u8::from_str_radix(t, 16).ok());
236 match ascii {
237 None => print!("{text}"),
238 Some(ascii) => {
239 if let Some(chr) = char::from_u32(ascii as u32) {
240 if chr.is_ascii() {
241 print!("{chr}")
242 }
243 }
244 }
245 }
246 let _ = std::io::stdout().flush();
247 }
248}
249
250fn format_size(size_in_bytes: usize) -> String {
251 if size_in_bytes < 1_000 {
252 format!("{}B", size_in_bytes)
253 } else if size_in_bytes < 1_000_000 {
254 format!("{:.2}KB", size_in_bytes as f64 / 1e3)
255 } else if size_in_bytes < 1_000_000_000 {
256 format!("{:.2}MB", size_in_bytes as f64 / 1e6)
257 } else {
258 format!("{:.2}GB", size_in_bytes as f64 / 1e9)
259 }
260}

For maximum performance, it's crucial to compile it with native CPU flag, e.g.: RUSTFLAGS="-Ctarget-cpu=native" cargo run.

If you wish to deploy this as Docker container, you can use this Dockerfile as a recipe:

1FROM rust
2
3ENV RUSTFLAGS="-Ctarget-cpu=native"
4
5WORKDIR /app
6
7COPY Cargo.toml Cargo.toml
8
9COPY Cargo.lock Cargo.lock
10
11RUN mkdir src && echo 'fn main() {panic!("not ready");}' > src/main.rs
12
13RUN cargo build --release --locked && rm -rf src
14
15COPY src src
16
17RUN touch src/main.rs && cargo build --release --locked
18
19CMD cargo run --release --locked

Llama example with Candle

Before running this, you need to download the Llama 2 model files:

All that being done, let's start by defining our Cargo.toml file:

1[package]
2name = "candle-llama"
3version = "0.1.0"
4edition = "2021"
5
6[dependencies]
7candle = { version = "0.4", package = "candle-core" }
8candle-nn = { version = "0.4" }
9candle-transformers = { version = "0.4" }
10hf-hub = "0.3.2"
11tokenizers = "0.15"
12serde_json = "1.0.107"
13
14[profile.dev.package."*"]
15opt-level = 3

and then our src/main.rs:

1// Adapted from
2// https://github.com/huggingface/candle/blob/main/candle-examples/examples/llama/main.rs
3// which have licenses
4// https://github.com/huggingface/candle/blob/main/LICENSE-APACHE
5// https://github.com/huggingface/candle/blob/main/LICENSE-MIT
6//
7// An implementation of LLaMA https://github.com/facebookresearch/llama
8//
9// This is based on nanoGPT in a similar way to:
10// https://github.com/Lightning-AI/lit-llama/blob/main/lit_llama/model.py
11//
12// The tokenizer config can be retrieved from:
13// https://huggingface.co/hf-internal-testing/llama-tokenizer/raw/main/tokenizer.json
14
15use candle::{DType, Device, Tensor};
16use candle_nn::VarBuilder;
17use candle_transformers::{
18 generation::LogitsProcessor, models::llama as model,
19};
20use hf_hub::{api::sync::Api, Repo, RepoType};
21use model::{Llama, LlamaConfig};
22use std::io::Write;
23use tokenizers::Tokenizer;
24
25const EOS_TOKEN: &str = "</s>";
26
27fn main() -> Result<(), Box<dyn std::error::Error>> {
28 // The initial prompt.
29 let prompt = "I enjoy Rust because ";
30 // The length of the sample to generate (in tokens).
31 let sample_len: usize = 100;
32 // The seed to use when generating random samples.
33 let seed = 7;
34
35 // The temperature used to generate samples.
36 let temperature: Option<f64> = None;
37 // Nucleus sampling probability cutoff.
38 let top_p: Option<f64> = None;
39 // Penalty to be applied for repeating tokens, 1. means no penalty.
40 let repeat_penalty: f32 = 1.0;
41 // The context size to consider for the repeat penalty.
42 let repeat_last_n: usize = 64;
43
44 let device = Device::Cpu;
45 let dtype = DType::F16;
46 let (llama, tokenizer_filename, mut cache) = {
47 let api = Api::new()?;
48 let model_id = "meta-llama/Llama-2-7b-hf".to_string();
49 println!("loading the model weights from {model_id}");
50 let revision = "main".to_string();
51 let api = api.repo(Repo::with_revision(
52 model_id,
53 RepoType::Model,
54 revision,
55 ));
56
57 let tokenizer_filename = api.get("tokenizer.json")?;
58
59 let config_filename = api.get("config.json")?;
60 let config: LlamaConfig =
61 serde_json::from_slice(&std::fs::read(config_filename)?)?;
62 let config = config.into_config(false);
63
64 let mut filenames = vec![];
65 for rfilename in [
66 "model-00001-of-00002.safetensors",
67 "model-00002-of-00002.safetensors",
68 ] {
69 let filename = api.get(rfilename)?;
70 filenames.push(filename);
71 }
72
73 println!("building the model");
74 let cache = model::Cache::new(true, dtype, &config, &device)?;
75
76 let vb = unsafe {
77 VarBuilder::from_mmaped_safetensors(
78 &filenames, dtype, &device,
79 )?
80 };
81 (Llama::load(vb, &config)?, tokenizer_filename, cache)
82 };
83 let tokenizer = Tokenizer::from_file(tokenizer_filename)
84 .map_err(|e| format!("Error loading tokenizer: {e}"))?;
85 let eos_token_id = tokenizer.token_to_id(EOS_TOKEN);
86 let mut tokens = tokenizer
87 .encode(prompt, true)
88 .map_err(|e| format!("Error encoding tokenizer: {e}"))?
89 .get_ids()
90 .to_vec();
91
92 println!("starting the inference loop");
93 print!("{prompt}");
94 let mut logits_processor =
95 LogitsProcessor::new(seed, temperature, top_p);
96 let start_gen = std::time::Instant::now();
97 let mut index_pos = 0;
98 let mut token_generated = 0;
99 for index in 0..sample_len {
100 let context_size = if cache.use_kv_cache && index > 0 {
101 1
102 } else {
103 tokens.len()
104 };
105 let ctxt = &tokens[tokens.len().saturating_sub(context_size)..];
106 let input = Tensor::new(ctxt, &device)?.unsqueeze(0)?;
107 let logits = llama.forward(&input, index_pos, &mut cache)?;
108 let logits = logits.squeeze(0)?;
109 let logits = if repeat_penalty == 1. {
110 logits
111 } else {
112 let start_at = tokens.len().saturating_sub(repeat_last_n);
113 candle_transformers::utils::apply_repeat_penalty(
114 &logits,
115 repeat_penalty,
116 &tokens[start_at..],
117 )?
118 };
119 index_pos += ctxt.len();
120
121 let next_token = logits_processor.sample(&logits)?;
122 token_generated += 1;
123 tokens.push(next_token);
124
125 // Extracting the last token as a string is complicated, here we just apply some simple
126 // heuristics as it seems to work well enough for this example. See the following for more
127 // details:
128 // https://github.com/huggingface/tokenizers/issues/1141#issuecomment-1562644141
129 if let Some(text) = tokenizer.id_to_token(next_token) {
130 let text = text.replace('', " ").replace("<0x0A>", "\n");
131 print!("{text}");
132 std::io::stdout().flush()?;
133 }
134 if Some(next_token) == eos_token_id {
135 break;
136 }
137 }
138 let dt = start_gen.elapsed();
139 println!(
140 "\n\n{} tokens generated ({} token/s)\n",
141 token_generated,
142 token_generated as f64 / dt.as_secs_f64(),
143 );
144 Ok(())
145}

For maximum performance, it's crucial to compile it with native CPU flag, e.g.: RUSTFLAGS="-Ctarget-cpu=native" cargo run (yes, I said this three times, but it's because it's important, performance will suffer a lot if you don't do that).

Slack chatbot using an LLM with Candle

Another example of application is the Slack chatbot using an LLM with Candle available at github.com/randommm/rust-slackbot-llm/.


If you found this project helpful, please consider making a donation.