#C241112B. 树上字符串
标签(Label)
- 数学
- 树上前缀和
- 容斥原理
网址(Website)
题目(Problem)
题解(Solution)
$\qquad$令 $a_x$ 表示 $x$ 对应的字符。
$\qquad$这道题是一道容斥题,由于暴力求出路径并计算时间会超,又发现 $|S|$ 很小,大概猜想出来一个 $O((n+q)|S|^2)$ 的做法(其实我根本没往这方面想)
$\qquad$由于求的是路径,考虑使用 $\text{LCA}$ 来解决这个问题,从暴力的角度来推理,我们只能先预处理出 $x\to lca$ 的路径,然后在接上 $lca\to y$ 的路径,考虑要怎么做才能在 $O(|S|^2)$ 时间内做到这个操作,令 $f_{x,l,r}\vert l\le r$ 表示从 $x$ 到根的子序列上有多少个 $S[l,r]$ ,这个东西可以通过深搜的时候预处理:枚举每一个 $S_i$ ,如果 $a_x=S_i$ ,那么当前的 $f_{x,i,i}$ 肯定为 $1$ ,对于$j\in[i+1,m]$ ,有 $f_{x,i,j}\leftarrow f_{fa,i+1,j}$ ($fa$ 是父节点)。除此之外我们发现还要记录从根节点到 $x$ 的数据,然后又惊奇的发现空间根本不够,于是干脆就用 $f_{x,r,l}\vert l\le r$ 表示从根节点到 $x$ 的序列的子序列上有多少个 $S[l,r]$ (注意是反过来的),由于当 $l=r$ 的时候和顺序没有关系,所以也不需要考虑重复的问题,转移和第一种情况类似。
$\qquad$那么就只需要对于每个询问的路径 $x\to y$ ,先求出 $lca(x,y)$ (后面就直接使用 $lca$ 表示) ,然后我们需要容斥出 $x\to lca$ 的答案和 $lca\to y$ 的答案即可。先考虑如果已经求出两个点 $x$ 和 $y$ 的 $f$ 数组,如何将这两条路接起来,假定 $lca=1$ ,对于每一个 $f_{x,1,k}$ ,可以接上每一个 $f_{y,m,k+1}|k+1\le m$ ,贡献就是相乘,即 $h_{1,m}\leftarrow f_{x,1,k}\times f_{y,k+1,n}$ 。同理我们原先计算的 $f_{x,l,r}$ 是 $x\to 1$ 的答案,也需要减少前面 $lca\to 1$ 的贡献。注意,$x$ 和 $y$ 两个点一个要减去 $lca\to 1$ 的贡献,一个减去 $fa[lca]\to 1$ 的贡献,因为 $lca$ 也有贡献。
出题者题解
算法分析
考虑把一条询问拆成 $u$ 到 $\text{lca}$ 和 $\text{lca}$ 到 $v$,前者需要求出 $S[1,i]$ 作为子序列出现了多少次,后者需要求出 $S[i,|S|]$ 作为子序列出现了多少次。两者是对称的,因此只讲前者。
对于每个 $u$,求出 $u$ 到根的路径上有多少个子序列为 $S[l,r]$,记为 $f_{u,l,r}$,这个数组容易在 $O(n|S|^2)$ 内的时间预处理。查询时考虑 $f_{u,1,i}$ 代表的 $S[1,i]$ 在 $u$ 到 $\text{lca}$ 和 $\text{lca}$ 父亲到根两条路径之间的分布。比如 $u$ 到 $\text{lca}$ 上有 $S[1,j]$,$\text{lca}$ 父亲到根有 $S[j+1,i]$,后者预处理时已经被求出来了。因此可以容斥掉这部分贡献。总时间复杂度 $O(n\log n + (n+q)|S|^2)$。
代码(Code)
40分
#include<bits/stdc++.h>
#include<vector>
#define For(i,l,r) for(int i=l;i<=r;i++)
#define Rof(i,l,r) for(int i=l;i>=r;i--)
using namespace std;
#define P pair<int,int>
#define int long long
#define x first
#define y second
inline int rd(){
char c;bool f=false;while(!isdigit(c=getchar()))f=c=='-';int x=c^48;
while(isdigit(c=getchar())){x=(((x<<2)+x)<<1)+(c^48);}return f?-x:x;
}const int inf = 0x3f3f3f3f3f3f3f3f;
const int mod = 998244353;
const int N = 201234;
bool ST;
inline void add(int &x,int y){if((x+=y)>=mod)x-=mod;}
char a[N],s[N],b[N];
vector<int> ft[N];
int n,m,q,fa[N];
struct LCA{
int f[N][32],dfn[N],idx;
int dep[N];
void dfs(int x,int F){
f[dfn[x]=++idx][0] = F, dep[x]=dep[F]+1, fa[x]=F;
for(auto y:ft[x]) if(y^F) dfs(y,x);
}
inline int cmn(int x,int y){return dfn[x]<dfn[y]?x:y;}
inline int lca(int x,int y){
if(x==y) return x;
if((x=dfn[x])>(y=dfn[y]))swap(x,y);
int d = __lg(y-x++);
return cmn(f[x][d], f[y-(1<<d)+1][d]);
}
void init(){
dfs(1,0);
For(j,1,18) for(int i=1;i+(1<<j)-1<=n;i++)
f[i][j] = cmn(f[i][j-1], f[i+(1<<(j-1))][j-1]);
}inline int dis(int x,int y){return dep[x]+dep[y]-(dep[lca(x,y)]<<1);}
}G;
int f[36];
void Solve(){
n=rd(),q=rd();For(i,1,n-1){
int x=rd(),y=rd();
ft[x].emplace_back(y);
ft[y].emplace_back(x);
}scanf("%s%s",a+1,s+1);
m=strlen(s+1);
G.init();
while(q--){
int x=rd(),y=rd(),F=G.lca(x,y);
int d = G.dep[x]+G.dep[y]-(G.dep[F]<<1)+1;
b[G.dep[x]-G.dep[F]+1] = a[F];
for(int i=1;x^F;x=fa[x],i++) b[i] = a[x];
for(int i=d;y^F;y=fa[y],i--) b[i] = a[y];
memset(f,0,sizeof f),
f[0] = 1;
For(i,1,d) Rof(j,m,1){
if(b[i]==s[j])
add(f[j], f[j-1]);
}
printf("%lld\n",f[m]);
}
}
bool ED;
signed main(){
cerr<<abs(&ST-&ED)/1024./1024.<<"MB\n";
freopen("treestr.in","r",stdin);
freopen("treestr.out","w",stdout);
int Tt=1;double Tim=clock();
while(Tt--) Solve();
return cerr<<"TIME:"<<(clock()-Tim)/CLOCKS_PER_SEC,0;
}
100分
#include<bits/stdc++.h>
#include<vector>
#define For(i,l,r) for(int i=l;i<=r;i++)
#define Rof(i,l,r) for(int i=l;i>=r;i--)
using namespace std;
#define P pair<int,int>
#define x first
#define y second
inline int rd(){
char c;bool f=false;while(!isdigit(c=getchar()))f=c=='-';int x=c^48;
while(isdigit(c=getchar())){x=(((x<<2)+x)<<1)+(c^48);}return f?-x:x;
}const int inf = 0x3f3f3f3f;
const int mod = 998244353;
const int N = 101234;
bool ST;
inline void add(int &x,int y){if((x+=y)>=mod)x-=mod;}
inline void sub(int &x,int y){if((x-=y)<0)x+=mod;}
char a[N],s[N];
vector<int> ft[N];
int n,m,q,fa[N];
int g[N][32][32];
struct LCA{
int f[N][32],dfn[N],idx;
void dfs(int x,int F){
f[dfn[x]=++idx][0] = F, fa[x]=F;
memcpy(g[x],g[F],sizeof g[x]);
Rof(i,m,1){
if(a[x]^s[i]) continue;
For(j,1,i-1) add(g[x][i][j], g[F][i-1][j]);
add(g[x][i][i], 1);
For(j,i+1,m) add(g[x][i][j], g[F][i+1][j]);
}
for(auto y:ft[x]) if(y^F) dfs(y,x);
}
inline int cmn(int x,int y){return dfn[x]<dfn[y]?x:y;}
inline int lca(int x,int y){
if(x==y) return x;
if((x=dfn[x])>(y=dfn[y]))swap(x,y);
int d = __lg(y-x++);
return cmn(f[x][d], f[y-(1<<d)+1][d]);
}
void init(){
dfs(1,0);
For(j,1,18) for(int i=1;i+(1<<j)-1<=n;i++)
f[i][j] = cmn(f[i][j-1], f[i+(1<<(j-1))][j-1]);
}
}G;
int h[2][32];
void Solve(){
n=rd(),q=rd();For(i,1,n-1){
int x=rd(),y=rd();
ft[x].emplace_back(y);
ft[y].emplace_back(x);
}scanf("%s%s",a+1,s+1), m=strlen(s+1);
G.init();
h[0][0] = h[1][m+1] = 1;
while(q--){
int x=rd(),y=rd(),lca=G.lca(x,y);
For(i,1,m){
h[0][i] = g[x][1][i];
For(j,0,i-1)
sub(h[0][i], 1ll*h[0][j]*g[fa[lca]][j+1][i]%mod);
}
Rof(i,m,1){
h[1][i] = g[y][m][i];
Rof(j,m+1,i+1)
sub(h[1][i], 1ll*h[1][j]*g[lca][j-1][i]%mod);
}
int ans = 0;
For(i,0,m) add(ans, 1ll*h[0][i]*h[1][i+1]%mod);
printf("%d\n",ans);
}
}
bool ED;
signed main(){
cerr<<abs(&ST-&ED)/1024./1024.<<"MB\n";
freopen("treestr.in","r",stdin);
freopen("treestr.out","w",stdout);
int Tt=1;double Tim=clock();
while(Tt--) Solve();
return cerr<<"TIME:"<<(clock()-Tim)/CLOCKS_PER_SEC,0;
}