$#C240808C. 简单数论题
网址(Website)
题目(Problem)
题目描述
给定一个长度为 $ n $ 的正整数数列 $ a_1, a_2, \ldots, a_n $。
定义一个不交区间集为若干个不相交的区间组成的集合 $ [l_1, r_1], [l_2, r_2], \ldots, [l_k, r_k] $,其中 $ 1 \leq l_1 \leq r_1 < l_2 \leq r_2 < \ldots < l_k \leq r_k \leq n $。
给定一个不交区间集,首先对每一个区间计算出里面所有数字的 $ \gcd $,记作 $ \gcd_{i \in [l_j, r_j]} a_i $。我们称一个不交区间集是好的,当且仅当:
$ \gcd_{i \in [l_1, r_1]} a_i = \gcd_{i \in [l_2, r_2]} a_i = \ldots = \gcd_{i \in [l_k, r_k]} a_i $
两个不交区间集不同当且仅当里面包含的区间不同。请对于每个 $ a_i $,求出包含它的不同的好的不交区间集的个数。
输入格式
- 从文件
gcd.in读入数据。 - 输入的第一行包含一个正整数 $ n $,表示数列的长度。
- 接下来一行包含 $ n $ 个正整数,代表数列里面的每个数。
输出格式
- 输出到文件
gcd.out。 - 输出 $ n $ 行,每行一个正整数,代表包含每个数字的不同的好的不交区间集的个数。
- 答案对 $ 998244353 $ 取模。
样例
输入数据 1
3
2 3 2输出数据 1
4
4
4样例1解释
所有好的不交区间集为:
- {[1,1]}
- {[2,2]}
- {[3,3]}
- {[1,2]}
- {[1,3]}
- {[2,3]}
- {[1,1], [3,3]}
数据规模与约定
对于所有数据,有:
- $ 1 \leq n \leq 10^5 $
- $ 1 \leq a_i \leq 10^9 $
测试点
| 测试点 | $ n \leq $ | 特殊性质 |
|---|---|---|
| $1\sim6$ | $10$ | 无 |
| $7\sim12$ | $10^5$ | 所有数字均为 2 的整数次幂 |
| $13\sim20$ | $10^5$ | 无 |
题解(Solution)
你这家伙怎么自己都不会打部分分啊?(恼)(2024.08.08)
花费了大概一天的时间把这道题打了一遍但是还是有很多地方没有理解清楚。(2024.08.09)
如此看来,写题解还是很重要的,可以增进理解。(2024.08.12)
哇!题解写的什么抽象代码啊!每个线段树计算答案不统一写,害得我理解了几个小时。后面自己推了一遍,最后还是按自己的做法做出来了!(2024.08.14)
我的题解
通过 $RMQ$ 计算一段区间的 $gcd$ ,固定左端点 $i$ ,枚举 $2$ 的进制数 $j$ ,则有 fr[i][j] = __gcd(fr[i][j-1],fr[i+(1<<(j-1))][j-1]);固定右端点,同样有 fl[i][j] = gcd(fl[i][j-1],fl[i-(1<<(j-1))][j-1]); 。枚举端点 $O(n)$ ,枚举进制数 $O(logn)$ ,求 $gcd$ 复杂度 $O(loga)$ ,总复杂度 $O(nlognloga)$ 。
注:这里的 fl 和 fr 指的是固定右端点向左和固定左端点向右的意思。
性质1:集合里面加入数字后 $gcd$ 若改变则至少除 $2$ ,所以不同的 $gcd$ 数量最多 $O(nloga)$ 级别。
性质2:对于每一个点 $l$ ,以之为区间的左边界,在端点从 $r$ 转移到 $r+1$ 时,$gcd$ 要么不变,要么变化,因此,对于 $j\in [l,n]$ ,区间 $[l,j]$ 的 $gcd$ 会随着 $j$ 的增长呈阶梯式下降趋势。
令三元组 $gl[g]=\{id,l,r\}$ 表示固定当前区间的左端点 $id$ ,满足 $\forall j\in[l,r],GCD_{i=id}^{i\le j}a[i]=g$ ;同理,令三元组 $gr[g]=\{id,l,r\}$ 表示固定当前区间的右端点 $id$ ,满足 $\forall j\in[l,r],(GCD_{i=id}^{i\ge j}a[i])=g$ 。(解释:$gl$ 可以理解为左边界为 $id$ ,右边界 $j$ 满足 $j\in[l,r]$ 的所有区间 $[i,j]$ 的区间 $gcd$ 为 $g$ ,$gr$ 同理)
性质3:对于一个 $g$ ,当前点 $a_i$ 对其的贡献为包含这个点的区间的个数,同时也等于总的 $gcd=g$ 的区间的个数减去不包含这个点的区间的个数。
考虑维护不包含当前点的区间的个数,令 $f[i]$ 表示从 $1$ 到 $i$ 范围内满足 $gcd=g$ 的区间组成的方案数,$g[i]$ 表示从 $i$ 到 $n$ 范围内满足 $gcd=g$ 的区间组成的方案数,那么对于一个点 $x$ ,包含它的区间的方案数为 $f[n]-f[x-1]\times g[x+1]$ ($f[n]$ 刚好就是总方案数)。
考虑对于每个 $gcd$ 求出 $f[i]$ 和 $g[i]$ 。对于 $f[i]$ ,将 $gcd=g$ 区间(用 $gl[g]$ 存储)的右端点 $id$ 从小到大排序,对于每一个三元组 $(i,l,r)$ ,分类讨论(这里以 $f[i]$ 举例,$g[i]$ 同理):
- 对于右端点小于 $l-1$ 的区间,它们可以和所有 $j\in[l,r]$ 的区间 $[j,i]$ 组合,贡献为 $\sum_{j=0}^{l-1} (f[j]\times (r-l+1))$ ,其中 $r-l+1$ 可以提出来,只需要维护 $f[j]$ 的区间和即可。
- 对于右端点在 $[l,r]$ 之间的区间,它们只能和自己右端点右边的区间组合,贡献为 $\sum_{j=l}^{r}(f[j]\times (r-j))$ ;
- 对于右端点大于 $r$ 的点,当前区间必定相交,没有贡献。
综上,对于 $f[i]$ 和 $g[i]$ ,有:
$$
f[i]=\sum_{j=0}^{l-1} (f[j]\times (r-l+1))+\sum_{j=l}^{r}(f[j]\times (r-j))\\
g[i]=\sum_{j=r+1}^{n+1}(g[j]\times(r-l+1))+\sum_{j=l}^r(g[j]\times(j-l))
$$
$$
f[i]=(\sum_{j=0}^{l-1} f[j])\times(r-l+1) + (\sum_{j=l}^rf[j])\times r-\sum_{j=l}^{r}f[j]\times j\\
g[i]=(\sum_{j=r+1}^{n+1} g[j])\times(r-l+1) - (\sum_{j=l}^rg[j])\times l+\sum_{j=l}^rg[j]\times j
$$
分别用线段树维护 $f[i]$、$f[i]\times i$、$g[i]$、$g[i]\times i$ 即可,对于点 $i$ 的答案的贡献就是总方案数 - $(\sum_{j=0}^{i-1}f[j])\times (\sum_{j=i+1}^{n+1}g[j])$ 。
发现对于每一个 $gcd$ 都要更新一遍 $1\sim n$ 的答案,不满足时间复杂度。考虑优化。
性质4:对于排序后的三元组 $(i,l_i,r_i)$ 和 $(j,l_j,r_j)$ ,当前 $gcd$ 对 $i$ 和 $j$ 中间的数(即 $k\in[i+1,j-1]$ )的贡献是相同的。
对于每个 $gcd$ 我们都要枚举对应的三元组,两个三元组之间的数直接用前缀和记录贡献,最后计算即可。
问题1:为什么 $(\sum_{j=0}^{i-1}f[j])\times (\sum_{j=i+1}^{n+1}g[j])$ 是从 $0$ 和 $n+1$ 开始?算上 $f[0]=1$ 我可以理解,相当于 $i$ 右边的区间自由组合的方案数, $g[n+1]=1$ 算上我也可以理解,但是乘的时候算上了 $f[0]\times g[n+1]$ ,这是什么意思?左边右边都不选?那怎么保证 $gcd= g$ 呢?
解答:(2024.08.16)后面才发现,计算 $sum=\sum_{j=0}^{n+1}f[j]$ 的时候也计算了 $f[0]\times g[n+1]$ ,二者相减刚好就抵消掉了。
问题2:为什么最后代码打出来发现 $\sum_{j=0}^{n+1}f[j]$ 不等于 $\sum_{j=0}^{n+1}g[j]$ 呢?最大的问题是:令 $sum$ 等于 $\sum_{j=0}^{n+1}f[j]$ 或 $\sum_{j=0}^{n+1}g[j]$ 最后计算出来的答案是一样的,但是这两个值可能不同(详见我的代码第 $132$ 行)。

题解
集合里面加入数字后 $gcd$ 若改变则至少除 $2$ ,所以不同的 $gcd$ 数量最多 $O(nloga)$ 级别。
固定左端点,通过倍增预处理区间 $gcd$ 可以在 $O(nlognloga)$ 的时间复杂度内找出每一个 $gcd$ 发生变化的右端点,记录每一个这样的极大区间。
枚举 $gcd$,考虑以 $fi$ 代表 $1\sim i$ 中 $gcd$ 等于枚举的数的不交区间集的方案数,这是个简单的 $DP$,可以将区间按左端点排序后利用线段树优化。复杂度与区间个数有关,总复杂度均摊下来也是 $O(nlognloga)$。
对后缀按相同方式做一遍,则包含 $i$ 位置的方案数就是总数减去只在 $[1,i−1],[i+1,n]$ 中选区间的方案数,后者可以直接通过前后缀计算得到的方案数乘算获得。
代码(Code)
题解代码
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
struct Info{
int id, p, _p;
Info(int id = 0, int p = 0, int _p = 0): id(id), p(p), _p(_p){}
};
const int MAXN = 100000+5;
const ll mod = 998244353;
int n, cnt;
int a[MAXN];
int lx[MAXN][18], rx[MAXN][18];
ll lsum[MAXN<<2], lsum2[MAXN<<2], rsum[MAXN<<2], rsum2[MAXN<<2];
ll ans[MAXN], b[MAXN];
vector<Info> lv[MAXN*10], rv[MAXN*10];
map<int, int> mp;
int gcd(int a, int b){
return b ? gcd(b, a%b) : a;
}
int find(int g){
if(mp.find(g) == mp.end()) mp[g] = ++cnt;
return mp[g];
}
void push_up(int x){
int ls = x<<1, rs = x<<1|1;
lsum[x] = lsum[ls] + lsum[rs];
rsum[x] = rsum[ls] + rsum[rs];
lsum2[x] = lsum2[ls] + lsum2[rs];
rsum2[x] = rsum2[ls] + rsum2[rs];
}
void build(int x, int l, int r){
if(l == r){
if(l == 0){
lsum[x] = 1;
}
if(r == n+1){
rsum[x] = 1;
}
return;
}
int mid = (l+r)/2;
build(x<<1, l, mid);
build(x<<1|1, mid+1, r);
push_up(x);
}
ll qlsum(int x, int l, int r, int L, int R){
if(L <= l && r <= R){
return lsum[x];
}
int mid = (l+r)/2;
ll res = 0;
if(L <= mid){
res += qlsum(x<<1, l, mid, L, R);
}
if(R > mid){
res += qlsum(x<<1|1, mid+1, r, L, R);
}
return res;
}
ll qlsum2(int x, int l, int r, int L, int R){
if(L <= l && r <= R){
return lsum2[x];
}
int mid = (l+r)/2;
ll res = 0;
if(L <= mid){
res += qlsum2(x<<1, l, mid, L, R);
}
if(R > mid){
res += qlsum2(x<<1|1, mid+1, r, L, R);
}
return res;
}
void updl(int x, int l, int r, int p, ll v){
if(l == r){
lsum[x] = v;
lsum2[x] = v*(n-p)%mod;
return;
}
int mid = (l+r)/2;
if(p <= mid){
updl(x<<1, l, mid, p, v);
}else{
updl(x<<1|1, mid+1, r, p, v);
}
lsum[x] = lsum[x<<1] + lsum[x<<1|1];
lsum2[x] = lsum2[x<<1] + lsum2[x<<1|1];
}
ll qrsum(int x, int l, int r, int L, int R){
if(L <= l && r <= R){
return rsum[x];
}
int mid = (l+r)/2;
ll res = 0;
if(L <= mid){
res += qrsum(x<<1, l, mid, L, R);
}
if(R > mid){
res += qrsum(x<<1|1, mid+1, r, L, R);
}
return res;
}
ll qrsum2(int x, int l, int r, int L, int R){
if(L <= l && r <= R){
return rsum2[x];
}
int mid = (l+r)/2;
ll res = 0;
if(L <= mid){
res += qrsum2(x<<1, l, mid, L, R);
}
if(R > mid){
res += qrsum2(x<<1|1, mid+1, r, L, R);
}
return res;
}
void updr(int x, int l, int r, int p, ll v){
if(l == r){
rsum[x] = v;
rsum2[x] = v*p%mod;
return;
}
int mid = (l+r)/2;
if(p <= mid){
updr(x<<1, l, mid, p, v);
}else{
updr(x<<1|1, mid+1, r, p, v);
}
rsum[x] = rsum[x<<1] + rsum[x<<1|1];
rsum2[x] = rsum2[x<<1] + rsum2[x<<1|1];
}
void add(int l, int r, ll v){
b[l] += v;
b[r+1] -= v;
}
int main(){
freopen("gcd.in", "r", stdin);
freopen("gcd.out", "w", stdout);
cin.tie(0);
ios::sync_with_stdio(false);
cin >> n;
for(int i = 1; i <= n; i++)cin >> a[i];
for(int i = 1; i <= n; i++){
lx[i][0] = a[i];
rx[i][0] = a[i];
}
for(int j = 1; j < 18; j++){
for(int i = 1; i <= n; i++){
if(i-(1<<j)+1 >= 1){
lx[i][j] = gcd(lx[i][j-1], lx[i-(1<<(j-1))][j-1]);
}
if(i+(1<<j)-1 <= n){
rx[i][j] = gcd(rx[i][j-1], rx[i+(1<<(j-1))][j-1]);
}
}
}
for(int i = 1; i <= n; i++){
int p = i, g = a[i];
while(p >= 1){
int _p = p;
for(int j = 17; j >= 0; j--){
if(p-(1<<j)+1 >= 1 && gcd(g, lx[p][j]) == g){
p = p - (1<<j);
}
}
lv[find(g)].push_back(Info(i, p+1, _p));
g = gcd(g, a[p]);
}
}
for(int i = 1; i <= n; i++){
int p = i, g = a[i];
while(p <= n){
int _p = p;
for(int j = 17; j >= 0; j--){
if(p+(1<<j)-1 <= n && gcd(g, rx[p][j]) == g){
p = p + (1<<j);
}
}
rv[find(g)].push_back(Info(i, p-1, _p));
g = gcd(g, a[p]);
}
}
build(1, 0, n+1);
for(auto bp : mp){
int g = bp.first, id = bp.second;
sort(lv[id].begin(), lv[id].end(), [&](Info a, Info b){
return a.id < b.id;
});
ll sum = 1;
vector<int> vec;
for(auto info : lv[id]){
int x = info.id, p = info.p, _p = info._p;
vec.push_back(x);
ll res = qlsum(1, 0, n+1, 0, p-1)*(_p-p+1) + qlsum2(1, 0, n+1, p, _p) - qlsum(1, 0, n+1, p, _p)*(n-_p);
res = (res % mod + mod) % mod;
sum += res;
updl(1, 0, n+1, x, res);
}
sum %= mod;
sort(rv[id].begin(), rv[id].end(), [&](Info a, Info b){
return a.id > b.id;
});
for(auto info : rv[id]){
int x = info.id, p = info.p, _p = info._p;
vec.push_back(x);
ll res = qrsum(1, 0, n+1, p+1, n+1)*(p-_p+1) + qrsum2(1, 0, n+1, _p, p) - qrsum(1, 0, n+1, _p, p)*_p;
res = (res % mod + mod) % mod;
updr(1, 0, n+1, x, res);
}
sort(vec.begin(), vec.end());
vec.erase(unique(vec.begin(), vec.end()), vec.end());
int x = 0;
for(auto _x : vec){
ll tmp = (qlsum(1, 0, n+1, 0, _x-1)%mod)*(qrsum(1, 0, n+1, _x+1, n+1)%mod)%mod;
ans[_x] += sum-tmp;
if(_x - x > 1){
add(x+1, _x-1, (sum - (qlsum(1, 0, n+1, 0, x)%mod)*(qrsum(1, 0, n+1, _x, n+1)%mod))%mod);
}
x = _x;
}
if(x < n){
add(x+1, n, sum - qlsum(1, 0, n+1, 0, x));
}
for(auto info : lv[id]){
int x = info.id;
updl(1, 0, n+1, x, 0);
}
for(auto info : rv[id]){
int x = info.id;
updr(1, 0, n+1, x, 0);
}
}
for(int i = 1; i <= n; i++){
b[i] += b[i-1];
ans[i] += b[i];
ans[i] = (ans[i] % mod + mod) % mod;
printf("%d\n", ans[i]);
}
return 0;
}我的代码
#include<bits/stdc++.h>
#include<vector>
#include<map>
#include<set>
#define For(i,l,r) for(int i=l;i<=r;i++)
#define Rof(i,l,r) for(int i=l;i>=r;i--)
using namespace std;
#define P pair<int,int>
#define int long long
#define x first
#define y second
#define gcd __gcd
inline int rd(){int x;return cin>>x,x;}
const int inf = 0x3f3f3f3f3f3f3f3f;
const int mod = 998244353;
const int N = 201234;
bool ST;
int n,a[N],ans[N];
//用 map 存储 gcd ,离散化
map<int,int> mp;
int tot = 0;
inline int find(int x){
if(!mp.count(x)) mp[x] = ++tot;
return mp[x];
}
//预处理 RMQ
int fl[N][64],fr[N][64];//fl是向左,fr是向右
void pre_RMQ(){
For(i,1,n) fl[i][0] = fr[i][0] = a[i];
For(j,1,24) for(int i=n;i-(1<<j)+1>=1;i--)
fl[i][j] = gcd(fl[i][j-1],fl[i-(1<<(j-1))][j-1]);
For(j,1,24) for(int i=1;i+(1<<j)-1<=n;i++)
fr[i][j] = gcd(fr[i][j-1],fr[i+(1<<(j-1))][j-1]);
}
//倍增求 gcd 及其区间
struct I{int id,x,y;};
vector<I> gl[N],gr[N];
void Get(){
//gl[g]={i,l,r} 固定右端点 i ,对于 j 属于 [l,r] 满足 GCD{a[j~i]} = g
For(i,1,n){
int l = i, g = a[i];
while(l>=1){
int r = l;
Rof(j,24,0)
if(l-(1<<j)+1>=1 && gcd(g,fl[l][j])==g)
l -= (1<<j);
gl[find(g)].push_back({i,l+1,r});//注意:当前的 l 已经超出限制
g = gcd(g, a[l]);
}
}
//gr[g]={i,l,r} 固定左端点 i ,对于 j 属于 [l,r] 满足 GCD{a[i~j]} = g
For(i,1,n){
int r = i, g = a[i];
while(r<=n){
int l = r;
Rof(j,24,0)
if(r+(1<<j)-1<=n && gcd(g,fr[r][j])==g)
r += (1<<j);
gr[find(g)].push_back({i,l,r-1});//注意:当前的 r 已经超出限制
g = gcd(g ,a[r]);
}
}
}
#define mid ((l+r)>>1)
#define ls (p<<1)
#define rs (p<<1|1)
struct Segment_Tree{
struct Tr{int sum,sum2;}tr[N<<2|3];
void pushup(Tr &T,Tr L,Tr R){T = {L.sum+R.sum, L.sum2+R.sum2};}
void update(int p,int l,int r,int pos,int val){
if(l==r){
tr[p].sum2 = val*pos%mod;
tr[p].sum = val;return;}
if(pos<=mid) update(ls,l,mid,pos,val);
else update(rs,mid+1,r,pos,val);
pushup(tr[p],tr[ls],tr[rs]);
}
inline Tr ask(int p,int l,int r,int L,int R){
if(L<=l && r<=R) return tr[p];
if(R<=mid) return ask(ls,l,mid,L,R);
if(L>mid) return ask(rs,mid+1,r,L,R);
Tr T;return pushup(T,ask(ls,l,mid,L,mid),ask(rs,mid+1,r,mid+1,R)),T;
}
}T1,T2;
int c[N];//维护前缀和
void Solve(){
//建树,更新初始值(注意!!!这不是一种方案!!!)
T1.update(1,0,n+1,0,1);
T2.update(1,0,n+1,n+1,1);
//枚举 gcd O(nlogn)
for(auto pg:mp){
int id = pg.y;//当前 gcd 对应的 id
sort(gl[id].begin(),gl[id].end(),[&](I a,I b){return a.id<b.id;});//起始点从小到大
sort(gr[id].begin(),gr[id].end(),[&](I a,I b){return a.id>b.id;});//起始点从大到小
vector<int> vec;//记录出现的区间的id
for(auto p:gl[id]){//p:区间interval
int i=p.id,l=p.x,r=p.y;
vec.emplace_back(i);
int res = T1.ask(1,0,n+1,0,l-1).sum * (r-l+1)
+ T1.ask(1,0,n+1,l,r).sum * r
- T1.ask(1,0,n+1,l,r).sum2;
res = (res%mod+mod)%mod;
T1.update(1,0,n+1,i,res);
}
for(auto p:gr[id]){//p:区间interval
int i=p.id,l=p.x,r=p.y;
vec.emplace_back(i);
int res = T2.ask(1,0,n+1,r+1,n+1).sum * (r-l+1)
- T2.ask(1,0,n+1,l,r).sum * l
+ T2.ask(1,0,n+1,l,r).sum2;
res = (res%mod+mod)%mod;
T2.update(1,0,n+1,i,res);
}
//将出现的区间id排序加去重
sort(vec.begin(),vec.end());
vec.erase(unique(vec.begin(),vec.end()),vec.end());
auto add = [&](int l, int r, int v){c[l] += v, c[r+1] -= v;};//前缀和
int sum = T1.ask(1,0,n+1,0,n+1).sum;
int l = 0;
for(auto r:vec){//对 x 贡献为 f(0,x-1)*g(x+1,n+1)
int tmp = (T1.ask(1,0,n+1,0,r-1).sum%mod) * (T2.ask(1,0,n+1,r+1,n+1).sum%mod) % mod;
ans[r] += sum - tmp;//总数减去没有用到当前 x 的个数
if(r-l>1){//如果 l 和 r 之间有数字
//维护前缀和加减 (对[l+1,r-1]区间所有的点的贡献都是...(见下))
add(l+1, r-1, (sum - (T1.ask(1,0,n+1,0,l).sum%mod)*(T2.ask(1,0,n+1,r,n+1).sum%mod))%mod);
}l = r;
}if(l<n) add(l+1, n, sum - (T1.ask(1,0,n+1,0,l).sum%mod));
//清空线段树
for(auto p:gl[id]) T1.update(1,0,n+1,p.id,0);
for(auto p:gr[id]) T2.update(1,0,n+1,p.id,0);
}
}
bool ED;
//看到题目就开走,不跑样例是小狗
//不打暴力是小狗
signed main(){
ios::sync_with_stdio(false),cin.tie(0),cout.tie(0);
cerr<<abs(&ST-&ED)/1024./1024.<<"MB\n";
freopen("gcd.in","r",stdin);
freopen("gcd.out","w",stdout);
cin>>n;int Tim = clock();
For(i,1,n) a[i] = rd();
pre_RMQ(),Get(),Solve();
For(i,1,n){
c[i] += c[i-1];
ans[i] += c[i];
ans[i] = (ans[i]%mod+mod)%mod;
cout<<ans[i]<<'\n';
}return cerr<<"TIME:"<<(clock()-Tim)/1000.,0;
}
