hdu-1540(线段树+区间合并)

发布时间 2023-04-07 17:58:22作者: 魏老6

Tunnel Warfare

HDU - 1540

思路:

没被摧毁的村庄为1,否则为0,用len记录

线段树维护区间的两个信息:

前缀最长1的序列pre

后缀最长1的序列suf

父节点与左右子节点的关系:

//lc为左节点,rc为右节点

1.若左右结点都不满1,则tr[p].pre = tr[lc].pre,tr[p].suf = tr[rc].suf

2.若左节点满1,tr[p].pre = tr[lc].pre + tr[rc].pre;

3.若右节点满1,tr[p].suf = tr[lc].suf + tr[rc].suf;

代码:

#define _CRT_SECURE_NO_WARNINGS 1
#include<algorithm>
#include<fstream>
#include<iostream>
#include<cstdio>
#include<deque>
#include<string>
#include<stdlib.h>
#include<string.h>
#include<math.h>
#include<vector>
#include<stack>
#include<queue>
#include<map>
#include<set>
#include<bitset>
#include<unordered_map>
using namespace std;
#define INF 2e9
#define MAXN 310000
#define N 1000010
#define M 10007
#define endl '\n'
#define exp 1e-8
#define lc p << 1
#define rc p << 1|1
#define lowbit(x) ((x)&-(x))
const double pi = acos(-1.0);
typedef long long LL;
typedef unsigned long long ULL;
inline ULL read() {
	ULL x = 0, f = 1;
	char ch = getchar();
	while (ch < '0' || ch>'9') {
		if (ch == '-')
			f = -1;
		ch = getchar();
	}
	while (ch >= '0' && ch <= '9') {
		x = (x << 1) + (x << 3) + (ch ^ 48);
		ch = getchar();
	}
	return x * f;
}
void print(ULL x) {
	if (x > 9)print(x / 10);
	putchar(x % 10 ^ 48);
}
int n, m,idx,his[N];
struct tree
{
	int l, r, len, pre, suf;
}tr[N*4];
void pushup(int p)
{
	int len = tr[p].r - tr[p].l + 1;
	tr[p].pre = tr[lc].pre;
	tr[p].suf = tr[rc].suf;
	if (len - (len >> 1) == tr[lc].pre) tr[p].pre = tr[lc].pre + tr[rc].pre;
	if (len >> 1 == tr[rc].suf) tr[p].suf = tr[lc].suf + tr[rc].suf;

}
void build(int p, int l, int r)
{
	tr[p].l = l, tr[p].r = r;
	if (l == r)
	{
		tr[p].len = tr[p].pre = tr[p].suf = 1;
		return;
	}
	int m = l + r >> 1;
	build(lc, l, m);
	build(rc, m + 1, r);
	pushup(p);
}
void update(int p, int x,int c )
{
	if (tr[p].l == tr[p].r)
	{
		tr[p].suf = tr[p].pre = tr[p].len = c;
		return;
	}
	int m = tr[p].l + tr[p].r >> 1;
	if (x <= m)update(lc, x, c);
	else  update(rc, x, c);
	pushup(p);
}
int query(int p, int x)
{
	if (tr[p].l == tr[p].r)
		return tr[p].len;
	int m = tr[p].l + tr[p].r >> 1;
	if (x <= m)
	{
		if (x > m - tr[lc].suf)return tr[lc].suf + tr[rc].pre;
		else return query(lc, x);
	}
	else
	{
		if (x <= m + tr[rc].pre)return tr[lc].suf + tr[rc].pre;
		else return query(rc, x);
	}
}
int main()
{
	while (scanf("%d%d", &n, &m) != EOF)
	{
		build(1, 1, n);
		idx = 0;
		while (m--)
		{
			char a;
			int x;
			cin >> a;
			if (a == 'D')
			{
				scanf("%d", &x);
				his[++idx] = x;
				update(1, x, 0);
			}
			else if (a == 'Q')
			{
				scanf("%d", &x);
				printf("%d\n", query(1, x));
			}
			else
			{
				x= his[idx--];
				update(1, x, 1);
			}
		}
	}
	return 0;
}