Rust:编译时计算与const generics进阶

Rust的const generics和const fn让编译时计算成为一种强大的元编程手段。本文深入探讨如何利用这些特性实现固定大小数组抽象、编译时矩阵运算和类型级断言。

const fn:从基础到进阶

Rust的const fn允许函数在编译时被求值。从Rust 1.46开始,const fn内部支持了ifmatch、循环等控制流,使其实用性大幅提升。

const fn fibonacci(n: u32) -> u64 {
    match n {
        0 => 0,
        1 => 1,
        _ => {
            let mut a: u64 = 0;
            let mut b: u64 = 1;
            let mut i = 2;
            while i <= n {
                let tmp = a + b;
                a = b;
                b = tmp;
                i += 1;
            }
            b
        }
    }
}

// 编译时就计算好了
const FIB_20: u64 = fibonacci(20);

fn main() {
    assert_eq!(FIB_20, 6765);
}

更复杂的场景——编译时构建查找表:

const fn build_ascii_lookup() -> [bool; 128] {
    let mut table = [false; 128];
    let mut i = 0;
    while i < 128 {
        table[i] = matches!(i as u8, b'a'..=b'z' | b'A'..=b'Z' | b'0'..=b'9' | b'_');
        i += 1;
    }
    table
}

const IS_IDENT_CHAR: [bool; 128] = build_ascii_lookup();

fn is_identifier_char(c: u8) -> bool {
    (c as usize) < 128 && IS_IDENT_CHAR[c as usize]
}

这种模式在解析器、编解码器中非常常见——零运行时开销的查找表。

const generics:固定大小的类型

const generics(const N: usize)让我们可以在类型层面表达数组大小,替代以前只能用宏或手动为每个大小实现trait的尴尬局面。

固定大小向量

use std::ops::{Add, Mul};

#[derive(Debug, Clone, Copy, PartialEq)]
struct Vec<const N: usize> {
    data: [f64; N],
}

impl<const N: usize> Vec<N> {
    const fn zero() -> Self {
        Vec { data: [0.0; N] }
    }

    fn dot(&self, other: &Self) -> f64 {
        let mut sum = 0.0;
        let mut i = 0;
        while i < N {
            sum += self.data[i] * other.data[i];
            i += 1;
        }
        sum
    }
}

impl<const N: usize> Add for Vec<N> {
    type Output = Self;
    fn add(self, rhs: Self) -> Self {
        let mut data = [0.0; N];
        let mut i = 0;
        while i < N {
            data[i] = self.data[i] + rhs.data[i];
            i += 1;
        }
        Vec { data }
    }
}

编译器会在类型层面阻止不同维度的向量做运算:

let v3 = Vec::<3> { data: [1.0, 2.0, 3.0] };
let v4 = Vec::<4> { data: [1.0, 2.0, 3.0, 4.0] };
// let bad = v3 + v4;  // 编译错误:类型不匹配

编译时矩阵

更强大的例子——带维度信息的矩阵:

#[derive(Debug, Clone)]
struct Matrix<const ROWS: usize, const COLS: usize> {
    data: [[f64; COLS]; ROWS],
}

impl<const ROWS: usize, const COLS: usize> Matrix<ROWS, COLS> {
    fn transpose(&self) -> Matrix<COLS, ROWS> {
        let mut result = [[0.0; ROWS]; COLS];
        for i in 0..ROWS {
            for j in 0..COLS {
                result[j][i] = self.data[i][j];
            }
        }
        Matrix { data: result }
    }
}

// 矩阵乘法:(M×K) * (K×N) -> (M×N)
// 注意K必须相同,这在类型层面就保证了
impl<const M: usize, const K: usize> Matrix<M, K> {
    fn mul<const N: usize>(&self, rhs: &Matrix<K, N>) -> Matrix<M, N> {
        let mut result = [[0.0; N]; M];
        for i in 0..M {
            for j in 0..N {
                for p in 0..K {
                    result[i][j] += self.data[i][p] * rhs.data[p][j];
                }
            }
        }
        Matrix { data: result }
    }
}

维度不匹配时的乘法在编译期就会报错,而不是留到运行时panic。

编译时断言

利用const计算实现编译时断言,把不变量检查从运行时前移到编译时:

/// 编译时断言:条件为false时触发编译错误
macro_rules! const_assert {
    ($cond:expr) => {
        const _: () = assert!($cond);
    };
    ($cond:expr, $msg:expr) => {
        const _: () = assert!($cond, $msg);
    };
}

// 确保平台假设
const_assert!(std::mem::size_of::<usize>() >= 4, "需要至少32位平台");

// 确保配置合理
const MAX_CONNECTIONS: usize = 1024;
const BUFFER_SIZE: usize = 4096;
const_assert!(BUFFER_SIZE.is_power_of_two());
const_assert!(MAX_CONNECTIONS <= 65535, "连接数不能超过u16范围");

结合const generics实现更精细的编译时检查:

struct BoundedArray<T, const N: usize, const MAX: usize> {
    data: [T; N],
}

impl<T: Default + Copy, const N: usize, const MAX: usize> BoundedArray<T, N, MAX> {
    fn new() -> Self {
        const { assert!(N <= MAX, "数组大小超过上限") };
        BoundedArray {
            data: [T::default(); N],
        }
    }
}

const { ... } 块(Rust 1.79+)让内联的编译时断言写起来更自然。

实际应用:编译时哈希

一个实用例子——编译时FNV-1a哈希,用于字符串常量的快速比较:

const fn fnv1a_hash(bytes: &[u8]) -> u64 {
    let mut hash: u64 = 0xcbf29ce484222325;
    let mut i = 0;
    while i < bytes.len() {
        hash ^= bytes[i] as u64;
        hash = hash.wrapping_mul(0x100000001b3);
        i += 1;
    }
    hash
}

const HASH_GET: u64 = fnv1a_hash(b"GET");
const HASH_POST: u64 = fnv1a_hash(b"POST");
const HASH_PUT: u64 = fnv1a_hash(b"PUT");

fn match_method(method: &[u8]) -> &'static str {
    match fnv1a_hash(method) {
        HASH_GET => "GET",
        HASH_POST => "POST",
        HASH_PUT => "PUT",
        _ => "UNKNOWN",
    }
}

小结

Rust的编译时计算正在变得越来越强大。const fn + const generics + 编译时断言组合起来,可以把大量检查和计算提前到编译期完成。这不只是性能优化——更重要的是把运行时的"可能出错"变成编译时的"不可能出错"。随着const trait、const closure等特性逐步稳定,这个方向还会走得更远。