数学系の問題が苦手すぎて困ったのでちゃんと記事にして忘れないようにしました.
総和のついた数式を変形させていくタイプの問題でした.
atcoder.jp
問題概要
長さ $ N $ の整数列 $ A $ と $ K $ が与えられる.
$ 1 \leq X \leq K $ について,以下の値を求めよ.
$$ \left( \sum _ {L=1}^{N-1} \sum _ {R=L+1}^{N} (A_L + A_R)^{K} \right) \bmod 998244353 $$
制約
- $ 2 \leq N \leq 2 \times 10^{5} $
- $ 1 \leq K \leq 300 $
- $ 1 \leq A_i \leq 10^{8} $
考察
本質ではないところから.問題文では 1-indexed になっていて嫌な気持ちになるので, $ i = L - 1, j = R - 1 $ として 0-indexed に直す.
$$ \sum _ {i=0}^{N-2} \sum _ {j=i+1}^{N-1} (A_i + A_j)^{K} $$
また,求める値を $ ans_X (1 \leq X \leq N) $ と置くことにする.
$ (A_i + A_j)^{K} $ はこのままの形で扱うのは難しいので変形したいが,この形は二項定理で展開するのが良いと相場が決まっている.すると,
$$ ans_X = \sum _ {i=0}^{N-2} \sum _ {j=i+1}^{N-1} \sum _ {k=0}^{X} \binom{X}{k} (A_i)^{k} (A_j)^{X-k} $$
のようになる.シグマがいっぱいで困っちゃう.
知識 (シグマを含む式の変形)
上の式のようにシグマが二重・三重になっているような式は次の3つの変形を使ってうまいこと書き換えるのが定石.
項ごとに分割して考えることができる
対象の式をカウンタ変数( $ i $ や $ j $ )の式として見たときに定数なら括り出せる
2つの並んだシグマは,2つのカウンタ変数の初期値/終端値に依存関係がないなら入れ替えられる
1, 2は総和の線形性より明らかです.
$$ \sum _ {i=0}^{N-1} (\alpha A_i + \beta B_i) = \alpha \sum _ {i=0}^{N-1} A_i + \beta \sum _ {i=0}^{N-1} B_i $$
3は自分の数学力では何といえばいいのかわかりませんが中高で習うはずです.
$ j $ の初期値/終端値に $ i $ が含まれているなどの場合は $ i \rightarrow j $ の順番を守らなければいけません.
閑話休題,問題の考察に戻って実際に変形してみる.なにも考えず変形規則をできるだけ適用した.
\begin{align}
ans_X &= \sum _ {i=0}^{N-2} \sum _ {j=i+1}^{N-1} \sum _ {k=0}^{X} \binom{X}{k} (A_i)^{k} (A_j)^{X-k} \\
&=\sum _ {k=0}^{X} \binom{X}{k} \sum _ {i=0}^{N-2} \left\{ (A_i)^{k} \cdot \sum _ {j=i+1}^{N-1} (A_j)^{X-k} \right\}
\end{align}
この形だと $ \sum _ {j=i+1}^{N-1} (A_j)^{X-k} $ を前計算しておくことで $ O(NK^{2}) $ で答えが得られる( $ ans_X $ $ 1 $ つにつき $ O(NK) $).
当然制約上このままでは間にあわない.式を見ると $ \sum _ {j=i+1}^{N-1} (A_j)^{X-k} $ が $ i $ の関数になっていることで $ i $ のシグマから括り出せていないことがわかる.この式が $ i $ の関数になってしまっている原因は明らかに $ j=i+1 $ であるからここを $ j=0 $ に変えてしまうことを考える.
最初の式に立ち返って $ j=0 $ にすると,どのぐらい答えに影響するのかを考える.これは意外と単純で $ ans_X $ は $ i < j $ の場合の和なので,$ j = i+1 $ を $ j=0 $ に変えた和から $ i = j $ と $ i > j $ の場合の和を引けば $ i < j $ の場合の和が算出されるだろう.また対称性から$ i < j $ の場合と $ i > j $ の場合の和は等しい.以上を式で言えば次の通りである.
\begin{align}
ans_X &= \sum _ {i=0}^{N-1}\ \sum _ {j=0}^{N-1} (A_i + A_j)^{X} - \sum _ {i=0}^{N-1} (A_i + A_i)^{X} - \sum _ {j=0}^{N-1} \sum _ {i=j+1}^{N-1} (A_i + A_j)^{X} \\
&= \left( \sum _ {i=0}^{N-1} \sum _ {j=0}^{N-1} (A_i + A_j)^{X} - \sum _ {i=0}^{N-1} (A_i + A_i)^{X} \right) \times \frac{1}{2} \\
&= \left( \sum _ {i=0}^{N-1} \sum _ {j=0}^{N-1} (A_i + A_j)^{X} - 2^{X} \cdot \sum _ {i=0}^{N-1} (A_i)^{X} \right) \times \frac{1}{2} \tag{*}
\end{align}
$ \sum _ {i=0}^{N-1} \sum _ {j=0}^{N-1} (A_i + A_j)^{X} $ を最初の変形と同様なにも考えず変形.
\begin{align}
\sum _ {i=0}^{N-1} \sum _ {j=0}^{N-1} (A_i + A_j)^{X} &= \sum _ {i=0}^{N-1} \sum _ {j=0}^{N-1} \sum _ {k=0}^{X} \binom{X}{k} (A_i)^{k} (A_j)^{X-k} \\
&= \sum _ {k=0}^{X} \binom{X}{k} \left( \sum _ {i=0}^{N-1} (A_i)^{k} \right) \left( \sum _ {j=0}^{N-1} (A_j)^{X-k} \right)
\end{align}
ここで $ f(k) = \sum _ {i=0}^{N-1} (A_i)^{k} $ と置くと,$ f(k) $ $ (0 \leq k \leq K) $ を $ O(NK) $ で前計算することにより答えが $ O(K^{2}) $ で求められる.
今一度全体の式 $ (*) $ を書き下すと
$$ ans_X = \left( \sum _ {k=0}^{X} \binom{X}{k} f(k) f(X-k) - 2^{X} f(X) \right) \times \frac{1}{2} $$
二項係数,$ f(k) $,$ 2^{k} $ を前計算することで全体の計算量は $ O(NK + K^{2}) $ になり間に合う.
余談
整理した式 $ (*) $を見ると
$$ \sum _ {k=0}^{X} \binom{X}{k} f(k) f(X-k) $$
といういかにも畳み込みっぽい形が見つかる.実際,
\begin{align}
\sum _ {k=0}^{X} \binom{X}{k} f(k) f(X-k) &= \sum _ {k=0}^{X} \frac{X!}{k!(X-k)!} f(k) f(X-k) \\
&= X! \sum _ {k=0}^{X} \frac{ f(k) }{ k! } \frac{ f(X-k) }{ (X-k)! } \\
&= X! \sum _ {k=0}^{X} g(k) g(X-k) \\
& \left( g(k) = \frac{ f(k) }{ k! } \right)
\end{align}
となりFFTを用いて $ O(NK + K \log{K}) $ になる.ボトルネックそこじゃないけど.
本当に余談だが,$ j $ の初期化を $ j=i+1 $ のままにしても畳み込みの形にすることができてその時の計算量は $ O(NK \log{K}) $ になる.試したけど間に合わんかった.つらい.
個人的なポイント
- 2項の式の累乗の式を見たら,とりあえず展開すべき
- シグマが二重三重になっている問題は変形が肝要であることが多い
- シグマ同士の入れ替えや線形性を利用した変形を手元でするべき
- シグマの変形の中でも
$$ \sum _ {i=0}^{N-1} \sum _ {j=0}^{N-1} f(i) f(j) = \left( \sum _ {i=0}^{N-1} f(i) \right) \cdot \left( \sum _ {j=0}^{N-1} f(j) \right) $$
の変形は典型な気がする (やはり $ N $ が $ 1 $ つ落ちるのは大きい)
提出URL1
提出URL2(FFT)
▶ソースコードを展開
#include <bits/stdc++.h>
#define rep(i,n) for(int i=0;i<(int)(n);i++)
#define FOR(i,n,m) for(int i=(int)(n); i<=(int)(m); i++)
#define RFOR(i,n,m) for(int i=(int)(n); i>=(int)(m); i--)
#define ITR(x,c) for(__typeof(c.begin()) x=c.begin();x!=c.end();x++)
#define RITR(x,c) for(__typeof(c.rbegin()) x=c.rbegin();x!=c.rend();x++)
#define setp(n) fixed << setprecision(n)
template<class T> bool chmax(T &a, const T &b) { if (a<b) { a=b; return 1; } return 0; }
template<class T> bool chmin(T &a, const T &b) { if (a>b) { a=b; return 1; } return 0; }
#define ll long long
#define vll vector<ll>
#define vi vector<int>
#define pll pair<ll,ll>
#define pi pair<int,int>
#define all(a) (a.begin()),(a.end())
#define rall(a) (a.rbegin()),(a.rend())
#define fi first
#define se second
#define pb push_back
#define ins insert
#define debug(a) cerr<<(a)<<endl
#define dbrep(a,n) rep(_i,n) cerr<<(a[_i])<<" "; cerr<<endl
#define dbrep2(a,n,m) rep(_i,n){rep(_j,m) cerr<<(a[_i][_j])<<" "; cerr<<endl;}
using namespace std;
template<class A, class B>
ostream &operator<<(ostream &os, const pair<A,B> &p){return os<<"("<<p.fi<<","<<p.se<<")";}
template<class A, class B>
istream &operator>>(istream &is, pair<A,B> &p){return is>>p.fi>>p.se;}
const ::std::uint_fast64_t MOD = 998244353;
class mint
{
private:
using value_type = ::std::uint_fast64_t;
value_type n;
public:
mint():n(0){}
mint(::std::int_fast64_t _n):n(_n<0 ? MOD-(-_n)%MOD : _n%MOD){}
mint(const mint &m):n(m.n){}
friend ::std::ostream& operator<<(::std::ostream &os, const mint &a){
return os << a.n;
}
friend ::std::istream& operator>>(::std::istream &is, mint &a){
value_type temp; is>>temp;
a = mint(temp);
return is;
}
mint& operator+=(const mint &m){n+=m.n; n=(n<MOD)?n:n-MOD; return *this;}
mint& operator-=(const mint &m){n+=MOD-m.n; n=(n<MOD)?n:n-MOD; return *this;}
mint& operator*=(const mint &m){n=n*m.n%MOD; return *this;}
mint& operator/=(const mint &m){return *this*=m.inv();}
mint& operator++(){return *this+=1;}
mint& operator--(){return *this-=1;}
mint operator+(const mint &m) const {return mint(*this)+=m;}
mint operator-(const mint &m) const {return mint(*this)-=m;}
mint operator*(const mint &m) const {return mint(*this)*=m;}
mint operator/(const mint &m) const {return mint(*this)/=m;}
mint operator++(int){mint t(*this); *this+=1; return t;}
mint operator--(int){mint t(*this); *this-=1; return t;}
bool operator==(const mint &m) const {return n==m.n;}
bool operator!=(const mint &m) const {return n!=m.n;}
mint operator-() const {return mint(MOD-n);}
mint pow(value_type b) const {
mint ret(1), m(*this);
while(b){
if (b & 1) ret*=m;
m*=m;
b>>=1;
}
return ret;
}
mint inv() const {return pow(MOD-2);}
};
class Combination
{
private:
::std::vector<mint> _fact;
::std::vector<mint> _finv;
public:
Combination(int n):_fact(n+1), _finv(n+1){
_fact[0] = _fact[1] = 1;
_finv[0] = _finv[1] = 1;
for(int i=2; i<=n; i++){
_fact[i] = _fact[i-1]*i;
_finv[i] = _fact[i].inv();
}
}
mint fact(int x){return _fact[x];}
mint finv(int x){return _finv[x];}
mint comb(int x, int y){
if (y>x || y<0) return 0;
return _fact[x]*_finv[y]*_finv[x-y];
}
mint homo(int x, int y){return comb(x+y-1, y);}
};
const int N_MAX = 2e5;
const int K_MAX = 300;
mint asum[K_MAX+1];
mint pow2[K_MAX+1];
int main(void)
{
cin.tie(0);
ios::sync_with_stdio(false);
ll N,K; cin>>N>>K;
vll a(N);
rep(i,N) cin>>a[i];
rep(i,N){
mint m = 1;
rep(j,K+1){
asum[j]+=m;
m*=a[i];
}
}
pow2[0]=1;
rep(i,K) pow2[i+1]=pow2[i]*2;
Combination bn(K);
mint inv2 = mint(2).inv();
FOR(x,1,K){
mint ans=0;
FOR(i,0,x){
ans+=bn.comb(x,i)*asum[x-i]*asum[i];
}
ans-=pow2[x]*asum[x];
cout<<ans*inv2<<"\n";
}
return 0;
}