Yuulis.log

Yuulis.log

トンネルを抜けるとそこは参照エラーであった。

【AtCoder】ABC 405 E - Fruit Lineup | 緑コーダーが解くAtCoder

atcoder.jp

配点: 475 点 / 実行時間制限: 2 sec / メモリ制限: 1024 MB / Difficulty: ??? / NoviSteps: 1D

問題概要

 A 個のリンゴと  B 個のオレンジと  C 個のバナナと  D 個のブドウがある。これらの  A+B+C+D 個の果物を、以下の条件全てを満たすように左右一列に並べる方法は何通りあるか。答えを  998244353 で割った余りを求めよ。ただし、同じ種類の果物同士は区別できないとする。

  • リンゴはすべて、バナナよりも左側に並べる。
  • リンゴはすべて、ブドウよりも左側に並べる。
  • オレンジはすべて、ブドウよりも左側に並べる。

制約

  •  1 \leq A, B, C, D \leq 10^6
  •  A, B, C, D は全て整数。

考察

並び替えの総数を計算する

数学Aの「場合の数」の単元で目にするような、並び替えの総数を数え上げる問題。

リンゴとオレンジ、バナナとオレンジ、そしてバナナとブドウの位置関係については条件がないことに注意したい。

このような場合は、より条件が強いものに注目して考えていくとよい。今回で言えば、条件文に複数回登場している「ブドウ」だ。


とりあえず、あるブドウ1個の位置を固定して、その左右に果物を並べていくことを考える。


ブドウの左側には、リンゴ  A 個とオレンジ  B 個を必ず並べなくてはならない。それに加えて、ブドウとの位置の縛りがないバナナも並べることができる。この個数を  k \: (0 \leq k \leq C) 個とする。

このとき、問題の条件より、位置関係は左から「リンゴ  A 個 → バナナ  k 個」が確定し、残りのオレンジ  B 個はこれらの果物の隙間に入れていくことになる。

このような並べ方の総数は、 Combination を用いて  {}_{A + B + k} C_{B} 通りと計算することができる。


一方、ブドウの右側には、ブドウの残り  D - 1 個とバナナの残り  C - k 個を並べる必要がある。両者に位置の縛りはないので、このような並べ方の総数は  {}_{(D - 1) + (C - k)} C_{(D-1)} 通りと計算することができる。


以上より、左右の場合の数を掛け合わせて、それを  k = 0, 1, \dots C までの和を取り、  998244353 の剰余を取った値が答えとなる :

 \begin{align*}
\sum_{k = 0}^{C} {}_{A + B + k} C_{B} \times {}_{(D - 1) + (C - k)} C_{(D-1)} \mod 998244353
\end{align*}

二項係数の素数剰余の計算

答えを求める計算式は導けたわけだが、実装にあたって厄介なのが二項係数(Combination)の計算である。

まともにやろうとすると  {}_n C_r = \frac{n!}{r!(n-r)!} を計算することになり、一つの二項係数の計算に  O(n) の時間がかかってしまい、 TLE は必至だ。


実は、素数  p の剰余を取る二項係数の高速な計算は、「前処理  O(n) \mod p の下での階乗とその逆元を求めておき、各係数の計算は  O(1) で済ませる」という方法で行うことができる。

詳細は以下の記事を参照のこと。実装例のコードもここからお借りさせてもらった。

drken1215.hatenablog.com

なお、実装時には ACLmodintを使うと簡単になる。

実装例

#include <bits/stdc++.h>
using namespace std;

#if __has_include(<atcoder/all>)
#include <atcoder/all>
using namespace atcoder;
#endif

#define repe(i, start, end) for (auto i = (start); (i) <= (end); (i)++)

using mint = modint998244353;

// ======================================== //

const int MAX = 4000100;
mint fac[MAX], finv[MAX], inv[MAX];

// 前処理
void COMinit()
{
    const int MOD = mint::mod();
    fac[0] = fac[1] = 1;
    finv[0] = finv[1] = 1;
    inv[1] = 1;
    for (int i = 2; i < MAX; i++)
    {
        fac[i] = fac[i - 1] * i;
        inv[i] = MOD - inv[MOD % i] * (MOD / i);
        finv[i] = finv[i - 1] * inv[i];
    }
}

// 二項係数計算
mint COM(int n, int k)
{
    if (n < k)
        return 0;
    if (n < 0 || k < 0)
        return 0;
    return fac[n] * finv[k] * finv[n - k];
}

int main()
{
    int A, B, C, D;
    cin >> A >> B >> C >> D;

    COMinit();

    mint ans = 0;
    repe(k, 0, C)
    {
        mint com1 = COM(A + B + k, B);
        mint com2 = COM((D - 1) + (C - k), (D - 1));
        ans += com1 * com2;
    }

    cout << ans.val() << endl;

    return 0;
}

atcoder.jp

実装時間: 60分

コメント

本質は前半の式の導出部分。高校数学の知識が使えるのはおもしろいかも。