题意
U2FsdGVkX19a/FECrawlE6grok2qk+quCztZzi8zn70EsBShO5C+sKe6jSAsNlVy
6l8voQ8MneOb/jUd/vDJMJxwAdu7/YdrVChsXdueoeJ8HrDtHAJt/VFKptm2DPT1
H5D+WGa7xTWwy4DD8TLliLOL90GDtyJ+f16QR1PTCvYVaf+DBHLXRHrE5kJyaB+C
BV/hNr+Iru22LR9+DwwEqroxDkkU0dvP90itbD2SAg3CdpFua50CGCysBzcu0vkM
afPHPLlz3LuD6j1ihcGcTNXAxyscifKof1zsrWJsKJ8sM8kBE08sWG7dJITOgblm
wQcWfQUVE4tVbAaGyNtLzpkCqr+5wpDc1dd8iZO1k7rfnTx/F8b2yWo7R5K6iX/8
ybtfC/VOnVO/tZOPjHtc66ab/DPBJZYv0J/wrTD0Bqo=
解析
考虑这个子序列可能长什么样,设 表示序列中的最大值, 表示序列中的最小值, , 都是任意介于 , 之间的值,大概就是这样的:
c c c c c c ... c a b a a b b ... a b a d d d d d d ... d
|- c 段(可空)--|- ab 混合段 ---------|- d 段(可空)--|
所以枚举最大值和最小值 ,。枚举 ab 混合段的左右端点 和 ,可以预处理出 表示 之前小于 的作为 c 段的方案数,同理 表示 之后的。这个方案数就是: ,注意到你可以把这个贡献拆成:,然后可以类似扫描线一样,枚举 ,维护 之前所有 的那个贡献的和。注意,这样算有可能会算到只有 a 没有 b 或只有 b 没有 a 的情况,只需要再算一下只有 a 和只有 b 的,减去就行了。
实现
using mint = static_modint<998'244'353>;
const mint INV2 = (mint::mod() + 1) / 2;
int main()
{
set_io();
int m, k;
std::cin >> m >> k;
int n = m * k;
std::vector<int> a(n);
std::vector<std::vector<int>> pos(m);
for (int i = 0; i < n; i++) {
std::cin >> a[i];
a[i]--;
pos[a[i]].emplace_back(i);
}
std::vector<mint> base2(n + 1), ibase2(n + 1);
ibase2[0] = base2[0] = 1;
for (int i = 0; i < n; i++) base2[i + 1] = base2[i] * 2;
for (int i = 0; i < n; i++) ibase2[i + 1] = ibase2[i] * INV2;
std::vector pre_sum(m + 1, std::vector<mint>(n + 1));
{
std::vector<int> cnt(m);
for (int i = 0; i < n; i++) {
cnt[a[i]]++;
for (int j = 0; j < m; j++) {
pre_sum[j + 1][i + 1] = pre_sum[j][i + 1] + base2[cnt[j]] - 1;
}
}
}
std::vector suf_sum(m + 1, std::vector<mint>(n + 1));
{
std::vector<int> cnt(m);
for (int i = n; i > 0; i--) {
cnt[a[i - 1]]++;
for (int j = 0; j < m; j++) {
suf_sum[j + 1][i - 1] = suf_sum[j][i - 1] + base2[cnt[j]] - 1;
}
}
}
auto calc = [&pre_sum, &suf_sum, &base2, &ibase2](const std::vector<int> &p, int l, int r)
{
mint ans = 0, sum = 0;
for (size_t i = 0; i < p.size(); i++) {
int v = p[i];
ans += sum * base2[i] * (suf_sum[r][v] - suf_sum[l + 1][v] + 1);
sum += ibase2[i + 1] * (pre_sum[r][v] - pre_sum[l + 1][v] + 1);
}
return ans;
};
mint ans = 0;
std::vector<int> lrpos(k * 2);
for (int l = 0; l < m; l++) {
for (int r = l + 1; r < m; r++) {
std::merge(pos[l].begin(), pos[l].end(), pos[r].begin(), pos[r].end(), lrpos.begin());
ans += calc(lrpos, l, r);
ans -= calc(pos[l], l, r);
ans -= calc(pos[r], l, r);
}
}
ans += (base2[k] - 1) * m;
std::cout << ans.val() << std::endl;
}