More parallelism in rust
The information in the main chapter gives most of what you'll need for easy parallel tasks in rust, which have no interdependencies or shared resources. Below we cover some more advanced use cases, although the example becomes a little more contrived.
Not covered is communication between threads. See the std::sync
library
for examples, particularly the mpsc
module.
Using shared memory
A real advantage of using compiled languages for parallel code is that you can easily control the memory used, and whether it is copied or not. In python and R, it is often the case that each new thread copies the memory (often the entire program's memory), which can both be slow and lead to memory use multiplying by the number of cores. On the other hand, you may need to be more careful that multiple threads aren't trying to write to the same data, or that their results may different if their execution order isn't deterministic.
Reading from shared memory without copying it is typically easy:
use std::time::Instant; use rayon::prelude::*; use rand::prelude::*; fn is_prime(int: u16) -> bool { let upper = (int as f64).sqrt().floor() as u16; for div in 2..=upper { if int % div == 0 { return false; } } true } fn main() { let size = 20000; let mut rng = thread_rng(); let mut arr = vec![0_u16; size]; rng.fill(&mut arr[..]); let threads = 4; rayon::ThreadPoolBuilder::new().num_threads(threads).build_global().unwrap(); let start = Instant::now(); let integers: Vec<usize> = (0..size).collect(); let primes: Vec<(u16, bool)> = integers.par_iter().map(|i| { let prime = is_prime(arr[*i] as u16); (arr[*i], prime) }).collect(); // Timing let end = Instant::now(); let parallel_time = end.duration_since(start).as_micros(); println!("Parallel done in {}us", parallel_time); // Check results look ok println!("{:?}", &primes[0..10]); }
In this slightly contrived example we first create random integers between 0 and \(2^{16}\)
and stored these in a Vec
called arr
. Each iteration then reads a position in this
list and processes it. At no point is the list copied (we didn't use .clone()
).
It is also possible to write to shared memory, if you know that every thread will
write to different positions (no race conditions). One way of doing this is with .par_chunks_mut()
to make separate chunks of shared memory to write to:
#![allow(unused)] fn main() { use rayon::prelude::*; use rand::prelude::*; fn is_prime(int: u16) -> bool { let upper = (int as f64).sqrt().floor() as u16; for div in 2..=upper { if int % div == 0 { return false; } } true } // We make a list of random numbers to check whether they are prime let size = 20000; let mut rng = thread_rng(); let mut arr = vec![0u16; size]; rng.fill(&mut arr[..]); // Use `iter_mut()` to split this list into non-overlapping chunks of 1000 // which we call 'slice' let start = Instant::now(); let chunk_size = 1000; arr.par_chunks_mut(chunk_size).for_each(|slice| { // Each thread iterates over its slice, if the random number is not prime // then we overwrite it with zero for int in slice.iter_mut() { if !is_prime(*int) { *int = 0; } } }); // Check results look ok println!("{:?}", &arr[0..10]); }
For more complex cases where threads may write to the same location, we will need to manage these shared resources more carefully.
Managing shared memory: atomics
Let's say we now want to calculate a sum of the prime numbers up to n
, with serial
code this is simple enough:
use std::time::Instant; fn is_prime(int: u32) -> bool { let upper = (int as f64).sqrt().floor() as u32; for div in 2..=upper { if int % div == 0 { return false; } } true } fn main() { let size = 20000; let start = Instant::now(); let integers: Vec<usize> = (2..size).collect(); let mut sum = 0; integers.iter().for_each(|i| { if is_prime(*i as u32) { sum += *i; } }); // Timing let end = Instant::now(); let serial_time = end.duration_since(start).as_micros(); println!("Serial done in {}us", serial_time); // Check results look ok println!("{sum}"); }
But if we try to replace with .par_iter()
in the above code we get the following
error:
error[E0594]: cannot assign to `sum`, as it is a captured variable in a `Fn` closure
--> src/main.rs:117:13
|
117 | sum += *i;
| ^^^^^^^^^ cannot assign
For more information about this error, try `rustc --explain E0594`.
warning: `week7` (bin "week7") generated 1 warning
rust is actually stopping us from doing this wrong. In an unsafe language we could probably use sum in this way without compiler error, but the sum would be wrong because threads would load and save to it at the same time, overwriting with the previous sum.
This would look something like this:
// thread 1 thread 2
// read sum
// add prime1 to sum // read sum
// write sum // add prime2 to sum
// sum = prime1 // write sum
// sum = prime2
This is a type of race condition, where the order of operations affects the result, but the order is not guaranteed. So we end up with sum having missed the addition of prime1 from thread1
One way of dealing this is to make sum 'atomic', so that the load, addition, and save back to memory happen in a single unbreakable unit:
use std::time::Instant; use std::sync::atomic::{Ordering, AtomicU32}; fn is_prime(int: u32) -> bool { let upper = (int as f64).sqrt().floor() as u32; for div in 2..=upper { if int % div == 0 { return false; } } true } fn main() { let size = 20000; let start = Instant::now(); let integers: Vec<usize> = (2..size).collect(); let sum = AtomicU32::new(0); integers.iter().for_each(|i| { if is_prime(*i as u32) { // fetch_add loads the current sum, adds the value to it, then writes // back to sum -- all in a single uninterruptable operation sum.fetch_add(*i as u32, Ordering::SeqCst); } }); // Timing let end = Instant::now(); let serial_time = end.duration_since(start).as_micros(); println!("Atomic done in {}us", serial_time); // We need to use load to read from the atomic println!("{}", sum.load(Ordering::SeqCst)); }
If multiple threads are trying to use the atomic at the same time this will decrease efficiency as they have to wait, but typically there is a low overhead to using atomics when threads aren't in conflict.
As sums are commutative i.e. A + B = B + A so we can add the results in any order. Sums are also associative i.e. (A + B) + C = A + (B + C) so no matter which grouping of elements we sum, we can then sum these results to get a correct final sum.
This means a sum can be used as a reduction
and we could either have each thread return its own non-atomic sum, which we then add together,
or even just use the built-in .sum()
operator.
Managing shared resources: mutex
What if we want to have a shared Vec
which keeps track of which numbers are prime? There is
no atomic for more complex types, just the basic types (integers, floats etc). We can use
a mutal exclusion 'mutex' to protect shared data when we write to it, which works with
any type:
use std::sync::{Arc, Mutex}; use std::thread; fn is_prime(int: u32) -> bool { let upper = (int as f64).sqrt().floor() as u32; for div in 2..=upper { if int % div == 0 { return false; } } true } fn main() { let size = 30; let primes = Arc::new(Mutex::new(Vec::new())); for i in 2..size { let list = Arc::clone(&primes); thread::spawn(move || { if is_prime(i as u32) { let mut list_mutex = list.lock().unwrap(); list_mutex.push(i); } }); } println!("{:?}", primes.lock().unwrap()); }
There are a few things going on here:
- We use a
Mutex
to surround theVec
. - Before we can use the mutex we call
.lock()
on it, this makes the thread wait until it can acquire the lock, which only a single thread can do at a time. - Calling
.lock().unwrap()
allows us to write to the contained object. - The mutex is automatically unlocked when the lock goes out of scope at the end of the if statement.
- We also need to use
Arc
to share (in a thread-safe way) a reference to the mutex between threads. If we don't use this we'll get a compile error as each thread will try to take control of the mutex, but only one can have mutable access in this way.
The results are correct, but might be out of order:
[2, 3, 5, 7, 13, 11, 17, 19, 23]
We are also likely to get a reduction in parallelisation efficiency as locking and unlocking the mutex takes some time, but more importantly some threads are waiting (blocking) while they wait to acquire the lock from another thread.
A typical example of a mutex might be multiple threads writing output to the terminal
writing over the top of each other, but println!()
actually already guards against this with
a mutex internally.
Another use for mutex might be where each thread writes to the same file. In this
case you may want to use RwLock
which
still allows reading by threads, but only one thread to have write access.