luckYrat's library.

This documentation is automatically generated by competitive-verifier/competitive-verifier


:heavy_check_mark: cpp/math/convolution.cpp

Depends on

Required by

Verified with

Code

// Ref: https://qiita.com/AngrySadEight/items/0dfde26060daaf6a2fda

#include "./binary-power-method.cpp"
#include "../data-structure/mod-int/mod-int.cpp"
using namespace std;

template<typename MINT>
vector<MINT> ntt(vector<MINT> X, int depth, vector<MINT> root) {
  long long n = X.size();
  if(n == 1){
    return X;
  }else{
    vector<MINT> val(0);
    vector<MINT> even(0);
    vector<MINT> odd(0);
    for(int i = 0; n > i; i++){
      if(i % 2 == 0)even.push_back(X[i]);
      else odd.push_back(X[i]);
    }

    auto ntt_even = ntt(even, depth-1, root);
    auto ntt_odd = ntt(odd, depth-1, root);

    mint r = root[depth];

    MINT now = 1;
    for(int i = 0; n > i; i++){
      val.push_back(ntt_even[i%(n/2)] + (now * ntt_odd[i%(n/2)]));
      now *= r;
    }
    return val;
  }
}

template<typename MINT> // 998244353 mod
vector<MINT> make_root(long long p){
  vector<MINT> val(0);
  mint r = uPow(3LL, 119LL, p);
  for(int i = 0; 23 > i; i++){
    val.push_back(r);
    r *= r;
  }
  reverse(val.begin(), val.end());
  return val;
}

template<typename MINT>
vector<MINT> make_invroot(vector<MINT> root){
  vector<MINT> val(0);
  for(int i = 0; 23 > i; i++){
    val.push_back(root[i].inverse());
  }
  return val;
}

template<typename MINT>
vector<MINT> convolution(vector<MINT> A, vector<MINT> B){
  long long p = A[0].getMod(); // each mod must be same

  vector<MINT> root = make_root<MINT>(p);
  vector<MINT> invroot = make_invroot<MINT>(root);

  size_t size = (A.size()+B.size()-1);
  int n = 1;
  int log2_n = 0;
  while(n < size){
    n *= 2;
    log2_n++;
  }
  
  while(A.size() < n)A.push_back(0);
  while(B.size() < n)B.push_back(0);

  // AとBのNTTを求める
  auto nttA = ntt(A, log2_n-1, root);
  auto nttB = ntt(B, log2_n-1, root);

  vector<MINT> nttC(n);
  for(int i = 0; n > i; i++){
    nttC[i] = nttA[i]*nttB[i];
  }

  auto nC = ntt(nttC, log2_n-1, invroot);
  vector<MINT> C(size);
  for(int i = 0; size > i; i++){
    C[i] = nC[i]/(mint)n;
  }

  return C;
}

#line 1 "cpp/math/convolution.cpp"
// Ref: https://qiita.com/AngrySadEight/items/0dfde26060daaf6a2fda

#line 2 "cpp/math/binary-power-method.cpp"

template <typename T>
T uPow(T z,T n, T mod){
  T ans = 1;
  while(n != 0){
    if(n%2){
      ans*=z;
      if(mod)ans%=mod;
    }
    n >>= 1;
    z*=z;
    if(mod)z%=mod;
  }
  return ans;
}

#line 2 "cpp/data-structure/mod-int/mod-int.cpp"

template <int mod>
struct ModInt{
  int n;
  ModInt():n(0){}
  ModInt(long long n_):n(n_ >= 0 ? n_%mod : mod - ((-n_)%mod) ){}
  ModInt(int n_):n(n_ >= 0 ? n_%mod : mod - ((-n_)%mod) ){}

  ModInt &operator+=(const ModInt &p){
    if((n+=p.n) >= mod)n-=mod;
    return *this;
  }
  ModInt &operator-=(const ModInt &p){
    n+=mod-p.n;
    if(n >= mod)n-=mod;
    return *this;
  }
  ModInt &operator*=(const ModInt &p){
    n = (int) ((1LL*n*p.n)%mod);
    return *this;
  }
  ModInt &operator/=(const ModInt &p){
    *this *= p.inverse();
    return *this;
  }
  ModInt operator-() const {return ModInt(-n);}
  ModInt operator+(const ModInt &p) const {return ModInt(*this) += p;}
  ModInt operator-(const ModInt &p) const {return ModInt(*this) -= p;}
  ModInt operator*(const ModInt &p) const {return ModInt(*this) *= p;}
  ModInt operator/(const ModInt &p) const {return ModInt(*this) /= p;}

  bool operator==(const ModInt &p) const {return n==p.n;}
  bool operator<(const ModInt &p) const {return n<p.n;}
  bool operator>(const ModInt &p) const {return n>p.n;}
  bool operator>=(const ModInt &p) const {return n>=p.n;}
  bool operator<=(const ModInt &p) const {return n<=p.n;}
  bool operator!=(const ModInt &p) const {return n!=p.n;}

  ModInt inverse() const {
    int a = n,b = mod,u = 1,v = 0;
    while(b){
      int t = a/b;
      a -= t*b; swap(a,b);
      u -= t*v; swap(u,v);
    }
    return ModInt(u);
  }

  ModInt pow(int64_t z) const {
    ModInt ret(1),mul(n);
    while(z > 0){
      if(z & 1) ret *= mul;
      mul *= mul;
      z >>= 1;
    }
    return ret;
  }

  int getMod() const {
    return mod;
  }

  friend ostream &operator<<(ostream &os, const ModInt &p){
    return os << p.n;
  }
  friend istream &operator>>(istream &is, ModInt &a){
    int64_t t;
    is >> t;
    a = ModInt<mod> ((long long)t);
    return (is);

  }
};
using mint = ModInt<MOD>;
#line 5 "cpp/math/convolution.cpp"
using namespace std;

template<typename MINT>
vector<MINT> ntt(vector<MINT> X, int depth, vector<MINT> root) {
  long long n = X.size();
  if(n == 1){
    return X;
  }else{
    vector<MINT> val(0);
    vector<MINT> even(0);
    vector<MINT> odd(0);
    for(int i = 0; n > i; i++){
      if(i % 2 == 0)even.push_back(X[i]);
      else odd.push_back(X[i]);
    }

    auto ntt_even = ntt(even, depth-1, root);
    auto ntt_odd = ntt(odd, depth-1, root);

    mint r = root[depth];

    MINT now = 1;
    for(int i = 0; n > i; i++){
      val.push_back(ntt_even[i%(n/2)] + (now * ntt_odd[i%(n/2)]));
      now *= r;
    }
    return val;
  }
}

template<typename MINT> // 998244353 mod
vector<MINT> make_root(long long p){
  vector<MINT> val(0);
  mint r = uPow(3LL, 119LL, p);
  for(int i = 0; 23 > i; i++){
    val.push_back(r);
    r *= r;
  }
  reverse(val.begin(), val.end());
  return val;
}

template<typename MINT>
vector<MINT> make_invroot(vector<MINT> root){
  vector<MINT> val(0);
  for(int i = 0; 23 > i; i++){
    val.push_back(root[i].inverse());
  }
  return val;
}

template<typename MINT>
vector<MINT> convolution(vector<MINT> A, vector<MINT> B){
  long long p = A[0].getMod(); // each mod must be same

  vector<MINT> root = make_root<MINT>(p);
  vector<MINT> invroot = make_invroot<MINT>(root);

  size_t size = (A.size()+B.size()-1);
  int n = 1;
  int log2_n = 0;
  while(n < size){
    n *= 2;
    log2_n++;
  }
  
  while(A.size() < n)A.push_back(0);
  while(B.size() < n)B.push_back(0);

  // AとBのNTTを求める
  auto nttA = ntt(A, log2_n-1, root);
  auto nttB = ntt(B, log2_n-1, root);

  vector<MINT> nttC(n);
  for(int i = 0; n > i; i++){
    nttC[i] = nttA[i]*nttB[i];
  }

  auto nC = ntt(nttC, log2_n-1, invroot);
  vector<MINT> C(size);
  for(int i = 0; size > i; i++){
    C[i] = nC[i]/(mint)n;
  }

  return C;
}

Back to top page