View on GitHub

Wenchong Huang

旧博客,文章搬迁完后删除

返回首页 返回专题

多项式求逆 模板

核心思想:B=(2-AB’)B’

NTT版(只支持NTT模数)

模板题:我是链接

#include<bits/stdc++.h>
#define Mod 998244353
using namespace std;

const int N=400010;
const int g=3;
int r[N];
int a[N],b[N],c[N];

int Pow(int a,int b)
{
	int ans=1;
	while(b)
	{
		if(b&1) ans=1ll*ans*a%Mod;
		a=1ll*a*a%Mod;
		b>>=1;
	}
	return ans;
}

void NTT(int *a,int len,bool INTT)
{
	int inv=Pow(len,Mod-2);
	for(int i=0;i<len;i++) if(i<r[i]) swap(a[i],a[r[i]]);
	for(int i=1;i<len;i<<=1)
	{
		int p=(i<<1);
		int wn=Pow(g,(Mod-1)/p);
		if(INTT) wn=Pow(wn,Mod-2);
		for(int j=0;j<len;j+=p)
		{
			int w=1;
			for(int k=0;k<i;k++)
			{
				int x=a[j+k],y=1ll*w*a[i+j+k]%Mod;
				a[j+k]=(x+y)%Mod;
				a[i+j+k]=(x-y+Mod)%Mod;
				w=1ll*w*wn%Mod;
			}
		}
	}
	if(INTT) for(int i=0;i<len;i++) a[i]=1ll*a[i]*inv%Mod;
}

void Inv(int deg,int *a,int *b)
{
	if(deg==1){b[0]=Pow(a[0],Mod-2);return;}
	Inv((deg+1)>>1,a,b);
	int l=0,n=1;
	while(n<(deg<<1)) n<<=1,l++;
	for(int i=0;i<n;i++) r[i]=(r[i/2]/2)|((i&1)<<(l-1));
	for(int i=0;i<deg;i++) c[i]=a[i];
	for(int i=deg;i<n;i++) c[i]=0;
	NTT(c,n,0);NTT(b,n,0);
	for(int i=0;i<n;i++)
		b[i]=1ll*(2ll-1ll*c[i]*b[i]%Mod+Mod)%Mod*b[i]%Mod;
	NTT(b,n,1);
	for(int i=deg;i<n;i++) b[i]=0;
}

int main()
{
	int n;
	scanf("%d",&n);
	for(int i=0;i<n;i++) scanf("%d",a+i);
	Inv(n,a,b);
	for(int i=0;i<n;i++) printf("%d ",b[i]);
	return 0;
}

MTT版(支持任意模数)

模板题链接:我是链接

#include<bits/stdc++.h>
#define Mod 1000000007
#define double long double
using namespace std;

typedef long long LL;
struct Comp
{
	double r,i;
	Comp(double x=0,double y=0):r(x),i(y){}
	friend Comp operator + (Comp a,Comp b){return Comp(a.r+b.r,a.i+b.i);}
	friend Comp operator - (Comp a,Comp b){return Comp(a.r-b.r,a.i-b.i);}
	friend Comp operator * (Comp a,Comp b){return Comp(a.r*b.r-a.i*b.i,a.r*b.i+a.i*b.r);}
};
const int M=sqrt(Mod);
const int N=400010;
const double pi=acos(-1);
Comp A[N],B[N],C[N],D[N],E[N],F[N],G[N];
int a[N],b[N],c[N],d[N],r[N];

int Pow(int a,int b)
{
	int ans=1;
	while(b)
	{
		if(b&1) ans=1ll*ans*a%Mod;
		a=1ll*a*a%Mod;
		b>>=1;
	}
	return ans;
}

void FFT(Comp *a,int n,int v)
{
	for(int i=0;i<n;i++) if(i<r[i]) swap(a[i],a[r[i]]);
	for(int i=1;i<n;i<<=1)
	{
		int p=(i<<1);
		Comp wn=Comp(cos(pi/i),v*sin(pi/i));
		for(int j=0;j<n;j+=p)
		{
			Comp w=Comp(1,0);
			for(int k=0;k<i;k++)
			{
				Comp x=a[j+k],y=w*a[i+j+k];
				a[j+k]=x+y;
				a[i+j+k]=x-y;
				w=w*wn;
			}
		}
	}
}

void MTT(int *a,int *b,int n,int *res)
{
	int len,l=0;
	for(len=1;len<(n<<1);len<<=1) l++;
	for(int i=0;i<len;i++) r[i]=(r[i/2]/2)|((i&1)<<(l-1));
	for(int i=0;i<n;i++)
	{
		A[i]=Comp(a[i]/M,0);
		B[i]=Comp(a[i]%M,0);
		C[i]=Comp(b[i]/M,0);
		D[i]=Comp(b[i]%M,0);
	}
	for(int i=n;i<len;i++) A[i]=B[i]=C[i]=D[i]=Comp(0,0);
	FFT(A,len,1);FFT(B,len,1);FFT(C,len,1);FFT(D,len,1);
	for(int i=0;i<len;i++)
	{
		E[i]=A[i]*C[i];
		F[i]=A[i]*D[i]+B[i]*C[i];
		G[i]=B[i]*D[i];
	}
	FFT(E,len,-1);FFT(F,len,-1);FFT(G,len,-1);
	for(int i=0;i<len;i++)
	{
		res[i]=(LL)round(E[i].r/len)%Mod*M%Mod*M%Mod;
		res[i]=(res[i]+(LL)round(F[i].r/len)%Mod*M%Mod)%Mod;
		res[i]=(res[i]+(LL)round(G[i].r/len)%Mod)%Mod;
	}
}

void Inv(int deg,int *a,int *b)
{
	if(deg==1){b[0]=Pow(a[0],Mod-2);return;}
	Inv((deg+1)>>1,a,b);
	MTT(a,b,deg,c);
	MTT(c,b,deg,d);
	for(int i=0;i<deg;i++)
	{
		b[i]=(b[i]+b[i])%Mod;
		b[i]=(b[i]-d[i])%Mod;
		if(b[i]<0) b[i]+=Mod;
	}
}

int main()
{
	int n;
	scanf("%d",&n);
	for(int i=0;i<n;i++) scanf("%d",a+i);
	Inv(n,a,b);
	for(int i=0;i<n;i++) printf("%d ",b[i]);
	return 0;
}