蓄水池抽样算法

在别人的指点下,了解到了这个算法,研究后发现是一个特别巧妙的随机算法。

问题特点

  1. 需要在一组数据流中选出 m 个数据;
  2. 数据流的规模 N 可能很大,甚至未知,无法一次性全部加载在内存中;
  3. 要保证每个数据被选中的概率是一致的。

需要注意的是第二点,规模可能是未知的,或者说很难得到规模。如果我们能提前确定 N 的大小,那么可以简单的使用 rand()%N 的方式得到一个确切的随机位置,且每个数据被选中的概率都是 $\frac{1}{N}$ ,无需使用此算法。

应用场景

有一些特定的场景可能会使用该算法,场景比较类似,都是大数据量且规模未知,要公平公平还是TMD公平!

  1. 一个很大的文本,无法得知有多少行,需要随机抽取其中若干行,每一行概率都要相等。
  2. 规模未知的用户参与抽奖,奖品若干,保证每个用户中奖的概率是一致的。

算法流程

设定数据的规模为 N,需要采样的数量为 m,即蓄水池的容量。

  1. 将 N 个数据的前 m 个直接放入蓄水池;
  2. 从第 i 个数据开始( i 从1开始), i > m,在 [0, i) 范围内随机选择整数 d, 如果 d 在[0, m) 范围内,则用第 i 个数据替换原蓄水池中下标为 d 的数据;
  3. 不停地重复步骤2,直至没有数据。

巧妙之处在于:当处理完所有数据时,蓄水池中的每个数据都是以 $\frac{m}{N}$ 的概率获得的。

代码实现

实现的核心代码约10行:ReservoirSampling.java
单元测试,运行后观察在重复若干次后,每个数据被选中的概率:ReservoirSamplingTest.java

大部分人看到这里就差不多结束了,下面会文字说下证明过程,感兴趣的可以再看看。


等概率证明过程

依旧设定数据的规模为 N,需要采样的数量为 m,数据编号 i 从1开始。

每个数据被选中的概率的计算公式如下:
\begin{equation}
\begin{split}
第 i 个接收到的数据最后能留在蓄水池中的概率&=第 i 个数据进入过蓄水池 \\
&\quad* 之后所有数据都没把 i 替换掉的概率
\end{split}
\end{equation}

我们需要讨论两种情况,一种是 i 在 [1, m] 这个范围内的情况,一种是 i 在 (m, N] 的范围内。

  1. 当 i <= m 时, 数据直接放入蓄水池,第 i 个数据进入过蓄水池的概率 = 1

  2. 当 i > m 时,在 [0, i) 内随机选取一个数 d,如果 d < m,则用第 i 个数据,替换蓄水池中下标为 d 的数据,所以**第 i 个数据进入过蓄水池的概率 = $\frac{m}{i}$**。

  3. 当 i <=m 时,算法从接收到第 m+1 个数据时开始有可能出现替换操作,根据第2点得出的规律,第 m+1 个数据会进入水池并替换掉池其它数据的概率为 $\frac{m}{m+1}$ ,刚好替换第 i 个数据的概率为 $\frac{1}{m+1}$,那么第 i 个数据不被第 m+1 个数据替换的概率为 1 - $\frac{1}{m+1}$ = $\frac{m}{m+1}$。

    同理,第 m+2 个数据会进入水池并替换掉池其它数据的概率为 $\frac{m}{m+2}$,刚好替换第 i 个数据的概率是 $\frac{1}{m+2}$,第 i 个数据不被第 m+2 个数据替换的概率为 1 - $\frac{1}{m+2}$ = $\frac{m+1}{m+2}$,……, 第 N 个数据不替换第 i 个数据的概率为 $\frac{N-1}{N}$。

    所以有以下结果:

\begin{split}
第 i 个数据不被后面所有数据替换的概率&=不被第m+1个数据替换的概率 \\
&\quad* 不被第m+2个数据替换的概率 \\
&\quad* 不被第m+3个数据替换的概率 \\
&\quad* … \\
&\quad* 不被第N个数据替换的概率 \\
&= \frac{m}{m+1} * \frac{m+1}{m+2} * \frac{m+2}{m+3} * … * \frac{N-1}{N} \\
&= \frac{m}{N}
\end{split}
4. 当 i > m 时,算法从接收到第 i+1 个数据时开始有可能替换第 i 个数据,根据第3点得到的结果, 第 i 个数据不被之后所有数据替换的概率为 $\frac{i}{N}$。

总结:
根据以上四条规律,以及开始所说的公式 (1),可以得出结果:

  • 结合第1点、第3点,当 i <= m 时,第 i 个数据被选中的概率为 $ 1 * \frac{m}{N} = \frac{m}{N}$;
  • 结合第2点、第4点,当 i > m 时,第 i 个数据被选中的概率为 $ \frac{m}{i} * \frac{i}{N} = \frac{m}{N}$。

==> 因此,对于任意 i ,都有 $\frac{m}{N}$ 的概率。