字符串匹配之 KMP 算法

某天在群里看到 井大侠 提起传说中的“看毛片”算法,在网上搜到一篇好文章(见参考资料 [1]),发现从有限状态自动机的角度来理解可能会更容易一些。

简单的逐个字符比较

最简单的匹配算法是将两个字符串第一个字符对齐,然后开始逐字符比较。当在某个位置发现不匹配时,把模式串向右移动一个字符,再从模式串的第一个字符开始与主串对应位置的字符开始比较。例如要从字符串 T 中查找模式串 P,比较过程如下:

                        1                   2
pos 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2
T   A B C Z A B C D A E Z A B C D A B C D A B D E
P   A B C D A B D
          *

在 pos[3] 发现不匹配,把 P 向右移动一个字符,从 P 的第一个字符开始比较:

                        1                   2
pos 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2
T   A B C Z A B C D A E Z A B C D A B C D A B D E
P     A B C D A B D
      *

在 pos[1] 不匹配,再向右移动……直到 p[0] 移动到 pos[15],发现所有字符都匹配:

                        1                   2
pos 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2
T   A B C Z A B C D A E Z A B C D A B C D A B D E
P                                 A B C D A B D
                                              *

这种情况可以看成是由模式串构成的一个有限状态自动机,主串作为输入流,每次发现不匹配时都回到初始状态:

  +--------------------------+
  |                          |
  +-----------------+        |
  |                 |        |
  +--------+        |        |
  |        |        |        |
  v        |        |        |
+---+ A  +---+ B  +---+ C  +---+ D  +---+ A  +---+ B  +---+ D  +===+
| 0 | -> | 1 | -> | 2 | -> | 3 | -> | 4 | -> | 5 | -> | 6 | -> | 7 |
+---+    +---+    +---+    +---+    +---+    +---+    +---+    +===+
  ^                                   |        |        |
  |                                   |        |        |
  +-----------------------------------+        |        |
  |                                            |        |
  +--------------------------------------------+        |
  |                                                     |
  +-----------------------------------------------------+

除非在前两个字符就发现不匹配,否则输入需要回溯(比较的位置往前移)。

记 T 的长度为 n,P 的长度为 m,0 < m <= n。

最好的情况是 p[0] 与 t[0], t[1], ..., t[n-m-1] 均不相同,这样只需比较 n-m 次;最坏情况是 P[0, 1, ..., m-2] 与 T[i, i+1, ..., i+m-2] 匹配,而 p[m-1] 与 t[i+m-1] 不相同(0 <= i < n-m-1),比较次数为 m(n-m) 次。

KMP 算法

KMP 算法是由 D.E.Knuth,J.H.Morris 和 V.R.Pratt 共同发现的,因此以三人的首字母命名该算法。从状态转移的角度看,KMP 算法通过对模式串的预处理,使得每次在不匹配的时候不必回到初始状态,而是可以退回到初始状态和当前状态之间的某个状态再开始比较,并且输入不需要回溯。

假设当前 p[0] 位于 pos[i],已匹配的字符范围是 pos[i, ..., k],不匹配的位置是 pos[k+1],即 P[0, ..., k-i] 和 T[i, ..., k] 已经匹配,p[k-i+1] != t[k+1]。如果要退回到当前状态前的某个状态 s,即把 P 向右移动 x 个字符,如果移动前的字符串和移动后有交集的话,P[0, ..., j] 和 P[x, ..., x+j] 必须匹配(0 < j < m-x),从 p[j+1] 开始和主串对应位置的字符比较:

      i                   k
pos   11  12  13  14  15  16  17  18  19  20  21  22
    +---+---+---+---+---+---+---+---+---+---+---+---+
T   | A | B | C | D | A | B | C | D | A | B | D | E |
    +---+---+---+---+---+---+---+---+---+---+---+---+
    +---+---+---+---+---+---+---+
P   | A | B | C | D | A | B | D |
    +---+---+---+---+---+---+---+
                    +---+---+---+---+---+---+---+
P'                  | A | B | C | D | A | B | D |
                    +---+---+---+---+---+---+---+
                    | <- -> |

KMP 算法就是要找出 j 的最大值,这样就能尽量减少重复比较的次数。可以看到 j 的值与字符串 T 无关,只与模式串 P 有关。这里用一个一维数组 next 保存转移的状态,对于 j>0,next[j] 表示当前状态为 j,如果输入不匹配应该退回到状态 next[j] 继续比较。例如下面是比较过程中的某个状态:

                        1                   2
pos 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2
T   A B C Z A B C D A E Z A B C D A B C D A B D E
P           A B C D A B D
                      *

在 pos[9] 发现不匹配,根据预处理得到的下一个状态,应该把 p[0] 向右移动到 pos[8],p[0] 和 t[8] 匹配,然后从 pos[9] 开始比较:

                        1                   2
pos 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2
T   A B C Z A B C D A E Z A B C D A B C D A B D E
P                   A B C D A B D
                      *

发现不匹配,由于该位置前已经没有重叠的字串,所以只能向右移动一个字符,再从头比较。经过两次移动后:

                        1                   2
pos 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2
T   A B C Z A B C D A E Z A B C D A B C D A B D E
P                         A B C D A B D
                          *

当在 pos[17] 发现不匹配时,向右移动使 p[0] 和 pos[15] 对齐,从 pos[17] 开始比较:

                        1                   2
pos 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2
T   A B C Z A B C D A E Z A B C D A B C D A B D E
P                                 A B C D A B D
                                      *

到 pos[21] 发现完全匹配,比较过程结束。

根据以上的例子,可以得到下面的比较程序:

char* do_kmp(const char *text, int tlen, const char *pattern,
      int plen, int next[])
{
   int i = 0, j = 0;

   while (i < tlen) {
      while (j != -1 && pattern[j] != text[i])
         j = next[j];

      ++i;
      ++j;
      if (j == plen)
         return (char*)(text + i - j);
   }

   return NULL;
}

上面的程序有个小问题,就是当模式串的末尾已超出主串末尾,而当前的比较位置还没到达主串末尾的时候,比较并不会立刻终止。不过这样的比较次数一般不多,和加上判断条件相比,后者要做的额外判断可能还要多一些。后面给出了比较过程的另一种实现,可以避免这个问题。

参考资料 [3] 中提到了一个构造模式串的有限状态自动机来匹配的算法,其中保存转移状态用了一个二维数组:一维是输入字符,另一维是对应于某个输入的转移状态;而 KMP 算法只用了一个一维数组,原因在于期望输入字符可以通过当前状态来获得,并且状态只分为匹配和不匹配两种(自动机可以根据不匹配字符的不同跳转到不同的状态),从而节省了空间。KMP 算法可以认为是有限状态自动机的一个特例,两者的不同之处在于 KMP 的模式串是确定的,而有限状态自动机可以处理正则表达式等模式串。

假设已知 next[j]=k,即已知 P[0, ..., j] 和 P[x, ..., x+j] 匹配,求 next[j+1]。为了方便统一处理,令 next[0]=-1 作为后退的终止条件。这里分两种情况讨论:

  1. 如果 p[j+1]=p[x+j+1],说明是期望的输入,可以直接进入下一个状态,即 next[j+1]=next[j]+1。例如模式串“ABCDABD”,已知 next[5]=1,表示当在状态 5 时(这时期望输入是 p[5]='B'),如果输入和期望输入不匹配可以回到状态 1,让 p[1](新的期望输入 'B')和不匹配位置上的字符重新比较。如果 p[2]=p[6],说明在状态 6 时(此时期望输入为 p[b]='D'),如果实际输入和期望输入不匹配时可以回到状态 2,让 p[2](新的期望输入 'C')和不匹配位置上的字符比较,从而避免了 p[0],p[1] 的重复比较。
  2. 如果 p[j+1]<>p[x+j+1],说明下一个字符不匹配,那么退回到前一个匹配的状态 next[j](next[j] 表示 P[0, ..., next[j]] 和 T[x, ..., x+next[j]] 匹配),看新的期望输入是否和当前输入一样。如果一样则是情况 1,不一样则再退回到 next[next[j]]……直到进入情况 1 或者退到状态为 -1。

根据分析,下面是一个 c 语言实现:

void kmp_transition_table(const char *pattern, int len, int next[])
{
   int state = 0, last_state /* = next[0] */ = -1;

   next[0] = -1;

   while (state < len - 1) {
      while (last_state != -1 && pattern[state] != pattern[last_state])
         last_state = next[last_state];

      ++state;
      ++last_state;
      next[state] = last_state;
   }
}

可以发现计算状态转移的过程和实际比较的过程很相似(计算状态转移可以看成模式串的自我匹配)。根据上面的程序,对于模式串“ABCDABD”,可以得到下面的状态转移图:

            +-----------------------------------+
            |                                   |
            +--------------------------+        |
            |                          |        |
            +-----------------+        |        |
            |                 |        |        |
            +--------+        |        |        |
            |        |        |        |        |
            v        |        |        |        |
+----+    +---+ A  +---+ B  +---+ C  +---+ D  +---+ A  +---+ B  +---+ D  +===+
| -1 | -> | 0 | -> | 1 | -> | 2 | -> | 3 | -> | 4 | -> | 5 | -> | 6 | -> | 7 |
+----+    +---+    +---+    +---+    +---+    +---+    +---+    +---+    +===+
  ^         |        ^        ^                          |        |
  |         |        |        |                          |        |
  +---------+        +--------|--------------------------+        |
                              |                                   |
                              +-----------------------------------+

从上面的状态转移图看到,当在状态 5 不匹配时退回到状态 1,很明显状态 5 和状态 1 的期望输入都是 'B',既然在状态 5 时不匹配,退回到状态 1 也同样不匹配,应该继续往后退才对,因此有下面的改进:

void kmp_transition_table(const char *pattern, int len, int next[])
{
   int state = 0, last_state /* = next[0] */ = -1;

   next[0] = -1;

   while (state < len - 1) {
      while (last_state != -1 && pattern[state] != pattern[last_state])
         last_state = next[last_state];

      ++state;
      ++last_state;

      if (pattern[state] == pattern[last_state])
         next[state] = next[last_state];
      else
         next[state] = last_state;
   }
}

修改后对于模式串“ABCDABD”生成新的状态转移图如下:

            +--------------------------------------------+
            |                                            |
            +--------------------------+                 |
            |                          |                 |
            +-----------------+        |                 |
            |                 |        |                 |
            +--------+        |        |                 |
            |        |        |        |                 |
            v        |        |        |                 |
+----+    +---+ A  +---+ B  +---+ C  +---+ D  +---+ A  +---+ B  +---+ D  +===+
| -1 | -> | 0 | -> | 1 | -> | 2 | -> | 3 | -> | 4 | -> | 5 | -> | 6 | -> | 7 |
+----+    +---+    +---+    +---+    +---+    +---+    +---+    +---+    +===+
  ^         |                 ^                 |                 |
  |         |                 |                 |                 |
  +---------+                 +-----------------|-----------------+
  |                                             |
  +---------------------------------------------+

Linux 内核里也有 kmp 的实现,在 lib/ts_kmp.c 中,主要用于网络过滤等方面。

最后是一个带 debug 信息的完整 c 程序,输出状态转移表以及每次比较的情况。

/*--------------------------------------*/
/* a demonstration of the KMP algorithm */
/* http://ouonline.net/      2011.06.03 */
/*--------------------------------------*/

#include <stdio.h>
#include <stdlib.h>

#define KMP_DEBUG

#ifdef KMP_DEBUG
static void kmp_debug_transition_table(int next[], int len)
{
   int i;

   printf("==================================\n");
   printf("transition table:\n");

   printf("index   0");
   for (i = 1; i < len; ++i) {
      if (i < 10 && next[i] < 0)
         printf(" ");
      printf(" %d", i);
   }
   printf("\n");

   printf("state  ");
   for (i = 0; i < len; ++i) {
      if (i < 10)
         printf("%d ", next[i]);
      else /* if (i < 100) */ {
         if (next[i] >= 0 && next[i] < 10)
            printf(" %d ", next[i]);
         else
            printf("%d ", next[i]);
      }
   }

   printf("\n==================================\n");
}
#endif

void kmp_transition_table(const char *pattern, int len, int next[])
{
   int state = 0, last_state /* = next[0] */ = -1;

   next[0] = -1;

   while (state < len - 1) {
      while (last_state != -1 && pattern[state] != pattern[last_state])
         last_state = next[last_state];

      ++state;
      ++last_state;

      if (pattern[state] == pattern[last_state])
         next[state] = next[last_state];
      else
         next[state] = last_state;
   }

#ifdef KMP_DEBUG
   kmp_debug_transition_table(next, len);
#endif
}

#ifdef KMP_DEBUG
static void kmp_debug_compare(const char *text, int tlen,
                              const char *pattern, int plen,
                              int i, int j)
{
   int c, k;

   if (tlen > 10) {
      printf("     ");
      for (c = 1; c < tlen / 10; ++c)
         printf("                   %d", c);
      if (tlen % 10 != 0)
         printf("                   %d", c);
      printf("\n");
   }
   printf("pos");
   for (c = 0; c <= tlen / 10; ++c)
      for (k = 0; k < 10; ++k)
         if (c * 10 + k < tlen)
            printf(" %d", k);
   printf("\n");

   printf("T  ");
   for (c = 0; c < tlen; ++c)
      printf(" %c", text[c]);
   printf("\n");

   printf("P  ");
   for (c = 0; c < i - j; ++c)
      printf("  ");
   for (c = 0; c < plen; ++c)
      printf(" %c", pattern[c]);
   printf("\n");

   printf("   ");
   for (c = 0; c < i; ++c)
      printf("  ");
   printf(" *\n");

   printf("---------------------------------\n");
}
#endif

char* do_kmp(const char *text, int tlen, const char *pattern,
      int plen, int next[])
{
   int i = 0, j = 0;

   while (plen - j + i <= tlen) {
#ifdef KMP_DEBUG
      kmp_debug_compare(text, tlen, pattern, plen, i, j);
#endif
      if (pattern[j] == text[i]) {
         ++i;
         ++j;
         if (j == plen)
            return (char*)(text + i - j);
      } else {
         j = next[j];
         if (j == -1) {
            ++i;
            j = 0;
         }
      }
   }

   return NULL;
}

int kmp(const char *text, int tlen, const char *pattern, int plen,
      const char **match_pos)
{
   int *next;

   *match_pos = NULL;

   if (!text || !pattern || !match_pos || plen < 0 || tlen < 0)
      return -1;

   if (plen > tlen)
      return 0;

   next = malloc(plen * sizeof(int));
   if (!next)
      return -2;

   kmp_transition_table(pattern, plen, next);
   *match_pos = do_kmp(text, tlen, pattern, plen, next);

   free(next);

   return 0;
}

#include <string.h>

int main()
{
   int ret;
   const char *text = "ABCZABCDAEZABCDABCDABDE", *pattern = "ABCDABD", *pos;

   ret = kmp(text, strlen(text), pattern, strlen(pattern), &pos);
   if (ret != 0)
      fprintf(stderr, "error.\n");
   else {
      if (pos == NULL)
         printf("no match.\n");
      else
         printf("pos = %lu\n", pos - text);
   }

   return 0;
}

参考资料

[1] 基于有限状态自动机分析KMP字符串匹配算法
[2] 维基百科上的KMP词条
[3] 算法导论(第二版),(美) Thomas H.Cormen, Charles E.Leiserson, Ronald L.Rivest, Clifford Stein 著,潘金贵,顾铁成,李成法,叶懋 译
[4] KMP算法深度解析

Comments (3)

发表回复

您的电子邮箱地址不会被公开。 必填项已用 * 标注