【题目描述】
作为一个生活散漫的人,小 Z 每天早上都要耗费很久从一堆五颜六色的袜子中找出一双来穿。终于有一天,小 Z 再也无法忍受这恼人的找袜子过程,于是他决定听天由命……
具体来说,小 Z 把这 $N$ 只袜子从 $1$ 到 $N$ 编号,然后从编号 $L$ 到 $R$ 的袜子中随机选取,尽管小 Z 并不在意两只袜子是不是完整的一双,甚至不在意两只袜子是否一左一右,他却很在意袜子的颜色,毕竟穿两只不同色的袜子会很尴尬。
你的任务便是告诉小 Z ,他有多大的概率抽到两只颜色相同的袜子。当然,小 Z 希望这个概率尽量高,所以他可能会询问多个 $(L, R)$ 以方便自己选择。
【题目链接】
BZOJ 2038 小 Z 的袜子 【国家集训队 2009】
【解题思路】
在区间 $[l, r]$ 内,设 $S$ 表示袜子的颜色集合,$f(x)$ 表示颜色 $x$ 出现的次数,根据古典概型:
$$ ans = \frac {\sum_{x \in S} C(2, f(x))} {C(2, r - l + 1)} $$
分母可以直接求出。
考虑到 $C(x, 2) = x(x - 1) = x^2 - x$,不妨将分子展开:
$$ \begin{align*} \sum_{x \in S} C(2, f(x)) & = \sum_{x \in S} (f^2(x) - f(x)) \\ & = \sum_{x \in S} f^2(x) - \sum_{x \in S} f(x) \\ & = \sum_{x \in S} f^2(x) - (r - l + 1) \end{align*} $$
所以我们需要求的是每种颜色出现次数的平方和。
这可以用莫队算法解决,详见 莫队算法 - 学习笔记。
【AC代码】
#include <cstdio>
#include <algorithm>
#include <cmath>
typedef long long int64;
inline int64 sqr(int64 x){
return x * x;
}
inline int64 gcd(int64 a, int64 b){
int64 d = 1;
while(a && b){
while(~a & 1 && ~b & 1) a >>= 1, b >>= 1, d <<= 1;
while(~a & 1) a >>= 1;
while(~b & 1) b >>= 1;
if(a < b) std::swap(a, b);
a = a - b >> 1;
}
return std::max(a, b) * d;
}
inline void reduce(int64 &u, int64 &d){
int64 g = gcd(u, d);
u /= g, d /= g;
}
const int MAXN = 50000;
const int MAXM = 50000;
int n, m;
int a[MAXN];
int64 ansU[MAXM], ansD[MAXM];
int blockSize;
struct Query{
int l, r;
int id;
inline friend bool operator<(const Query &a, const Query &b){
if(a.l / blockSize != b.l / blockSize) return a.l / blockSize < b.l / blockSize;
else return a.r < b.r;
}
void calc(int64 sqrSum){
ansU[id] = sqrSum - (r - l + 1);
ansD[id] = (int64)(r - l) * (r - l + 1);
reduce(ansU[id], ansD[id]);
}
} querys[MAXM];
int l, r;
int f[MAXN + 1];
bool in[MAXN];
int64 currAns;
inline void flip(int pos){
in[pos] ^= 1;
currAns -= sqr(f[ a[pos] ]);
if(in[pos]){
f[ a[pos] ]++;
} else{
f[ a[pos] ]--;
}
currAns += sqr(f[ a[pos] ]);
}
inline void solve(){
blockSize = static_cast<int>(std::ceil(std::sqrt(n)) + 1e-6);
std::sort(querys, querys + m);
l = 0, r = 0, flip(0);
for(Query *q = querys; q != querys + m; q++){
while(l > q->l) flip(--l);
while(r < q->r) flip(++r);
while(l < q->l) flip(l++);
while(r > q->r) flip(r--);
q->calc(currAns);
}
}
int main(){
scanf("%d%d", &n, &m);
for(int i = 0; i < n; i++) scanf("%d", &a[i]);
for(int i = 0; i < m; i++){
Query *q = &querys[i];
scanf("%d%d", &q->l, &q->r), q->l--, q->r--;
q->id = i;
}
solve();
for(int i = 0; i < m; i++) printf("%lld/%lld\n", ansU[i], ansD[i]);
return 0;
}
就是这样咯~