1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82
| #include<bits/stdc++.h> #define inf 0x3f3f3f3f3f3f3f3fll #define debug(x) cerr<<#x<<"="<<x<<endl using namespace std; using ll=long long; using ld=long double; using pli=pair<ll,int>; using pi=pair<int,int>; template<typename A> using vc=vector<A>; inline int read() { int s=0,w=1;char ch; while((ch=getchar())>'9'||ch<'0') if(ch=='-') w=-1; while(ch>='0'&&ch<='9') s=s*10+ch-'0',ch=getchar(); return s*w; } int num[12000001]; int ls[12000001]; int rs[12000001]; int root[300001]; int siz[300001]; int fa[300001]; queue<pi>que; int n,m,k; int cnt; ll ans; inline int find(int num) { if(fa[num]==num) return num; return fa[num]=find(fa[num]); } void ins(int &p,int pl,int pr,int x,int y) { if(!p) p=++cnt; if(pl==pr) { if(num[p]) que.push(pi(num[p],y)); num[p]=y;return ; } int mid=(pl+pr)>>1; if(x<=mid) ins(ls[p],pl,mid,x,y); else ins(rs[p],mid+1,pr,x,y); } int merge(int u,int v,int pl,int pr) { if(!u||!v) return u|v; if(pl==pr) { que.push(pi(num[u],num[v])); return u; } int mid=(pl+pr)>>1; ls[u]=merge(ls[u],ls[v],pl,mid); rs[u]=merge(rs[u],rs[v],mid+1,pr); return u; } int main() { n=read(),m=read(),k=read(); for(int i=1;i<=n;i++) fa[i]=i; for(int i=1;i<=m;i++) { int u=read(),v=read(),w=read(); ins(root[v],1,k,w,u); } while(!que.empty()) { int u=que.front().first,v=que.front().second; que.pop(),u=find(u),v=find(v); if(u==v) continue; fa[v]=u,root[u]=merge(root[u],root[v],1,k); } for(int i=1;i<=n;i++) { int p=find(i); ans+=siz[p]; siz[p]++; } printf("%lld\n",ans); return 0; }
|