bzoj 3625 小朋友和二叉树 多项式开根

论坛 期权论坛 编程之家     
选择匿名的用户   2021-6-2 16:25   1304   0

常数大到飞起。

O(nlogn)的算法在CF上跑了2000ms也是神奇。

有空看下怎么常数写小一点。。

NTT做了个小优化,快了一点

#include <iostream>  
#include <cstdio>  
#include <cstring>  
#include <cmath>  
#include <algorithm>  
#include <queue>  
#include <set>  
#include <ctime>  
#include <cstdlib>  
#include <tr1/unordered_map>  
  
using namespace std;  
using namespace std::tr1;  
  
#define N 600020  
#define LL long long  
#define ls (i << 1)  
#define rs (ls | 1)  
#define md ((ll + rr) >> 1)  
#define lson ll, md, ls  
#define rson md + 1, rr, rs  
#define inf 0x3f3f3f3f  
#define K 3  
  
const int P = 998244353;  
const int pRoot = 3;  
  
int qpow(int x, int k, int p) {  
    int ret = 1;  
    while(k) {  
        if(k & 1) ret = 1LL * ret * x % p;  
        k >>= 1;  
        x = 1LL * x * x % p;  
    }  
    return ret;  
}  
int wn[25];  
int t1[N], t2[N], t3[N], t4[N];  
int inv2 = qpow(2, P - 2, P);  
  
void getWn() {  
    for(int i = 1; i <= 21; ++i) {  
        int t = 1 << i;  
        wn[i] = qpow(pRoot, (P - 1) / t, P);  
    }  
}  
  
int rev[N], wnPow[N];
  
void change(int y[], int len) {  
 for(int i = 0; i < len; ++i) {
  rev[i] = (rev[i>>1] >> 1) + (i & 1) * len / 2;
  if(i < rev[i]) swap(y[i], y[rev[i]]);
 }
}  
  
void FFT(int y[], int len, int on) {  
    change(y, len);  
    int id = 0;  
    for(int h = 2; h <= len; h <<= 1) {  
        ++id;  
  wnPow[0] = 1;
  for(int j = 1; j < h / 2; ++j) wnPow[j] = 1LL * wnPow[j-1] * wn[id] % P;
        for(int j = 0; j < len; j += h) {  
            for(int k = j; k < j + h / 2; ++k) {  
                int u = y[k];  
                int t = 1LL * wnPow[k-j] * y[k+h/2] % P;  
                y[k] = u + t;  
                if(y[k] >= P) y[k] -= P;  
                y[k+h/2] = u - t;  
                if(y[k+h/2] < 0) y[k+h/2] += P;  
            }  
        }  
    }  
    if(on == -1) {  
        for(int i = 1; i < len / 2; ++i) swap(y[i], y[len-i]);  
        int inv = qpow(len, P - 2, P);  
        for(int i = 0; i < len; ++i) {  
            y[i] = 1LL * y[i] * inv % P;  
        }  
    }  
}  
  
  
void mul(int x[], int y[], int len) {  
    FFT(x, len, 1);  
    FFT(y, len, 1);  
    for(int i = 0; i < len; ++i) x[i] = 1LL * x[i] * y[i] % P;  
    FFT(x, len, -1);  
  
}  
  
void getInv(int A[], int A0[], int k) {  
    if(k == 1) {  
        A0[0] = qpow(A[0], P - 2, P);  
        return;  
    }  
    getInv(A, A0, k / 2);  
    for(int i = 0; i < 2 * k; ++i) {  
        if(i < k) t3[i] = A[i];  
        else t3[i] = 0;  
    }  
    for(int i = k / 2; i < 2 * k; ++i) A0[i] = 0;  
    FFT(t3, 2 * k, 1);  
    FFT(A0, 2 * k, 1);  
    for(int i = 0; i < 2 * k; ++i) {  
        t3[i] = 2 - 1LL * t3[i] * A0[i] % P;  
        if(t3[i] < 0) t3[i] += P;  
        A0[i] = 1LL * A0[i] * t3[i] % P;  
    }  
    FFT(A0, 2 * k, -1);  
}  
  
  
void getSqrt(int A[], int A0[], int k) {  
    if(k == 1) {  
        A0[0] = 1;  
        return;  
    }  
    getSqrt(A, A0, k / 2);  
    for(int i = k / 2; i < 2 * k; ++i) A0[i] = 0;  
    getInv(A0, t1, k);  
    for(int i = k; i < 2 * k; ++i) t1[i] = 0;  
    for(int i = 0; i < 2 * k; ++i) {  
        if(i < k) t2[i] = A[i];  
        else t2[i] = 0;  
    }  
    FFT(A0, 2 * k, 1);  
    FFT(t1, 2 * k, 1);  
    FFT(t2, 2 * k, 1);  
    for(int i = 0; i < 2 * k; ++i) {  
        t1[i] = 1LL * t1[i] * t2[i] % P;  
        A0[i] += t1[i];  
        if(A0[i] >= P) A0[i] -= P;  
    }  
    FFT(A0, 2 * k, -1);  
      
    for(int i = 0; i < k; ++i) A0[i] = 1LL * A0[i] * inv2 % P;  
}  
int n, m, c[N], d[N];  
  
void debug() {  
    int x[8] = {1, 1, 0, 0};  
    int y[8] = {1, 1, 0, 1};  
    getSqrt(x, y, 2);  
    for(int i = 0; i < 8; ++i) {  
        printf("%d ", y[i] * 2 % P);  
    }  
    puts("");  
}  
  
int main() {  
    getWn();  
    scanf("%d%d", &n, &m);  
    for(int i = 1; i <= n; ++i) {  
        int v;  
        scanf("%d", &v);  
        if(v <= m) c[v] = 1;  
    }  
    int len = 1;  
    while(len <= m) len <<= 1;  
    for(int i = 0; i < len; ++i) {  
        c[i] = - 4 * c[i];  
        if(c[i] < 0) c[i] += P;  
    }  
    c[0]++;  
    getSqrt(c, d, len);  
    d[0]++; if(d[0] >= P) d[0] -= P;  
    getInv(d, c, len);  
    for(int i = 0; i < len; ++i) {  
        c[i] *= 2;  
        if(c[i] >= P) c[i] -= P;  
    }  
    for(int i = 1; i <= m; ++i) printf("%d\n", c[i]);  
    return 0;  
}  


分享到 :
0 人收藏
您需要登录后才可以回帖 登录 | 立即注册

本版积分规则

积分:3875789
帖子:775174
精华:0
期权论坛 期权论坛
发布
内容

下载期权论坛手机APP