Python + Rust混合编程:PyO3实战

Python 写逻辑方便,但 CPU 密集型任务性能不行。用 Rust 重写热点路径,通过 PyO3 无缝暴露为 Python 模块,是目前最成熟的混合编程方案。本文从环境搭建到实际加速案例做一次完整实战。

PyO3 + maturin 环境搭建

PyO3 是 Rust 编写 Python 扩展的绑定库,maturin 是配套的构建/打包工具。

# 安装 maturin
pip install maturin

# 创建项目
mkdir py-rust-demo && cd py-rust-demo
maturin init --bindings pyo3

# 目录结构
# ├── Cargo.toml
# ├── pyproject.toml
# └── src/
#     └── lib.rs

Cargo.toml 关键配置:

[package]
name = "py_rust_demo"
version = "0.1.0"
edition = "2021"

[lib]
name = "py_rust_demo"
crate-type = ["cdylib"]

[dependencies]
pyo3 = { version = "0.22", features = ["extension-module"] }

crate-type = ["cdylib"] 是关键——生成动态链接库供 Python 加载。

基础:#[pyfunction]

最简单的用法是导出函数:

use pyo3::prelude::*;

/// 计算斐波那契数列第 n 项
#[pyfunction]
fn fibonacci(n: u64) -> u64 {
    let (mut a, mut b) = (0u64, 1u64);
    for _ in 0..n {
        let t = b;
        b = a + b;
        a = t;
    }
    a
}

/// Python 模块入口
#[pymodule]
fn py_rust_demo(m: &Bound<'_, PyModule>) -> PyResult<()> {
    m.add_function(wrap_pyfunction!(fibonacci, m)?)?;
    Ok(())
}

构建并测试:

maturin develop --release

python -c "from py_rust_demo import fibonacci; print(fibonacci(50))"
# 12586269025

maturin develop 会编译 Rust 代码并安装到当前 Python 环境,--release 开启优化。

#[pyclass] 导出类

use pyo3::prelude::*;

#[pyclass]
struct Matrix {
    data: Vec<Vec<f64>>,
    rows: usize,
    cols: usize,
}

#[pymethods]
impl Matrix {
    #[new]
    fn new(rows: usize, cols: usize) -> Self {
        Matrix {
            data: vec![vec![0.0; cols]; rows],
            rows,
            cols,
        }
    }

    fn set(&mut self, row: usize, col: usize, val: f64) -> PyResult<()> {
        if row >= self.rows || col >= self.cols {
            return Err(pyo3::exceptions::PyIndexError::new_err("Index out of bounds"));
        }
        self.data[row][col] = val;
        Ok(())
    }

    fn get(&self, row: usize, col: usize) -> PyResult<f64> {
        if row >= self.rows || col >= self.cols {
            return Err(pyo3::exceptions::PyIndexError::new_err("Index out of bounds"));
        }
        Ok(self.data[row][col])
    }

    fn multiply(&self, other: &Matrix) -> PyResult<Matrix> {
        if self.cols != other.rows {
            return Err(pyo3::exceptions::PyValueError::new_err(
                "Matrix dimensions mismatch"
            ));
        }
        let mut result = Matrix::new(self.rows, other.cols);
        for i in 0..self.rows {
            for j in 0..other.cols {
                let mut sum = 0.0;
                for k in 0..self.cols {
                    sum += self.data[i][k] * other.data[k][j];
                }
                result.data[i][j] = sum;
            }
        }
        Ok(result)
    }

    fn __repr__(&self) -> String {
        format!("Matrix({}x{})", self.rows, self.cols)
    }
}

Python 侧使用方式完全符合直觉:

from py_rust_demo import Matrix

a = Matrix(2, 3)
a.set(0, 0, 1.0)
a.set(0, 1, 2.0)
# ...
print(a)  # Matrix(2x3)

类型转换

PyO3 自动处理 Python <-> Rust 的类型转换:

Python 类型 Rust 类型 说明
int i32, i64, u64 自动转换,溢出会报错
float f32, f64 直接映射
str String, &str UTF-8 安全
bool bool 直接映射
list Vec<T> 自动递归转换
dict HashMap<K, V> 自动转换
bytes &[u8], Vec<u8> 零拷贝可能(&[u8]
None Option<T> None -> None

对于复杂类型或需要高性能的场景,可以使用 Bound<'_, PyList> 等引用类型避免拷贝:

#[pyfunction]
fn sum_list(list: &Bound<'_, PyList>) -> PyResult<f64> {
    let mut total = 0.0;
    for item in list.iter() {
        total += item.extract::<f64>()?;
    }
    Ok(total)
}

错误处理

Rust 的 Result 自动映射到 Python 异常:

use pyo3::exceptions::PyValueError;

#[pyfunction]
fn parse_config(raw: &str) -> PyResult<String> {
    let parsed: serde_json::Value = serde_json::from_str(raw)
        .map_err(|e| PyValueError::new_err(format!("Invalid JSON: {}", e)))?;

    Ok(parsed["name"]
        .as_str()
        .ok_or_else(|| PyValueError::new_err("Missing 'name' field"))?
        .to_string())
}

Python 侧会收到标准的 ValueError,traceback 正常工作。

实战案例:图像处理加速

用 Rust 加速图像灰度转换 + 高斯模糊:

use pyo3::prelude::*;
use pyo3::types::PyBytes;

#[pyfunction]
fn grayscale_blur(py: Python<'_>, data: &[u8], width: usize, height: usize) -> PyResult<Py<PyBytes>> {
    let mut gray = vec![0u8; width * height];

    // RGB -> 灰度
    for i in 0..width * height {
        let r = data[i * 3] as f32;
        let g = data[i * 3 + 1] as f32;
        let b = data[i * 3 + 2] as f32;
        gray[i] = (0.299 * r + 0.587 * g + 0.114 * b) as u8;
    }

    // 3x3 均值模糊
    let mut blurred = vec![0u8; width * height];
    for y in 1..height - 1 {
        for x in 1..width - 1 {
            let mut sum = 0u32;
            for dy in -1i32..=1 {
                for dx in -1i32..=1 {
                    sum += gray[((y as i32 + dy) as usize) * width + (x as i32 + dx) as usize] as u32;
                }
            }
            blurred[y * width + x] = (sum / 9) as u8;
        }
    }

    Ok(PyBytes::new(py, &blurred).into())
}

性能对比(1920x1080 图像):

实现 耗时
纯 Python (循环) ~3200 ms
NumPy 向量化 ~12 ms
Rust (PyO3) ~4 ms
Rust (PyO3 + rayon) ~1.2 ms

Rust 版本比纯 Python 快约 800 倍,比 NumPy 还快 3 倍。加上 rayon 并行后接近 3000 倍加速。

加入 rayon 并行

# Cargo.toml
[dependencies]
pyo3 = { version = "0.22", features = ["extension-module"] }
rayon = "1.10"
use rayon::prelude::*;

#[pyfunction]
fn grayscale_parallel(_py: Python<'_>, data: &[u8], width: usize, height: usize) -> PyResult<Vec<u8>> {
    let gray: Vec<u8> = (0..width * height)
        .into_par_iter()
        .map(|i| {
            let r = data[i * 3] as f32;
            let g = data[i * 3 + 1] as f32;
            let b = data[i * 3 + 2] as f32;
            (0.299 * r + 0.587 * g + 0.114 * b) as u8
        })
        .collect();
    Ok(gray)
}

实战案例:CSV 解析加速

处理大型 CSV 文件也是典型的 CPU 密集场景:

#[pyfunction]
fn fast_csv_stats(path: &str) -> PyResult<(f64, f64, f64, usize)> {
    let mut reader = csv::Reader::from_path(path)
        .map_err(|e| pyo3::exceptions::PyIOError::new_err(e.to_string()))?;

    let mut sum = 0.0f64;
    let mut min = f64::MAX;
    let mut max = f64::MIN;
    let mut count = 0usize;

    for result in reader.records() {
        let record = result.map_err(|e| pyo3::exceptions::PyValueError::new_err(e.to_string()))?;
        if let Some(val_str) = record.get(0) {
            if let Ok(val) = val_str.parse::<f64>() {
                sum += val;
                if val < min { min = val; }
                if val > max { max = val; }
                count += 1;
            }
        }
    }

    let avg = if count > 0 { sum / count as f64 } else { 0.0 };
    Ok((avg, min, max, count))
}

100 万行 CSV 的性能对比:

实现 耗时
Python csv 模块 ~2100 ms
pandas read_csv ~280 ms
Rust (PyO3) ~95 ms

Rust 版本比 pandas 快 3 倍,比纯 Python 快 22 倍。

开发工作流

# 开发阶段:编译并安装到当前环境
maturin develop --release

# 构建 wheel 包
maturin build --release

# 发布到 PyPI
maturin publish

# 交叉编译(如为 Linux aarch64 构建)
maturin build --release --target aarch64-unknown-linux-gnu

maturin 自动处理 Python ABI 兼容性、wheel 打包、平台标签等细节。配合 GitHub Actions 可以做到一次 push、多平台自动构建。

何时该用 PyO3

适合的场景:

  • CPU 密集型热点路径(图像处理、数值计算、数据解析)
  • 需要调用 Rust 生态库(加密、压缩、网络协议)
  • NumPy/pandas 已经不够快

不适合的场景:

  • IO 密集型任务(Python async 就够了)
  • 逻辑频繁变动的业务代码(Rust 编译慢)
  • 团队没人会 Rust(学习成本不可忽略)