常数大到飞起。
O(nlogn)的算法在CF上跑了2000ms也是神奇。
有空看下怎么常数写小一点。。
NTT做了个小优化,快了一点
#include <iostream>
#include <cstdio>
#include <cstring>
#include <cmath>
#include <algorithm>
#include <queue>
#include <set>
#include <ctime>
#include <cstdlib>
#include <tr1/unordered_map>
using namespace std;
using namespace std::tr1;
#define N 600020
#define LL long long
#define ls (i << 1)
#define rs (ls | 1)
#define md ((ll + rr) >> 1)
#define lson ll, md, ls
#define rson md + 1, rr, rs
#define inf 0x3f3f3f3f
#define K 3
const int P = 998244353;
const int pRoot = 3;
int qpow(int x, int k, int p) {
int ret = 1;
while(k) {
if(k & 1) ret = 1LL * ret * x % p;
k >>= 1;
x = 1LL * x * x % p;
}
return ret;
}
int wn[25];
int t1[N], t2[N], t3[N], t4[N];
int inv2 = qpow(2, P - 2, P);
void getWn() {
for(int i = 1; i <= 21; ++i) {
int t = 1 << i;
wn[i] = qpow(pRoot, (P - 1) / t, P);
}
}
int rev[N], wnPow[N];
void change(int y[], int len) {
for(int i = 0; i < len; ++i) {
rev[i] = (rev[i>>1] >> 1) + (i & 1) * len / 2;
if(i < rev[i]) swap(y[i], y[rev[i]]);
}
}
void FFT(int y[], int len, int on) {
change(y, len);
int id = 0;
for(int h = 2; h <= len; h <<= 1) {
++id;
wnPow[0] = 1;
for(int j = 1; j < h / 2; ++j) wnPow[j] = 1LL * wnPow[j-1] * wn[id] % P;
for(int j = 0; j < len; j += h) {
for(int k = j; k < j + h / 2; ++k) {
int u = y[k];
int t = 1LL * wnPow[k-j] * y[k+h/2] % P;
y[k] = u + t;
if(y[k] >= P) y[k] -= P;
y[k+h/2] = u - t;
if(y[k+h/2] < 0) y[k+h/2] += P;
}
}
}
if(on == -1) {
for(int i = 1; i < len / 2; ++i) swap(y[i], y[len-i]);
int inv = qpow(len, P - 2, P);
for(int i = 0; i < len; ++i) {
y[i] = 1LL * y[i] * inv % P;
}
}
}
void mul(int x[], int y[], int len) {
FFT(x, len, 1);
FFT(y, len, 1);
for(int i = 0; i < len; ++i) x[i] = 1LL * x[i] * y[i] % P;
FFT(x, len, -1);
}
void getInv(int A[], int A0[], int k) {
if(k == 1) {
A0[0] = qpow(A[0], P - 2, P);
return;
}
getInv(A, A0, k / 2);
for(int i = 0; i < 2 * k; ++i) {
if(i < k) t3[i] = A[i];
else t3[i] = 0;
}
for(int i = k / 2; i < 2 * k; ++i) A0[i] = 0;
FFT(t3, 2 * k, 1);
FFT(A0, 2 * k, 1);
for(int i = 0; i < 2 * k; ++i) {
t3[i] = 2 - 1LL * t3[i] * A0[i] % P;
if(t3[i] < 0) t3[i] += P;
A0[i] = 1LL * A0[i] * t3[i] % P;
}
FFT(A0, 2 * k, -1);
}
void getSqrt(int A[], int A0[], int k) {
if(k == 1) {
A0[0] = 1;
return;
}
getSqrt(A, A0, k / 2);
for(int i = k / 2; i < 2 * k; ++i) A0[i] = 0;
getInv(A0, t1, k);
for(int i = k; i < 2 * k; ++i) t1[i] = 0;
for(int i = 0; i < 2 * k; ++i) {
if(i < k) t2[i] = A[i];
else t2[i] = 0;
}
FFT(A0, 2 * k, 1);
FFT(t1, 2 * k, 1);
FFT(t2, 2 * k, 1);
for(int i = 0; i < 2 * k; ++i) {
t1[i] = 1LL * t1[i] * t2[i] % P;
A0[i] += t1[i];
if(A0[i] >= P) A0[i] -= P;
}
FFT(A0, 2 * k, -1);
for(int i = 0; i < k; ++i) A0[i] = 1LL * A0[i] * inv2 % P;
}
int n, m, c[N], d[N];
void debug() {
int x[8] = {1, 1, 0, 0};
int y[8] = {1, 1, 0, 1};
getSqrt(x, y, 2);
for(int i = 0; i < 8; ++i) {
printf("%d ", y[i] * 2 % P);
}
puts("");
}
int main() {
getWn();
scanf("%d%d", &n, &m);
for(int i = 1; i <= n; ++i) {
int v;
scanf("%d", &v);
if(v <= m) c[v] = 1;
}
int len = 1;
while(len <= m) len <<= 1;
for(int i = 0; i < len; ++i) {
c[i] = - 4 * c[i];
if(c[i] < 0) c[i] += P;
}
c[0]++;
getSqrt(c, d, len);
d[0]++; if(d[0] >= P) d[0] -= P;
getInv(d, c, len);
for(int i = 0; i < len; ++i) {
c[i] *= 2;
if(c[i] >= P) c[i] -= P;
}
for(int i = 1; i <= m; ++i) printf("%d\n", c[i]);
return 0;
}
|