OpenAI 的 Whisper 在通用场景下语音识别效果不错,但在中文特定领域(比如医疗、法律术语)识别率可能明显下降。Whisper 是支持微调的,通过微调可以让模型更适应特定场景。
为什么要微调
Whisper 的预训练数据覆盖了多语言多场景,是一个很好的通用模型。但具体到某个垂直领域时,它会遇到几个问题:
- 领域术语识别差:医疗场景下,"胸腔镜"可能被识别成"胸腔经","心房颤动"变成"心防颤动"
- 中文口音和方言:带口音的普通话识别率明显下降
- 特殊格式:有些场景需要输出特定格式,比如数字需要转成阿拉伯数字而不是汉字
微调的目标就是在保持通用能力的基础上,让模型更适应特定场景。
数据准备
微调需要"音频+文本"对。数据质量比数量重要得多。
数据格式
HuggingFace 的 Whisper 微调流程期望的数据格式:
# 每条数据包含:
{
"audio": {
"path": "audio_001.wav",
"array": [...], # numpy 数组
"sampling_rate": 16000
},
"sentence": "对应的文本标注"
}
音频要求:采样率 16kHz,单声道。如果原始音频不是这个格式,需要先做重采样:
import librosa
audio, sr = librosa.load("input.mp3", sr=16000, mono=True)
数据量建议
根据社区经验:
- 5-10 小时:有一定效果,但可能过拟合
- 20-50 小时:比较理想的范围
- 100+ 小时:效果更好,但标注成本高
标注工作很耗时。可以使用半自动方式:先用未微调的 Whisper 跑一遍得到初始文本,然后人工校对修正。比纯人工标注效率高很多。
数据集构建
from datasets import Dataset, Audio
# 假设已有 audio_files 和 transcriptions 列表
data = {
"audio": audio_files,
"sentence": transcriptions
}
dataset = Dataset.from_dict(data).cast_column("audio", Audio(sampling_rate=16000))
# 拆分训练集和测试集
dataset = dataset.train_test_split(test_size=0.1)
print(f"训练集: {len(dataset['train'])} 条")
print(f"测试集: {len(dataset['test'])} 条")
HuggingFace Transformers 微调流程
加载模型和处理器
from transformers import WhisperProcessor, WhisperForConditionalGeneration
model_name = "openai/whisper-small"
processor = WhisperProcessor.from_pretrained(model_name)
model = WhisperForConditionalGeneration.from_pretrained(model_name)
# 设置语言和任务
model.config.forced_decoder_ids = processor.get_decoder_prompt_ids(
language="zh", task="transcribe"
)
model.config.suppress_tokens = []
推荐从 whisper-small(244M 参数)开始,在单卡 GPU 上微调即可。whisper-large-v2 效果更好但训练资源需求大得多。
数据预处理
def prepare_dataset(batch):
audio = batch["audio"]
# 处理音频特征
batch["input_features"] = processor.feature_extractor(
audio["array"],
sampling_rate=audio["sampling_rate"]
).input_features[0]
# 处理文本标签
batch["labels"] = processor.tokenizer(batch["sentence"]).input_ids
return batch
dataset = dataset.map(prepare_dataset, remove_columns=dataset.column_names["train"])
训练配置
from transformers import Seq2SeqTrainingArguments, Seq2SeqTrainer
from transformers import DataCollatorForSeq2Seq
import evaluate
# 数据整理器
data_collator = DataCollatorForSeq2Seq(processor.tokenizer, model=model)
# WER 指标
wer_metric = evaluate.load("wer")
cer_metric = evaluate.load("cer")
def compute_metrics(pred):
pred_ids = pred.predictions
label_ids = pred.label_ids
label_ids[label_ids == -100] = processor.tokenizer.pad_token_id
pred_str = processor.tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
label_str = processor.tokenizer.batch_decode(label_ids, skip_special_tokens=True)
wer = wer_metric.compute(predictions=pred_str, references=label_str)
cer = cer_metric.compute(predictions=pred_str, references=label_str)
return {"wer": wer, "cer": cer}
training_args = Seq2SeqTrainingArguments(
output_dir="./whisper-small-zh-finetuned",
per_device_train_batch_size=16,
gradient_accumulation_steps=2,
learning_rate=1e-5,
warmup_steps=500,
max_steps=5000,
fp16=True,
evaluation_strategy="steps",
eval_steps=500,
save_steps=500,
logging_steps=100,
predict_with_generate=True,
generation_max_length=225,
report_to="tensorboard",
load_best_model_at_end=True,
metric_for_best_model="cer",
greater_is_better=False,
)
开始训练
trainer = Seq2SeqTrainer(
model=model,
args=training_args,
train_dataset=dataset["train"],
eval_dataset=dataset["test"],
data_collator=data_collator,
compute_metrics=compute_metrics,
tokenizer=processor.feature_extractor,
)
trainer.train()
训练参数选择
几个关键参数的经验:
- learning_rate:1e-5 是比较安全的起点。太大容易破坏预训练权重,太小收敛太慢。
- batch_size:取决于 GPU 显存。whisper-small 在 A100 40GB 上可以用 batch_size=16。显存不够就用 gradient_accumulation_steps 来补。
- max_steps:根据数据量调整,可以通过观察验证集 loss 来判断是否需要更多步数。
- warmup_steps:前 500 步 warmup 是个不错的选择,让学习率缓慢上升。
- fp16:建议开启,能节省显存且训练速度更快。
评估指标
语音识别常用两个指标:
- WER(Word Error Rate):词错误率,适合英文。计算公式:(替换+删除+插入) / 参考词数
- CER(Character Error Rate):字错误率,适合中文。以字为单位计算错误率
中文场景主要看 CER。
微调效果
在领域术语识别上,微调后的模型通常会有明显改善。例如:
- "胸腔镜手术" → 微调前可能识别为 "胸腔经手术",微调后可正确识别
- "房颤消融术" → 微调前可能识别为 "防颤小容术",微调后可正确识别
- "普萘洛尔" → 微调前可能识别为 "普奈落尔",微调后可正确识别
具体的 CER/WER 下降幅度取决于数据质量和训练参数,一般在领域场景下有显著改善。
注意事项
- 不要过拟合:如果训练集 loss 持续下降但验证集 loss 上升,说明过拟合了。及时停止或减少训练步数。
- 数据多样性:标注数据要覆盖不同的说话人、语速、环境噪音。只用一个人的录音微调会导致模型只对这个人效果好。
- 保持通用能力:微调过度可能导致通用场景识别能力下降。可以在训练数据中混入一部分通用数据。
- 模型选择:显存够就用 whisper-medium 或 large,效果更好。small 是性价比最高的选择。
微调 Whisper 的门槛其实不高,HuggingFace 的工具链非常完善。瓶颈主要在数据标注上。如果场景有足够的标注数据,微调是提升识别效果最直接的方法。