【题目描述】

有一棵点数为 N 的树,以点 1 为根,且树点有边权。

然后有 M 个操作,分为三种:

  • 操作 1 :把某个节点 x 的点权增加 a 。
  • 操作 2 :把某个节点 x 为根的子树中所有点的点权都增加 a 。
  • 操作 3 :询问某个节点 x 到根的路径中所有点的点权和。

【题目链接】

BZOJ 4034 树上操作 【HAOI 2016】

【解题思路】

模板题,拿来练习一下树剖。

简要题解:树链剖分,重链先行求 DFS序,这样子树在序列上为连续的,重链在序列上也为连续的,线段树维护区间操作即可。

【AC代码】

尝试了下标记永久化的线段树。

#include <cstdio>
#include <stack>

typedef long long int64;

const int MAXN = 100000;

struct SegMentTree{
    #define mid (this->l + this->r >> 1)

    int l, r;
    SegMentTree *lc, *rc;
    int64 sum, tag;

    SegMentTree(int l = 0, int r = 0) : l(l), r(r), lc(NULL), rc(NULL), sum(0), tag(0) {}

    void update(){
        sum = lc->sum + rc->sum + tag * (r - l);
    }

    void build(){
        if(r - l != 1){
            (lc = new SegMentTree(l, mid))->build();
            (rc = new SegMentTree(mid, r))->build();

            update();
        }
    }

    void modify(int i, int x){
        sum += x;
        if(r - l != 1) (i < mid ? lc : rc)->modify(i, x);
    }

    void modifyAll(int l, int r, int x){
        if(l == this->l && r == this->r) tag += x, sum += (int64)(r - l) * x;
        else{
            if(l < mid) lc->modifyAll(l, std::min(mid, r), x);
            if(r > mid) rc->modifyAll(std::max(l, mid), r, x);

            update();
        }
    }

    int64 query(int l, int r, int64 extra = 0){
        if(l == this->l && r == this->r) return sum + (r - l) * extra;
        else{
            int64 ans = 0;
            if(l < mid) ans += lc->query(l, std::min(mid, r), extra + tag);
            if(r > mid) ans += rc->query(std::max(l, mid), r, extra + tag);
            return ans;
        }
    }

    #undef mid
} *segMentTree;

struct Node{
    struct Edge *edges;

    bool pushed;
    Node *fa, *son, *top;
    int w, sz, begin, end;
} nodes[MAXN];

struct Edge{
    Node *to;
    Edge *next;

    Edge(Node *fr, Node *to) : to(to), next(fr->edges) {}
};

inline void link(Node *u, Node *v){
    u->edges = new Edge(u, v);
    v->edges = new Edge(v, u);
}

int n;

inline void cut(Node *root){
    std::stack<Node*> S;

    for(Node *v = nodes; v != nodes + n; v++) v->pushed = false;

    S.push(root), root->fa = NULL;
    while(!S.empty()){
        Node *v = S.top();
        if(!v->pushed){
            for(Edge *e = v->edges; e; e = e->next) if(e->to != v->fa){
                e->to->fa = v;

                S.push(e->to);
            }

            v->pushed = true;
        } else{
            v->sz = 1, v->son = NULL;
            for(Edge *e = v->edges; e; e = e->next) if(e->to != v->fa){
                v->sz += e->to->sz;
                if(v->son == NULL || v->son->sz < e->to->sz) v->son = e->to;
            }

            S.pop();
        }
    }

    for(Node *v = nodes; v != nodes + n; v++) v->pushed = false;

    int dfsClock = 0;
    S.push(root);
    while(!S.empty()){
        Node *v = S.top();
        if(!v->pushed){
            v->begin = dfsClock++;

            if(v == root || v != v->fa->son) v->top = v; else v->top = v->fa->top;

            for(Edge *e = v->edges; e; e = e->next) if(e->to != v->fa && e->to != v->son) S.push(e->to);
            if(v->son) S.push(v->son);

            v->pushed = true;
        } else{
            v->end = dfsClock;

            S.pop();
        }
    }
}

inline void build(){
    cut(nodes);

    segMentTree = new SegMentTree(0, n);
    segMentTree->build();
    for(Node *v = nodes; v != nodes + n; v++) segMentTree->modify(v->begin, v->w);
}

inline void modify(Node *v, int x){
    segMentTree->modify(v->begin, x);
}

inline void modifyAll(Node *v, int x){
    segMentTree->modifyAll(v->begin, v->end, x);
}

inline int64 query(Node *v){
    int64 ans = 0;

    while(v){
        ans += segMentTree->query(v->top->begin, v->begin + 1);
        v = v->top->fa;
    }

    return ans;
}

int main(){
    int m;

    scanf("%d%d", &n, &m);
    for(Node *v = nodes; v != nodes + n; v++) scanf("%d", &v->w);
    for(int i = 0, u, v; i < n - 1; i++){
        scanf("%d%d", &u, &v), u--, v--;

        link(nodes + u, nodes + v);
    }

    build();

    for(int cmd, pos, val; m--; ){
        scanf("%d%d", &cmd, &pos), pos--;

        if(cmd == 1 || cmd == 2){
            scanf("%d", &val);

            if(cmd == 1) modify(nodes + pos, val); else modifyAll(nodes + pos, val);
        } else if(cmd == 3){
            printf("%lld\n", query(nodes + pos));
        } else{
            throw;
        }
    }

    return 0;
}

就是这样咯~