【题目描述】
给定 $n, k$,求: $$ \sum_{i = 1}^n k\ \text{mod}\ i $$ $ 1 \leq n, k \leq 10^{9} $
【题目链接】
BZOJ 1257 余数之和 【CQOI 2007】
【解题思路】
对于 $i > k$ 的部分,$k\ \text{mod}\ i$ 都是 $k$,可以直接统计。
对于 $i \leq k$ 的部分,我们考虑 模的定义:
$$ k \ \text{mod}\ i = k - i \times \left \lfloor \frac k i \right \rfloor $$
性质:$\left \lfloor \frac k i \right \rfloor$ 至多有 $2\sqrt k$ 种取值。
简要证明:对于 $i \leq \sqrt k$,$i$ 至多有 $\sqrt k$ 种取值,所以该式至多有 $\sqrt k$ 种取值,对于 $i > \sqrt k$,有 $\left \lfloor \frac k i \right \rfloor \leq \sqrt k$,所以至多有 $\sqrt k$ 种取值,加起来就至多有 $2\sqrt k$ 种取值。
所以我们可以枚举这个取值快速计算,考虑使得取值相同的 $i$ 是连续的一段区间,所以我们可以枚举值相同的区间。
为了感性理解这一点,不如我们看一下 $k = 100$ 时,$\left \lfloor \frac k i \right \rfloor$ 的表:
100
50
33
25
20
16
14
12
11
10
9
8
7 7
6 6
5 5 5 5
4 4 4 4 4
3 3 3 3 3 3 3 3
2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2
1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1
假设我们枚举到了 $[l, r]$,其中所有的 $\left \lfloor \frac k i \right \rfloor$ 的取值都是 $x$,这个区间中的 $i$ 答案的贡献就是:
$$ \begin{align*} \sum\limits_{i = l}^r k - ix &= \sum\limits_{i = l}^r k - \sum\limits_{i = l}^r ix \\ &= (r - l + 1)k - x\sum\limits_{i = l}^r i \\ \end{align*} $$
最后那个求和可以用等差数列求和公式 $O(1)$ 计算,这样我们就可以快速计算一个区间的贡献了。
那么该如何枚举区间呢?可以用如下的代码:
for(int l = 1, r; l <= k; l = r + 1){
r = k / (k / l);
// 这样枚举,[l, r] 的取值都是 k / l
}
核心代码就是 r = k / (k / l)
这一句,计算出了区间右端点。
为什么呢?
考虑我们的任务是要找一个最大的 $r, r > l$,使得:
$$ \left \lfloor \frac k r \right \rfloor \geq \left \lfloor \frac k l \right \rfloor $$
左边的下取整可以直接去掉
$$ \frac k r \geq \left \lfloor \frac k l \right \rfloor $$
然后把 $r$ 乘过去,$\left \lfloor \frac k l \right \rfloor$ 除过来,得到:
$$ r \leq \frac k {\left \lfloor \frac k l \right \rfloor} $$
所以最大的满足这个式子的 整数 $r$ 就是 $\left \lfloor \frac k {\left \lfloor \frac k l \right \rfloor} \right \rfloor$ 咯。
这样我们单次询问就是 $O(\sqrt k)$ 的啦。
总结:主要用到了 $\left \lfloor \frac k i \right \rfloor$ 至多有 $2\sqrt k$ 种取值这一性质,因为是第一次写此类题解,写的详细了一点。利用这种性质成块处理的方法在 莫比乌斯反演 的题目中很常用,所以要熟练掌握(其实也很好写不是嘛)。
【AC 代码】
还要注意这题稍微特殊一点,因为 $i \leq n$,所以枚举到的区间端点要对 $n$ 取 $\min$。
#include <cstdio>
#include <algorithm>
typedef long long int64;
inline int64 sum(int64 n){
return n * (n + 1) / 2;
}
inline int64 solve(int64 k, int64 n){
int64 ans = 0;
if(n > k){
ans += k * (n - k);
n = k;
}
for(int64 l = 1, r; l <= n; l = r + 1){
r = std::min(k / (k / l), n);
ans += (r - l + 1) * k - (k / l) * (sum(r) - sum(l - 1));
}
return ans;
}
int main(){
int64 n, k;
scanf("%lld%lld", &n, &k);
printf("%lld\n", solve(k, n));
return 0;
}
就是这样咯~