[国家集训队] 聪聪可可 解题记录
前言
看到题解区全是用容斥做的,但是我太蒻了不会,所以来水一发不用容斥的题解。
题意简述
给定一棵树,边有边权,任意选择一条路径,求这条路径的长度是 3 3 3 的倍数的概率。
题目分析
易知,概率即为 合法的路径 所有路径 \frac{合法的路径}{所有路径} 所有路径合法的路径。
关键在于合法路径怎么去求。
使用点分治算法,我们枚举重心 r t rt rt,对于它的每个儿子,用 d i s t i dist_i disti 表示结点 i i i 在模 3 3 3 意义下距离根的长度;用 c n t 0 / 1 / 2 cnt_{0/1/2} cnt0/1/2 表示在遍历当前结点之前,距离根的长度为 0 / 1 / 2 0/1/2 0/1/2 的链的数量。
接下来分类讨论:
- 如果 d i s t i = 0 dist_i = 0 disti=0,那么它一定能与之前所有距离根长度为 0 0 0 的链组合成一条合法路径,合法数量加上 c n t 0 × 2 + 2 cnt_0\times 2+2 cnt0×2+2, + 2 +2 +2 是因为它本身到根也可以组成两个数对。
- 如果 d i s t i = 1 dist_i = 1 disti=1,那么它一定能与之前所有距离根长度为 2 2 2 的链组合成一条合法路径,合法数量加上 c n t 2 × 2 cnt_2\times 2 cnt2×2 。
- 如果 d i s t i = 2 dist_i = 2 disti=2,那么它一定能与之前所有距离根长度为 2 2 2 的链组合成一条合法路径,合法数量加上 c n t 1 × 2 cnt_1\times 2 cnt1×2 。
在dfs
函数内部,我们枚举当前结点 x x x 的儿子,先更新 d i s t dist dist 数组,然后更新答案,最后才更新 c n t cnt cnt,因为 c n t cnt cnt 保存的是当前儿子之前的数据。
AC Code
// Problem: P2634 [国家集训队] 聪聪可可
// Contest: Luogu
// URL: https://www.luogu.com.cn/problem/P2634
// Memory Limit: 125 MB
// Time Limit: 1000 ms
// Author: Li_Feiy
#include<bits/stdc++.h>
#define arrout(a,n) rep(i,1,n)std::cout<<a[i]<<" "
#define arrin(a,n) rep(i,1,n)a[i]=read()
#define rep(i,x,n) for(int i=x;i<=n;i++)
#define dep(i,x,n) for(int i=x;i>=n;i--)
#define erg(i,x) for(int i=head[x];i;i=e[i].nex)
#define dbg(x) std::cout<<#x<<":"<<x<<" "
#define mem(a,x) memset(a,x,sizeof a)
#define all(x) x.begin(),x.end()
#define arrall(a,n) a+1,a+1+n
#define PII std::pair<int,int>
#define m_p std::make_pair
#define u_b upper_bound
#define l_b lower_bound
#define p_b push_back
#define CD const double
#define CI const int
#define int long long
#define il inline
#define ss second
#define ff first
#define itn int
int read() {
char ch=getchar();
int r=0,w=1;
while(ch<'0'||ch>'9') w=ch=='-'?-1:w,ch=getchar();
while(ch>='0'&&ch<='9') r=r*10+ch-'0',ch=getchar();
return r*w;
}
CI N=2e4+5,INF=1e8+5;
int n,fz,fm,rt,tot,sum,cnt[N],max[N],dis[N],dist[N],size[N],head[N];
bool tf[INF],vis[N];
std::queue<int> t;
struct edge {
int to,nex,data;
}e[N<<1];
void add(int x,int y,int z) {
e[++tot].to=y;
e[tot].data=z;
e[tot].nex=head[x];
head[x]=tot;
}
void update_size(int x,int fa) {
size[x]=1;
max[x]=0;
erg(i,x) {
int y=e[i].to;
if(y==fa||vis[y]) continue;
update_size(y,x);
size[x]+=size[y];
max[x]=std::max(max[x],size[y]);
}
max[x]=std::max(max[x],sum-size[x]);
if(max[x]<max[rt]) rt=x;
}
void update_dist(int x,int fa) {
dis[++dis[0]]=dist[x]%3;
erg(i,x) {
int y=e[i].to,z=e[i].data;
if(y==fa||vis[y]) continue;
dist[y]=(dist[x]+z)%3;
update_dist(y,x);
}
}
void update_cnt(int x,int fa) {
cnt[dist[x]%3]++;
erg(i,x) {
int y=e[i].to;
if(y==fa||vis[y]) continue;
update_cnt(y,x);
}
}
void dfs(int x,int fa) {
vis[x]=1;
erg(i,x) {
int y=e[i].to,z=e[i].data;
if(y==fa||vis[y]) continue;
dist[y]=z%3;
update_dist(y,x);
rep(i,1,dis[0]) {
switch(dis[i]) {
case 0:
fz+=cnt[0]*2+2;
break;
case 1:
fz+=cnt[2]*2;
break;
case 2:
fz+=cnt[1]*2;
break;
}
}
dis[0]=0;
update_cnt(y,x);
}
cnt[0]=cnt[1]=cnt[2]=0;
erg(i,x) {
int y=e[i].to,z=e[i].data;
if(y==fa||vis[y]) continue;
sum=size[y];
rt=0;
max[rt]=INF;
update_size(y,x);
update_size(rt,-1);
dfs(rt,x);
}
}
signed main() {
n=read();
rep(i,1,n-1) {
int x=read(),y=read(),z=read();
add(x,y,z);
add(y,x,z);
}
sum=fz=n;
rt=0;
max[rt]=INF;
update_size(1,-1);
update_size(rt,-1);
dfs(rt,-1);
fm=n*n;
while(std::__gcd(fz,fm)!=1) {
int gcd=std::__gcd(fz,fm);
fz/=gcd,fm/=gcd;
}
printf("%lld/%lld",fz,fm);
return 0;
}