「题目描述」

S国 有 $N$ 个 城市,编号从 $1$ 到 $N$ 。

城市间用 $N - 1$ 条双向道路连接,满足从一个城市出发可以到达其它所有城市。

每个城市信仰不同的宗教,如飞天面条神教、隐形独角兽教、绝地教都是常见的信仰,为了方便,我们用不同的正整数代表各种宗教。

S国 的居民常常旅行。旅行时他们总会走最短路,并且为了避免麻烦,只在 信仰和他们相同 的城市留宿。当然旅程的终点也是信仰与他相同的城市。

S国政府为每个城市标定了不同的旅行评级,旅行者们常会记下途中(包括起点和终点)留宿过的城市的评级 总和最大值

在S国的历史上常会发生以下几种事件:

  • CC x c :城市 x 的居民全体改信了 c 教。
  • CW x w :城市 x 的评级调整为 w。
  • QS x y :一位旅行者从城市 x 出发,到城市 y,并记下了途中留宿过的城市的评级总和。
  • QM x y :一位旅行者从城市 x 出发,到城市 y,并记下了途中留宿过的城市的评级最大值。

由于年代久远,旅行者记下的数字已经遗失了,但记录开始之前每座城市的信仰与评级,还有事件记录本身是完好的。

请根据这些信息,还原旅行者记下的数字。

为了方便,我们认为事件之间的间隔足够长,以致在任意一次旅行中,所有城市的评级和信仰保持不变。

「题目链接」

BZOJ 3531 旅行 「SDOI 2014」

「解题思路」

树链剖分,对每个宗教开一颗线段树。

修改信仰就是在原宗教的树中单点修改点权为 0,在新宗教的树中单点修改为真正的点权。

单点修改、树链查询什么的在对应的树中直接进行即可。

为保证复杂度,需要动态开点。

「AC 代码」

#include <cstdio>
#include <algorithm>
#include <stack>
#include <cassert>

template<typename T>
inline void uMax(T &x, const T &y){
    if(x < y) x = y;
}

const int MAX_N = 100000;
const int MAX_C = 100000;
const int MAX_NODE_NUM = 3400000;

struct SegMentTree{
#define mid ((l + r) >> 1)
    struct Node{
        Node *lc, *rc;
        int sum, max;

        void init(){
            lc = rc = NULL;
            sum = max = 0;
        }

        void update(){
            sum = (lc ? lc->sum : 0) + (rc ? rc->sum : 0);
            max = std::max((lc ? lc->max : 0), (rc ? rc->max : 0));
        }
    };

    static Node* newNode(){
        static Node nodes[MAX_NODE_NUM], *ptr = nodes;
        return ptr->init(), ptr++;
    }

    Node *root;
    int rootL, rootR;

    SegMentTree(int l, int r){
        rootL = l, rootR = r, root = NULL;
    }

    int querySum(Node *v, int l, int r, int qL, int qR){
        if(v == NULL) return 0;
        else if(l == qL && qR == r) return v->sum;
        else if(qR <= mid) return querySum(v->lc, l, mid, qL, qR);
        else if(qL >= mid) return querySum(v->rc, mid, r, qL, qR);
        else return querySum(v->lc, l, mid, qL, mid) + querySum(v->rc, mid, r, mid, qR);
    }

    int queryMax(Node *v, int l, int r, int qL, int qR){
        if(v == NULL) return 0;
        else if(l == qL && qR == r) return v->max;
        else if(qR <= mid) return queryMax(v->lc, l, mid, qL, qR);
        else if(qL >= mid) return queryMax(v->rc, mid, r, qL, qR);
        else return std::max(queryMax(v->lc, l, mid, qL, mid), queryMax(v->rc, mid, r, mid, qR));
    }

    void modify(Node *&v, int l, int r, int i, int x){
        if(v == NULL) v = newNode();

        if(r - l == 1){
            v->max = v->sum = x;
        } else{
            if(i < mid) modify(v->lc, l, mid, i, x); else modify(v->rc, mid, r, i, x);
            v->update();
        }
    }

    int querySum(int l, int r){
        return querySum(root, rootL, rootR, l, r);
    }

    int queryMax(int l, int r){
        return queryMax(root, rootL, rootR, l, r);
    }

    void modify(int i, int x){
        modify(root, rootL, rootR, i, x);
    }
} *segMentTree[MAX_C + 1];

struct Node{
    struct Edge *e;
    Node *f, *c, *p;
    int d, s, id;
    bool vis;

    int type, val;
} nodes[MAX_N];

inline Node* node(int i){
    return &nodes[i];
}

struct Edge{
    Node *t;
    Edge *n;

    Edge(Node *t = NULL, Edge *n = NULL) : t(t), n(n) {}
} edges[2 * (MAX_N - 1)], *ptr = edges;

inline void link(int a, int b){
    Node *u = node(a), *v = node(b);

    u->e = new (ptr++) Edge(v, u->e);
    v->e = new (ptr++) Edge(u, v->e);
}

int n, c;

inline void cut(Node *root){
    std::stack<Node*> S;

    S.push(root), root->f = NULL, root->d = 0;
    while(!S.empty()){
        Node *v = S.top();

        if(!v->vis){
            for(Edge *e = v->e; e; e = e->n) if(e->t != v->f){
                e->t->f = v, e->t->d = v->d + 1;
                S.push(e->t);
            }

            v->vis = true;
        } else{
            v->s = 1, v->c = NULL;
            for(Edge *e = v->e; e; e = e->n) if(e->t != v->f){
                v->s += e->t->s;
                if(!v->c || v->c->s < e->t->s) v->c = e->t;
            }

            S.pop();
        }
    }

    for(Node *v = node(0); v != node(n); v++) v->vis = false;

    int dfsClock = 0;
    S.push(root);
    while(!S.empty()){
        Node *v = S.top();

        if(!v->vis){
            v->id = dfsClock++;
            for(Edge *e = v->e; e; e = e->n) if(e->t != v->f && e->t != v->c) S.push(e->t);
            if(v->c) S.push(v->c);

            v->p = (!v->f || v != v->f->c) ? v : v->f->p;

            v->vis = true;
        } else{
            S.pop();
        }
    }
}

inline void build(){
    cut(nodes);

    for(int i = 1; i <= MAX_C; i++) segMentTree[i] = new SegMentTree(0, n);

    for(Node *v = node(0); v != node(n); v++) segMentTree[v->type]->modify(v->id, v->val);
}

inline void modifyType(int i, int x){
    Node *v = node(i);
    segMentTree[v->type]->modify(v->id, 0);
    segMentTree[v->type = x]->modify(v->id, v->val);
}

inline void modifyVal(int i, int x){
    Node *v = node(i);
    segMentTree[v->type]->modify(v->id, v->val = x);
}

inline int querySum(int x, int y){
    Node *u = node(x), *v = node(y);
    assert(u->type == v->type);
    SegMentTree *s = segMentTree[u->type];

    int ans = 0;
    while(u->p != v->p){
        if(u->p->d < v->p->d) std::swap(u, v);

        ans += s->querySum(u->p->id, u->id + 1);
        u = u->p->f;
    }

    if(u->d > v->d) std::swap(u, v);
    ans += s->querySum(u->id, v->id + 1);

    return ans;
}

inline int queryMax(int x, int y){
    Node *u = node(x), *v = node(y);
    assert(u->type == v->type);
    SegMentTree *s = segMentTree[u->type];

    int ans = 0;
    while(u->p != v->p){
        if(u->p->d < v->p->d) std::swap(u, v);

        uMax(ans, s->queryMax(u->p->id, u->id + 1));
        u = u->p->f;
    }

    if(u->d > v->d) std::swap(u, v);
    uMax(ans, s->queryMax(u->id, v->id + 1));

    return ans;
}

int main(){
    int q;

    scanf("%d%d", &n, &q);
    for(Node *v = node(0); v != node(n); v++) scanf("%d%d", &v->val, &v->type);
    for(int i = 0, u, v; i < n - 1; i++){
        scanf("%d%d", &u, &v), u--, v--;
        link(u, v);
    }

    build();

    static char cmd[2 + 1];
    for(int i = 0, x, y; i < q; i++){
        scanf("%s%d%d", cmd, &x, &y);

        if(cmd[0] == 'C'){
            x--;

            if(cmd[1] == 'C') modifyType(x, y);
            else if(cmd[1] == 'W') modifyVal(x, y);
            else throw;
        } else if(cmd[0] == 'Q'){
            x--, y--;

            if(cmd[1] == 'S') printf("%d\n", querySum(x, y));
            else if(cmd[1] == 'M') printf("%d\n", queryMax(x, y));
            else throw;
        } else throw;
    }

    return 0;
}

就是这样咯~