【题目描述】

给定一棵树,点有颜色,每次询问 $(u, v)$ 路径上,不同颜色的种数。

【题目链接】

SPOJ COT2 Count on a tree II

【解题思路】

树上莫队,详见 莫队算法 - 学习笔记

【AC代码】

#include <cstdio>
#include <algorithm>
#include <cmath>

const int MAXN = 40000;
const int MAXL = 2 * MAXN - 1;
const int MAXLOGL = 17;
const int MAXM = 100000;

struct Node{
    struct Edge *edges;
    Node *fa;
    int id, pos, blockId, depth, w;
    bool in;
} nodes[MAXN];

struct Edge{
    Node *to;
    Edge *next;

    Edge() {}
    Edge(Node *fr, Node *to) : to(to), next(fr->edges) {}
} epool[2 * (MAXN - 1)], *eptr = epool;

inline void link(Node *u, Node *v){
    u->edges = new (eptr++) Edge(u, v);
    v->edges = new (eptr++) Edge(v, u);
}

int n, m;
int ans[MAXM];

Node *f[MAXL][MAXLOGL + 1];
int logBase2[MAXL + 1];
int len;

Node *S[MAXN];
int top;
int blockCnt, blockSize;

inline void dfs(Node *v, Node *fa = NULL){
    static int dfsClock = 0;
    v->id = ++dfsClock;

    f[v->pos = len++][0] = v;

    int bot = top;
    for(Edge *e = v->edges; e; e = e->next) if(!e->to->depth){
        e->to->depth = v->depth + 1, e->to->fa = v, dfs(e->to);
        if(top - bot >= blockSize){
            while(top > bot) S[--top]->blockId = blockCnt;
            blockCnt++;
        }

        f[len++][0] = v;
    }

    S[top++] = v;
}

inline Node* min(Node *a, Node *b){
    return a->depth < b->depth ? a : b;
}

inline Node* rmq(int l, int r){
    int k = logBase2[r - l + 1];
    return min(f[l][k], f[r - (1 << k) + 1][k]);
}

inline Node* queryLCA(Node *a, Node *b){
    if(a->pos > b->pos) std::swap(a, b);
    return rmq(a->pos, b->pos);
}

void build(){
    blockSize = int(std::ceil(std::sqrt(n)) + 1e-6);
    nodes->depth = 1, nodes->fa = NULL, dfs(nodes);
    while(top) S[--top]->blockId = blockCnt - 1;

    logBase2[1] = 0;
    for(int i = 2; i <= len; i++) logBase2[i] = logBase2[i >> 1] + 1;
    
    for(int k = 1; k <= logBase2[len]; k++){
        for(int i = 0; i < len; i++){
            if(i + (1 << k - 1) < len){
                f[i][k] = min(f[i][k - 1], f[i + (1 << k - 1)][k - 1]);
            } else{
                f[i][k] = f[i][k - 1];
            }
        }
    }
}

inline void compress(){
    static int tmp[MAXN];
    for(int i = 0; i < n; i++) tmp[i] = (nodes + i)->w;
    std::sort(tmp, tmp + n);
    int size = std::unique(tmp, tmp + n) - tmp;
    for(int i = 0; i < n; i++){
        (nodes + i)->w = std::lower_bound(tmp, tmp + size, (nodes + i)->w) - tmp;
    }
}

struct Query{
    Node *u, *v;
    int id;

    inline friend bool operator<(const Query &a, const Query &b){
        return a.u->blockId < b.u->blockId || (a.u->blockId == b.u->blockId && a.v->id < b.v->id);
    }
} querys[MAXM];

int buc[MAXN];
int currAns;

inline void flip(Node *v){
    v->in ^= 1;
    if(v->in){
        if(++buc[v->w] == 1) currAns++;
    } else{
        if(--buc[v->w] == 0) currAns--;
    }
}

inline void flip(Node *u, Node *v){
    Node *lca = queryLCA(u, v);
    for(Node *p = u; p != lca; p = p->fa) flip(p);
    for(Node *p = v; p != lca; p = p->fa) flip(p);
}

inline void solve(){
    compress();

    build();
    std::sort(querys, querys + m);

    Node *u = nodes, *v = nodes, *lca = nodes;
    flip(nodes);
    for(Query *q = querys; q != querys + m; q++){
        flip(lca);
        lca = queryLCA(q->u, q->v);
        flip(lca);

        flip(u, q->u), u = q->u;
        flip(v, q->v), v = q->v;

        ans[q->id] = currAns;
    }
}

int main(){
    scanf("%d%d", &n, &m);
    for(Node *v = nodes; v != nodes + n; v++) scanf("%d", &v->w);
    for(int i = 0, u, v; i < n - 1; i++){
        scanf("%d%d", &u, &v), u--, v--;

        link(nodes + u, nodes + v);
    }
    for(int i = 0, u, v; i < m; i++){
        scanf("%d%d", &u, &v), u--, v--;

        Query *q = querys + i;
        q->u = nodes + u, q->v = nodes + v, q->id = i;
    }

    solve();

    for(int i = 0; i < m; i++) printf("%d\n", ans[i]);

    return 0;
}

就是这样啦~