P2023 [AHOI2009] 维护序列题解

发布时间 2023-08-14 20:46:00作者: SunnyYuan

题目描述

image

思路

我们可以想到用线段树,

然后维护两个懒标记 \(\text{add, mul}\)

表示当前子区间需要乘上 \(\text{mul}\) 并加上 \(\text{add}\)

注意,如果一个区间需要乘上 \(x\),它的懒标记 \(\text{add}\) 也要乘上 \(x\)

下传标记需要特别注意 long long

代码

#include <bits/stdc++.h>

using namespace std;

const int N = 100010;

struct edge {
	int sum = 0;
	int mul = 1;
	int add = 0;
} tr[N << 2];

int n, mod, q;
int a[N];

void pushup(int u) {
	tr[u].sum = (1ll * tr[u << 1].sum + tr[u << 1 | 1].sum) % mod;
}

void addtag(int u, int l, int r, int type, int x) {
	if (type == 1) {
		tr[u].sum = (1ll * tr[u].sum * x) % mod;
		tr[u].add = (1ll * tr[u].add * x) % mod;
		tr[u].mul = (1ll * tr[u].mul * x) % mod;
	}
	else if (type == 2) {
		tr[u].sum = (tr[u].sum + 1ll * x * (r - l + 1)) % mod;
		tr[u].add = (1ll * tr[u].add + x) % mod;
	}
}

void pushdown(int u, int l, int r) {
	if (tr[u].mul != 1 || tr[u].add != 0) {
		int mid = (l + r) >> 1;
		tr[u << 1].add = (1ll * tr[u << 1].add * tr[u].mul) % mod;
		tr[u << 1 | 1].add = (1ll * tr[u << 1 | 1].add * tr[u].mul) % mod;
		tr[u << 1].sum = (1ll * tr[u << 1].sum * tr[u].mul) % mod;
		tr[u << 1 | 1].sum = (1ll * tr[u << 1 | 1].sum * tr[u].mul) % mod;
		tr[u << 1].mul = (1ll * tr[u << 1].mul * tr[u].mul) % mod;
		tr[u << 1 | 1].mul = (1ll * tr[u << 1 | 1].mul * tr[u].mul) % mod;
		
		tr[u << 1].sum = (1ll * tr[u << 1].sum + 1ll * tr[u].add * (mid - l + 1)) % mod;
		tr[u << 1 | 1].sum = (1ll * tr[u << 1 | 1].sum + 1ll * tr[u].add * (r - mid)) % mod;
		tr[u << 1].add = (1ll * tr[u << 1].add + tr[u].add) % mod;
		tr[u << 1 | 1].add = (1ll * tr[u << 1 | 1].add + tr[u].add) % mod;
		
		tr[u].mul = 1;
		tr[u].add = 0;
	}
}

void build(int u, int l, int r) {
	if (l == r) {
		tr[u].sum = a[l];
		return;
	}
	
	int mid = (l + r) >> 1;
	build(u << 1, l, mid);
	build(u << 1 | 1, mid + 1, r);
	pushup(u);
}

void modify(int u, int l, int r, int pl, int pr, int type, int x) {
	if (pl <= l && r <= pr) {
		addtag(u, l, r, type, x);
		return;
	}
	
	pushdown(u, l, r);
	
	int mid = (l + r) >> 1;
	if (pl <= mid) modify(u << 1, l, mid, pl, pr, type, x);
	if (pr > mid) modify(u << 1 | 1, mid + 1, r, pl, pr, type, x);
	
	pushup(u);
}

int query(int u, int l, int r, int pl, int pr) {
	if (pl <= l && r <= pr) return tr[u].sum;
	
	pushdown(u, l, r);
	int mid = (l + r) >> 1, sum = 0;
	if (pl <= mid) sum = (1ll * sum + query(u << 1, l, mid, pl, pr)) % mod;
	if (pr > mid) sum = (1ll * sum + query(u << 1 | 1, mid + 1, r, pl, pr)) % mod;
	
	return sum;
}

int main() {
	ios::sync_with_stdio(false);
	cin.tie(nullptr);
	
	cin >> n >> mod;
	for (int i = 1; i <= n; i++) cin >> a[i];
	build(1, 1, n);
	
	cin >> q;
	int opt, a, b, c;
	while (q--) {
		cin >> opt;
		if (opt == 1 || opt == 2) {
			cin >> a >> b >> c;
			modify(1, 1, n, a, b, opt, c);
		}
		else {
			cin >> a >> b;
			cout << query(1, 1, n, a, b) << '\n';
		}
	}
	return 0;
}