Problem#
Given a rooted tree with $n$ nodes rooted at $1$, and $m$ colors available, you need to color each node.
Find the number of essentially different colorings, modulo $998244353$.
Two trees are considered essentially the same if they are isomorphic (ignoring node labels, with the root unchanged) after coloring.
Subject to $n≤500$.
Solution#
First, this is a counting problem, which I'm not very good at.
I'll supplement this section once I learn more about counting problems.
Next, let's look at another key point of this problem—how to determine if a subtree is isomorphic?
Tree Hashing
This involves using certain properties of the subtree to compute a specific value for that subtree.
Different methods can be used to hash multiple times to avoid collisions.
Code#
#include<bits/stdc++.h>
#define int long long
const int mod = 998244353;
const int N = 510;
using namespace std;
int n, m;
int head[N], nxt[N << 1], to[N << 1], cnt = 1;
int siz[N];
unsigned long long hsh[N];
int sum[N];
void add(int x, int y)
{
to[++cnt] = y;
nxt[cnt] = head[x];
head[x] = cnt;
}
void dfs1(int u, int fa)
{
//
for(int i = head[u]; i; i = nxt[i])
{
int v = to[i];
if(v == fa) continue;
dfs1(v, u);
siz[u] += siz[v];
}
// cout << " " <<u << " " << siz[u] << endl;
}
void dfs2(int u, int fa)
{
hsh[u] = siz[u] + 10;
for(int i = head[u]; i; i = nxt[i])
{
int v = to[i];
if(v == fa) continue;
dfs2(v, u);
hsh[u] *= ((hsh[v] * 114 + 4869) * (hsh[v] * 114 + 4869) - 100);
}
hsh[u] -= 4869;
// cout << " " <<u << " " << siz[u] << endl;
}
int ksm(int a, int n)
{
if(a == 1) return 1;
int ans = 1;
while(n)
{
if(n & 1) ans = ans * a % mod;
a = a * a % mod;
n >>= 1;
}
return ans;
}
void dfs3(int u, int fa)
{
sum[u] = m;
unordered_map<unsigned long long, int> cnt;
for(int i = head[u]; i; i = nxt[i])
{
int v = to[i];
if(v == fa) continue;
dfs3(v, u);
}
for(int i = head[u]; i; i = nxt[i])
{
int v = to[i];
if(v == fa) continue;
sum[u] = (sum[u] * (sum[v] + cnt[hsh[v]]) % mod) * ksm(cnt[hsh[v]] + 1, mod - 2) % mod;
cnt[hsh[v]] ++;
}
// cout << u << " " << sum[u] << endl;
}
signed main()
{
scanf("%lld%lld", &n, &m);
for(int i = 1;i <= n;i++)
siz[i] = 1;
int u, v;
for(int i = 1;i <= n - 1;i++)
{
cin >> u >> v;
// cout << "qwq" << u << " " << v << endl;
add(u, v), add(v, u);
}
dfs1(1, 0);
dfs2(1, 0);
dfs3(1, 0);
printf("%lld\n", sum[1]);
return 0;
}