Description
给定一个整数集合 \(c\),对于每个 \(i\in[1,m]\),求有多少种不同的带点权的二叉树使得这棵树点权和为 \(i\) 并且顶点的点权全部在集合 \(c\) 中。\(m\leq 10^5\)。
Solution
设 \(f[i]\) 为点权为 \(i\) 的二叉树的方案, \(c[i]\) 表示 \(i\) 是否在集合 \(c\) 中。
所以 \(f[i]=\sum\limits_{j=1}^{i} c[j]\cdot\sum\limits_{p=0}^{i-j}f[p]\cdot \sum\limits_{k=0}^{i-j-p}f[k],f[0]=1\)
发现这是个卷积形式,也就是说 \(f[i+j+k]=c[i]\cdot f[j]\cdot f[k]\),即 \(f=c\times f\times f\)。
解一下方程,\(f=\frac{1\pm \sqrt{1-4c}}{2c}\)
然而 \(c\) 的常数项为 \(0\),所以不能求逆。尝试分子有理化解得 \(f=\frac2{1\pm\sqrt{1-4c}}\)
当 \(x=0\) 时,\(c=0,f=1\),所以分母只能取正号。
求逆+开根即可。
Code
#include<bits/stdc++.h>
using std::min;
using std::max;
using std::swap;
using std::vector;
typedef double db;
typedef long long ll;
#define pb(A) push_back(A)
#define pii std::pair<int,int>
#define all(A) A.begin(),A.end()
#define mp(A,B) std::make_pair(A,B)
const int N=4e5+5;
const int mod=998244353;
#define inv(x) ksm(x,mod-2)
int lim,rev[N],f[N];
int n,m,tmpa[N],c[N];
int a[N],b[N],tmpb[N];
int ksm(int a,int b,int ans=1){
while(b){
if(b&1) ans=1ll*ans*a%mod;
a=1ll*a*a%mod;b>>=1;
} return ans;
}
void ntt(int *f,int opt){
for(int i=0;i<lim;i++) if(i<rev[i]) swap(f[i],f[rev[i]]);
for(int mid=1;mid<lim;mid<<=1){
int tmp=ksm(3,(mod-1)/(mid<<1));
if(opt<0) tmp=inv(tmp);
for(int R=mid<<1,j=0;j<lim;j+=R){
int w=1;
for(int k=0;k<mid;k++,w=1ll*w*tmp%mod){
int x=f[j+k],y=1ll*w*f[j+k+mid]%mod;
f[j+k]=(x+y)%mod,f[j+k+mid]=(mod+x-y)%mod;
}
}
} if(opt<0)
for(int in=inv(lim),i=0;i<lim;i++) f[i]=1ll*f[i]*in%mod;
}
void get(int n){
lim=1;while(lim<=n) lim<<=1;
for(int i=1;i<lim;i++) rev[i]=(rev[i>>1]>>1)|(i&1?lim>>1:0);
}
void solveinv(int len,int *a,int *b){
if(len==1) return b[0]=inv(a[0]),void();
solveinv(len>>1,a,b);
get(len);
for(int i=0;i<len;i++) tmpa[i]=a[i];
ntt(tmpa,1),ntt(b,1);
for(int i=0;i<lim;i++) b[i]=1ll*b[i]*(2ll-1ll*tmpa[i]*b[i]%mod+mod)%mod;
ntt(b,-1);
for(int i=len;i<lim;i++) b[i]=0;
for(int i=0;i<lim;i++) tmpa[i]=0;
}
void solvesqr(int len,int *a,int *b){
if(len==1) return b[0]=1,void();
solvesqr(len>>1,a,b);
solveinv(len,b,tmpb);
get(len);
for(int i=0;i<len;i++) tmpa[i]=a[i];
ntt(tmpb,1),ntt(tmpa,1);
for(int i=0;i<lim;i++) tmpa[i]=1ll*tmpa[i]*tmpb[i]%mod;
ntt(tmpa,-1);
for(int i=0,inv2=mod+1>>1;i<lim;i++) b[i]=1ll*(tmpa[i]+b[i])%mod*inv2%mod;
for(int i=len;i<lim;i++) b[i]=0;
for(int i=0;i<lim;i++) tmpa[i]=tmpb[i]=0;
}
int getint(){
int X=0,w=0;char ch=getchar();
while(!isdigit(ch))w|=ch=='-',ch=getchar();
while( isdigit(ch))X=X*10+ch-48,ch=getchar();
if(w) return -X;return X;
}
signed main(){
n=getint(),m=getint();
for(int i=1;i<=n;i++){
int x=getint();
a[x]=1;
}
for(int i=1;i<=m;i++) if(a[i]) a[i]=mod-4;
get(m);a[0]=1;solvesqr(lim,a,b);
get(m);b[0]++;solveinv(lim,b,c);
for(int i=1;i<=m;i++) printf("%lld\n",2ll*c[i]%mod);
return 0;
}