$#C240726C. c
网址(Website)
题解(Solution)
考虑Kruskal算法:
- 每次找到边权最小的边,合并所在联通块。
- 使用并查集维护联通块。
考虑 $dp$ ,类似于 BFS 的思路,由于 $m\le 18$ ,因此我们可以维护 $2^m$ 个状态(转化为十进制的 $int$ 类型存储),令 $f[i][j]$ 表示当前状态为 $i$ (即每一个字符串转化为 $int$ 类型存储之后值),$j$ 表示不同元素的个数,此时状态 $i$ 有 $j$ 个差异能走到的点的编号,最初时,有 $f[s][0]=i$ (其中 $s$ 表示字符串 $s[i]$ 转化为 $int$ 的大小)。
考虑计算答案,首先枚举距离 $j$ ,枚举状态 $i\in[0,2^m-1]$ ,再枚举一个 $k$ ,表示新的状态只在第 $k$ 位上与状态 $i$ 不同,如果这两个状态都存在可通往的一个点,那么就将两个点f[i^(1<<k)][j]
和f[i][j]
加入并查集判断,两点之间的距离为 $j*2+1$ 。
考虑更新 $f[i][j]$ ,同样枚举 $k$ ,则新的状态为 f[i^(1<<k)][j+1]
,分类讨论:
- 如果
f[i^(1<<k)[j+1]]
已经存有一个点,那么直接将f[i][j]
和f[i^(1<<k)][j+1]
加入并查集判断,两点距离为 $j*2+2$ 。 - 如果
f[i^(1<<k)[j+1]]
没有点,那么f[i^(1<<k)][j+1] = f[i][j]
。
这样一来,对于每个 $j$ ,我们都计算了有贡献的答案,最后输出 $ans$ 即可。
时间复杂度:$O(\frac{2^m \times m^2}{2})$ (其实$\frac{m^2}{2}$可以看作 ${m}$)。
⭐⭐⭐
关于一些问题:
如果两个点的二进制状态相同,那最开始就应该连边,可是代码中并没有具体的体现呢?
答:由于最开始的状态会直接被覆盖掉,所以相同状态的点在计算过程中只会算作一个点加入计算,由于边长为 $0$ 的边对答案没有贡献,我们可以认为在一开始的时候就将所有距离 $0$ 的点合成为一个点,只需要计算剩下的有贡献的点即可。
为什么这里的 $j$ 要选择选择对半搜索的方式?
答:我也不知道,但是我构想了其他的方案,发现时间复杂度都不好过。
代码(Code)
#include<bits/stdc++.h>
#include<queue>
#define For(i,l,r) for(int i=l;i<=r;i++)
#define Rof(i,l,r) for(int i=l;i>=r;i--)
using namespace std;
#define int long long
#define fi first
#define se second
inline int input(){int x;return cin>>x,x;}
const int inf =0x3f3f3f3f3f3f3f3f;
const int N = 301234;
const int M = 64;
int f[N][M],a[N];//f[i][j]:i状态走j边权可以到的点的下标
int n,m,ans,cnt;
int fa[N];inline int find(int x){
if(fa[x]==x) return x;
return fa[x] = find(fa[x]);
}
inline int Solve(){
// 初始化答案和计数器
cnt = n-1; ans = 0;
For(j,0,m/2){//枚举距离
// 计算每个二进制位的连接情况
For(i,0,(1<<m)-1){//枚举状态
if(f[i][j])
For(k,0,m-1)
if(f[i^(1<<k)][j]){
int x = find(f[i^(1<<k)][j]), y = find(f[i][j]);
if(x!=y) fa[x] = y, ans += j*2+1, cnt--;
}
}
// 如果已经连接完所有点,则退出
if(!cnt) break;
// 处理下一个二进制位的连接情况
For(i,0,(1<<m)-1){//枚举状态
if(f[i][j])
For(k,0,m-1){
if(f[i^(1<<k)][j+1]){
int x = find(f[i^(1<<k)][j+1]), y = find(f[i][j]);
if(x!=y) fa[x] = y, ans += j*2+2, cnt--;
}else{
f[i^(1<<k)][j+1] = f[i][j];
}
}
}
// 如果已经连接完所有点,则退出
if(!cnt) break;
}return ans;
}
//多测清空
signed main(){
freopen("c.in","r",stdin),
freopen("c.out","w",stdout);
ios::sync_with_stdio(false),cin.tie(0),cout.tie(0);
int T = input(), Tim = clock();
while(T--){
// 读取点的数量和每个点上字符串的长度
cin>>n>>m;
memset(f,0,sizeof f);
iota(fa+1,fa+n+1,1);
// 读取每个点的字符串并计算其对应的二进制表示
For(i,1,n){
char s[24];cin>>s;
a[i] = 0;
For(j,0,m-1) if(s[j]=='Y') a[i] += (1<<j);
f[a[i]][0] = i;
}
cout<<Solve()<<'\n';
}return cerr<<"TIME:"<<(clock()-Tim)/1000.,0;
}