【题目描述】

给定 $n, k$,求: $$ \sum_{i = 1}^n k\ \text{mod}\ i $$ $ 1 \leq n, k \leq 10^{9} $

【题目链接】

BZOJ 1257 余数之和 【CQOI 2007】

【解题思路】

对于 $i > k$ 的部分,$k\ \text{mod}\ i$ 都是 $k$,可以直接统计。

对于 $i \leq k$ 的部分,我们考虑 模的定义:

$$ k \ \text{mod}\ i = k - i \times \left \lfloor \frac k i \right \rfloor $$

性质:$\left \lfloor \frac k i \right \rfloor$ 至多有 $2\sqrt k$ 种取值。

简要证明:对于 $i \leq \sqrt k$,$i$ 至多有 $\sqrt k$ 种取值,所以该式至多有 $\sqrt k$ 种取值,对于 $i > \sqrt k$,有 $\left \lfloor \frac k i \right \rfloor \leq \sqrt k$,所以至多有 $\sqrt k$ 种取值,加起来就至多有 $2\sqrt k$ 种取值。

所以我们可以枚举这个取值快速计算,考虑使得取值相同的 $i$ 是连续的一段区间,所以我们可以枚举值相同的区间。

为了感性理解这一点,不如我们看一下 $k = 100$ 时,$\left \lfloor \frac k i \right \rfloor$ 的表:

100
50
33
25
20
16
14
12
11
10
9
8
7 7 
6 6
5 5 5 5 
4 4 4 4 4
3 3 3 3 3 3 3 3
2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 
1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1

假设我们枚举到了 $[l, r]$,其中所有的 $\left \lfloor \frac k i \right \rfloor$ 的取值都是 $x$,这个区间中的 $i$ 答案的贡献就是:

$$ \begin{align*} \sum\limits_{i = l}^r k - ix &= \sum\limits_{i = l}^r k - \sum\limits_{i = l}^r ix \\ &= (r - l + 1)k - x\sum\limits_{i = l}^r i \\ \end{align*} $$

最后那个求和可以用等差数列求和公式 $O(1)$ 计算,这样我们就可以快速计算一个区间的贡献了。

那么该如何枚举区间呢?可以用如下的代码:

for(int l = 1, r; l <= k; l = r + 1){
    r = k / (k / l);
    
    // 这样枚举,[l, r] 的取值都是 k / l
}

核心代码就是 r = k / (k / l) 这一句,计算出了区间右端点。

为什么呢?

考虑我们的任务是要找一个最大的 $r, r > l$,使得:

$$ \left \lfloor \frac k r \right \rfloor \geq \left \lfloor \frac k l \right \rfloor $$

左边的下取整可以直接去掉

$$ \frac k r \geq \left \lfloor \frac k l \right \rfloor $$

然后把 $r$ 乘过去,$\left \lfloor \frac k l \right \rfloor$ 除过来,得到:

$$ r \leq \frac k {\left \lfloor \frac k l \right \rfloor} $$

所以最大的满足这个式子的 整数 $r$ 就是 $\left \lfloor \frac k {\left \lfloor \frac k l \right \rfloor} \right \rfloor$ 咯。

这样我们单次询问就是 $O(\sqrt k)$ 的啦。

总结:主要用到了 $\left \lfloor \frac k i \right \rfloor$ 至多有 $2\sqrt k$ 种取值这一性质,因为是第一次写此类题解,写的详细了一点。利用这种性质成块处理的方法在 莫比乌斯反演 的题目中很常用,所以要熟练掌握(其实也很好写不是嘛)。

【AC 代码】

还要注意这题稍微特殊一点,因为 $i \leq n$,所以枚举到的区间端点要对 $n$ 取 $\min$。

#include <cstdio>
#include <algorithm>

typedef long long int64;

inline int64 sum(int64 n){
    return n * (n + 1) / 2;
}

inline int64 solve(int64 k, int64 n){
    int64 ans = 0;

    if(n > k){
        ans += k * (n - k);
        n = k;
    }

    for(int64 l = 1, r; l <= n; l = r + 1){
        r = std::min(k / (k / l), n);
        ans += (r - l + 1) * k - (k / l) * (sum(r) - sum(l - 1));
    }

    return ans;
}

int main(){
    int64 n, k; 

    scanf("%lld%lld", &n, &k);

    printf("%lld\n", solve(k, n));

    return 0;
}

就是这样咯~