AtCoder Regular Contest 170 C. Prefix Mex Sequence(dp mex性质)

题目

给定一个长为n(n<=5e3)的01字符串s,

求满足以下条件的长为n的序列a的方案数,答案对998244353取模

条件:

对于任意i∈[1,n],

如果s[i]=1,则有A[i]=mex(A[1],A[2],...,A[i-1])

如果s[i]=0,则有A[i]≠mex(A[1],A[2],...,A[i-1])

其中,mex为未出现在集合内的最小正整数

思路来源

官方题解

题解

直观地想,是dp[i][j]表示前i个mex为j的方案数,

然后注意到会选了一些大于mex的数,所以为了转移,

就想办法把这些位置记录下来,实际对mex有贡献的时候再填进去,

但是,需要处理一个用x个位置覆盖了y种值的方案数,

求这个的复杂度很难降下来,应该是O(n^3)的

考虑另辟蹊径,dp[i][j]表示前i个出现了j种不同值的方案数,

可以这么做的原因是,无论mex是何值,

s[i]=1在值种类数确定的时候转移方程是相同的,s[i]=0也是相同的

所以,相当于是把mex不同的方案聚类起来,放在一起统计了

①如果s[i]=1,代表本次一定新选了一种数

[0,m]共m+1种数,只有之前<m+1种本次才可以,

插入之前种类数里不存在的最小正整数的这个空隙,

每种方案对应的选法都是唯一的,dp[i][j]->dp[i+1][j+1]

②如果s[i]=0,本次可以新选,也可以不新选

如果不新选,之前有j种,本次挑一种,有j种方案,j*dp[i][j]->dp[i+1][j+1]

如果新选,之前有j种,

没出现的m+1-j种中,恰有一种不能选(会导致s[i]=1),其余都可以选,(m-j)*dp[i][j]->dp[i+1][j+1]

代码

// #include<bits/stdc++.h>
#include<iostream>
#include<cstdio>
#include<vector>
// #include<map>
// #include<random>
using namespace std;
#define rep(i,a,b) for(int i=(a);i<=(b);++i)
#define per(i,a,b) for(int i=(a);i>=(b);--i)
typedef long long ll;
typedef double db;
typedef pair<int,int> P;
#define fi first
#define se second
#define pb push_back
#define dbg(x) cerr<<(#x)<<":"<<x<<" ";
#define dbg2(x) cerr<<(#x)<<":"<<x<<endl;
#define SZ(a) (int)(a.size())
#define sci(a) scanf("%d",&(a))
#define pt(a) printf("%d",a);
#define pte(a) printf("%d
",a)
#define ptlle(a) printf("%lld
",a)
#define debug(...) fprintf(stderr, __VA_ARGS__)
const int N=5e3+10,mod=998244353;
int n,m,dp[N][N],v;//dp[i][j]表示前i个出现了j种不同的数字且s[1,i]合法的方案数,前i个最多i种
void add(int &x,int y){
    x=(x+y)%mod;
}
int main(){
    sci(n),sci(m);
    dp[0][0]=1;
    rep(i,0,n-1){
        sci(v);
        if(v==1){//[0,m] 共m+1种 只有之前<m+1种才可以 插入最小空隙 方案唯一
            rep(j,0,min(i,m)){
                add(dp[i+1][j+1],dp[i][j]);
            }
        }
        else{// 要么新增一种,要么没有新增,但是新增一种的时候有一个值不能选,也就是m+1-j种里只能选m-j种,否则s[i]=1不合法
            rep(j,0,min(i,m+1)){
                if(j<m)add(dp[i+1][j+1],1ll*(m-j)*dp[i][j]%mod); // 新增一种
                if(j)add(dp[i+1][j],1ll*j*dp[i][j]%mod);// 不新增
            }
        }
    }
    int ans=0;
    rep(i,0,min(n,m+1)){
        add(ans,dp[n][i]);
    }
    pte(ans);
    return 0;
}