[ZJOI2008]树的统计(树链剖分,线段树)

题目

描述

一棵树上有n个节点,编号分别为1到n,每个节点都有一个权值w。
我们将以下面的形式来要求你对这棵树完成一些操作:
I. CHANGE u t : 把结点u的权值改为t
II. QMAX u v: 询问从点u到点v的路径上的节点的最大权值
III. QSUM u v: 询问从点u到点v的路径上的节点的权值和
注意:从点u到点v的路径上的节点包括u和v本身

输入格式

输入文件的第一行为一个整数n,表示节点的个数。
接下来n – 1行,每行2个整数a和b,表示节点a和节点b之间有一条边相连。
接下来一行n个整数,第i个整数wi表示节点i的权值。
接下来1行,为一个整数q,表示操作的总数。
接下来q行,每行一个操作,以“CHANGE u t”或者“QMAX u v”或者“QSUM u v”的形式给出。

输出格式

对于每个“QMAX”或者“QSUM”的操作,每行输出一个整数表示要求输出的结果。

输入样例

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
4
1 2
2 3
4 1
4 2 1 3
12
QMAX 3 4
QMAX 3 3
QMAX 3 2
QMAX 2 3
QSUM 3 4
QSUM 2 1
CHANGE 1 5
QMAX 3 4
CHANGE 3 6
QMAX 3 4
QMAX 2 4
QSUM 3 4

输出样例

1
2
3
4
5
6
7
8
9
10
4
1
2
2
10
6
5
6
5
16

说明

对于100%的数据,保证$1 \leq n \leq 30000,0 \leq q \leq 200000$;中途操作中保证每个节点的权值w在$-30000$到$30000$之间。


解题思路

这是一道树链剖分的模板题了。
注意权值有可能为负数,所以求最大值时要初始化为-INF

复杂度$O(n \log ^2n)$


Code

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
#include<cstdio>
#include<algorithm>

using namespace std;

typedef long long LL;
const LL INF = 1e16;
const int N = 30005;
int n, u, v, q;
char opt[10];
LL a[N];

struct Edge{
int nxt, to;
}edge[N<<1];
int head[N], edgeNum;
void addEdge(int from, int to){
edge[++edgeNum].nxt = head[from];
edge[edgeNum].to = to;
head[from] = edgeNum;
}

int size[N], fa[N], dep[N], son[N];
void dfs1(int x, int f, int depth){
size[x] = 1, fa[x] = f, dep[x] = depth, son[x] = 0;
for(int i = head[x]; i; i = edge[i].nxt){
if(edge[i].to == f) continue;
dfs1(edge[i].to, x, depth+1);
size[x] += size[edge[i].to];
if(size[edge[i].to] > size[son[x]]) son[x] = edge[i].to;
}
}
int st[N], ed[N], belong[N], dfsClock, fun[N];
void dfs2(int x, int top){
st[x] = ++dfsClock, fun[dfsClock] = x;
belong[x] = top;
if(son[x]) dfs2(son[x], top);
for(int i = head[x]; i; i = edge[i].nxt){
if(edge[i].to == fa[x] || edge[i].to == son[x]) continue;
dfs2(edge[i].to, edge[i].to);
}
ed[x] = dfsClock;
}

#define lid id<<1
#define rid id<<1|1
#define mid ((tr[id].l+tr[id].r)>>1)

struct segTree{
int l, r;
LL mx, sum;
segTree(){
l = r = 0;
sum = 0ll;
mx = -INF;
}
}tr[N<<2];

struct {
inline void pushup(int id){
tr[id].mx = max(tr[lid].mx, tr[rid].mx);
tr[id].sum = tr[lid].sum + tr[rid].sum;
}
void build(int id, int l, int r){
tr[id].l = l, tr[id].r = r;
if(tr[id].l == tr[id].r){
tr[id].mx = tr[id].sum = a[fun[l]];
return;
}
build(lid, l, mid);
build(rid, mid+1, r);
pushup(id);
}
LL query(int id, int l, int r, int k){
if(tr[id].l == l && tr[id].r == r)
return k == 0 ? tr[id].mx : tr[id].sum;
if(r <= mid) return query(lid, l, r, k);
else if(l > mid) return query(rid, l, r, k);
else return k == 0 ? max(query(lid, l, mid, k), query(rid, mid+1, r, k)) : query(lid, l, mid, k) + query(rid, mid+1, r, k);
}
void modify(int id, int pos, LL v){
if(tr[id].l == tr[id].r){
tr[id].sum = tr[id].mx = v;
return;
}
if(pos <= mid) modify(lid, pos, v);
else modify(rid, pos, v);
pushup(id);
}
}seg;

LL query(int u, int v, int k){
LL res = k == 0 ? -INF : 0;
while(belong[u] != belong[v]){
if(dep[belong[u]] < dep[belong[v]]) swap(u, v);
res = k == 0 ? max(res, seg.query(1, st[belong[u]], st[u], 0)) : res + seg.query(1, st[belong[u]], st[u], 1);
u = fa[belong[u]];
}
if(dep[u] > dep[v]) swap(u, v);
res = k == 0 ? max(res, seg.query(1, st[u], st[v], 0)) : res + seg.query(1, st[u], st[v], 1);
return res;
}

int main(){
scanf("%d", &n);
for(int i = 1; i < n; i++){
scanf("%d%d", &u, &v);
addEdge(u, v);
addEdge(v, u);
}
for(int i = 1; i <= n; i++)
scanf("%lld", &a[i]);
dfs1(1, 1, 1);
dfs2(1, 1);
seg.build(1, 1, n);
scanf("%d", &q);
while(q--){
scanf("%s%d%d", opt, &u, &v);
if(opt[0] == 'C')
seg.modify(1, st[u], 1ll*v);
else if(opt[1] == 'M')
printf("%lld\n", query(u, v, 0));
else if(opt[1] == 'S')
printf("%lld\n", query(u, v, 1));
}
return 0;
}