「题目描述」
有一张 $n \times m $ 的数表。
其中第 $i$ 行第 $j$ 列 ($i \in [1, n], j \in [1,m] $) 的数值为能同时整除 $i$ 和 $j$ 的所有自然数之和。
每组询问给定 $a$,计算数表中不大于 $a$ 的数之和。
「题目链接」
BZOJ 3539 数表 「SDOI 2014」
「解题思路」
先不管 $a$ 的限制,考虑怎么做。
设 $\sigma(x)$ 表示 $x$ 的约数和,则第 $i$ 行第 $j$ 列的数为 $\sigma((i, j))$ 。
然后 按套路 肆意推导一番,这是基本功了: $$ \begin{align*} \sum_{i = 1}^n \sum_{j = 1}^m \sigma((i, j)) % &= \sum_{d} \sigma(d) \sum_{i = 1}^{n} \sum_{j = 1}^m [(i, j) = d] % \newline &= \sum_{d} \sigma(d) \sum_{i = 1}^{\left \lfloor \frac n d \right \rfloor} \sum_{j = 1}^{\left \lfloor \frac m d \right \rfloor} [(i, j) = 1] % \newline &= \sum_{d} \sigma(d) \sum_{i = 1}^{\left \lfloor \frac n d \right \rfloor} \sum_{j = 1}^{\left \lfloor \frac m d \right \rfloor} \sum_{x \mid (i, j)} \mu(x) % \newline &= \sum_{d} \sigma(d) \sum_{x}\mu(x) \sum_{i = 1}^{\left \lfloor \frac n d \right \rfloor} \sum_{j = 1}^{\left \lfloor \frac m d \right \rfloor} [x \mid i] [x \mid j] % \newline &= \sum_{d} \sigma(d) \sum_{x}\mu(x) \left \lfloor \frac n {dx}\right \rfloor \left \lfloor \frac m {dx}\right \rfloor % \newline &= \sum_{T} \left \lfloor \frac n T \right \rfloor \left \lfloor \frac m T \right \rfloor \sum_{d \mid T} \sigma(d) \mu(\frac T d) \end{align*} $$ 至此,设 $g(T) = \sum_{d \mid T} \sigma(d) \mu(\frac T d)$ ,只要能计算出 $g$ ,然后就可以成块处理 $O(\sqrt n)$ 回答每次询问了。
两种计算方法:
方法一:线性筛预处理 $\sigma, \mu$ ,枚举 $d$ ,枚举 $d$ 的倍数,将 $d$ 的贡献累加进去,复杂度 $O(n \log n)$。
方法二:其实 $g(T)$ 是积性函数,所以线性筛直接筛之即可,$O(n)$ 。但是,在本题中 $a$ 的限制下,不能用这种方法,所以不再展开说。
有了 $a$ 的限制,我们可以离线处理:把询问按 $a$ 升序排序,$d$ 按 $\sigma(d)$ 升序排序,然后顺序处理询问,对于 $\sigma(d) \leq a$ 的 $d$ ,枚举倍数,将其贡献加入,树状数组维护 $g$ 的区间和,然后成块处理即可。
注意常数,减少不必要的取模。
「AC 代码」
#include <cstdio>
#include <algorithm>
typedef long long int64;
const int64 ONE = 1;
const int64 MOD = ONE << 31;
const int MAX_N = 100000;
const int MAX_Q = 20000;
bool isNotPrime[MAX_N + 1];
int primes[MAX_N], m;
int min[MAX_N + 1];
int mu[MAX_N + 1];
int64 sigma[MAX_N + 1];
inline void sieve(int n = MAX_N){
isNotPrime[1] = true, min[1] = 0, mu[1] = 1, sigma[1] = 1;
for(int i = 1; i <= n; i++){
if(!isNotPrime[i]){
primes[m++] = i;
min[i] = i;
mu[i] = -1;
sigma[i] = i + 1;
}
for(int j = 0, p, x; j < m; j++){
if((x = i * (p = primes[j])) > n) break;
isNotPrime[x] = true;
if(i % p == 0){
min[x] = min[i] * p;
mu[x] = 0;
}else{
min[x] = p;
mu[x] = -mu[i];
};
if(x == min[x]){
sigma[x] = ((1 - (int64)min[x] * p) / (1 - p)) % MOD;
} else{
sigma[x] = sigma[ min[x] ] * sigma[x / min[x]] % MOD;
}
if(i % p == 0) break;
}
}
}
struct BinaryIndexedTree{
int64 a[MAX_N + 1];
int n;
void init(int n){
this->n = n;
}
static int lowbit(int i){
return i & -i;
}
void add(int i, const int64& x){
for(; i <= n; i += lowbit(i)){
a[i] += x;
if(a[i] >= MOD) a[i] -= MOD;
};
}
int64 sum(int i){
int64 x = 0;
for(; i; i -= lowbit(i)){
x += a[i];
if(x >= MOD) x -= MOD;
}
return x;
}
int64 sum(int l, int r){
return (sum(r) - sum(l - 1) + MOD) % MOD;
}
} g;
struct Query{
int n, m, a;
int64 *ans;
inline friend bool operator<(const Query &x, const Query &y){
return x.a < y.a;
}
} querys[MAX_Q];
int qNum;
int n;
inline int64 query(int n, int m){
int64 ans = 0;
if(n > m) std::swap(n, m);
for(int64 l = 1, r; l <= n; l = r + 1){
r = std::min(n / (n / l), m / (m / l));
(ans += g.sum(l, r) * (n / l) % MOD * (m / l) % MOD) %= MOD;
}
return ans;
}
inline void apply(int64 val, int d){
for(int i = d; i <= n; i += d){
if(mu[i / d] == 1){
g.add(i, val);
} else if(mu[i / d] == -1){
g.add(i, -val + MOD);
}
}
}
inline void solve(){
std::sort(querys, querys + qNum);
sieve(n);
typedef std::pair<int64, int> Info;
static Info infos[MAX_N + 1];
for(int i = 1; i <= n; i++) infos[i] = std::make_pair(sigma[i], i);
std::sort(infos + 1, infos + n + 1);
g.init(n);
Info *it = infos + 1, *end = infos + n + 1;
for(Query *q = querys; q != querys + qNum; q++){
for(; it != end && it->first <= q->a; it++) apply(it->first, it->second);
*q->ans = query(q->n, q->m);
}
}
int main(){
scanf("%d", &qNum);
static int64 ans[MAX_Q];
for(Query *q = querys; q != querys + qNum; q++){
scanf("%d%d%d", &q->n, &q->m, &q->a), q->ans = &ans[q - querys];
n = std::max(n, std::max(q->n, q->m));
}
solve();
for(int i = 0; i < qNum; i++) printf("%lld\n", ans[i]);
return 0;
}
就是这样咯~