ICode9

精准搜索请尝试: 精确搜索
首页 > 其他分享> 文章详细

LOJ#6289. 花朵 树链剖分+分治NTT

2020-07-25 08:31:13  阅读:260  来源: 互联网

标签:return 6289 剖分 LOJ mid poly int len 重链


本来以为这道题会非常难调,但是没想到调了不到 5 分钟就 A 了.  

由于基于多项式的运算都可以方便地进行封装,所以细节就不是很多(或者说几乎没有细节)   

题意:给定一棵树,每个点有点权,求对于所有大小为 $m$ 的独立集的点权之积的和.     

数据范围:$n,m \leqslant 8 \times 10^4$.  

先考虑一个十分显然的 $O(n^2)$ 暴力:

令 $f[x][i],g[x][i]$ 分别表示点 $x$ 选/不选的情况下独立集大小为 $i$ 的点积 之和.  

考虑将 $x$ 与 $x$ 的一个儿子 $y$ 合并:$f[x][i+j]=f[x][i] \times f[y][j]$,$g$ 同理.  

然后 $x$ 的初始值是:$f[x][1]=w[x],g[x][0]=1$.    

树形DP 卡一下上界复杂度是 $O(n^2)$ 的.  

不难发现,上述 $f[x][i+j] = f[x][i] \times f[y][j]$ 是一个卷积的形式.  

如果是菊花图或者链的话可以直接用 NTT/分治NTT 来做.   

正解的话考虑进行轻重路径剖分:   

对于一条重链来说,先求出该重链中每个点轻儿子为根的多项式 $f,g$,然后对于重链中每个点都将其轻儿子与该点合并.   

最后对于一条重链进行分治,求出该重链链顶为根的多项式.   

分析一下时间复杂度: 

考虑一条重链链顶为根的子树会被卷多少次:其祖先中每一条重链都会将其贡献一次.  

那么树链剖分中一个点有 $O(\log n)$ 个祖先,而每次卷积的时候对链分治的复杂度是 $O(n \log^2 n)$.  

总复杂度就是 $O(n \log^3 n)$,但是由于树链剖分的常数比较小,跑的并不慢.   

code:  

#include <queue>
#include <cstdio>   
#include <vector>
#include <cstring> 
#include <algorithm>  
#define N 1000009 
#define ll long long 
#define mod 998244353 
#define pb push_back
#define setIO(s) freopen(s".in","r",stdin)  
using namespace std;  
int m; 
int A[N<<2],B[N<<2];      
int tim,edges,n; 
int size[N],son[N],top[N],hd[N],to[N<<1],nex[N<<1],fa[N],dep[N]; 
int dfn[N],bu[N],si[N],val[N];   
void add(int u,int v) { 
    nex[++edges]=hd[u],hd[u]=edges,to[edges]=v;  
}
int ADD(int x,int y) { 
    return (ll)(x+y)%mod; 
}  
int DEC(int x,int y) { 
    return (ll)(x-y+mod)%mod; 
}  
int MUL(int x,int y) { 
    return (ll)x*y%mod; 
}
int qpow(int x,int y) { 
    int tmp=1; 
    for(;y;y>>=1,x=(ll)x*x%mod) {   
        if(y&1) tmp=(ll)tmp*x%mod; 
    }  
    return tmp; 
}
int get_inv(int x) { 
    return qpow(x,mod-2); 
}
void NTT(int *a,int len,int op) { 
    for(int i=0,k=0;i<len;++i) { 
        if(i>k) { 
            swap(a[i],a[k]); 
        }  
        for(int j=len>>1;(k^=j)<j;j>>=1); 
    }  
    for(int l=1;l<len;l<<=1) { 
        int wn=qpow(3,(mod-1)/(l<<1));  
        if(op==-1) wn=get_inv(wn);  
        for(int i=0;i<len;i+=l<<1) { 
            int w=1;  
            for(int j=0;j<l;++j) { 
                int x=a[i+j],y=(ll)w*a[i+j+l]%mod;  
                a[i+j]=(ll)(x+y)%mod;  
                a[i+j+l]=(ll)(x-y+mod)%mod;  
                w=(ll)w*wn%mod; 
            }
        }
    }
    if(op==-1) { 
        int iv=get_inv(len); 
        for(int i=0;i<len;++i) { 
            a[i]=(ll)a[i]*iv%mod;   
        }
    }
}
struct poly { 
    int len;
    vector<int>a;  
    poly() { len=0,a.clear(); } 
    void push(int x) { 
        a.pb(x),++len;
    }
    void resize(int x) {
        a.resize(x),len=x;    
    }                       
    poly operator*(const poly &b) const { 
        int lim;
        for(lim=1;lim<len+b.len-1;lim<<=1); 
        for(int i=0;i<lim;++i) A[i]=B[i]=0;
        for(int i=0;i<len;++i) A[i]=a[i];
        for(int i=0;i<b.len;++i) B[i]=b.a[i];
        NTT(A,lim,1),NTT(B,lim,1);
        for(int i=0;i<lim;++i) {    
            A[i]=(ll)A[i]*B[i]%mod;
        }
        NTT(A,lim,-1);
        poly c;
        for(int i=0;i<len+b.len-1;++i) { 
            c.push(A[i]); 
        }
        if(c.len>m+1) c.resize(m+1);
        return c;   
    }
    poly operator+(const poly &b) const {
        poly c; 
        c.resize(max(len,b.len));  
        for(int i=0;i<c.len;++i) c.a[i]=0; 
        for(int i=0;i<c.len;++i) {    
            if(i<len) c.a[i]=ADD(c.a[i],a[i]); 
            if(i<b.len) c.a[i]=ADD(c.a[i],b.a[i]);  
        }
        return c;   
    }
    poly operator-(const poly &b) const {    
        poly c;  
        c.resize(max(len,b.len));    
        for(int i=0;i<c.len;++i) c.a[i]=0;
        for(int i=0;i<c.len;++i) { 
            if(i<len) c.a[i]=ADD(c.a[i],a[i]); 
            if(i<b.len) c.a[i]=DEC(c.a[i],b.a[i]);  
        }  
        return c;  
    }
}f0[N],f1[N],g[2][N];      
struct data {
    poly f00,f01,f10,f11;           
    data operator+(const data &b) const { 
        data c;   
        c.f00=(f01*b.f00)+(f00*(b.f00+b.f10));   
        c.f11=(f11*b.f01)+(f10*(b.f11+b.f01));    
        c.f01=(f01*b.f01)+(f00*(b.f01+b.f11));      
        c.f10=(f11*b.f00)+(f10*(b.f10+b.f00));    
        return c;  
    }
}tmp;  
void dfs1(int x,int ff) {  
    fa[x]=ff,dep[x]=dep[ff]+1,size[x]=1;  
    for(int i=hd[x];i;i=nex[i]) { 
        int y=to[i];  
        if(y==ff) continue;  
        dfs1(y,x);
        size[x]+=size[y];
        if(size[y]>size[son[x]]) son[x]=y;
    }
}
void dfs2(int x,int tp) { 
    top[x]=tp;  
    dfn[x]=++tim;
    bu[tim]=x;
    ++si[tp];  
    if(son[x]) {  
        dfs2(son[x],tp); 
    }
    for(int i=hd[x];i;i=nex[i]) {    
        if(to[i]!=fa[x]&&to[i]!=son[x]) { 
            dfs2(to[i],to[i]);  
        }
    }
}
poly calc(int l,int r,int d) {     
    if(l==r) {   
        return g[d][l];  
    }
    int mid=(l+r)>>1;  
    return calc(l,mid,d)*calc(mid+1,r,d);  
}
data solve(int l,int r) {   
    if(l==r) {      
        int u=bu[l];   
        data e;   
        e.f00=f0[u];  
        e.f11=f1[u];  
        return e;  
    }
    int mid=(l+r)>>1;       
    return solve(l,mid)+solve(mid+1,r);  
}
int main() { 
    // setIO("input");  
    int x,y,z; 
    scanf("%d%d",&n,&m);    
    for(int i=1;i<=n;++i) scanf("%d",&val[i]);
    for(int i=1;i<n;++i) {
        scanf("%d%d",&x,&y); 
        add(x,y),add(y,x); 
    }
    dfs1(1,0),dfs2(1,1);       
    for(int i=1;i<=n;++i) {
        f0[i].push(1);  
        f1[i].push(0);  
        f1[i].push(val[i]);    
    }        
    for(int i=n;i>=1;--i) {
        int p=bu[i]; 
        if(top[p]==p) {
            for(int j=dfn[p];j<=dfn[p]+si[p]-1;++j) { 
                x=bu[j];         
                int p0=0,p1=0;      
                for(int e=hd[x];e;e=nex[e]) {
                    y=to[e];  
                    if(y==son[x]||y==fa[x]) continue;            
                    g[0][++p0]=f0[y]+f1[y];   
                    g[1][++p1]=f0[y];  
                }     
                if(p0) f0[x]=calc(1,p0,0);  
                if(p1) f1[x]=f1[x]*calc(1,p1,1); 
            } 
            tmp=solve(dfn[p],dfn[p]+si[p]-1);       
            f0[p]=tmp.f01+tmp.f00;  
            f1[p]=tmp.f10+tmp.f11;             
        }
    }   
    f0[1].resize(m+1); 
    f1[1].resize(m+1);  
    printf("%d\n",(ll)(f0[1].a[m]+f1[1].a[m])%mod);  
    return 0; 
}

  

标签:return,6289,剖分,LOJ,mid,poly,int,len,重链
来源: https://www.cnblogs.com/guangheli/p/13375471.html

本站声明: 1. iCode9 技术分享网(下文简称本站)提供的所有内容,仅供技术学习、探讨和分享;
2. 关于本站的所有留言、评论、转载及引用,纯属内容发起人的个人观点,与本站观点和立场无关;
3. 关于本站的所有言论和文字,纯属内容发起人的个人观点,与本站观点和立场无关;
4. 本站文章均是网友提供,不完全保证技术分享内容的完整性、准确性、时效性、风险性和版权归属;如您发现该文章侵犯了您的权益,可联系我们第一时间进行删除;
5. 本站为非盈利性的个人网站,所有内容不会用来进行牟利,也不会利用任何形式的广告来间接获益,纯粹是为了广大技术爱好者提供技术内容和技术思想的分享性交流网站。

专注分享技术,共同学习,共同进步。侵权联系[81616952@qq.com]

Copyright (C)ICode9.com, All Rights Reserved.

ICode9版权所有