字符串匹配是算法中最经典的问题之一。KMP算法和Z函数是两种高效的字符串匹配方法,它们的核心思想都是利用已匹配的信息避免重复比较,把暴力匹配的O(n*m)优化到O(n+m)。
问题定义
给定一个文本串text和一个模式串pattern,找出pattern在text中出现的所有位置。
例如:text = "ABABCABABC", pattern = "ABABC",答案是位置0和5。
KMP算法
暴力匹配的问题
暴力匹配每次失配后,模式串向右移动一位,从头开始比较:
text: A B A B C A B A B C
pattern: A B A B D ← 在D处失配
↓ 回退到开头,右移一位
pattern: A B A B D ← 重新开始比较
问题是:我们已经知道前面匹配过的字符是什么,为什么要重新比较?KMP的核心思想就是利用这些信息,让模式串尽可能少回退。
核心概念:前缀函数(next数组)
KMP的关键是预处理模式串,计算一个前缀函数π(也叫next数组)。
π[i]的定义:模式串pattern[0..i]的最长的相等的真前缀与真后缀的长度的长度。
"真前缀"指不等于整个串本身的前缀。
以pattern = "ABABC"为例:
| i | 子串 | 最长相等前后缀 | π[i] |
|---|---|---|---|
| 0 | A | 无 | 0 |
| 1 | AB | 无 | 0 |
| 2 | ABA | A | 1 |
| 3 | ABAB | AB | 2 |
| 4 | ABABC | 无 | 0 |
图解π[3] = 2的含义:
ABAB
├┤ 前缀 "AB"
├┤ 后缀 "AB"
前缀"AB"和后缀"AB"相同,长度为2。
前缀函数的计算
计算过程可以用一个自动机的思路来理解:
graph LR
A["i=0, π=0"] --> B{"i>0 且<br/>pattern[i]≠pattern[j]?"}
B -->|是| C["j = π[j-1]<br/>回退"]
C --> B
B -->|否| D{"pattern[i]==pattern[j]?"}
D -->|是| E["j++"]
D -->|否| F["j不变"]
E --> G["π[i] = j"]
F --> G
G --> H["i++ → 处理下一个"]
H --> B
Python实现:
def compute_prefix(pattern: str) -> list[int]:
"""计算模式串的前缀函数 π"""
n = len(pattern)
pi = [0] * n
j = 0 # 当前已匹配的前缀长度
for i in range(1, n):
# 如果当前字符不匹配,回退 j
while j > 0 and pattern[i] != pattern[j]:
j = pi[j - 1]
# 如果匹配,前缀长度+1
if pattern[i] == pattern[j]:
j += 1
pi[i] = j
return pi
以pattern = "ABABC"为例,手动走一遍:
i=1: pattern[1]='B' ≠ pattern[0]='A', j=0 → π[1]=0
i=2: pattern[2]='A' == pattern[0]='A', j=1 → π[2]=1
i=3: pattern[3]='B' == pattern[1]='B', j=2 → π[3]=2
i=4: pattern[4]='C' ≠ pattern[2]='A', j=π[1]=0
pattern[4]='C' ≠ pattern[0]='A', j=0 → π[4]=0
π = [0, 0, 1, 2, 0]
KMP匹配过程
有了前缀函数,匹配时遇到失配就不再回退到开头,而是跳到π[j-1]的位置:
graph TD
A["开始匹配"] --> B{"text[i] == pattern[j]?"}
B -->|是| C["i++, j++"]
C --> D{"j == len(pattern)?"}
D -->|是| E["找到一个匹配!<br/>记录位置 i-j"]
E --> F["j = π[j-1]"]
F --> B
D -->|否| B
B -->|否| G{"j > 0?"}
G -->|是| H["j = π[j-1]<br/>利用前缀函数回退"]
H --> B
G -->|否| I["i++<br/>模式串未匹配任何字符"]
I --> B
Python实现:
def kmp_search(text: str, pattern: str) -> list[int]:
"""KMP字符串匹配,返回所有匹配起始位置"""
if not pattern:
return []
pi = compute_prefix(pattern)
results = []
j = 0 # pattern中已匹配的字符数
for i in range(len(text)):
# 失配时利用前缀函数回退
while j > 0 and text[i] != pattern[j]:
j = pi[j - 1]
# 匹配则前进
if text[i] == pattern[j]:
j += 1
# 完全匹配
if j == len(pattern):
results.append(i - j + 1)
j = pi[j - 1]
return results
匹配过程图解
以text = "ABABCABABC", pattern = "ABABC"为例:
第一轮匹配 (i=0~4):
text: [A B A B C] A B A B C
pattern: [A B A B C]
j=5 → 记录匹配位置 0
后续 j = π[4] = 0
第二轮匹配 (i=5~9):
text: A B A B C [A B A B C]
pattern: [A B A B C]
j=5 → 记录匹配位置 5
最终匹配位置:[0, 5]
时间复杂度:O(n + m),其中n是文本长度,m是模式长度。每个字符最多被比较常数次。
Z函数(Z-Algorithm)
Z函数的定义
对于字符串s,Z[i]表示从位置i开始的子串与整个字符串s的最长公共前缀的长度。规定Z[0] = 0。
以s = "ABACABA"为例:
| i | 后缀 | 与s的最长公共前缀 | Z[i] |
|---|---|---|---|
| 0 | ABACABA | (整个串,定义Z[0]=0) | 0 |
| 1 | BACABA | 无 | 0 |
| 2 | ACABA | A | 1 |
| 3 | CABA | 无 | 0 |
| 4 | ABA | ABA | 3 |
| 5 | BA | 无 | 0 |
| 6 | A | A | 1 |
图解Z[4] = 3:
s: A B A C A B A
│ │ │ ← 前缀 "ABA"
│ │ │ ← s[4..6] = "ABA"
从位置4开始的"ABA"与开头的"ABA"完全匹配,长度为3。
Z函数的计算
Z函数的计算维护一个Z-Box[L, R],表示目前已知的最右边的匹配区间:
graph TD
A["初始化 l=0, r=0, Z[0]=0"] --> B["遍历 i = 1 to n-1"]
B --> C{"i ≤ r ?"}
C -->|是| D["z[i] = z[i-l]"]
C -->|否| E["z[i] = 0"]
D --> F{"z[i] < r-i+1 ?"}
F -->|是| G["z[i]已确定<br/>无需扩展"]
F -->|否| H["z[i] = max(0, r-i+1)"]
E --> I["从0开始扩展"]
H --> J["继续暴力扩展"]
I --> J
J --> K["while: s[z[i]] == s[i+z[i]]"]
K --> L["i+z[i]-1 > r ?<br/>更新 l=i, r=i+z[i]-1"]
G --> L
L --> M["i++ → 下一个"]
M --> C
核心思想:如果位置i在Z-Box[l, r]内,那么s[i..r]与s[i-l..r-l]相同。此时借用z[i-l]:若z[i-l] < r-i+1,说明匹配不会超出Z-Box,z[i]直接确定;否则从位置r之后继续暴力扩展。
Python实现:
def compute_z(s: str) -> list[int]:
"""计算字符串的Z函数"""
n = len(s)
z = [0] * n
l, r = 0, 0 # Z-Box的闭区间 [l, r]
for i in range(1, n):
if i <= r and z[i - l] < r - i + 1:
# 借用值严格小于Z-Box剩余长度,z[i]可直接确定
z[i] = z[i - l]
else:
# 超出Z-Box或i不在Box内,从max(0, r-i+1)开始暴力扩展
z[i] = max(0, r - i + 1)
while i + z[i] < n and s[z[i]] == s[i + z[i]]:
z[i] += 1
# 更新Z-Box
if i + z[i] - 1 > r:
l, r = i, i + z[i] - 1
return z
以s = "ABACABA"为例,手动走一遍:
i=1: i>r(0), 从0开始扩展: s[0]='A'≠s[1]='B' → z[1]=0
i=2: i>r(0), 从0开始扩展: s[0]='A'==s[2]='A', s[1]='B'≠s[3]='C' → z[2]=1
更新 l=2, r=2
i=3: i>r(2), 从0开始扩展: s[0]='A'≠s[3]='C' → z[3]=0
i=4: i>r(2), 从0开始扩展: s[0..2]="ABA"==s[4..6]="ABA" → z[4]=3
更新 l=4, r=6
i=5: i≤r(6), z[5-l]=z[1]=0 < r-i+1=2, z[5]=z[1]=0(直接确定,不进入循环)
i=6: i≤r(6), z[6-l]=z[2]=1, r-i+1=1, z[6-l]=r-i+1,不满足<条件
进入else: z[6]=max(0,1)=1, 扩展: s[1]='B'≠s[7]越界 → z[6]=1
z = [0, 0, 1, 0, 3, 0, 1]
用Z函数做字符串匹配
将模式串和文本串用特殊字符#连接:s = pattern + "#" + text,然后计算Z函数。
如果某个位置的Z值等于pattern的长度,说明找到了一个匹配。
def z_search(text: str, pattern: str) -> list[int]:
"""用Z函数做字符串匹配"""
s = pattern + "#" + text
z = compute_z(s)
m = len(pattern)
results = []
for i in range(len(s)):
if z[i] == m:
results.append(i - m - 1) # 减去pattern和#的长度
return results
图解匹配过程,text = "ABABCABABC", pattern = "ABABC":
s = "ABABC#ABABCABABC"
↑pattern
↑text
计算出的Z数组:
| i | s[i] | Z[i] | 说明 |
|---|---|---|---|
| 0 | A | 0 | |
| 1 | B | 0 | |
| 2 | A | 2 | |
| 3 | B | 0 | |
| 4 | C | 0 | |
| 5 | # | 0 | |
| 6 | A | 5 | → 匹配位置 = 6-5-1 = 0 |
| 7 | B | 0 | |
| 8 | A | 2 | |
| 9 | B | 0 | |
| 10 | C | 0 | |
| 11 | A | 5 | → 匹配位置 = 11-5-1 = 5 |
| 12 | B | 0 | |
| 13 | A | 2 | |
| 14 | B | 0 | |
| 15 | C | 0 |
Z值等于len(pattern)=5的位置(i=6和i=11)就是两个匹配起点。
时间复杂度同样是O(n + m)。
KMP vs Z函数:对比
| 特性 | KMP | Z函数 |
|---|---|---|
| 预处理 | 模式串的next数组 | 拼接串的Z数组 |
| 空间复杂度 | O(m) | O(n+m) |
| 匹配方式 | 直接在text上滑动 | 构造拼接串一次计算 |
| 代码简洁度 | 需要理解回退逻辑 | 思路更直观 |
| 额外能力 | 专注字符串匹配 | 还能解决周期串、最长回文前缀等 |
graph LR
subgraph "KMP"
A1["预处理pattern"] --> A2["计算next数组"]
A2 --> A3["在text上滑动匹配"]
end
subgraph "Z函数"
B1["构造 pattern#text"] --> B2["计算Z数组"]
B2 --> B3["找Z值等于m的位置"]
end
实战:经典例题
例题1:字符串周期
问题:判断字符串是否有周期。如"abcabcabc"的周期为3。
解法:计算Z函数,如果i + Z[i] == n且n % i == 0,则i是一个周期。
def find_period(s: str) -> int:
"""找字符串的最小周期"""
z = compute_z(s)
n = len(s)
for i in range(1, n):
if n % i == 0 and z[i] + i == n:
return i
return n # 无周期,周期为自身长度
"abcabcabc":Z = [0, 0, 0, 6, 0, 0, 3, 0, 0],i=3时Z[3]+3=9=n且9%3=0,周期为3。
例题2:统计子串出现次数
问题:统计pattern在text中出现的次数(即KMP/Z函数的基本应用)。
上面kmp_search和z_search返回的列表长度就是答案。
例题3:最长回文前缀
问题:找字符串的最长回文前缀。
解法:构造s + "#" + reverse(s),计算Z函数。遍历rev部分中每个起始位置i,若Z[i] == len(combined)-i(即rev的后缀完全匹配s的前缀),则对应的s前缀是回文。
def longest_palindrome_prefix(s: str) -> str:
"""找最长回文前缀"""
rev = s[::-1]
combined = s + "#" + rev
z = compute_z(combined)
n = len(s)
best = 1 # 单个字符一定是回文
# 遍历 combined 中 rev 的部分(索引 n+1 到 2n)
for i in range(n + 1, len(combined)):
# Z[i] == len(combined) - i 表示从 i 到 combined 末尾的整个后缀都能匹配 s 的前缀
# 设匹配长度 k = Z[i],则 rev 的后 k 个字符 == s 的前 k 个字符
# rev 的后 k 个字符 = s 的前 k 个字符的反转
# 所以 s[0..k-1] == reverse(s[0..k-1]),即 s 的前 k 个字符是回文
if z[i] == len(combined) - i:
best = max(best, z[i])
return s[:best]
以s = "aacecaaa"为例:
rev = "aaacecaa",combined = "aacecaaa#aaacecaa"(长度17)- 计算Z函数后,
Z[10] = 7且len(combined) - 10 = 7,满足条件 - 返回
s[:7] = "aacecaa",正是最长回文前缀(aacecaa是回文,但完整的aacecaaa不是)
总结
KMP和Z函数都是利用已匹配的信息来跳过冗余比较,将字符串匹配从O(n*m)优化到O(n+m)。
- KMP:直接、空间效率好(O(m)),适合纯匹配场景
- Z函数:思路更直观,功能更广(周期检测、回文前缀等),但需要O(n+m)空间
建议两个都掌握。KMP的next数组思想在很多字符串问题中都有变体应用,Z函数则在竞赛中更常用,因为它的定义更自然,一道题的思路往往更容易想到Z函数。
# 完整代码汇总,方便复制使用
def compute_prefix(pattern: str) -> list[int]:
n = len(pattern)
pi = [0] * n
j = 0
for i in range(1, n):
while j > 0 and pattern[i] != pattern[j]:
j = pi[j - 1]
if pattern[i] == pattern[j]:
j += 1
pi[i] = j
return pi
def kmp_search(text: str, pattern: str) -> list[int]:
if not pattern:
return []
pi = compute_prefix(pattern)
results = []
j = 0
for i in range(len(text)):
while j > 0 and text[i] != pattern[j]:
j = pi[j - 1]
if text[i] == pattern[j]:
j += 1
if j == len(pattern):
results.append(i - j + 1)
j = pi[j - 1]
return results
def compute_z(s: str) -> list[int]:
n = len(s)
z = [0] * n
l, r = 0, 0
for i in range(1, n):
if i <= r and z[i - l] < r - i + 1:
z[i] = z[i - l]
else:
z[i] = max(0, r - i + 1)
while i + z[i] < n and s[z[i]] == s[i + z[i]]:
z[i] += 1
if i + z[i] - 1 > r:
l, r = i, i + z[i] - 1
return z
def z_search(text: str, pattern: str) -> list[int]:
s = pattern + "#" + text
z = compute_z(s)
m = len(pattern)
return [i - m - 1 for i in range(len(s)) if z[i] == m]
def longest_palindrome_prefix(s: str) -> str:
rev = s[::-1]
combined = s + "#" + rev
z = compute_z(combined)
best = 1
for i in range(len(s) + 1, len(combined)):
if z[i] == len(combined) - i:
best = max(best, z[i])
return s[:best]