CF547E Mike and Friends题解

发布时间 2023-06-08 20:02:59作者: 霜木_Atomic

题目链接

温馨提示:做本题之前可以先尝试这个:洛谷 P2414 阿狸的打字机(是简单版的uwu)。

首先,这个题涉及多模式串匹配,首先想 AC 自动机。但是有个问题:我们如何去计算一个串出现的次数呢?

我们先考虑查询一个串 \(a\) 在串 \(b\) 中出现的次数。首先,在 AC 自动机上有一个性质,就是如果从某一点 \(u\)\(fail\) 往回跳的时候,会直接或间接经过另一个串 \(s\) 的结束位置,那这个点上一定有 \(s\) 串。这个结论很显然,因为 \(fail\) 跳的是最长相同后缀,那 \(s\) 一定是以 \(u\) 为结束点的一个串的后缀。我们首先想到一种暴力,那就是在 Trie 树上往下走 \(b\) 串,然后每走到一个点都跳 \(fail\),看 \(a\) 的结束位置能被经过几次。但是这样显然复杂的会炸,更别提区间查了。怎么办呢?

我们又发现一个性质。\(fail\) 每次只会往深度更小的点跳,这意味着,\(fail\) 的指针会形成一个树的结构!没错!我们可以建立一个 \(fail\) 树。那刚才的问题就可以转变为,以 \(a\) 串的结束位置为根的子树上,有多少个节点属于 \(b\) 串。这个可以通过dfn序和树状数组直接维护(想写线段树没人拦你)。每次将 \(b\) 上的节点赋成 \(1\),然后直接查询即可。查完之后,再赋成 \(0\)(这就是阿狸的打字机的做法awa)

我们来回到这个题。这个题要求区间查。我们首先来想,如果只查一个区间,那和上一道题类似,将整个区间的字符串的每个点都加上出现次数,然后树状数组查。但是这个是一堆区间。这时候,我们有一个常用的套路(在另一个题里也有用到(这道题)。那就是,既然是求区间和,那我们可以利用前缀和的性质,分别求出 \([1, l-1]\)\([1, r]\) 的答案,然后相减即可。

然鹅如果暴力做的话还是会不断加不断删。那我们可以考虑把询问离线下来,把位置按大小排序,这样,可以保证每个字符串只加不减,而且只会加一次。这样,这道题就可以做了。复杂度 \(O(q\log q+\max(q, \lvert s \rvert) \log \lvert s \rvert)\)

代码:

#include<bits/stdc++.h>
#define ll long long
using namespace std;
const int N = 2e5+100, M = 5e5+100;

//快读
inline int read(){
    int x = 0;char ch = getchar();
    while(ch<'0' || ch>'9'){
        ch = getchar();
    }
    while(ch>='0'&&ch<='9'){
        x = x*10+ch-48;
        ch = getchar();
    }
    return x;
}

//AC自动机
struct Trie{
    int son[26];
}t[N];
int ps[N];
int idx = 1;//因为要建树,所以最好从1开始加点。
string ss[N];//因为后面还要用到原来的字符串,所以这里用string来存一下。
void insert(string s, int id){
    int lth = s.length();
    int u = 1;
    int v ;
    for(int i = 0; i<lth; i++){
        v = s[i]-'a';
        if(!t[u].son[v]){
            t[u].son[v] = ++idx;
        }
        u = t[u].son[v];
    }
    ps[id] = u;
}

queue<int> q;
int fail[N];

struct node{
    int nxt, to;
}edge[N];
int head[N], tot;
void add(int u, int v){
    edge[++tot].nxt = head[u];
    edge[tot].to = v;
    head[u] = tot;
}
void build(){
    int u = 1;
    for(int i = 0; i<26; ++i){
        if(t[u].son[i]){
            q.push(t[u].son[i]);
            fail[t[u].son[i]] = 1;
            add(1, t[u].son[i]);//fail树
        }
    }
    while(!q.empty()){
        u = q.front();
        q.pop();
        for(int i = 0; i<26; i++){
            if(t[u].son[i]){
                q.push(t[u].son[i]);
                int now = fail[u];
                while(now!=1&&!(t[now].son[i])){
                    now = fail[now];
                }
                if(t[now].son[i]){
                    now = t[now].son[i];
                }
                fail[t[u].son[i]] = now;
                add(now, t[u].son[i]);//fail树
            }
        }
    }
}
//fail树
int dfn[N], cntd;
int siz[N];
void dfs_fail(int u){
    siz[u] = 1;
    dfn[u] = ++cntd;
    for(int i = head[u]; i; i = edge[i].nxt){
        int v = edge[i].to;
        dfs_fail(v);
        siz[u]+=siz[v];
    }
}
//树状数组
int tc[N];
inline int lowbit(int x){
    return x&(-x);
}
void ins(int x, int pos){
    for(int i = pos; i<=cntd; i+=lowbit(i)){
        tc[i]+=x;
    }
}
int query(int pos){
    int sum = 0;
    for(int i = pos; i; i-=lowbit(i)){
        sum+=tc[i];
    }
    return sum;
}

char s[N];
int n, Q;
struct question{
    int p, k, id;
    bool lr;
    bool operator < (const question &b) const{
        return p<b.p;
    }
}qu[M<<1];
int totq;
int ans[M];

//添加字符串
void adds(int id){
    int u = 1, lth = ss[id].length();
    for(int i = 0; i<lth; ++i){
        int v = ss[id][i]-'a';
        ins(1, dfn[t[u].son[v]]);
        u = t[u].son[v];
    }
}
int main(){
    n = read(), Q = read();
    for(int i = 1 ;i<=n; i++){
        cin >> ss[i];
        insert(ss[i], i);
    }
    int K, l, r;
    for(int i = 1; i<=Q; ++i){
        l = read(), r = read(), K = read();
        qu[++totq].p = l-1 ,qu[totq].id = i, qu[totq].k = K,qu[totq].lr = 0;
        qu[++totq].p = r, qu[totq].id = i, qu[totq].k = K, qu[totq].lr = 1;
    }
    build();
    dfs_fail(1);
    sort(qu+1, qu+totq+1); 
    int now = 0;
    for(int i = 1; i<=totq; i++){
        while(now<qu[i].p){
            adds(++now);
        }
        int tk = qu[i].k;
        int u = ps[tk];
        if(qu[i].lr){//如果是右端点
            int tmp = query(dfn[u]+siz[u]-1)-query(dfn[u]-1);
            ans[qu[i].id] = tmp-ans[qu[i].id];
        }else{//是左端点
            ans[qu[i].id] = query(dfn[u]+siz[u]-1)-query(dfn[u]-1);
        }
    }
    for(int i = 1; i<=Q; i++){
        printf("%d\n", ans[i]);
    }
    return 0;
}