树上的边题解

题面

给一棵 $n$ 个点的树,定义 $f(l,r)$ 为:

$\forall i \in [l,r]$,$j \in [l,r]$,都存在 $i \to j$ 的路径时,需要选择的最少树边数量。

现在你要求出:

$$ \sum \limits _{l=1}^n \sum \limits _{r=l}^n f(l,r) $$

题解

枚举 $l$ 或者 $r$ 是没有前途的,时间复杂度的下界是 $O(n^2)$,并且无法在左端点移动时更新整体的答案,没有进一步优化的空间。

因此换一个思路,考虑统计每条边的贡献。

一条 $u \to v$ 的边,将原树分成两棵子树。该边会对答案产生贡献,当且仅当:

$\exists i \in [l,r],j \in [l,r]$,满足 $i \in \operatorname{subtree}(u),j \in \operatorname{subtree}(v)$.

对于一个编号序列 $1,2,3,\ldots,n$,每个编号只有两种状态:$i \in \operatorname{subtree}(u)$ 或 $i \in \operatorname{subtree}(v)$.

选择的一个连续编号区间 $[l,r]$,钦定的边会产生贡献,当且仅当该区间内同时包含 $i \in \operatorname{subtree}(u)$ 与 $j \in \operatorname{subtree}(v)$.

那么有一个自然的想法:维护每一段连续的相同状态的编号,即维护若干个线段 $[L,R]$ 满足 $\forall i \in [L,R], i \in \operatorname{subtree}(u \text{ or } v)$,那么对于每一条线段,只要 $l < L,r \ge R$ 或 $l \le L,r > R$ 就能满足条件。

然而不难发现,这样会算重很多贡献,容斥不太可做的样子。

正难则反,统计不合法的情况数,然后用总情况数减去不合法情况统计合法情况数。

一个连续编号区间 $[l,r]$ 不合法,当且仅当 $\forall i \in [l,r]$,$i \in \operatorname{subtree}(u \text{ or } v)$,可以发现当且仅当 $[l,r]$ 是求出的线段 $[L,R]$ 的连续子段,它就是不合法的。

那么对于每一个连续相同状态编号区间,可以计算出它对应的不合法情况数:$\dbinom{R - L + 1}{2}$.

而对于一条边,它不能产生贡献的总情况数是:$\sum \dbinom{R - L + 1}{2}$,总情况数是 $\dbinom{n}{2}$,能产生的贡献就是 $\dbinom{n}{2} - \sum \dbinom{R - L + 1}{2}$.

现在只剩下一个问题:如何维护连续相同状态编号区间及其对应贡献?

这个做法多多,但核心大同小异:加入一个点最多改变一条线段的贡献。

例如有一种 dsu on tree 的做法,用 set 来暴力维护线段,当加入一个点时,找到对应线段更新贡献,时间复杂度 $O(n \log^2 n)$.

当然可以使用线段树来维护线段,这就是一个维护连续子段的变体。

假如当前搜索到 $u$ 点,定义在 $u$ 点子树内状态为 $1$,在 $u$ 点子树外状态为 $0$,以编号为下标,每个子树就有一个唯一对应的 $01$ 串。

在这个 $01$ 串上,用线段树维护连续子段产生的贡献。唯一的难点就是上推操作。

对于左半线段与右半线段的上推合并,仅有左半线段的右端点与右半线段的左端点拼接可能产生的贡献。因此记录左右端点开始的最长连续子段,记录左右端点的值,然后更新贡献即可。

可以使用线段树合并,每个点仅会产生一次插入操作,时间复杂度 $O(n \log n)$,使用优先重儿子搜索、垃圾回收的空间优化方法,可以比较轻松地通过。

代码

不算难码,细节也不多。为了实现方便,牺牲了一定的常数。

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
#include <cstdio>
#include <algorithm>
#include <ctype.h>
const int bufSize = 1e6;
inline char nc()
{
    #ifdef DEBUG
    return getchar();
    #endif
    static char buf[bufSize], *p1 = buf, *p2 = buf;
    return p1 == p2 && (p2 = (p1 = buf) + fread(buf, 1, bufSize, stdin), p1 == p2) ? EOF : *p1++;
}
template<typename T>
inline T read(T &r)
{
    static char c;
    r = 0;
    for (c = nc(); !isdigit(c); c = nc());
    for (; isdigit(c); c = nc()) r = r * 10 + c - 48;
    return r;
}
const int maxn = 1e5 +100;
struct node
{
    int to, next;
} E[maxn << 1];
int head[maxn], tot;
inline void add(const int& x, const int& y) { E[++tot].next = head[x], E[tot].to = y, head[x] = tot; }
struct segnode
{
    int llen, rlen, lval, rval, ls, rs;
    long long val;
    bool isEmpty() const { return llen == 0; }
} A[maxn << 2];
int n, root[maxn],s[maxn << 2], top, ind, size[maxn], son[maxn];
inline int newnode() { return top > 0 ? s[top--] : ++ind; }
inline void delnode(int x) { if (x) A[x] = segnode(), s[++top] = x; }
inline long long getx(int x) { return 1ll * x * (x - 1) >> 1; }
inline segnode makezero(int l, int r)
{
    segnode res;
    res.llen = res.rlen = (r - l + 1), res.lval = res.rval = 0;
    res.val = getx(res.llen);
    return res;
}
inline void pushup(int l, int r, segnode& res, segnode ls, segnode rs)
{
    int mid = (l + r) >> 1;
    if (ls.isEmpty()) ls = makezero(l, mid);
    if (rs.isEmpty()) rs = makezero(mid + 1, r);
    res.llen = ls.llen, res.rlen = rs.rlen, res.lval = ls.lval, res.rval = rs.rval;
    res.val = ls.val + rs.val;
    if (ls.rval == rs.lval)
    {
        res.val -= getx(ls.rlen) + getx(rs.llen);
        res.val += getx(ls.rlen + rs.llen);
        if (ls.llen == (mid - l + 1)) res.llen += rs.llen;
        if (rs.rlen == (r - mid)) res.rlen += ls.rlen;
    }
}
void modify(int l, int r, int& p, int pos, int k)
{
    if(!p) p = newnode();
    if (l == r) return (void)(A[p].llen = A[p].rlen = 1, A[p].lval = A[p].rval = k, A[p].val = 0);
    int mid = (l + r) >> 1;
    if (pos <= mid) modify(l, mid, A[p].ls, pos, k);
    else modify(mid + 1, r, A[p].rs, pos, k);
    pushup(l, r, A[p], A[A[p].ls], A[A[p].rs]);
}
void merge(int l, int r, int& x, int y)
{
    if (!x || !y) return (void)(x += y);
    int mid = (l + r) >> 1;
    merge(l, mid, A[x].ls, A[y].ls), merge(mid + 1, r, A[x].rs, A[y].rs), delnode(y);
    pushup(l, r, A[x], A[A[x].ls], A[A[x].rs]);
}
long long res;
void dfs1(int u, int fa)
{
    size[u] = 1;
    for (int p = head[u]; p; p = E[p].next)
    {
        int v = E[p].to;
        if (v == fa) continue;
        dfs1(v, u), size[u] += size[v];
        if (size[v] > size[son[u]]) son[u] = v;
    }
}
void dfs2(int u, int fa)
{
    if (son[u]) dfs2(son[u], u), merge(1, n, root[u], root[son[u]]);
    for (int p = head[u]; p; p = E[p].next)
    {
        int v = E[p].to;
        if (v == fa || v == son[u]) continue;
        dfs2(v, u), merge(1, n, root[u], root[v]);
    }
    modify(1, n, root[u], u, 1);
    if (fa) res += getx(n) - A[root[u]].val;
}
int main()
{
    read(n);
    for (int i = 1, a, b; i < n; ++i) read(a), read(b), add(a, b), add(b, a);
    dfs1(1, 0), dfs2(1, 0);
    printf("%lld\n", res);
    return 0;
}