给定一个字符串 (),求有多少个字符串 使得 为 的子序列。
这题的难点主要在于防止算重。我们定义 表示前一个 中的字符在原串中出现的位置, 表示后一个 中的字符在原串中出现的位置,令 表示从 往后第一个 出现的位置(包括 )。
为了避免算重,我们强制要求 和 尽量小。举个例子,如果原串是 acbcab
,
是 ab
,那么 和 便如下计算:
可以发现,这样算出来的 和 一定是最小的,并且对于每一个 ,最多只有一个 和 。
我们枚举一个 ,则 。令 表示 的末尾元素为 , 的末尾元素为 时的方案数,显然 。转移方程:。当然,人人为我的转移写起来很困难,可以考虑我为人人的转移,枚举一个字符 , 会对 贡献。最后统计 的答案就 。
时间复杂度 。
#include <algorithm>
#include <array>
#include <iostream>
#include <string>
#include <vector>
#include <atcoder/modint>
using mint = atcoder::modint998244353;
int main()
{
std::string s;
std::cin >> s;
int n = s.size();
std::vector<std::array<int, 26>> next(n + 1);
std::fill(next[n].begin(), next[n].end(), n);
for (int i = n; i > 0; ){
i--;
next[i] = next[i + 1];
next[i][s[i] - 'a'] = i;
}
mint ans = 0;
for (int q = 0; q < n; q++) {
int p = next[0][s[q] - 'a'];
if (p >= q) continue;
std::vector f(n, std::vector<mint>(n));
f[p][q] = 1;
for (int i = 0; i < n; i++) {
for (int j = 0; j < n; j++) {
for (int ch = 0; ch < 26; ch++) {
int ni = next[i + 1][ch];
int nj = next[j + 1][ch];
if (ni >= q || nj >= n) continue;
f[ni][nj] += f[i][j];
}
}
}
for (int i = 0; i < n; i++) {
for (int j = 0; j < n; j++) {
if (next[i + 1][s[q] - 'a'] != q) continue;
ans += f[i][j];
}
}
}
std::cout << ans.val() << std::endl;
}