量化交易:用Rust实现高性能回测引擎

回测引擎是量化交易系统的核心组件。本文用 Rust 实现一个事件驱动的回测引擎,覆盖数据加载、信号生成、撮合模拟和绩效统计。

一、为什么用 Rust

量化回测对性能有极高要求。一个典型场景:在 10 年的分钟线数据(约 120 万根 K 线)上回测 1000 个参数组合,总共需要处理约 12 亿次数据点。Python 的 backtrader/zipline 在这种规模下会非常慢(小时级别),而 Rust 可以在几分钟内完成。

Rust 的优势还包括:

  • 内存安全:量化系统中内存错误可能导致错误的交易信号,后果很严重
  • 零成本抽象:trait、泛型不会带来运行时开销
  • 并行安全:Rayon 并行回测时编译器保证没有数据竞争

二、核心数据结构

use chrono::NaiveDateTime;
use serde::Deserialize;

/// K线数据
#[derive(Debug, Clone, Deserialize)]
pub struct Bar {
    pub datetime: NaiveDateTime,
    pub open: f64,
    pub high: f64,
    pub low: f64,
    pub close: f64,
    pub volume: f64,
}

/// 交易信号
#[derive(Debug, Clone)]
pub enum Signal {
    Buy { symbol: String, price: f64, quantity: f64 },
    Sell { symbol: String, price: f64, quantity: f64 },
    Hold,
}

/// 订单
#[derive(Debug, Clone)]
pub struct Order {
    pub id: u64,
    pub symbol: String,
    pub side: OrderSide,
    pub price: f64,
    pub quantity: f64,
    pub timestamp: NaiveDateTime,
    pub status: OrderStatus,
}

#[derive(Debug, Clone, PartialEq)]
pub enum OrderSide { Buy, Sell }

#[derive(Debug, Clone, PartialEq)]
pub enum OrderStatus { Pending, Filled, Cancelled }

/// 持仓
#[derive(Debug, Clone)]
pub struct Position {
    pub symbol: String,
    pub quantity: f64,
    pub avg_cost: f64,
    pub unrealized_pnl: f64,
}

/// 账户状态
#[derive(Debug, Clone)]
pub struct Account {
    pub cash: f64,
    pub positions: std::collections::HashMap<String, Position>,
    pub total_value: f64,
    pub order_count: u64,
}

impl Account {
    pub fn new(initial_cash: f64) -> Self {
        Account {
            cash: initial_cash,
            positions: std::collections::HashMap::new(),
            total_value: initial_cash,
            order_count: 0,
        }
    }

    pub fn update_value(&mut self, prices: &std::collections::HashMap<String, f64>) {
        let position_value: f64 = self.positions.iter_mut().map(|(sym, pos)| {
            if let Some(&price) = prices.get(sym) {
                pos.unrealized_pnl = (price - pos.avg_cost) * pos.quantity;
                price * pos.quantity
            } else {
                pos.avg_cost * pos.quantity
            }
        }).sum();
        self.total_value = self.cash + position_value;
    }
}

三、事件驱动架构

回测引擎的核心是一个事件循环。每根 K 线到来时,依次触发:数据更新 → 策略计算 → 信号生成 → 订单撮合 → 账户更新。

/// 事件类型
#[derive(Debug)]
pub enum Event {
    MarketData(Bar),
    SignalGenerated(Signal),
    OrderSubmitted(Order),
    OrderFilled(Order),
}

/// 策略 trait
pub trait Strategy {
    fn on_bar(&mut self, bar: &Bar, account: &Account) -> Signal;
    fn name(&self) -> &str;
}

/// 撮合引擎
pub struct MatchingEngine {
    slippage: f64,       // 滑点(百分比)
    commission: f64,     // 手续费(百分比)
}

impl MatchingEngine {
    pub fn new(slippage: f64, commission: f64) -> Self {
        MatchingEngine { slippage, commission }
    }

    /// 尝试撮合订单
    pub fn try_fill(&self, order: &mut Order, bar: &Bar) -> Option<f64> {
        // 简单撮合:用下一根 K 线的开盘价成交
        let fill_price = match order.side {
            OrderSide::Buy => bar.open * (1.0 + self.slippage),
            OrderSide::Sell => bar.open * (1.0 - self.slippage),
        };

        // 检查价格是否在 K 线范围内
        if fill_price >= bar.low && fill_price <= bar.high {
            let commission = fill_price * order.quantity * self.commission;
            order.status = OrderStatus::Filled;
            order.price = fill_price;
            Some(commission)
        } else {
            None
        }
    }
}

四、回测引擎主体

/// 回测引擎
pub struct Backtester<S: Strategy> {
    strategy: S,
    matching_engine: MatchingEngine,
    account: Account,
    pending_orders: Vec<Order>,
    equity_curve: Vec<(NaiveDateTime, f64)>,
    trades: Vec<Trade>,
}

#[derive(Debug, Clone)]
pub struct Trade {
    pub symbol: String,
    pub side: OrderSide,
    pub entry_price: f64,
    pub exit_price: f64,
    pub quantity: f64,
    pub pnl: f64,
    pub entry_time: NaiveDateTime,
    pub exit_time: NaiveDateTime,
}

impl<S: Strategy> Backtester<S> {
    pub fn new(strategy: S, initial_cash: f64, slippage: f64, commission: f64) -> Self {
        Backtester {
            strategy,
            matching_engine: MatchingEngine::new(slippage, commission),
            account: Account::new(initial_cash),
            pending_orders: Vec::new(),
            equity_curve: Vec::new(),
            trades: Vec::new(),
        }
    }

    pub fn run(&mut self, data: &[Bar]) -> BacktestResult {
        for (i, bar) in data.iter().enumerate() {
            // 1. 撮合待成交订单(用当前 bar 撮合上一根 bar 产生的订单)
            self.process_pending_orders(bar);

            // 2. 更新账户净值
            let mut prices = std::collections::HashMap::new();
            prices.insert("default".to_string(), bar.close);
            self.account.update_value(&prices);

            // 3. 记录权益曲线
            self.equity_curve.push((bar.datetime, self.account.total_value));

            // 4. 策略生成信号
            let signal = self.strategy.on_bar(bar, &self.account);

            // 5. 信号转订单
            match signal {
                Signal::Buy { symbol, price, quantity } => {
                    let cost = price * quantity;
                    if self.account.cash >= cost {
                        self.account.order_count += 1;
                        self.pending_orders.push(Order {
                            id: self.account.order_count,
                            symbol,
                            side: OrderSide::Buy,
                            price,
                            quantity,
                            timestamp: bar.datetime,
                            status: OrderStatus::Pending,
                        });
                    }
                }
                Signal::Sell { symbol, price, quantity } => {
                    self.account.order_count += 1;
                    self.pending_orders.push(Order {
                        id: self.account.order_count,
                        symbol,
                        side: OrderSide::Sell,
                        price,
                        quantity,
                        timestamp: bar.datetime,
                        status: OrderStatus::Pending,
                    });
                }
                Signal::Hold => {}
            }
        }

        self.calculate_result()
    }

    fn process_pending_orders(&mut self, bar: &Bar) {
        let mut filled_orders = Vec::new();

        for order in &mut self.pending_orders {
            if order.status != OrderStatus::Pending {
                continue;
            }
            if let Some(commission) = self.matching_engine.try_fill(order, bar) {
                match order.side {
                    OrderSide::Buy => {
                        self.account.cash -= order.price * order.quantity + commission;
                        let pos = self.account.positions
                            .entry(order.symbol.clone())
                            .or_insert(Position {
                                symbol: order.symbol.clone(),
                                quantity: 0.0,
                                avg_cost: 0.0,
                                unrealized_pnl: 0.0,
                            });
                        let total_cost = pos.avg_cost * pos.quantity + order.price * order.quantity;
                        pos.quantity += order.quantity;
                        pos.avg_cost = total_cost / pos.quantity;
                    }
                    OrderSide::Sell => {
                        self.account.cash += order.price * order.quantity - commission;
                        if let Some(pos) = self.account.positions.get_mut(&order.symbol) {
                            let pnl = (order.price - pos.avg_cost) * order.quantity;
                            pos.quantity -= order.quantity;
                            self.trades.push(Trade {
                                symbol: order.symbol.clone(),
                                side: OrderSide::Sell,
                                entry_price: pos.avg_cost,
                                exit_price: order.price,
                                quantity: order.quantity,
                                pnl,
                                entry_time: order.timestamp,
                                exit_time: bar.datetime,
                            });
                            if pos.quantity <= 0.0 {
                                self.account.positions.remove(&order.symbol);
                            }
                        }
                    }
                }
                filled_orders.push(order.id);
            }
        }

        self.pending_orders.retain(|o| o.status == OrderStatus::Pending);
    }

    fn calculate_result(&self) -> BacktestResult {
        let total_trades = self.trades.len();
        let winning_trades = self.trades.iter().filter(|t| t.pnl > 0.0).count();
        let total_pnl: f64 = self.trades.iter().map(|t| t.pnl).sum();

        // 计算最大回撤
        let mut peak = 0.0_f64;
        let mut max_drawdown = 0.0_f64;
        for (_, value) in &self.equity_curve {
            peak = peak.max(*value);
            let dd = (peak - value) / peak;
            max_drawdown = max_drawdown.max(dd);
        }

        // 年化收益率(假设 252 个交易日)
        let initial = self.equity_curve.first().map(|e| e.1).unwrap_or(0.0);
        let final_val = self.equity_curve.last().map(|e| e.1).unwrap_or(0.0);
        let days = self.equity_curve.len() as f64;
        let annual_return = if days > 0.0 && initial > 0.0 {
            ((final_val / initial).powf(252.0 / days) - 1.0) * 100.0
        } else {
            0.0
        };

        BacktestResult {
            total_trades,
            winning_trades,
            win_rate: if total_trades > 0 { winning_trades as f64 / total_trades as f64 * 100.0 } else { 0.0 },
            total_pnl,
            max_drawdown: max_drawdown * 100.0,
            annual_return,
            final_equity: final_val,
        }
    }
}

#[derive(Debug)]
pub struct BacktestResult {
    pub total_trades: usize,
    pub winning_trades: usize,
    pub win_rate: f64,
    pub total_pnl: f64,
    pub max_drawdown: f64,
    pub annual_return: f64,
    pub final_equity: f64,
}

impl std::fmt::Display for BacktestResult {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        writeln!(f, "=== Backtest Result ===")?;
        writeln!(f, "Total Trades:   {}", self.total_trades)?;
        writeln!(f, "Win Rate:       {:.1}%", self.win_rate)?;
        writeln!(f, "Total PnL:      {:.2}", self.total_pnl)?;
        writeln!(f, "Max Drawdown:   {:.2}%", self.max_drawdown)?;
        writeln!(f, "Annual Return:  {:.2}%", self.annual_return)?;
        writeln!(f, "Final Equity:   {:.2}", self.final_equity)
    }
}

五、实现一个示例策略

双均线交叉策略——短期均线上穿长期均线买入,下穿卖出:

pub struct DualMACross {
    short_period: usize,
    long_period: usize,
    prices: Vec<f64>,
}

impl DualMACross {
    pub fn new(short_period: usize, long_period: usize) -> Self {
        DualMACross {
            short_period,
            long_period,
            prices: Vec::new(),
        }
    }

    fn sma(&self, period: usize) -> Option<f64> {
        if self.prices.len() < period {
            return None;
        }
        let sum: f64 = self.prices[self.prices.len()-period..].iter().sum();
        Some(sum / period as f64)
    }
}

impl Strategy for DualMACross {
    fn on_bar(&mut self, bar: &Bar, account: &Account) -> Signal {
        self.prices.push(bar.close);

        let short_ma = match self.sma(self.short_period) { Some(v) => v, None => return Signal::Hold };
        let long_ma = match self.sma(self.long_period) { Some(v) => v, None => return Signal::Hold };

        let has_position = account.positions.contains_key("default");

        if short_ma > long_ma && !has_position {
            let quantity = (account.cash * 0.95 / bar.close).floor();
            if quantity > 0.0 {
                return Signal::Buy {
                    symbol: "default".to_string(),
                    price: bar.close,
                    quantity,
                };
            }
        } else if short_ma < long_ma && has_position {
            if let Some(pos) = account.positions.get("default") {
                return Signal::Sell {
                    symbol: "default".to_string(),
                    price: bar.close,
                    quantity: pos.quantity,
                };
            }
        }

        Signal::Hold
    }

    fn name(&self) -> &str { "DualMA Cross" }
}

六、并行参数优化

回测引擎的一大优势是可以用 Rayon 做并行参数优化:

use rayon::prelude::*;

fn optimize_parameters(data: &[Bar]) {
    let params: Vec<(usize, usize)> = (5..=30)
        .flat_map(|short| ((short+10)..=120).map(move |long| (short, long)))
        .collect();

    let results: Vec<_> = params.par_iter().map(|&(short, long)| {
        let strategy = DualMACross::new(short, long);
        let mut bt = Backtester::new(strategy, 1_000_000.0, 0.001, 0.0003);
        let result = bt.run(data);
        (short, long, result.annual_return, result.max_drawdown)
    }).collect();

    // 按年化收益排序
    let mut results = results;
    results.sort_by(|a, b| b.2.partial_cmp(&a.2).unwrap());

    println!("Top 10 parameter combinations:");
    for (short, long, ret, dd) in results.iter().take(10) {
        println!("  MA({},{}) -> Annual: {:.2}%, MaxDD: {:.2}%", short, long, ret, dd);
    }
}

在 2000+ 种参数组合的搜索中,Rayon 会自动利用所有 CPU 核心。在 16 核机器上,相比单线程可以获得接近线性的加速。

七、内存映射数据文件

对于大规模历史数据,可以用 memmap2 做内存映射,避免一次性加载到内存:

use memmap2::Mmap;
use std::fs::File;

fn load_bars_mmap(path: &str) -> Vec<Bar> {
    let file = File::open(path).unwrap();
    let mmap = unsafe { Mmap::map(&file).unwrap() };

    // 假设数据是固定格式的二进制文件
    // 每条记录: timestamp(8B) + OHLCV(5*8B) = 48B
    let record_size = 48;
    let count = mmap.len() / record_size;

    let mut bars = Vec::with_capacity(count);
    for i in 0..count {
        let offset = i * record_size;
        let ts = i64::from_le_bytes(mmap[offset..offset+8].try_into().unwrap());
        let open = f64::from_le_bytes(mmap[offset+8..offset+16].try_into().unwrap());
        let high = f64::from_le_bytes(mmap[offset+16..offset+24].try_into().unwrap());
        let low = f64::from_le_bytes(mmap[offset+24..offset+32].try_into().unwrap());
        let close = f64::from_le_bytes(mmap[offset+32..offset+40].try_into().unwrap());
        let volume = f64::from_le_bytes(mmap[offset+40..offset+48].try_into().unwrap());

        bars.push(Bar {
            datetime: chrono::DateTime::from_timestamp(ts, 0)
                .unwrap().naive_utc(),
            open, high, low, close, volume,
        });
    }
    bars
}

总结

这个回测引擎实现了核心的事件驱动回测流程。生产级别还需要补充:多品种支持、限价单/止损单、资金管理模块、滑点模型的精细化、实盘对接接口等。但作为骨架,它展示了 Rust 在量化领域的适用性——类型安全避免了很多隐蔽的数值错误,性能让大规模参数搜索变得实际可行。