【题目描述】
给定一颗$n$个节点的树,树的节点标号从0开始。每个节点可以是黑色或白色。要求支持以下操作:
- 将x节点涂黑
- 查询节点x到所有黑点的距离之和
【解题思路】
首先想到,在每次询问时,把查询点到所有黑点的距离都全部计算一遍,这显然在时间上是行不通的。
想想我们每次计算的内容,有哪些地方是共通的而不是必须每次重新计算得到的。 设黑点集合$S$,我们要计算的是 $$ ans(x) = \sum_{v \in S} dist(x, v) $$ 我们知道 $$ dist(x, v) = preDist(x) + preDist(v) - 2 * preDist(LCA(x, v)) $$ 其中$preDist(x)$表示x到根节点距离。
观察上式,在我们的累加求和过程中,$preDist(x)$ 和$preDist(v)$两项的和都是可以预处理出来的,我们真正需要实时计算的,只是$preDist(LCA(x, v))$一项的和。
树链剖分,每次在将一个点染黑时,将该点到根的路径上所有边权的系数加一(初始系数为零),用线段树维护,在查询时,查询该点到根路径上的带系数边权和,即为所有 $LCA$ 的距离之和。
每次修改都影响 黑点到根的一整条路径上的信息,而询问时是查询 到根的整条路径上的信息,也就是说,每次修改对每次询问造成的实质影响是两条路径的公共部分,即 $LCA$ 到根的的路径,也就是我们需要的信息。
系数含义的实质就是该子树中的黑点个数。
【代码】
记得开long long。 树剖写了非递归,感觉也挺自然的。 线段树写的还不是很熟,要多加练习。
#include <cstdio>
#include <algorithm>
#include <queue>
#include <stack>
#include <iostream>
#define MAXN 1000020
#define int64 long long
struct Node;
struct Path;
struct Node{
Node *father;
Node *children, *next;
bool asked;
int depth, size;
Node *maxChild;
int maxDepth, pos;
Path *path;
int w;
bool blacked;
int64 dist;
Node(){
father = children = next = maxChild = NULL;
path = NULL;
asked = blacked = false;
depth = size = maxDepth = 0;
}
} vs[MAXN];
int64 n;
Node *root = vs;
inline void addChild(int64 u, int64 v){
(vs + v)->next = (vs + u)->children;
(vs + v)->father = vs + u;
(vs + u)->children = vs + v;
}
#define mid (this->l + this->r >> 1)
struct SegmentTree{
SegmentTree *lchild, *rchild;
int l, r;
int base; // base为带系数和的基
int lazy; // 区间延迟修改量
int64 sum; // 该区间带系数和
void update(){
sum = base = 0;
if(lchild) sum += lchild->sum, base += lchild->base;
if(rchild) sum += rchild->sum, base += rchild->base;
}
void pushDown(){
if(lazy){
if(lchild) lchild->lazy += lazy, lchild->sum += lazy * lchild->base;
if(rchild) rchild->lazy += lazy, rchild->sum += lazy * rchild->base;
lazy = 0;
}
}
SegmentTree(int64 l, int64 r) : l(l), r(r), base(0), lazy(0), sum(0){
if(r - l == 1) lchild = rchild = NULL;
else{
lchild = new SegmentTree(l, mid);
rchild = new SegmentTree(mid, r);
update();
}
}
void setBase(int64 pos, int64 x){
if(r - l == 1) this->base = x;
else{
if(pos < mid) lchild->setBase(pos, x);
else rchild->setBase(pos, x);
update();
}
}
void add(int l, int r, int delta){
if(this->l == l && this->r == r) lazy += delta, sum += base * delta;
else{
pushDown();
if(l < mid) lchild->add(l, std::min(mid, r), delta);
if(r > mid) rchild->add(std::max(mid, l), r, delta);
update();
}
}
int64 query(int l, int r){
if(this->l == l && this->r == r) return sum;
else{
pushDown();
int64 ans = 0;
if(l < mid) ans = ans + lchild->query(l, std::min(mid, r));
if(r > mid) ans = ans + rchild->query(std::max(l, mid), r);
return ans;
}
}
};
struct Path{
SegmentTree *S;
Node *top;
Path(Node *v){
top = v;
S = new SegmentTree(0, v->maxDepth - v->depth + 1);
}
};
inline void cut(){
std::stack<Node*> S;
for(Node *v = vs; v != vs + n; v++) v->asked = false;
root->depth = 0;
root->dist = root->w;
S.push(root);
while(!S.empty()){
Node *v = S.top();
if(!v->asked){
for(Node *vi = v->children; vi; vi = vi->next){
vi->depth = v->depth + 1;
vi->dist = v->dist + vi->w;
S.push(vi);
}
v->asked = true;
}
else{
v->size = 1;
for(Node *vi = v->children; vi; vi = vi->next){
v->size += vi->size;
if(!v->maxChild || vi->size > v->maxChild->size){
v->maxChild = vi;
}
}
if(v->maxChild) v->maxDepth = v->maxChild->maxDepth;
else v->maxDepth = v->depth;
S.pop();
}
}
std::queue<Node*> Q;
Q.push(root);
while(!Q.empty()){
Node *v = Q.front(); Q.pop();
if(v == root || v != v->father->maxChild) v->path = new Path(v), v->pos = 0;
else v->path = v->father->path, v->pos = v->father->pos + 1;
for(Node *vi = v->children; vi; vi = vi->next) Q.push(vi);
}
for(Node *v = vs; v != vs + n; v++) v->path->S->setBase(v->pos, v->w);
}
int64 allBlackDist = 0;
int allBlackCount = 0;
inline int64 query(int64 u){
Node *v = vs + u;
int dist = v->dist;
int64 lcaSum = 0;
while(v){
lcaSum += v->path->S->query(0, v->pos + 1);
v = v->path->top->father;
}
return allBlackDist + (int64) dist * allBlackCount - 2 * lcaSum;
}
inline void black(int u){
Node *v = vs + u;
if(v->blacked) return;
else v->blacked = true;
allBlackDist = allBlackDist + v->dist;
allBlackCount++;
while(v){
v->path->S->add(0, v->pos + 1, 1);
v = v->path->top->father;
}
}
int main(){
int m;
scanf("%d%d", &n, &m);
for(int i = 1; i <= n - 1; i++){
int v;
scanf("%d", &v);
addChild(v, i);
}
for(Node *v = vs + 1; v != vs + n; v++) scanf("%d", &v->w);
cut();
for(int i = 0; i < m; i++){
int opt, x;
scanf("%d%d", &opt, &x);
if(opt == 1) black(x);
else printf("%I64d\n", query(x));;
}
return 0;
}
就这样啦