Description

题目链接

题目大意:$m$ 次询问,每次给定 $l, r, z$,求 $\sum \limits_{i = l}^r \text{dep}(\text{LCA}(i, z))$。$n, m \le 5 \times 10^4$。

Solution

一个很自然的想法是计算每个点作为 $\text{LCA}$ 的贡献。

注意到 $\text{dep}(u)$ 相当于将根到 $u$ 的路径上的点加 $1$,然后查询从 $u$ 到根的点权和。

由于 $\text{LCA}(i, z)$ 一定在 $z$ 到根的路径上,我们可以得出一个暴力算法:枚举 $i \in [l, r]$,将根到 $i$ 的路径上的点加 $1$,查询的答案即为 $z$ 到根的路径上的点权和。

然而这样做的时间并不优,因为无法避免其他点对当前询问的影响,每次都需要清空线段树。

考虑离线后树上差分,将询问拆成 $[1, r]$ 的贡献减去 $[1, l - 1]$ 的贡献,然后按右端点排序,这样就解决了上面的问题。

接下来从 $1$ 到 $n$ 枚举点 $i$,同时维护差分后的询问编号 $j$,每次查询即可。

Code

#include <iostream>
#include <cstdio>
#include <cstring>
#include <algorithm>
#include <vector>
#include <cmath>

template <class T>
inline void read(T &x) {
    x = 0;
    int f = 0;
    char ch = getchar();
    while (!isdigit(ch))    { f |= ch == '-'; ch = getchar(); }
    while (isdigit(ch))     { x = (x << 1) + (x << 3) + (ch ^ 48); ch = getchar(); }
    x = f ? -x : x;
    return ;
}

typedef long long LL;
typedef unsigned long long uLL;

const int mod = 201314;

struct Query {
    int l, r, id;
    Query() : l(), r(), id() {}
    Query(const int &il, const int &ir, const int &iid) : l(il), r(ir), id(iid) {}
    friend bool operator < (const Query &a, const Query &b) {
        return a.l < b.l;
    }
} a[100010];

struct Node {
    int l, r, v, tag;
} t[200010];

std::vector<int> g[50010];
int dfn[50010], siz[50010], son[50010], top[50010], fa[50010], dep[50010], ans[50010];
int n, q, cnt, tot;

inline int lson(int x) { return x << 1; }
inline int rson(int x) { return x << 1 | 1; }

void dfs1(int x, int p) {
    fa[x] = p;
    siz[x] = 1;
    dep[x] = dep[p] + 1;
    for (auto i : g[x]) {
        if (i != p) {
            dfs1(i, x);
            if (siz[i] > siz[son[x]])    son[x] = i;
        }
    }
}

void dfs2(int x, int p) {
    dfn[x] = ++cnt;
    top[x] = p;
    if (son[x]) {
        dfs2(son[x], p);
        for (auto i : g[x]) {
            if (i != fa[x] && i != son[x]) {
                dfs2(i, i);
            }
        }
    }
}

void buildTree(int x, int l, int r) {
    t[x].l = l, t[x].r = r;
    if (l == r)    return ;
    int mid = (l + r) >> 1;
    buildTree(lson(x), l, mid), buildTree(rson(x), mid + 1, r);
}

void pushDown(int x) {
    if (t[x].tag) {
        t[lson(x)].tag += t[x].tag, t[rson(x)].tag += t[x].tag;
        t[lson(x)].v += t[x].tag * (t[lson(x)].r - t[lson(x)].l + 1), t[rson(x)].v += t[x].tag * (t[rson(x)].r - t[rson(x)].l + 1);
        t[lson(x)].v %= mod, t[rson(x)].v %= mod;
        t[x].tag = 0;
    }
}

void modify(int x, int l, int r, int d) {
    if (l <= t[x].l && t[x].r <= r) {
        t[x].v += (t[x].r - t[x].l + 1) * d;
        t[x].v %= mod;
        t[x].tag += d;
        return ;
    }
    pushDown(x);
    int mid = (t[x].l + t[x].r) >> 1;
    if (l <= mid)    modify(lson(x), l, r, d);
    if (r > mid)     modify(rson(x), l, r, d);
    t[x].v = t[lson(x)].v + t[rson(x)].v;
    t[x].v %= mod;
}

int query(int x, int l, int r) {
    if (l <= t[x].l && t[x].r <= r)    return t[x].v;
    pushDown(x);
    int mid = (t[x].l + t[x].r) >> 1, s = 0;
    if (l <= mid)    s += query(lson(x), l, r);
    if (r > mid)     s += query(rson(x), l, r);
    return s % mod;
}

void update(int u) {
    while (u) {
        modify(1, dfn[top[u]], dfn[u], 1);
        u = fa[top[u]];
    }
}

int querySum(int u) {
    int s = 0;
    while (u) {
        s += query(1, dfn[top[u]], dfn[u]);
        u = fa[top[u]];
    }
    return s;
}

int main() {
    read(n), read(q);
    for (int i = 2, f; i <= n; ++i) {
        read(f);
        g[i].push_back(f + 1);
        g[f + 1].push_back(i);
    }
    buildTree(1, 1, n);
    dfs1(1, 0);
    dfs2(1, 1);
    for (int i = 1, l, r, z; i <= q; ++i) {
        read(l), read(r), read(z);
        ++l, ++r, ++z;
        a[++tot] = Query(r, z, i * 2 - 1), a[++tot] = Query(l - 1, z, i * 2);
    }
    int now = 1;
    std::sort(a + 1, a + q + q + 1);
    while (a[now].l == 0 && now <= tot) {
        int d = querySum(a[now].r);
        if (a[now].id & 1)    ans[(a[now].id + 1) >> 1] += d;
        else    ans[(a[now].id + 1) >> 1] -= d;
        ++now;
    }
    for (int i = 1; i <= n; ++i) {
        update(i);
        while (a[now].l == i && now <= tot) {
            int d = querySum(a[now].r);
            if (a[now].id & 1)    ans[(a[now].id + 1) >> 1] += d;
            else    ans[(a[now].id + 1) >> 1] -= d;
            ++now;
        }
    }
    for (int i = 1; i <= q; ++i)    printf("%d\n", (ans[i] % mod + mod) % mod);
    return 0;
}
Last modification:November 9th, 2021 at 01:13 am