回测引擎是量化交易系统的核心组件。本文用 Rust 实现一个事件驱动的回测引擎,覆盖数据加载、信号生成、撮合模拟和绩效统计。
一、为什么用 Rust
量化回测对性能有极高要求。一个典型场景:在 10 年的分钟线数据(约 120 万根 K 线)上回测 1000 个参数组合,总共需要处理约 12 亿次数据点。Python 的 backtrader/zipline 在这种规模下会非常慢(小时级别),而 Rust 可以在几分钟内完成。
Rust 的优势还包括:
- 内存安全:量化系统中内存错误可能导致错误的交易信号,后果很严重
- 零成本抽象:trait、泛型不会带来运行时开销
- 并行安全:Rayon 并行回测时编译器保证没有数据竞争
二、核心数据结构
use chrono::NaiveDateTime;
use serde::Deserialize;
/// K线数据
#[derive(Debug, Clone, Deserialize)]
pub struct Bar {
pub datetime: NaiveDateTime,
pub open: f64,
pub high: f64,
pub low: f64,
pub close: f64,
pub volume: f64,
}
/// 交易信号
#[derive(Debug, Clone)]
pub enum Signal {
Buy { symbol: String, price: f64, quantity: f64 },
Sell { symbol: String, price: f64, quantity: f64 },
Hold,
}
/// 订单
#[derive(Debug, Clone)]
pub struct Order {
pub id: u64,
pub symbol: String,
pub side: OrderSide,
pub price: f64,
pub quantity: f64,
pub timestamp: NaiveDateTime,
pub status: OrderStatus,
}
#[derive(Debug, Clone, PartialEq)]
pub enum OrderSide { Buy, Sell }
#[derive(Debug, Clone, PartialEq)]
pub enum OrderStatus { Pending, Filled, Cancelled }
/// 持仓
#[derive(Debug, Clone)]
pub struct Position {
pub symbol: String,
pub quantity: f64,
pub avg_cost: f64,
pub unrealized_pnl: f64,
}
/// 账户状态
#[derive(Debug, Clone)]
pub struct Account {
pub cash: f64,
pub positions: std::collections::HashMap<String, Position>,
pub total_value: f64,
pub order_count: u64,
}
impl Account {
pub fn new(initial_cash: f64) -> Self {
Account {
cash: initial_cash,
positions: std::collections::HashMap::new(),
total_value: initial_cash,
order_count: 0,
}
}
pub fn update_value(&mut self, prices: &std::collections::HashMap<String, f64>) {
let position_value: f64 = self.positions.iter_mut().map(|(sym, pos)| {
if let Some(&price) = prices.get(sym) {
pos.unrealized_pnl = (price - pos.avg_cost) * pos.quantity;
price * pos.quantity
} else {
pos.avg_cost * pos.quantity
}
}).sum();
self.total_value = self.cash + position_value;
}
}
三、事件驱动架构
回测引擎的核心是一个事件循环。每根 K 线到来时,依次触发:数据更新 → 策略计算 → 信号生成 → 订单撮合 → 账户更新。
/// 事件类型
#[derive(Debug)]
pub enum Event {
MarketData(Bar),
SignalGenerated(Signal),
OrderSubmitted(Order),
OrderFilled(Order),
}
/// 策略 trait
pub trait Strategy {
fn on_bar(&mut self, bar: &Bar, account: &Account) -> Signal;
fn name(&self) -> &str;
}
/// 撮合引擎
pub struct MatchingEngine {
slippage: f64, // 滑点(百分比)
commission: f64, // 手续费(百分比)
}
impl MatchingEngine {
pub fn new(slippage: f64, commission: f64) -> Self {
MatchingEngine { slippage, commission }
}
/// 尝试撮合订单
pub fn try_fill(&self, order: &mut Order, bar: &Bar) -> Option<f64> {
// 简单撮合:用下一根 K 线的开盘价成交
let fill_price = match order.side {
OrderSide::Buy => bar.open * (1.0 + self.slippage),
OrderSide::Sell => bar.open * (1.0 - self.slippage),
};
// 检查价格是否在 K 线范围内
if fill_price >= bar.low && fill_price <= bar.high {
let commission = fill_price * order.quantity * self.commission;
order.status = OrderStatus::Filled;
order.price = fill_price;
Some(commission)
} else {
None
}
}
}
四、回测引擎主体
/// 回测引擎
pub struct Backtester<S: Strategy> {
strategy: S,
matching_engine: MatchingEngine,
account: Account,
pending_orders: Vec<Order>,
equity_curve: Vec<(NaiveDateTime, f64)>,
trades: Vec<Trade>,
}
#[derive(Debug, Clone)]
pub struct Trade {
pub symbol: String,
pub side: OrderSide,
pub entry_price: f64,
pub exit_price: f64,
pub quantity: f64,
pub pnl: f64,
pub entry_time: NaiveDateTime,
pub exit_time: NaiveDateTime,
}
impl<S: Strategy> Backtester<S> {
pub fn new(strategy: S, initial_cash: f64, slippage: f64, commission: f64) -> Self {
Backtester {
strategy,
matching_engine: MatchingEngine::new(slippage, commission),
account: Account::new(initial_cash),
pending_orders: Vec::new(),
equity_curve: Vec::new(),
trades: Vec::new(),
}
}
pub fn run(&mut self, data: &[Bar]) -> BacktestResult {
for (i, bar) in data.iter().enumerate() {
// 1. 撮合待成交订单(用当前 bar 撮合上一根 bar 产生的订单)
self.process_pending_orders(bar);
// 2. 更新账户净值
let mut prices = std::collections::HashMap::new();
prices.insert("default".to_string(), bar.close);
self.account.update_value(&prices);
// 3. 记录权益曲线
self.equity_curve.push((bar.datetime, self.account.total_value));
// 4. 策略生成信号
let signal = self.strategy.on_bar(bar, &self.account);
// 5. 信号转订单
match signal {
Signal::Buy { symbol, price, quantity } => {
let cost = price * quantity;
if self.account.cash >= cost {
self.account.order_count += 1;
self.pending_orders.push(Order {
id: self.account.order_count,
symbol,
side: OrderSide::Buy,
price,
quantity,
timestamp: bar.datetime,
status: OrderStatus::Pending,
});
}
}
Signal::Sell { symbol, price, quantity } => {
self.account.order_count += 1;
self.pending_orders.push(Order {
id: self.account.order_count,
symbol,
side: OrderSide::Sell,
price,
quantity,
timestamp: bar.datetime,
status: OrderStatus::Pending,
});
}
Signal::Hold => {}
}
}
self.calculate_result()
}
fn process_pending_orders(&mut self, bar: &Bar) {
let mut filled_orders = Vec::new();
for order in &mut self.pending_orders {
if order.status != OrderStatus::Pending {
continue;
}
if let Some(commission) = self.matching_engine.try_fill(order, bar) {
match order.side {
OrderSide::Buy => {
self.account.cash -= order.price * order.quantity + commission;
let pos = self.account.positions
.entry(order.symbol.clone())
.or_insert(Position {
symbol: order.symbol.clone(),
quantity: 0.0,
avg_cost: 0.0,
unrealized_pnl: 0.0,
});
let total_cost = pos.avg_cost * pos.quantity + order.price * order.quantity;
pos.quantity += order.quantity;
pos.avg_cost = total_cost / pos.quantity;
}
OrderSide::Sell => {
self.account.cash += order.price * order.quantity - commission;
if let Some(pos) = self.account.positions.get_mut(&order.symbol) {
let pnl = (order.price - pos.avg_cost) * order.quantity;
pos.quantity -= order.quantity;
self.trades.push(Trade {
symbol: order.symbol.clone(),
side: OrderSide::Sell,
entry_price: pos.avg_cost,
exit_price: order.price,
quantity: order.quantity,
pnl,
entry_time: order.timestamp,
exit_time: bar.datetime,
});
if pos.quantity <= 0.0 {
self.account.positions.remove(&order.symbol);
}
}
}
}
filled_orders.push(order.id);
}
}
self.pending_orders.retain(|o| o.status == OrderStatus::Pending);
}
fn calculate_result(&self) -> BacktestResult {
let total_trades = self.trades.len();
let winning_trades = self.trades.iter().filter(|t| t.pnl > 0.0).count();
let total_pnl: f64 = self.trades.iter().map(|t| t.pnl).sum();
// 计算最大回撤
let mut peak = 0.0_f64;
let mut max_drawdown = 0.0_f64;
for (_, value) in &self.equity_curve {
peak = peak.max(*value);
let dd = (peak - value) / peak;
max_drawdown = max_drawdown.max(dd);
}
// 年化收益率(假设 252 个交易日)
let initial = self.equity_curve.first().map(|e| e.1).unwrap_or(0.0);
let final_val = self.equity_curve.last().map(|e| e.1).unwrap_or(0.0);
let days = self.equity_curve.len() as f64;
let annual_return = if days > 0.0 && initial > 0.0 {
((final_val / initial).powf(252.0 / days) - 1.0) * 100.0
} else {
0.0
};
BacktestResult {
total_trades,
winning_trades,
win_rate: if total_trades > 0 { winning_trades as f64 / total_trades as f64 * 100.0 } else { 0.0 },
total_pnl,
max_drawdown: max_drawdown * 100.0,
annual_return,
final_equity: final_val,
}
}
}
#[derive(Debug)]
pub struct BacktestResult {
pub total_trades: usize,
pub winning_trades: usize,
pub win_rate: f64,
pub total_pnl: f64,
pub max_drawdown: f64,
pub annual_return: f64,
pub final_equity: f64,
}
impl std::fmt::Display for BacktestResult {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
writeln!(f, "=== Backtest Result ===")?;
writeln!(f, "Total Trades: {}", self.total_trades)?;
writeln!(f, "Win Rate: {:.1}%", self.win_rate)?;
writeln!(f, "Total PnL: {:.2}", self.total_pnl)?;
writeln!(f, "Max Drawdown: {:.2}%", self.max_drawdown)?;
writeln!(f, "Annual Return: {:.2}%", self.annual_return)?;
writeln!(f, "Final Equity: {:.2}", self.final_equity)
}
}
五、实现一个示例策略
双均线交叉策略——短期均线上穿长期均线买入,下穿卖出:
pub struct DualMACross {
short_period: usize,
long_period: usize,
prices: Vec<f64>,
}
impl DualMACross {
pub fn new(short_period: usize, long_period: usize) -> Self {
DualMACross {
short_period,
long_period,
prices: Vec::new(),
}
}
fn sma(&self, period: usize) -> Option<f64> {
if self.prices.len() < period {
return None;
}
let sum: f64 = self.prices[self.prices.len()-period..].iter().sum();
Some(sum / period as f64)
}
}
impl Strategy for DualMACross {
fn on_bar(&mut self, bar: &Bar, account: &Account) -> Signal {
self.prices.push(bar.close);
let short_ma = match self.sma(self.short_period) { Some(v) => v, None => return Signal::Hold };
let long_ma = match self.sma(self.long_period) { Some(v) => v, None => return Signal::Hold };
let has_position = account.positions.contains_key("default");
if short_ma > long_ma && !has_position {
let quantity = (account.cash * 0.95 / bar.close).floor();
if quantity > 0.0 {
return Signal::Buy {
symbol: "default".to_string(),
price: bar.close,
quantity,
};
}
} else if short_ma < long_ma && has_position {
if let Some(pos) = account.positions.get("default") {
return Signal::Sell {
symbol: "default".to_string(),
price: bar.close,
quantity: pos.quantity,
};
}
}
Signal::Hold
}
fn name(&self) -> &str { "DualMA Cross" }
}
六、并行参数优化
回测引擎的一大优势是可以用 Rayon 做并行参数优化:
use rayon::prelude::*;
fn optimize_parameters(data: &[Bar]) {
let params: Vec<(usize, usize)> = (5..=30)
.flat_map(|short| ((short+10)..=120).map(move |long| (short, long)))
.collect();
let results: Vec<_> = params.par_iter().map(|&(short, long)| {
let strategy = DualMACross::new(short, long);
let mut bt = Backtester::new(strategy, 1_000_000.0, 0.001, 0.0003);
let result = bt.run(data);
(short, long, result.annual_return, result.max_drawdown)
}).collect();
// 按年化收益排序
let mut results = results;
results.sort_by(|a, b| b.2.partial_cmp(&a.2).unwrap());
println!("Top 10 parameter combinations:");
for (short, long, ret, dd) in results.iter().take(10) {
println!(" MA({},{}) -> Annual: {:.2}%, MaxDD: {:.2}%", short, long, ret, dd);
}
}
在 2000+ 种参数组合的搜索中,Rayon 会自动利用所有 CPU 核心。在 16 核机器上,相比单线程可以获得接近线性的加速。
七、内存映射数据文件
对于大规模历史数据,可以用 memmap2 做内存映射,避免一次性加载到内存:
use memmap2::Mmap;
use std::fs::File;
fn load_bars_mmap(path: &str) -> Vec<Bar> {
let file = File::open(path).unwrap();
let mmap = unsafe { Mmap::map(&file).unwrap() };
// 假设数据是固定格式的二进制文件
// 每条记录: timestamp(8B) + OHLCV(5*8B) = 48B
let record_size = 48;
let count = mmap.len() / record_size;
let mut bars = Vec::with_capacity(count);
for i in 0..count {
let offset = i * record_size;
let ts = i64::from_le_bytes(mmap[offset..offset+8].try_into().unwrap());
let open = f64::from_le_bytes(mmap[offset+8..offset+16].try_into().unwrap());
let high = f64::from_le_bytes(mmap[offset+16..offset+24].try_into().unwrap());
let low = f64::from_le_bytes(mmap[offset+24..offset+32].try_into().unwrap());
let close = f64::from_le_bytes(mmap[offset+32..offset+40].try_into().unwrap());
let volume = f64::from_le_bytes(mmap[offset+40..offset+48].try_into().unwrap());
bars.push(Bar {
datetime: chrono::DateTime::from_timestamp(ts, 0)
.unwrap().naive_utc(),
open, high, low, close, volume,
});
}
bars
}
总结
这个回测引擎实现了核心的事件驱动回测流程。生产级别还需要补充:多品种支持、限价单/止损单、资金管理模块、滑点模型的精细化、实盘对接接口等。但作为骨架,它展示了 Rust 在量化领域的适用性——类型安全避免了很多隐蔽的数值错误,性能让大规模参数搜索变得实际可行。