KMP算法与Z函数:字符串匹配的两大利器

字符串匹配是算法中最经典的问题之一。KMP算法和Z函数是两种高效的字符串匹配方法,它们的核心思想都是利用已匹配的信息避免重复比较,把暴力匹配的O(n*m)优化到O(n+m)。

问题定义

给定一个文本串text和一个模式串pattern,找出pattern在text中出现的所有位置。

例如:text = "ABABCABABC", pattern = "ABABC",答案是位置0和5。

KMP算法

暴力匹配的问题

暴力匹配每次失配后,模式串向右移动一位,从头开始比较:

text:    A B A B C A B A B C
pattern: A B A B D    ← 在D处失配
                ↓ 回退到开头,右移一位
pattern:   A B A B D  ← 重新开始比较

问题是:我们已经知道前面匹配过的字符是什么,为什么要重新比较?KMP的核心思想就是利用这些信息,让模式串尽可能少回退。

核心概念:前缀函数(next数组)

KMP的关键是预处理模式串,计算一个前缀函数π(也叫next数组)。

π[i]的定义:模式串pattern[0..i]最长的相等的真前缀与真后缀的长度的长度。

"真前缀"指不等于整个串本身的前缀。

pattern = "ABABC"为例:

i 子串 最长相等前后缀 π[i]
0 A 0
1 AB 0
2 ABA A 1
3 ABAB AB 2
4 ABABC 0

图解π[3] = 2的含义:

ABAB
├┤    前缀 "AB"
  ├┤  后缀 "AB"

前缀"AB"和后缀"AB"相同,长度为2。

前缀函数的计算

计算过程可以用一个自动机的思路来理解:

graph LR
    A["i=0, π=0"] --> B{"i>0 且<br/>pattern[i]≠pattern[j]?"}
    B -->|是| C["j = π[j-1]<br/>回退"]
    C --> B
    B -->|否| D{"pattern[i]==pattern[j]?"}
    D -->|是| E["j++"]
    D -->|否| F["j不变"]
    E --> G["π[i] = j"]
    F --> G
    G --> H["i++ → 处理下一个"]
    H --> B

Python实现:

def compute_prefix(pattern: str) -> list[int]:
    """计算模式串的前缀函数 π"""
    n = len(pattern)
    pi = [0] * n
    j = 0  # 当前已匹配的前缀长度
    
    for i in range(1, n):
        # 如果当前字符不匹配,回退 j
        while j > 0 and pattern[i] != pattern[j]:
            j = pi[j - 1]
        # 如果匹配,前缀长度+1
        if pattern[i] == pattern[j]:
            j += 1
        pi[i] = j
    
    return pi

pattern = "ABABC"为例,手动走一遍:

i=1: pattern[1]='B' ≠ pattern[0]='A', j=0 → π[1]=0
i=2: pattern[2]='A' == pattern[0]='A', j=1 → π[2]=1
i=3: pattern[3]='B' == pattern[1]='B', j=2 → π[3]=2
i=4: pattern[4]='C' ≠ pattern[2]='A', j=π[1]=0
     pattern[4]='C' ≠ pattern[0]='A', j=0 → π[4]=0

π = [0, 0, 1, 2, 0]

KMP匹配过程

有了前缀函数,匹配时遇到失配就不再回退到开头,而是跳到π[j-1]的位置:

graph TD
    A["开始匹配"] --> B{"text[i] == pattern[j]?"}
    B -->|是| C["i++, j++"]
    C --> D{"j == len(pattern)?"}
    D -->|是| E["找到一个匹配!<br/>记录位置 i-j"]
    E --> F["j = π[j-1]"]
    F --> B
    D -->|否| B
    B -->|否| G{"j > 0?"}
    G -->|是| H["j = π[j-1]<br/>利用前缀函数回退"]
    H --> B
    G -->|否| I["i++<br/>模式串未匹配任何字符"]
    I --> B

Python实现:

def kmp_search(text: str, pattern: str) -> list[int]:
    """KMP字符串匹配,返回所有匹配起始位置"""
    if not pattern:
        return []
    
    pi = compute_prefix(pattern)
    results = []
    j = 0  # pattern中已匹配的字符数
    
    for i in range(len(text)):
        # 失配时利用前缀函数回退
        while j > 0 and text[i] != pattern[j]:
            j = pi[j - 1]
        # 匹配则前进
        if text[i] == pattern[j]:
            j += 1
        # 完全匹配
        if j == len(pattern):
            results.append(i - j + 1)
            j = pi[j - 1]
    
    return results

匹配过程图解

text = "ABABCABABC", pattern = "ABABC"为例:

第一轮匹配 (i=0~4):
text:    [A B A B C] A B A B C
pattern: [A B A B C]
         j=5 → 记录匹配位置 0
         后续 j = π[4] = 0

第二轮匹配 (i=5~9):
text:    A B A B C [A B A B C]
pattern:           [A B A B C]
         j=5 → 记录匹配位置 5

最终匹配位置:[0, 5]

时间复杂度:O(n + m),其中n是文本长度,m是模式长度。每个字符最多被比较常数次。

Z函数(Z-Algorithm)

Z函数的定义

对于字符串sZ[i]表示从位置i开始的子串与整个字符串s最长公共前缀的长度。规定Z[0] = 0

s = "ABACABA"为例:

i 后缀 与s的最长公共前缀 Z[i]
0 ABACABA (整个串,定义Z[0]=0) 0
1 BACABA 0
2 ACABA A 1
3 CABA 0
4 ABA ABA 3
5 BA 0
6 A A 1

图解Z[4] = 3

s:      A B A C A B A
        │ │ │          ← 前缀 "ABA"
                │ │ │  ← s[4..6] = "ABA"

从位置4开始的"ABA"与开头的"ABA"完全匹配,长度为3。

Z函数的计算

Z函数的计算维护一个Z-Box[L, R],表示目前已知的最右边的匹配区间:

graph TD
    A["初始化 l=0, r=0, Z[0]=0"] --> B["遍历 i = 1 to n-1"]
    B --> C{"i ≤ r ?"}
    C -->|是| D["z[i] = z[i-l]"]
    C -->|否| E["z[i] = 0"]
    D --> F{"z[i] < r-i+1 ?"}
    F -->|是| G["z[i]已确定<br/>无需扩展"]
    F -->|否| H["z[i] = max(0, r-i+1)"]
    E --> I["从0开始扩展"]
    H --> J["继续暴力扩展"]
    I --> J
    J --> K["while: s[z[i]] == s[i+z[i]]"]
    K --> L["i+z[i]-1 > r ?<br/>更新 l=i, r=i+z[i]-1"]
    G --> L
    L --> M["i++ → 下一个"]
    M --> C

核心思想:如果位置i在Z-Box[l, r]内,那么s[i..r]s[i-l..r-l]相同。此时借用z[i-l]:若z[i-l] < r-i+1,说明匹配不会超出Z-Box,z[i]直接确定;否则从位置r之后继续暴力扩展。

Python实现:

def compute_z(s: str) -> list[int]:
    """计算字符串的Z函数"""
    n = len(s)
    z = [0] * n
    l, r = 0, 0  # Z-Box的闭区间 [l, r]
    
    for i in range(1, n):
        if i <= r and z[i - l] < r - i + 1:
            # 借用值严格小于Z-Box剩余长度,z[i]可直接确定
            z[i] = z[i - l]
        else:
            # 超出Z-Box或i不在Box内,从max(0, r-i+1)开始暴力扩展
            z[i] = max(0, r - i + 1)
            while i + z[i] < n and s[z[i]] == s[i + z[i]]:
                z[i] += 1
        # 更新Z-Box
        if i + z[i] - 1 > r:
            l, r = i, i + z[i] - 1
    
    return z

s = "ABACABA"为例,手动走一遍:

i=1: i>r(0), 从0开始扩展: s[0]='A'≠s[1]='B' → z[1]=0
i=2: i>r(0), 从0开始扩展: s[0]='A'==s[2]='A', s[1]='B'≠s[3]='C' → z[2]=1
     更新 l=2, r=2
i=3: i>r(2), 从0开始扩展: s[0]='A'≠s[3]='C' → z[3]=0
i=4: i>r(2), 从0开始扩展: s[0..2]="ABA"==s[4..6]="ABA" → z[4]=3
     更新 l=4, r=6
i=5: i≤r(6), z[5-l]=z[1]=0 < r-i+1=2, z[5]=z[1]=0(直接确定,不进入循环)
i=6: i≤r(6), z[6-l]=z[2]=1, r-i+1=1, z[6-l]=r-i+1,不满足<条件
     进入else: z[6]=max(0,1)=1, 扩展: s[1]='B'≠s[7]越界 → z[6]=1

z = [0, 0, 1, 0, 3, 0, 1]

用Z函数做字符串匹配

将模式串和文本串用特殊字符#连接:s = pattern + "#" + text,然后计算Z函数。

如果某个位置的Z值等于pattern的长度,说明找到了一个匹配。

def z_search(text: str, pattern: str) -> list[int]:
    """用Z函数做字符串匹配"""
    s = pattern + "#" + text
    z = compute_z(s)
    m = len(pattern)
    results = []
    
    for i in range(len(s)):
        if z[i] == m:
            results.append(i - m - 1)  # 减去pattern和#的长度
    
    return results

图解匹配过程,text = "ABABCABABC", pattern = "ABABC"

s = "ABABC#ABABCABABC"
     ↑pattern
           ↑text

计算出的Z数组:

i s[i] Z[i] 说明
0 A 0
1 B 0
2 A 2
3 B 0
4 C 0
5 # 0
6 A 5 → 匹配位置 = 6-5-1 = 0
7 B 0
8 A 2
9 B 0
10 C 0
11 A 5 → 匹配位置 = 11-5-1 = 5
12 B 0
13 A 2
14 B 0
15 C 0

Z值等于len(pattern)=5的位置(i=6和i=11)就是两个匹配起点。

时间复杂度同样是O(n + m)。

KMP vs Z函数:对比

特性 KMP Z函数
预处理 模式串的next数组 拼接串的Z数组
空间复杂度 O(m) O(n+m)
匹配方式 直接在text上滑动 构造拼接串一次计算
代码简洁度 需要理解回退逻辑 思路更直观
额外能力 专注字符串匹配 还能解决周期串、最长回文前缀等
graph LR
    subgraph "KMP"
        A1["预处理pattern"] --> A2["计算next数组"]
        A2 --> A3["在text上滑动匹配"]
    end
    subgraph "Z函数"
        B1["构造 pattern#text"] --> B2["计算Z数组"]
        B2 --> B3["找Z值等于m的位置"]
    end

实战:经典例题

例题1:字符串周期

问题:判断字符串是否有周期。如"abcabcabc"的周期为3。

解法:计算Z函数,如果i + Z[i] == nn % i == 0,则i是一个周期。

def find_period(s: str) -> int:
    """找字符串的最小周期"""
    z = compute_z(s)
    n = len(s)
    for i in range(1, n):
        if n % i == 0 and z[i] + i == n:
            return i
    return n  # 无周期,周期为自身长度

"abcabcabc":Z = [0, 0, 0, 6, 0, 0, 3, 0, 0],i=3时Z[3]+3=9=n且9%3=0,周期为3。

例题2:统计子串出现次数

问题:统计pattern在text中出现的次数(即KMP/Z函数的基本应用)。

上面kmp_searchz_search返回的列表长度就是答案。

例题3:最长回文前缀

问题:找字符串的最长回文前缀。

解法:构造s + "#" + reverse(s),计算Z函数。遍历rev部分中每个起始位置i,若Z[i] == len(combined)-i(即rev的后缀完全匹配s的前缀),则对应的s前缀是回文。

def longest_palindrome_prefix(s: str) -> str:
    """找最长回文前缀"""
    rev = s[::-1]
    combined = s + "#" + rev
    z = compute_z(combined)
    n = len(s)
    
    best = 1  # 单个字符一定是回文
    # 遍历 combined 中 rev 的部分(索引 n+1 到 2n)
    for i in range(n + 1, len(combined)):
        # Z[i] == len(combined) - i 表示从 i 到 combined 末尾的整个后缀都能匹配 s 的前缀
        # 设匹配长度 k = Z[i],则 rev 的后 k 个字符 == s 的前 k 个字符
        # rev 的后 k 个字符 = s 的前 k 个字符的反转
        # 所以 s[0..k-1] == reverse(s[0..k-1]),即 s 的前 k 个字符是回文
        if z[i] == len(combined) - i:
            best = max(best, z[i])
    
    return s[:best]

s = "aacecaaa"为例:

  • rev = "aaacecaa"combined = "aacecaaa#aaacecaa"(长度17)
  • 计算Z函数后,Z[10] = 7len(combined) - 10 = 7,满足条件
  • 返回s[:7] = "aacecaa",正是最长回文前缀(aacecaa是回文,但完整的aacecaaa不是)

总结

KMP和Z函数都是利用已匹配的信息来跳过冗余比较,将字符串匹配从O(n*m)优化到O(n+m)。

  • KMP:直接、空间效率好(O(m)),适合纯匹配场景
  • Z函数:思路更直观,功能更广(周期检测、回文前缀等),但需要O(n+m)空间

建议两个都掌握。KMP的next数组思想在很多字符串问题中都有变体应用,Z函数则在竞赛中更常用,因为它的定义更自然,一道题的思路往往更容易想到Z函数。

# 完整代码汇总,方便复制使用
def compute_prefix(pattern: str) -> list[int]:
    n = len(pattern)
    pi = [0] * n
    j = 0
    for i in range(1, n):
        while j > 0 and pattern[i] != pattern[j]:
            j = pi[j - 1]
        if pattern[i] == pattern[j]:
            j += 1
        pi[i] = j
    return pi

def kmp_search(text: str, pattern: str) -> list[int]:
    if not pattern:
        return []
    pi = compute_prefix(pattern)
    results = []
    j = 0
    for i in range(len(text)):
        while j > 0 and text[i] != pattern[j]:
            j = pi[j - 1]
        if text[i] == pattern[j]:
            j += 1
        if j == len(pattern):
            results.append(i - j + 1)
            j = pi[j - 1]
    return results

def compute_z(s: str) -> list[int]:
    n = len(s)
    z = [0] * n
    l, r = 0, 0
    for i in range(1, n):
        if i <= r and z[i - l] < r - i + 1:
            z[i] = z[i - l]
        else:
            z[i] = max(0, r - i + 1)
            while i + z[i] < n and s[z[i]] == s[i + z[i]]:
                z[i] += 1
        if i + z[i] - 1 > r:
            l, r = i, i + z[i] - 1
    return z

def z_search(text: str, pattern: str) -> list[int]:
    s = pattern + "#" + text
    z = compute_z(s)
    m = len(pattern)
    return [i - m - 1 for i in range(len(s)) if z[i] == m]

def longest_palindrome_prefix(s: str) -> str:
    rev = s[::-1]
    combined = s + "#" + rev
    z = compute_z(combined)
    best = 1
    for i in range(len(s) + 1, len(combined)):
        if z[i] == len(combined) - i:
            best = max(best, z[i])
    return s[:best]