Rust Concurrency

Table of Contents

1. Rust 线程

Rust 中线程和操作系统中的线程是一一对应的。在标准库中没有实现 M:N 线程(协程)模型,如果你需要这样的模型可以使用第三方库,如 tokio

1.1. 创建线程

使用 thread::spawn 可以创建线程,如果操作系统创建线程失败,thread::spawn 会 panic,如果你想捕获这个错误,可以使用 std::thread::Builder,参考节 1.4

下面是使用 thread::spawn 创建线程的例子:

use std::thread;
use std::time::Duration;

fn main() {
    thread::spawn(|| {                  // 这里创建了一个子线程,spawn 参数是个 closure
        for i in 1..10 {                // 这是子线程执行的代码
            println!("hi number {} from the spawned thread!", i);
            thread::sleep(Duration::from_millis(3));
        }
    });

    for i in 1..5 {                     // 这里主线程执行的代码
        println!("hi number {} from the main thread!", i);
        thread::sleep(Duration::from_millis(3));
    }
}

下面是一个可能的输出(注:线程的调度由操作系统负责,多次运行得到的结果很可能不一样):

hi number 1 from the main thread!
hi number 1 from the spawned thread!
hi number 2 from the main thread!
hi number 2 from the spawned thread!
hi number 3 from the main thread!
hi number 3 from the spawned thread!
hi number 4 from the main thread!
hi number 4 from the spawned thread!

在这个例子中,由于主线程先执行完代码退出,从而导致子线程还没执行完代码也退出了。如何让主线程等待子线程呢?请看下节。

1.2. join 等待

thread::spawn 会返回一个 JoinHandle 对象,在 JoinHandle 对象上调用 join 可以等待关联的线程执行完,如:

use std::thread;
use std::time::Duration;

fn main() {
    let handle = thread::spawn(|| {
        for i in 1..10 {
            println!("hi number {} from the spawned thread!", i);
            thread::sleep(Duration::from_millis(3));
        }
    });

    for i in 1..5 {
        println!("hi number {} from the main thread!", i);
        thread::sleep(Duration::from_millis(3));
    }

    handle.join().unwrap();                       // join 表示等待 handle 关联线程执行完
}

下面是一个可能的输出(多次运行可能输出顺序不一样,但输出内容不会少):

hi number 1 from the main thread!
hi number 1 from the spawned thread!
hi number 2 from the main thread!
hi number 2 from the spawned thread!
hi number 3 from the main thread!
hi number 3 from the spawned thread!
hi number 4 from the main thread!
hi number 4 from the spawned thread!
hi number 5 from the spawned thread!
hi number 6 from the spawned thread!
hi number 7 from the spawned thread!
hi number 8 from the spawned thread!
hi number 9 from the spawned thread!

1.3. move 使用主线程数据

如果子线程要使用主线程定义中变量,需要在 thread::spawn 创建线程时指定 move 关键字,如:

use std::thread;

fn main() {
    let i: i32 = 100;
    let v: Vec<i32> = vec![1, 2, 3];

    let handle = thread::spawn(move || {                 // 指定 move,子线程会“复制变量”或者“拿走变量 ownership”
        println!("Here's a int: {:?}", i);               // i32 类型实现了 Copy Trait,i 会直接复制进来
        println!("Here's a vector: {:?}", v);            // Vec<i32> 没有实现 Copy Trait,其 ownership 会转移到子线程
    });

    println!("Main thread, here's a int: {:?}", i);      // 这里还可以使用 i,因为类型 i32 实现了 Copy Trait,它会被复制
    //println!("Main thread, here's a vector: {:?}", v); // 主线程不能使用 v 了,因为类型 Vec<i32> 没有实现 Copy Trait,
                                                         // 从而,ownership 会转移进子线程,主线程中 v 不再可用

    handle.join().unwrap();
}

1.4. 定制线程

如果你想定制线程,比如修改线程名称或者修改线程栈的大小,可以使用 std::thread::Builder。下面是它的使用例子:

use std::thread;
use std::time::Duration;
use std::thread::current;

fn main() {
    let builder = thread::Builder::new()
        .name("foo".into())           // 定制线程名称为 foo
        .stack_size(32 * 1024);       // 设置线程栈大小

    let handle = builder.spawn(|| {
        for i in 1..10 {
            println!("hi number {} from the spawned thread {}!", i, current().name().unwrap());
            thread::sleep(Duration::from_millis(3));
        }
    }).unwrap();

    handle.join().unwrap();
}

上面代码会输出:

hi number 1 from the spawned thread foo!
hi number 2 from the spawned thread foo!
hi number 3 from the spawned thread foo!
hi number 4 from the spawned thread foo!
hi number 5 from the spawned thread foo!
hi number 6 from the spawned thread foo!
hi number 7 from the spawned thread foo!
hi number 8 from the spawned thread foo!
hi number 9 from the spawned thread foo!

2. 同步原语

2.1. Mutex

互斥器(Mutex)保证任意时刻,只允许一个线程访问某些数据。它的使用步骤是:

  1. 使用数据之前需要“获取锁”。
  2. 使用完数据之后需要“释放锁”,这样其他线程才能够获取锁。

下面是互斥器的使用例子:

use std::sync::{Mutex, MutexGuard};

fn main() {                       // 这是个单线程例子,Mutex没啥意义,仅是为了介绍 Mutex API 使用
    let m = Mutex::new(5);

    {
        let mut num: MutexGuard<i32> = m.lock().unwrap();    // 使用数据前需要获取锁
        *num = 6;
    }                             // num 在这里会离开作用域,自动释放锁

    println!("m = {:?}", m);      // 输出 m = Mutex { data: 6 }
}

Mutex 对象上调用 lock 会返回名为 MutexGuard 的智能指针。这个智能指针实现了 Deref 来指向其内部数据;其也提供了一个 Drop 实现当 MutexGuard 离开作用域时自动释放锁。 这样,我们不需要手动释放锁了,也不用担心忘记释放锁了。

2.1.1. 多线程下使用 Mutex, Arc

介绍完了 Mutex 的基本 API,下面介绍一下 Mutex 在多线程下的使用例子:

use std::sync::Mutex;
use std::thread;

fn main() {                                           // 这是个错误的例子,后面会修正
    let counter = Mutex::new(0);
    let mut handles = vec![];

    for _ in 0..10 {                                  // 启动十个线程,并在各个线程中对同一个计数器值加一
        let handle = thread::spawn(move || {
            let mut num = counter.lock().unwrap();    // 这里会编译出错!因为 counter 从主线程 move 到“首个”子线程了,
                                                      // 后续的子线程无法再使用 counter 了
            *num += 1;
        });
        handles.push(handle);
    }

    for handle in handles {
        handle.join().unwrap();
    }

    println!("Result: {}", *counter.lock().unwrap());  // 这里会编译出错!因为 counter 已经被 move 了“首个”子线程
}

上面代码,会编译出错。 我们不能将 counter 的所有权移动到多个线程中。显然,counter 需要“多个所有者”,这可以通过“引用计数智能指针”来实现。 请看下面代码:

use std::sync::{Mutex, Arc};
use std::thread;

fn main() {
    let counter = Arc::new(Mutex::new(0));
    let mut handles = vec![];

    for _ in 0..10 {                                      // 启动十个线程,并在各个线程中对同一个计数器值加一
        let counter_new = Arc::clone(&counter);           // counter 引用计算会增加 1
        let handle = thread::spawn(move || {              // 每个子线程,都是从主线程中 move 一个新的 counter_new
            let mut num = counter_new.lock().unwrap();

            *num += 1;
        });                             // 子线程中 num 超过作用域后,会释放 Mutex 锁,同时 counter 引用计算会减 1
        handles.push(handle);
    }

    for handle in handles {
        handle.join().unwrap();
    }

    println!("Result: {}", *counter.lock().unwrap());     // Result: 10
}

上面代码中使用了智能指针 Arc<T> 来让 counter 具备多个所有者。Arc 是“Atomically Reference Counted”的缩写。

需要注意的是,上面场景中我们不能使用智能指针 Rc<T>,因为这个智能指针的实现中没有使用锁保护引用计数,多个线程都可能同时更新引用计数,这可能导致问题,Rc<T> 被设计为仅适用于单线程的场景,它的性能比 Arc<T> 要好。

如果我们把上面代码中的 Arc 修改为 Rc,会发现无法通过编译,编译器会报错:

`Rc<Mutex<i32>>` cannot be sent between threads safely
the trait `Send` is not implemented for `Rc<Mutex<i32>>`

可见,Rust 考虑很周道,不给你犯错的机会。上面错误中提到 Rc<Mutex<i32>> 没有实现 Send trait,关于 Send trait,节 2.2.1 中会介绍。

2.2. Send trait 和 Sync trait

本节介绍 std::marker 中的两个 trait:SendSync

2.2.1. Send(该类型的所有权可在线程间转移)

实现了 Send trait 的类型的所有权可以在线程间传递。

几乎所有的 Rust 类型都是 Send 的,不过有一些例外,如 Rc<T> 没有实现 Send trait,这一点在节 2.1.1 中提到过。之所以不让 Rc<T> 在线程中转移所有权的原因是:为了性能的考虑 Rc<T> 并没有使用锁保护引用计算,这样多个线程同时更新引用计数时可能出错。

Rust 把“所有权可以在线程间传递”这个事情抽象为 Send trait,这样 Rust 类型系统和 trait bound 就可以确保永远也不会意外的将不安全的对象(如 Rc<T>)在线程间发送。

2.2.2. Sync(该类型的引用可在多个线程之间被共享)

实现了 Sync trait 的类型可以安全地在多个线程中拥有其值的引用。 换一种方式来说,对于任意类型 T ,如果 &T 是 Send 的话,那么 T 就是 Sync 的,这意味着其 引用就可以安全的发送到另一个线程。

2.2.3. Send vs. Sync

前面的解释还是有点难理解了,下面换用更直白一点的方式来解释这两类约束:

Send:
满足 Send 约束的类型,能在多线程之间安全的排它使用(Exclusive access is thread-safe)。
满足 Send 约束的类型 T,表示 T 和 &mut T(mut 表示能修改这个引用,甚至于删除即 drop 这个数据)这两种类型的数据能在多个线程之间传递,说得直白些: 能在多个线程之间 move 值以及修改引用到的值。

Sync:
满足 Sync 约束的类型,能在多线程之间安全的共享使用(Shared access is thread-safe)。
满足 Sync 约束的类型 T,只表示 该类型能在多个线程中读共享,即:不能 move,也不能修改,仅仅只能通过引用 &T 来读取这个值。

摘自:https://www.codedump.info/post/20220619-weekly-19

2.3. RwLock

下面是 Rust 中读写锁 RwLock 的使用例子:

// From https://riptutorial.com/rust/example/24527/read-write-locks
use std::time::Duration;
use std::thread;
use std::thread::sleep;
use std::sync::{Arc, RwLock };

fn main() {
    // Create an u32 with an inital value of 0
    let initial_value = 0u32;

    // Move the initial value into the read-write lock which is wrapped into an atomic reference
    // counter in order to allow safe sharing.
    let rw_lock = Arc::new(RwLock::new(initial_value));

    let producer_lock = rw_lock.clone();
    let producer_thread = thread::spawn(move || {
        loop {
            if let Ok(mut write_guard) = producer_lock.write() { // 用 write 获得写锁,没拿到锁就 block 直接拿到
                // the returned write_guard implements `Deref` giving us easy access to the target value
                *write_guard += 1;

                sleep(Duration::from_millis(600));
                println!("Updated value: {}", *write_guard);
            }  // write_guard 超过作用域会自动释放锁

            sleep(Duration::from_millis(1000));
        }
    });

    let consumer_id_lock = rw_lock.clone();
    let consumer_id_thread = thread::spawn(move || {
        loop {
            if let Ok(read_guard) = consumer_id_lock.read() { // 用 read 获得读锁,没拿到锁就 block 直接拿到
                // the returned read_guard also implements `Deref`
                println!("Read value: {}", *read_guard);
            }  // read_guard 超过作用域会自动释放锁

            sleep(Duration::from_millis(1000));
        }
    });

    let consumer_square_lock = rw_lock.clone();
    let consumer_square_thread = thread::spawn(move || {
        loop {
            if let Ok(read_guard) = consumer_square_lock.try_read() { // try_read 获得读锁,没拿到锁就返回错
                let value = *read_guard;
                println!("Read value squared: {}", value * value);
            } else { // read_guard 超过作用域会自动释放锁
                println!("The try_read() get read lock failed");
            }

            sleep(Duration::from_millis(1000));
        }
    });

    let _ = producer_thread.join();
    let _ = consumer_id_thread.join();
    let _ = consumer_square_thread.join();
}

上面程序中,创建了 3 个子线程,1 个线程中使用 write() 拿写锁,另外 2 个线程分别使用 read()/try_read() 拿读锁。

2.4. Once

如果你想要“一次性地全局初始化”,可以使用 Once,如:

use std::sync::Once;

static mut VAL: usize = 0;
static INIT: Once = Once::new();

// Accessing a `static mut` is unsafe much of the time, but if we do so
// in a synchronized fashion (e.g., write once or read all) then we're
// good to go!
//
// This function will only call `expensive_computation` once, and will
// otherwise always return the value returned from the first invocation.
fn get_cached_val() -> usize {
    unsafe {
        INIT.call_once(|| {
            VAL = expensive_computation();
        });
        VAL
    }
}

fn expensive_computation() -> usize {
    // ...
    println!("expensive_computation is called");
    100
}

fn main() {
    println!("{}", get_cached_val());
    println!("{}", get_cached_val());
}

2.5. Barrier

如果你想要“所有线程运行到同一点后,才往下执行”,可以使用 Barrier,如:

use std::sync::{Arc, Barrier};
use std::thread;

fn main() {
    let nthreads = 10;
    let mut handles = Vec::with_capacity(nthreads);
    let barrier = Arc::new(Barrier::new(nthreads));

    for _ in 0..nthreads {
        let b = Arc::clone(&barrier);
        // The same messages will be printed together.
        // You will NOT see any interleaving.
        handles.push(thread::spawn(move || {
            println!("before wait");
            b.wait();                       // barrier 上调用 wait 表示所有线程都运行到这个点后,才往下运行
            println!("after wait");
        }));
    }

    // Wait for other threads to finish.
    for handle in handles {
        handle.join().unwrap();
    }
}

2.6. Condvar

Rust 中也实现了条件变量,参考 Condvar

2.7. mpsc(通道中消息传递来实现在线程间传送数据)

Rust 中也实现了类似 Golang 中用通道(channel)传递消息来实现同步的机制。

我们可以把通道(channel)想象为“河流”。如果你将诸如橡皮鸭或小船之类的东西放入其中,它们会顺流而下到达下游。编程中的通道有两部分组成,一个发送者(transmitter)和一个接收者(receiver)。 发送者位于上游位置 ,在这里可以将橡皮鸭放入河中, 接收者则位于下游 ,橡皮鸭最终会漂流至此。代码中的一部分调用发送者的方法以及希望发送的数据,另一部分则检查接收端收到的消息。

Rust 中,std::sync::mpsc::channel 函数用于创建一个新的通道。mpsc 是 Multiple Producer, Single Consumer 的缩写,表达的意思是 一个通道可以有多个发送者,但只能有一个消费这些值的接收者。

下面是 Rust 中通道的使用例子:

use std::thread;
use std::sync::mpsc;

fn main() {
    let (tx, rx) = mpsc::channel();           // 创建一个通道,tx 为发送者,rx 为接收者

    thread::spawn(move || {
        let val = String::from("hi");
        tx.send(val).unwrap();                // 调用 send 方法往通道中发送数据
    });

    let received = rx.recv().unwrap();        // 调用 recv 方法从通道中读数据
    println!("Got: {}", received);            // Got: hi
}

接收者调用 recv 方法会阻塞当前线程执行直到从通道中接收一个值;如果希望通道中没有值时不阻塞,则可以使用 try_recv ,它不会阻塞,总是立刻返回,没有值就返回错误。

从通道中读取数据,不一定要显式调用 recv 函数,使用迭代器也行,当通道被关闭时,迭代器也将结束。如:

    for received in rx {                      // 不用显式调用 recv 函数,而是将 rx 当作一个迭代器
        println!("Got: {}", received);
    }

2.7.1. 通道中的所有权转移

一旦把某个值发送到通道中,这个值对应类型如果实现了 Copy trait,那么该值会被复制到通道中;如果值对应类型没有实现 Copy Trait,那么该值会转移所有权到通道中。 如:

use std::thread;
use std::sync::mpsc;

fn main() {
    let (tx, rx) = mpsc::channel();

    thread::spawn(move || {
        let val = String::from("hi");
        tx.send(val).unwrap();          // String 没有实现 Copy trait,val 的所有权转移到通道中
        print!("{}", val);              // 这里会报错!不能再使用 val 了
    });

    let received = rx.recv().unwrap();
    println!("Got: {}", received);
}

2.7.2. 多个生产者

前面介绍过 mpsc 是 Multiple Producer, Single Consumer 的缩写,下面介绍一下如何实现多个生产者。

通过调用 mpsc::Sender::clone 克隆通道的发送者可以实现多个生产者。 如:

use std::thread;
use std::sync::mpsc;
use std::time::Duration;

fn main() {
    let (tx, rx) = mpsc::channel();

    let tx1 = mpsc::Sender::clone(&tx);     // 调用 clone 克隆通道的发送者
    thread::spawn(move || {                 // 这个线程中将使用发送者 tx1
        let vals = vec![
            String::from("hello from tx1"),
            String::from("world from tx1"),
        ];

        for val in vals {
            tx1.send(val).unwrap();
            thread::sleep(Duration::from_secs(1));
        }
    });

    thread::spawn(move || {                 // 这个线程中将使用发送者 tx
        let vals = vec![
            String::from("hello from tx"),
            String::from("world from tx"),
        ];

        for val in vals {
            tx.send(val).unwrap();
            thread::sleep(Duration::from_secs(1));
        }
    });

    for received in rx {
        println!("Got: {}", received);
    }
}

下面是上面程序的一个可能输出(多次运行上面程序可能得到不一样的输出):

Got: hello from tx1
Got: hello from tx
Got: world from tx
Got: world from tx1

3. 实现 Web Server

3.1. 单线程实现 Web Server

下面是一个简单的 Web Server 实现:

use std::net::{TcpListener, TcpStream};
use std::io::{Read, Write};

fn handle_connection(mut stream: TcpStream) {
    let mut buffer = [0; 512];
    stream.read(&mut buffer).unwrap();             // 从流中读取数据
    println!("Get request: [{}]", String::from_utf8_lossy(&buffer[..])); // 打印出请求数据

    let response = "HTTP/1.1 200 OK\r\n\r\nThis is a Rust Web Server";
    stream.write(response.as_bytes()).unwrap();    // 往流中写入数据
    stream.flush().unwrap();
}

fn main() {
    let listener = TcpListener::bind("127.0.0.1:7878").unwrap();

    for stream in listener.incoming() {
        let stream: TcpStream = stream.unwrap();

        handle_connection(stream);
    }
}

上面代码中,调用 TcpListener::bind 可以监听在某个端口。TcpListener 的 incoming 方法返回一个“迭代器”,从这个迭代器中可以得到一系列的流,每个流代表一个客户端和服务端之间打开的连接。

启动上面程序(服务器)后,我们使用 curl 作为客户端测试一下:

$ curl 127.0.0.1:7878
This is a Rust Web Server

而服务器端会显示:

Get request: [GET / HTTP/1.1
Host: 127.0.0.1:7878
User-Agent: curl/7.56.0
Accept: */*

]

3.2. 线程池实现 Web Server

下面将介绍如何使用线程池实现 Web Server,主要思想为:

  1. ThreadPool 会初始化指定数量的工作线程,并创建一个通道,作为暂存任务的队列;
  2. 新建一个 Job 结构体来存放用于向通道中发送的闭包(任务)。
  3. ThreadPool 的 execute 方法的作用就是往通道中写入任务;
  4. 每个 Worker(对应一个线程)遍历通道的接收端并执行任何接收到的任务;

线程池版本的 Web Server 的完整代码如下:

use std::net::{TcpListener, TcpStream};
use std::io::{Read, Write};
use std::sync::{mpsc, Arc, Mutex};
use std::thread;

pub struct ThreadPool {
    workers: Vec<Worker>,            // 每个 worker 对应一个线程
    sender: mpsc::Sender<Message>,   // 通道的发送端,提交任务就是往通道的发送端写入数据
}

type Job = Box<dyn FnOnce() + Send + 'static>;

enum Message {
    NewJob(Job),  // 表示新任务
    Terminate,    // 表示没有任务了,终止执行。暂时没用
}

impl ThreadPool {
    /// Create a new ThreadPool.
    pub fn new(size: usize) -> ThreadPool {
        assert!(size > 0);

        // 创建一个通道
        // 通道“发送端”给 execute,通道“接收端”给每个 worker(一个 worker 对应一个线程)
        // 通道“接收端”被多个线程访问,所以需要使用使用 Mutex 和 Arc
        // Arc 使得多个 worker 拥有接收端,而 Mutex 可确保一次只有一个 worker 能从接收端得到任务。
        let (sender, receiver) = mpsc::channel();
        let receiver = Arc::new(Mutex::new(receiver));
        let mut workers = Vec::with_capacity(size);
        for id in 0..size {
            workers.push(Worker::new(id, Arc::clone(&receiver)));
        }
        ThreadPool { workers, sender }
    }

    /// 往通道中写入任务
    pub fn execute<F>(&self, f: F)
        where
            F: FnOnce() + Send + 'static,         // 参考 thread::spawn 的签名
    {
        let job = Box::new(f);
        self.sender.send(Message::NewJob(job)).unwrap();
    }
}

struct Worker {
    id: usize,
    thread: Option<thread::JoinHandle<()>>,
}

impl Worker {
    // 创建 worker,一个 worker 对应一个线程,线程在死循环中不停查看通道中是否有任务,有就执行
    fn new(id: usize, receiver: Arc<Mutex<mpsc::Receiver<Message>>>) -> Worker {
        let thread = thread::spawn(move ||
            loop {
                let message = receiver.lock().unwrap().recv().unwrap(); // lock: 访问通道前先拿锁
                                                                        // recv: 通道中有数据就接收,没数据就 block

                match message {
                    Message::NewJob(job) => {
                        println!("Worker {} got a job; executing.", id);

                        job();
                    }
                    Message::Terminate => {
                        println!("Worker {} was told to terminate.", id);

                        break;
                    }
                }
            });

        Worker {
            id,
            thread: Some(thread),
        }
    }
}

fn handle_connection(mut stream: TcpStream) {
    let mut buffer = [0; 512];
    stream.read(&mut buffer).unwrap();             // 从流中读取数据
    println!("Get request: [{}]", String::from_utf8_lossy(&buffer[..]));

    let response = "HTTP/1.1 200 OK\r\n\r\nThis is a Rust Web Server";
    stream.write(response.as_bytes()).unwrap();    // 往流中写入数据
    stream.flush().unwrap();
}

fn main() {
    let listener = TcpListener::bind("127.0.0.1:7878").unwrap();
    let pool = ThreadPool::new(4);

    for stream in listener.incoming() {
        let stream = stream.unwrap();

        pool.execute(|| {               // 提交闭包到线程池
            handle_connection(stream);
        });
    }
}

上面代码中, execute 的参数的约束为 FnOnce() + Send + 'static ,和 thread::spawn 的签名是一样的。下面依次介绍一下这 3 个约束:

  1. 因为处理请求的线程只会执行闭包一次,所以 FnOnce 是我们需要的 trait,这里符合 FnOnce 中 Once 的意思;
  2. 由于闭包从一个线程转移到另一个线程,所以需要实现 Send trait;
  3. 而指定 'static 是因为并不知道线程会执行多久。

4. Future, async/await

Future 是并发编程中经常使用的一种结构,它代表着“某个未来计算的结果”,它有两种状态:一是计算还没有完成(Pending),二是计算已经完成(Ready)。状态只能从 Pending 状态转换为 Ready 状态;一旦处于 Ready 状态,就不再变化了。

Rust 中 trait Future 的定义如下:

pub trait Future {
    type Output;                                                            // 计算完成后,结果值的类型

    fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output>;  // 调用 poll 可以让它从 Pending 状态转换为 Ready 状态,当然可能需要调用多次
}

调用 poll 可以尝试使 Future 从 Pending 转换为 Ready,它返回 Poll 枚举类型,它的定义如下:

pub enum Poll<T> {
    Ready(T),
    Pending,
}

调用 poll 方法比较繁琐,因为准备 poll 的参数比较麻烦,下面是 poll 的一个例子:

use std::future::Future;
use std::pin::Pin;
use std::task::{Poll, Context, Waker, RawWaker, RawWakerVTable};

unsafe fn vt_clone(data: *const ()) -> RawWaker {
    RawWaker::new(data, &VTABLE)
}

unsafe fn vt_wake(_data: *const ()) {
}

unsafe fn vt_wake_by_ref(_data: *const ()) {
}

unsafe fn vt_drop(_data: *const ()) {
}

static VTABLE: RawWakerVTable = RawWakerVTable::new(
    vt_clone,
    vt_wake,
    vt_wake_by_ref,
    vt_drop
);

struct IntFuture {
    x: i32,
}

impl Future for IntFuture {
    type Output = i32;

    fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll<i32> {
        Poll::Ready(self.x + 100)
    }
}

fn main() {
    let mut f = IntFuture{ x: 30 };

    let rw = RawWaker::new(&(), &VTABLE);
    let w = unsafe { Waker::from_raw(rw) };
    let mut cx = Context::from_waker(&w);

    let p = unsafe { Pin::new_unchecked(&mut f) };

    assert_eq!(p.poll(&mut cx), Poll::Ready(130));
}

上面的代码很是繁琐。不过, 我们一般不用自己显示地调用 poll 函数,而是通过 block_on/.await 等待 Future 到达 Ready 状态,并获得结果。

4.1. block_on(阻塞线程直到 Future 到达 Ready 状态)

使用 block_on 可以阻塞当前线程直到 Future 到达 Ready 状态, 下面是它的一个例子:

use std::future::Future;
use std::pin::Pin;
use std::task::{Poll, Context};
use futures::executor::block_on;

struct IntFuture {
    x: i32,
}

impl Future for IntFuture {
    type Output = i32;

    fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll<i32> {
        Poll::Ready(self.x + 100)
    }
}

fn main() {
    let f = IntFuture{ x: 30 };

    let result = block_on(f);             // 阻塞当前线程直到 Future f 到达 Ready 状态

    assert_eq!(result, 130);
}

4.2. async(返回 Future)

函数使用 async 修饰可以得到一个 Future。比如,前面介绍的代码:

struct IntFuture {
    x: i32,
}

impl Future for IntFuture {
    type Output = i32;

    fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll<i32> {
        Poll::Ready(self.x + 100)
    }
}

let f = IntFuture{ x: 30 };

可以使用 async 简化为:

async fn int_future(x: i32) -> i32 {
    x + 100
}

let f = int_future(30);

4.1 的例子可以简化为:

use std::future::Future;
use futures::executor::block_on;

async fn int_future(x: i32) -> i32 {
    x + 100
}

fn main() {
    let f = int_future(30);

    let result = block_on(f);

    assert_eq!(result, 130);
}

4.3. await(只能在 async 中使用)

4.1 中介绍过,使用 block_on 可以阻塞当前线程直到 Future 处于 Ready 状态。我们再看一个 block_on 的使用例子:

use std::future::Future;
use futures::executor::block_on;

struct Song {}

async fn learn_song() -> Song { Song{} }
async fn sing_song(song: Song) { /* ... */ }
async fn dance() { /* ... */ }

fn main() {
    let song = block_on(learn_song());
    block_on(sing_song(song));
    block_on(dance());
}

上面代码有个缺陷:使用 block_on 会导致 learn_song/sing_song/dance 这 3 个操作只能依次执行。但是,从现实上说,learn_song/sing_song 需要按序执行;而 dance 应该可以和 learn_song/sing_song 并行执行。如何能实现这个要求呢?

使用 .awaitfutures::join! 可以完成上面的任务:

use std::future::Future;
use futures::executor::block_on;

struct Song {}

async fn learn_song() -> Song { Song{} }
async fn sing_song(song: Song) { /* ... */ }
async fn dance() { /* ... */ }

async fn learn_and_sing() {
    // Wait until the song has been learned before singing it.
    // We use `.await` here rather than `block_on` to prevent blocking the
    // thread, which makes it possible to `dance` at the same time.
    let song = learn_song().await;
    sing_song(song).await;
}

async fn async_main() {
    let f1 = learn_and_sing();
    let f2 = dance();

    // `join!` is like `.await` but can wait for multiple futures concurrently.
    // If we're temporarily blocked in the `learn_and_sing` future, the `dance`
    // future will take over the current thread. If `dance` becomes blocked,
    // `learn_and_sing` can take back over. If both futures are blocked, then
    // `async_main` is blocked and will yield to the executor.
    futures::join!(f1, f2);
}

fn main() {
    block_on(async_main());
}

4.4. Tips

4.4.1. async 函数中调用阻塞函数(需要放到 spawn_blocking 中)

Rust 生态中,有异步代码和同步代码这两种风格,一些库同时提供异步 API 和同步 API,也有很多库只提供同步 API。

在 async 函数中调用“长时间运行的阻塞函数”是异步编程的一个禁忌,因为这会导致当前执行器(工作线程)挂起,不能执行其它异步任务了,从而影响并发性能。

如果在某些场景下,async 函数不可避免地需要调用“长时间运行的阻塞函数”,那应该把这个阻塞函数移动到不同的线程池中,这样执行器就可以继续运行其它异步任务。tokio 运行时提供了 spawn_blocking 函数来可以把阻塞函数放到独立的线程池中运行。如:

use std::thread;
use std::time::Duration;
use tokio::task::spawn_blocking;

fn long_running_task() -> u32 {
    // 下面是模拟长时间运行的任务
    thread::sleep(Duration::from_secs(5));
    5
}

async fn my_task() {                // my_task 是个异步函数,不要直接调用长时间运行的阻塞函数 long_running_task
    let res = spawn_blocking(|| {   // 需要把 long_running_task 移动到不同的线程池中运行
        long_running_task()
    }).await.unwrap();

    println!("The answer was: {}", res);
}

#[tokio::main]
async fn main() {
    my_task().await;
}

参考:https://thomask.sdf.org/blog/2021/03/08/bridging-sync-async-code-in-rust.html

5. Pinning

5.1. 自引用结构

当一个结构体的字段是指针,且指向自己的另一个字段,这样的结构是“自引用结构”。“自引用结构”有个难题:对象 move 时,其指针指向的数据还是旧对象的字段,而旧对象可能失效了,这会出现内存安全问题。

下面例子说明了对象 move 时,结构体指针字段所指内容并没有变:

#[derive(Debug)]
struct Test {
    a: String,
    b: *const String,
}

impl Test {
    fn new(txt: &str) -> Self {
        Test {
            a: String::from(txt),
            b: std::ptr::null(),
        }
    }

    fn init(&mut self) {
        let self_ref: *const String = &self.a;
        self.b = self_ref;    // 这样,b 指向了 a,Test 是自引用结构
    }

    fn a(&self) -> &str {
        &self.a
    }

    fn b(&self) -> &String {
        assert!(!self.b.is_null(), "Test::b called without Test::init being called first");
        unsafe { &*(self.b) }
    }
}

fn main() {
    let mut test1 = Test::new("test1");
    test1.init();
    let mut test2 = Test::new("test2");
    test2.init();

    println!("a: {}, b: {}", test1.a(), test1.b());
    println!("a: {}, b: {}", test2.a(), test2.b());
    std::mem::swap(&mut test1, &mut test2);             // 交换 test1 和 test2
    println!("a: {}, b: {}", test1.a(), test1.b());
    println!("a: {}, b: {}", test2.a(), test2.b());
}

执行上面程序,会输出:

a: test1, b: test1
a: test2, b: test2
a: test2, b: test1      // 已经交换了 test1 和 test2,这时访问 test1.b(),竟然还输出 test1,这说明 swap 后,b 所指向的内容并没有变
a: test1, b: test2

参考:https://rust-lang.github.io/async-book/04_pinning/01_chapter.html

5.2. Pin 总结

Pin 是一个智能指针(它实现了 Deref 和 DerefMut),它包裹了另外一个指针 P(我们把 P 指针所指向的内容称为 T), 如果 P 没有实现 Unpin,则 Pin 保证 T 不会被移动。 从而保证了内存的安全性。Pin 可以用 Pin<P<T>> 表示(P 是 Pointer 的缩写,T 是 Type 的缩写)。

Pin(它的含义是“固定”)具有让 T 不能移动的能力,这个能力是否生效取决于 T 是否实现 Unpin。简单的说,如果 T 实现了 Unpin,则 Pin 的能力就失效了,这时候 Pin<P<T>> 就等价于 P<T> ;如果 T 没有实现 Unpin,则 Pin 就发挥作用了。

哪些类型实现了 Unpin 呢?Unpin 是一个 auto trait, 编译器默认会给绝大部分类型(如数字、字符串、 布尔值等,也包括完全由它们组成的结构体或枚举类型)实现 Unpin。

6. Stream

Stream 和 Future 有点像,不过 Stream 在完成状态(Ready)之前可以“多次产生值”。Stream 的定义如下:

trait Stream {
    /// The type of the value yielded by the stream.
    type Item;

    /// Attempt to resolve the next item in the stream.
    /// Returns `Poll::Pending` if not ready, `Poll::Ready(Some(x))` if a value
    /// is ready, and `Poll::Ready(None)` if the stream has completed.
    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>)
        -> Poll<Option<Self::Item>>;
}

7. 组合多个 Futures

7.1. join!

使用宏 futures::join 可以并行地等待两个 Future,即两个 Future 同时处于 Ready 状态才往下运行,它只能在 async 函数中使用,节 4.3 中介绍过它的用法。

7.2. select!

使用宏 futures::select 可以实现任意一个 Future 处于 Ready 状态就往下运行的效果,它只能在 async 函数中使用。下面是它的一个例子:

use futures::future;
use futures::select;
use futures::executor::block_on;

async fn async_main() {
    let mut a = future::ready(4);
    let mut b = future::pending::<()>();

    let res = select! {
    a_res = a => { println!("future a is ready, {:?}", a_res); a_res + 1 },
    b_res = b => { println!("future b is ready, {:?}", b_res); 0 },
    };

    assert_eq!(res, 5);
}

fn main() {
    block_on(async_main())
}

8. 参考

Author: cig01

Created: <2020-12-03 Thu>

Last updated: <2022-06-20 Mon>

Creator: Emacs 27.1 (Org mode 9.4)