Wankupi's Website

带修标号区间子树点权和之和

一句话题意:树上点权带修,查询标号区间内每个点的子树和的和。

\[\sum_{x=l}^{r}\sum_{y\in \operatorname{subtree}(x)} w(y) \]

结点数量和操作数不大于 \(1e5\)。时限 1.5s。

60pts

观察到和式中有两个求和号,考虑交换顺序。

\[\sum_{x=l}^{r}\sum_{y\in \operatorname{subtree}(x)} w(y) = \sum_{y=1}^{n} w(y)\sum_{x=l}^{r} [y\in \operatorname{subtree}(x)] \]

发现对于左边可以快速地对标号区间进行逐个增缩, 而右边可以方便地对单点权值的变化进行改动。

因此考虑 带修莫队,同时维护 点权的树状数组 和 根链中标号在 \([l,r]\) 中的节点数量的树状数组。

复杂度约为 \(O(n^{5/3}\log n)\)。取得60分。本地满数据最优块长需要6s左右。

cpp
#include <algorithm>
#include <cstdio>

using ll = long long;
int const maxn = 100003;
int const block = 2500;

int n = 0, m = 0;

int head[maxn], nxt[maxn << 1], to[maxn << 1], cnt = 0;
void insert(int u, int e) {
	nxt[++cnt] = head[u];
	head[u] = cnt;
	to[cnt] = e;
}

struct TreeArray {
	static int lowbit(int x) { return x & -x; }
	ll tr[maxn];
	void add(int p, ll v) {
		while (p <= n) {
			tr[p] += v;
			p += lowbit(p);
		}
	}
	ll query(int p) {
		ll ret = 0;
		while (p) {
			ret += tr[p];
			p -= lowbit(p);
		}
		return ret;
	}
	void add(int l, int r, ll v) {
		add(l, v);
		add(r + 1, -v);
	}
	ll query(int l, int r) {
		return query(r) - query(l - 1);
	}
};

int w[maxn];
int dfn[maxn], cdfn = 0;
int siz[maxn];
void dfs(int x, int fa) {
	dfn[x] = ++cdfn;
	siz[x] = 1;
	for (int i = head[x]; i; i = nxt[i]) {
		if (to[i] == fa) continue;
		dfs(to[i], x);
		siz[x] += siz[to[i]];
	}
}

struct Query {
	int id;
	int l, r, t;
};
struct Change {
	int x, y, from;
};
inline bool operator<(Query const &A, Query const &B) {
	if (A.l / block != B.l / block)
		return A.l < B.l;
	if (A.r / block != B.r / block)
		return ((A.l / block) & 1) ? A.r < B.r : A.r > B.r;
	return ((A.r / block) & 1) ? A.t < B.t : A.t > B.t;
}

Change ch[maxn];
int cnt_change = 0;
Query q[maxn];
int cnt_query = 0;

TreeArray sum, chain;
ll total = 0;
inline void add(int x) {
	total += sum.query(dfn[x], dfn[x] + siz[x] - 1);
	chain.add(dfn[x], dfn[x] + siz[x] - 1, +1);
}
inline void rmv(int x) {
	total -= sum.query(dfn[x], dfn[x] + siz[x] - 1);
	chain.add(dfn[x], dfn[x] + siz[x] - 1, -1);
}
inline void timeAdd(int t) {
	total += (ch[t].y - ch[t].from) * chain.query(dfn[ch[t].x]);
	sum.add(dfn[ch[t].x], ch[t].y - ch[t].from);
}
inline void timeRmv(int t) {
	total -= (ch[t].y - ch[t].from) * chain.query(dfn[ch[t].x]);
	sum.add(dfn[ch[t].x], -ch[t].y + ch[t].from);
}

int ql = 1, qr = 0, qt = 0;

ll ans[maxn];

int main() {
	scanf("%d%d", &n, &m);
	for (int i = 1; i <= n; ++i)
		scanf("%d", w + i);
	for (int i = 1; i <= n; ++i) {
		int u, v;
		scanf("%d%d", &u, &v);
		insert(u, v);
		insert(v, u);
	}
	dfs(to[head[0]], 0);
	for (int i = 1; i <= n; ++i)
		sum.tr[dfn[i]] = w[i];
	for (int i = 1; i <= n; ++i)
		sum.tr[i] += sum.tr[i - 1];
	for (int i = n; i >= 1; --i)
		sum.tr[i] -= sum.tr[i - sum.lowbit(i)];

	for (int i = 1; i <= m; ++i) {
		int opt, x, y;
		scanf("%d %d %d", &opt, &x, &y);
		if (opt == 1) {
			ch[++cnt_change] = {x, y, w[x]};
			w[x] = y;
		}
		else {
			++cnt_query;
			q[cnt_query] = {cnt_query, x, y, cnt_change};
		}
	}
	for (int i = cnt_change; i >= 1; --i)
		w[ch[i].x] = ch[i].from;

	std::sort(q + 1, q + cnt_query + 1);

	for (int i = 1; i <= cnt_query; ++i) {
		auto [id, l, r, t] = q[i];
		while (qr < r) add(++qr);
		while (ql > l) add(--ql);
		while (qr > r) rmv(qr--);
		while (ql < l) rmv(ql++);
		while (qt < t) timeAdd(++qt);
		while (qt > t) timeRmv(qt--);
		ans[id] = total;
	}
	for (int i = 1; i <= cnt_query; ++i)
		printf("%lld\n", ans[i]);
	return 0;
}

70pts

考虑对标号区间分块。显然答案在标号区间具有可加性。

在上面的等式中,右边

\[\sum_{y=1}^{n} w(y)\sum_{x=l}^{r} [y\in \operatorname{subtree}(x)] \]

告诉了我们在已知区间时,该如何处理点权的变化。

现在我们已经分好了块 \([L,R]\),那么

\[\sum_{x=L}^{R} [y\in \operatorname{subtree}(x)] \]

就给出了 \(y\) 的权值发生改变时对该块的影响系数。

每一块对每一个点都有一个系数,总共 \(n \times \frac{n}{B}\) 个。

于是预处理出这些系数,当点权修改时, 扫描每一个块,维护其答案。

那么查询时,我们将块内答案直接加和, 块外散点仍然使用树状数组求子树和。

复杂度 \(O(n\sqrt{n}\log n)\),本地满数据 3s。

cpp
#include <algorithm>
#include <cstdio>

using ll = long long;

int const maxn = 100003;
int const block = 300;
int const maxBlocks = maxn / block + 3;

inline int read() {
	int x = 0, c = getchar();
	while (c < '0' || c > '9')
		c = getchar();
	while ('0' <= c && c <= '9') {
		x = 10 * x + c - '0';
		c = getchar();
	}
	return x;
}

int n = 0, m = 0;

int L[maxBlocks], R[maxBlocks], cnt_block = 0;
int belong[maxn];

int head[maxn], nxt[maxn << 1], to[maxn << 1], cnt = 0;
void insert(int u, int e) {
	nxt[++cnt] = head[u];
	head[u] = cnt;
	to[cnt] = e;
}

int w[maxn];
int dfn[maxn], cdfn = 0;
int siz[maxn];
void dfs(int x, int fa) {
	dfn[x] = ++cdfn;
	siz[x] = 1;
	for (int i = head[x]; i; i = nxt[i]) {
		if (to[i] == fa) continue;
		dfs(to[i], x);
		siz[x] += siz[to[i]];
	}
}

int chain[maxBlocks][maxn]; // be init in dfs2

int block_id_dfs2 = 0;
int count_nodes_in_range_dfs2 = 0;
void dfs2(int x, int fa) {
	if (L[block_id_dfs2] <= x && x <= R[block_id_dfs2])
		++count_nodes_in_range_dfs2;
	chain[block_id_dfs2][x] = count_nodes_in_range_dfs2;
	for (int i = head[x]; i; i = nxt[i])
		if (to[i] != fa)
			dfs2(to[i], x);
	if (L[block_id_dfs2] <= x && x <= R[block_id_dfs2])
		--count_nodes_in_range_dfs2;
}

struct TreeArray {
	static int lowbit(int x) { return x & -x; }
	ll tr[maxn];
	void add(int p, ll v) {
		while (p <= n) {
			tr[p] += v;
			p += lowbit(p);
		}
	}
	ll query(int p) {
		ll ret = 0;
		while (p) {
			ret += tr[p];
			p -= lowbit(p);
		}
		return ret;
	}
	void add(int l, int r, ll v) {
		add(l, v);
		add(r + 1, -v);
	}
	ll query(int l, int r) {
		return query(r) - query(l - 1);
	}
};
TreeArray sum;

ll total[maxBlocks];

void prepare() {
	cnt_block = (n + block - 1) / block;
	for (int i = 1; i <= cnt_block; ++i) {
		L[i] = R[i - 1] + 1;
		R[i] = L[i] + block - 1;
		if (R[i] > n) R[i] = n;
		for (int j = L[i]; j <= R[i]; ++j)
			belong[j] = i;
	}
	dfs(to[head[0]], 0);
	for (int i = 1; i <= cnt_block; ++i) {
		block_id_dfs2 = i;
		dfs2(to[head[0]], 0);
	}
	for (int i = 1; i <= n; ++i)
		sum.tr[dfn[i]] = w[i];
	for (int i = 1; i <= n; ++i)
		sum.tr[i] += sum.tr[i - 1];

	for (int i = 1; i <= cnt_block; ++i) {
		for (int j = L[i]; j <= R[i]; ++j)
			total[i] += sum.tr[dfn[j] + siz[j] - 1] - sum.tr[dfn[j] - 1];
	}

	for (int i = n; i >= 1; --i)
		sum.tr[i] -= sum.tr[i - sum.lowbit(i)];
}

void modify(int x, int y) {
	ll delta = y - w[x];
	w[x] = y;
	sum.add(dfn[x], delta);
	for (int i = 1; i <= cnt_block; ++i)
		total[i] += delta * chain[i][x];
}

ll query(int l, int r) {
	int lid = (l + block - 1) / block, rid = (r + block - 1) / block;
	ll ret = 0;
	auto add_range = [&ret](int l, int r) {
		for (int i = l; i <= r; ++i)
			ret += sum.query(dfn[i], dfn[i] + siz[i] - 1);
	};
	if (lid == rid) {
		add_range(l, r);
		return ret;
	}
	if (l != L[lid])
		add_range(l, R[lid++]);
	if (r != R[lid])
		add_range(L[rid--], r);
	for (int i = lid; i <= rid; ++i)
		ret += total[i];
	return ret;
}

int main() {
	n = read();
	m = read();
	for (int i = 1; i <= n; ++i)
		w[i] = read();
	for (int i = 1; i <= n; ++i) {
		int u = read(), v = read();
		insert(u, v);
		insert(v, u);
	}
	prepare();
	for (int i = 1; i <= m; ++i) {
		int opt = read(), x = read(), y = read();
		if (opt == 1)
			modify(x, y);
		else
			printf("%lld\n", query(x, y));
	}
	return 0;
}

100pts

在 70 分的做法,显然根号是无法优化的。

考虑优化掉 \(\log n\)。 这个 \(\log\) 是在树状数组中取得, 我们想一想树状数组都做了什么。 首先权值修改时,进行一个单点加。 查询子树和时,进行一个区间求和,转化为前缀和相减。

这个过程也可以分块。 你可能会说,分块之后查询不是 \(\sqrt{n}\) 吗?变差了!

但是,我们可以考虑再维护 块内前缀和 与 块前缀和。 于是就可以 \(O(1)\) 查询啦! 而且修改的时候只是从 \(O(\log n + \sqrt{n})\) 变成了 \(O(\sqrt{n} + \sqrt{n})\)

本机调一下块长后,足以通过此题。

cpp
struct Array {
	ll tr[maxn]; // 历史遗留,没改
	ll S[maxBlocks];
	ll s[maxBlocks][block + 1];
	void init() {
		for (int i = 1; i <= cnt_block; ++i) {
			S[i] = tr[R[i]];
			for (int j = L[i]; j <= R[i]; ++j)
				s[i][j - L[i] + 1] = tr[j] - tr[L[i] - 1];
		}
	}
	void add(int p, ll v) {
		int pid = (p + block - 1) / block;
		for (int j = p; j <= R[pid]; ++j)
			s[pid][j - L[pid] + 1] += v;
		for (int i = pid; i <= cnt_block; ++i)
			S[i] += v;
	}
	ll query(int l, int r) {
		int lid = (l + block - 1) / block, rid = (r + block - 1) / block;
		if (lid == rid) return s[lid][r - L[lid] + 1] - s[lid][l - L[lid]];
		ll ret = 0;
		if (r != R[rid]) {
			ret += s[rid][r - L[rid] + 1];
			--rid;
		}
		if (l != L[lid]) {
			ret += s[lid][R[lid] - L[lid] + 1] - s[lid][l - L[lid]];
			++lid;
		}
		ret += S[rid] - S[lid - 1];
		return ret;
	}
};
void prepare() {
	// FROM
	// for (int i = n; i >= 1; --i)
	//  sum.tr[i] -= sum.tr[i - sum.lowbit(i)];
	// TO
	sum.init();
}


引用

此题由 xjq 学长推荐为小作业。这是他的题解链接