题意
题目链接。
使用 coze.com 的 gpt-4 进行翻译和格式美化.
一个长度为$M$的数组$A$被称为前缀平衡,如果它满足以下条件:
- 让$S_A$表示在$A$中出现的所有元素的集合。
- 对于每个$x \in S_A$,和索引$i(1 \leq i \leq M)$,让$f_i(x)$表示$x$在$[A_1, A_2, \dots, A_i]$中出现的次数。也就是说,$f_i(x)$表示$x$在$A$的前$i$个元素中的频率。
- 然后,$A$被称为前缀平衡的如果,对于每个整数三元组$(x, y, i)$使得$x, y \in S_A$ 和 $1 \leq i \leq M$,我们有$|f_i(x) – f_i(y)| \leq 1$。
例如,数组[1, 2, 1] 和 [1, 3, 2, 3] 是前缀平衡的,但是 [2, 3, 2, 4] 不是(长度为3的前缀包含两个2且没有4)。
你将得到一个长度为$N$的数组$A$和一个整数$M$。$A$中的每个元素都在0和$M$之间。
令$K$表示$A$中0的数量。通过将$A$中的每个0替换为一个在1和$M$之间的整数,可以得到$M^K$个数组。这$M^K$个数组中有多少个是前缀平衡的?
答案可能会很大,所以找它对998244353取模的值。
解析
首先容易发现,对于一个合法的前缀平衡的数列,如果这个数列有 $m$ 种数,则这个数列中的 $a_{km+1},a_{km+2},\dots,a_{(k+1)m}]$ 这些位置,必定是每个数都出现且仅出现一次。
我们可以考虑枚举 $m$ 的数值,记已给数列中非 $0$ 数的种类为 $sum$,则 $m$ 的范围为 $\max(1,sum) \sim \min(N,M)$。但这些 $m$ 中有一些 $m$ 是不合法的。
不妨先考虑对于一个 $m$ 怎么统计答案:我们对 $0$ 做前缀和后,可以 $O(1)$ 得到每个区间的 $0$ 的数量,让母后我们对于数列中每个长度为 $m$ 的区间统计答案(略去对于每个区间怎么统计答案,可以结合代码理解),再额外处理一下最后长度不固定的那一块即可,时间复杂度为 $O(n/m)$。对所有 $m$ 求和,最终时间复杂度为 $O(n\log n)$.
关键是如何判断一个 $m$ 是否合法。发现问题可以转化为对于一个区间,如何快速判断这个区间内有无重复的数。
记 $R_i$ 表示一个最小的数满足 $i\sim R_i$ 中有重复的数。则:
- 若 $a_i =0,R_i=R_{i+1}$
- 若 $a_i \neq 0$,记 $k$ 为 $i$ 右侧第一个与 $a_i$ 相等的数(这个也可以 $O(n)$ 预处理,此处略去),则 $R_i=\min (R_{i+1},k)$。
然后对于区间 $[km+1,(k+1)m]$,我们可以通过 $R_{km+1} \le (k+1)m$ 是否成立来判断该区间内是否有相同的数。
代码
#include<bits/stdc++.h>
#define ll long long
#define int long long
#define pi pair<int,int>
#define fi first
#define se second
#define mk make_pair
#define pb push_back
using namespace std;
const int N=5e5+5;
const int mod=998244353;
int T,n,M,fac[N],a[N];
ll ksm(ll a,ll b){
ll ret=1;
while(b){
if(b&1) ret=ret*a%mod;
a=a*a%mod;
b>>=1;
}
return ret;
}
int C(int mm,int nn){
return fac[mm]*ksm(fac[nn],mod-2)%mod*ksm(fac[mm-nn],mod-2)%mod;
}
int sum,zero[N],vis[N],r[N],nxt[N];
void solve(){
scanf("%d%d",&n,&M);sum=0;
for(int i=1;i<=n;i++) nxt[i]=0;
for(int i=1;i<=M;i++) vis[i]=0;
for(int i=1;i<=n;i++) scanf("%d",&a[i]);
for(int i=1;i<=n;i++){
zero[i]=zero[i-1];
if(a[i]==0){
zero[i]++;
}
else{
if(!vis[a[i]]){
vis[a[i]]=i;
sum++;
}
else{
nxt[vis[a[i]]]=i;
vis[a[i]]=i;
}
}
}
for(int i=1;i<=n;i++){
if(!nxt[i]) nxt[i]=n+1;
}
r[n+1]=n+1;
for(int i=n;i>=1;i--){
if(a[i]==0) r[i]=r[i+1];
else r[i]=min(r[i+1],nxt[i]);
}
int ans=0;
for(int m=max(sum,1ll);m<=min(n,M);m++){
int tmp=C(M-sum,m-sum);
for(int i=m;i<=n;i+=m){
if(r[i-m+1]<=i){
tmp=0;break;
}
int k=zero[i]-zero[i-m];
tmp=tmp*fac[k]%mod;
}
if(tmp!=0&&n%m!=0){
if(r[n-n%m+1]<=n){
tmp=0;
}
else{
int k=zero[n]-zero[n-n%m];
tmp=tmp*C(m-(n%m-k),k)%mod*fac[k]%mod;
}
}
ans=(ans+tmp)%mod;
}
printf("%lld\n",ans);
}
signed main(){
fac[1]=1;fac[0]=1;
for(int i=2;i<=500000;i++) fac[i]=fac[i-1]*i%mod;
scanf("%lld",&T);
while(T--) solve();
return 0;
}