原题面:Codeforces
扯
这个缩点有点 nb。
这个转化枚举对象有点 nb。
题意简述
给定一张 \(n\) 个节点 \(m\) 条边的无向图,没有自环重边。
每个节点都在一个组 (group) 中,共有 \(k\) 组,可能存在某组为空。
求选出两组点,是它们能构成二分图的 方案数。 \(1\le n\le 5\times 10^5\),\(0\le m\le 5\times 10^5\),\(2\le k\le 5\times 10^5\)。 \(1\le c_i\le k\)。
3S,512MB。
分析题意
从某提交记录学到的
感谢 DrCold 的注释代码!
提交记录:https://codeforces.com/contest/1445/submission/97408824。
一张图是二分图的充要条件是不存在奇环。
先考虑每一组是否为二分图,使用扩展域并查集维护。
将内部不为二分图的组删除,再反向转化一下,求剩下的 \(k'\) 组里,有多少对不能组成二分图,答案即 \(\frac{k'(k'-1)}{2} -\) 对数。
之后枚举所有点对检查?太慢啦!
发现两组点 不能构成二分图的必要条件,是两组点之间有至少一条边。
发现图很稀疏,考虑枚举所有的 连接不同组 的边检查。
先将所有连接不同组的边按照 它们两端的点的组号排序,使得连接的组相同的边 是 被连续枚举的。
在判断每组内是否为二分图后,得到的扩展域并查集的基础上,将枚举到的新的边加入图中。
若新边连接的两个点的组中出现奇环,则该边无用,直接跳过。
若加入后出现奇环,则说明连接的两组不能构成二分图,答案 -1。
当当前枚举的边 与 上一条边连接的组不同时,回退之前的连接操作。
因此需要一个可撤销并查集维护,则以上过程的复杂度为 \(O(m\log n)\)。
官方题解
首先使用 DFS 黑白染色求每组内是否为二分图。
将一个内部为二分图的组压缩成两个节点一条边(黑白点各压成一个节点)。
将只有一个节点的组保留,删去内部不为二分图的组。
考虑不同组之间的点,若它们之间存在一条边,则在新图中它们对应的压缩节点之间连一条边。
检查时枚举连接不同组的边,对于连接的组相同的边放在一起,对连接的两组在加入上述边的情况下进行黑白染色。
压缩后上述过程最多只会枚举 4 个点,6 条边,所有黑白染色的总复杂度可看做 \(O(n)\)。
算法总复杂度可看做 \(O(n+m)\)。
爆零小技巧
注意一些为了使答案不重不漏而添加的特判,详见代码。
代码实现
//知识点:可撤销并查集,枚举
/*
By:Luckyblock
*/
#include <algorithm>
#include <cctype>
#include <cstdio>
#include <cstring>
#include <vector>
#define LL long long
const int kMaxn = 5e5 + 10;
//=============================================================
LL ans;
int n, m, k, cnt, col[kMaxn];
bool no[kMaxn];
int top, fa[kMaxn << 1], size[kMaxn << 1];
struct Stack {
int u, v, fa, size;
} st[kMaxn << 1];
int e_num;
struct Edge {
int u, v, colu, colv;
} e[kMaxn];
//=============================================================
inline int read() {
int f = 1, w = 0;
char ch = getchar();
for (; !isdigit(ch); ch = getchar())
if (ch == '-') f = -1;
for (; isdigit(ch); ch = getchar()) w = (w << 3) + (w << 1) + (ch ^ '0');
return f * w;
}
void Chkmax(int &fir_, int sec_) {
if (sec_ > fir_) fir_ = sec_;
}
void Chkmin(int &fir_, int sec_) {
if (sec_ < fir_) fir_ = sec_;
}
bool CompareEdge(Edge f_, Edge s_) {
if (f_.colu != s_.colu) return f_.colu< s_.colu;
return f_.colv < s_.colv;
}
int Find(int x_) {
while (x_ != fa[x_]) x_ = fa[x_];
return x_;
}
void Union(int u_, int v_) {
int fu = Find(u_), fv = Find(v_);
if (size[fu] > size[fv]) std::swap(fu, fv); //没写按秩合并
st[++ top] = (Stack) {fu, fv, fa[fu], size[fu]};
if (fu != fv) {
fa[fu] = fv;
size[fv] += size[fu];
size[fu] = 0;
}
}
void Restore() {
Stack now = st[top];
if (now.u != now.v) {
fa[now.u] = now.fa;
size[now.u] = now.size;
size[now.v] -= now.size;
}
top --;
}
//=============================================================
int main() {
n = read(), m = read(), cnt = k = read();
for (int i = 1; i <= n; ++ i) col[i] = read();
for (int i = 1; i <= 2 * n; ++ i) {
fa[i] = i;
size[i] = 1;
}
for (int i = 1; i <= m; ++ i) {
int u_ = read(), v_ = read();
if (col[u_] != col[v_]) {
e[++ e_num] = (Edge) {u_, v_, col[u_], col[v_]};
if (e[e_num].colu > e[e_num].colv) {
std::swap(e[e_num].colu, e[e_num].colv);
}
continue ;
}
if (no[col[u_]]) continue ; //特判,防止重复统计
if (Find(u_) == Find(v_)) {
cnt --;
no[col[u_]] = true;
continue ;
}
Union(u_, v_ + n);
Union(v_, u_ + n);
}
ans = 1ll * (cnt - 1ll) * cnt / 2ll;
int last_top = top, flag = false;
std::sort(e + 1, e + e_num + 1, CompareEdge);
for (int i = 1; i <= e_num; ++ i) {
Edge now = e[i];
if (no[now.colu] || no[now.colv]) continue ;
if (now.colu != e[i - 1].colu || now.colv != e[i - 1].colv) {
while (top != last_top) Restore();
flag = false;
}
if (flag) continue ; //特判,防止重复统计
if (Find(now.u) == Find(now.v)) {
ans --;
flag = true;
continue ;
}
Union(now.u, now.v + n);
Union(now.v, now.u + n);
}
printf("%lld\n", ans);
return 0;
}
|
请发表评论