More parallelism in rust

The information in the main chapter gives most of what you'll need for easy parallel tasks in rust, which have no interdependencies or shared resources. Below we cover some more advanced use cases, although the example becomes a little more contrived.

Not covered is communication between threads. See the std::sync library for examples, particularly the mpsc module.

Using shared memory

A real advantage of using compiled languages for parallel code is that you can easily control the memory used, and whether it is copied or not. In python and R, it is often the case that each new thread copies the memory (often the entire program's memory), which can both be slow and lead to memory use multiplying by the number of cores. On the other hand, you may need to be more careful that multiple threads aren't trying to write to the same data, or that their results may different if their execution order isn't deterministic.

Reading from shared memory without copying it is typically easy:

use std::time::Instant;
use rayon::prelude::*;
use rand::prelude::*;

fn is_prime(int: u16) -> bool {
    let upper = (int as f64).sqrt().floor() as u16;
    for div in 2..=upper {
        if int % div == 0 {
            return false;
        }
    }
    true
}
fn main() {

    let size = 20000;
    let mut rng = thread_rng();
    let mut arr = vec![0_u16; size];
    rng.fill(&mut arr[..]);

    let threads = 4;
    rayon::ThreadPoolBuilder::new().num_threads(threads).build_global().unwrap();

    let start = Instant::now();
    let integers: Vec<usize> = (0..size).collect();
    let primes: Vec<(u16, bool)> = integers.par_iter().map(|i| {
        let prime = is_prime(arr[*i] as u16);
        (arr[*i], prime)
    }).collect();

    // Timing
    let end = Instant::now();
    let parallel_time = end.duration_since(start).as_micros();
    println!("Parallel done in {}us", parallel_time);

    // Check results look ok
    println!("{:?}", &primes[0..10]);
}

In this slightly contrived example we first create random integers between 0 and \(2^{16}\) and stored these in a Vec called arr. Each iteration then reads a position in this list and processes it. At no point is the list copied (we didn't use .clone()).

It is also possible to write to shared memory, if you know that every thread will write to different positions (no race conditions). One way of doing this is with .par_chunks_mut() to make separate chunks of shared memory to write to:

#![allow(unused)]
fn main() {
use rayon::prelude::*;
use rand::prelude::*;

fn is_prime(int: u16) -> bool {
    let upper = (int as f64).sqrt().floor() as u16;
    for div in 2..=upper {
        if int % div == 0 {
            return false;
        }
    }
    true
}
// We make a list of random numbers to check whether they are prime
let size = 20000;
let mut rng = thread_rng();
let mut arr = vec![0u16; size];
rng.fill(&mut arr[..]);

// Use `iter_mut()` to split this list into non-overlapping chunks of 1000
// which we call 'slice'
let start = Instant::now();
let chunk_size = 1000;
arr.par_chunks_mut(chunk_size).for_each(|slice| {
    // Each thread iterates over its slice, if the random number is not prime
    // then we overwrite it with zero
    for int in slice.iter_mut() {
        if !is_prime(*int) {
            *int = 0;
        }
    }
});

// Check results look ok
println!("{:?}", &arr[0..10]);
}

For more complex cases where threads may write to the same location, we will need to manage these shared resources more carefully.

Managing shared memory: atomics

Let's say we now want to calculate a sum of the prime numbers up to n, with serial code this is simple enough:

use std::time::Instant;

fn is_prime(int: u32) -> bool {
    let upper = (int as f64).sqrt().floor() as u32;
    for div in 2..=upper {
        if int % div == 0 {
            return false;
        }
    }
    true
}
fn main() {
    let size = 20000;

    let start = Instant::now();

    let integers: Vec<usize> = (2..size).collect();
    let mut sum = 0;
    integers.iter().for_each(|i| {
        if is_prime(*i as u32) {
            sum += *i;
        }
    });

    // Timing
    let end = Instant::now();
    let serial_time = end.duration_since(start).as_micros();
    println!("Serial done in {}us", serial_time);

    // Check results look ok
    println!("{sum}");
}

But if we try to replace with .par_iter() in the above code we get the following error:

error[E0594]: cannot assign to `sum`, as it is a captured variable in a `Fn` closure
   --> src/main.rs:117:13
    |
117 |             sum += *i;
    |             ^^^^^^^^^ cannot assign

For more information about this error, try `rustc --explain E0594`.
warning: `week7` (bin "week7") generated 1 warning

rust is actually stopping us from doing this wrong. In an unsafe language we could probably use sum in this way without compiler error, but the sum would be wrong because threads would load and save to it at the same time, overwriting with the previous sum.

This would look something like this:

// thread 1                 thread 2
// read sum
// add prime1 to sum        // read sum
// write sum                // add prime2 to sum
// sum = prime1             // write sum
                            // sum = prime2

This is a type of race condition, where the order of operations affects the result, but the order is not guaranteed. So we end up with sum having missed the addition of prime1 from thread1

One way of dealing this is to make sum 'atomic', so that the load, addition, and save back to memory happen in a single unbreakable unit:

use std::time::Instant;
use std::sync::atomic::{Ordering, AtomicU32};

fn is_prime(int: u32) -> bool {
    let upper = (int as f64).sqrt().floor() as u32;
    for div in 2..=upper {
        if int % div == 0 {
            return false;
        }
    }
    true
}
fn main() {
    let size = 20000;

    let start = Instant::now();

    let integers: Vec<usize> = (2..size).collect();
    let sum = AtomicU32::new(0);
    integers.iter().for_each(|i| {
        if is_prime(*i as u32) {
            // fetch_add loads the current sum, adds the value to it, then writes
            // back to sum -- all in a single uninterruptable operation
            sum.fetch_add(*i as u32, Ordering::SeqCst);
        }
    });

    // Timing
    let end = Instant::now();
    let serial_time = end.duration_since(start).as_micros();
    println!("Atomic done in {}us", serial_time);

    // We need to use load to read from the atomic
    println!("{}", sum.load(Ordering::SeqCst));
}

If multiple threads are trying to use the atomic at the same time this will decrease efficiency as they have to wait, but typically there is a low overhead to using atomics when threads aren't in conflict.

Note

As sums are commutative i.e. A + B = B + A so we can add the results in any order. Sums are also associative i.e. (A + B) + C = A + (B + C) so no matter which grouping of elements we sum, we can then sum these results to get a correct final sum.

This means a sum can be used as a reduction and we could either have each thread return its own non-atomic sum, which we then add together, or even just use the built-in .sum() operator.

Managing shared resources: mutex

What if we want to have a shared Vec which keeps track of which numbers are prime? There is no atomic for more complex types, just the basic types (integers, floats etc). We can use a mutal exclusion 'mutex' to protect shared data when we write to it, which works with any type:

use std::sync::{Arc, Mutex};
use std::thread;

fn is_prime(int: u32) -> bool {
    let upper = (int as f64).sqrt().floor() as u32;
    for div in 2..=upper {
        if int % div == 0 {
            return false;
        }
    }
    true
}
fn main() {
    let size = 30;
    let primes = Arc::new(Mutex::new(Vec::new()));
    for i in 2..size {
        let list = Arc::clone(&primes);
        thread::spawn(move || {
            if is_prime(i as u32) {
                let mut list_mutex = list.lock().unwrap();
                list_mutex.push(i);
            }
        });
    }
    println!("{:?}",  primes.lock().unwrap());
}

There are a few things going on here:

  • We use a Mutex to surround the Vec.
  • Before we can use the mutex we call .lock() on it, this makes the thread wait until it can acquire the lock, which only a single thread can do at a time.
  • Calling .lock().unwrap() allows us to write to the contained object.
  • The mutex is automatically unlocked when the lock goes out of scope at the end of the if statement.
  • We also need to use Arc to share (in a thread-safe way) a reference to the mutex between threads. If we don't use this we'll get a compile error as each thread will try to take control of the mutex, but only one can have mutable access in this way.

The results are correct, but might be out of order:

[2, 3, 5, 7, 13, 11, 17, 19, 23]

We are also likely to get a reduction in parallelisation efficiency as locking and unlocking the mutex takes some time, but more importantly some threads are waiting (blocking) while they wait to acquire the lock from another thread.

Note

A typical example of a mutex might be multiple threads writing output to the terminal writing over the top of each other, but println!()actually already guards against this with a mutex internally.

Another use for mutex might be where each thread writes to the same file. In this case you may want to use RwLock which still allows reading by threads, but only one thread to have write access.