【题目描述】
给定一棵树,点有颜色,每次询问 $(u, v)$ 路径上,不同颜色的种数。
【题目链接】
SPOJ COT2 Count on a tree II
【解题思路】
树上莫队,详见 莫队算法 - 学习笔记
【AC代码】
#include <cstdio>
#include <algorithm>
#include <cmath>
const int MAXN = 40000;
const int MAXL = 2 * MAXN - 1;
const int MAXLOGL = 17;
const int MAXM = 100000;
struct Node{
struct Edge *edges;
Node *fa;
int id, pos, blockId, depth, w;
bool in;
} nodes[MAXN];
struct Edge{
Node *to;
Edge *next;
Edge() {}
Edge(Node *fr, Node *to) : to(to), next(fr->edges) {}
} epool[2 * (MAXN - 1)], *eptr = epool;
inline void link(Node *u, Node *v){
u->edges = new (eptr++) Edge(u, v);
v->edges = new (eptr++) Edge(v, u);
}
int n, m;
int ans[MAXM];
Node *f[MAXL][MAXLOGL + 1];
int logBase2[MAXL + 1];
int len;
Node *S[MAXN];
int top;
int blockCnt, blockSize;
inline void dfs(Node *v, Node *fa = NULL){
static int dfsClock = 0;
v->id = ++dfsClock;
f[v->pos = len++][0] = v;
int bot = top;
for(Edge *e = v->edges; e; e = e->next) if(!e->to->depth){
e->to->depth = v->depth + 1, e->to->fa = v, dfs(e->to);
if(top - bot >= blockSize){
while(top > bot) S[--top]->blockId = blockCnt;
blockCnt++;
}
f[len++][0] = v;
}
S[top++] = v;
}
inline Node* min(Node *a, Node *b){
return a->depth < b->depth ? a : b;
}
inline Node* rmq(int l, int r){
int k = logBase2[r - l + 1];
return min(f[l][k], f[r - (1 << k) + 1][k]);
}
inline Node* queryLCA(Node *a, Node *b){
if(a->pos > b->pos) std::swap(a, b);
return rmq(a->pos, b->pos);
}
void build(){
blockSize = int(std::ceil(std::sqrt(n)) + 1e-6);
nodes->depth = 1, nodes->fa = NULL, dfs(nodes);
while(top) S[--top]->blockId = blockCnt - 1;
logBase2[1] = 0;
for(int i = 2; i <= len; i++) logBase2[i] = logBase2[i >> 1] + 1;
for(int k = 1; k <= logBase2[len]; k++){
for(int i = 0; i < len; i++){
if(i + (1 << k - 1) < len){
f[i][k] = min(f[i][k - 1], f[i + (1 << k - 1)][k - 1]);
} else{
f[i][k] = f[i][k - 1];
}
}
}
}
inline void compress(){
static int tmp[MAXN];
for(int i = 0; i < n; i++) tmp[i] = (nodes + i)->w;
std::sort(tmp, tmp + n);
int size = std::unique(tmp, tmp + n) - tmp;
for(int i = 0; i < n; i++){
(nodes + i)->w = std::lower_bound(tmp, tmp + size, (nodes + i)->w) - tmp;
}
}
struct Query{
Node *u, *v;
int id;
inline friend bool operator<(const Query &a, const Query &b){
return a.u->blockId < b.u->blockId || (a.u->blockId == b.u->blockId && a.v->id < b.v->id);
}
} querys[MAXM];
int buc[MAXN];
int currAns;
inline void flip(Node *v){
v->in ^= 1;
if(v->in){
if(++buc[v->w] == 1) currAns++;
} else{
if(--buc[v->w] == 0) currAns--;
}
}
inline void flip(Node *u, Node *v){
Node *lca = queryLCA(u, v);
for(Node *p = u; p != lca; p = p->fa) flip(p);
for(Node *p = v; p != lca; p = p->fa) flip(p);
}
inline void solve(){
compress();
build();
std::sort(querys, querys + m);
Node *u = nodes, *v = nodes, *lca = nodes;
flip(nodes);
for(Query *q = querys; q != querys + m; q++){
flip(lca);
lca = queryLCA(q->u, q->v);
flip(lca);
flip(u, q->u), u = q->u;
flip(v, q->v), v = q->v;
ans[q->id] = currAns;
}
}
int main(){
scanf("%d%d", &n, &m);
for(Node *v = nodes; v != nodes + n; v++) scanf("%d", &v->w);
for(int i = 0, u, v; i < n - 1; i++){
scanf("%d%d", &u, &v), u--, v--;
link(nodes + u, nodes + v);
}
for(int i = 0, u, v; i < m; i++){
scanf("%d%d", &u, &v), u--, v--;
Query *q = querys + i;
q->u = nodes + u, q->v = nodes + v, q->id = i;
}
solve();
for(int i = 0; i < m; i++) printf("%d\n", ans[i]);
return 0;
}
就是这样啦~