Description

题目链接

Solution

前八个点是暴力分,然而当年我根本没看懂这题,暴力都敲不懂。

考虑 $m \le 3$ 的情况,直接设 $f(i, s_1, s_2, s_3)$ 表示使用了前 $i$ 种方法,三种食材分别用了 $s_1, s_2, s_3$ 次的方案数,然后直接 DP。

我们发现如果不考虑第三个条件,那么总的方案数很好统计,为 $\prod \limits_{i = 1}^n (\sum \limits_{j = 1}^m a_{i, j} + 1) - 1$。

注意到在所有方案中,如果有一种方案不合法,那么一定只存在一种食材用了超过 $\lfloor \dfrac{k}{2} \rfloor$ 次。如果我们直接设状态的话,需要存下所有其他食材的使用数,不如我们枚举这个不合法的食材,然后容斥计算答案。

那么我们设 $f(i, j, k)$ 表示前 $i$ 种方法,不合法的食材选了 $j$ 个,其他食材选了 $k$ 个,钦定 $m$ 为不合法的食材,那么 $f(i, j, k) = f(i - 1, j, k) + f(i - 1, j - 1, k) \cdot a_{i, m} + f(i - 1, j, k - 1) \cdot (\sum \limits_{x = 1}^m a_{i, x} - a_{i, m})$。最终对于 $m$ 这种食材,不合法的方案数为 $\sum \limits_{j > k} f(n, j, k)$。这样可以拿到 $84$ 分。

我们发现后两维其实可以压缩成一维,令 $j - k$ 作为第二维,代入原式中,可以得到类似的转移方程,答案为 $\sum \limits_{j - k > 0} f(i, j - k)$。边界 $f(0, 0) = 1$。由于 $j - k$ 可能为负数,需要加上一个偏移量。

Code

#include <iostream>
#include <cstdio>
#include <cstring>
#include <algorithm>

template <class T>
inline void read(T &x) {
    x = 0;
    int f = 0;
    char ch = getchar();
    while (!isdigit(ch))    { f |= ch == '-'; ch = getchar(); }
    while (isdigit(ch))     { x = (x << 1) + (x << 3) + (ch ^ 48); ch = getchar(); }
    x = f ? -x : x;
    return ;
}

typedef unsigned long long uLL;
typedef long long LL;

const int mod = 998244353;

LL a[110][2010], f[110][210];
LL s[110];
LL n, m, ans = 1, sum;

int main() {
    read(n), read(m);
    for (int i = 1; i <= n; ++i) {
        for (int j = 1; j <= m; ++j) {
            read(a[i][j]);
            s[i] += a[i][j];
            s[i] %= mod;
        }
    }
    for (int i = 1; i <= n; ++i)    ans = ans * (1 + s[i]) % mod;
    --ans;
    for (int i = 1; i <= m; ++i) {
        memset(f, 0, sizeof(f));
        f[0][n] = 1;
        for (int j = 1; j <= n; ++j) {
            for (int k = n - j; k <= n + j; ++k) {
                f[j][k] = f[j - 1][k] % mod + f[j - 1][k - 1] % mod * a[j][i] % mod + f[j - 1][k + 1] % mod * (s[j] % mod - a[j][i] % mod) % mod;
                f[j][k] %= mod;
            }
        }
        for (int j = 1; j <= n; ++j)    sum = sum % mod + f[n][n + j] % mod;
    }
    printf("%lld\n", ((ans - sum) % mod + mod) % mod);
    return 0;
}
Last modification:September 2nd, 2021 at 02:04 am